Working nn_training challenge.

This commit is contained in:
FiveMovesAhead 2025-07-16 21:38:05 +01:00
parent 075ada5796
commit 0ea3676730
22 changed files with 4130 additions and 13 deletions

View File

@ -8,16 +8,13 @@ on:
- 'knapsack/*'
- 'vector_search/*'
- 'hypergraph/*'
- 'nn_training/*'
- 'test/satisfiability/*'
- 'test/vehicle_routing/*'
- 'test/knapsack/*'
- 'test/vector_search/*'
- 'test/hypergraph/*'
- 'dev/satisfiability/*'
- 'dev/vehicle_routing/*'
- 'dev/knapsack/*'
- 'dev/vector_search/*'
- 'dev/hypergraph/*'
- 'test/nn_training/*'
jobs:
init:

View File

@ -95,7 +95,7 @@ f"""Library not found at {so_path}:
"knapsack": "c003",
"vector_search": "c004",
"hypergraph": "c005",
"optimiser": "c006",
"nn_training": "c006",
}
challenge_id = challenge_ids[CHALLENGE]

View File

@ -23,14 +23,15 @@ tig-challenges = { path = "../tig-challenges" }
crate-type = ["cdylib", "rlib"]
[features]
cuda = ["cudarc"]
c001 = ["tig-challenges/c001"]
satisfiability = ["c001"]
c002 = ["tig-challenges/c002"]
vehicle_routing = ["c002"]
c003 = ["tig-challenges/c003"]
knapsack = ["c003"]
c004 = ["cuda", "tig-challenges/c004"]
c004 = ["cudarc", "tig-challenges/c004"]
vector_search = ["c004"]
c005 = ["cuda", "tig-challenges/c005"]
c005 = ["cudarc", "tig-challenges/c005"]
hypergraph = ["c005"]
c006 = ["cudarc", "tig-challenges/c006"]
optimizer = ["c006"]

View File

@ -31,3 +31,7 @@ pub use vector_search as c004;
pub mod hypergraph;
#[cfg(feature = "c005")]
pub use hypergraph as c005;
#[cfg(feature = "c006")]
pub mod nn_training;
#[cfg(feature = "c006")]
pub use nn_training as c006;

File diff suppressed because it is too large Load Diff

View File

@ -30,3 +30,5 @@ c004 = ["cuda", "tig-algorithms/c004", "tig-challenges/c004"]
vector_search = ["c004"]
c005 = ["cuda", "tig-algorithms/c005", "tig-challenges/c005"]
hypergraph = ["c005"]
c006 = ["cuda", "tig-algorithms/c006", "tig-challenges/c006"]
nn_training = ["c006"]

View File

@ -41,8 +41,13 @@ case "$CHALLENGE" in
build_so $ALGORITHM
build_ptx $ALGORITHM
;;
nn_training)
echo "Building ALGORITHM '$ALGORITHM' for CHALLENGE 'nn_training'"
build_so $ALGORITHM
build_ptx $ALGORITHM --extra-cu-files tig-challenges/src/nn/kernels.cu
;;
*)
echo "Error: Invalid CHALLENGE value. Must be one of: satisfiability, knapsack, vehicle_routing, vector_search, hypergraph"
echo "Error: Invalid CHALLENGE value. Must be one of: satisfiability, knapsack, vehicle_routing, vector_search, hypergraph, nn_training"
exit 1
;;
esac

View File

@ -237,6 +237,7 @@ $NORMAL_EXIT:
def main():
parser = argparse.ArgumentParser(description='Compile PTX with injected runtime signature')
parser.add_argument('algorithm', help='Algorithm name')
parser.add_argument('--extra-cu-files', nargs='*', default=[], help='Additional .cu files to include in the compilation')
args = parser.parse_args()
@ -272,6 +273,11 @@ def main():
code = f.read() + "\n"
with open(challenge_cu, 'r') as f:
code += f.read() + "\n"
for extra_cu in args.extra_cu_files:
if not os.path.exists(extra_cu):
raise FileNotFoundError(f"Extra .cu file does not exist: {extra_cu}")
with open(extra_cu, 'r') as f:
code += f.read() + "\n"
kernel_regex = r'(?:extern\s+"C"\s+__global__|__device__)\s+\w+\s+(?P<func>\w+)\s*\('
kernels_to_ignore = [match.group('func') for match in re.finditer(kernel_regex, code)]
with open(algorithm_cu, 'r') as f:

View File

@ -22,14 +22,15 @@ serde_json = { version = "1.0.113" }
statrs = { version = "0.18.0" }
[features]
cuda = ["cudarc"]
c001 = []
satisfiability = ["c001"]
c002 = []
vehicle_routing = ["c002"]
c003 = []
knapsack = ["c003"]
c004 = ["cuda"]
c004 = ["cudarc"]
vector_search = ["c004"]
c005 = ["cuda"]
c005 = ["cudarc"]
hypergraph = ["c005"]
c006 = ["cudarc", "cudarc/cublas", "cudarc/cudnn"]
nn_training = ["c006"]

View File

@ -20,3 +20,9 @@ pub use vector_search as c004;
pub mod hypergraph;
#[cfg(feature = "c005")]
pub use hypergraph as c005;
#[cfg(feature = "c006")]
pub(crate) mod nn;
#[cfg(feature = "c006")]
pub mod nn_training;
#[cfg(feature = "c006")]
pub use nn_training as c006;

View File

@ -0,0 +1,338 @@
use super::CudnnTensorDescriptor;
use anyhow::Result;
use cudarc::{
cudnn::{result::CudnnError, sys::*, Cudnn},
driver::{
CudaModule, CudaSlice, CudaStream, DevicePtr, DevicePtrMut, LaunchConfig, PushKernelArg,
},
};
use std::sync::Arc;
const THREADS_PER_BLOCK: u32 = 1024;
pub struct BatchNorm1d {
num_features: usize,
momentum: f64,
eps: f64,
pub weight: CudaSlice<f32>,
pub bias: CudaSlice<f32>,
pub running_mean: CudaSlice<f32>,
pub running_var: CudaSlice<f32>,
pub requires_grad: bool,
pub weight_grad: Option<CudaSlice<f32>>,
pub bias_grad: Option<CudaSlice<f32>>,
// cuDNN specific cache for backward pass
saved_mean: CudaSlice<f32>,
saved_inv_variance: CudaSlice<f32>,
}
impl BatchNorm1d {
pub fn new(
num_features: usize,
momentum: f64,
eps: f64,
requires_grad: bool,
stream: Arc<CudaStream>,
) -> Result<Self> {
let weight = stream.memcpy_stod(&vec![1.0; num_features])?; // Init with ones (scale)
let bias = stream.alloc_zeros::<f32>(num_features)?;
let running_mean = stream.alloc_zeros::<f32>(num_features)?;
let running_var = stream.memcpy_stod(&vec![1.0; num_features])?; // Init with ones
let (weight_grad, bias_grad) = if requires_grad {
(
Some(stream.alloc_zeros::<f32>(num_features)?),
Some(stream.alloc_zeros::<f32>(num_features)?),
)
} else {
(None, None)
};
// These are populated by forward pass for use in backward pass
let saved_mean = stream.alloc_zeros::<f32>(num_features)?;
let saved_inv_variance = stream.alloc_zeros::<f32>(num_features)?;
Ok(Self {
num_features,
momentum,
eps,
weight,
bias,
running_mean,
running_var,
requires_grad,
weight_grad,
bias_grad,
saved_mean,
saved_inv_variance,
})
}
pub fn forward<I: DevicePtr<f32>>(
&mut self,
input: &I,
training: bool,
stream: Arc<CudaStream>,
cudnn: &Cudnn,
) -> Result<CudaSlice<f32>> {
let batch_size = input.len() / self.num_features;
let mut output = stream.alloc_zeros::<f32>(input.len())?;
// For 1D batch norm, set up tensors as (N, C, 1, 1) but use SPATIAL mode
let mut x_desc = CudnnTensorDescriptor::new()?;
x_desc.set_4d(
cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
cudnnDataType_t::CUDNN_DATA_FLOAT,
batch_size as i32,
self.num_features as i32,
1,
1,
)?;
let mut y_desc = CudnnTensorDescriptor::new()?;
y_desc.set_4d(
cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
cudnnDataType_t::CUDNN_DATA_FLOAT,
batch_size as i32,
self.num_features as i32,
1,
1,
)?;
let mut derived_bn_desc = CudnnTensorDescriptor::new()?;
derived_bn_desc.set_4d(
cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
cudnnDataType_t::CUDNN_DATA_FLOAT,
1,
self.num_features as i32,
1,
1,
)?;
let alpha = 1.0f32;
let beta = 0.0f32;
let alpha_ptr = &alpha as *const f32 as *const std::ffi::c_void;
let beta_ptr = &beta as *const f32 as *const std::ffi::c_void;
// Use SPATIAL mode for 1D batch normalization
let mode = cudnnBatchNormMode_t::CUDNN_BATCHNORM_SPATIAL;
let status = if training {
unsafe {
cudnnBatchNormalizationForwardTraining(
cudnn.handle,
mode,
alpha_ptr,
beta_ptr,
*x_desc,
input.device_ptr(&stream).0 as *const _,
*y_desc,
output.device_ptr_mut(&stream).0 as *mut _,
*derived_bn_desc,
self.weight.device_ptr(&stream).0 as *const _,
self.bias.device_ptr(&stream).0 as *const _,
self.momentum,
self.running_mean.device_ptr_mut(&stream).0 as *mut _,
self.running_var.device_ptr_mut(&stream).0 as *mut _,
self.eps,
self.saved_mean.device_ptr_mut(&stream).0 as *mut _,
self.saved_inv_variance.device_ptr_mut(&stream).0 as *mut _,
)
}
} else {
unsafe {
cudnnBatchNormalizationForwardInference(
cudnn.handle,
mode,
alpha_ptr,
beta_ptr,
*x_desc,
input.device_ptr(&stream).0 as *const _,
*y_desc,
output.device_ptr_mut(&stream).0 as *mut _,
*derived_bn_desc,
self.weight.device_ptr(&stream).0 as *const _,
self.bias.device_ptr(&stream).0 as *const _,
self.running_mean.device_ptr(&stream).0 as *const _,
self.running_var.device_ptr(&stream).0 as *const _,
self.eps,
)
}
};
// Debug: Check saved_mean and saved_inv_variance after forward pass if training
if training {
let mut saved_mean_sample = vec![0.0f32; 5.min(self.saved_mean.len())];
let mut saved_inv_variance_sample = vec![0.0f32; 5.min(self.saved_inv_variance.len())];
stream.memcpy_dtoh(
&self.saved_mean.slice(0..5.min(self.saved_mean.len())),
&mut saved_mean_sample,
)?;
stream.memcpy_dtoh(
&self
.saved_inv_variance
.slice(0..5.min(self.saved_inv_variance.len())),
&mut saved_inv_variance_sample,
)?;
stream.synchronize()?;
}
if status == cudnnStatus_t::CUDNN_STATUS_SUCCESS {
Ok(output)
} else {
Err(CudnnError(status).into())
}
}
pub fn backward(
&mut self,
input: &CudaSlice<f32>,
grad_output: &CudaSlice<f32>,
should_accumulate_gradients: bool,
stream: Arc<CudaStream>,
cudnn: &Cudnn,
_module: Arc<CudaModule>,
) -> Result<CudaSlice<f32>> {
let batch_size = input.len() / self.num_features;
let mut grad_input = stream.alloc_zeros::<f32>(input.len())?;
// Set up tensor descriptors (same as forward pass)
let mut x_desc = CudnnTensorDescriptor::new()?;
x_desc.set_4d(
cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
cudnnDataType_t::CUDNN_DATA_FLOAT,
batch_size as i32,
self.num_features as i32,
1,
1,
)?;
let mut dy_desc = CudnnTensorDescriptor::new()?;
dy_desc.set_4d(
cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
cudnnDataType_t::CUDNN_DATA_FLOAT,
batch_size as i32,
self.num_features as i32,
1,
1,
)?;
let mut dx_desc = CudnnTensorDescriptor::new()?;
dx_desc.set_4d(
cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
cudnnDataType_t::CUDNN_DATA_FLOAT,
batch_size as i32,
self.num_features as i32,
1,
1,
)?;
let mut derived_bn_desc = CudnnTensorDescriptor::new()?;
derived_bn_desc.set_4d(
cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
cudnnDataType_t::CUDNN_DATA_FLOAT,
1,
self.num_features as i32,
1,
1,
)?;
let alpha_data = 1.0f32;
let beta_data = 0.0f32;
let alpha_param = 1.0f32;
let beta_param = if self.requires_grad && should_accumulate_gradients {
1.0f32
} else {
0.0f32
}; // Accumulate if trainable
let alpha_data_ptr = &alpha_data as *const f32 as *const std::ffi::c_void;
let beta_data_ptr = &beta_data as *const f32 as *const std::ffi::c_void;
let alpha_param_ptr = &alpha_param as *const f32 as *const std::ffi::c_void;
let beta_param_ptr = &beta_param as *const f32 as *const std::ffi::c_void;
// Use SPATIAL mode (same as forward)
let mode = cudnnBatchNormMode_t::CUDNN_BATCHNORM_SPATIAL;
let (mut temp_wg, mut temp_bg); // Must live long enough
let (wg, bg) = if self.requires_grad {
(
self.weight_grad.as_mut().unwrap(),
self.bias_grad.as_mut().unwrap(),
)
} else {
// Use temporary buffers if grads are not required for this layer
temp_wg = Some(stream.alloc_zeros::<f32>(self.num_features)?);
temp_bg = Some(stream.alloc_zeros::<f32>(self.num_features)?);
(temp_wg.as_mut().unwrap(), temp_bg.as_mut().unwrap())
};
let status = unsafe {
cudnnBatchNormalizationBackward(
cudnn.handle,
mode,
alpha_data_ptr, // alphaDataDiff
beta_data_ptr, // betaDataDiff
alpha_param_ptr, // alphaParamDiff
beta_param_ptr, // betaParamDiff (use 1.0 to accumulate, 0.0 to overwrite)
*x_desc, // xDesc
input.device_ptr(&stream).0 as *const _, // x
*dy_desc, // dyDesc
grad_output.device_ptr(&stream).0 as *const _, // dy
*dx_desc, // dxDesc
grad_input.device_ptr_mut(&stream).0 as *mut _, // dx
*derived_bn_desc, // dBnScaleBiasDesc
self.weight.device_ptr(&stream).0 as *const _, // bnScale
wg.device_ptr_mut(&stream).0 as *mut _, // dBnScaleResult (weight gradients)
bg.device_ptr_mut(&stream).0 as *mut _, // dBnBiasResult (bias gradients)
self.eps, // epsilon
self.saved_mean.device_ptr(&stream).0 as *const _, // savedMean
self.saved_inv_variance.device_ptr(&stream).0 as *const _, // savedInvVariance
)
};
if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS {
return Err(CudnnError(status).into());
}
Ok(grad_input)
}
pub fn zero_grad(&mut self, stream: Arc<CudaStream>, module: Arc<CudaModule>) -> Result<()> {
let zero_kernel = module.load_function("zero_out")?;
if let Some(wg) = self.weight_grad.as_mut() {
let n = wg.len() as u32;
let grid_dim = (n + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
let cfg = LaunchConfig {
grid_dim: (grid_dim, 1, 1),
block_dim: (THREADS_PER_BLOCK, 1, 1),
shared_mem_bytes: 0,
};
let n_i32 = n as i32;
unsafe {
stream
.launch_builder(&zero_kernel)
.arg(wg)
.arg(&n_i32)
.launch(cfg)?;
};
}
if let Some(bg) = self.bias_grad.as_mut() {
let n = bg.len() as u32;
let grid_dim = (n + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
let cfg = LaunchConfig {
grid_dim: (grid_dim, 1, 1),
block_dim: (THREADS_PER_BLOCK, 1, 1),
shared_mem_bytes: 0,
};
let n_i32 = n as i32;
unsafe {
stream
.launch_builder(&zero_kernel)
.arg(bg)
.arg(&n_i32)
.launch(cfg)?;
};
}
Ok(())
}
}

View File

@ -0,0 +1,157 @@
#include <cuda_runtime.h>
#include <device_launch_parameters.h>
#include <curand_kernel.h>
#include <cfloat>
#include <cmath>
__device__ float relu(float x) {
return fmaxf(x, 0.0f);
}
extern "C" __global__ void init_linear_layer(
unsigned long long seed,
int out_features,
int in_features,
float* weights, // (out_features, in_features)
float* biases // (out_features)
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= out_features * in_features) return;
curandState state;
curand_init(seed + idx, 0, 0, &state);
float fan_in = (float)in_features;
float fan_out = (float)out_features;
float limit = sqrtf(2.0f / (fan_in + fan_out)) * 0.5f;
weights[idx] = curand_uniform(&state) * 2.0f * limit - limit;
// Initialize biases to zero (can be done by one thread or memset)
if (idx < out_features) {
biases[idx] = 0.0f;
}
}
extern "C" __global__ void zero_out(float* data, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
data[idx] = 0.0f;
}
}
extern "C" __global__ void add_bias_forward(float* output, const float* bias, int batch_size, int features) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= batch_size * features) return;
int feature_idx = idx % features;
output[idx] += bias[feature_idx];
}
extern "C" __global__ void activation_forward(float* data, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= n) return;
data[idx] = relu(data[idx]);
}
extern "C" __global__ void loss_mse(
const float* output, // (batch_size, out_features)
const float* target, // (batch_size, out_features)
int batch_size,
int out_features,
float* grad_loss, // (batch_size, out_features)
float* total_loss_out // single element
) {
extern __shared__ float s_loss[]; // size = blockDim.x
int tid = threadIdx.x;
int idx = blockIdx.x * blockDim.x + tid;
float element_loss_sum = 0.0f;
int n = batch_size * out_features;
for (int i = idx; i < n; i += gridDim.x * blockDim.x) {
float diff = output[i] - target[i];
grad_loss[i] = 2.0f * diff;// / batch_size;
element_loss_sum += diff * diff;
}
s_loss[tid] = element_loss_sum;
__syncthreads();
// Reduction in shared memory
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
if (tid < s) {
s_loss[tid] += s_loss[tid + s];
}
__syncthreads();
}
if (tid == 0) {
atomicAdd(total_loss_out, s_loss[0] / n);
}
}
extern "C" __global__ void activation_backward(
const float* grad_in,
const float* pre_act_vals, // Input to the activation function from forward pass
int n,
float* grad_out
) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) {
// If pre-activation value was positive, gradient passes through. Otherwise, it's zero.
grad_out[i] = pre_act_vals[i] > 0.0f ? grad_in[i] : 0.0f;
}
}
extern "C" __global__ void backward_bias(
const float* grad_output, // (batch_size, out_features)
float* bias_grad, // (out_features)
int batch_size,
int out_features
) {
extern __shared__ float s_grad_sum[]; // size = out_features
int feature_idx = blockIdx.x; // Each block responsible for one feature
if (feature_idx >= out_features) return;
float sum = 0.0f;
for(int i = threadIdx.x; i < batch_size; i += blockDim.x) {
sum += grad_output[i * out_features + feature_idx];
}
s_grad_sum[threadIdx.x] = sum;
__syncthreads();
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
if (threadIdx.x < s) {
s_grad_sum[threadIdx.x] += s_grad_sum[threadIdx.x + s];
}
__syncthreads();
}
if(threadIdx.x == 0) {
atomicAdd(&bias_grad[feature_idx], s_grad_sum[0]);
}
}
extern "C" __global__ void apply_parameter_updates_direct(
float* params,
const float* updates,
int n
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
params[idx] += updates[idx]; // Direct addition (updates already scaled)
}
}
extern "C" __global__ void copy_tensor(
float* dst,
const float* src,
int n
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
dst[idx] = src[idx];
}
}

View File

@ -0,0 +1,251 @@
use anyhow::Result;
use cudarc::{
cublas::{sys::cublasOperation_t, CudaBlas, Gemm, GemmConfig},
driver::{CudaModule, CudaSlice, CudaStream, DevicePtr, LaunchConfig, PushKernelArg},
};
use std::sync::Arc;
const THREADS_PER_BLOCK: u32 = 1024;
pub struct Linear {
pub in_features: usize,
pub out_features: usize,
pub weight: CudaSlice<f32>,
pub bias: CudaSlice<f32>,
pub requires_grad: bool,
pub weight_grad: Option<CudaSlice<f32>>,
pub bias_grad: Option<CudaSlice<f32>>,
}
impl Linear {
pub fn new(
in_features: usize,
out_features: usize,
requires_grad: bool,
stream: Arc<CudaStream>,
) -> Result<Self> {
let weight = stream.alloc_zeros::<f32>(out_features * in_features)?;
let bias = stream.alloc_zeros::<f32>(out_features)?;
let (weight_grad, bias_grad) = if requires_grad {
(
Some(stream.alloc_zeros::<f32>(out_features * in_features)?),
Some(stream.alloc_zeros::<f32>(out_features)?),
)
} else {
(None, None)
};
Ok(Self {
in_features,
out_features,
weight,
bias,
requires_grad,
weight_grad,
bias_grad,
})
}
pub fn init_weights(
&mut self,
seed: [u8; 32],
stream: Arc<CudaStream>,
module: Arc<CudaModule>,
) -> Result<()> {
let kernel = module.load_function("init_linear_layer")?;
let n = (self.out_features * self.in_features) as u32;
let grid_dim = (n + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
let cfg = LaunchConfig {
grid_dim: (grid_dim, 1, 1),
block_dim: (THREADS_PER_BLOCK, 1, 1),
shared_mem_bytes: 0,
};
let d_seed = stream.memcpy_stod(&seed)?;
let out_f = self.out_features as i32;
let in_f = self.in_features as i32;
unsafe {
stream
.launch_builder(&kernel)
.arg(&d_seed)
.arg(&out_f)
.arg(&in_f)
.arg(&mut self.weight)
.arg(&mut self.bias)
.launch(cfg)?
};
Ok(())
}
pub fn forward<I: DevicePtr<f32>>(
&self,
input_batch: &I,
stream: Arc<CudaStream>,
module: Arc<CudaModule>,
cublas: &CudaBlas,
) -> Result<CudaSlice<f32>> {
let batch_size = input_batch.len() / self.in_features;
let mut output_batch = stream.alloc_zeros::<f32>(batch_size * self.out_features)?;
let gemm_config = GemmConfig {
transa: cublasOperation_t::CUBLAS_OP_N, // Don't transpose input
transb: cublasOperation_t::CUBLAS_OP_N, // Don't transpose weight
m: self.out_features as i32, // Changed: rows of output
n: batch_size as i32, // Changed: cols of output
k: self.in_features as i32, // Same: inner dimension
alpha: 1.0f32,
lda: self.out_features as i32, // Changed: leading dim of weight
ldb: self.in_features as i32, // Changed: leading dim of input
beta: 0.0f32,
ldc: self.out_features as i32, // Changed: leading dim of output
};
unsafe {
cublas.gemm(gemm_config, &self.weight, input_batch, &mut output_batch)?;
}
let kernel = module.load_function("add_bias_forward")?;
let n = (batch_size * self.out_features) as u32;
let grid_dim = (n + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
let cfg = LaunchConfig {
grid_dim: (grid_dim, 1, 1),
block_dim: (THREADS_PER_BLOCK, 1, 1),
shared_mem_bytes: 0,
};
let bs = batch_size as i32;
let of = self.out_features as i32;
unsafe {
stream
.launch_builder(&kernel)
.arg(&mut output_batch)
.arg(&self.bias)
.arg(&bs)
.arg(&of)
.launch(cfg)?
};
Ok(output_batch)
}
pub fn backward(
&mut self,
input_from_cache: &CudaSlice<f32>,
grad_output_batch: &CudaSlice<f32>,
should_accumulate_gradients: bool,
stream: Arc<CudaStream>,
module: Arc<CudaModule>,
cublas: &CudaBlas,
) -> Result<CudaSlice<f32>> {
let batch_size = input_from_cache.len() / self.in_features;
if self.requires_grad {
let wg = self.weight_grad.as_mut().unwrap();
let bg = self.bias_grad.as_mut().unwrap();
// Correctly computes dW = d(Y^T) * X^T for column-major layout.
// dW(out,in) = d(Y^T)(out,batch) * X(in,batch)^T
let gemm_config_wg = GemmConfig {
transa: cublasOperation_t::CUBLAS_OP_N,
transb: cublasOperation_t::CUBLAS_OP_T,
m: self.out_features as i32,
n: self.in_features as i32,
k: batch_size as i32,
alpha: 1.0f32,
lda: self.out_features as i32,
ldb: self.in_features as i32,
beta: if should_accumulate_gradients {
1.0f32
} else {
0.0f32
},
ldc: self.out_features as i32,
};
unsafe {
cublas.gemm(gemm_config_wg, grad_output_batch, input_from_cache, wg)?;
}
let kernel = module.load_function("backward_bias")?;
let threads_per_block = 256u32;
let grid_dim = (self.out_features as u32, 1, 1);
let cfg = LaunchConfig {
grid_dim: grid_dim,
block_dim: (threads_per_block, 1, 1),
shared_mem_bytes: threads_per_block * 4,
};
let bs = batch_size as i32;
let of = self.out_features as i32;
unsafe {
stream
.launch_builder(&kernel)
.arg(grad_output_batch)
.arg(bg)
.arg(&bs)
.arg(&of)
.launch(cfg)?
};
}
let mut grad_input_batch = stream.alloc_zeros::<f32>(batch_size * self.in_features)?;
// Correctly computes dX = W^T * d(Y^T) for column-major layout.
// dX(in,batch) = W(out,in)^T * d(Y^T)(out,batch)
let gemm_config_d_input = GemmConfig {
transa: cublasOperation_t::CUBLAS_OP_T,
transb: cublasOperation_t::CUBLAS_OP_N,
m: self.in_features as i32,
n: batch_size as i32,
k: self.out_features as i32,
alpha: 1.0f32,
lda: self.out_features as i32,
ldb: self.out_features as i32,
beta: 0.0f32,
ldc: self.in_features as i32,
};
unsafe {
cublas.gemm(
gemm_config_d_input,
&self.weight,
grad_output_batch,
&mut grad_input_batch,
)?;
}
Ok(grad_input_batch)
}
pub fn zero_grad(&mut self, stream: Arc<CudaStream>, module: Arc<CudaModule>) -> Result<()> {
let zero_kernel = module.load_function("zero_out")?;
if let Some(wg) = self.weight_grad.as_mut() {
let n = wg.len() as u32;
let grid_dim = (n + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
let cfg = LaunchConfig {
grid_dim: (grid_dim, 1, 1),
block_dim: (THREADS_PER_BLOCK, 1, 1),
shared_mem_bytes: 0,
};
let n_i32 = n as i32;
unsafe {
stream
.launch_builder(&zero_kernel)
.arg(wg)
.arg(&n_i32)
.launch(cfg)?;
};
}
if let Some(bg) = self.bias_grad.as_mut() {
let n = bg.len() as u32;
let grid_dim = (n + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
let cfg = LaunchConfig {
grid_dim: (grid_dim, 1, 1),
block_dim: (THREADS_PER_BLOCK, 1, 1),
shared_mem_bytes: 0,
};
let n_i32 = n as i32;
unsafe {
stream
.launch_builder(&zero_kernel)
.arg(bg)
.arg(&n_i32)
.launch(cfg)?;
};
}
Ok(())
}
}

View File

@ -0,0 +1,516 @@
use super::{BatchNorm1d, Linear};
use anyhow::Result;
use cudarc::{
cublas::CudaBlas,
cudnn::Cudnn,
driver::{
CudaModule, CudaSlice, CudaStream, CudaView, DevicePtr, DeviceRepr, LaunchConfig,
PushKernelArg,
},
};
use rand::{prelude::*, rngs::StdRng};
use std::sync::Arc;
const THREADS_PER_BLOCK: u32 = 1024;
pub struct MLP {
pub lin: Vec<Linear>,
pub bns: Vec<BatchNorm1d>,
pub layer_cnt: usize,
}
#[derive(Clone)]
pub struct ForwardCache<T: DeviceRepr> {
pub input: CudaSlice<T>,
pub linear_output: CudaSlice<T>,
pub activated_output: Option<CudaSlice<T>>,
pub bn_input: Option<CudaSlice<T>>,
}
impl MLP {
pub fn new(
layer_sizes: &[usize],
frozen_layers: usize,
stream: Arc<CudaStream>,
) -> Result<Self> {
let layer_cnt = layer_sizes.len() - 1;
let mut lin = Vec::with_capacity(layer_cnt);
let mut bns = Vec::with_capacity(layer_cnt - 1);
for l in 0..layer_cnt {
let requires_grad = l < layer_cnt.saturating_sub(frozen_layers);
lin.push(Linear::new(
layer_sizes[l],
layer_sizes[l + 1],
requires_grad,
stream.clone(),
)?);
if l < layer_cnt - 1 {
bns.push(BatchNorm1d::new(
layer_sizes[l + 1],
0.1,
1e-5,
requires_grad,
stream.clone(),
)?);
}
}
Ok(Self {
lin,
bns,
layer_cnt,
})
}
pub fn init_weights(
&mut self,
seed: [u8; 32],
stream: Arc<CudaStream>,
module: Arc<CudaModule>,
) -> Result<()> {
let mut rng = StdRng::from_seed(seed);
for layer in &mut self.lin {
layer.init_weights(rng.gen(), stream.clone(), module.clone())?;
}
Ok(())
}
pub fn zero_grad(&mut self, stream: Arc<CudaStream>, module: Arc<CudaModule>) -> Result<()> {
for layer in &mut self.lin {
layer.zero_grad(stream.clone(), module.clone())?;
}
for bn in &mut self.bns {
bn.zero_grad(stream.clone(), module.clone())?;
}
Ok(())
}
pub fn forward<I: DevicePtr<f32>>(
&mut self,
input: &I,
training: bool,
stream: Arc<CudaStream>,
module: Arc<CudaModule>,
cublas: &CudaBlas,
cudnn: &Cudnn,
) -> Result<(CudaSlice<f32>, Vec<ForwardCache<f32>>)> {
let mut x = stream.alloc_zeros::<f32>(input.len())?;
stream.memcpy_dtod(input, &mut x)?;
let mut caches = Vec::with_capacity(self.layer_cnt);
for l in 0..self.layer_cnt {
let input_cache = x.clone();
let linear_output =
self.lin[l].forward(&input_cache, stream.clone(), module.clone(), cublas)?;
if l < self.layer_cnt - 1 {
let mut activated = linear_output.clone();
let act_fwd_kernel = module.load_function("activation_forward")?;
let n = activated.len() as u32;
let grid_dim = (n + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
let cfg = LaunchConfig {
grid_dim: (grid_dim, 1, 1),
block_dim: (THREADS_PER_BLOCK, 1, 1),
shared_mem_bytes: 0,
};
let n_i32 = n as i32;
unsafe {
stream
.launch_builder(&act_fwd_kernel)
.arg(&mut activated)
.arg(&n_i32)
.launch(cfg)?
};
let bn_input_cache = activated.clone();
let bn_output = self.bns[l].forward(&activated, training, stream.clone(), cudnn)?;
caches.push(ForwardCache {
input: input_cache,
linear_output,
activated_output: Some(activated),
bn_input: Some(bn_input_cache),
});
x = bn_output;
} else {
caches.push(ForwardCache {
input: input_cache,
linear_output: linear_output.clone(),
activated_output: None,
bn_input: None,
});
x = linear_output;
}
}
Ok((x, caches))
}
pub fn backward(
&mut self,
grad: &CudaSlice<f32>,
forward_caches: &[ForwardCache<f32>],
should_accumulate_gradients: bool,
stream: Arc<CudaStream>,
module: Arc<CudaModule>,
cublas: &CudaBlas,
cudnn: &Cudnn,
) -> Result<()> {
let mut current_grad = grad.clone();
for i in (0..self.lin.len()).rev() {
let mut grad_to_pass_to_linear = current_grad.clone();
// For intermediate layers, backpropagate through BN and Activation in reverse order.
if i < self.bns.len() {
// Step 1: Backpropagate through BatchNorm.
// The input to the BN's forward pass was the *activated* output of the linear layer.
let bn_input = forward_caches[i].activated_output.as_ref().unwrap();
let grad_after_bn = self.bns[i].backward(
bn_input,
&current_grad,
should_accumulate_gradients,
stream.clone(),
cudnn,
module.clone(),
)?;
// Step 2: Backpropagate through Activation.
// The input to the activation's forward pass was the direct output of the linear layer.
let pre_activation_values = &forward_caches[i].linear_output;
let mut grad_after_activation = stream.alloc_zeros::<f32>(grad_after_bn.len())?;
let kernel = module.load_function("activation_backward")?;
let cfg = LaunchConfig {
grid_dim: ((grad_after_bn.len() as u32 + 255) / 256, 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 0,
};
unsafe {
stream
.launch_builder(&kernel)
.arg(&grad_after_bn)
.arg(pre_activation_values)
.arg(&(grad_after_bn.len() as i32))
.arg(&mut grad_after_activation)
.launch(cfg)?;
}
grad_to_pass_to_linear = grad_after_activation;
}
// Step 3: Backpropagate through the linear layer.
// The input to the linear layer's forward pass is stored in the cache for this layer.
let input_to_linear = &forward_caches[i].input;
let grad_after_linear = self.lin[i].backward(
input_to_linear,
&grad_to_pass_to_linear,
should_accumulate_gradients,
stream.clone(),
module.clone(),
cublas,
)?;
current_grad = grad_after_linear;
}
Ok(())
}
pub fn loss_and_grad(
&self,
output: &CudaSlice<f32>,
target: &CudaView<'_, f32>,
stream: Arc<CudaStream>,
module: Arc<CudaModule>,
) -> Result<(CudaSlice<f32>, CudaSlice<f32>)> {
let mut grad = stream.alloc_zeros::<f32>(output.len())?;
let mut loss = stream.alloc_zeros::<f32>(1)?;
let loss_kernel = module.load_function("loss_mse")?;
let total_elements = output.len() as u32;
let threads_per_block = 256u32;
let grid_dim = (total_elements + threads_per_block - 1) / threads_per_block;
let cfg = LaunchConfig {
grid_dim: (grid_dim, 1, 1),
block_dim: (threads_per_block, 1, 1),
shared_mem_bytes: threads_per_block * 4,
};
let batch_size = (output.len() / self.lin.last().unwrap().out_features) as i32;
let out_features = self.lin.last().unwrap().out_features as i32;
unsafe {
stream
.launch_builder(&loss_kernel)
.arg(output)
.arg(target)
.arg(&batch_size)
.arg(&out_features)
.arg(&mut grad)
.arg(&mut loss)
.launch(cfg)?;
}
Ok((loss, grad))
}
/// Extract all model parameters into a flat vector of CudaSlices
pub fn extract_parameters(&self, _stream: Arc<CudaStream>) -> Result<Vec<CudaSlice<f32>>> {
let mut params = Vec::new();
// Linear layer parameters
for layer in &self.lin {
params.push(layer.weight.clone());
params.push(layer.bias.clone());
}
// BatchNorm parameters
for bn in &self.bns {
params.push(bn.weight.clone());
params.push(bn.bias.clone());
params.push(bn.running_mean.clone());
params.push(bn.running_var.clone());
}
Ok(params)
}
/// Extract all model gradients into a flat vector of CudaSlices
pub fn extract_gradients(&self, stream: Arc<CudaStream>) -> Result<Vec<CudaSlice<f32>>> {
let mut grads = Vec::new();
// Linear layer gradients
for layer in &self.lin {
if layer.requires_grad {
grads.push(layer.weight_grad.as_ref().unwrap().clone());
grads.push(layer.bias_grad.as_ref().unwrap().clone());
} else {
// Create zero tensors for non-trainable parameters
grads.push(stream.alloc_zeros::<f32>(layer.weight.len())?);
grads.push(stream.alloc_zeros::<f32>(layer.bias.len())?);
}
}
// BatchNorm gradients
for bn in &self.bns {
if bn.requires_grad {
grads.push(bn.weight_grad.as_ref().unwrap().clone());
grads.push(bn.bias_grad.as_ref().unwrap().clone());
} else {
grads.push(stream.alloc_zeros::<f32>(bn.weight.len())?);
grads.push(stream.alloc_zeros::<f32>(bn.bias.len())?);
}
// No gradients for running_mean and running_var
grads.push(stream.alloc_zeros::<f32>(bn.running_mean.len())?);
grads.push(stream.alloc_zeros::<f32>(bn.running_var.len())?);
}
Ok(grads)
}
/// Get parameter sizes for optimizer initialization
pub fn get_parameter_sizes(&self) -> Vec<usize> {
let mut sizes = Vec::new();
for layer in &self.lin {
sizes.push(layer.weight.len());
sizes.push(layer.bias.len());
}
for bn in &self.bns {
sizes.push(bn.weight.len());
sizes.push(bn.bias.len());
sizes.push(bn.running_mean.len());
sizes.push(bn.running_var.len());
}
sizes
}
/// Apply parameter updates from optimizer
pub fn apply_optimizer_updates(
&mut self,
updates: &[CudaSlice<f32>],
stream: Arc<CudaStream>,
module: Arc<CudaModule>,
) -> Result<()> {
let kernel = module.load_function("apply_parameter_updates_direct")?;
let mut update_idx = 0;
// Apply to linear layers
for layer in &mut self.lin {
if layer.requires_grad {
// Weight update
let n = layer.weight.len() as u32;
let grid_dim = (n + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
let cfg = LaunchConfig {
grid_dim: (grid_dim, 1, 1),
block_dim: (THREADS_PER_BLOCK, 1, 1),
shared_mem_bytes: 0,
};
let n_i32 = n as i32;
unsafe {
stream
.launch_builder(&kernel)
.arg(&mut layer.weight)
.arg(&updates[update_idx])
.arg(&n_i32)
.launch(cfg)?;
}
}
update_idx += 1;
if layer.requires_grad {
// Bias update
let n = layer.bias.len() as u32;
let grid_dim = (n + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
let cfg = LaunchConfig {
grid_dim: (grid_dim, 1, 1),
block_dim: (THREADS_PER_BLOCK, 1, 1),
shared_mem_bytes: 0,
};
let n_i32 = n as i32;
unsafe {
stream
.launch_builder(&kernel)
.arg(&mut layer.bias)
.arg(&updates[update_idx])
.arg(&n_i32)
.launch(cfg)?;
}
}
update_idx += 1;
}
// Apply to BatchNorm layers
for bn in &mut self.bns {
if bn.requires_grad {
// Weight update
let n = bn.weight.len() as u32;
let grid_dim = (n + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
let cfg = LaunchConfig {
grid_dim: (grid_dim, 1, 1),
block_dim: (THREADS_PER_BLOCK, 1, 1),
shared_mem_bytes: 0,
};
let n_i32 = n as i32;
unsafe {
stream
.launch_builder(&kernel)
.arg(&mut bn.weight)
.arg(&updates[update_idx])
.arg(&n_i32)
.launch(cfg)?;
}
}
update_idx += 1;
if bn.requires_grad {
// Bias update
let n = bn.bias.len() as u32;
let grid_dim = (n + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
let cfg = LaunchConfig {
grid_dim: (grid_dim, 1, 1),
block_dim: (THREADS_PER_BLOCK, 1, 1),
shared_mem_bytes: 0,
};
let n_i32 = n as i32;
unsafe {
stream
.launch_builder(&kernel)
.arg(&mut bn.bias)
.arg(&updates[update_idx])
.arg(&n_i32)
.launch(cfg)?;
}
}
update_idx += 1;
// Skip running_mean and running_var (they're not trainable)
update_idx += 2;
}
Ok(())
}
/// Set model parameters from a vector of CudaSlices
pub fn set_parameters(
&mut self,
params: &[CudaSlice<f32>],
stream: Arc<CudaStream>,
module: Arc<CudaModule>,
) -> Result<()> {
let copy_kernel = module.load_function("copy_tensor")?;
let mut param_idx = 0;
for layer in &mut self.lin {
// Copy weights
let n = layer.weight.len() as u32;
let grid_dim = (n + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
let cfg = LaunchConfig {
grid_dim: (grid_dim, 1, 1),
block_dim: (THREADS_PER_BLOCK, 1, 1),
shared_mem_bytes: 0,
};
let n_i32 = n as i32;
unsafe {
stream
.launch_builder(&copy_kernel)
.arg(&mut layer.weight)
.arg(&params[param_idx])
.arg(&n_i32)
.launch(cfg)?;
}
param_idx += 1;
// Copy biases
let n = layer.bias.len() as u32;
let grid_dim = (n + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
let cfg = LaunchConfig {
grid_dim: (grid_dim, 1, 1),
block_dim: (THREADS_PER_BLOCK, 1, 1),
shared_mem_bytes: 0,
};
let n_i32 = n as i32;
unsafe {
stream
.launch_builder(&copy_kernel)
.arg(&mut layer.bias)
.arg(&params[param_idx])
.arg(&n_i32)
.launch(cfg)?;
}
param_idx += 1;
}
for bn in &mut self.bns {
// Copy BN parameters
for target in [
&mut bn.weight,
&mut bn.bias,
&mut bn.running_mean,
&mut bn.running_var,
] {
let n = target.len() as u32;
let grid_dim = (n + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
let cfg = LaunchConfig {
grid_dim: (grid_dim, 1, 1),
block_dim: (THREADS_PER_BLOCK, 1, 1),
shared_mem_bytes: 0,
};
let n_i32 = n as i32;
unsafe {
stream
.launch_builder(&copy_kernel)
.arg(target)
.arg(&params[param_idx])
.arg(&n_i32)
.launch(cfg)?;
}
param_idx += 1;
}
}
Ok(())
}
}

View File

@ -0,0 +1,8 @@
mod batch_norm;
pub use batch_norm::*;
mod linear;
pub use linear::*;
mod mlp;
pub use mlp::*;
mod tensor;
pub use tensor::*;

View File

@ -0,0 +1,49 @@
use anyhow::Result;
use cudarc::cudnn::{self, result::CudnnError};
use std::ops::Deref;
pub struct CudnnTensorDescriptor(cudnn::sys::cudnnTensorDescriptor_t);
impl CudnnTensorDescriptor {
pub fn new() -> Result<Self, CudnnError> {
let mut desc = std::ptr::null_mut();
unsafe {
match cudnn::sys::cudnnCreateTensorDescriptor(&mut desc) {
cudnn::sys::cudnnStatus_t::CUDNN_STATUS_SUCCESS => Ok(Self(desc)),
e => Err(CudnnError(e)),
}
}
}
pub fn set_4d(
&mut self,
format: cudnn::sys::cudnnTensorFormat_t,
data_type: cudnn::sys::cudnnDataType_t,
n: i32,
c: i32,
h: i32,
w: i32,
) -> Result<(), CudnnError> {
unsafe {
match cudnn::sys::cudnnSetTensor4dDescriptor(self.0, format, data_type, n, c, h, w) {
cudnn::sys::cudnnStatus_t::CUDNN_STATUS_SUCCESS => Ok(()),
e => Err(CudnnError(e)),
}
}
}
}
impl Deref for CudnnTensorDescriptor {
type Target = cudnn::sys::cudnnTensorDescriptor_t;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl Drop for CudnnTensorDescriptor {
fn drop(&mut self) {
unsafe {
cudnn::sys::cudnnDestroyTensorDescriptor(self.0);
}
}
}

View File

@ -0,0 +1,115 @@
#include <cuda_runtime.h>
#include <device_launch_parameters.h>
#include <curand_kernel.h>
#include <cfloat>
#include <cmath>
#ifndef M_PI
#define M_PI 3.14159265358979323846f
#endif
__device__ float box_muller(curandState* state, float& z1) {
float u1, u2;
// Prevent u1 from being 0 to avoid log(0) -> -inf
do {
u1 = curand_uniform(state);
} while (u1 == 0.0f);
u2 = curand_uniform(state);
float mag = sqrtf(-2.0f * logf(u1));
float z0 = mag * cosf(2.0f * M_PI * u2);
z1 = mag * sinf(2.0f * M_PI * u2);
return z0;
}
extern "C" __global__ void generate_rff_params(
unsigned char* seed,
int output_dims,
int input_dims,
int k_rff,
float lengthscale,
float* a_params, // (output_dims, k_rff)
float* b_params, // (output_dims, k_rff)
float* w_params // (output_dims, k_rff, input_dims)
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= output_dims * k_rff) return;
curandState state;
curand_init(
*((unsigned long long*)seed) + idx,
0, 0, &state
);
int out_dim = idx / k_rff;
int k_idx = idx % k_rff;
// Box-muller generates two samples, cache one
float z1;
a_params[idx] = box_muller(&state, z1);
// Note: this is not perfectly efficient, could be improved by having half threads write z1
b_params[idx] = curand_uniform(&state) * 2.0f * M_PI;
float lengthscale_inv_sq = 1.0f / (lengthscale * lengthscale);
for(int in_dim = 0; in_dim < input_dims; ++in_dim) {
int w_idx = out_dim * k_rff * input_dims + k_idx * input_dims + in_dim;
float z_w1;
w_params[w_idx] = lengthscale_inv_sq * box_muller(&state, z_w1);
}
}
extern "C" __global__ void generate_dataset(
unsigned char* seed,
int num_samples,
int input_dims,
int output_dims,
int k_rff,
float scaling_factor,
float noise_std,
const float* a_params,
const float* b_params,
const float* w_params,
float* out_inputs, // (num_samples, input_dims)
float* out_targets_noisy, // (num_samples, output_dims)
float* out_targets_true // (num_samples, output_dims)
) {
int sample_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (sample_idx >= num_samples) return;
curandState state;
curand_init(
*((unsigned long long*)seed) + sample_idx + num_samples, // Offset seed
0, 0, &state
);
// Generate input sample
for (int i = 0; i < input_dims; ++i) {
out_inputs[sample_idx * input_dims + i] = curand_uniform(&state) * 2.0f - 1.0f;
}
// Generate targets
for (int out_dim = 0; out_dim < output_dims; ++out_dim) {
float f_val = 0.0f;
for (int k_idx = 0; k_idx < k_rff; ++k_idx) {
float wx_sum = 0.0f;
for (int in_dim = 0; in_dim < input_dims; ++in_dim) {
float w = w_params[out_dim * k_rff * input_dims + k_idx * input_dims + in_dim];
float x = out_inputs[sample_idx * input_dims + in_dim];
wx_sum += w * x;
}
float b = b_params[out_dim * k_rff + k_idx];
float a = a_params[out_dim * k_rff + k_idx];
f_val += a * cosf(wx_sum + b);
}
f_val *= scaling_factor;
float z_noise1;
float noise = noise_std * box_muller(&state, z_noise1);
out_targets_true[sample_idx * output_dims + out_dim] = f_val;
out_targets_noisy[sample_idx * output_dims + out_dim] = f_val + noise;
}
}

View File

@ -0,0 +1,648 @@
use anyhow::{anyhow, Result};
use cudarc::{
cublas::CudaBlas,
cudnn::Cudnn,
driver::{CudaModule, CudaSlice, CudaStream, CudaView, LaunchConfig, PushKernelArg},
runtime::sys::cudaDeviceProp,
};
use rand::{prelude::*, rngs::StdRng};
use serde::{Deserialize, Serialize};
use serde_json::{from_value, Map, Value};
use std::{any::Any, sync::Arc};
use crate::nn::MLP;
const THREADS_PER_BLOCK: u32 = 1024;
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Difficulty {
pub num_hidden_layers: usize,
pub accuracy_factor: u32,
}
impl From<Vec<i32>> for Difficulty {
fn from(arr: Vec<i32>) -> Self {
Self {
num_hidden_layers: arr[0] as usize,
accuracy_factor: arr[1] as u32,
}
}
}
impl Into<Vec<i32>> for Difficulty {
fn into(self) -> Vec<i32> {
vec![self.num_hidden_layers as i32, self.accuracy_factor as i32]
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Solution {
pub weights: Vec<Vec<Vec<f32>>>,
pub biases: Vec<Vec<f32>>,
pub epochs_used: usize,
pub bn_weights: Vec<Vec<f32>>,
pub bn_biases: Vec<Vec<f32>>,
pub bn_running_means: Vec<Vec<f32>>,
pub bn_running_vars: Vec<Vec<f32>>,
}
impl TryFrom<Map<String, Value>> for Solution {
type Error = serde_json::Error;
fn try_from(v: Map<String, Value>) -> Result<Self, Self::Error> {
from_value(Value::Object(v))
}
}
pub struct Dataset {
pub inputs: CudaSlice<f32>,
pub targets_noisy: CudaSlice<f32>,
pub targets_true_f: CudaSlice<f32>,
pub train_size: usize,
pub validation_size: usize,
pub test_size: usize,
pub input_dims: usize,
pub output_dims: usize,
}
impl Dataset {
pub fn train_inputs(&self) -> CudaView<f32> {
self.inputs.slice(0..self.train_size * self.input_dims)
}
pub fn train_targets_noisy(&self) -> CudaView<f32> {
self.targets_noisy
.slice(0..self.train_size * self.output_dims)
}
pub fn train_targets_true_f(&self) -> CudaView<f32> {
self.targets_true_f
.slice(0..self.train_size * self.output_dims)
}
pub fn validation_inputs(&self) -> CudaView<f32> {
self.inputs.slice(
self.train_size * self.input_dims
..(self.train_size + self.validation_size) * self.input_dims,
)
}
pub fn validation_targets_noisy(&self) -> CudaView<f32> {
self.targets_noisy.slice(
self.train_size * self.output_dims
..(self.train_size + self.validation_size) * self.output_dims,
)
}
pub fn validation_targets_true_f(&self) -> CudaView<f32> {
self.targets_true_f.slice(
self.train_size * self.output_dims
..(self.train_size + self.validation_size) * self.output_dims,
)
}
pub fn test_inputs(&self) -> CudaView<f32> {
self.inputs.slice(
(self.train_size + self.validation_size) * self.input_dims
..(self.train_size + self.validation_size + self.test_size) * self.input_dims,
)
}
pub fn test_targets_noisy(&self) -> CudaView<f32> {
self.targets_noisy.slice(
(self.train_size + self.validation_size) * self.output_dims
..(self.train_size + self.validation_size + self.test_size) * self.output_dims,
)
}
pub fn test_targets_true_f(&self) -> CudaView<f32> {
self.targets_true_f.slice(
(self.train_size + self.validation_size) * self.output_dims
..(self.train_size + self.validation_size + self.test_size) * self.output_dims,
)
}
}
pub struct Challenge {
pub seed: [u8; 32],
pub difficulty: Difficulty,
pub hidden_layers_dims: usize,
pub batch_size: usize,
pub max_epochs: usize,
pub patience: usize,
pub min_loss_delta: f32,
pub num_frozen_layers: usize,
pub dataset: Dataset,
}
impl Challenge {
pub fn generate_instance(
seed: &[u8; 32],
difficulty: &Difficulty,
module: Arc<CudaModule>,
stream: Arc<CudaStream>,
_prop: &cudaDeviceProp,
) -> Result<Self> {
const K_RFF: usize = 128;
const RFF_AMPLITUDE_PER_FUNC: f32 = 1.0;
const RFF_LENGTHSCALE_PER_INPUT_DIM: f32 = 0.3;
const NOISE_STD: f32 = 0.2;
const INPUT_DIMS: usize = 1;
const OUTPUT_DIMS: usize = 2;
const TRAIN_SIZE: usize = 1000;
const VALIDATION_SIZE: usize = 200;
const TEST_SIZE: usize = 250;
let scaling_factor = RFF_AMPLITUDE_PER_FUNC * (2.0 / K_RFF as f32).sqrt();
let d_seed = stream.memcpy_stod(seed)?;
// Allocate memory for RFF params
let mut a_params = stream.alloc_zeros::<f32>(OUTPUT_DIMS * K_RFF)?;
let mut b_params = stream.alloc_zeros::<f32>(OUTPUT_DIMS * K_RFF)?;
let mut w_params = stream.alloc_zeros::<f32>(OUTPUT_DIMS * K_RFF * INPUT_DIMS)?;
// Generate RFF params
let generate_rff_params_kernel = module.load_function("generate_rff_params")?;
let n = (OUTPUT_DIMS * K_RFF) as u32;
let grid_dim = (n + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
let cfg = LaunchConfig {
grid_dim: (grid_dim, 1, 1),
block_dim: (THREADS_PER_BLOCK, 1, 1),
shared_mem_bytes: 0,
};
unsafe {
stream
.launch_builder(&generate_rff_params_kernel)
.arg(&d_seed)
.arg(&(OUTPUT_DIMS as i32))
.arg(&(INPUT_DIMS as i32))
.arg(&(K_RFF as i32))
.arg(&RFF_LENGTHSCALE_PER_INPUT_DIM)
.arg(&mut a_params)
.arg(&mut b_params)
.arg(&mut w_params)
.launch(cfg)?;
}
// Generate datasets
let generate_dataset_kernel = module.load_function("generate_dataset")?;
// Training data
let total_samples = TRAIN_SIZE + VALIDATION_SIZE + TEST_SIZE;
let mut inputs = stream.alloc_zeros::<f32>(total_samples * INPUT_DIMS)?;
let mut targets_noisy = stream.alloc_zeros::<f32>(total_samples * OUTPUT_DIMS)?;
let mut targets_true_f = stream.alloc_zeros::<f32>(total_samples * OUTPUT_DIMS)?;
let grid_dim = (total_samples as u32 + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
let cfg_train = LaunchConfig {
grid_dim: (grid_dim, 1, 1),
block_dim: (THREADS_PER_BLOCK, 1, 1),
shared_mem_bytes: 0,
};
unsafe {
stream
.launch_builder(&generate_dataset_kernel)
.arg(&d_seed)
.arg(&(total_samples as i32))
.arg(&(INPUT_DIMS as i32))
.arg(&(OUTPUT_DIMS as i32))
.arg(&K_RFF)
.arg(&scaling_factor)
.arg(&NOISE_STD)
.arg(&a_params)
.arg(&b_params)
.arg(&w_params)
.arg(&mut inputs)
.arg(&mut targets_noisy)
.arg(&mut targets_true_f)
.launch(cfg_train)?;
}
stream.synchronize()?;
Ok(Self {
seed: *seed,
difficulty: difficulty.clone(),
hidden_layers_dims: 256,
batch_size: 128,
max_epochs: 1000,
patience: 50,
min_loss_delta: 1e-7,
num_frozen_layers: 2,
dataset: Dataset {
inputs,
targets_noisy,
targets_true_f,
train_size: TRAIN_SIZE,
validation_size: VALIDATION_SIZE,
test_size: TEST_SIZE,
input_dims: INPUT_DIMS,
output_dims: OUTPUT_DIMS,
},
})
}
pub fn verify_solution(
&self,
solution: &Solution,
module: Arc<CudaModule>,
stream: Arc<CudaStream>,
_prop: &cudaDeviceProp,
) -> Result<()> {
let cublas = CudaBlas::new(stream.clone())?;
let cudnn = Cudnn::new(stream.clone())?;
let mut model = MLP::new(&self.layer_dims(), self.num_frozen_layers, stream.clone())?;
load_solution(&mut model, solution, stream.clone())?;
let (output, _) = model.forward(
&self.dataset.test_inputs(),
false,
stream.clone(),
module.clone(),
&cublas,
&cudnn,
)?;
let (loss, _) = model.loss_and_grad(
&output,
&&self.dataset.test_targets_noisy(),
stream.clone(),
module.clone(),
)?;
let avg_model_loss_on_test = stream.memcpy_dtov(&loss)?[0];
// Calculate baseline error epsilon_star_squared
let alpha = 4.0 - self.difficulty.accuracy_factor as f32 / 1000.0;
let y_h = stream.memcpy_dtov(&self.dataset.test_targets_noisy())?;
let f_h = stream.memcpy_dtov(&self.dataset.test_targets_true_f())?;
stream.synchronize()?;
let sum_sq_diff_true_vs_noisy: f32 = y_h
.iter()
.zip(f_h.iter())
.map(|(y, f)| (*y - *f).powi(2))
.sum();
let epsilon_star_squared =
(alpha / self.dataset.test_size as f32) * sum_sq_diff_true_vs_noisy;
if avg_model_loss_on_test <= epsilon_star_squared {
Ok(())
} else {
Err(anyhow!(
"Model test loss ({:.4e}) exceeds target baseline epsilon_star_squared ({:.4e})",
avg_model_loss_on_test,
epsilon_star_squared
))
}
}
pub fn layer_dims(&self) -> Vec<usize> {
let mut layer_dims = vec![self.hidden_layers_dims; self.difficulty.num_hidden_layers];
layer_dims.insert(0, self.dataset.input_dims);
layer_dims.push(self.dataset.output_dims);
layer_dims
}
}
pub trait CudaOptimizerState: Any + Send + Sync {
fn as_any(&self) -> &dyn Any;
fn as_any_mut(&mut self) -> &mut dyn Any;
fn box_clone(&self) -> Box<dyn CudaOptimizerState>;
}
impl Clone for Box<dyn CudaOptimizerState> {
fn clone(&self) -> Self {
self.box_clone()
}
}
/// Function type for initializing optimizer state
pub type CudaOptimizerInitStateFn = fn(
seed: &[u8; 32],
param_sizes: &[usize], // Sizes of all parameter tensors
stream: Arc<CudaStream>,
) -> Result<Box<dyn CudaOptimizerState>>;
/// Function type for querying optimizer at specific parameters (like parameter prediction)
pub type CudaOptimizerQueryAtParamsFn = fn(
optimizer_state: &dyn CudaOptimizerState,
model_params: &[CudaSlice<f32>],
gradients: Option<&[CudaSlice<f32>]>,
epoch: usize,
train_loss: Option<f32>,
val_loss: Option<f32>,
stream: Arc<CudaStream>,
module: Arc<CudaModule>,
) -> Result<Option<Vec<CudaSlice<f32>>>>;
/// Function type for optimizer step (computes parameter updates)
pub type CudaOptimizerStepFn = fn(
optimizer_state: &mut dyn CudaOptimizerState,
gradients: &[CudaSlice<f32>],
epoch: usize,
train_loss: Option<f32>,
val_loss: Option<f32>,
stream: Arc<CudaStream>,
module: Arc<CudaModule>,
) -> Result<Vec<CudaSlice<f32>>>;
pub fn training_loop(
challenge: &Challenge,
module: Arc<CudaModule>,
stream: Arc<CudaStream>,
prop: &cudaDeviceProp,
optimizer_init_state: CudaOptimizerInitStateFn,
optimizer_query_at_params: CudaOptimizerQueryAtParamsFn,
optimizer_step: CudaOptimizerStepFn,
) -> Result<(Solution, Vec<f32>, Vec<f32>)> {
let Challenge {
batch_size,
max_epochs,
min_loss_delta,
patience,
dataset:
Dataset {
train_size,
validation_size,
input_dims,
output_dims,
..
},
..
} = *challenge;
let cublas = CudaBlas::new(stream.clone())?;
let cudnn = Cudnn::new(stream.clone())?;
let mut model = MLP::new(
&challenge.layer_dims(),
challenge.num_frozen_layers,
stream.clone(),
)?;
model.init_weights(&challenge.seed, stream.clone(), module.clone())?;
// Initialize optimizer
let param_sizes = model.get_parameter_sizes();
let mut optimizer_state = optimizer_init_state(
challenge.seed,
&param_sizes,
stream.clone(),
module.clone(),
prop,
)?;
let mut lowest_loss = f32::INFINITY;
let mut _best_epoch = 0;
let mut epochs_no_improvement = 0;
let mut best_model_solution: Option<Solution> = None;
let mut prev_train_loss = None;
let mut prev_validation_loss = None;
let mut train_losses = Vec::with_capacity(max_epochs);
let mut validation_losses = Vec::with_capacity(max_epochs);
let num_train_batches = (train_size + batch_size - 1) / batch_size;
let num_val_batches = (validation_size + batch_size - 1) / batch_size;
// Initialize RNG for shuffling
let mut rng = StdRng::from_seed(challenge.seed);
// Copy training data to host for shuffled batch creation
let train_inputs = stream.memcpy_dtov(&challenge.dataset.train_inputs())?;
let train_targets = stream.memcpy_dtov(&challenge.dataset.train_targets_noisy())?;
let validation_inputs_d = challenge.dataset.validation_inputs();
let validation_targets_d = challenge.dataset.validation_targets_noisy();
stream.synchronize()?;
for epoch in 0..max_epochs {
// --- Shuffle training data indices each epoch ---
let mut train_indices: Vec<usize> = (0..train_size).collect();
train_indices.shuffle(&mut rng);
// --- Training Phase ---
let mut epoch_train_loss_sum = 0.0;
for i in 0..num_train_batches {
let batch_start_idx = i * batch_size;
let current_batch_size = (train_size - batch_start_idx).min(batch_size);
if current_batch_size == 0 {
continue;
}
model.zero_grad(stream.clone(), module.clone())?;
// Create shuffled batch data
let mut input_batch_data = vec![0.0f32; current_batch_size * input_dims];
let mut target_batch_data = vec![0.0f32; current_batch_size * output_dims];
// Gather shuffled batch data
for batch_offset in 0..current_batch_size {
let shuffled_sample_idx = train_indices[batch_start_idx + batch_offset];
// Copy input data for this sample
let input_start = shuffled_sample_idx * input_dims;
let batch_input_start = batch_offset * input_dims;
for d in 0..input_dims {
input_batch_data[batch_input_start + d] = train_inputs[input_start + d];
}
// Copy target data for this sample
let target_start = shuffled_sample_idx * output_dims;
let batch_target_start = batch_offset * output_dims;
for d in 0..output_dims {
target_batch_data[batch_target_start + d] = train_targets[target_start + d];
}
}
// Upload shuffled batch to GPU
let mut d_input_batch = stream.alloc_zeros::<f32>(current_batch_size * input_dims)?;
let mut d_target_batch = stream.alloc_zeros::<f32>(current_batch_size * output_dims)?;
stream.memcpy_htod(&input_batch_data, &mut d_input_batch)?;
stream.memcpy_htod(&target_batch_data, &mut d_target_batch)?;
// Query optimizer for parameter modifications before forward pass
let model_params = model.extract_parameters(stream.clone())?;
let original_params = if let Some(modified_params) = optimizer_query_at_params(
optimizer_state.as_ref(),
&model_params,
None,
epoch,
prev_train_loss,
prev_validation_loss,
stream.clone(),
module.clone(),
prop,
)? {
let backup = model_params.clone();
model.set_parameters(&modified_params, stream.clone(), module.clone())?;
Some(backup)
} else {
None
};
let (output, caches) = model.forward(
&d_input_batch,
true,
stream.clone(),
module.clone(),
&cublas,
&cudnn,
)?;
let (loss, grad) = model.loss_and_grad(
&output,
&d_target_batch.as_view(),
stream.clone(),
module.clone(),
)?;
model.backward(
&grad,
&caches,
false,
stream.clone(),
module.clone(),
&cublas,
&cudnn,
)?;
// Restore original parameters if they were modified
if let Some(params_to_restore) = original_params {
model.set_parameters(&params_to_restore, stream.clone(), module.clone())?;
}
// Get gradients and apply optimizer step
let gradients = model.extract_gradients(stream.clone())?;
let param_updates = optimizer_step(
optimizer_state.as_mut(),
&gradients,
epoch,
prev_train_loss,
prev_validation_loss,
stream.clone(),
module.clone(),
prop,
)?;
model.apply_optimizer_updates(&param_updates, stream.clone(), module.clone())?;
let batch_loss = stream.memcpy_dtov(&loss)?[0];
epoch_train_loss_sum += batch_loss * current_batch_size as f32;
}
stream.synchronize()?;
let avg_train_loss = epoch_train_loss_sum / train_size as f32;
prev_train_loss = Some(avg_train_loss);
train_losses.push(avg_train_loss);
// --- Validation Phase ---
let mut epoch_val_loss_sum = 0.0;
if validation_size > 0 {
for i in 0..num_val_batches {
let batch_start = i * batch_size;
let current_batch_size = (validation_size - batch_start).min(batch_size);
if current_batch_size == 0 {
continue;
}
let d_input_batch = validation_inputs_d.slice(
batch_start * input_dims..(batch_start + current_batch_size) * input_dims,
);
let d_target_batch = validation_targets_d.slice(
batch_start * output_dims..(batch_start + current_batch_size) * output_dims,
);
let (output, _) = model.forward(
&d_input_batch,
false,
stream.clone(),
module.clone(),
&cublas,
&cudnn,
)?;
let (loss, _) = model.loss_and_grad(
&output,
&d_target_batch,
stream.clone(),
module.clone(),
)?;
let mut batch_loss_h = vec![0.0; 1];
stream.memcpy_dtoh(&loss, &mut batch_loss_h)?;
epoch_val_loss_sum += batch_loss_h[0] * current_batch_size as f32;
}
}
stream.synchronize()?;
let avg_val_loss = if validation_size > 0 {
epoch_val_loss_sum / validation_size as f32
} else {
avg_train_loss
};
prev_validation_loss = Some(avg_val_loss);
validation_losses.push(avg_val_loss);
// --- Early Stopping ---
if avg_val_loss < lowest_loss - min_loss_delta {
lowest_loss = avg_val_loss;
_best_epoch = epoch;
best_model_solution = Some(to_solution(&model, epoch + 1, stream.clone())?);
epochs_no_improvement = 0;
} else {
epochs_no_improvement += 1;
if epochs_no_improvement >= patience {
break;
}
}
}
stream.synchronize()?;
let solution = best_model_solution.ok_or_else(|| anyhow!("No valid solution found during training. Validation loss may have been NaN or never improved."))?;
Ok((solution, train_losses, validation_losses))
}
pub fn load_solution(mlp: &mut MLP, solution: &Solution, stream: Arc<CudaStream>) -> Result<()> {
for (i, layer) in mlp.lin.iter_mut().enumerate() {
let w_flat: Vec<f32> = solution.weights[i].iter().flatten().cloned().collect();
stream.memcpy_htod(&w_flat, &mut layer.weight)?;
stream.memcpy_htod(&solution.biases[i], &mut layer.bias)?;
}
for (i, bn) in mlp.bns.iter_mut().enumerate() {
stream.memcpy_htod(&solution.bn_weights[i], &mut bn.weight)?;
stream.memcpy_htod(&solution.bn_biases[i], &mut bn.bias)?;
stream.memcpy_htod(&solution.bn_running_means[i], &mut bn.running_mean)?;
stream.memcpy_htod(&solution.bn_running_vars[i], &mut bn.running_var)?;
}
stream.synchronize()?;
Ok(())
}
pub fn to_solution(mlp: &MLP, epochs_used: usize, stream: Arc<CudaStream>) -> Result<Solution> {
stream.synchronize()?;
let mut weights = Vec::new();
let mut biases = Vec::new();
for layer in &mlp.lin {
let mut w_h = vec![0.0; layer.weight.len()];
stream.memcpy_dtoh(&layer.weight, &mut w_h)?;
let mut b_h = vec![0.0; layer.bias.len()];
stream.memcpy_dtoh(&layer.bias, &mut b_h)?;
weights.push(w_h.chunks(layer.in_features).map(|c| c.to_vec()).collect());
biases.push(b_h);
}
let mut bn_weights = Vec::new();
let mut bn_biases = Vec::new();
let mut bn_running_means = Vec::new();
let mut bn_running_vars = Vec::new();
for bn in &mlp.bns {
bn_weights.push(stream.memcpy_dtov(&bn.weight)?);
bn_biases.push(stream.memcpy_dtov(&bn.bias)?);
bn_running_means.push(stream.memcpy_dtov(&bn.running_mean)?);
bn_running_vars.push(stream.memcpy_dtov(&bn.running_var)?);
}
Ok(Solution {
weights,
biases,
epochs_used,
bn_weights,
bn_biases,
bn_running_means,
bn_running_vars,
})
}

View File

@ -32,3 +32,5 @@ c004 = ["cuda", "tig-challenges/c004"]
vector_search = ["c004"]
c005 = ["cuda", "tig-challenges/c005"]
hypergraph = ["c005"]
c006 = ["cuda", "tig-challenges/c006"]
nn_training = ["c006"]

View File

@ -296,6 +296,12 @@ pub fn compute_solution(
#[cfg(feature = "c005")]
dispatch_challenge!(c005, gpu)
}
"c006" => {
#[cfg(not(feature = "c006"))]
panic!("tig-runtime was not compiled with '--features c006'");
#[cfg(feature = "c006")]
dispatch_challenge!(c006, gpu)
}
_ => panic!("Unsupported challenge"),
}
};

View File

@ -30,3 +30,5 @@ c004 = ["cuda", "tig-challenges/c004"]
vector_search = ["c004"]
c005 = ["cuda", "tig-challenges/c005"]
hypergraph = ["c005"]
c006 = ["cuda", "tig-challenges/c006"]
nn_training = ["c006"]

View File

@ -187,6 +187,12 @@ pub fn verify_solution(
#[cfg(feature = "c005")]
dispatch_challenge!(c005, gpu)
}
"c006" => {
#[cfg(not(feature = "c006"))]
panic!("tig-verifier was not compiled with '--features c006'");
#[cfg(feature = "c006")]
dispatch_challenge!(c006, gpu)
}
_ => panic!("Unsupported challenge"),
}