diff --git a/tig-challenges/Cargo.toml b/tig-challenges/Cargo.toml index 2d0f9f9..290a0ad 100644 --- a/tig-challenges/Cargo.toml +++ b/tig-challenges/Cargo.toml @@ -13,7 +13,10 @@ cudarc = { version = "0.12.0", features = [ "cuda-version-from-build-system", ], optional = true } ndarray = "0.15.6" -rand = { version = "0.8.5", default-features = false, features = ["std_rng"] } +rand = { version = "0.8.5", default-features = false, features = [ + "std_rng", + "small_rng", +] } serde = { version = "1.0.196", features = ["derive"] } serde_json = { version = "1.0.113" } diff --git a/tig-challenges/src/lib.rs b/tig-challenges/src/lib.rs index 42b27e1..5b30761 100644 --- a/tig-challenges/src/lib.rs +++ b/tig-challenges/src/lib.rs @@ -1,5 +1,5 @@ use anyhow::{anyhow, Result}; -use rand::{rngs::StdRng, Rng, SeedableRng}; +use rand::{rngs::SmallRng, SeedableRng}; use serde::de::DeserializeOwned; use serde::Serialize; @@ -95,18 +95,18 @@ pub struct CudaKernel { } pub struct RngArray { - rngs: [StdRng; 8], - index: u32, + rngs: [SmallRng; 8], + index: usize, } impl RngArray { pub fn new(seeds: [u64; 8]) -> Self { - let rngs = seeds.map(StdRng::seed_from_u64); + let rngs = seeds.map(SmallRng::seed_from_u64); RngArray { rngs, index: 0 } } - pub fn get_mut(&mut self) -> &mut StdRng { - self.index = (&mut self.rngs[self.index as usize]).gen_range(0..8); - &mut self.rngs[self.index as usize] + pub fn get_mut(&mut self) -> &mut SmallRng { + self.index = (self.index + 1) % 8; + &mut self.rngs[self.index] } }