mirror of
https://github.com/tig-pool-nk/tig-monorepo.git
synced 2026-02-21 15:17:22 +08:00
Serialize solutions with u32 instead of f32 for precision.
This commit is contained in:
parent
585f3762a4
commit
5982163fcd
@ -36,17 +36,18 @@ impl Into<Vec<i32>> for Difficulty {
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct Solution {
|
||||
pub weights: Vec<Vec<Vec<f32>>>,
|
||||
pub biases: Vec<Vec<f32>>,
|
||||
pub weights: Vec<Vec<Vec<u32>>>,
|
||||
pub biases: Vec<Vec<u32>>,
|
||||
pub epochs_used: usize,
|
||||
pub bn_weights: Vec<Vec<f32>>,
|
||||
pub bn_biases: Vec<Vec<f32>>,
|
||||
pub bn_running_means: Vec<Vec<f32>>,
|
||||
pub bn_running_vars: Vec<Vec<f32>>,
|
||||
pub bn_weights: Vec<Vec<u32>>,
|
||||
pub bn_biases: Vec<Vec<u32>>,
|
||||
pub bn_running_means: Vec<Vec<u32>>,
|
||||
pub bn_running_vars: Vec<Vec<u32>>,
|
||||
}
|
||||
|
||||
impl TryFrom<Map<String, Value>> for Solution {
|
||||
type Error = serde_json::Error;
|
||||
|
||||
fn try_from(v: Map<String, Value>) -> Result<Self, Self::Error> {
|
||||
from_value(Value::Object(v))
|
||||
}
|
||||
@ -594,30 +595,45 @@ pub fn training_loop(
|
||||
Ok((solution, train_losses, validation_losses))
|
||||
}
|
||||
|
||||
fn vec_u32_to_f32(vec: &Vec<u32>) -> Vec<f32> {
|
||||
vec.iter().map(|u| f32::from_bits(*u)).collect()
|
||||
}
|
||||
|
||||
pub fn load_solution(mlp: &mut MLP, solution: &Solution, stream: Arc<CudaStream>) -> Result<()> {
|
||||
for (i, layer) in mlp.lin.iter_mut().enumerate() {
|
||||
let w_flat: Vec<f32> = solution.weights[i].iter().flatten().cloned().collect();
|
||||
let w_flat: Vec<f32> =
|
||||
vec_u32_to_f32(&solution.weights[i].iter().flatten().cloned().collect());
|
||||
stream.memcpy_htod(&w_flat, &mut layer.weight)?;
|
||||
|
||||
stream.memcpy_htod(&solution.biases[i], &mut layer.bias)?;
|
||||
stream.memcpy_htod(&vec_u32_to_f32(&solution.biases[i]), &mut layer.bias)?;
|
||||
}
|
||||
for (i, bn) in mlp.bns.iter_mut().enumerate() {
|
||||
stream.memcpy_htod(&solution.bn_weights[i], &mut bn.weight)?;
|
||||
stream.memcpy_htod(&solution.bn_biases[i], &mut bn.bias)?;
|
||||
stream.memcpy_htod(&solution.bn_running_means[i], &mut bn.running_mean)?;
|
||||
stream.memcpy_htod(&solution.bn_running_vars[i], &mut bn.running_var)?;
|
||||
stream.memcpy_htod(&vec_u32_to_f32(&solution.bn_weights[i]), &mut bn.weight)?;
|
||||
stream.memcpy_htod(&vec_u32_to_f32(&solution.bn_biases[i]), &mut bn.bias)?;
|
||||
stream.memcpy_htod(
|
||||
&vec_u32_to_f32(&solution.bn_running_means[i]),
|
||||
&mut bn.running_mean,
|
||||
)?;
|
||||
stream.memcpy_htod(
|
||||
&vec_u32_to_f32(&solution.bn_running_vars[i]),
|
||||
&mut bn.running_var,
|
||||
)?;
|
||||
}
|
||||
stream.synchronize()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn vec_f32_to_u32(vec: Vec<f32>) -> Vec<u32> {
|
||||
vec.into_iter().map(|f| f.to_bits()).collect()
|
||||
}
|
||||
|
||||
pub fn to_solution(mlp: &MLP, epochs_used: usize, stream: Arc<CudaStream>) -> Result<Solution> {
|
||||
stream.synchronize()?;
|
||||
let mut weights = Vec::new();
|
||||
let mut biases = Vec::new();
|
||||
for layer in &mlp.lin {
|
||||
let w = stream.memcpy_dtov(&layer.weight)?;
|
||||
let b = stream.memcpy_dtov(&layer.bias)?;
|
||||
let w = vec_f32_to_u32(stream.memcpy_dtov(&layer.weight)?);
|
||||
let b = vec_f32_to_u32(stream.memcpy_dtov(&layer.bias)?);
|
||||
|
||||
weights.push(w.chunks(layer.in_features).map(|c| c.to_vec()).collect());
|
||||
biases.push(b);
|
||||
@ -629,10 +645,10 @@ pub fn to_solution(mlp: &MLP, epochs_used: usize, stream: Arc<CudaStream>) -> Re
|
||||
let mut bn_running_vars = Vec::new();
|
||||
|
||||
for bn in &mlp.bns {
|
||||
bn_weights.push(stream.memcpy_dtov(&bn.weight)?);
|
||||
bn_biases.push(stream.memcpy_dtov(&bn.bias)?);
|
||||
bn_running_means.push(stream.memcpy_dtov(&bn.running_mean)?);
|
||||
bn_running_vars.push(stream.memcpy_dtov(&bn.running_var)?);
|
||||
bn_weights.push(vec_f32_to_u32(stream.memcpy_dtov(&bn.weight)?));
|
||||
bn_biases.push(vec_f32_to_u32(stream.memcpy_dtov(&bn.bias)?));
|
||||
bn_running_means.push(vec_f32_to_u32(stream.memcpy_dtov(&bn.running_mean)?));
|
||||
bn_running_vars.push(vec_f32_to_u32(stream.memcpy_dtov(&bn.running_var)?));
|
||||
}
|
||||
|
||||
Ok(Solution {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user