Add calc_average_distance func to vector_search.

This commit is contained in:
FiveMovesAhead 2025-05-15 16:30:34 +01:00
parent 541d5ff0d2
commit 7f3fcddc9c
3 changed files with 84 additions and 61 deletions

View File

@ -144,7 +144,7 @@ impl SubInstance {
difficulty: &Difficulty,
module: Arc<CudaModule>,
stream: Arc<CudaStream>,
prop: &cudaDeviceProp,
_prop: &cudaDeviceProp,
) -> Result<Self> {
let num_hyperedges = difficulty.num_hyperedges;
let target_num_nodes = difficulty.num_hyperedges; // actual number may be around 8% less
@ -443,7 +443,7 @@ impl SubInstance {
solution: &SubSolution,
module: Arc<CudaModule>,
stream: Arc<CudaStream>,
prop: &cudaDeviceProp,
_prop: &cudaDeviceProp,
) -> Result<u32> {
if solution.partition.len() != self.num_nodes as usize {
return Err(anyhow!(

View File

@ -59,7 +59,7 @@ impl Challenge {
difficulty: &Difficulty,
module: Arc<CudaModule>,
stream: Arc<CudaStream>,
prop: &cudaDeviceProp,
_prop: &cudaDeviceProp,
) -> Result<Self> {
let mut rng = StdRng::from_seed(seed.clone());
let better_than_baseline = difficulty.better_than_baseline;
@ -168,61 +168,18 @@ impl Challenge {
module: Arc<CudaModule>,
stream: Arc<CudaStream>,
prop: &cudaDeviceProp,
) -> Result<()> {
if solution.indexes.len() != self.difficulty.num_queries as usize {
return Err(anyhow!(
"Invalid number of indexes. Expected: {}, Actual: {}",
self.difficulty.num_queries,
solution.indexes.len()
));
}
let calc_total_distance_kernel = module.load_function("calc_total_distance")?;
let d_solution_indexes = stream.memcpy_stod(&solution.indexes)?;
let mut d_total_distance = stream.alloc_zeros::<f32>(1)?;
let mut errorflag = stream.alloc_zeros::<u32>(1)?;
let threads_per_block = MAX_THREADS_PER_BLOCK;
let blocks =
(self.difficulty.num_queries as u32 + threads_per_block - 1) / threads_per_block;
let cfg = LaunchConfig {
grid_dim: (blocks, 1, 1),
block_dim: (threads_per_block, 1, 1),
shared_mem_bytes: 0,
};
unsafe {
stream
.launch_builder(&calc_total_distance_kernel)
.arg(&(self.vector_dims as u32))
.arg(&(self.database_size as u32))
.arg(&(self.difficulty.num_queries as u32))
.arg(&self.d_query_vectors)
.arg(&self.d_database_vectors)
.arg(&d_solution_indexes)
.arg(&mut d_total_distance)
.arg(&mut errorflag)
.launch(cfg)?;
}
stream.synchronize()?;
let total_distance = stream.memcpy_dtov(&d_total_distance)?[0];
let error_flag = stream.memcpy_dtov(&errorflag)?[0];
match error_flag {
0 => {}
1 => {
return Err(anyhow!("Invalid index in solution"));
}
_ => {
return Err(anyhow!("Unknown error code: {}", error_flag));
}
}
let avg_dist = total_distance / self.difficulty.num_queries as f32;
) -> Result<f32> {
let avg_dist = calc_average_distance(
self.difficulty.num_queries,
self.vector_dims,
self.database_size,
&self.d_query_vectors,
&self.d_database_vectors,
&solution.indexes,
module.clone(),
stream.clone(),
prop,
)?;
if avg_dist > self.max_distance {
return Err(anyhow!(
"Average query vector distance is '{}'. Max dist: '{}'",
@ -230,7 +187,73 @@ impl Challenge {
self.max_distance
));
}
Ok(())
Ok(avg_dist)
}
}
pub fn calc_average_distance(
num_queries: u32,
vector_dims: u32,
database_size: u32,
d_query_vectors: &CudaSlice<f32>,
d_database_vectors: &CudaSlice<f32>,
indexes: &Vec<usize>,
module: Arc<CudaModule>,
stream: Arc<CudaStream>,
_prop: &cudaDeviceProp,
) -> Result<f32> {
if indexes.len() != num_queries as usize {
return Err(anyhow!(
"Invalid number of indexes. Expected: {}, Actual: {}",
num_queries,
indexes.len()
));
}
let calc_total_distance_kernel = module.load_function("calc_total_distance")?;
let d_solution_indexes = stream.memcpy_stod(indexes)?;
let mut d_total_distance = stream.alloc_zeros::<f32>(1)?;
let mut errorflag = stream.alloc_zeros::<u32>(1)?;
let threads_per_block = MAX_THREADS_PER_BLOCK;
let blocks = (num_queries + threads_per_block - 1) / threads_per_block;
let cfg = LaunchConfig {
grid_dim: (blocks, 1, 1),
block_dim: (threads_per_block, 1, 1),
shared_mem_bytes: 0,
};
unsafe {
stream
.launch_builder(&calc_total_distance_kernel)
.arg(&vector_dims)
.arg(&database_size)
.arg(&num_queries)
.arg(d_query_vectors)
.arg(d_database_vectors)
.arg(&d_solution_indexes)
.arg(&mut d_total_distance)
.arg(&mut errorflag)
.launch(cfg)?;
}
stream.synchronize()?;
let total_distance = stream.memcpy_dtov(&d_total_distance)?[0];
let error_flag = stream.memcpy_dtov(&errorflag)?[0];
match error_flag {
0 => {}
1 => {
return Err(anyhow!("Invalid index in solution"));
}
_ => {
return Err(anyhow!("Unknown error code: {}", error_flag));
}
}
let avg_dist = total_distance / num_queries as f32;
Ok(avg_dist)
}

View File

@ -376,7 +376,7 @@ fn find_best_insertion(
best
}
pub fn calc_baseline_routes(
fn calc_baseline_routes(
num_nodes: usize,
max_capacity: i32,
demands: &Vec<i32>,