mirror of
https://github.com/tig-pool-nk/tig-monorepo.git
synced 2026-02-21 15:17:22 +08:00
Add calc_average_distance func to vector_search.
This commit is contained in:
parent
541d5ff0d2
commit
7f3fcddc9c
@ -144,7 +144,7 @@ impl SubInstance {
|
||||
difficulty: &Difficulty,
|
||||
module: Arc<CudaModule>,
|
||||
stream: Arc<CudaStream>,
|
||||
prop: &cudaDeviceProp,
|
||||
_prop: &cudaDeviceProp,
|
||||
) -> Result<Self> {
|
||||
let num_hyperedges = difficulty.num_hyperedges;
|
||||
let target_num_nodes = difficulty.num_hyperedges; // actual number may be around 8% less
|
||||
@ -443,7 +443,7 @@ impl SubInstance {
|
||||
solution: &SubSolution,
|
||||
module: Arc<CudaModule>,
|
||||
stream: Arc<CudaStream>,
|
||||
prop: &cudaDeviceProp,
|
||||
_prop: &cudaDeviceProp,
|
||||
) -> Result<u32> {
|
||||
if solution.partition.len() != self.num_nodes as usize {
|
||||
return Err(anyhow!(
|
||||
|
||||
@ -59,7 +59,7 @@ impl Challenge {
|
||||
difficulty: &Difficulty,
|
||||
module: Arc<CudaModule>,
|
||||
stream: Arc<CudaStream>,
|
||||
prop: &cudaDeviceProp,
|
||||
_prop: &cudaDeviceProp,
|
||||
) -> Result<Self> {
|
||||
let mut rng = StdRng::from_seed(seed.clone());
|
||||
let better_than_baseline = difficulty.better_than_baseline;
|
||||
@ -168,61 +168,18 @@ impl Challenge {
|
||||
module: Arc<CudaModule>,
|
||||
stream: Arc<CudaStream>,
|
||||
prop: &cudaDeviceProp,
|
||||
) -> Result<()> {
|
||||
if solution.indexes.len() != self.difficulty.num_queries as usize {
|
||||
return Err(anyhow!(
|
||||
"Invalid number of indexes. Expected: {}, Actual: {}",
|
||||
self.difficulty.num_queries,
|
||||
solution.indexes.len()
|
||||
));
|
||||
}
|
||||
|
||||
let calc_total_distance_kernel = module.load_function("calc_total_distance")?;
|
||||
|
||||
let d_solution_indexes = stream.memcpy_stod(&solution.indexes)?;
|
||||
let mut d_total_distance = stream.alloc_zeros::<f32>(1)?;
|
||||
let mut errorflag = stream.alloc_zeros::<u32>(1)?;
|
||||
|
||||
let threads_per_block = MAX_THREADS_PER_BLOCK;
|
||||
let blocks =
|
||||
(self.difficulty.num_queries as u32 + threads_per_block - 1) / threads_per_block;
|
||||
|
||||
let cfg = LaunchConfig {
|
||||
grid_dim: (blocks, 1, 1),
|
||||
block_dim: (threads_per_block, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
|
||||
unsafe {
|
||||
stream
|
||||
.launch_builder(&calc_total_distance_kernel)
|
||||
.arg(&(self.vector_dims as u32))
|
||||
.arg(&(self.database_size as u32))
|
||||
.arg(&(self.difficulty.num_queries as u32))
|
||||
.arg(&self.d_query_vectors)
|
||||
.arg(&self.d_database_vectors)
|
||||
.arg(&d_solution_indexes)
|
||||
.arg(&mut d_total_distance)
|
||||
.arg(&mut errorflag)
|
||||
.launch(cfg)?;
|
||||
}
|
||||
|
||||
stream.synchronize()?;
|
||||
|
||||
let total_distance = stream.memcpy_dtov(&d_total_distance)?[0];
|
||||
let error_flag = stream.memcpy_dtov(&errorflag)?[0];
|
||||
|
||||
match error_flag {
|
||||
0 => {}
|
||||
1 => {
|
||||
return Err(anyhow!("Invalid index in solution"));
|
||||
}
|
||||
_ => {
|
||||
return Err(anyhow!("Unknown error code: {}", error_flag));
|
||||
}
|
||||
}
|
||||
|
||||
let avg_dist = total_distance / self.difficulty.num_queries as f32;
|
||||
) -> Result<f32> {
|
||||
let avg_dist = calc_average_distance(
|
||||
self.difficulty.num_queries,
|
||||
self.vector_dims,
|
||||
self.database_size,
|
||||
&self.d_query_vectors,
|
||||
&self.d_database_vectors,
|
||||
&solution.indexes,
|
||||
module.clone(),
|
||||
stream.clone(),
|
||||
prop,
|
||||
)?;
|
||||
if avg_dist > self.max_distance {
|
||||
return Err(anyhow!(
|
||||
"Average query vector distance is '{}'. Max dist: '{}'",
|
||||
@ -230,7 +187,73 @@ impl Challenge {
|
||||
self.max_distance
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
Ok(avg_dist)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn calc_average_distance(
|
||||
num_queries: u32,
|
||||
vector_dims: u32,
|
||||
database_size: u32,
|
||||
d_query_vectors: &CudaSlice<f32>,
|
||||
d_database_vectors: &CudaSlice<f32>,
|
||||
indexes: &Vec<usize>,
|
||||
module: Arc<CudaModule>,
|
||||
stream: Arc<CudaStream>,
|
||||
_prop: &cudaDeviceProp,
|
||||
) -> Result<f32> {
|
||||
if indexes.len() != num_queries as usize {
|
||||
return Err(anyhow!(
|
||||
"Invalid number of indexes. Expected: {}, Actual: {}",
|
||||
num_queries,
|
||||
indexes.len()
|
||||
));
|
||||
}
|
||||
|
||||
let calc_total_distance_kernel = module.load_function("calc_total_distance")?;
|
||||
|
||||
let d_solution_indexes = stream.memcpy_stod(indexes)?;
|
||||
let mut d_total_distance = stream.alloc_zeros::<f32>(1)?;
|
||||
let mut errorflag = stream.alloc_zeros::<u32>(1)?;
|
||||
|
||||
let threads_per_block = MAX_THREADS_PER_BLOCK;
|
||||
let blocks = (num_queries + threads_per_block - 1) / threads_per_block;
|
||||
|
||||
let cfg = LaunchConfig {
|
||||
grid_dim: (blocks, 1, 1),
|
||||
block_dim: (threads_per_block, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
|
||||
unsafe {
|
||||
stream
|
||||
.launch_builder(&calc_total_distance_kernel)
|
||||
.arg(&vector_dims)
|
||||
.arg(&database_size)
|
||||
.arg(&num_queries)
|
||||
.arg(d_query_vectors)
|
||||
.arg(d_database_vectors)
|
||||
.arg(&d_solution_indexes)
|
||||
.arg(&mut d_total_distance)
|
||||
.arg(&mut errorflag)
|
||||
.launch(cfg)?;
|
||||
}
|
||||
|
||||
stream.synchronize()?;
|
||||
|
||||
let total_distance = stream.memcpy_dtov(&d_total_distance)?[0];
|
||||
let error_flag = stream.memcpy_dtov(&errorflag)?[0];
|
||||
|
||||
match error_flag {
|
||||
0 => {}
|
||||
1 => {
|
||||
return Err(anyhow!("Invalid index in solution"));
|
||||
}
|
||||
_ => {
|
||||
return Err(anyhow!("Unknown error code: {}", error_flag));
|
||||
}
|
||||
}
|
||||
|
||||
let avg_dist = total_distance / num_queries as f32;
|
||||
Ok(avg_dist)
|
||||
}
|
||||
|
||||
@ -376,7 +376,7 @@ fn find_best_insertion(
|
||||
best
|
||||
}
|
||||
|
||||
pub fn calc_baseline_routes(
|
||||
fn calc_baseline_routes(
|
||||
num_nodes: usize,
|
||||
max_capacity: i32,
|
||||
demands: &Vec<i32>,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user