mirror of
https://github.com/tig-foundation/tig-monorepo.git
synced 2026-02-21 10:27:49 +08:00
Optimizer solution size.
Some checks failed
Test Workspace / Test Workspace (push) Has been cancelled
Some checks failed
Test Workspace / Test Workspace (push) Has been cancelled
This commit is contained in:
parent
49176788c3
commit
49106e998f
14
Cargo.lock
generated
14
Cargo.lock
generated
@ -154,6 +154,15 @@ version = "0.22.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9475866fec1451be56a3c2400fd081ff546538961565ccb5b7142cbd22bc7a51"
|
||||
|
||||
[[package]]
|
||||
name = "bincode"
|
||||
version = "1.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bitflags"
|
||||
version = "1.3.2"
|
||||
@ -352,7 +361,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cudarc"
|
||||
version = "0.16.4"
|
||||
source = "git+https://github.com/tig-foundation/cudarc.git?branch=runtime-fuel%2Fcudnn-cublas#76a6231512aabd410377abbbf89bc8deefc48e54"
|
||||
source = "git+https://github.com/tig-foundation/cudarc.git?branch=runtime-fuel%2Fcudnn-cublas#b3fccf5003c6e356bdd36e6808bf6e66f08f98d2"
|
||||
dependencies = [
|
||||
"libloading",
|
||||
]
|
||||
@ -2047,7 +2056,10 @@ name = "tig-challenges"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"base64 0.21.7",
|
||||
"bincode",
|
||||
"cudarc",
|
||||
"flate2",
|
||||
"ndarray",
|
||||
"rand",
|
||||
"serde",
|
||||
|
||||
@ -29,6 +29,8 @@ extern "C" __global__ void finalize_kernel(
|
||||
u_int64_t *errorstat_ptr // RETURNED: (64-bit) Error status
|
||||
)
|
||||
{
|
||||
gbl_FUELUSAGE += fuelusage_ptr[0];
|
||||
gbl_SIGNATURE ^= signature_ptr[0];
|
||||
fuelusage_ptr[0] = gbl_FUELUSAGE; // RETURNED: (64-bit) Fuel usage
|
||||
signature_ptr[0] = gbl_SIGNATURE; // RETURNED: (64-bit) Run-time signature
|
||||
errorstat_ptr[0] = gbl_ERRORSTAT; // RETURNED: (64-bit) Error status -- set to non-zero if fuel runs out
|
||||
|
||||
@ -20,8 +20,12 @@ rand = { version = "0.8.5", default-features = false, features = [
|
||||
serde = { version = "1.0.196", features = ["derive"] }
|
||||
serde_json = { version = "1.0.113" }
|
||||
statrs = { version = "0.18.0" }
|
||||
bincode = { version = "1.3", optional = true }
|
||||
flate2 = { version = "1.0", optional = true }
|
||||
base64 = { version = "0.21", optional = true }
|
||||
|
||||
[features]
|
||||
bincode_solution = ["bincode", "flate2", "base64"]
|
||||
c001 = []
|
||||
satisfiability = ["c001"]
|
||||
c002 = []
|
||||
@ -32,5 +36,5 @@ c004 = ["cudarc"]
|
||||
vector_search = ["c004"]
|
||||
c005 = ["cudarc"]
|
||||
hypergraph = ["c005"]
|
||||
c006 = ["cudarc", "cudarc/cublas", "cudarc/cudnn"]
|
||||
c006 = ["cudarc", "cudarc/cublas", "cudarc/cudnn", "bincode_solution"]
|
||||
neuralnet_optimizer = ["c006"]
|
||||
|
||||
@ -1,14 +1,24 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
|
||||
use cudarc::{
|
||||
cublas::CudaBlas,
|
||||
cudnn::Cudnn,
|
||||
driver::{CudaModule, CudaSlice, CudaStream, CudaView, LaunchConfig, PushKernelArg},
|
||||
runtime::sys::cudaDeviceProp,
|
||||
};
|
||||
use flate2::{read::GzDecoder, write::GzEncoder, Compression};
|
||||
use rand::{prelude::*, rngs::StdRng};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde::{
|
||||
de::{self, Visitor},
|
||||
Deserialize, Deserializer, Serialize, Serializer,
|
||||
};
|
||||
use serde_json::{from_value, Map, Value};
|
||||
use std::{any::Any, sync::Arc};
|
||||
use std::{
|
||||
any::Any,
|
||||
fmt,
|
||||
io::{Read, Write},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use crate::neuralnet::MLP;
|
||||
|
||||
@ -34,22 +44,126 @@ impl Into<Vec<i32>> for Difficulty {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Solution {
|
||||
pub weights: Vec<Vec<Vec<u32>>>,
|
||||
pub biases: Vec<Vec<u32>>,
|
||||
pub weights: Vec<Vec<Vec<f32>>>,
|
||||
pub biases: Vec<Vec<f32>>,
|
||||
pub epochs_used: usize,
|
||||
pub bn_weights: Vec<Vec<u32>>,
|
||||
pub bn_biases: Vec<Vec<u32>>,
|
||||
pub bn_running_means: Vec<Vec<u32>>,
|
||||
pub bn_running_vars: Vec<Vec<u32>>,
|
||||
pub bn_weights: Vec<Vec<f32>>,
|
||||
pub bn_biases: Vec<Vec<f32>>,
|
||||
pub bn_running_means: Vec<Vec<f32>>,
|
||||
pub bn_running_vars: Vec<Vec<f32>>,
|
||||
}
|
||||
|
||||
// Helper struct for (de)serialization
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct SolutionData {
|
||||
weights: Vec<Vec<Vec<f32>>>,
|
||||
biases: Vec<Vec<f32>>,
|
||||
epochs_used: usize,
|
||||
bn_weights: Vec<Vec<f32>>,
|
||||
bn_biases: Vec<Vec<f32>>,
|
||||
bn_running_means: Vec<Vec<f32>>,
|
||||
bn_running_vars: Vec<Vec<f32>>,
|
||||
}
|
||||
|
||||
impl Serialize for Solution {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
// Serialize with bincode
|
||||
let bincode_data = bincode::serialize(&SolutionData {
|
||||
weights: self.weights.clone(),
|
||||
biases: self.biases.clone(),
|
||||
epochs_used: self.epochs_used,
|
||||
bn_weights: self.bn_weights.clone(),
|
||||
bn_biases: self.bn_biases.clone(),
|
||||
bn_running_means: self.bn_running_means.clone(),
|
||||
bn_running_vars: self.bn_running_vars.clone(),
|
||||
})
|
||||
.map_err(|e| serde::ser::Error::custom(format!("Bincode serialization failed: {}", e)))?;
|
||||
|
||||
// Compress with gzip
|
||||
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
|
||||
encoder
|
||||
.write_all(&bincode_data)
|
||||
.map_err(|e| serde::ser::Error::custom(format!("Compression failed: {}", e)))?;
|
||||
let compressed_data = encoder
|
||||
.finish()
|
||||
.map_err(|e| serde::ser::Error::custom(format!("Compression finish failed: {}", e)))?;
|
||||
|
||||
// Encode as base64
|
||||
let base64_string = BASE64.encode(&compressed_data);
|
||||
|
||||
// Serialize the base64 string
|
||||
serializer.serialize_str(&base64_string)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for Solution {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Solution, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
struct SolutionVisitor;
|
||||
|
||||
impl<'de> Visitor<'de> for SolutionVisitor {
|
||||
type Value = Solution;
|
||||
|
||||
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
||||
formatter.write_str("a base64 encoded, compressed, bincode serialized Solution")
|
||||
}
|
||||
|
||||
fn visit_str<E>(self, value: &str) -> Result<Solution, E>
|
||||
where
|
||||
E: de::Error,
|
||||
{
|
||||
// Decode from base64
|
||||
let compressed_data = BASE64
|
||||
.decode(value)
|
||||
.map_err(|e| E::custom(format!("Base64 decode failed: {}", e)))?;
|
||||
|
||||
// Decompress
|
||||
let mut decoder = GzDecoder::new(&compressed_data[..]);
|
||||
let mut bincode_data = Vec::new();
|
||||
decoder
|
||||
.read_to_end(&mut bincode_data)
|
||||
.map_err(|e| E::custom(format!("Decompression failed: {}", e)))?;
|
||||
|
||||
// Deserialize with bincode
|
||||
let solution: SolutionData = bincode::deserialize(&bincode_data)
|
||||
.map_err(|e| E::custom(format!("Bincode deserialization failed: {}", e)))?;
|
||||
|
||||
Ok(Solution {
|
||||
weights: solution.weights,
|
||||
biases: solution.biases,
|
||||
epochs_used: solution.epochs_used,
|
||||
bn_weights: solution.bn_weights,
|
||||
bn_biases: solution.bn_biases,
|
||||
bn_running_means: solution.bn_running_means,
|
||||
bn_running_vars: solution.bn_running_vars,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
deserializer.deserialize_str(SolutionVisitor)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<Map<String, Value>> for Solution {
|
||||
type Error = serde_json::Error;
|
||||
|
||||
fn try_from(v: Map<String, Value>) -> Result<Self, Self::Error> {
|
||||
from_value(Value::Object(v))
|
||||
let base64_value = v
|
||||
.get("base64")
|
||||
.ok_or_else(|| de::Error::custom("Missing 'base64' field"))?;
|
||||
|
||||
let base64_string = base64_value
|
||||
.as_str()
|
||||
.ok_or_else(|| de::Error::custom("'base64' field must be a string"))?;
|
||||
|
||||
from_value(Value::String(base64_string.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
@ -595,45 +709,30 @@ pub fn training_loop(
|
||||
Ok((solution, train_losses, validation_losses))
|
||||
}
|
||||
|
||||
fn vec_u32_to_f32(vec: &Vec<u32>) -> Vec<f32> {
|
||||
vec.iter().map(|u| f32::from_bits(*u)).collect()
|
||||
}
|
||||
|
||||
pub fn load_solution(mlp: &mut MLP, solution: &Solution, stream: Arc<CudaStream>) -> Result<()> {
|
||||
for (i, layer) in mlp.lin.iter_mut().enumerate() {
|
||||
let w_flat: Vec<f32> =
|
||||
vec_u32_to_f32(&solution.weights[i].iter().flatten().cloned().collect());
|
||||
let w_flat: Vec<f32> = solution.weights[i].iter().flatten().cloned().collect();
|
||||
stream.memcpy_htod(&w_flat, &mut layer.weight)?;
|
||||
|
||||
stream.memcpy_htod(&vec_u32_to_f32(&solution.biases[i]), &mut layer.bias)?;
|
||||
stream.memcpy_htod(&solution.biases[i], &mut layer.bias)?;
|
||||
}
|
||||
for (i, bn) in mlp.bns.iter_mut().enumerate() {
|
||||
stream.memcpy_htod(&vec_u32_to_f32(&solution.bn_weights[i]), &mut bn.weight)?;
|
||||
stream.memcpy_htod(&vec_u32_to_f32(&solution.bn_biases[i]), &mut bn.bias)?;
|
||||
stream.memcpy_htod(
|
||||
&vec_u32_to_f32(&solution.bn_running_means[i]),
|
||||
&mut bn.running_mean,
|
||||
)?;
|
||||
stream.memcpy_htod(
|
||||
&vec_u32_to_f32(&solution.bn_running_vars[i]),
|
||||
&mut bn.running_var,
|
||||
)?;
|
||||
stream.memcpy_htod(&solution.bn_weights[i], &mut bn.weight)?;
|
||||
stream.memcpy_htod(&solution.bn_biases[i], &mut bn.bias)?;
|
||||
stream.memcpy_htod(&solution.bn_running_means[i], &mut bn.running_mean)?;
|
||||
stream.memcpy_htod(&solution.bn_running_vars[i], &mut bn.running_var)?;
|
||||
}
|
||||
stream.synchronize()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn vec_f32_to_u32(vec: Vec<f32>) -> Vec<u32> {
|
||||
vec.into_iter().map(|f| f.to_bits()).collect()
|
||||
}
|
||||
|
||||
pub fn to_solution(mlp: &MLP, epochs_used: usize, stream: Arc<CudaStream>) -> Result<Solution> {
|
||||
stream.synchronize()?;
|
||||
let mut weights = Vec::new();
|
||||
let mut biases = Vec::new();
|
||||
for layer in &mlp.lin {
|
||||
let w = vec_f32_to_u32(stream.memcpy_dtov(&layer.weight)?);
|
||||
let b = vec_f32_to_u32(stream.memcpy_dtov(&layer.bias)?);
|
||||
let w = stream.memcpy_dtov(&layer.weight)?;
|
||||
let b = stream.memcpy_dtov(&layer.bias)?;
|
||||
|
||||
weights.push(w.chunks(layer.in_features).map(|c| c.to_vec()).collect());
|
||||
biases.push(b);
|
||||
@ -645,10 +744,10 @@ pub fn to_solution(mlp: &MLP, epochs_used: usize, stream: Arc<CudaStream>) -> Re
|
||||
let mut bn_running_vars = Vec::new();
|
||||
|
||||
for bn in &mlp.bns {
|
||||
bn_weights.push(vec_f32_to_u32(stream.memcpy_dtov(&bn.weight)?));
|
||||
bn_biases.push(vec_f32_to_u32(stream.memcpy_dtov(&bn.bias)?));
|
||||
bn_running_means.push(vec_f32_to_u32(stream.memcpy_dtov(&bn.running_mean)?));
|
||||
bn_running_vars.push(vec_f32_to_u32(stream.memcpy_dtov(&bn.running_var)?));
|
||||
bn_weights.push(stream.memcpy_dtov(&bn.weight)?);
|
||||
bn_biases.push(stream.memcpy_dtov(&bn.bias)?);
|
||||
bn_running_means.push(stream.memcpy_dtov(&bn.running_mean)?);
|
||||
bn_running_vars.push(stream.memcpy_dtov(&bn.running_var)?);
|
||||
}
|
||||
|
||||
Ok(Solution {
|
||||
|
||||
@ -250,11 +250,18 @@ pub fn compute_solution(
|
||||
Some(s) => {
|
||||
match challenge.verify_solution(&s, module.clone(), stream.clone(), &prop) {
|
||||
Ok(_) => (
|
||||
serde_json::to_value(&s)
|
||||
.unwrap()
|
||||
.as_object()
|
||||
.unwrap()
|
||||
.to_owned(),
|
||||
match serde_json::to_value(&s).unwrap() {
|
||||
serde_json::Value::String(s) => {
|
||||
let mut map = serde_json::Map::new();
|
||||
map.insert(
|
||||
"base64".to_string(),
|
||||
serde_json::Value::String(s),
|
||||
);
|
||||
map
|
||||
}
|
||||
serde_json::Value::Object(map) => map,
|
||||
_ => panic!("Expected String or Object from to_value"),
|
||||
},
|
||||
None,
|
||||
),
|
||||
Err(e) => (Solution::new(), Some(e.to_string())),
|
||||
|
||||
Loading…
Reference in New Issue
Block a user