Serialize solutions with u32 instead of f32 for precision.

This commit is contained in:
FiveMovesAhead 2025-08-08 17:04:06 +01:00
parent 585f3762a4
commit 5982163fcd

View File

@ -36,17 +36,18 @@ impl Into<Vec<i32>> for Difficulty {
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Solution {
pub weights: Vec<Vec<Vec<f32>>>,
pub biases: Vec<Vec<f32>>,
pub weights: Vec<Vec<Vec<u32>>>,
pub biases: Vec<Vec<u32>>,
pub epochs_used: usize,
pub bn_weights: Vec<Vec<f32>>,
pub bn_biases: Vec<Vec<f32>>,
pub bn_running_means: Vec<Vec<f32>>,
pub bn_running_vars: Vec<Vec<f32>>,
pub bn_weights: Vec<Vec<u32>>,
pub bn_biases: Vec<Vec<u32>>,
pub bn_running_means: Vec<Vec<u32>>,
pub bn_running_vars: Vec<Vec<u32>>,
}
impl TryFrom<Map<String, Value>> for Solution {
type Error = serde_json::Error;
fn try_from(v: Map<String, Value>) -> Result<Self, Self::Error> {
from_value(Value::Object(v))
}
@ -594,30 +595,45 @@ pub fn training_loop(
Ok((solution, train_losses, validation_losses))
}
fn vec_u32_to_f32(vec: &Vec<u32>) -> Vec<f32> {
vec.iter().map(|u| f32::from_bits(*u)).collect()
}
pub fn load_solution(mlp: &mut MLP, solution: &Solution, stream: Arc<CudaStream>) -> Result<()> {
for (i, layer) in mlp.lin.iter_mut().enumerate() {
let w_flat: Vec<f32> = solution.weights[i].iter().flatten().cloned().collect();
let w_flat: Vec<f32> =
vec_u32_to_f32(&solution.weights[i].iter().flatten().cloned().collect());
stream.memcpy_htod(&w_flat, &mut layer.weight)?;
stream.memcpy_htod(&solution.biases[i], &mut layer.bias)?;
stream.memcpy_htod(&vec_u32_to_f32(&solution.biases[i]), &mut layer.bias)?;
}
for (i, bn) in mlp.bns.iter_mut().enumerate() {
stream.memcpy_htod(&solution.bn_weights[i], &mut bn.weight)?;
stream.memcpy_htod(&solution.bn_biases[i], &mut bn.bias)?;
stream.memcpy_htod(&solution.bn_running_means[i], &mut bn.running_mean)?;
stream.memcpy_htod(&solution.bn_running_vars[i], &mut bn.running_var)?;
stream.memcpy_htod(&vec_u32_to_f32(&solution.bn_weights[i]), &mut bn.weight)?;
stream.memcpy_htod(&vec_u32_to_f32(&solution.bn_biases[i]), &mut bn.bias)?;
stream.memcpy_htod(
&vec_u32_to_f32(&solution.bn_running_means[i]),
&mut bn.running_mean,
)?;
stream.memcpy_htod(
&vec_u32_to_f32(&solution.bn_running_vars[i]),
&mut bn.running_var,
)?;
}
stream.synchronize()?;
Ok(())
}
fn vec_f32_to_u32(vec: Vec<f32>) -> Vec<u32> {
vec.into_iter().map(|f| f.to_bits()).collect()
}
pub fn to_solution(mlp: &MLP, epochs_used: usize, stream: Arc<CudaStream>) -> Result<Solution> {
stream.synchronize()?;
let mut weights = Vec::new();
let mut biases = Vec::new();
for layer in &mlp.lin {
let w = stream.memcpy_dtov(&layer.weight)?;
let b = stream.memcpy_dtov(&layer.bias)?;
let w = vec_f32_to_u32(stream.memcpy_dtov(&layer.weight)?);
let b = vec_f32_to_u32(stream.memcpy_dtov(&layer.bias)?);
weights.push(w.chunks(layer.in_features).map(|c| c.to_vec()).collect());
biases.push(b);
@ -629,10 +645,10 @@ pub fn to_solution(mlp: &MLP, epochs_used: usize, stream: Arc<CudaStream>) -> Re
let mut bn_running_vars = Vec::new();
for bn in &mlp.bns {
bn_weights.push(stream.memcpy_dtov(&bn.weight)?);
bn_biases.push(stream.memcpy_dtov(&bn.bias)?);
bn_running_means.push(stream.memcpy_dtov(&bn.running_mean)?);
bn_running_vars.push(stream.memcpy_dtov(&bn.running_var)?);
bn_weights.push(vec_f32_to_u32(stream.memcpy_dtov(&bn.weight)?));
bn_biases.push(vec_f32_to_u32(stream.memcpy_dtov(&bn.bias)?));
bn_running_means.push(vec_f32_to_u32(stream.memcpy_dtov(&bn.running_mean)?));
bn_running_vars.push(vec_f32_to_u32(stream.memcpy_dtov(&bn.running_var)?));
}
Ok(Solution {