mirror of
https://github.com/tig-foundation/tig-monorepo.git
synced 2026-02-21 10:27:49 +08:00
This commit is contained in:
parent
d1774ed6ec
commit
1fff31b630
@ -26,6 +26,10 @@ tig-challenges = { path = "../tig-challenges", features = [
|
||||
[lib]
|
||||
crate-type = ["cdylib", "rlib"]
|
||||
|
||||
[[example]]
|
||||
name = "test_leverage"
|
||||
required-features = ["cur_decomposition"]
|
||||
|
||||
[features]
|
||||
c001 = ["tig-challenges/c001"]
|
||||
satisfiability = ["c001"]
|
||||
|
||||
369
tig-algorithms/examples/test_leverage.rs
Normal file
369
tig-algorithms/examples/test_leverage.rs
Normal file
@ -0,0 +1,369 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use cudarc::{
|
||||
driver::{CudaContext, CudaModule, CudaStream},
|
||||
nvrtc::Ptx,
|
||||
runtime::{result::device::get_device_prop, sys::cudaDeviceProp},
|
||||
};
|
||||
use std::{cell::RefCell, sync::Arc, time::Instant};
|
||||
use tig_challenges::cur_decomposition::*;
|
||||
|
||||
#[path = "../src/cur_decomposition/leverage/mod.rs"]
|
||||
mod leverage;
|
||||
|
||||
// ─── Stats helper ────────────────────────────────────────────────────────────
|
||||
|
||||
struct Stats {
|
||||
values: Vec<f64>,
|
||||
}
|
||||
|
||||
impl Stats {
|
||||
fn new() -> Self {
|
||||
Self { values: Vec::new() }
|
||||
}
|
||||
fn push(&mut self, v: f64) {
|
||||
self.values.push(v);
|
||||
}
|
||||
fn is_empty(&self) -> bool {
|
||||
self.values.is_empty()
|
||||
}
|
||||
fn mean(&self) -> f64 {
|
||||
self.values.iter().sum::<f64>() / self.values.len() as f64
|
||||
}
|
||||
fn min(&self) -> f64 {
|
||||
self.values.iter().cloned().fold(f64::INFINITY, f64::min)
|
||||
}
|
||||
fn max(&self) -> f64 {
|
||||
self.values.iter().cloned().fold(f64::NEG_INFINITY, f64::max)
|
||||
}
|
||||
fn std(&self) -> f64 {
|
||||
let m = self.mean();
|
||||
(self.values.iter().map(|x| (x - m).powi(2)).sum::<f64>() / self.values.len() as f64)
|
||||
.sqrt()
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Helpers ─────────────────────────────────────────────────────────────────
|
||||
|
||||
/// The singular values used in challenge generation are strictly decreasing
|
||||
/// (exp(-l5*sqrt(j+1)/sqrt(T)) for j=0..T), so their sorted order is fixed
|
||||
/// regardless of the seed shuffle. optimal_fnorm is purely a function of T and K.
|
||||
fn compute_optimal_fnorm(true_rank: i32, target_rank: i32) -> f32 {
|
||||
let l5: f32 = 2.0;
|
||||
(0..true_rank)
|
||||
.skip(target_rank as usize)
|
||||
.map(|j| {
|
||||
let s = (-l5 * ((j + 1) as f32).sqrt() / (true_rank as f32).sqrt()).exp();
|
||||
s * s
|
||||
})
|
||||
.sum::<f32>()
|
||||
.sqrt()
|
||||
}
|
||||
|
||||
fn make_seed(index: u64) -> [u8; 32] {
|
||||
let mut seed = [0u8; 32];
|
||||
seed[0..8].copy_from_slice(&index.to_le_bytes());
|
||||
seed
|
||||
}
|
||||
|
||||
// ─── Per-config benchmark ─────────────────────────────────────────────────────
|
||||
|
||||
fn run_config(
|
||||
m: i32,
|
||||
n: i32,
|
||||
true_rank: i32,
|
||||
target_rank: i32,
|
||||
num_seeds: usize,
|
||||
hyperparameters: &Option<serde_json::Map<String, serde_json::Value>>,
|
||||
module: Arc<CudaModule>,
|
||||
stream: Arc<CudaStream>,
|
||||
prop: &cudaDeviceProp,
|
||||
) -> Result<()> {
|
||||
println!("\n╔══ M={} N={} T={} K={} ══", m, n, true_rank, target_rank);
|
||||
println!(
|
||||
"║ {:>5} {:>10} {:>10} {:>9} {:>9}",
|
||||
"seed", "gen_ms", "algo_ms", "quality", "fnorm/opt"
|
||||
);
|
||||
println!("╟{}", "─".repeat(52));
|
||||
|
||||
let track = Track { m, n };
|
||||
let optimal = compute_optimal_fnorm(true_rank, target_rank);
|
||||
let baseline = 50.0 * optimal;
|
||||
|
||||
let mut gen_ms = Stats::new();
|
||||
let mut algo_ms = Stats::new();
|
||||
let mut quality_stats = Stats::new();
|
||||
let mut ratio_stats = Stats::new();
|
||||
let mut passes = 0usize;
|
||||
let mut attempted = 0usize;
|
||||
|
||||
for s in 0..num_seeds {
|
||||
let seed = make_seed(s as u64);
|
||||
|
||||
// ── Generation ─────────────────────────────────────────────────────
|
||||
let t0 = Instant::now();
|
||||
let challenge = match Challenge::generate_single_instance(
|
||||
&seed,
|
||||
&track,
|
||||
true_rank,
|
||||
target_rank,
|
||||
module.clone(),
|
||||
stream.clone(),
|
||||
prop,
|
||||
) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
println!("║ {:>5} gen error: {}", s, e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let g_ms = t0.elapsed().as_secs_f64() * 1000.0;
|
||||
gen_ms.push(g_ms);
|
||||
|
||||
// ── Algorithm ──────────────────────────────────────────────────────
|
||||
let best: RefCell<Option<Solution>> = RefCell::new(None);
|
||||
let t1 = Instant::now();
|
||||
let result = {
|
||||
let save_fn = |sol: &Solution| -> Result<()> {
|
||||
*best.borrow_mut() = Some(sol.clone());
|
||||
Ok(())
|
||||
};
|
||||
leverage::solve_challenge(
|
||||
&challenge,
|
||||
&save_fn,
|
||||
hyperparameters,
|
||||
module.clone(),
|
||||
stream.clone(),
|
||||
prop,
|
||||
)
|
||||
};
|
||||
let a_ms = t1.elapsed().as_secs_f64() * 1000.0;
|
||||
algo_ms.push(a_ms);
|
||||
attempted += 1;
|
||||
|
||||
let solution = match result {
|
||||
Ok(sol) => sol.or_else(|| best.into_inner()),
|
||||
Err(e) => {
|
||||
println!("║ {:>5} {:>10.1} algo error: {}", s, g_ms, e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
match solution {
|
||||
None => {
|
||||
println!(
|
||||
"║ {:>5} {:>10.1} {:>10.1} no solution",
|
||||
s, g_ms, a_ms
|
||||
);
|
||||
}
|
||||
Some(sol) => {
|
||||
let fnorm =
|
||||
challenge.evaluate_fnorm(&sol, module.clone(), stream.clone(), prop)?;
|
||||
let q = (baseline - fnorm) / (baseline - optimal);
|
||||
let ratio = fnorm / optimal;
|
||||
|
||||
quality_stats.push(q as f64);
|
||||
ratio_stats.push(ratio as f64);
|
||||
if q > 0.0 {
|
||||
passes += 1;
|
||||
}
|
||||
|
||||
println!(
|
||||
"║ {:>5} {:>10.1} {:>10.1} {:>9.4} {:>9.4}{}",
|
||||
s,
|
||||
g_ms,
|
||||
a_ms,
|
||||
q,
|
||||
ratio,
|
||||
if q <= 0.0 { " ← FAIL" } else { "" }
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
println!("╟{}", "─".repeat(52));
|
||||
|
||||
if !gen_ms.is_empty() {
|
||||
println!(
|
||||
"║ gen_ms avg={:8.1} min={:8.1} max={:8.1}",
|
||||
gen_ms.mean(),
|
||||
gen_ms.min(),
|
||||
gen_ms.max()
|
||||
);
|
||||
}
|
||||
if !algo_ms.is_empty() {
|
||||
println!(
|
||||
"║ algo_ms avg={:8.1} min={:8.1} max={:8.1}",
|
||||
algo_ms.mean(),
|
||||
algo_ms.min(),
|
||||
algo_ms.max()
|
||||
);
|
||||
}
|
||||
if !quality_stats.is_empty() {
|
||||
println!(
|
||||
"║ quality avg={:8.4} min={:8.4} max={:8.4} std={:.4} pass={}/{}",
|
||||
quality_stats.mean(),
|
||||
quality_stats.min(),
|
||||
quality_stats.max(),
|
||||
quality_stats.std(),
|
||||
passes,
|
||||
attempted
|
||||
);
|
||||
println!(
|
||||
"║ f/opt avg={:8.4} min={:8.4} max={:8.4} std={:.4}",
|
||||
ratio_stats.mean(),
|
||||
ratio_stats.min(),
|
||||
ratio_stats.max(),
|
||||
ratio_stats.std()
|
||||
);
|
||||
}
|
||||
println!("╚{}", "═".repeat(52));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ─── Main ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
fn print_usage(prog: &str) {
|
||||
eprintln!(
|
||||
"Usage: {} <PTX_PATH> <M> <N> <T> <K> [<M> <N> <T> <K> ...] [OPTIONS]",
|
||||
prog
|
||||
);
|
||||
eprintln!();
|
||||
eprintln!(" PTX_PATH Path to compiled leverage.ptx");
|
||||
eprintln!(" Build with:");
|
||||
eprintln!(" CHALLENGE=cur_decomposition python3 tig-binary/scripts/build_ptx leverage");
|
||||
eprintln!();
|
||||
eprintln!(" M N T K Matrix dimensions and ranks. Multiple groups run sequentially.");
|
||||
eprintln!(" M = rows, N = cols, T = true rank, K = target rank (K <= T)");
|
||||
eprintln!();
|
||||
eprintln!("Options:");
|
||||
eprintln!(" --seeds N Number of seeds to test per config (default: 5)");
|
||||
eprintln!(" --trials N Leverage score trials per solve (default: algorithm default)");
|
||||
eprintln!(" --gpu N GPU device index (default: 0)");
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
|
||||
if args.len() < 6 {
|
||||
print_usage(&args[0]);
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
let ptx_path = &args[1];
|
||||
|
||||
// Parse positional M N T K groups and named options
|
||||
let mut configs: Vec<(i32, i32, i32, i32)> = Vec::new();
|
||||
let mut num_seeds: usize = 5;
|
||||
let mut num_trials: Option<usize> = None;
|
||||
let mut gpu_device: usize = 0;
|
||||
|
||||
let mut i = 2;
|
||||
while i < args.len() {
|
||||
match args[i].as_str() {
|
||||
"--seeds" => {
|
||||
i += 1;
|
||||
num_seeds = args
|
||||
.get(i)
|
||||
.ok_or_else(|| anyhow!("--seeds requires a value"))?
|
||||
.parse()
|
||||
.map_err(|_| anyhow!("--seeds must be a positive integer"))?;
|
||||
}
|
||||
"--trials" => {
|
||||
i += 1;
|
||||
num_trials = Some(
|
||||
args.get(i)
|
||||
.ok_or_else(|| anyhow!("--trials requires a value"))?
|
||||
.parse()
|
||||
.map_err(|_| anyhow!("--trials must be a positive integer"))?,
|
||||
);
|
||||
}
|
||||
"--gpu" => {
|
||||
i += 1;
|
||||
gpu_device = args
|
||||
.get(i)
|
||||
.ok_or_else(|| anyhow!("--gpu requires a value"))?
|
||||
.parse()
|
||||
.map_err(|_| anyhow!("--gpu must be a non-negative integer"))?;
|
||||
}
|
||||
_ => {
|
||||
// Expect a group of 4 integers: M N T K
|
||||
if i + 3 >= args.len() {
|
||||
eprintln!(
|
||||
"Expected M N T K group at position {}, but not enough arguments.",
|
||||
i
|
||||
);
|
||||
print_usage(&args[0]);
|
||||
std::process::exit(1);
|
||||
}
|
||||
let parse = |s: &str, name: &str| -> Result<i32> {
|
||||
s.parse()
|
||||
.map_err(|_| anyhow!("{} must be a positive integer, got '{}'", name, s))
|
||||
};
|
||||
let m = parse(&args[i], "M")?;
|
||||
let n = parse(&args[i + 1], "N")?;
|
||||
let t = parse(&args[i + 2], "T")?;
|
||||
let k = parse(&args[i + 3], "K")?;
|
||||
configs.push((m, n, t, k));
|
||||
i += 3;
|
||||
}
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
|
||||
if configs.is_empty() {
|
||||
eprintln!("No M N T K configs provided.");
|
||||
print_usage(&args[0]);
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
let hyperparameters = num_trials.map(|t| {
|
||||
let mut map = serde_json::Map::new();
|
||||
map.insert("num_trials".to_string(), serde_json::json!(t));
|
||||
map
|
||||
});
|
||||
|
||||
// ── CUDA setup (shared across all configs) ────────────────────────────
|
||||
let ptx_src = std::fs::read_to_string(ptx_path)
|
||||
.map_err(|e| anyhow!("Failed to read PTX '{}': {}", ptx_path, e))?;
|
||||
// Replace the sentinel fuel-limit with u64::MAX so kernels never abort.
|
||||
let ptx_src = ptx_src.replace("0xdeadbeefdeadbeef", "0xffffffffffffffff");
|
||||
let ptx = Ptx::from_src(ptx_src);
|
||||
|
||||
let num_gpus = CudaContext::device_count()?;
|
||||
if num_gpus == 0 {
|
||||
return Err(anyhow!("No CUDA devices found"));
|
||||
}
|
||||
|
||||
println!("=== CUR Decomposition Benchmark (leverage) ===");
|
||||
println!("PTX : {}", ptx_path);
|
||||
println!("GPU : device {} of {}", gpu_device, num_gpus);
|
||||
println!("Seeds : {}", num_seeds);
|
||||
if let Some(t) = num_trials {
|
||||
println!("Trials : {}", t);
|
||||
}
|
||||
println!("Configs: {}", configs.len());
|
||||
|
||||
let ctx = CudaContext::new(gpu_device)?;
|
||||
ctx.set_blocking_synchronize()?;
|
||||
let module = ctx.load_module(ptx)?;
|
||||
let stream = ctx.default_stream();
|
||||
let prop = get_device_prop(gpu_device as i32)?;
|
||||
|
||||
// ── Run each config ───────────────────────────────────────────────────
|
||||
for (m, n, true_rank, target_rank) in &configs {
|
||||
run_config(
|
||||
*m,
|
||||
*n,
|
||||
*true_rank,
|
||||
*target_rank,
|
||||
num_seeds,
|
||||
&hyperparameters,
|
||||
module.clone(),
|
||||
stream.clone(),
|
||||
&prop,
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
1332
tig-algorithms/lib/cur_decomposition/ptx/leverage.ptx
Normal file
1332
tig-algorithms/lib/cur_decomposition/ptx/leverage.ptx
Normal file
File diff suppressed because one or more lines are too long
44
tig-algorithms/src/cur_decomposition/leverage/kernels.cu
Normal file
44
tig-algorithms/src/cur_decomposition/leverage/kernels.cu
Normal file
@ -0,0 +1,44 @@
|
||||
#include <curand_kernel.h>
|
||||
#include <stdint.h>
|
||||
|
||||
// Fill `size` elements with iid N(0, scale) values.
|
||||
extern "C" __global__ void standard_gaussian_kernel(
|
||||
float *mat, int size, float scale, uint64_t seed
|
||||
) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx >= size) return;
|
||||
curandState state;
|
||||
curand_init((unsigned long long)seed, (unsigned long long)idx, 0, &state);
|
||||
mat[idx] = curand_normal(&state) * scale;
|
||||
}
|
||||
|
||||
// Squared L2 norm of every column of a (rows x cols) column-major matrix.
|
||||
// out[j] = sum_i mat[i + j*rows]^2
|
||||
extern "C" __global__ void col_sq_norms_kernel(
|
||||
const float *mat, float *out, int rows, int cols
|
||||
) {
|
||||
int j = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (j >= cols) return;
|
||||
float sum = 0.0f;
|
||||
const float *col = mat + (long long)j * rows;
|
||||
for (int i = 0; i < rows; i++) {
|
||||
float v = col[i];
|
||||
sum += v * v;
|
||||
}
|
||||
out[j] = sum;
|
||||
}
|
||||
|
||||
// Squared L2 norm of every row of a (rows x cols) column-major matrix.
|
||||
// out[i] = sum_j mat[i + j*rows]^2
|
||||
extern "C" __global__ void row_sq_norms_kernel(
|
||||
const float *mat, float *out, int rows, int cols
|
||||
) {
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i >= rows) return;
|
||||
float sum = 0.0f;
|
||||
for (int j = 0; j < cols; j++) {
|
||||
float v = mat[i + (long long)j * rows];
|
||||
sum += v * v;
|
||||
}
|
||||
out[i] = sum;
|
||||
}
|
||||
653
tig-algorithms/src/cur_decomposition/leverage/mod.rs
Normal file
653
tig-algorithms/src/cur_decomposition/leverage/mod.rs
Normal file
@ -0,0 +1,653 @@
|
||||
// TIG's UI uses the pattern `tig_challenges::<challenge_name>` to automatically detect your algorithm's challenge
|
||||
use anyhow::{anyhow, Result};
|
||||
use core::ffi::c_int;
|
||||
use cudarc::{
|
||||
cublas::{
|
||||
sys::{self as cublas_sys, cublasOperation_t},
|
||||
CudaBlas, Gemm, GemmConfig,
|
||||
},
|
||||
cusolver::{sys as cusolver_sys, DnHandle},
|
||||
driver::{safe::LaunchConfig, CudaModule, CudaSlice, CudaStream, DevicePtr, DevicePtrMut, PushKernelArg},
|
||||
runtime::sys::cudaDeviceProp,
|
||||
};
|
||||
use rand::{rngs::SmallRng, Rng, SeedableRng};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{Map, Value};
|
||||
use std::sync::Arc;
|
||||
use tig_challenges::cur_decomposition::*;
|
||||
|
||||
const MAX_THREADS: u32 = 1024;
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct Hyperparameters {
|
||||
pub num_trials: usize,
|
||||
}
|
||||
|
||||
pub fn help() {
|
||||
println!("Classic leverage score CUR decomposition (GPU).");
|
||||
println!("Hyperparameters:");
|
||||
println!(" num_trials: number of random leverage-score trials (default: 3)");
|
||||
}
|
||||
|
||||
// ─── CPU helpers (only for small k×k operations) ─────────────────────────────
|
||||
|
||||
/// Gauss-Jordan inversion of an n×n column-major matrix. Returns None if singular.
|
||||
fn invert(a: &[f32], n: usize) -> Option<Vec<f32>> {
|
||||
let mut aug = vec![0.0f32; n * 2 * n];
|
||||
for i in 0..n {
|
||||
for j in 0..n {
|
||||
aug[i * 2 * n + j] = a[i + j * n];
|
||||
}
|
||||
aug[i * 2 * n + n + i] = 1.0;
|
||||
}
|
||||
for col in 0..n {
|
||||
let (max_row, max_val) = (col..n)
|
||||
.map(|r| (r, aug[r * 2 * n + col].abs()))
|
||||
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
|
||||
.unwrap();
|
||||
if max_val < 1e-10 {
|
||||
return None;
|
||||
}
|
||||
if max_row != col {
|
||||
for j in 0..(2 * n) {
|
||||
aug.swap(col * 2 * n + j, max_row * 2 * n + j);
|
||||
}
|
||||
}
|
||||
let pivot = aug[col * 2 * n + col];
|
||||
for j in 0..(2 * n) {
|
||||
aug[col * 2 * n + j] /= pivot;
|
||||
}
|
||||
for row in 0..n {
|
||||
if row == col {
|
||||
continue;
|
||||
}
|
||||
let factor = aug[row * 2 * n + col];
|
||||
for j in 0..(2 * n) {
|
||||
let v = aug[col * 2 * n + j];
|
||||
aug[row * 2 * n + j] -= factor * v;
|
||||
}
|
||||
}
|
||||
}
|
||||
let mut inv = vec![0.0f32; n * n];
|
||||
for i in 0..n {
|
||||
for j in 0..n {
|
||||
inv[i + j * n] = aug[i * 2 * n + n + j];
|
||||
}
|
||||
}
|
||||
Some(inv)
|
||||
}
|
||||
|
||||
/// Column-major matrix multiply for small k×k matrices: C = A(m×p) * B(p×n).
|
||||
fn matmul(a: &[f32], m: usize, p: usize, b: &[f32], n: usize) -> Vec<f32> {
|
||||
let mut c = vec![0.0f32; m * n];
|
||||
for j in 0..n {
|
||||
for l in 0..p {
|
||||
let b_lj = b[l + j * p];
|
||||
for i in 0..m {
|
||||
c[i + j * m] += a[i + l * m] * b_lj;
|
||||
}
|
||||
}
|
||||
}
|
||||
c
|
||||
}
|
||||
|
||||
/// Sample k indices without replacement, proportional to weights (all non-negative).
|
||||
fn weighted_sample_k(weights: &[f32], k: usize, rng: &mut SmallRng) -> Vec<usize> {
|
||||
let mut pool: Vec<(usize, f32)> = weights
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, &w)| (i, w.max(0.0) + 1e-12))
|
||||
.collect();
|
||||
let mut out = Vec::with_capacity(k);
|
||||
for _ in 0..k {
|
||||
let total: f32 = pool.iter().map(|(_, w)| w).sum();
|
||||
let mut r = rng.gen::<f32>() * total;
|
||||
let mut chosen = pool.len() - 1;
|
||||
for (idx, &(_, w)) in pool.iter().enumerate() {
|
||||
r -= w;
|
||||
if r <= 0.0 {
|
||||
chosen = idx;
|
||||
break;
|
||||
}
|
||||
}
|
||||
out.push(pool[chosen].0);
|
||||
pool.swap_remove(chosen);
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
// ─── GPU helpers ─────────────────────────────────────────────────────────────
|
||||
|
||||
/// In-place QR decomposition on GPU: d_mat (m×n) is overwritten with Q (m×n, orthonormal cols).
|
||||
fn gpu_qr(
|
||||
cusolver: &DnHandle,
|
||||
stream: &Arc<CudaStream>,
|
||||
d_mat: &mut CudaSlice<f32>,
|
||||
m: c_int,
|
||||
n: c_int,
|
||||
) -> Result<()> {
|
||||
let min_mn = m.min(n);
|
||||
|
||||
// ── geqrf ──────────────────────────────────────────────────────────────
|
||||
let mut lwork = 0 as c_int;
|
||||
unsafe {
|
||||
let stat = cusolver_sys::cusolverDnSgeqrf_bufferSize(
|
||||
cusolver.cu(),
|
||||
m,
|
||||
n,
|
||||
d_mat.device_ptr_mut(stream).0 as *mut f32,
|
||||
m,
|
||||
&mut lwork as *mut c_int,
|
||||
);
|
||||
if stat != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
|
||||
return Err(anyhow!("cusolverDnSgeqrf_bufferSize failed"));
|
||||
}
|
||||
}
|
||||
|
||||
let ws = (lwork as usize).max(1);
|
||||
let mut d_work = stream.alloc_zeros::<f32>(ws)?;
|
||||
let mut d_info = stream.alloc_zeros::<c_int>(1)?;
|
||||
let mut d_tau = stream.alloc_zeros::<f32>(min_mn as usize)?;
|
||||
|
||||
unsafe {
|
||||
let stat = cusolver_sys::cusolverDnSgeqrf(
|
||||
cusolver.cu(),
|
||||
m,
|
||||
n,
|
||||
d_mat.device_ptr_mut(stream).0 as *mut f32,
|
||||
m,
|
||||
d_tau.device_ptr_mut(stream).0 as *mut f32,
|
||||
d_work.device_ptr_mut(stream).0 as *mut f32,
|
||||
lwork,
|
||||
d_info.device_ptr_mut(stream).0 as *mut c_int,
|
||||
);
|
||||
if stat != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
|
||||
return Err(anyhow!("cusolverDnSgeqrf failed"));
|
||||
}
|
||||
}
|
||||
stream.synchronize()?;
|
||||
|
||||
// ── orgqr ──────────────────────────────────────────────────────────────
|
||||
let mut lwork_q = 0 as c_int;
|
||||
unsafe {
|
||||
let stat = cusolver_sys::cusolverDnSorgqr_bufferSize(
|
||||
cusolver.cu(),
|
||||
m,
|
||||
n,
|
||||
min_mn,
|
||||
d_mat.device_ptr_mut(stream).0 as *const f32,
|
||||
m,
|
||||
d_tau.device_ptr_mut(stream).0 as *const f32,
|
||||
&mut lwork_q as *mut c_int,
|
||||
);
|
||||
if stat != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
|
||||
return Err(anyhow!("cusolverDnSorgqr_bufferSize failed"));
|
||||
}
|
||||
}
|
||||
|
||||
let ws_q = (lwork_q as usize).max(1);
|
||||
let mut d_work_q = stream.alloc_zeros::<f32>(ws_q)?;
|
||||
|
||||
unsafe {
|
||||
let stat = cusolver_sys::cusolverDnSorgqr(
|
||||
cusolver.cu(),
|
||||
m,
|
||||
n,
|
||||
min_mn,
|
||||
d_mat.device_ptr_mut(stream).0 as *mut f32,
|
||||
m,
|
||||
d_tau.device_ptr_mut(stream).0 as *const f32,
|
||||
d_work_q.device_ptr_mut(stream).0 as *mut f32,
|
||||
lwork_q,
|
||||
d_info.device_ptr_mut(stream).0 as *mut c_int,
|
||||
);
|
||||
if stat != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
|
||||
return Err(anyhow!("cusolverDnSorgqr failed"));
|
||||
}
|
||||
}
|
||||
stream.synchronize()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Launch the norm kernel over `count` outputs.
|
||||
fn launch_norm_kernel(
|
||||
stream: &Arc<CudaStream>,
|
||||
kernel: &cudarc::driver::CudaFunction,
|
||||
d_mat: &CudaSlice<f32>,
|
||||
d_out: &mut CudaSlice<f32>,
|
||||
rows: i32,
|
||||
cols: i32,
|
||||
count: u32,
|
||||
) -> Result<()> {
|
||||
unsafe {
|
||||
stream
|
||||
.launch_builder(kernel)
|
||||
.arg(d_mat)
|
||||
.arg(d_out)
|
||||
.arg(&rows)
|
||||
.arg(&cols)
|
||||
.launch(LaunchConfig {
|
||||
grid_dim: ((count + MAX_THREADS - 1) / MAX_THREADS, 1, 1),
|
||||
block_dim: (MAX_THREADS, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
})?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ─── Solver ──────────────────────────────────────────────────────────────────
|
||||
|
||||
pub fn solve_challenge(
|
||||
challenge: &Challenge,
|
||||
save_solution: &dyn Fn(&Solution) -> Result<()>,
|
||||
hyperparameters: &Option<Map<String, Value>>,
|
||||
module: Arc<CudaModule>,
|
||||
stream: Arc<CudaStream>,
|
||||
_prop: &cudaDeviceProp,
|
||||
) -> anyhow::Result<Option<Solution>> {
|
||||
let hp = match hyperparameters {
|
||||
Some(hp) => {
|
||||
serde_json::from_value::<Hyperparameters>(Value::Object(hp.clone()))
|
||||
.map_err(|e| anyhow!("Failed to parse hyperparameters: {}", e))?
|
||||
}
|
||||
None => Hyperparameters { num_trials: 3 },
|
||||
};
|
||||
|
||||
let m = challenge.m;
|
||||
let n = challenge.n;
|
||||
let k = challenge.target_k;
|
||||
let m_sz = m as usize;
|
||||
let n_sz = n as usize;
|
||||
let k_sz = k as usize;
|
||||
let num_trials = hp.num_trials.max(1);
|
||||
|
||||
let mut rng = SmallRng::from_seed(challenge.seed);
|
||||
let seed0 = u64::from_le_bytes(challenge.seed[0..8].try_into()?);
|
||||
let seed1 = u64::from_le_bytes(challenge.seed[8..16].try_into()?);
|
||||
|
||||
// Sketch dimension: a bit larger than k for better approximation.
|
||||
let s = (k + 10).min(m).min(n);
|
||||
let s_sz = s as usize;
|
||||
|
||||
// ── GPU handles and kernels ───────────────────────────────────────────────
|
||||
let cublas = CudaBlas::new(stream.clone())?;
|
||||
let cusolver = DnHandle::new(stream.clone())?;
|
||||
|
||||
let gaussian_kernel = module.load_function("standard_gaussian_kernel")?;
|
||||
let col_norms_kernel = module.load_function("col_sq_norms_kernel")?;
|
||||
let row_norms_kernel = module.load_function("row_sq_norms_kernel")?;
|
||||
let extract_cols_kernel = module.load_function("extract_columns_kernel")?;
|
||||
let extract_rows_kernel = module.load_function("extract_rows_kernel")?;
|
||||
|
||||
// Reusable m×n buffer for the reconstruction check.
|
||||
let mut d_cur_buf = stream.alloc_zeros::<f32>(m_sz * n_sz)?;
|
||||
|
||||
let mut best_fnorm = f32::INFINITY;
|
||||
let mut best_solution: Option<Solution> = None;
|
||||
|
||||
// ── Helper: extract C/R, solve for U, evaluate fnorm, optionally save ────
|
||||
// Returns (fnorm) or None if the system is singular.
|
||||
let mut run_trial = |c_idxs_i32: Vec<i32>,
|
||||
r_idxs_i32: Vec<i32>,
|
||||
rng_inner: &mut SmallRng|
|
||||
-> Result<Option<(Vec<i32>, Vec<f32>, Vec<i32>, f32)>> {
|
||||
let d_c_idxs = stream.memcpy_stod(&c_idxs_i32)?;
|
||||
let d_r_idxs = stream.memcpy_stod(&r_idxs_i32)?;
|
||||
|
||||
let c_size = m_sz * k_sz;
|
||||
let r_size = k_sz * n_sz;
|
||||
let mut d_c = stream.alloc_zeros::<f32>(c_size)?;
|
||||
let mut d_r = stream.alloc_zeros::<f32>(r_size)?;
|
||||
|
||||
unsafe {
|
||||
stream
|
||||
.launch_builder(&extract_cols_kernel)
|
||||
.arg(&challenge.d_a_mat)
|
||||
.arg(&mut d_c)
|
||||
.arg(&m)
|
||||
.arg(&n)
|
||||
.arg(&k)
|
||||
.arg(&d_c_idxs)
|
||||
.launch(LaunchConfig {
|
||||
grid_dim: ((c_size as u32 + MAX_THREADS - 1) / MAX_THREADS, 1, 1),
|
||||
block_dim: (MAX_THREADS, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
})?;
|
||||
|
||||
stream
|
||||
.launch_builder(&extract_rows_kernel)
|
||||
.arg(&challenge.d_a_mat)
|
||||
.arg(&mut d_r)
|
||||
.arg(&m)
|
||||
.arg(&n)
|
||||
.arg(&k)
|
||||
.arg(&d_r_idxs)
|
||||
.launch(LaunchConfig {
|
||||
grid_dim: ((r_size as u32 + MAX_THREADS - 1) / MAX_THREADS, 1, 1),
|
||||
block_dim: (MAX_THREADS, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
})?;
|
||||
}
|
||||
|
||||
// C^T C (k×k)
|
||||
let mut d_ctc = stream.alloc_zeros::<f32>(k_sz * k_sz)?;
|
||||
unsafe {
|
||||
cublas.gemm(
|
||||
GemmConfig {
|
||||
transa: cublasOperation_t::CUBLAS_OP_T,
|
||||
transb: cublasOperation_t::CUBLAS_OP_N,
|
||||
m: k, n: k, k: m,
|
||||
alpha: 1.0f32, lda: m, ldb: m, beta: 0.0f32, ldc: k,
|
||||
},
|
||||
&d_c,
|
||||
&d_c,
|
||||
&mut d_ctc,
|
||||
)?;
|
||||
}
|
||||
|
||||
// R R^T (k×k)
|
||||
let mut d_rrt = stream.alloc_zeros::<f32>(k_sz * k_sz)?;
|
||||
unsafe {
|
||||
cublas.gemm(
|
||||
GemmConfig {
|
||||
transa: cublasOperation_t::CUBLAS_OP_N,
|
||||
transb: cublasOperation_t::CUBLAS_OP_T,
|
||||
m: k, n: k, k: n,
|
||||
alpha: 1.0f32, lda: k, ldb: k, beta: 0.0f32, ldc: k,
|
||||
},
|
||||
&d_r,
|
||||
&d_r,
|
||||
&mut d_rrt,
|
||||
)?;
|
||||
}
|
||||
|
||||
// C^T A (k×n)
|
||||
let mut d_cta = stream.alloc_zeros::<f32>(k_sz * n_sz)?;
|
||||
unsafe {
|
||||
cublas.gemm(
|
||||
GemmConfig {
|
||||
transa: cublasOperation_t::CUBLAS_OP_T,
|
||||
transb: cublasOperation_t::CUBLAS_OP_N,
|
||||
m: k, n, k: m,
|
||||
alpha: 1.0f32, lda: m, ldb: m, beta: 0.0f32, ldc: k,
|
||||
},
|
||||
&d_c,
|
||||
&challenge.d_a_mat,
|
||||
&mut d_cta,
|
||||
)?;
|
||||
}
|
||||
|
||||
// (C^T A) R^T (k×k) = M
|
||||
let mut d_m = stream.alloc_zeros::<f32>(k_sz * k_sz)?;
|
||||
unsafe {
|
||||
cublas.gemm(
|
||||
GemmConfig {
|
||||
transa: cublasOperation_t::CUBLAS_OP_N,
|
||||
transb: cublasOperation_t::CUBLAS_OP_T,
|
||||
m: k, n: k, k: n,
|
||||
alpha: 1.0f32, lda: k, ldb: k, beta: 0.0f32, ldc: k,
|
||||
},
|
||||
&d_cta,
|
||||
&d_r,
|
||||
&mut d_m,
|
||||
)?;
|
||||
}
|
||||
|
||||
stream.synchronize()?;
|
||||
let ctc = stream.memcpy_dtov(&d_ctc)?;
|
||||
let rrt = stream.memcpy_dtov(&d_rrt)?;
|
||||
let m_mat = stream.memcpy_dtov(&d_m)?;
|
||||
|
||||
// Invert k×k matrices on CPU (k is small).
|
||||
let ctc_inv = match invert(&ctc, k_sz) {
|
||||
Some(v) => v,
|
||||
None => return Ok(None),
|
||||
};
|
||||
let rrt_inv = match invert(&rrt, k_sz) {
|
||||
Some(v) => v,
|
||||
None => return Ok(None),
|
||||
};
|
||||
|
||||
// U = (C^T C)^{-1} M (R R^T)^{-1} (k×k, on CPU)
|
||||
let tmp = matmul(&ctc_inv, k_sz, k_sz, &m_mat, k_sz);
|
||||
let u_mat = matmul(&tmp, k_sz, k_sz, &rrt_inv, k_sz);
|
||||
|
||||
// Upload U and compute CUR = C U R on GPU.
|
||||
let d_u = stream.memcpy_stod(&u_mat)?;
|
||||
|
||||
// CU = C * U (m×k)
|
||||
let mut d_cu = stream.alloc_zeros::<f32>(m_sz * k_sz)?;
|
||||
unsafe {
|
||||
cublas.gemm(
|
||||
GemmConfig {
|
||||
transa: cublasOperation_t::CUBLAS_OP_N,
|
||||
transb: cublasOperation_t::CUBLAS_OP_N,
|
||||
m, n: k, k,
|
||||
alpha: 1.0f32, lda: m, ldb: k, beta: 0.0f32, ldc: m,
|
||||
},
|
||||
&d_c,
|
||||
&d_u,
|
||||
&mut d_cu,
|
||||
)?;
|
||||
}
|
||||
|
||||
// CUR = CU * R (m×n) — written into d_cur_buf
|
||||
unsafe {
|
||||
cublas.gemm(
|
||||
GemmConfig {
|
||||
transa: cublasOperation_t::CUBLAS_OP_N,
|
||||
transb: cublasOperation_t::CUBLAS_OP_N,
|
||||
m, n, k,
|
||||
alpha: 1.0f32, lda: m, ldb: k, beta: 0.0f32, ldc: m,
|
||||
},
|
||||
&d_cu,
|
||||
&d_r,
|
||||
&mut d_cur_buf,
|
||||
)?;
|
||||
}
|
||||
|
||||
// fnorm = ||A - CUR||_F via axpy + nrm2.
|
||||
let mn = (m * n) as c_int;
|
||||
let alpha_neg: f32 = -1.0;
|
||||
let mut fnorm = 0.0f32;
|
||||
unsafe {
|
||||
let (a_ptr, _ag) = challenge.d_a_mat.device_ptr(&stream);
|
||||
let (cur_ptr, _cg) = d_cur_buf.device_ptr_mut(&stream);
|
||||
cublas_sys::cublasSaxpy_v2(
|
||||
*cublas.handle(),
|
||||
mn,
|
||||
&alpha_neg as *const f32,
|
||||
a_ptr as *const f32,
|
||||
1,
|
||||
cur_ptr as *mut f32,
|
||||
1,
|
||||
)
|
||||
.result()?;
|
||||
cublas_sys::cublasSnrm2_v2(
|
||||
*cublas.handle(),
|
||||
mn,
|
||||
cur_ptr as *const f32,
|
||||
1,
|
||||
&mut fnorm as *mut f32,
|
||||
)
|
||||
.result()?;
|
||||
}
|
||||
stream.synchronize()?;
|
||||
|
||||
Ok(Some((c_idxs_i32, u_mat, r_idxs_i32, fnorm)))
|
||||
};
|
||||
|
||||
// ── Warm-start: column/row squared norms of A (no sketch/QR needed) ──────
|
||||
{
|
||||
let mut d_col_norms = stream.alloc_zeros::<f32>(n_sz)?;
|
||||
let mut d_row_norms = stream.alloc_zeros::<f32>(m_sz)?;
|
||||
launch_norm_kernel(&stream, &col_norms_kernel, &challenge.d_a_mat, &mut d_col_norms, m, n, n as u32)?;
|
||||
launch_norm_kernel(&stream, &row_norms_kernel, &challenge.d_a_mat, &mut d_row_norms, m, n, m as u32)?;
|
||||
stream.synchronize()?;
|
||||
let col_norms = stream.memcpy_dtov(&d_col_norms)?;
|
||||
let row_norms = stream.memcpy_dtov(&d_row_norms)?;
|
||||
|
||||
let c_idxs = weighted_sample_k(&col_norms, k_sz, &mut rng);
|
||||
let r_idxs = weighted_sample_k(&row_norms, k_sz, &mut rng);
|
||||
let c_i32: Vec<i32> = c_idxs.iter().map(|&i| i as i32).collect();
|
||||
let r_i32: Vec<i32> = r_idxs.iter().map(|&i| i as i32).collect();
|
||||
|
||||
if let Ok(Some((ci, u, ri, fnorm))) = run_trial(c_i32, r_i32, &mut rng) {
|
||||
let sol = Solution { c_idxs: ci, u_mat: u, r_idxs: ri };
|
||||
save_solution(&sol)?;
|
||||
best_fnorm = fnorm;
|
||||
best_solution = Some(sol);
|
||||
}
|
||||
}
|
||||
|
||||
// ── Leverage score trials ─────────────────────────────────────────────────
|
||||
for trial in 0..num_trials {
|
||||
let col_seed = seed0 ^ (trial as u64).wrapping_mul(0xA1B2_C3D4_E5F6_0718);
|
||||
let row_seed = seed1 ^ (trial as u64).wrapping_mul(0x1827_3645_5463_7281);
|
||||
let scale = 1.0f32 / (s as f32).sqrt();
|
||||
|
||||
// ── Column leverage scores ──────────────────────────────────────────
|
||||
// Omega_c: n×s ~ N(0, 1/sqrt(s))
|
||||
let mut d_omega_c = stream.alloc_zeros::<f32>(n_sz * s_sz)?;
|
||||
unsafe {
|
||||
stream
|
||||
.launch_builder(&gaussian_kernel)
|
||||
.arg(&mut d_omega_c)
|
||||
.arg(&(n * s))
|
||||
.arg(&scale)
|
||||
.arg(&col_seed)
|
||||
.launch(LaunchConfig {
|
||||
grid_dim: ((n_sz * s_sz) as u32 / MAX_THREADS + 1, 1, 1),
|
||||
block_dim: (MAX_THREADS, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
})?;
|
||||
}
|
||||
|
||||
// Y_c = A * Omega_c (m×s)
|
||||
let mut d_y_c = stream.alloc_zeros::<f32>(m_sz * s_sz)?;
|
||||
unsafe {
|
||||
cublas.gemm(
|
||||
GemmConfig {
|
||||
transa: cublasOperation_t::CUBLAS_OP_N,
|
||||
transb: cublasOperation_t::CUBLAS_OP_N,
|
||||
m, n: s, k: n,
|
||||
alpha: 1.0f32, lda: m, ldb: n, beta: 0.0f32, ldc: m,
|
||||
},
|
||||
&challenge.d_a_mat,
|
||||
&d_omega_c,
|
||||
&mut d_y_c,
|
||||
)?;
|
||||
}
|
||||
drop(d_omega_c);
|
||||
|
||||
// Q_c = QR(Y_c) in-place (m×s, orthonormal)
|
||||
gpu_qr(&cusolver, &stream, &mut d_y_c, m, s)?;
|
||||
let d_q_c = d_y_c;
|
||||
|
||||
// Z_c = Q_c^T * A (s×n)
|
||||
let mut d_z_c = stream.alloc_zeros::<f32>(s_sz * n_sz)?;
|
||||
unsafe {
|
||||
cublas.gemm(
|
||||
GemmConfig {
|
||||
transa: cublasOperation_t::CUBLAS_OP_T,
|
||||
transb: cublasOperation_t::CUBLAS_OP_N,
|
||||
m: s, n, k: m,
|
||||
alpha: 1.0f32, lda: m, ldb: m, beta: 0.0f32, ldc: s,
|
||||
},
|
||||
&d_q_c,
|
||||
&challenge.d_a_mat,
|
||||
&mut d_z_c,
|
||||
)?;
|
||||
}
|
||||
drop(d_q_c);
|
||||
|
||||
// col_lev[j] = ||Z_c[:, j]||²
|
||||
let mut d_col_lev = stream.alloc_zeros::<f32>(n_sz)?;
|
||||
launch_norm_kernel(&stream, &col_norms_kernel, &d_z_c, &mut d_col_lev, s, n, n as u32)?;
|
||||
drop(d_z_c);
|
||||
|
||||
// ── Row leverage scores ─────────────────────────────────────────────
|
||||
// Omega_r: m×s ~ N(0, 1/sqrt(s))
|
||||
let mut d_omega_r = stream.alloc_zeros::<f32>(m_sz * s_sz)?;
|
||||
unsafe {
|
||||
stream
|
||||
.launch_builder(&gaussian_kernel)
|
||||
.arg(&mut d_omega_r)
|
||||
.arg(&(m * s))
|
||||
.arg(&scale)
|
||||
.arg(&row_seed)
|
||||
.launch(LaunchConfig {
|
||||
grid_dim: ((m_sz * s_sz) as u32 / MAX_THREADS + 1, 1, 1),
|
||||
block_dim: (MAX_THREADS, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
})?;
|
||||
}
|
||||
|
||||
// Y_r = A^T * Omega_r (n×s)
|
||||
let mut d_y_r = stream.alloc_zeros::<f32>(n_sz * s_sz)?;
|
||||
unsafe {
|
||||
cublas.gemm(
|
||||
GemmConfig {
|
||||
transa: cublasOperation_t::CUBLAS_OP_T,
|
||||
transb: cublasOperation_t::CUBLAS_OP_N,
|
||||
m: n, n: s, k: m,
|
||||
alpha: 1.0f32, lda: m, ldb: m, beta: 0.0f32, ldc: n,
|
||||
},
|
||||
&challenge.d_a_mat,
|
||||
&d_omega_r,
|
||||
&mut d_y_r,
|
||||
)?;
|
||||
}
|
||||
drop(d_omega_r);
|
||||
|
||||
// Q_r = QR(Y_r) in-place (n×s, orthonormal)
|
||||
gpu_qr(&cusolver, &stream, &mut d_y_r, n, s)?;
|
||||
let d_q_r = d_y_r;
|
||||
|
||||
// Z_r = A * Q_r (m×s)
|
||||
let mut d_z_r = stream.alloc_zeros::<f32>(m_sz * s_sz)?;
|
||||
unsafe {
|
||||
cublas.gemm(
|
||||
GemmConfig {
|
||||
transa: cublasOperation_t::CUBLAS_OP_N,
|
||||
transb: cublasOperation_t::CUBLAS_OP_N,
|
||||
m, n: s, k: n,
|
||||
alpha: 1.0f32, lda: m, ldb: n, beta: 0.0f32, ldc: m,
|
||||
},
|
||||
&challenge.d_a_mat,
|
||||
&d_q_r,
|
||||
&mut d_z_r,
|
||||
)?;
|
||||
}
|
||||
drop(d_q_r);
|
||||
|
||||
// row_lev[i] = ||Z_r[i, :]||²
|
||||
let mut d_row_lev = stream.alloc_zeros::<f32>(m_sz)?;
|
||||
launch_norm_kernel(&stream, &row_norms_kernel, &d_z_r, &mut d_row_lev, m, s, m as u32)?;
|
||||
drop(d_z_r);
|
||||
|
||||
// ── Copy scores to CPU and sample ───────────────────────────────────
|
||||
stream.synchronize()?;
|
||||
let col_lev = stream.memcpy_dtov(&d_col_lev)?;
|
||||
let row_lev = stream.memcpy_dtov(&d_row_lev)?;
|
||||
|
||||
let c_idxs = weighted_sample_k(&col_lev, k_sz, &mut rng);
|
||||
let r_idxs = weighted_sample_k(&row_lev, k_sz, &mut rng);
|
||||
let c_i32: Vec<i32> = c_idxs.iter().map(|&i| i as i32).collect();
|
||||
let r_i32: Vec<i32> = r_idxs.iter().map(|&i| i as i32).collect();
|
||||
|
||||
// ── Compute U, evaluate fnorm, save if improved ─────────────────────
|
||||
if let Ok(Some((ci, u, ri, fnorm))) = run_trial(c_i32, r_i32, &mut rng) {
|
||||
if fnorm < best_fnorm {
|
||||
best_fnorm = fnorm;
|
||||
let sol = Solution { c_idxs: ci, u_mat: u, r_idxs: ri };
|
||||
save_solution(&sol)?;
|
||||
best_solution = Some(sol);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(best_solution)
|
||||
}
|
||||
|
||||
// Important! Do not include any tests in this file, it will result in your submission being rejected
|
||||
@ -532,6 +532,7 @@ impl Challenge {
|
||||
let src = &matrices[matrix_idx];
|
||||
let mut d_a_mat = stream.alloc_zeros::<f32>(mat_size)?;
|
||||
let alpha: f32 = 1.0;
|
||||
{
|
||||
let (src_ptr, _src_record) = src.device_ptr(&stream);
|
||||
let (dst_ptr, _dst_record) = d_a_mat.device_ptr_mut(&stream);
|
||||
unsafe {
|
||||
@ -546,6 +547,7 @@ impl Challenge {
|
||||
)
|
||||
.result()?;
|
||||
}
|
||||
}
|
||||
stream.synchronize()?;
|
||||
|
||||
challenges.push(Challenge {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user