From 49106e998fde8f6cb9d205bc92bf7fcd14c82265 Mon Sep 17 00:00:00 2001 From: FiveMovesAhead Date: Fri, 12 Sep 2025 09:28:05 +0100 Subject: [PATCH] Optimizer solution size. --- Cargo.lock | 14 +- tig-binary/src/framework.cu | 2 + tig-challenges/Cargo.toml | 6 +- tig-challenges/src/neuralnet_optimizer.rs | 173 +++++++++++++++++----- tig-runtime/src/main.rs | 17 ++- 5 files changed, 168 insertions(+), 44 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2c10ab45..edc36c92 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -154,6 +154,15 @@ version = "0.22.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9475866fec1451be56a3c2400fd081ff546538961565ccb5b7142cbd22bc7a51" +[[package]] +name = "bincode" +version = "1.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" +dependencies = [ + "serde", +] + [[package]] name = "bitflags" version = "1.3.2" @@ -352,7 +361,7 @@ dependencies = [ [[package]] name = "cudarc" version = "0.16.4" -source = "git+https://github.com/tig-foundation/cudarc.git?branch=runtime-fuel%2Fcudnn-cublas#76a6231512aabd410377abbbf89bc8deefc48e54" +source = "git+https://github.com/tig-foundation/cudarc.git?branch=runtime-fuel%2Fcudnn-cublas#b3fccf5003c6e356bdd36e6808bf6e66f08f98d2" dependencies = [ "libloading", ] @@ -2047,7 +2056,10 @@ name = "tig-challenges" version = "0.1.0" dependencies = [ "anyhow", + "base64 0.21.7", + "bincode", "cudarc", + "flate2", "ndarray", "rand", "serde", diff --git a/tig-binary/src/framework.cu b/tig-binary/src/framework.cu index c2e2a13f..22f63f5c 100644 --- a/tig-binary/src/framework.cu +++ b/tig-binary/src/framework.cu @@ -29,6 +29,8 @@ extern "C" __global__ void finalize_kernel( u_int64_t *errorstat_ptr // RETURNED: (64-bit) Error status ) { + gbl_FUELUSAGE += fuelusage_ptr[0]; + gbl_SIGNATURE ^= signature_ptr[0]; fuelusage_ptr[0] = gbl_FUELUSAGE; // RETURNED: (64-bit) Fuel usage signature_ptr[0] = gbl_SIGNATURE; // RETURNED: (64-bit) Run-time signature errorstat_ptr[0] = gbl_ERRORSTAT; // RETURNED: (64-bit) Error status -- set to non-zero if fuel runs out diff --git a/tig-challenges/Cargo.toml b/tig-challenges/Cargo.toml index dcc568d8..64511a69 100644 --- a/tig-challenges/Cargo.toml +++ b/tig-challenges/Cargo.toml @@ -20,8 +20,12 @@ rand = { version = "0.8.5", default-features = false, features = [ serde = { version = "1.0.196", features = ["derive"] } serde_json = { version = "1.0.113" } statrs = { version = "0.18.0" } +bincode = { version = "1.3", optional = true } +flate2 = { version = "1.0", optional = true } +base64 = { version = "0.21", optional = true } [features] +bincode_solution = ["bincode", "flate2", "base64"] c001 = [] satisfiability = ["c001"] c002 = [] @@ -32,5 +36,5 @@ c004 = ["cudarc"] vector_search = ["c004"] c005 = ["cudarc"] hypergraph = ["c005"] -c006 = ["cudarc", "cudarc/cublas", "cudarc/cudnn"] +c006 = ["cudarc", "cudarc/cublas", "cudarc/cudnn", "bincode_solution"] neuralnet_optimizer = ["c006"] diff --git a/tig-challenges/src/neuralnet_optimizer.rs b/tig-challenges/src/neuralnet_optimizer.rs index 5eacf2fd..68a33750 100644 --- a/tig-challenges/src/neuralnet_optimizer.rs +++ b/tig-challenges/src/neuralnet_optimizer.rs @@ -1,14 +1,24 @@ use anyhow::{anyhow, Result}; +use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _}; use cudarc::{ cublas::CudaBlas, cudnn::Cudnn, driver::{CudaModule, CudaSlice, CudaStream, CudaView, LaunchConfig, PushKernelArg}, runtime::sys::cudaDeviceProp, }; +use flate2::{read::GzDecoder, write::GzEncoder, Compression}; use rand::{prelude::*, rngs::StdRng}; -use serde::{Deserialize, Serialize}; +use serde::{ + de::{self, Visitor}, + Deserialize, Deserializer, Serialize, Serializer, +}; use serde_json::{from_value, Map, Value}; -use std::{any::Any, sync::Arc}; +use std::{ + any::Any, + fmt, + io::{Read, Write}, + sync::Arc, +}; use crate::neuralnet::MLP; @@ -34,22 +44,126 @@ impl Into> for Difficulty { } } -#[derive(Serialize, Deserialize, Debug, Clone)] +#[derive(Debug, Clone)] pub struct Solution { - pub weights: Vec>>, - pub biases: Vec>, + pub weights: Vec>>, + pub biases: Vec>, pub epochs_used: usize, - pub bn_weights: Vec>, - pub bn_biases: Vec>, - pub bn_running_means: Vec>, - pub bn_running_vars: Vec>, + pub bn_weights: Vec>, + pub bn_biases: Vec>, + pub bn_running_means: Vec>, + pub bn_running_vars: Vec>, +} + +// Helper struct for (de)serialization +#[derive(Serialize, Deserialize)] +struct SolutionData { + weights: Vec>>, + biases: Vec>, + epochs_used: usize, + bn_weights: Vec>, + bn_biases: Vec>, + bn_running_means: Vec>, + bn_running_vars: Vec>, +} + +impl Serialize for Solution { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + // Serialize with bincode + let bincode_data = bincode::serialize(&SolutionData { + weights: self.weights.clone(), + biases: self.biases.clone(), + epochs_used: self.epochs_used, + bn_weights: self.bn_weights.clone(), + bn_biases: self.bn_biases.clone(), + bn_running_means: self.bn_running_means.clone(), + bn_running_vars: self.bn_running_vars.clone(), + }) + .map_err(|e| serde::ser::Error::custom(format!("Bincode serialization failed: {}", e)))?; + + // Compress with gzip + let mut encoder = GzEncoder::new(Vec::new(), Compression::default()); + encoder + .write_all(&bincode_data) + .map_err(|e| serde::ser::Error::custom(format!("Compression failed: {}", e)))?; + let compressed_data = encoder + .finish() + .map_err(|e| serde::ser::Error::custom(format!("Compression finish failed: {}", e)))?; + + // Encode as base64 + let base64_string = BASE64.encode(&compressed_data); + + // Serialize the base64 string + serializer.serialize_str(&base64_string) + } +} + +impl<'de> Deserialize<'de> for Solution { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct SolutionVisitor; + + impl<'de> Visitor<'de> for SolutionVisitor { + type Value = Solution; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a base64 encoded, compressed, bincode serialized Solution") + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + // Decode from base64 + let compressed_data = BASE64 + .decode(value) + .map_err(|e| E::custom(format!("Base64 decode failed: {}", e)))?; + + // Decompress + let mut decoder = GzDecoder::new(&compressed_data[..]); + let mut bincode_data = Vec::new(); + decoder + .read_to_end(&mut bincode_data) + .map_err(|e| E::custom(format!("Decompression failed: {}", e)))?; + + // Deserialize with bincode + let solution: SolutionData = bincode::deserialize(&bincode_data) + .map_err(|e| E::custom(format!("Bincode deserialization failed: {}", e)))?; + + Ok(Solution { + weights: solution.weights, + biases: solution.biases, + epochs_used: solution.epochs_used, + bn_weights: solution.bn_weights, + bn_biases: solution.bn_biases, + bn_running_means: solution.bn_running_means, + bn_running_vars: solution.bn_running_vars, + }) + } + } + + deserializer.deserialize_str(SolutionVisitor) + } } impl TryFrom> for Solution { type Error = serde_json::Error; fn try_from(v: Map) -> Result { - from_value(Value::Object(v)) + let base64_value = v + .get("base64") + .ok_or_else(|| de::Error::custom("Missing 'base64' field"))?; + + let base64_string = base64_value + .as_str() + .ok_or_else(|| de::Error::custom("'base64' field must be a string"))?; + + from_value(Value::String(base64_string.to_string())) } } @@ -595,45 +709,30 @@ pub fn training_loop( Ok((solution, train_losses, validation_losses)) } -fn vec_u32_to_f32(vec: &Vec) -> Vec { - vec.iter().map(|u| f32::from_bits(*u)).collect() -} - pub fn load_solution(mlp: &mut MLP, solution: &Solution, stream: Arc) -> Result<()> { for (i, layer) in mlp.lin.iter_mut().enumerate() { - let w_flat: Vec = - vec_u32_to_f32(&solution.weights[i].iter().flatten().cloned().collect()); + let w_flat: Vec = solution.weights[i].iter().flatten().cloned().collect(); stream.memcpy_htod(&w_flat, &mut layer.weight)?; - stream.memcpy_htod(&vec_u32_to_f32(&solution.biases[i]), &mut layer.bias)?; + stream.memcpy_htod(&solution.biases[i], &mut layer.bias)?; } for (i, bn) in mlp.bns.iter_mut().enumerate() { - stream.memcpy_htod(&vec_u32_to_f32(&solution.bn_weights[i]), &mut bn.weight)?; - stream.memcpy_htod(&vec_u32_to_f32(&solution.bn_biases[i]), &mut bn.bias)?; - stream.memcpy_htod( - &vec_u32_to_f32(&solution.bn_running_means[i]), - &mut bn.running_mean, - )?; - stream.memcpy_htod( - &vec_u32_to_f32(&solution.bn_running_vars[i]), - &mut bn.running_var, - )?; + stream.memcpy_htod(&solution.bn_weights[i], &mut bn.weight)?; + stream.memcpy_htod(&solution.bn_biases[i], &mut bn.bias)?; + stream.memcpy_htod(&solution.bn_running_means[i], &mut bn.running_mean)?; + stream.memcpy_htod(&solution.bn_running_vars[i], &mut bn.running_var)?; } stream.synchronize()?; Ok(()) } -fn vec_f32_to_u32(vec: Vec) -> Vec { - vec.into_iter().map(|f| f.to_bits()).collect() -} - pub fn to_solution(mlp: &MLP, epochs_used: usize, stream: Arc) -> Result { stream.synchronize()?; let mut weights = Vec::new(); let mut biases = Vec::new(); for layer in &mlp.lin { - let w = vec_f32_to_u32(stream.memcpy_dtov(&layer.weight)?); - let b = vec_f32_to_u32(stream.memcpy_dtov(&layer.bias)?); + let w = stream.memcpy_dtov(&layer.weight)?; + let b = stream.memcpy_dtov(&layer.bias)?; weights.push(w.chunks(layer.in_features).map(|c| c.to_vec()).collect()); biases.push(b); @@ -645,10 +744,10 @@ pub fn to_solution(mlp: &MLP, epochs_used: usize, stream: Arc) -> Re let mut bn_running_vars = Vec::new(); for bn in &mlp.bns { - bn_weights.push(vec_f32_to_u32(stream.memcpy_dtov(&bn.weight)?)); - bn_biases.push(vec_f32_to_u32(stream.memcpy_dtov(&bn.bias)?)); - bn_running_means.push(vec_f32_to_u32(stream.memcpy_dtov(&bn.running_mean)?)); - bn_running_vars.push(vec_f32_to_u32(stream.memcpy_dtov(&bn.running_var)?)); + bn_weights.push(stream.memcpy_dtov(&bn.weight)?); + bn_biases.push(stream.memcpy_dtov(&bn.bias)?); + bn_running_means.push(stream.memcpy_dtov(&bn.running_mean)?); + bn_running_vars.push(stream.memcpy_dtov(&bn.running_var)?); } Ok(Solution { diff --git a/tig-runtime/src/main.rs b/tig-runtime/src/main.rs index 4292accd..5ab4c6a8 100644 --- a/tig-runtime/src/main.rs +++ b/tig-runtime/src/main.rs @@ -250,11 +250,18 @@ pub fn compute_solution( Some(s) => { match challenge.verify_solution(&s, module.clone(), stream.clone(), &prop) { Ok(_) => ( - serde_json::to_value(&s) - .unwrap() - .as_object() - .unwrap() - .to_owned(), + match serde_json::to_value(&s).unwrap() { + serde_json::Value::String(s) => { + let mut map = serde_json::Map::new(); + map.insert( + "base64".to_string(), + serde_json::Value::String(s), + ); + map + } + serde_json::Value::Object(map) => map, + _ => panic!("Expected String or Object from to_value"), + }, None, ), Err(e) => (Solution::new(), Some(e.to_string())),