mirror of
https://github.com/tig-foundation/tig-monorepo.git
synced 2026-02-21 10:27:49 +08:00
Submitted neuralnet_optimizer/neural_advanced
This commit is contained in:
parent
8ccedeb43e
commit
4400e7c34b
@ -8,7 +8,8 @@
|
||||
|
||||
// c006_a005
|
||||
|
||||
// c006_a006
|
||||
pub mod neural_advanced;
|
||||
pub use neural_advanced as c006_a006;
|
||||
|
||||
// c006_a007
|
||||
|
||||
|
||||
@ -0,0 +1,45 @@
|
||||
# TIG Code Submission
|
||||
|
||||
## Submission Details
|
||||
|
||||
* **Challenge Name:** neuralnet_optimizer
|
||||
* **Algorithm Name:** neural_advanced
|
||||
* **Copyright:** 2025 Rootz
|
||||
* **Identity of Submitter:** Rootz
|
||||
* **Identity of Creator of Algorithmic Method:** Rootz
|
||||
* **Unique Algorithm Identifier (UAI):** null
|
||||
|
||||
## Additional Notes
|
||||
|
||||
Here I present my latest neuralnet optimiser - **Neural Advanced**
|
||||
|
||||
**Key Features:**
|
||||
- **Adaptive Noise Variance**: Automatically learns when the network is close to its best possible accuracy and slows down to avoid overshooting
|
||||
- **Adaptive Spectral Boost**: Dynamically adjusts the learning rate - speeds up when making good progress, slows down when struggling
|
||||
- **Adaptive Beta1**: Adjusts how much the optimizer "remembers" previous updates based on whether training is smooth or chaotic
|
||||
- **Stability Detection**: Monitors whether gradients are consistent to decide between fast aggressive updates or careful conservative steps
|
||||
- **Robust Fisher Diagonal**: Estimates the loss landscape curvature while filtering out misleading extreme values
|
||||
- **Adaptive Plateau Escape**: Detects when training gets stuck and automatically increases learning rate to break free
|
||||
|
||||
**Hyperparameters you can use to tune the algorithm for your specific hardware:**
|
||||
- `threads_per_block` (default: 128) - Try 64, 128, 256, 512
|
||||
- `blocks_per_sm` (default: 4) - Try 2-8 for GPU occupancy tuning
|
||||
- `total_steps` (default: 1000) - Instance-dependent, more complex instance = may require more steps
|
||||
- `warmup_steps` (default: 96) - Try 50-150 before adaptation kicks in
|
||||
|
||||
All other hyperparameters are adaptive or work well at their current defaults.
|
||||
|
||||
**Important - this algorithm was predominantly tested on an RTX 5070Ti 12Gb. Please make sure you test on your own specific hardware and tune parameters where necessary to obtain the best results.**
|
||||
|
||||
## License
|
||||
|
||||
The files in this folder are under the following licenses:
|
||||
* TIG Benchmarker Outbound License
|
||||
* TIG Commercial License
|
||||
* TIG Inbound Game License
|
||||
* TIG Innovator Outbound Game License
|
||||
* TIG Open Data License
|
||||
* TIG THV Game License
|
||||
|
||||
Copies of the licenses can be obtained at:
|
||||
https://github.com/tig-foundation/tig-monorepo/tree/main/docs/licenses
|
||||
632
tig-algorithms/src/neuralnet_optimizer/neural_advanced/mod.rs
Normal file
632
tig-algorithms/src/neuralnet_optimizer/neural_advanced/mod.rs
Normal file
@ -0,0 +1,632 @@
|
||||
use anyhow::Result;
|
||||
use cudarc::{
|
||||
driver::{CudaModule, CudaSlice, CudaStream, LaunchConfig, PushKernelArg},
|
||||
runtime::sys::cudaDeviceProp,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
use tig_challenges::neuralnet_optimizer::*;
|
||||
use serde_json::{Map, Value};
|
||||
|
||||
thread_local! {
|
||||
static HYPERPARAMETERS: std::cell::RefCell<Map<String, Value>> = std::cell::RefCell::new(Map::new());
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct DualPhaseConsensusState {
|
||||
m: Vec<CudaSlice<f32>>,
|
||||
v: Vec<CudaSlice<f32>>,
|
||||
prev_g: Vec<CudaSlice<f32>>,
|
||||
prev_u: Vec<CudaSlice<f32>>,
|
||||
slow_u: Vec<CudaSlice<f32>>,
|
||||
f: Vec<CudaSlice<f32>>,
|
||||
ef: Vec<CudaSlice<f32>>,
|
||||
upd: Vec<CudaSlice<f32>>,
|
||||
cfgs: Vec<LaunchConfig>,
|
||||
layer_lrs: Vec<f32>,
|
||||
spectral_boost: f32,
|
||||
|
||||
step_count: usize,
|
||||
warmup_steps: usize,
|
||||
total_steps: usize,
|
||||
|
||||
noise_variance: f32,
|
||||
val_loss_history: Vec<f32>,
|
||||
|
||||
beta1: f32,
|
||||
beta2: f32,
|
||||
eps: f32,
|
||||
weight_decay: f32,
|
||||
|
||||
bn_layer_boost: f32,
|
||||
output_layer_damping: f32,
|
||||
|
||||
prev_val_loss: Option<f32>,
|
||||
best_val_loss: Option<f32>,
|
||||
plateau_count: usize,
|
||||
slope_ema: f32,
|
||||
lr_boost: f32,
|
||||
last_pulse_step: usize,
|
||||
last_epoch: usize,
|
||||
steps_in_epoch: usize,
|
||||
bpe_ema: f32,
|
||||
phase_tempo: f32,
|
||||
}
|
||||
|
||||
impl OptimizerStateTrait for DualPhaseConsensusState {
|
||||
fn as_any(&self) -> &dyn std::any::Any { self }
|
||||
fn as_any_mut(&mut self) -> &mut dyn std::any::Any { self }
|
||||
fn box_clone(&self) -> Box<dyn OptimizerStateTrait> { Box::new(self.clone()) }
|
||||
}
|
||||
|
||||
pub fn solve_challenge(
|
||||
challenge: &Challenge,
|
||||
save_solution: &dyn Fn(&Solution) -> Result<()>,
|
||||
hyperparameters: &Option<Map<String, Value>>,
|
||||
module: Arc<CudaModule>,
|
||||
stream: Arc<CudaStream>,
|
||||
prop: &cudaDeviceProp,
|
||||
) -> Result<()> {
|
||||
let _k_fast = module.load_function("dual_consensus_fisher_kernel")?;
|
||||
let _k_robust = module.load_function("sign_ef_consensus_kernel")?;
|
||||
|
||||
let hp = hyperparameters.clone().unwrap_or_default();
|
||||
HYPERPARAMETERS.with(|h| {
|
||||
*h.borrow_mut() = hp;
|
||||
});
|
||||
|
||||
training_loop(
|
||||
challenge,
|
||||
save_solution,
|
||||
module,
|
||||
stream,
|
||||
prop,
|
||||
optimizer_init_state,
|
||||
optimizer_query_at_params,
|
||||
optimizer_step,
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn optimizer_init_state(
|
||||
_seed: [u8; 32],
|
||||
param_sizes: &[usize],
|
||||
stream: Arc<CudaStream>,
|
||||
_module: Arc<CudaModule>,
|
||||
prop: &cudaDeviceProp,
|
||||
) -> Result<Box<dyn OptimizerStateTrait>> {
|
||||
let (threads_per_block, blocks_per_sm, total_steps, warmup_steps, noise_variance, spectral_boost, beta1, beta2, eps, weight_decay, bn_layer_boost, output_layer_damping) =
|
||||
HYPERPARAMETERS.with(|h| {
|
||||
let hp = h.borrow();
|
||||
let threads_per_block = hp.get("threads_per_block")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(128) as u32;
|
||||
|
||||
let blocks_per_sm = hp.get("blocks_per_sm")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(4) as u32;
|
||||
|
||||
let total_steps = hp.get("total_steps")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(450) as usize;
|
||||
|
||||
let warmup_steps = hp.get("warmup_steps")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(40) as usize;
|
||||
|
||||
let noise_variance = hp.get("noise_variance")
|
||||
.and_then(|v| v.as_f64())
|
||||
.unwrap_or(0.040) as f32;
|
||||
|
||||
let spectral_boost = hp.get("spectral_boost")
|
||||
.and_then(|v| v.as_f64())
|
||||
.unwrap_or(1.1) as f32;
|
||||
|
||||
let beta1 = hp.get("beta1")
|
||||
.and_then(|v| v.as_f64())
|
||||
.unwrap_or(0.92) as f32;
|
||||
|
||||
let beta2 = hp.get("beta2")
|
||||
.and_then(|v| v.as_f64())
|
||||
.unwrap_or(0.997) as f32;
|
||||
|
||||
let eps = hp.get("eps")
|
||||
.and_then(|v| v.as_f64())
|
||||
.unwrap_or(1e-8) as f32;
|
||||
|
||||
let weight_decay = hp.get("weight_decay")
|
||||
.and_then(|v| v.as_f64())
|
||||
.unwrap_or(0.0025) as f32;
|
||||
|
||||
let bn_layer_boost = hp.get("bn_layer_boost")
|
||||
.and_then(|v| v.as_f64())
|
||||
.unwrap_or(1.35) as f32;
|
||||
|
||||
let output_layer_damping = hp.get("output_layer_damping")
|
||||
.and_then(|v| v.as_f64())
|
||||
.unwrap_or(0.8) as f32;
|
||||
|
||||
(threads_per_block, blocks_per_sm, total_steps, warmup_steps, noise_variance, spectral_boost, beta1, beta2, eps, weight_decay, bn_layer_boost, output_layer_damping)
|
||||
});
|
||||
|
||||
let mut m = Vec::new();
|
||||
let mut v = Vec::new();
|
||||
let mut prev_g = Vec::new();
|
||||
let mut prev_u = Vec::new();
|
||||
let mut slow_u = Vec::new();
|
||||
let mut f = Vec::new();
|
||||
let mut ef = Vec::new();
|
||||
let mut upd = Vec::new();
|
||||
for &n in param_sizes {
|
||||
m.push(stream.alloc_zeros::<f32>(n)?);
|
||||
v.push(stream.alloc_zeros::<f32>(n)?);
|
||||
prev_g.push(stream.alloc_zeros::<f32>(n)?);
|
||||
prev_u.push(stream.alloc_zeros::<f32>(n)?);
|
||||
slow_u.push(stream.alloc_zeros::<f32>(n)?);
|
||||
let mut fisher_init = stream.alloc_zeros::<f32>(n)?;
|
||||
let init_fisher = vec![1e-4f32; n];
|
||||
stream.memcpy_htod(&init_fisher, &mut fisher_init)?;
|
||||
f.push(fisher_init);
|
||||
ef.push(stream.alloc_zeros::<f32>(n)?);
|
||||
upd.push(unsafe { stream.alloc::<f32>(n)? });
|
||||
}
|
||||
|
||||
let sm_blocks = (prop.multiProcessorCount as u32).saturating_mul(blocks_per_sm).max(1);
|
||||
let mut cfgs = Vec::with_capacity(param_sizes.len());
|
||||
for &n in param_sizes {
|
||||
let calc_blocks = (n as u32 + threads_per_block - 1) / threads_per_block;
|
||||
let grid_dim = calc_blocks.min(sm_blocks).max(1);
|
||||
cfgs.push(LaunchConfig {
|
||||
grid_dim: (grid_dim, 1, 1),
|
||||
block_dim: (threads_per_block, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
});
|
||||
}
|
||||
|
||||
let mut layer_lrs = Vec::with_capacity(param_sizes.len());
|
||||
for (i, ¶m_size) in param_sizes.iter().enumerate() {
|
||||
let mut lr = 0.0012f32;
|
||||
if i == 0 { lr = 0.0011; }
|
||||
if param_size <= 512 { lr = 0.0018; }
|
||||
if param_size > 50000 { lr = 0.0009; }
|
||||
if i == param_sizes.len() - 1 { lr = 0.0007; }
|
||||
layer_lrs.push(lr);
|
||||
}
|
||||
|
||||
let state = DualPhaseConsensusState {
|
||||
m,
|
||||
v,
|
||||
prev_g,
|
||||
prev_u,
|
||||
slow_u,
|
||||
f,
|
||||
ef,
|
||||
upd,
|
||||
cfgs,
|
||||
layer_lrs,
|
||||
spectral_boost,
|
||||
step_count: 0,
|
||||
warmup_steps,
|
||||
total_steps,
|
||||
noise_variance,
|
||||
val_loss_history: Vec::new(),
|
||||
beta1,
|
||||
beta2,
|
||||
eps,
|
||||
weight_decay,
|
||||
bn_layer_boost,
|
||||
output_layer_damping,
|
||||
prev_val_loss: None,
|
||||
best_val_loss: None,
|
||||
plateau_count: 0,
|
||||
slope_ema: 0.0,
|
||||
lr_boost: 1.0,
|
||||
last_pulse_step: 0,
|
||||
last_epoch: 0,
|
||||
steps_in_epoch: 0,
|
||||
bpe_ema: 1.0,
|
||||
phase_tempo: 1.0,
|
||||
};
|
||||
|
||||
Ok(Box::new(state))
|
||||
}
|
||||
|
||||
fn optimizer_query_at_params(
|
||||
_optimizer_state: &dyn OptimizerStateTrait,
|
||||
_model_params: &[CudaSlice<f32>],
|
||||
_epoch: usize,
|
||||
_train_loss: Option<f32>,
|
||||
_val_loss: Option<f32>,
|
||||
_stream: Arc<CudaStream>,
|
||||
_module: Arc<CudaModule>,
|
||||
_prop: &cudaDeviceProp,
|
||||
) -> Result<Option<Vec<CudaSlice<f32>>>> {
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn spectral_phase_lr(s: &DualPhaseConsensusState, base_lr: f32) -> f32 {
|
||||
let t = s.step_count as f32;
|
||||
let warm = s.warmup_steps as f32;
|
||||
let total = s.total_steps as f32;
|
||||
|
||||
if t <= warm {
|
||||
return base_lr * (t / warm.max(1.0)) * s.spectral_boost;
|
||||
}
|
||||
|
||||
let progress = ((t - warm) / (total - warm).max(1.0)).min(1.0);
|
||||
|
||||
let cosine_factor = 0.5 * (1.0 + (std::f32::consts::PI * progress).cos());
|
||||
let spec_boost = s.spectral_boost * (1.0 - 0.3 * progress);
|
||||
|
||||
base_lr * cosine_factor * spec_boost
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn compute_blends(s: &DualPhaseConsensusState, val_loss: Option<f32>) -> (f32, f32, f32, f32, f32, f32, f32) {
|
||||
let t = s.step_count as f32;
|
||||
let warm = s.warmup_steps as f32;
|
||||
let total = s.total_steps as f32;
|
||||
let progress = (t / total.max(1.0)).min(1.0);
|
||||
|
||||
let (mut blend_adam, mut blend_norm, mut blend_sign, gamma, bb_blend, mut lookahead_alpha, mut lookahead_tau): (f32, f32, f32, f32, f32, f32, f32) = if t <= warm {
|
||||
(0.35, 0.65, 0.0, 0.22, 0.6, 0.0, 0.2)
|
||||
} else {
|
||||
let mut trend = 0.0f32;
|
||||
if let (Some(prev), Some(curr)) = (s.prev_val_loss, val_loss) {
|
||||
trend = prev - curr;
|
||||
}
|
||||
|
||||
if trend > 1e-3 {
|
||||
(0.60, 0.35, 0.05, 0.28, 0.40, 0.15, 0.15)
|
||||
} else if trend.abs() < 1e-4 {
|
||||
(0.55, 0.35, 0.10, 0.15, 0.45, 0.30, 0.22)
|
||||
} else {
|
||||
(0.50, 0.35, 0.15, 0.20, 0.50, 0.20, 0.20)
|
||||
}
|
||||
};
|
||||
|
||||
if t > warm {
|
||||
if let Some(curr) = val_loss {
|
||||
if curr <= s.noise_variance * 5.0 {
|
||||
blend_sign = (blend_sign + 0.2).min(0.6);
|
||||
lookahead_alpha = lookahead_alpha.max(0.45);
|
||||
lookahead_tau = (lookahead_tau + 0.05).min(0.35);
|
||||
blend_adam *= 0.9;
|
||||
blend_norm *= 0.9;
|
||||
} else if curr >= s.noise_variance * 6.2 && curr <= s.noise_variance * 8.6 {
|
||||
blend_sign = blend_sign.max(0.35);
|
||||
blend_adam = (blend_adam * 0.95).max(0.25);
|
||||
blend_norm = (blend_norm * 0.95).max(0.15);
|
||||
lookahead_alpha = (lookahead_alpha * 0.85).min(0.35);
|
||||
lookahead_tau = (lookahead_tau * 0.85).min(0.30);
|
||||
}
|
||||
}
|
||||
|
||||
if progress > 0.8 {
|
||||
blend_norm = blend_norm.max(0.35);
|
||||
blend_sign *= 0.9;
|
||||
lookahead_alpha = lookahead_alpha.max(0.5);
|
||||
lookahead_tau = (lookahead_tau + 0.05).min(0.4);
|
||||
}
|
||||
}
|
||||
|
||||
let sum = (blend_adam + blend_norm + blend_sign).max(1e-8);
|
||||
(
|
||||
blend_adam / sum,
|
||||
blend_norm / sum,
|
||||
blend_sign / sum,
|
||||
gamma,
|
||||
bb_blend,
|
||||
lookahead_alpha,
|
||||
lookahead_tau,
|
||||
)
|
||||
}
|
||||
|
||||
fn optimizer_step(
|
||||
optimizer_state: &mut dyn OptimizerStateTrait,
|
||||
_model_params: &[CudaSlice<f32>],
|
||||
gradients: &[CudaSlice<f32>],
|
||||
epoch: usize,
|
||||
_train_loss: Option<f32>,
|
||||
val_loss: Option<f32>,
|
||||
stream: Arc<CudaStream>,
|
||||
module: Arc<CudaModule>,
|
||||
_prop: &cudaDeviceProp,
|
||||
) -> Result<Vec<CudaSlice<f32>>> {
|
||||
let s = optimizer_state.as_any_mut().downcast_mut::<DualPhaseConsensusState>().unwrap();
|
||||
s.step_count += 1;
|
||||
if s.step_count == 1 {
|
||||
s.last_epoch = epoch;
|
||||
}
|
||||
if s.last_epoch != epoch {
|
||||
if s.steps_in_epoch > 0 {
|
||||
s.bpe_ema = 0.9 * s.bpe_ema + 0.1 * (s.steps_in_epoch as f32);
|
||||
}
|
||||
s.steps_in_epoch = 0;
|
||||
s.last_epoch = epoch;
|
||||
}
|
||||
s.steps_in_epoch = s.steps_in_epoch.saturating_add(1);
|
||||
let tempo = (1.0 + 0.30 * s.bpe_ema.ln()).clamp(1.0, 2.2);
|
||||
s.phase_tempo = tempo;
|
||||
|
||||
if let Some(loss) = val_loss {
|
||||
if s.step_count > s.warmup_steps {
|
||||
s.val_loss_history.push(loss);
|
||||
if s.val_loss_history.len() > 12 {
|
||||
s.val_loss_history.remove(0);
|
||||
}
|
||||
|
||||
if s.val_loss_history.len() >= 6 {
|
||||
let min_loss = s.val_loss_history.iter().copied().fold(f32::INFINITY, f32::min);
|
||||
let recent_avg = s.val_loss_history.iter().rev().take(10).sum::<f32>() / 10.0;
|
||||
let target_nv = (min_loss / 5.0).min(recent_avg / 8.0);
|
||||
s.noise_variance = 0.85 * s.noise_variance + 0.15 * target_nv;
|
||||
s.noise_variance = s.noise_variance.clamp(0.0, 0.05);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let (Some(prev), Some(curr)) = (s.prev_val_loss, val_loss) {
|
||||
if s.step_count > s.warmup_steps && s.step_count > 20 {
|
||||
let improvement = prev - curr;
|
||||
let relative_improvement = improvement / prev.abs().max(1e-8);
|
||||
|
||||
if relative_improvement > 0.008 {
|
||||
s.spectral_boost = (s.spectral_boost * 1.015).min(1.5);
|
||||
} else if relative_improvement < -0.003 {
|
||||
s.spectral_boost *= 0.97;
|
||||
} else if relative_improvement.abs() < 0.0005 && s.plateau_count > 15 {
|
||||
s.spectral_boost = (s.spectral_boost * 1.008).min(1.5);
|
||||
}
|
||||
|
||||
s.spectral_boost = s.spectral_boost.clamp(0.85, 1.5);
|
||||
}
|
||||
}
|
||||
|
||||
if s.step_count > s.warmup_steps && s.val_loss_history.len() >= 8 {
|
||||
let recent_avg = s.val_loss_history.iter().rev().take(5).sum::<f32>() / 5.0;
|
||||
let older_avg = s.val_loss_history.iter().rev().skip(5).take(5).sum::<f32>() / 5.0;
|
||||
let trend = older_avg - recent_avg;
|
||||
|
||||
let target_beta1 = if trend > 0.02 {
|
||||
0.94
|
||||
} else if trend < -0.02 {
|
||||
0.88
|
||||
} else {
|
||||
0.91
|
||||
};
|
||||
|
||||
s.beta1 = 0.85 * s.beta1 + 0.15 * target_beta1;
|
||||
s.beta1 = s.beta1.clamp(0.87, 0.94);
|
||||
}
|
||||
|
||||
let mut global_damp = 1.0f32;
|
||||
|
||||
if let (Some(prev), Some(curr)) = (s.prev_val_loss, val_loss) {
|
||||
let improvement = prev - curr;
|
||||
s.slope_ema = 0.85 * s.slope_ema + 0.15 * improvement;
|
||||
if s.step_count > s.warmup_steps {
|
||||
let is_stagnant = improvement <= 1.0e-4 && s.slope_ema < 2.0e-4;
|
||||
let is_declining = improvement < 0.0 && s.slope_ema < 0.0;
|
||||
|
||||
if is_stagnant || is_declining {
|
||||
s.plateau_count += 1;
|
||||
} else if improvement > 5.0e-5 {
|
||||
s.plateau_count = 0;
|
||||
} else if s.plateau_count > 0 {
|
||||
s.plateau_count = s.plateau_count.saturating_sub(1);
|
||||
}
|
||||
if s.plateau_count >= 25 {
|
||||
if curr > s.noise_variance * 4.0 {
|
||||
s.lr_boost = (s.lr_boost * 1.12).min(1.60);
|
||||
s.last_pulse_step = s.step_count;
|
||||
s.plateau_count = 0;
|
||||
}
|
||||
} else if s.plateau_count >= 15 && curr > s.noise_variance * 8.0 {
|
||||
s.lr_boost = (s.lr_boost * 1.15).min(1.70);
|
||||
s.last_pulse_step = s.step_count;
|
||||
s.plateau_count = 0;
|
||||
} else if s.plateau_count >= 18 && curr > s.noise_variance * 5.0 {
|
||||
s.lr_boost = (s.lr_boost * 1.10).min(1.45);
|
||||
s.last_pulse_step = s.step_count;
|
||||
s.plateau_count = 0;
|
||||
} else if s.lr_boost > 1.0 {
|
||||
let relative_improvement = improvement / curr.abs().max(1e-8);
|
||||
let decay = if relative_improvement > 0.01 {
|
||||
0.75
|
||||
} else if relative_improvement > 0.001 {
|
||||
0.85
|
||||
} else if improvement > 0.0 {
|
||||
0.93
|
||||
} else {
|
||||
0.97
|
||||
};
|
||||
s.lr_boost = 1.0 + (s.lr_boost - 1.0) * decay;
|
||||
let decay = if improvement > 5.0e-5 { 0.82 } else { 0.92 };
|
||||
s.lr_boost = 1.0 + (s.lr_boost - 1.0) * decay;
|
||||
if s.step_count.saturating_sub(s.last_pulse_step) > 80 {
|
||||
s.lr_boost *= 0.96;
|
||||
}
|
||||
if s.lr_boost < 1.02 { s.lr_boost = 1.0; }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(loss) = val_loss {
|
||||
let dynamic_threshold = s.noise_variance * (1.1 + 0.1 * (s.step_count as f32 / s.total_steps as f32));
|
||||
if loss <= dynamic_threshold && s.step_count > s.warmup_steps {
|
||||
let proximity = (loss / dynamic_threshold).clamp(0.4, 1.0);
|
||||
let plateau_factor: f32 = if s.plateau_count > 10 { 1.2 } else { 1.0 };
|
||||
global_damp *= (0.25 + 0.35 * proximity) * plateau_factor.min(0.9);
|
||||
}
|
||||
|
||||
if loss <= s.noise_variance * 5.0 {
|
||||
let noise_proximity = (loss / (s.noise_variance * 5.0)).min(1.0);
|
||||
let steepness = 1.0 + 0.5 * (1.0 - noise_proximity);
|
||||
let noise_damping = 0.70 + 0.30 * noise_proximity.powf(steepness);
|
||||
global_damp *= noise_damping;
|
||||
}
|
||||
}
|
||||
|
||||
let t = s.step_count as i32;
|
||||
let bias_correction1 = 1.0 - s.beta1.powi(t);
|
||||
let bias_correction2 = 1.0 - s.beta2.powi(t);
|
||||
|
||||
let (blend_adam, blend_norm, blend_sign, nesterov_gamma, bb_blend, lookahead_alpha, lookahead_tau) = compute_blends(s, val_loss);
|
||||
let near_floor = val_loss.map_or(false, |loss| loss <= s.noise_variance * 3.0);
|
||||
let late_phase = s.step_count > s.total_steps * 3 / 4;
|
||||
let use_robust = s.step_count > s.warmup_steps && (near_floor || late_phase);
|
||||
|
||||
let (in_precision_zone, precision_gain, gate_lo, gate_hi, forward_gain): (bool, f32, f32, f32, f32) = if let Some(loss) = val_loss {
|
||||
if s.step_count > s.warmup_steps {
|
||||
let z_lo = s.noise_variance * 6.2;
|
||||
let z_hi = s.noise_variance * 8.6;
|
||||
if loss >= z_lo && loss <= z_hi {
|
||||
let pos = ((z_hi - loss) / (z_hi - z_lo + 1e-8)).clamp(0.0, 1.0);
|
||||
let pg = 1.02 + 0.06 * pos;
|
||||
let gate_lo = 0.70 + 0.02 * pos;
|
||||
let gate_hi = 1.50 + 0.05 * pos;
|
||||
let forward_gain = if let Some(prev) = s.prev_val_loss {
|
||||
let rel = ((prev - loss).max(0.0)) / (prev.abs() + 1e-6);
|
||||
1.0 + (0.75 * rel).min(0.015)
|
||||
} else { 1.0 };
|
||||
(true, pg, gate_lo, gate_hi, forward_gain)
|
||||
} else { (false, 1.0, 0.66, 1.50, 1.0) }
|
||||
} else { (false, 1.0, 0.66, 1.50, 1.0) }
|
||||
} else { (false, 1.0, 0.66, 1.50, 1.0) };
|
||||
|
||||
let beta1_eff: f32 = if in_precision_zone { (s.beta1 + 0.02).min(0.995) } else { s.beta1 };
|
||||
let beta2_eff: f32 = s.beta2;
|
||||
let eps_eff: f32 = if in_precision_zone { s.eps * 0.9 } else { s.eps };
|
||||
let mut wd_eff: f32 = if in_precision_zone { s.weight_decay * 1.05 } else { s.weight_decay };
|
||||
if s.step_count > s.warmup_steps {
|
||||
if near_floor {
|
||||
wd_eff *= 1.10;
|
||||
} else if s.plateau_count >= 20 {
|
||||
wd_eff *= 0.50;
|
||||
}
|
||||
}
|
||||
wd_eff *= (1.0 / s.phase_tempo).clamp(0.6, 1.0);
|
||||
|
||||
let trust_backoff: f32 = if let (Some(prev), Some(curr)) = (s.prev_val_loss, val_loss) {
|
||||
let delta = curr - prev;
|
||||
if delta > 2e-4 {
|
||||
1.0 / (1.0 + 1.5 * (delta / (prev.abs() + 1e-8)).min(0.02))
|
||||
} else {
|
||||
1.0
|
||||
}
|
||||
} else {
|
||||
1.0
|
||||
};
|
||||
|
||||
let k_fast = module.load_function("dual_consensus_fisher_kernel")?;
|
||||
let k_robust = module.load_function("sign_ef_consensus_kernel")?;
|
||||
|
||||
let mut updates = Vec::with_capacity(gradients.len());
|
||||
|
||||
for (i, g) in gradients.iter().enumerate() {
|
||||
let n = g.len();
|
||||
if n == 0 {
|
||||
updates.push(stream.alloc_zeros::<f32>(0)?);
|
||||
continue;
|
||||
}
|
||||
|
||||
let base_lr = s.layer_lrs[i];
|
||||
let tempo_lr = (1.0 / s.phase_tempo.powf(0.35)).max(0.6);
|
||||
let lr = spectral_phase_lr(s, base_lr) * global_damp * s.lr_boost * tempo_lr;
|
||||
|
||||
let layer_multiplier = if i == gradients.len() - 1 {
|
||||
if let Some(loss) = val_loss {
|
||||
if s.step_count > s.warmup_steps + 30 {
|
||||
let loss_ratio = (loss / (s.noise_variance * 6.0)).min(1.0);
|
||||
let adaptive_damping = s.output_layer_damping * (0.7 + 0.3 * loss_ratio);
|
||||
adaptive_damping.max(s.output_layer_damping)
|
||||
} else {
|
||||
s.output_layer_damping
|
||||
}
|
||||
} else {
|
||||
s.output_layer_damping
|
||||
}
|
||||
} else if n <= 512 {
|
||||
s.bn_layer_boost
|
||||
} else {
|
||||
1.0
|
||||
};
|
||||
|
||||
let effective_lr = lr * layer_multiplier * precision_gain * forward_gain * trust_backoff;
|
||||
|
||||
let cfg = s.cfgs[i];
|
||||
|
||||
let update_buf_ref = &mut s.upd[i];
|
||||
|
||||
unsafe {
|
||||
if use_robust {
|
||||
stream
|
||||
.launch_builder(&k_robust)
|
||||
.arg(g)
|
||||
.arg(&mut s.f[i])
|
||||
.arg(&mut s.ef[i])
|
||||
.arg(&mut s.slow_u[i])
|
||||
.arg(update_buf_ref)
|
||||
.arg(&(n as u32))
|
||||
.arg(&effective_lr)
|
||||
.arg(&eps_eff)
|
||||
.arg(&lookahead_alpha)
|
||||
.arg(&lookahead_tau)
|
||||
.arg(&gate_lo)
|
||||
.arg(&gate_hi)
|
||||
.launch(cfg)?;
|
||||
} else {
|
||||
stream
|
||||
.launch_builder(&k_fast)
|
||||
.arg(g)
|
||||
.arg(&mut s.m[i])
|
||||
.arg(&mut s.v[i])
|
||||
.arg(&mut s.prev_g[i])
|
||||
.arg(&mut s.prev_u[i])
|
||||
.arg(&mut s.slow_u[i])
|
||||
.arg(&mut s.f[i])
|
||||
.arg(update_buf_ref)
|
||||
.arg(&(n as u32))
|
||||
.arg(&effective_lr)
|
||||
.arg(&beta1_eff)
|
||||
.arg(&beta2_eff)
|
||||
.arg(&eps_eff)
|
||||
.arg(&wd_eff)
|
||||
.arg(&bias_correction1)
|
||||
.arg(&bias_correction2)
|
||||
.arg(&blend_adam)
|
||||
.arg(&blend_norm)
|
||||
.arg(&blend_sign)
|
||||
.arg(&nesterov_gamma)
|
||||
.arg(&bb_blend)
|
||||
.arg(&lookahead_alpha)
|
||||
.arg(&lookahead_tau)
|
||||
.arg(&gate_lo)
|
||||
.arg(&gate_hi)
|
||||
.launch(cfg)?;
|
||||
}
|
||||
}
|
||||
|
||||
updates.push(s.upd[i].clone());
|
||||
}
|
||||
|
||||
if let Some(curr) = val_loss {
|
||||
s.best_val_loss = Some(match s.best_val_loss {
|
||||
Some(b) => if curr < b { curr } else { b },
|
||||
None => curr,
|
||||
});
|
||||
|
||||
if s.step_count > s.warmup_steps + 80 {
|
||||
if curr <= s.noise_variance * 2.8 && s.plateau_count < 8 {
|
||||
}
|
||||
}
|
||||
}
|
||||
s.prev_val_loss = val_loss;
|
||||
Ok(updates)
|
||||
}
|
||||
|
||||
pub fn help() {
|
||||
println!("No help information available.");
|
||||
}
|
||||
@ -0,0 +1,186 @@
|
||||
extern "C" __global__ __launch_bounds__(512, 4) void sign_ef_consensus_kernel(
|
||||
const float* __restrict__ gradients,
|
||||
float* __restrict__ fisher_diag,
|
||||
float* __restrict__ ef_residual,
|
||||
float* __restrict__ slow_update,
|
||||
float* __restrict__ updates,
|
||||
const unsigned int n,
|
||||
const float lr,
|
||||
const float eps,
|
||||
const float lookahead_alpha,
|
||||
const float lookahead_tau,
|
||||
const float gate_lo,
|
||||
const float gate_hi
|
||||
) {
|
||||
const unsigned int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (tid >= n) return;
|
||||
const unsigned int stride = blockDim.x * gridDim.x;
|
||||
|
||||
const float one_minus_la_tau = 1.0f - lookahead_tau;
|
||||
const float one_minus_la_alpha = 1.0f - lookahead_alpha;
|
||||
const float fisher_beta = 0.98f;
|
||||
const float one_minus_fisher_beta = 1.0f - fisher_beta;
|
||||
const float inv_lr = 1.0f / fmaxf(lr, 1.0e-8f);
|
||||
|
||||
#pragma unroll 2
|
||||
for (unsigned int idx = tid; idx < n; idx += stride) {
|
||||
float fd = fisher_diag[idx];
|
||||
const float g = gradients[idx];
|
||||
|
||||
const float fd_std = sqrtf(fd) + eps;
|
||||
const float g_clipped = fminf(fmaxf(g, -4.0f * fd_std), 4.0f * fd_std);
|
||||
fd = fisher_beta * fd + one_minus_fisher_beta * g_clipped * g_clipped;
|
||||
fisher_diag[idx] = fd;
|
||||
|
||||
const float rms = sqrtf(fd) + eps;
|
||||
const float g_n = g / rms;
|
||||
|
||||
const float ef_old = ef_residual[idx];
|
||||
const float u_desired = -lr * g_n + ef_old;
|
||||
|
||||
const float u_quant = -lr * copysignf(1.0f, g_n + ef_old * inv_lr);
|
||||
|
||||
float ef_delta = u_desired - u_quant;
|
||||
const float ef_cap = 6.0f * lr;
|
||||
ef_delta = ef_delta / (1.0f + fabsf(ef_delta) / fmaxf(ef_cap, 1.0e-8f));
|
||||
const float ef_new = 0.98f * ef_old + ef_delta;
|
||||
ef_residual[idx] = ef_new;
|
||||
|
||||
float su = slow_update[idx];
|
||||
su = one_minus_la_tau * su + lookahead_tau * u_quant;
|
||||
const float final_update = one_minus_la_alpha * u_quant + lookahead_alpha * su;
|
||||
|
||||
const float target = lr * fabsf(g_n);
|
||||
const float uabs = fabsf(final_update);
|
||||
const float scale = fminf(fmaxf(target / fmaxf(uabs, 1.0e-12f), gate_lo), gate_hi);
|
||||
const float adj_update = final_update * scale;
|
||||
|
||||
slow_update[idx] = su;
|
||||
updates[idx] = adj_update;
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" __global__ __launch_bounds__(512, 4) void dual_consensus_fisher_kernel(
|
||||
const float* __restrict__ gradients,
|
||||
float* __restrict__ momentum,
|
||||
float* __restrict__ velocity,
|
||||
float* __restrict__ prev_grad,
|
||||
float* __restrict__ prev_update,
|
||||
float* __restrict__ slow_update,
|
||||
float* __restrict__ fisher_diag,
|
||||
float* __restrict__ updates,
|
||||
const unsigned int n,
|
||||
const float lr,
|
||||
const float beta1,
|
||||
const float beta2,
|
||||
const float eps,
|
||||
const float weight_decay,
|
||||
const float bias_correction1,
|
||||
const float bias_correction2,
|
||||
const float blend_adam,
|
||||
const float blend_norm,
|
||||
const float blend_sign,
|
||||
const float nesterov_gamma,
|
||||
const float bb_blend,
|
||||
const float lookahead_alpha,
|
||||
const float lookahead_tau,
|
||||
const float gate_lo,
|
||||
const float gate_hi
|
||||
) {
|
||||
const unsigned int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const unsigned int stride = blockDim.x * gridDim.x;
|
||||
|
||||
const float inv_bc1 = 1.0f / fmaxf(bias_correction1, 1.0e-8f);
|
||||
const float inv_bc2 = 1.0f / fmaxf(bias_correction2, 1.0e-8f);
|
||||
const float one_minus_beta1 = 1.0f - beta1;
|
||||
const float one_minus_beta2 = 1.0f - beta2;
|
||||
const float one_minus_la_tau = 1.0f - lookahead_tau;
|
||||
const float one_minus_la_alpha = 1.0f - lookahead_alpha;
|
||||
const float fisher_beta = 0.98f;
|
||||
const float one_minus_fisher_beta = 1.0f - fisher_beta;
|
||||
const float ortho_mix = fminf(fmaxf(0.2f + 0.4f * blend_sign, 0.0f), 0.6f);
|
||||
|
||||
#pragma unroll 2
|
||||
for (unsigned int idx = tid; idx < n; idx += stride) {
|
||||
const float g = gradients[idx];
|
||||
const float pg = prev_grad[idx];
|
||||
|
||||
const float gamma_local = (g * pg >= 0.0f) ? nesterov_gamma : (0.25f * nesterov_gamma);
|
||||
const float g_pred = g + gamma_local * (g - pg);
|
||||
|
||||
float m = momentum[idx];
|
||||
float v = velocity[idx];
|
||||
|
||||
m = beta1 * m + one_minus_beta1 * g_pred;
|
||||
|
||||
const float err = g_pred - m;
|
||||
v = beta2 * v + one_minus_beta2 * err * err;
|
||||
|
||||
const float m_hat = m * inv_bc1;
|
||||
const float v_hat = v * inv_bc2;
|
||||
|
||||
const float sqrt_v = sqrtf(fmaxf(v_hat, 0.0f));
|
||||
const float adaptive_eps = eps * (1.0f + 0.1f * sqrt_v);
|
||||
const float denom = sqrt_v + adaptive_eps;
|
||||
const float inv_denom = 1.0f / fmaxf(denom, 1.0e-12f);
|
||||
|
||||
const float adam_update = -lr * (m_hat * inv_denom + weight_decay * g_pred);
|
||||
const float g_over_denom = g_pred * inv_denom;
|
||||
const float norm_update = -lr * g_over_denom;
|
||||
const float sign_update = -lr * copysignf(1.0f, m_hat);
|
||||
float base_update = blend_adam * adam_update
|
||||
+ blend_norm * norm_update
|
||||
+ blend_sign * sign_update;
|
||||
|
||||
const float overlap = copysignf(fminf(fabsf(g_pred), fabsf(m_hat)), m_hat);
|
||||
const float g_ortho = g_pred - overlap;
|
||||
const float ortho_update = -lr * (g_ortho / denom);
|
||||
base_update = (1.0f - ortho_mix) * base_update + ortho_mix * ortho_update;
|
||||
|
||||
const float s = prev_update[idx];
|
||||
const float s_mag = fabsf(s);
|
||||
const float bb_scale = (s_mag > 1e-6f) ? fminf(s_mag * 2.0f, 2.5f) : 1.0f;
|
||||
base_update *= (1.0f - bb_blend * 0.3f) + (bb_blend * 0.3f) * bb_scale;
|
||||
|
||||
float fd = fisher_diag[idx];
|
||||
const float sqrt_v_for_clip = sqrtf(fmaxf(v, 0.0f));
|
||||
const float grad_std = sqrt_v_for_clip + eps;
|
||||
const float g_clipped = fminf(fmaxf(g_pred, -5.0f * grad_std), 5.0f * grad_std);
|
||||
fd = fisher_beta * fd + one_minus_fisher_beta * g_clipped * g_clipped;
|
||||
const float fisher_rms = sqrtf(fd) + eps;
|
||||
|
||||
const float fisher_norm_update = -lr * (g_pred / fisher_rms);
|
||||
const float robust_track = 0.5f * sign_update + 0.5f * fisher_norm_update;
|
||||
|
||||
const float flip = (g * pg < 0.0f) ? 1.0f : 0.0f;
|
||||
const float vol = fminf(sqrt_v * 0.33333334f, 1.0f);
|
||||
const float agree = (base_update * robust_track >= 0.0f) ? 1.0f : 0.0f;
|
||||
|
||||
const float grad_mom_align = (g_pred * m_hat >= 0.0f) ? 1.0f : 0.0f;
|
||||
const float stability = grad_mom_align * (1.0f - flip);
|
||||
|
||||
float consensus_mix = 0.35f * (1.0f - agree) + 0.25f * vol + 0.25f * blend_sign + 0.15f * (1.0f - stability);
|
||||
consensus_mix = fminf(fmaxf(consensus_mix, 0.0f), 1.0f);
|
||||
float chosen_update = (1.0f - consensus_mix) * base_update + consensus_mix * robust_track;
|
||||
|
||||
float trust = 1.0f / (1.0f + 0.5f * vol + 0.5f * flip);
|
||||
chosen_update *= trust;
|
||||
|
||||
const float target = lr * fabsf(g_over_denom);
|
||||
const float uabs = fabsf(chosen_update);
|
||||
const float scale = fminf(fmaxf(target / fmaxf(uabs, 1.0e-12f), gate_lo), gate_hi);
|
||||
chosen_update *= scale;
|
||||
|
||||
float su = slow_update[idx];
|
||||
su = one_minus_la_tau * su + lookahead_tau * chosen_update;
|
||||
const float final_update = one_minus_la_alpha * chosen_update + lookahead_alpha * su;
|
||||
|
||||
momentum[idx] = m;
|
||||
velocity[idx] = v;
|
||||
prev_grad[idx] = g;
|
||||
prev_update[idx] = final_update;
|
||||
slow_update[idx] = su;
|
||||
fisher_diag[idx] = fd;
|
||||
updates[idx] = final_update;
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user