Remove sub-instances, and apply bincode serialization to Solution.

This commit is contained in:
FiveMovesAhead 2025-10-06 10:40:13 +01:00
parent 2f0a76f046
commit b5280ab287
8 changed files with 352 additions and 613 deletions

View File

@ -9,9 +9,13 @@ edition.workspace = true
[dependencies]
anyhow = "1.0.81"
base64 = "0.21"
bincode = "1.3"
cudarc = { git = "https://github.com/tig-foundation/cudarc.git", branch = "runtime-fuel/cudnn-cublas", features = [
"cuda-version-from-build-system",
], optional = true }
flate2 = "1.0"
paste = "1.0.15"
ndarray = "0.15.6"
rand = { version = "0.8.5", default-features = false, features = [
"std_rng",
@ -20,13 +24,9 @@ rand = { version = "0.8.5", default-features = false, features = [
serde = { version = "1.0.196", features = ["derive"] }
serde_json = { version = "1.0.113" }
statrs = { version = "0.18.0" }
bincode = { version = "1.3", optional = true }
flate2 = { version = "1.0", optional = true }
base64 = { version = "0.21", optional = true }
[features]
hide_verification = []
bincode_solution = ["bincode", "flate2", "base64"]
c001 = []
satisfiability = ["c001"]
c002 = []
@ -37,5 +37,5 @@ c004 = ["cudarc"]
vector_search = ["c004"]
c005 = ["cudarc"]
hypergraph = ["c005"]
c006 = ["cudarc", "cudarc/cublas", "cudarc/cudnn", "bincode_solution"]
c006 = ["cudarc", "cudarc/cublas", "cudarc/cudnn"]
neuralnet_optimizer = ["c006"]

View File

@ -3,7 +3,6 @@ use cudarc::driver::*;
use cudarc::runtime::sys::cudaDeviceProp;
use rand::{rngs::StdRng, Rng, SeedableRng};
use serde::{Deserialize, Serialize};
use serde_json::{from_value, Map, Value};
use std::sync::Arc;
#[derive(Serialize, Deserialize, Debug, Clone)]
@ -29,39 +28,21 @@ impl Into<Vec<i32>> for Difficulty {
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Solution {
pub sub_solutions: Vec<SubSolution>,
impl_base64_serde! {
Solution {
partition: Vec<u32>,
}
}
impl Solution {
pub fn new() -> Self {
Self {
sub_solutions: Vec::new(),
partition: Vec::new(),
}
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct SubSolution {
pub partition: Vec<u32>,
}
impl TryFrom<Map<String, Value>> for Solution {
type Error = serde_json::Error;
fn try_from(v: Map<String, Value>) -> Result<Self, Self::Error> {
from_value(Value::Object(v))
}
}
pub struct Challenge {
pub seed: [u8; 32],
pub difficulty: Difficulty,
pub sub_instances: Vec<SubInstance>,
}
pub struct SubInstance {
pub seed: [u8; 32],
pub difficulty: Difficulty,
pub num_nodes: u32,
@ -84,83 +65,9 @@ pub struct SubInstance {
baseline_connectivity_metric: u32,
}
pub const NUM_SUB_INSTANCES: usize = 4;
pub const MAX_THREADS_PER_BLOCK: u32 = 1024;
impl Challenge {
pub fn generate_instance(
seed: &[u8; 32],
difficulty: &Difficulty,
module: Arc<CudaModule>,
stream: Arc<CudaStream>,
prop: &cudaDeviceProp,
) -> Result<Challenge> {
let mut rng = StdRng::from_seed(seed.clone());
let mut sub_instances = Vec::new();
for _ in 0..NUM_SUB_INSTANCES {
sub_instances.push(SubInstance::generate_instance(
&rng.gen(),
difficulty,
module.clone(),
stream.clone(),
prop,
)?);
}
Ok(Challenge {
seed: seed.clone(),
difficulty: difficulty.clone(),
sub_instances,
})
}
conditional_pub!(
fn verify_solution(
&self,
solution: &Solution,
module: Arc<CudaModule>,
stream: Arc<CudaStream>,
prop: &cudaDeviceProp,
) -> Result<()> {
let mut better_than_baselines = Vec::new();
for (i, (sub_instance, sub_solution)) in self
.sub_instances
.iter()
.zip(&solution.sub_solutions)
.enumerate()
{
match sub_instance.verify_solution(
sub_solution,
module.clone(),
stream.clone(),
prop,
) {
Ok(connectivity_metric) => better_than_baselines.push(
connectivity_metric as f64
/ sub_instance.baseline_connectivity_metric as f64,
),
Err(e) => return Err(anyhow!("Instance {}: {}", i, e.to_string())),
}
}
let average = 1.0
- (better_than_baselines.iter().map(|x| x * x).sum::<f64>()
/ better_than_baselines.len() as f64)
.sqrt();
let threshold = self.difficulty.better_than_baseline as f64 / 1000.0;
if average >= threshold {
Ok(())
} else {
Err(anyhow!(
"Average better_than_baseline ({}) is less than ({})",
average,
threshold
))
}
}
);
}
impl SubInstance {
pub fn generate_instance(
seed: &[u8; 32],
difficulty: &Difficulty,
@ -479,99 +386,123 @@ impl SubInstance {
})
}
pub fn calc_connectivity_metric(
&self,
solution: &Solution,
module: Arc<CudaModule>,
stream: Arc<CudaStream>,
_prop: &cudaDeviceProp,
) -> Result<u32> {
if solution.partition.len() != self.num_nodes as usize {
return Err(anyhow!(
"Invalid number of partitions. Expected: {}, Actual: {}",
self.num_nodes,
solution.partition.len()
));
}
// Get the kernels
let validate_partition_kernel = module.load_function("validate_partition")?;
let calc_connectivity_metric_kernel = module.load_function("calc_connectivity_metric")?;
let count_nodes_in_part_kernel = module.load_function("count_nodes_in_part")?;
let block_size = MAX_THREADS_PER_BLOCK;
let grid_size = (self.difficulty.num_hyperedges + block_size - 1) / block_size;
let cfg = LaunchConfig {
grid_dim: (grid_size, 1, 1),
block_dim: (block_size, 1, 1),
shared_mem_bytes: 0,
};
// 1.1 Check if all nodes are assigned to a part
let d_partition = stream.memcpy_stod(&solution.partition)?;
let mut d_error_flag = stream.alloc_zeros::<u32>(1)?;
unsafe {
stream
.launch_builder(&validate_partition_kernel)
.arg(&self.num_nodes)
.arg(&self.num_parts)
.arg(&d_partition)
.arg(&mut d_error_flag)
.launch(cfg)?;
}
stream.synchronize()?;
if stream.memcpy_dtov(&d_error_flag)?[0] != 0 {
return Err(anyhow!(
"Invalid partition. All nodes must be assigned to one of {} parts",
self.num_parts
));
};
// 1.2 Check if any partition exceeds the maximum size
let mut d_nodes_in_part = stream.alloc_zeros::<u32>(self.num_parts as usize)?;
unsafe {
stream
.launch_builder(&count_nodes_in_part_kernel)
.arg(&self.num_nodes)
.arg(&self.num_parts)
.arg(&d_partition)
.arg(&mut d_nodes_in_part)
.launch(cfg.clone())?;
}
stream.synchronize()?;
let nodes_in_partition = stream.memcpy_dtov(&d_nodes_in_part)?;
if nodes_in_partition
.iter()
.any(|&x| x < 1 || x > self.max_part_size)
{
return Err(anyhow!(
"Each part must have at least 1 and at most {} nodes",
self.max_part_size
));
}
// 1.3 Calculate connectivity
let mut d_connectivity_metric = stream.alloc_zeros::<u32>(1)?;
unsafe {
stream
.launch_builder(&calc_connectivity_metric_kernel)
.arg(&self.difficulty.num_hyperedges)
.arg(&self.d_hyperedge_offsets)
.arg(&self.d_hyperedge_nodes)
.arg(&d_partition)
.arg(&mut d_connectivity_metric)
.launch(cfg.clone())?;
}
stream.synchronize()?;
let connectivity_metric = stream.memcpy_dtov(&d_connectivity_metric)?[0];
Ok(connectivity_metric)
}
conditional_pub!(
fn verify_solution(
&self,
solution: &SubSolution,
solution: &Solution,
module: Arc<CudaModule>,
stream: Arc<CudaStream>,
_prop: &cudaDeviceProp,
) -> Result<u32> {
if solution.partition.len() != self.num_nodes as usize {
return Err(anyhow!(
"Invalid number of partitions. Expected: {}, Actual: {}",
self.num_nodes,
solution.partition.len()
));
) -> Result<()> {
let connectivity_metric =
self.calc_connectivity_metric(solution, module, stream, _prop)?;
let btb = self.difficulty.better_than_baseline as f64 / 1000.0;
let connectivity_metric_threshold =
(self.baseline_connectivity_metric as f64 * (1.0 - btb)).ceil() as u32;
if connectivity_metric > connectivity_metric_threshold {
Err(anyhow!(
"connectivity_metric {} is greater than threshold {} (baseline: {}, better_than_baseline: {}%)",
connectivity_metric,
connectivity_metric_threshold,
self.baseline_connectivity_metric,
btb * 100.0
))
} else {
Ok(())
}
// Get the kernels
let validate_partition_kernel = module.load_function("validate_partition")?;
let calc_connectivity_metric_kernel =
module.load_function("calc_connectivity_metric")?;
let count_nodes_in_part_kernel = module.load_function("count_nodes_in_part")?;
let block_size = MAX_THREADS_PER_BLOCK;
let grid_size = (self.difficulty.num_hyperedges + block_size - 1) / block_size;
let cfg = LaunchConfig {
grid_dim: (grid_size, 1, 1),
block_dim: (block_size, 1, 1),
shared_mem_bytes: 0,
};
// 1.1 Check if all nodes are assigned to a part
let d_partition = stream.memcpy_stod(&solution.partition)?;
let mut d_error_flag = stream.alloc_zeros::<u32>(1)?;
unsafe {
stream
.launch_builder(&validate_partition_kernel)
.arg(&self.num_nodes)
.arg(&self.num_parts)
.arg(&d_partition)
.arg(&mut d_error_flag)
.launch(cfg)?;
}
stream.synchronize()?;
if stream.memcpy_dtov(&d_error_flag)?[0] != 0 {
return Err(anyhow!(
"Invalid partition. All nodes must be assigned to one of {} parts",
self.num_parts
));
};
// 1.2 Check if any partition exceeds the maximum size
let mut d_nodes_in_part = stream.alloc_zeros::<u32>(self.num_parts as usize)?;
unsafe {
stream
.launch_builder(&count_nodes_in_part_kernel)
.arg(&self.num_nodes)
.arg(&self.num_parts)
.arg(&d_partition)
.arg(&mut d_nodes_in_part)
.launch(cfg.clone())?;
}
stream.synchronize()?;
let nodes_in_partition = stream.memcpy_dtov(&d_nodes_in_part)?;
if nodes_in_partition
.iter()
.any(|&x| x < 1 || x > self.max_part_size)
{
return Err(anyhow!(
"Each part must have at least 1 and at most {} nodes",
self.max_part_size
));
}
// 1.3 Calculate connectivity
let mut d_connectivity_metric = stream.alloc_zeros::<u32>(1)?;
unsafe {
stream
.launch_builder(&calc_connectivity_metric_kernel)
.arg(&self.difficulty.num_hyperedges)
.arg(&self.d_hyperedge_offsets)
.arg(&self.d_hyperedge_nodes)
.arg(&d_partition)
.arg(&mut d_connectivity_metric)
.launch(cfg.clone())?;
}
stream.synchronize()?;
let connectivity_metric = stream.memcpy_dtov(&d_connectivity_metric)?[0];
Ok(connectivity_metric)
}
);
}

View File

@ -1,10 +1,6 @@
use anyhow::{anyhow, Result};
use rand::{
rngs::{SmallRng, StdRng},
Rng, SeedableRng,
};
use rand::{rngs::SmallRng, Rng, SeedableRng};
use serde::{Deserialize, Serialize};
use serde_json::{from_value, Map, Value};
use std::collections::HashSet;
#[derive(Serialize, Deserialize, Debug, Clone)]
@ -31,41 +27,20 @@ impl Into<Vec<i32>> for Difficulty {
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Solution {
pub sub_solutions: Vec<SubSolution>,
impl_base64_serde! {
Solution {
items: Vec<usize>,
}
}
impl Solution {
pub fn new() -> Self {
Self {
sub_solutions: Vec::new(),
}
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct SubSolution {
pub items: Vec<usize>,
}
impl TryFrom<Map<String, Value>> for Solution {
type Error = serde_json::Error;
fn try_from(v: Map<String, Value>) -> Result<Self, Self::Error> {
from_value(Value::Object(v))
Self { items: Vec::new() }
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Challenge {
pub seed: [u8; 32],
pub difficulty: Difficulty,
pub sub_instances: Vec<SubInstance>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct SubInstance {
pub seed: [u8; 32],
pub difficulty: Difficulty,
pub weights: Vec<u32>,
@ -78,58 +53,8 @@ pub struct SubInstance {
baseline_value: u32,
}
pub const NUM_SUB_INSTANCES: usize = 16;
impl Challenge {
pub fn generate_instance(seed: &[u8; 32], difficulty: &Difficulty) -> Result<Challenge> {
let mut rng = StdRng::from_seed(seed.clone());
let mut sub_instances = Vec::new();
for _ in 0..NUM_SUB_INSTANCES {
sub_instances.push(SubInstance::generate_instance(&rng.gen(), difficulty)?);
}
Ok(Challenge {
seed: seed.clone(),
difficulty: difficulty.clone(),
sub_instances,
})
}
conditional_pub!(
fn verify_solution(&self, solution: &Solution) -> Result<()> {
let mut better_than_baselines = Vec::new();
for (i, (sub_instance, sub_solution)) in self
.sub_instances
.iter()
.zip(&solution.sub_solutions)
.enumerate()
{
match sub_instance.verify_solution(&sub_solution) {
Ok(total_value) => better_than_baselines
.push(total_value as f64 / sub_instance.baseline_value as f64),
Err(e) => return Err(anyhow!("Instance {}: {}", i, e.to_string())),
}
}
let average = (better_than_baselines.iter().map(|x| x * x).sum::<f64>()
/ better_than_baselines.len() as f64)
.sqrt()
- 1.0;
let threshold = self.difficulty.better_than_baseline as f64 / 10000.0;
if average >= threshold {
Ok(())
} else {
Err(anyhow!(
"Average better_than_baseline ({}) is less than ({})",
average,
threshold
))
}
}
);
}
impl SubInstance {
pub fn generate_instance(seed: &[u8; 32], difficulty: &Difficulty) -> Result<SubInstance> {
pub fn generate_instance(seed: &[u8; 32], difficulty: &Difficulty) -> Result<Self> {
let mut rng = SmallRng::from_seed(seed.clone());
// Set constant density for value generation
let density = 0.25;
@ -308,7 +233,7 @@ impl SubInstance {
let baseline_value = calculate_total_value(&selected_items, &values, &interaction_values);
Ok(SubInstance {
Ok(Challenge {
seed: seed.clone(),
difficulty: difficulty.clone(),
weights,
@ -319,36 +244,53 @@ impl SubInstance {
})
}
pub fn calculate_total_value(&self, solution: &Solution) -> Result<u32> {
let selected_items: HashSet<usize> = solution.items.iter().cloned().collect();
if selected_items.len() != solution.items.len() {
return Err(anyhow!("Duplicate items selected."));
}
let total_weight = selected_items
.iter()
.map(|&item| {
if item >= self.weights.len() {
return Err(anyhow!("Item ({}) is out of bounds", item));
}
Ok(self.weights[item])
})
.collect::<Result<Vec<_>, _>>()?
.iter()
.sum::<u32>();
if total_weight > self.max_weight {
return Err(anyhow!(
"Total weight ({}) exceeded max weight ({})",
total_weight,
self.max_weight
));
}
let selected_items_vec: Vec<usize> = selected_items.into_iter().collect();
let total_value =
calculate_total_value(&selected_items_vec, &self.values, &self.interaction_values);
Ok(total_value)
}
conditional_pub!(
fn verify_solution(&self, solution: &SubSolution) -> Result<u32> {
let selected_items: HashSet<usize> = solution.items.iter().cloned().collect();
if selected_items.len() != solution.items.len() {
return Err(anyhow!("Duplicate items selected."));
fn verify_solution(&self, solution: &Solution) -> Result<()> {
let total_value = self.calculate_total_value(solution)?;
let btb = self.difficulty.better_than_baseline as f64 / 10000.0;
let total_value_threshold = (self.baseline_value as f64 * (1.0 + btb)).floor() as u32;
if total_value < total_value_threshold {
Err(anyhow!(
"Total value ({}) is less than threshold ({}) (baseline: {}, better_than_baseline: {}%)",
total_value,
total_value_threshold,
self.baseline_value,
btb * 100.0
))
} else {
Ok(())
}
let total_weight = selected_items
.iter()
.map(|&item| {
if item >= self.weights.len() {
return Err(anyhow!("Item ({}) is out of bounds", item));
}
Ok(self.weights[item])
})
.collect::<Result<Vec<_>, _>>()?
.iter()
.sum::<u32>();
if total_weight > self.max_weight {
return Err(anyhow!(
"Total weight ({}) exceeded max weight ({})",
total_weight,
self.max_weight
));
}
let selected_items_vec: Vec<usize> = selected_items.into_iter().collect();
let total_value =
calculate_total_value(&selected_items_vec, &self.values, &self.interaction_values);
Ok(total_value)
}
);
}

View File

@ -10,6 +10,98 @@ macro_rules! conditional_pub {
};
}
macro_rules! impl_base64_serde {
($name:ident { $( $field:ident : $ty:ty ),* $(,)? }) => {
paste::paste! {
#[derive(Debug, Clone)]
pub struct $name {
$( pub $field : $ty ),*
}
#[derive(serde::Serialize, serde::Deserialize)]
struct [<$name Data>] {
$( $field : $ty ),*
}
impl serde::Serialize for $name {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use flate2::{write::GzEncoder, Compression};
use base64::engine::general_purpose::STANDARD as BASE64;
use base64::Engine;
use std::io::Write;
let helper = [<$name Data>] {
$( $field: self.$field.clone() ),*
};
let bincode_data = bincode::serialize(&helper)
.map_err(|e| serde::ser::Error::custom(format!("Bincode serialization failed: {}", e)))?;
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
encoder
.write_all(&bincode_data)
.map_err(|e| serde::ser::Error::custom(format!("Compression failed: {}", e)))?;
let compressed_data = encoder
.finish()
.map_err(|e| serde::ser::Error::custom(format!("Compression finish failed: {}", e)))?;
let encoded = BASE64.encode(&compressed_data);
serializer.serialize_str(&encoded)
}
}
impl<'de> serde::Deserialize<'de> for $name {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
use flate2::read::GzDecoder;
use base64::engine::general_purpose::STANDARD as BASE64;
use base64::Engine;
use std::io::Read;
use std::fmt;
struct VisitorImpl;
impl<'de> serde::de::Visitor<'de> for VisitorImpl {
type Value = $name;
fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "a base64 encoded, compressed, bincode serialized {}", stringify!($name))
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
let compressed = BASE64.decode(v)
.map_err(|e| E::custom(format!("Base64 decode failed: {}", e)))?;
let mut decoder = GzDecoder::new(&compressed[..]);
let mut decompressed = Vec::new();
decoder
.read_to_end(&mut decompressed)
.map_err(|e| E::custom(format!("Decompression failed: {}", e)))?;
let data: [<$name Data>] = bincode::deserialize(&decompressed)
.map_err(|e| E::custom(format!("Bincode deserialization failed: {}", e)))?;
Ok($name {
$( $field: data.$field ),*
})
}
}
deserializer.deserialize_str(VisitorImpl)
}
}
}
};
}
#[cfg(feature = "c001")]
pub mod satisfiability;
#[cfg(feature = "c001")]

View File

@ -47,15 +47,16 @@ impl Into<Vec<i32>> for Difficulty {
}
}
#[derive(Debug, Clone)]
pub struct Solution {
pub weights: Vec<Vec<Vec<f32>>>,
pub biases: Vec<Vec<f32>>,
pub epochs_used: usize,
pub bn_weights: Vec<Vec<f32>>,
pub bn_biases: Vec<Vec<f32>>,
pub bn_running_means: Vec<Vec<f32>>,
pub bn_running_vars: Vec<Vec<f32>>,
impl_base64_serde! {
Solution {
weights: Vec<Vec<Vec<f32>>>,
biases: Vec<Vec<f32>>,
epochs_used: usize,
bn_weights: Vec<Vec<f32>>,
bn_biases: Vec<Vec<f32>>,
bn_running_means: Vec<Vec<f32>>,
bn_running_vars: Vec<Vec<f32>>,
}
}
impl Solution {
@ -72,118 +73,6 @@ impl Solution {
}
}
// Helper struct for (de)serialization
#[derive(Serialize, Deserialize)]
struct SolutionData {
weights: Vec<Vec<Vec<f32>>>,
biases: Vec<Vec<f32>>,
epochs_used: usize,
bn_weights: Vec<Vec<f32>>,
bn_biases: Vec<Vec<f32>>,
bn_running_means: Vec<Vec<f32>>,
bn_running_vars: Vec<Vec<f32>>,
}
impl Serialize for Solution {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
// Serialize with bincode
let bincode_data = bincode::serialize(&SolutionData {
weights: self.weights.clone(),
biases: self.biases.clone(),
epochs_used: self.epochs_used,
bn_weights: self.bn_weights.clone(),
bn_biases: self.bn_biases.clone(),
bn_running_means: self.bn_running_means.clone(),
bn_running_vars: self.bn_running_vars.clone(),
})
.map_err(|e| serde::ser::Error::custom(format!("Bincode serialization failed: {}", e)))?;
// Compress with gzip
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
encoder
.write_all(&bincode_data)
.map_err(|e| serde::ser::Error::custom(format!("Compression failed: {}", e)))?;
let compressed_data = encoder
.finish()
.map_err(|e| serde::ser::Error::custom(format!("Compression finish failed: {}", e)))?;
// Encode as base64
let base64_string = BASE64.encode(&compressed_data);
// Serialize the base64 string
serializer.serialize_str(&base64_string)
}
}
impl<'de> Deserialize<'de> for Solution {
fn deserialize<D>(deserializer: D) -> Result<Solution, D::Error>
where
D: Deserializer<'de>,
{
struct SolutionVisitor;
impl<'de> Visitor<'de> for SolutionVisitor {
type Value = Solution;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a base64 encoded, compressed, bincode serialized Solution")
}
fn visit_str<E>(self, value: &str) -> Result<Solution, E>
where
E: de::Error,
{
// Decode from base64
let compressed_data = BASE64
.decode(value)
.map_err(|e| E::custom(format!("Base64 decode failed: {}", e)))?;
// Decompress
let mut decoder = GzDecoder::new(&compressed_data[..]);
let mut bincode_data = Vec::new();
decoder
.read_to_end(&mut bincode_data)
.map_err(|e| E::custom(format!("Decompression failed: {}", e)))?;
// Deserialize with bincode
let solution: SolutionData = bincode::deserialize(&bincode_data)
.map_err(|e| E::custom(format!("Bincode deserialization failed: {}", e)))?;
Ok(Solution {
weights: solution.weights,
biases: solution.biases,
epochs_used: solution.epochs_used,
bn_weights: solution.bn_weights,
bn_biases: solution.bn_biases,
bn_running_means: solution.bn_running_means,
bn_running_vars: solution.bn_running_vars,
})
}
}
deserializer.deserialize_str(SolutionVisitor)
}
}
impl TryFrom<Map<String, Value>> for Solution {
type Error = serde_json::Error;
fn try_from(v: Map<String, Value>) -> Result<Self, Self::Error> {
let base64_value = v
.get("base64")
.ok_or_else(|| de::Error::custom("Missing 'base64' field"))?;
let base64_string = base64_value
.as_str()
.ok_or_else(|| de::Error::custom("'base64' field must be a string"))?;
from_value(Value::String(base64_string.to_string()))
}
}
pub struct Dataset {
pub inputs: CudaSlice<f32>,
pub targets_noisy: CudaSlice<f32>,

View File

@ -5,12 +5,7 @@ use rand::{
rngs::{SmallRng, StdRng},
Rng, SeedableRng,
};
use serde::{
de::{self, SeqAccess, Visitor},
ser::SerializeSeq,
Deserialize, Deserializer, Serialize, Serializer,
};
use serde_json::{from_value, Map, Value};
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Debug, Copy, Clone)]
pub struct Difficulty {
@ -36,10 +31,10 @@ impl Into<Vec<i32>> for Difficulty {
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Solution {
#[serde(with = "bool_vec_as_u8")]
pub variables: Vec<bool>,
impl_base64_serde! {
Solution {
variables: Vec<bool>,
}
}
impl Solution {
@ -50,14 +45,6 @@ impl Solution {
}
}
impl TryFrom<Map<String, Value>> for Solution {
type Error = serde_json::Error;
fn try_from(v: Map<String, Value>) -> Result<Self, Self::Error> {
from_value(Value::Object(v))
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Challenge {
pub seed: [u8; 32],
@ -129,52 +116,3 @@ impl Challenge {
}
);
}
mod bool_vec_as_u8 {
use super::*;
use std::fmt;
pub fn serialize<S>(data: &Vec<bool>, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut seq = serializer.serialize_seq(Some(data.len()))?;
for &value in data {
seq.serialize_element(&(if value { 1 } else { 0 }))?;
}
seq.end()
}
struct BoolVecVisitor;
impl<'de> Visitor<'de> for BoolVecVisitor {
type Value = Vec<bool>;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a sequence of booleans or integers 0/1")
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
let mut vec = Vec::new();
while let Some(value) = seq.next_element::<serde_json::Value>()? {
match value {
serde_json::Value::Number(n) if n.as_u64() == Some(1) => vec.push(true),
serde_json::Value::Number(n) if n.as_u64() == Some(0) => vec.push(false),
serde_json::Value::Bool(b) => vec.push(b),
_ => return Err(de::Error::custom("expected 0, 1, true, or false")),
}
}
Ok(vec)
}
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<bool>, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_seq(BoolVecVisitor)
}
}

View File

@ -31,9 +31,10 @@ impl Into<Vec<i32>> for Difficulty {
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Solution {
pub indexes: Vec<usize>,
impl_base64_serde! {
Solution {
indexes: Vec<usize>,
}
}
impl Solution {
@ -44,14 +45,6 @@ impl Solution {
}
}
impl TryFrom<Map<String, Value>> for Solution {
type Error = serde_json::Error;
fn try_from(v: Map<String, Value>) -> Result<Self, Self::Error> {
from_value(Value::Object(v))
}
}
pub struct Challenge {
pub seed: [u8; 32],
pub difficulty: Difficulty,
@ -170,6 +163,26 @@ impl Challenge {
});
}
pub fn calc_average_distance(
&self,
solution: &Solution,
module: Arc<CudaModule>,
stream: Arc<CudaStream>,
prop: &cudaDeviceProp,
) -> Result<f32> {
calc_average_distance(
self.difficulty.num_queries,
self.vector_dims,
self.database_size,
&self.d_query_vectors,
&self.d_database_vectors,
&solution.indexes,
module.clone(),
stream.clone(),
prop,
)
}
conditional_pub!(
fn verify_solution(
&self,
@ -177,26 +190,17 @@ impl Challenge {
module: Arc<CudaModule>,
stream: Arc<CudaStream>,
prop: &cudaDeviceProp,
) -> Result<f32> {
let avg_dist = calc_average_distance(
self.difficulty.num_queries,
self.vector_dims,
self.database_size,
&self.d_query_vectors,
&self.d_database_vectors,
&solution.indexes,
module.clone(),
stream.clone(),
prop,
)?;
) -> Result<()> {
let avg_dist = self.calc_average_distance(solution, module, stream, prop)?;
if avg_dist > self.max_distance {
return Err(anyhow!(
"Average query vector distance is '{}'. Max dist: '{}'",
avg_dist,
self.max_distance
));
} else {
Ok(())
}
Ok(avg_dist)
}
);
}

View File

@ -1,10 +1,6 @@
use anyhow::{anyhow, Result};
use rand::{
rngs::{SmallRng, StdRng},
Rng, SeedableRng,
};
use rand::{rngs::SmallRng, Rng, SeedableRng};
use serde::{Deserialize, Serialize};
use serde_json::{from_value, Map, Value};
use statrs::function::erf::{erf, erf_inv};
use std::collections::{HashMap, HashSet};
@ -32,41 +28,20 @@ impl Into<Vec<i32>> for Difficulty {
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Solution {
pub sub_solutions: Vec<SubSolution>,
impl_base64_serde! {
Solution {
routes: Vec<Vec<usize>>,
}
}
impl Solution {
pub fn new() -> Self {
Self {
sub_solutions: Vec::new(),
}
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct SubSolution {
pub routes: Vec<Vec<usize>>,
}
impl TryFrom<Map<String, Value>> for Solution {
type Error = serde_json::Error;
fn try_from(v: Map<String, Value>) -> Result<Self, Self::Error> {
from_value(Value::Object(v))
Self { routes: Vec::new() }
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Challenge {
pub seed: [u8; 32],
pub difficulty: Difficulty,
pub sub_instances: Vec<SubInstance>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct SubInstance {
pub seed: [u8; 32],
pub difficulty: Difficulty,
pub demands: Vec<i32>,
@ -82,58 +57,8 @@ pub struct SubInstance {
pub due_times: Vec<i32>,
}
pub const NUM_SUB_INSTANCES: usize = 16;
impl Challenge {
pub fn generate_instance(seed: &[u8; 32], difficulty: &Difficulty) -> Result<Challenge> {
let mut rng = StdRng::from_seed(seed.clone());
let mut sub_instances = Vec::new();
for _ in 0..NUM_SUB_INSTANCES {
sub_instances.push(SubInstance::generate_instance(&rng.gen(), difficulty)?);
}
Ok(Challenge {
seed: seed.clone(),
difficulty: difficulty.clone(),
sub_instances,
})
}
conditional_pub!(
fn verify_solution(&self, solution: &Solution) -> Result<()> {
let mut better_than_baselines = Vec::new();
for (i, (sub_instance, sub_solution)) in self
.sub_instances
.iter()
.zip(&solution.sub_solutions)
.enumerate()
{
match sub_instance.verify_solution(&sub_solution) {
Ok(total_distance) => better_than_baselines
.push(total_distance as f64 / sub_instance.baseline_total_distance as f64),
Err(e) => return Err(anyhow!("Instance {}: {}", i, e.to_string())),
}
}
let average = 1.0
- (better_than_baselines.iter().map(|x| x * x).sum::<f64>()
/ better_than_baselines.len() as f64)
.sqrt();
let threshold = self.difficulty.better_than_baseline as f64 / 1000.0;
if average >= threshold {
Ok(())
} else {
Err(anyhow!(
"Average better_than_baseline ({}) is less than ({})",
average,
threshold
))
}
}
);
}
impl SubInstance {
pub fn generate_instance(seed: &[u8; 32], difficulty: &Difficulty) -> Result<SubInstance> {
pub fn generate_instance(seed: &[u8; 32], difficulty: &Difficulty) -> Result<Self> {
let mut rng = SmallRng::from_seed(seed.clone());
let num_nodes = difficulty.num_nodes;
let max_capacity = 200;
@ -252,7 +177,7 @@ impl SubInstance {
&due_times,
)?;
Ok(SubInstance {
Ok(Challenge {
seed: seed.clone(),
difficulty: difficulty.clone(),
demands,
@ -266,26 +191,44 @@ impl SubInstance {
})
}
pub fn calc_routes_total_distance(&self, solution: &Solution) -> Result<i32> {
if solution.routes.len() > self.fleet_size {
return Err(anyhow!(
"Number of routes ({}) exceeds fleet size ({})",
solution.routes.len(),
self.fleet_size
));
}
let total_distance = calc_routes_total_distance(
self.difficulty.num_nodes,
self.max_capacity,
&self.demands,
&self.distance_matrix,
&solution.routes,
self.service_time,
&self.ready_times,
&self.due_times,
)?;
Ok(total_distance)
}
conditional_pub!(
fn verify_solution(&self, solution: &SubSolution) -> Result<i32> {
if solution.routes.len() > self.fleet_size {
return Err(anyhow!(
"Number of routes ({}) exceeds fleet size ({})",
solution.routes.len(),
self.fleet_size
));
fn verify_solution(&self, solution: &Solution) -> Result<()> {
let total_distance = self.calc_routes_total_distance(solution)?;
let btb = self.difficulty.better_than_baseline as f64 / 1000.0;
let total_distance_threshold =
(self.baseline_total_distance as f64 * (1.0 - btb)).ceil() as i32;
if total_distance > total_distance_threshold {
Err(anyhow!(
"Total distance {} is greater than threshold {} (baseline: {}, better_than_baseline: {}%)",
total_distance,
total_distance_threshold,
self.baseline_total_distance,
btb * 100.0
))
} else {
Ok(())
}
let total_distance = calc_routes_total_distance(
self.difficulty.num_nodes,
self.max_capacity,
&self.demands,
&self.distance_matrix,
&solution.routes,
self.service_time,
&self.ready_times,
&self.due_times,
)?;
Ok(total_distance)
}
);
}