"Player 0x2ab284965628b39927c69855af0c7a2536b30eff submitted code vector_search/stat_filter_sigma"

This commit is contained in:
0x2ab284965628b39927c69855af0c7a2536b30eff 2025-12-31 16:53:49 +00:00
parent bd333f8f72
commit 09c8a2af24
4 changed files with 1710 additions and 1 deletions

View File

@ -156,7 +156,8 @@
// c004_a079
// c004_a080
pub mod stat_filter_sigma;
pub use stat_filter_sigma as c004_a080;
// c004_a081

View File

@ -0,0 +1,23 @@
# TIG Code Submission
## Submission Details
* **Challenge Name:** vector_search
* **Algorithm Name:** stat_filter_sigma
* **Copyright:** 2025 The Granite Labs LLC
* **Identity of Submitter:** Granite Labs LLC
* **Identity of Creator of Algorithmic Method:** Granite Labs LLC
* **Unique Algorithm Identifier (UAI):** c004_a072
## License
The files in this folder are under the following licenses:
* TIG Benchmarker Outbound License
* TIG Commercial License
* TIG Inbound Game License
* TIG Innovator Outbound Game License
* TIG Open Data License
* TIG THV Game License
Copies of the licenses can be obtained at:
https://github.com/tig-foundation/tig-monorepo/tree/main/docs/licenses

View File

@ -0,0 +1,920 @@
/*!
Copyright 2025 The Granite Labs LLC
Identity of Submitter Granite Labs LLC
Identity of Creator of Algorithmic Method Granite Labs LLC
UAI c004_a072
Licensed under the TIG Inbound Game License v3.0 or (at your option) any later
version (the "License"); you may not use this file except in compliance with the
License. You may obtain a copy of the License at
https://github.com/tig-foundation/tig-monorepo/tree/main/docs/licenses
Unless required by applicable law or agreed to in writing, software distributed
under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
CONDITIONS OF ANY KIND, either express or implied. See the License for the specific
language governing permissions and limitations under the License.
*/
//
// stat_filter
//
// Filtering based on Median Absolute Deviation (MAD):
// We compute the median of all L2 norms, then calculate the MAD (median of
// absolute deviations from the median). The threshold is set to:
// norm_threshold = scale_factor × MAD × 1.4826
// The factor 1.4826 scales MAD to match the standard deviation for normally
// distributed data. This makes the filter more robust to outliers compared to
// filtering methods based on mean and standard deviation, which are more
// sensitive to extreme values.
//
// Reference:
// - NIST Engineering Statistics Handbook:
// https://www.itl.nist.gov/div898/handbook/eda/section3/eda35h.htm
// - See also: https://www.itl.nist.gov/div898/handbook/eda/section3/eda356.htm
//
#include <float.h>
#include <math_constants.h> // defines CUDART_INF_F, CUDART_NAN_F, etc.
//-------------------- Dimension Stats --------------------------
__device__ inline void atomicMaxFloat(float* addr, float val) {
int* addr_i = reinterpret_cast<int*>(addr);
int old = *addr_i, assumed;
do {
assumed = old;
if (__int_as_float(assumed) >= val) break;
old = atomicCAS(addr_i, assumed, __float_as_int(val));
} while (assumed != old);
}
__device__ inline void atomicMinFloat(float* addr, float val) {
int* addr_i = reinterpret_cast<int*>(addr);
int old = *addr_i, assumed;
do {
assumed = old;
if (__int_as_float(assumed) <= val) break;
old = atomicCAS(addr_i, assumed, __float_as_int(val));
} while (assumed != old);
}
// Initialize out_min/out_max
extern "C" __global__ void init_minmax_kernel(
float* __restrict__ out_min,
float* __restrict__ out_max,
int dims,
float min_init, // e.g., +INF
float max_init) // e.g., -INF (or 0 if you know data is >=0)
{
int d = blockIdx.x * blockDim.x + threadIdx.x;
if (d < dims) {
out_min[d] = min_init;
out_max[d] = max_init;
}
}
// Compute per-dim min and max over all vectors
extern "C" __global__ void compute_dim_stats_kernel(
const float* __restrict__ db, // [num_vecs * dims]
float* __restrict__ out_min, // [dims]
float* __restrict__ out_max, // [dims]
int num_vecs,
int dims)
{
int d = threadIdx.x;
for (int v = 0; v < num_vecs; ++v) {
float x = db[(size_t)v * dims + d];
atomicMinFloat(&out_min[d], x);
atomicMaxFloat(&out_max[d], x);
}
}
//-------------------- Calculate Dimension Divisors -------------
// Build per-dimension divisors from min/max.
// Scale the min/max down so we throw away outliers.
#ifndef FRAC_OF_MIN_MAX
//#define FRAC_OF_MIN_MAX 0.90f
#define FRAC_OF_MIN_MAX 0.80f
#endif
extern "C" __global__ void build_u4_divisors_from_minmax_kernel(
float* __restrict__ dim_min, // [dims]
float* __restrict__ dim_max, // [dims]
float* __restrict__ s, // [dims]
int dims)
{
int d = blockIdx.x * blockDim.x + threadIdx.x;
if (d >= dims) return;
float mn = dim_min[d];
float mx = dim_max[d];
float range = mx - mn;
if (!isfinite(range) || range <= 0.0f) {
// Constant or degenerate dim: mark with s[d] = 0 to signal "constant"
s[d] = 0.0f;
return;
}
// Shrink to the central FRAC_OF_MIN_MAX of the range
float mid = 0.5f * (mx + mn);
float half = 0.5f * FRAC_OF_MIN_MAX * range;
mn = mid - half;
mx = mid + half;
// Write back the trimmed bounds so quantization uses them too
dim_min[d] = mn;
dim_max[d] = mx;
// Normal scale: map (trimmed) range into ~16 buckets
float trimmed_range = mx - mn; // == FRAC_OF_MIN_MAX * original range
float step = trimmed_range / 16.0f;
s[d] = step;
}
extern "C" __global__ void build_u2_divisors_from_minmax_kernel(
float* __restrict__ dim_min, // [dims]
float* __restrict__ dim_max, // [dims]
float* __restrict__ s, // [dims]
int dims)
{
int d = blockIdx.x * blockDim.x + threadIdx.x;
if (d >= dims) return;
float mn = dim_min[d];
float mx = dim_max[d];
float range = mx - mn;
if (!isfinite(range) || range <= 0.0f) {
// Constant or degenerate dim: mark with s[d] = 0 to signal "constant"
s[d] = 0.0f;
return;
}
// Same symmetric shrink
float mid = 0.5f * (mx + mn);
float half = 0.5f * FRAC_OF_MIN_MAX * range;
mn = mid - half;
mx = mid + half;
dim_min[d] = mn;
dim_max[d] = mx;
float trimmed_range = mx - mn; // FRAC_OF_MIN_MAX * original
float step = trimmed_range / 4.0f; // 4 levels for u2
s[d] = step;
}
//-------------------- Dimension Aware Conversion ---------------
// Packs two 4-bit codes per byte: even dim -> low nibble, odd dim -> high nibble.
// out size per row = (dims + 1) >> 1 bytes.
extern "C" __global__ void f32_to_u4_packed_perdim_kernel(
const float* __restrict__ in, // [num_vecs * dims], original floats
const float* __restrict__ dim_min, // [dims], per-dim min
const float* __restrict__ s, // [dims], per-dim step = range/16 (or 0)
uint8_t* __restrict__ out, // [num_vecs * ((dims+1)>>1)], packed u4
int num_vecs,
int dims)
{
int row_bytes = (dims + 1) >> 1; // 2 dims per byte
int bi = blockIdx.x * blockDim.x + threadIdx.x;
int total_bytes = num_vecs * row_bytes;
if (bi >= total_bytes) return;
int v = bi / row_bytes; // vector id
int b = bi % row_bytes; // byte index within row
int j0 = (b << 1); // even dim
int j1 = j0 + 1; // odd dim
const float* vin = in + (size_t)v * dims;
// ---- Dim j0 -> low nibble ----
int q0 = 0;
if (j0 < dims) {
float x0 = vin[j0]; // original value (can be negative)
float mn0 = dim_min[j0];
float sj0 = s[j0]; // step = (max-min)/16 or 0
if (sj0 <= 0.0f || !isfinite(sj0)) {
// Degenerate / constant dimension: treat as uninformative, code = 0
q0 = 0;
} else {
float t0 = (x0 - mn0) / sj0; // in ~[0,16]
int q0_lin = __float2int_rn(t0); // 4-bit linear bin
q0 = max(0, min(15, q0_lin));
}
}
// ---- Dim j1 -> high nibble ----
int q1 = 0;
if (j1 < dims) {
float x1 = vin[j1];
float mn1 = dim_min[j1];
float sj1 = s[j1];
if (sj1 <= 0.0f || !isfinite(sj1)) {
q1 = 0;
} else {
float t1 = (x1 - mn1) / sj1;
int q1_lin = __float2int_rn(t1);
q1 = max(0, min(15, q1_lin));
}
}
uint8_t nibble0 = (uint8_t)(q0 & 0x0F); // low nibble
uint8_t nibble1 = (uint8_t)(q1 & 0x0F); // high nibble
out[(size_t)v * row_bytes + b] = (uint8_t)((nibble1 << 4) | nibble0);
}
// Packs four 2-bit codes per byte: dims j0..j3 -> bits [1:0], [3:2], [5:4], [7:6].
// out size per row = (dims + 3) >> 2 bytes.
extern "C" __global__ void f32_to_u2_packed_perdim_kernel(
const float* __restrict__ in, // [num_vecs * dims], original floats
const float* __restrict__ dim_min, // [dims], per-dim min
const float* __restrict__ s, // [dims], per-dim step = range/4 or 0
uint8_t* __restrict__ out, // [num_vecs * ((dims+3)>>2)], packed u2
int num_vecs,
int dims)
{
int row_bytes = (dims + 3) >> 2; // 4 dims per byte
int bi = blockIdx.x * blockDim.x + threadIdx.x;
int total_bytes = num_vecs * row_bytes;
if (bi >= total_bytes) return;
int v = bi / row_bytes; // vector id
int b = bi % row_bytes; // byte index within row
int j0 = (b << 2); // first dim for this byte
const float* vin = in + (size_t)v * dims;
uint8_t packed = 0;
#pragma unroll
for (int k = 0; k < 4; ++k) {
int j = j0 + k;
int q = 0;
if (j < dims) {
float x = vin[j]; // original value
float mn = dim_min[j];
float sj = s[j]; // step = (max-min)/4 or 0
if (sj <= 0.0f || !isfinite(sj)) {
// Degenerate scale: constant dimension -> uninformative
q = 0;
} else {
float t = (x - mn) / sj; // ~[0,4]
int q_lin = __float2int_rn(t); // 0..3
q = max(0, min(3, q_lin));
}
}
packed |= (uint8_t)((q & 0x3) << (2 * k)); // 2 bits each
}
out[(size_t)v * row_bytes + b] = packed;
}
//----------------- Vector Stats Before Conversion ---------------
extern "C" __global__ void compute_vector_stats_kernel(
const float* vectors,
float* norm_l2,
float* norm_l2_squared,
int num_vectors,
const int vector_size
)
{
int i = blockIdx.x * blockDim.x + threadIdx.x;
double norm_sq = 0.0;
if (i < num_vectors) {
int idx = i * vector_size;
for (int j = 0; j < vector_size; j++) {
double v = vectors[idx + j];
norm_sq = fmaf(v, v, norm_sq);
}
norm_l2_squared[i] = norm_sq;
norm_l2[i] = sqrt(norm_sq);
}
}
//----------------- Vector Stats After Conversion ---------------
extern "C" __global__ void compute_vector_stats_u4_packed_kernel(
const uint8_t* __restrict__ vectors_packed, // [num_vecs * ((dims+1)>>1)]
float* __restrict__ norm_l2, // [num_vecs]
float* __restrict__ norm_l2_squared, // [num_vecs]
int num_vecs,
int dims)
{
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= num_vecs) return;
const int row_bytes = (dims + 1) >> 1; // 2 dims per byte
const uint8_t* row = vectors_packed + (size_t)i * row_bytes;
double acc = 0.0;
int j = 0;
// Process full bytes
for (int by = 0; by < row_bytes; ++by) {
uint8_t b = row[by];
// low nibble -> dim j
if (j < dims) {
double v = (double)(b & 0x0Fu);
acc = fma(v, v, acc);
++j;
}
// high nibble -> dim j
if (j < dims) {
double v = (double)(b >> 4);
acc = fma(v, v, acc);
++j;
}
}
float accf = (float)acc;
norm_l2_squared[i] = accf;
norm_l2[i] = sqrtf(accf);
}
extern "C" __global__ void compute_vector_stats_u2_packed_kernel(
const uint8_t* __restrict__ vectors_packed, // [num_vecs * ((dims+3)>>2)]
float* __restrict__ norm_l2, // [num_vecs]
float* __restrict__ norm_l2_squared, // [num_vecs]
int num_vecs,
int dims)
{
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= num_vecs) return;
const int row_bytes = (dims + 3) >> 2; // 4 dims per byte
const uint8_t* row = vectors_packed + (size_t)i * row_bytes;
double acc = 0.0;
int j = 0;
for (int by = 0; by < row_bytes; ++by) {
uint8_t b = row[by];
// dim j (bits 1:0)
if (j < dims) {
double v = (double)((b ) & 0x3u);
acc = fma(v, v, acc);
++j;
}
// dim j+1 (bits 3:2)
if (j < dims) {
double v = (double)((b >> 2) & 0x3u);
acc = fma(v, v, acc);
++j;
}
// dim j+2 (bits 5:4)
if (j < dims) {
double v = (double)((b >> 4) & 0x3u);
acc = fma(v, v, acc);
++j;
}
// dim j+3 (bits 7:6)
if (j < dims) {
double v = (double)((b >> 6) & 0x3u);
acc = fma(v, v, acc);
++j;
}
}
float accf = (float)acc;
norm_l2_squared[i] = accf;
norm_l2[i] = sqrtf(accf);
}
//
//----------------- Nearest Neighbor Search ---------------------
//
#ifndef KMAX
#define KMAX 64
#endif
__device__ __forceinline__ void topk_try_insert(float d, int i, float* best_d, int* best_i, int K) {
if (d >= best_d[K-1]) return;
int pos = K-1;
while (pos > 0 && d < best_d[pos-1]) {
best_d[pos] = best_d[pos-1];
best_i[pos] = best_i[pos-1];
--pos;
}
best_d[pos] = d; best_i[pos] = i;
}
//------------------- 4-BIT bit-sliced -------------------------
extern "C" __global__ void u4_packed_to_bitplanes_rowwise(
const uint8_t* __restrict__ packed, // [num_vecs][(D+1)>>1] ; 2 dims/byte (lo nibble, hi nibble)
unsigned long long* __restrict__ out_b0, // [num_vecs][W] ; bit 0 plane
unsigned long long* __restrict__ out_b1, // [num_vecs][W] ; bit 1 plane
unsigned long long* __restrict__ out_b2, // [num_vecs][W] ; bit 2 plane
unsigned long long* __restrict__ out_b3, // [num_vecs][W] ; bit 3 plane (MSB)
int num_vecs, int D, int W)
{
int v = blockIdx.x * blockDim.x + threadIdx.x;
if (v >= num_vecs) return;
const uint8_t* row = packed + (size_t)v * ((D + 1) >> 1);
for (int w = 0; w < W; ++w) {
unsigned long long b0 = 0ULL, b1 = 0ULL, b2 = 0ULL, b3 = 0ULL;
int j_base = w << 6; // 64 dims per 64b word
#pragma unroll
for (int t = 0; t < 64; ++t) {
int j = j_base + t;
if (j >= D) break;
int by = j >> 1; // 2 dims per byte
uint8_t code = (j & 1)
? (row[by] >> 4) & 0xF // high nibble
: (row[by] ) & 0xF; // low nibble
if (code & 0x1) b0 |= (1ULL << t);
if (code & 0x2) b1 |= (1ULL << t);
if (code & 0x4) b2 |= (1ULL << t);
if (code & 0x8) b3 |= (1ULL << t);
}
out_b0[(size_t)v * W + w] = b0;
out_b1[(size_t)v * W + w] = b1;
out_b2[(size_t)v * W + w] = b2;
out_b3[(size_t)v * W + w] = b3;
}
}
// 4-bit bit-sliced Top-K kernel
extern "C" __global__ void find_topk_neighbors_u4_bitsliced_kernel(
const unsigned long long* __restrict__ q0, // [M][W]
const unsigned long long* __restrict__ q1, // [M][W]
const unsigned long long* __restrict__ q2, // [M][W]
const unsigned long long* __restrict__ q3, // [M][W]
const unsigned long long* __restrict__ x0, // [N][W]
const unsigned long long* __restrict__ x1, // [N][W]
const unsigned long long* __restrict__ x2, // [N][W]
const unsigned long long* __restrict__ x3, // [N][W]
const float* __restrict__ norm_l2, // [N] (bin-space)
const float* __restrict__ norm_l2_squared, // [N] (bin-space)
int* __restrict__ topk_indices, // [M*K]
//float* __restrict__ topk_distances, // [M*K] // Not needed in upper level
const int K,
const float max_distance,
const int vector_database_len, // N
const int query_vectors_len, // M
const int vector_size, // D
const float precomputed_threshold,
const float* __restrict__ query_norm_l2, // [M] (bin-space)
const float* __restrict__ query_norm_l2_squared, // [M] (bin-space)
const int W // words per plane
)
{
int q = blockIdx.x;
if (q >= query_vectors_len) return;
if (K > KMAX) return;
// shared: per-thread heaps + query planes
extern __shared__ unsigned char smem[];
int* sm_idx = (int*)smem;
float* sm_dist = (float*)(sm_idx + blockDim.x * K);
unsigned long long* sm_q0 = (unsigned long long*)(sm_dist + blockDim.x * K);
unsigned long long* sm_q1 = sm_q0 + W;
unsigned long long* sm_q2 = sm_q1 + W;
unsigned long long* sm_q3 = sm_q2 + W;
__shared__ float norm_threshold;
__shared__ float query_norm, query_norm_sq;
__shared__ unsigned long long tail_mask;
if (threadIdx.x == 0) {
norm_threshold = precomputed_threshold;
query_norm_sq = query_norm_l2_squared[q];
query_norm = query_norm_l2[q];
int tail_bits = vector_size & 63;
tail_mask = (tail_bits == 0) ? 0xFFFFFFFFFFFFFFFFULL : ((1ULL << tail_bits) - 1ULL);
}
__syncthreads();
// load query bitplanes into shared
const unsigned long long* Q0 = q0 + (size_t)q * W;
const unsigned long long* Q1 = q1 + (size_t)q * W;
const unsigned long long* Q2 = q2 + (size_t)q * W;
const unsigned long long* Q3 = q3 + (size_t)q * W;
for (int w = threadIdx.x; w < W; w += blockDim.x) {
unsigned long long m = (w == W-1) ? tail_mask : 0xFFFFFFFFFFFFFFFFULL;
sm_q0[w] = Q0[w] & m;
sm_q1[w] = Q1[w] & m;
sm_q2[w] = Q2[w] & m;
sm_q3[w] = Q3[w] & m;
}
__syncthreads();
// thread-local top-K
float tk_dist[KMAX];
int tk_idx[KMAX];
#pragma unroll
for (int t = 0; t < K; ++t) { tk_dist[t] = CUDART_INF_F; tk_idx[t] = -1; }
// scan DB rows owned by this thread
for (int i = threadIdx.x; i < vector_database_len; i += blockDim.x) {
float norm_diff = fabsf(norm_l2[i] - query_norm);
if (norm_diff > norm_threshold) continue;
const unsigned long long* X0 = x0 + (size_t)i * W;
const unsigned long long* X1 = x1 + (size_t)i * W;
const unsigned long long* X2 = x2 + (size_t)i * W;
const unsigned long long* X3 = x3 + (size_t)i * W;
int c00=0,c01=0,c02=0,c03=0,
c10=0,c11=0,c12=0,c13=0,
c20=0,c21=0,c22=0,c23=0,
c30=0,c31=0,c32=0,c33=0;
#pragma unroll
for (int w = 0; w < W; ++w) {
unsigned long long m = (w == W-1) ? tail_mask : 0xFFFFFFFFFFFFFFFFULL;
unsigned long long q0w = sm_q0[w];
unsigned long long q1w = sm_q1[w];
unsigned long long q2w = sm_q2[w];
unsigned long long q3w = sm_q3[w];
unsigned long long x0w = X0[w] & m;
unsigned long long x1w = X1[w] & m;
unsigned long long x2w = X2[w] & m;
unsigned long long x3w = X3[w] & m;
c00 += __popcll(q0w & x0w);
c01 += __popcll(q0w & x1w);
c02 += __popcll(q0w & x2w);
c03 += __popcll(q0w & x3w);
c10 += __popcll(q1w & x0w);
c11 += __popcll(q1w & x1w);
c12 += __popcll(q1w & x2w);
c13 += __popcll(q1w & x3w);
c20 += __popcll(q2w & x0w);
c21 += __popcll(q2w & x1w);
c22 += __popcll(q2w & x2w);
c23 += __popcll(q2w & x3w);
c30 += __popcll(q3w & x0w);
c31 += __popcll(q3w & x1w);
c32 += __popcll(q3w & x2w);
c33 += __popcll(q3w & x3w);
}
// dot = Σ_{i=0..3} Σ_{j=0..3} 2^(i+j) * cij
int dot_i =
(1 * c00)
+ (2 * (c01 + c10))
+ (4 * (c02 + c20 + c11))
+ (8 * (c03 + c30 + c12 + c21))
+ (16 * (c13 + c31 + c22))
+ (32 * (c23 + c32))
+ (64 * c33);
float dot = (float)dot_i;
float d2 = query_norm_sq + norm_l2_squared[i] - 2.0f * dot;
d2 = fmaxf(d2, 0.0f);
topk_try_insert(d2, i, tk_dist, tk_idx, K);
}
// spill & merge per-thread candidates
int base = threadIdx.x * K;
#pragma unroll
for (int t = 0; t < K; ++t) {
sm_idx [base + t] = tk_idx[t];
sm_dist[base + t] = tk_dist[t];
}
__syncthreads();
if (threadIdx.x == 0) {
float best_d[KMAX];
int best_i[KMAX];
#pragma unroll
for (int t = 0; t < K; ++t) { best_d[t] = CUDART_INF_F; best_i[t] = -1; }
int Nspill = blockDim.x * K;
for (int n = 0; n < Nspill; ++n) {
float d = sm_dist[n];
int i = sm_idx[n];
if (i >= 0 && isfinite(d)) topk_try_insert(d, i, best_d, best_i, K);
}
for (int a = 0; a < K-1; ++a)
for (int b = a+1; b < K; ++b)
if (best_d[b] < best_d[a]) {
float td=best_d[a]; best_d[a]=best_d[b]; best_d[b]=td;
int ti=best_i[a]; best_i[a]=best_i[b]; best_i[b]=ti;
}
int out = q * K;
for (int t = 0; t < K; ++t) {
topk_indices [out + t] = best_i[t];
//topk_distances[out + t] = best_d[t];
}
}
}
//------------------- 2-BIT bit-sliced -------------------------
// packed: 4 dims per byte, low→high 2b fields
extern "C" __global__ void u2_packed_to_bitplanes_rowwise(
const uint8_t* __restrict__ packed, // [num_vecs][(D+3)>>2]
unsigned long long* __restrict__ out_b0, // [num_vecs][W]
unsigned long long* __restrict__ out_b1, // [num_vecs][W]
int num_vecs, int D, int W)
{
int v = blockIdx.x * blockDim.x + threadIdx.x;
if (v >= num_vecs) return;
const uint8_t* row = packed + (size_t)v * ((D + 3) >> 2);
unsigned long long* b0 = out_b0 + (size_t)v * W;
unsigned long long* b1 = out_b1 + (size_t)v * W;
for (int w = 0; w < W; ++w) {
unsigned long long word0 = 0ULL, word1 = 0ULL;
int j_base = w << 6; // 64 dims per word
for (int t = 0; t < 64; ++t) {
int j = j_base + t;
if (j >= D) break;
int by = j >> 2; // 4 dims per byte
int off = (j & 3) << 1; // 2 bits per dim
uint8_t code = (row[by] >> off) & 0x3;
if (code & 0x1) word0 |= (1ULL << t);
if (code & 0x2) word1 |= (1ULL << t);
}
b0[w] = word0;
b1[w] = word1;
}
}
// Each vector uses two bitplanes:
// - plane 0 (LSB): B0, plane 1 (MSB): B1
// Layout per set (queries or DB): [num_vecs][words_per_plane],
// where words_per_plane = ceil(vector_size / 64.0).
// For SIFT-128 → words_per_plane = 2.
extern "C" __global__ void find_topk_neighbors_u2_bitsliced_kernel(
const unsigned long long* __restrict__ query_b0, // [M][W]
const unsigned long long* __restrict__ query_b1, // [M][W]
const unsigned long long* __restrict__ db_b0, // [N][W]
const unsigned long long* __restrict__ db_b1, // [N][W]
const float* __restrict__ norm_l2, // [N] (bin-space norms, sqrt)
const float* __restrict__ norm_l2_squared, // [N] (bin-space norms^2)
int* __restrict__ topk_indices, // [M*K]
//float* __restrict__ topk_distances, // [M*K]
const int K,
const float max_distance,
const int vector_database_len, // N
const int query_vectors_len, // M
const int vector_size, // D
const float precomputed_threshold,
const float* __restrict__ query_norm_l2, // [M] (bin-space)
const float* __restrict__ query_norm_l2_squared, // [M] (bin-space)
const int words_per_plane // W = (D+63)>>6
)
{
const int q = blockIdx.x; // one query per block
//int q = blockIdx.x;
if (q >= query_vectors_len) return;
if (K > KMAX) return; // or assert
// Shared: per-thread candidate spill + query bitplanes
extern __shared__ unsigned char smem[];
int* sm_idx = (int*)smem;
float* sm_dist = (float*)(sm_idx + blockDim.x * K);
// Place query planes after the per-thread scratch:
unsigned long long* sm_q0 = (unsigned long long*)(sm_dist + blockDim.x * K);
unsigned long long* sm_q1 = sm_q0 + words_per_plane;
__shared__ float norm_threshold;
__shared__ float query_norm_sq;
__shared__ float query_norm;
__shared__ unsigned long long tail_mask;
if (threadIdx.x == 0) {
norm_threshold = precomputed_threshold;
query_norm_sq = query_norm_l2_squared[q];
query_norm = query_norm_l2[q];
int tail_bits = vector_size & 63; // D % 64
tail_mask = (tail_bits == 0) ? 0xFFFFFFFFFFFFFFFFULL
: ((1ULL << tail_bits) - 1ULL);
}
__syncthreads();
// Load query bitplanes into shared (rowwise)
const unsigned long long* q0 = query_b0 + (size_t)q * words_per_plane;
const unsigned long long* q1 = query_b1 + (size_t)q * words_per_plane;
for (int w = threadIdx.x; w < words_per_plane; w += blockDim.x) {
unsigned long long q0w = q0[w];
unsigned long long q1w = q1[w];
// Mask tail word to ensure no stray bits are counted
if (w == words_per_plane - 1) {
q0w &= tail_mask;
q1w &= tail_mask;
}
sm_q0[w] = q0w;
sm_q1[w] = q1w;
}
__syncthreads();
// Thread-local Top-K
float tk_dist[KMAX];
int tk_idx[KMAX];
#pragma unroll
for (int t = 0; t < K; ++t) { tk_dist[t] = CUDART_INF_F; tk_idx[t] = -1; }
// Scan DB rows owned by this thread
for (int i = threadIdx.x; i < vector_database_len; i += blockDim.x) {
// -- Rowwise --
// Norm prefilter in the SAME (bin) space
float norm_diff = fabsf(norm_l2[i] - query_norm);
if (norm_diff > norm_threshold) continue;
const unsigned long long* x0 = db_b0 + (size_t)i * words_per_plane;
const unsigned long long* x1 = db_b1 + (size_t)i * words_per_plane;
int c00 = 0, c01 = 0, c10 = 0, c11 = 0;
// Two 64b words for SIFT-128 (general W supported)
#pragma unroll
for (int w = 0; w < words_per_plane; ++w) {
// Mask tail for the last word
unsigned long long mask = (w == words_per_plane - 1) ? tail_mask : 0xFFFFFFFFFFFFFFFFULL;
unsigned long long qw0 = sm_q0[w];
unsigned long long qw1 = sm_q1[w];
unsigned long long xw0 = x0[w] & mask;
unsigned long long xw1 = x1[w] & mask;
c00 += __popcll(qw0 & xw0);
c01 += __popcll(qw0 & xw1);
c10 += __popcll(qw1 & xw0);
c11 += __popcll(qw1 & xw1);
}
// 2-bit dot product in bin space
int dot_i = c00 + 2 * (c01 + c10) + 4 * c11;
float dot = (float)dot_i;
float d2 = query_norm_sq + norm_l2_squared[i] - 2.0f * dot;
d2 = fmaxf(d2, 0.0f); // robust to tiny underflow
topk_try_insert(d2, i, tk_dist, tk_idx, K);
} // end for i
// Spill per-thread candidates
int base = threadIdx.x * K;
#pragma unroll
for (int t = 0; t < K; ++t) {
sm_idx [base + t] = tk_idx[t];
sm_dist[base + t] = tk_dist[t];
}
__syncthreads();
// Merge to block top-K
if (threadIdx.x == 0) {
float best_d[KMAX];
int best_i[KMAX];
#pragma unroll
for (int t = 0; t < K; ++t) { best_d[t] = CUDART_INF_F; best_i[t] = -1; }
int Nspill = blockDim.x * K;
for (int n = 0; n < Nspill; ++n) {
float d = sm_dist[n];
int i = sm_idx[n];
if (i >= 0 && isfinite(d)) topk_try_insert(d, i, best_d, best_i, K);
}
// Optional small sort for tidy output
for (int a = 0; a < K-1; ++a)
for (int b = a+1; b < K; ++b)
if (best_d[b] < best_d[a]) {
float td=best_d[a]; best_d[a]=best_d[b]; best_d[b]=td;
int ti=best_i[a]; best_i[a]=best_i[b]; best_i[b]=ti;
}
int out_base = q * K;
for (int t = 0; t < K; ++t) {
topk_indices [out_base + t] = best_i[t];
//topk_distances[out_base + t] = best_d[t];
}
}
}
//-------------------- Rerank Top K --------------------------
extern "C" __global__ void refine_topk_rerank_kernel(
const float* __restrict__ query_vectors, // [num_queries * dim]
const float* __restrict__ db_vectors, // [db_len * dim]
const int* __restrict__ candidates, // [num_queries * K]
int* __restrict__ out_index, // [num_queries]
float* __restrict__ out_distance, // [num_queries] (squared L2)
const int num_queries,
const int dim,
const int K
)
{
int q = blockIdx.x;
if (q >= num_queries) return;
extern __shared__ unsigned char shared[];
float* sm_q = reinterpret_cast<float*>(shared);
float* red = sm_q + dim; // reduction buffer, length = blockDim.x
// Cache query vector into shared memory
for (int j = threadIdx.x; j < dim; j += blockDim.x) {
sm_q[j] = query_vectors[q * dim + j];
}
__syncthreads();
float best_d = FLT_MAX;
int best_i = -1;
// For each candidate, compute exact squared L2 distance in parallel
for (int t = 0; t < K; ++t) {
int db_idx = candidates[q * K + t];
if (db_idx < 0) continue;
const float* db = &db_vectors[db_idx * dim];
// Partial sum over dimensions (strided by thread)
float sum = 0.0f;
for (int j = threadIdx.x; j < dim; j += blockDim.x) {
float diff = sm_q[j] - db[j];
sum = fmaf(diff, diff, sum);
}
// Block-wide reduction into red[0]
red[threadIdx.x] = sum;
__syncthreads();
for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
if (threadIdx.x < stride) {
red[threadIdx.x] += red[threadIdx.x + stride];
}
__syncthreads();
}
if (threadIdx.x == 0) {
float d = red[0];
if (d < best_d) { best_d = d; best_i = db_idx; }
}
__syncthreads();
}
if (threadIdx.x == 0) {
out_index[q] = best_i;
out_distance[q] = best_d;
}
}

View File

@ -0,0 +1,765 @@
// TIG's UI uses the pattern `tig_challenges::<challenge_name>` to automatically detect your algorithm's challenge
// when launching kernels, you should not exceed this const or else it may not be deterministic
//const MAX_THREADS_PER_BLOCK: u32 = 1024;
//
// stat_filter
//
// Filtering based on Median Absolute Deviation (MAD):
// We compute the median of all L2 norms, then calculate the MAD (median of
// absolute deviations from the median). The threshold is set to:
// norm_threshold = scale_factor × MAD × 1.4826
// The factor 1.4826 scales MAD to match the standard deviation for normally
// distributed data. This makes the filter more robust to outliers compared to
// filtering methods based on mean and standard deviation, which are more
// sensitive to extreme values.
//
// Reference:
// - NIST Engineering Statistics Handbook:
// https://www.itl.nist.gov/div898/handbook/eda/section3/eda35h.htm
// - See also: https://www.itl.nist.gov/div898/handbook/eda/section3/eda356.htm
//
//use crate::{seeded_hasher, HashMap, HashSet};
use anyhow::{anyhow, Result};
use cudarc::driver::PushKernelArg;
use cudarc::{
driver::{CudaModule, CudaStream, LaunchConfig},
runtime::sys::cudaDeviceProp,
};
use std::sync::Arc;
use serde::{Deserialize, Serialize};
use serde_json::{Map, Value};
use tig_challenges::vector_search::*;
// use std::env;
const MAD_SCALE_NORMAL: f32 = 1.4826;
const MAX_THREADS_PER_BLOCK: u32 = 1024;
// Default K for Top-K retrieval (must be <= kernel KMAX)
pub const DEFAULT_TOP_K: usize = 10;
// Maximum K supported by the CUDA kernels (must match KMAX in kernels.cu).
const TOPK_MAX: usize = 64;
//pub const DEFAULT_TOP_K: usize = 25;
//pub const DEFAULT_TOP_K: usize = 40;
// Default bit mode (4 or 2)
//pub const DEFAULT_BIT_MODE: usize = 4;
#[derive(Copy, Clone, Debug)]
enum BitMode {
U2,
U4,
}
impl BitMode {
// #[inline]
// fn from_env() -> Self {
// match std::env::var("STATFILT_BIT_MODE").ok().and_then(|s| s.trim().parse::<usize>().ok()) {
// Some(2) => BitMode::U2,
// _ => BitMode::U4, // default
// }
// }
#[inline]
fn bits(self) -> usize {
match self {
BitMode::U2 => 2,
BitMode::U4 => 4,
}
}
#[inline]
fn planes(self) -> usize {
self.bits()
} // one bit-plane per bit
#[inline]
fn as_str(self) -> &'static str {
match self {
BitMode::U2 => "2",
BitMode::U4 => "4",
}
}
#[inline]
fn k(self, template: &str) -> String {
// Replace "{b}" with "2" or "4" in a kernel name template.
template.replace("{b}", self.as_str())
}
}
//
// Hyperparameter configuration passed from TIG (MAD scale, bit mode, internal top-k).
//
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
struct Hyperparameters {
/// Internal top-k used during the bitsliced search and refinement.
top_k: usize,
/// Number of bits per dimension to use (2 or 4).
bit_mode: usize,
/// MAD scale factor. Values in (0, 5) enable MAD filtering,
/// values >= 5 disable it. A value of 0.0 enables "legacy auto-MAD"
/// mode, which computes the scale dynamically based on num_queries.
mad_scale: f32,
}
impl Default for Hyperparameters {
fn default() -> Self {
Self {
top_k: DEFAULT_TOP_K, // 10
bit_mode: 2, // 2-bit
// 5.0 effectively disables MAD filtering with the current threshold logic.
mad_scale: 5.0,
}
}
}
//
// ================ Solve Challenge Function ================
pub fn solve_challenge(
challenge: &Challenge,
save_solution: &dyn Fn(&Solution) -> anyhow::Result<()>,
hyperparameters: &Option<Map<String, Value>>,
module: Arc<CudaModule>,
stream: Arc<CudaStream>,
prop: &cudaDeviceProp,
) -> anyhow::Result<()> {
// Parse hyperparameters (if provided) or fall back to defaults.
let hparams: Hyperparameters = match hyperparameters {
Some(hyperparameters) => {
serde_json::from_value::<Hyperparameters>(Value::Object(hyperparameters.clone()))
.map_err(|e| anyhow!("Failed to parse hyperparameters: {}", e))?
}
None => Hyperparameters::default(),
};
// Validate bit_mode; we only support 2-bit and 4-bit variants.
if hparams.bit_mode != 2 && hparams.bit_mode != 4 {
return Err(anyhow!(
"Invalid bit_mode: {}. Must be 2 or 4.",
hparams.bit_mode
));
}
println!("Searching {} DB vectors of length {} for {} queries",challenge.database_size,challenge.vector_dims,challenge.difficulty.num_queries);
// let start_time_total = std::time::Instant::now();
// Get top-k value to use (hyperparameter, clamped to what the kernels support)
let mut topk = hparams.top_k.max(1);
let max_topk = std::cmp::min(challenge.database_size as usize, TOPK_MAX);
if topk > max_topk {
topk = max_topk;
}
// Get bit mode value to use from hyperparameters (2-bit or 4-bit)
let mode = match hparams.bit_mode {
2 => BitMode::U2,
4 => BitMode::U4,
_ => unreachable!(), // validated above
};
//println!("mode = {} bits; topk = {}", mode.bits(), topk);
// Allocations for dimension statistics
let mut d_db_dim_min = stream.alloc_zeros::<f32>(challenge.vector_dims as usize)?;
let mut d_db_dim_max = stream.alloc_zeros::<f32>(challenge.vector_dims as usize)?;
let d_s = stream.alloc_zeros::<f32>(challenge.vector_dims as usize)?;
// Allocations for norms
let d_db_norm_l2 = stream.alloc_zeros::<f32>(challenge.database_size as usize)?;
let d_db_norm_l2_squared = stream.alloc_zeros::<f32>(challenge.database_size as usize)?;
let d_query_norm_l2 = stream.alloc_zeros::<f32>(challenge.difficulty.num_queries as usize)?;
let d_query_norm_l2_squared =
stream.alloc_zeros::<f32>(challenge.difficulty.num_queries as usize)?;
// Total number of elements in DB and queries
//let num_db_el = challenge.database_size * challenge.vector_dims;
//let num_qv_el = challenge.difficulty.num_queries * challenge.vector_dims;
// ---------- Packed buffers ----------
let dims = challenge.vector_dims as usize;
let n_db = challenge.database_size as usize;
let n_q = challenge.difficulty.num_queries as usize;
// Packed bytes per row for N-bit values = ceil(dims * bits / 8)
let row_bytes = (dims * mode.bits() + 7) >> 3;
let num_db_bytes = n_db * row_bytes;
let num_qv_bytes = n_q * row_bytes;
// Allocate packed outputs
let d_db_packed = stream.alloc_zeros::<u8>(num_db_bytes)?;
let d_qv_packed = stream.alloc_zeros::<u8>(num_qv_bytes)?;
/*
let d_db_packed = alloc_uninit::<u8>(&stream, num_db_bytes)?;
let d_qv_packed = alloc_uninit::<u8>(&stream, num_qv_bytes)?;
*/
// load kernels
let init_minmax_kernel = module.load_function("init_minmax_kernel")?;
let compute_dim_stats_kernel = module.load_function("compute_dim_stats_kernel")?;
// launch config (use counts of VECTORS for stats kernels)
let threads_db: u32 = 256;
let blocks_db: u32 = ((challenge.database_size as u32) + threads_db - 1) / threads_db;
let threads_init: u32 = 256;
let blocks_init: u32 = ((challenge.vector_dims as u32) + threads_init - 1) / threads_init;
// initialize min/max arrays on device
let min_init: f32 = f32::INFINITY;
let max_init: f32 = f32::NEG_INFINITY; // or 0.0 if values are known >= 0
unsafe {
stream
.launch_builder(&init_minmax_kernel)
.arg(&mut d_db_dim_min)
.arg(&mut d_db_dim_max)
.arg(&challenge.vector_dims)
.arg(&min_init)
.arg(&max_init)
.launch(LaunchConfig {
grid_dim: (blocks_init, 1, 1),
block_dim: (threads_init, 1, 1),
shared_mem_bytes: 0,
})?;
}
// compute per-dim min & max... scan original data
unsafe {
stream
.launch_builder(&compute_dim_stats_kernel)
.arg(&challenge.d_database_vectors) // const float* db
.arg(&mut d_db_dim_min) // float* out_min
.arg(&mut d_db_dim_max) // float* out_max
.arg(&challenge.database_size) // num_vecs
.arg(&challenge.vector_dims) // dims
.launch(LaunchConfig {
grid_dim: (1, 1, 1),
block_dim: (challenge.vector_dims as u32, 1, 1),
shared_mem_bytes: 0,
})?;
}
//stream.synchronize()?;
//
// ---------- Compute Dimensional Stats ----------
//
let threads_db: u32 = 256;
//let blocks_db: u32 = ((num_db_el as u32) + threads_db - 1) / threads_db;
let blocks_db: u32 = ((challenge.vector_dims as u32) + threads_db - 1) / threads_db;
let threads_qv: u32 = 256;
//let blocks_qv: u32 = ((num_qv_el as u32) + threads_qv - 1) / threads_qv;
let blocks_qv: u32 = ((challenge.vector_dims as u32) + threads_qv - 1) / threads_qv;
// Calculate the per-dim divisors based on min/max
let build_divisors_from_minmax_kernel =
module.load_function(&mode.k("build_u{b}_divisors_from_minmax_kernel"))?;
let cfg_db_dm = LaunchConfig {
grid_dim: (blocks_db, 1, 1),
block_dim: (threads_db, 1, 1),
shared_mem_bytes: 0,
};
unsafe {
stream
.launch_builder(&build_divisors_from_minmax_kernel)
.arg(&d_db_dim_min)
.arg(&d_db_dim_max)
.arg(&d_s)
.arg(&challenge.vector_dims)
.launch(cfg_db_dm)?;
}
//stream.synchronize()?;
//
// ---------- Convert input data by packing into bits ----------
//
let f32_to_packed_perdim_kernel =
module.load_function(&mode.k("f32_to_u{b}_packed_perdim_kernel"))?;
// DB
let threads_db: u32 = 256;
let blocks_db: u32 = ((num_db_bytes as u32) + threads_db - 1) / threads_db;
let cfg_db = LaunchConfig {
grid_dim: (blocks_db, 1, 1),
block_dim: (threads_db, 1, 1),
shared_mem_bytes: 0,
};
unsafe {
stream
.launch_builder(&f32_to_packed_perdim_kernel)
.arg(&challenge.d_database_vectors) // const float* in [num_db * D]
.arg(&d_db_dim_min) // float* in_min
.arg(&d_s) // const float* s [D]
.arg(&d_db_packed) // uint8_t* out [num_db * ((D+1)>>1)]
.arg(&challenge.database_size) // num_vecs
.arg(&challenge.vector_dims) // dims
.launch(cfg_db)?;
}
// Queries
let threads_qv: u32 = 256;
let blocks_qv: u32 = ((num_qv_bytes as u32) + threads_qv - 1) / threads_qv;
let cfg_qv = LaunchConfig {
grid_dim: (blocks_qv, 1, 1),
block_dim: (threads_qv, 1, 1),
shared_mem_bytes: 0,
};
unsafe {
stream
.launch_builder(&f32_to_packed_perdim_kernel)
.arg(&challenge.d_query_vectors) // const float* in [num_query * D]
.arg(&d_db_dim_min) // float* in_min
.arg(&d_s)
.arg(&d_qv_packed)
.arg(&challenge.difficulty.num_queries)
.arg(&challenge.vector_dims)
.launch(cfg_qv)?;
}
//stream.synchronize()?;
//
// ---------- Compute Vector Stats ----------
//
let compute_vector_stats_packed_kernel =
module.load_function(&mode.k("compute_vector_stats_u{b}_packed_kernel"))?;
let threads_per_block_stats = prop.maxThreadsPerBlock as u32;
let num_blocks_db =
(challenge.database_size + threads_per_block_stats - 1) / threads_per_block_stats;
let cfg_stats = LaunchConfig {
grid_dim: (num_blocks_db, 1, 1),
block_dim: (threads_per_block_stats, 1, 1),
shared_mem_bytes: 0,
};
// DB norms
unsafe {
stream
.launch_builder(&compute_vector_stats_packed_kernel)
.arg(&d_db_packed) // const uint8_t* packed [num_db * ((D+1)>>1)]
.arg(&d_db_norm_l2) // float* norm_l2 [num_db]
.arg(&d_db_norm_l2_squared) // float* norm_l2_sq [num_db]
.arg(&challenge.database_size) // num_vecs
.arg(&challenge.vector_dims) // dims
.launch(cfg_stats)?;
}
// Query norms
let num_blocks_qv =
(challenge.difficulty.num_queries + threads_per_block_stats - 1) / threads_per_block_stats;
let cfg_stats_qv = LaunchConfig {
grid_dim: (num_blocks_qv, 1, 1),
block_dim: (threads_per_block_stats, 1, 1),
shared_mem_bytes: 0,
};
unsafe {
stream
.launch_builder(&compute_vector_stats_packed_kernel)
.arg(&d_qv_packed)
.arg(&d_query_norm_l2)
.arg(&d_query_norm_l2_squared)
.arg(&challenge.difficulty.num_queries)
.arg(&challenge.vector_dims)
.launch(cfg_stats_qv)?;
}
//stream.synchronize()?;
// let elapsed_time_ms_1 = start_time_total.elapsed().as_micros() as f32 / 1000.0;
//
// ---------- Compute MAD Stats ----------
//
let mut norm_threshold: f32 = f32::MAX;
// Determine MAD scale: 0.0 means "legacy auto-MAD" (compute from num_queries)
let scale = if hparams.mad_scale == 0.0 {
scale_factor(challenge.difficulty.num_queries as usize)
} else {
hparams.mad_scale
};
//println!("stat_filter scale: {}", scale);
// Only compute and apply MAD if within range
if scale > 0.0 && scale < 5.0 {
// MAD threshold on DB norms (unchanged logic)
let mut h_norms = stream.memcpy_dtov(&d_db_norm_l2)?;
h_norms.sort_by(|a, b| a.partial_cmp(b).unwrap());
let mid = h_norms.len() / 2;
let median = if h_norms.len() % 2 == 0 {
(h_norms[mid - 1] + h_norms[mid]) / 2.0
} else {
h_norms[mid]
};
let mut deviations: Vec<f32> = h_norms.iter().map(|&x| (x - median).abs()).collect();
deviations.sort_by(|a, b| a.partial_cmp(b).unwrap());
let mad = if deviations.len() % 2 == 0 {
(deviations[mid - 1] + deviations[mid]) / 2.0
} else {
deviations[mid]
};
norm_threshold = scale * mad * MAD_SCALE_NORMAL;
}
// let elapsed_time_ms_2 = start_time_total.elapsed().as_micros() as f32 / 1000.0;
//
// ---------- Search ----------
//
// --- TopK outputs ---
let mut d_topk_indices =
stream.alloc_zeros::<i32>((challenge.difficulty.num_queries as usize) * topk)?;
// Save some memory -- we don't use this output
//let mut d_topk_dist =
// stream.alloc_zeros::<f32>((challenge.difficulty.num_queries as usize) * topk)?;
// --- Geometry ---
let words_per_plane = ((dims + 63) >> 6) as usize; // W
let words_per_plane_i32 = words_per_plane as i32;
// --- Shared memory sizing for Top-K ---
// Per-thread spill for heap:
let per_thread_bytes = topk * (std::mem::size_of::<i32>() + std::mem::size_of::<f32>());
// 4 planes * W words * 8B per word
//let base_query_bytes = 4 * words_per_plane * std::mem::size_of::<u64>();
let base_query_bytes = mode.planes() * words_per_plane * std::mem::size_of::<u64>();
let smem_limit = prop.sharedMemPerBlock as usize;
let mut threads_per_block: usize = 256;
while base_query_bytes + threads_per_block * per_thread_bytes > smem_limit
&& threads_per_block > 32
{
threads_per_block >>= 1;
}
if base_query_bytes + threads_per_block * per_thread_bytes > smem_limit {
return Err(anyhow!(
"Insufficient shared memory for topk={} with dims={} (need ~{}B, have {}B)",
topk,
challenge.vector_dims,
base_query_bytes + threads_per_block * per_thread_bytes,
smem_limit
));
}
let threads_per_block = threads_per_block as u32;
let shared_mem_bytes =
(base_query_bytes + (threads_per_block as usize) * per_thread_bytes) as u32;
let cfg_topk = LaunchConfig {
grid_dim: (challenge.difficulty.num_queries, 1, 1),
block_dim: (threads_per_block, 1, 1),
shared_mem_bytes: shared_mem_bytes,
};
let k_i32: i32 = topk as i32;
// --- Convert packed -> bitplanes ---
let packed_to_bitplanes_rowwise = module.load_function(&mode.k("u{b}_packed_to_bitplanes_rowwise"))?;
let blk_conv = (256u32, 1u32, 1u32);
let grd_db = (((n_db as u32) + 255) / 256, 1, 1);
let grd_q = (((n_q as u32) + 255) / 256, 1, 1);
let find_topk_neighbors_bitsliced_kernel =
module.load_function(&mode.k("find_topk_neighbors_u{b}_bitsliced_kernel"))?;
// let mut elapsed_time_ms_3 = start_time_total.elapsed().as_micros() as f32 / 1000.0;
match mode {
BitMode::U2 => {
let mut d_db_b0 = stream.alloc_zeros::<u64>(n_db * words_per_plane)?;
let mut d_db_b1 = stream.alloc_zeros::<u64>(n_db * words_per_plane)?;
let mut d_q_b0 = stream.alloc_zeros::<u64>(n_q * words_per_plane)?;
let mut d_q_b1 = stream.alloc_zeros::<u64>(n_q * words_per_plane)?;
/*
let mut d_db_b0 = alloc_uninit::<u64>(&stream, n_db * words_per_plane)?;
let mut d_db_b1 = alloc_uninit::<u64>(&stream, n_db * words_per_plane)?;
let mut d_q_b0 = alloc_uninit::<u64>(&stream, n_q * words_per_plane)?;
let mut d_q_b1 = alloc_uninit::<u64>(&stream, n_q * words_per_plane)?;
*/
unsafe {
// DB
stream
.launch_builder(&packed_to_bitplanes_rowwise)
.arg(&d_db_packed)
.arg(&mut d_db_b0)
.arg(&mut d_db_b1)
.arg(&(challenge.database_size))
.arg(&(challenge.vector_dims))
.arg(&words_per_plane_i32)
.launch(LaunchConfig {
grid_dim: grd_db,
block_dim: blk_conv,
shared_mem_bytes: 0,
})?;
// Q
stream
.launch_builder(&packed_to_bitplanes_rowwise)
.arg(&d_qv_packed)
.arg(&mut d_q_b0)
.arg(&mut d_q_b1)
.arg(&(challenge.difficulty.num_queries))
.arg(&(challenge.vector_dims))
.arg(&words_per_plane_i32)
.launch(LaunchConfig {
grid_dim: grd_q,
block_dim: blk_conv,
shared_mem_bytes: 0,
})?;
}
// elapsed_time_ms_3 = start_time_total.elapsed().as_micros() as f32 / 1000.0;
// launch top-k with 2 plane args
unsafe {
stream
.launch_builder(&find_topk_neighbors_bitsliced_kernel)
.arg(&d_q_b0)
.arg(&d_q_b1)
.arg(&d_db_b0)
.arg(&d_db_b1)
.arg(&d_db_norm_l2)
.arg(&d_db_norm_l2_squared)
.arg(&mut d_topk_indices)
//.arg(&mut d_topk_dist) // Save some memory -- we don't use this output
.arg(&k_i32)
.arg(&challenge.max_distance)
.arg(&challenge.database_size)
.arg(&challenge.difficulty.num_queries)
.arg(&challenge.vector_dims)
.arg(&norm_threshold)
.arg(&d_query_norm_l2)
.arg(&d_query_norm_l2_squared)
.arg(&words_per_plane_i32)
.launch(cfg_topk)?;
}
}
BitMode::U4 => {
let mut d_db_b0 = stream.alloc_zeros::<u64>(n_db * words_per_plane)?;
let mut d_db_b1 = stream.alloc_zeros::<u64>(n_db * words_per_plane)?;
let mut d_db_b2 = stream.alloc_zeros::<u64>(n_db * words_per_plane)?;
let mut d_db_b3 = stream.alloc_zeros::<u64>(n_db * words_per_plane)?;
let mut d_q_b0 = stream.alloc_zeros::<u64>(n_q * words_per_plane)?;
let mut d_q_b1 = stream.alloc_zeros::<u64>(n_q * words_per_plane)?;
let mut d_q_b2 = stream.alloc_zeros::<u64>(n_q * words_per_plane)?;
let mut d_q_b3 = stream.alloc_zeros::<u64>(n_q * words_per_plane)?;
/*
let mut d_db_b0 = alloc_uninit::<u64>(&stream, n_db * words_per_plane)?;
let mut d_db_b1 = alloc_uninit::<u64>(&stream, n_db * words_per_plane)?;
let mut d_db_b2 = alloc_uninit::<u64>(&stream, n_db * words_per_plane)?;
let mut d_db_b3 = alloc_uninit::<u64>(&stream, n_db * words_per_plane)?;
let mut d_q_b0 = alloc_uninit::<u64>(&stream, n_q * words_per_plane)?;
let mut d_q_b1 = alloc_uninit::<u64>(&stream, n_q * words_per_plane)?;
let mut d_q_b2 = alloc_uninit::<u64>(&stream, n_q * words_per_plane)?;
let mut d_q_b3 = alloc_uninit::<u64>(&stream, n_q * words_per_plane)?;
*/
unsafe {
// DB
stream
.launch_builder(&packed_to_bitplanes_rowwise)
.arg(&d_db_packed)
.arg(&mut d_db_b0)
.arg(&mut d_db_b1)
.arg(&mut d_db_b2)
.arg(&mut d_db_b3)
.arg(&(challenge.database_size))
.arg(&(challenge.vector_dims))
.arg(&words_per_plane_i32)
.launch(LaunchConfig {
grid_dim: grd_db,
block_dim: blk_conv,
shared_mem_bytes: 0,
})?;
// Q
stream
.launch_builder(&packed_to_bitplanes_rowwise)
.arg(&d_qv_packed)
.arg(&mut d_q_b0)
.arg(&mut d_q_b1)
.arg(&mut d_q_b2)
.arg(&mut d_q_b3)
.arg(&(challenge.difficulty.num_queries))
.arg(&(challenge.vector_dims))
.arg(&words_per_plane_i32)
.launch(LaunchConfig {
grid_dim: grd_q,
block_dim: blk_conv,
shared_mem_bytes: 0,
})?;
}
// elapsed_time_ms_3 = start_time_total.elapsed().as_micros() as f32 / 1000.0;
// launch top-k with 4 plane args
unsafe {
stream
.launch_builder(&find_topk_neighbors_bitsliced_kernel)
.arg(&d_q_b0)
.arg(&d_q_b1)
.arg(&d_q_b2)
.arg(&d_q_b3)
.arg(&d_db_b0)
.arg(&d_db_b1)
.arg(&d_db_b2)
.arg(&d_db_b3)
.arg(&d_db_norm_l2)
.arg(&d_db_norm_l2_squared)
.arg(&mut d_topk_indices)
//.arg(&mut d_topk_dist) // Save some memory -- we don't use this output
.arg(&k_i32)
.arg(&challenge.max_distance)
.arg(&challenge.database_size)
.arg(&challenge.difficulty.num_queries)
.arg(&challenge.vector_dims)
.arg(&norm_threshold)
.arg(&d_query_norm_l2)
.arg(&d_query_norm_l2_squared)
.arg(&words_per_plane_i32)
.launch(cfg_topk)?;
}
}
}
// Pull back top-K indices, build Top-1 for the Solution, and compute Recall@K if provided
let h_topk: Vec<i32> = stream.memcpy_dtov(&d_topk_indices)?;
let mut top1 = Vec::<usize>::with_capacity(challenge.difficulty.num_queries as usize);
for q in 0..(challenge.difficulty.num_queries as usize) {
let base = q * topk;
top1.push(h_topk[base] as usize); // assuming kernel writes sorted asc by distance
}
// let elapsed_time_ms_4 = start_time_total.elapsed().as_micros() as f32 / 1000.0;
//
// === Re-rank Top-K on FP32 ===
//
// NOTE: We only return the best match, not an array. This is an "internal" top-k.
//
let refine_fn = module.load_function("refine_topk_rerank_kernel")?;
let threads_refine: u32 = 128;
let grid_refine = challenge.difficulty.num_queries;
let shared_refine = (challenge.vector_dims as usize * std::mem::size_of::<f32>()
+ threads_refine as usize * std::mem::size_of::<f32>()) as u32;
let mut d_refined_index =
stream.alloc_zeros::<i32>(challenge.difficulty.num_queries as usize)?;
let mut d_refined_distance =
stream.alloc_zeros::<f32>(challenge.difficulty.num_queries as usize)?;
let k_i32: i32 = topk as i32;
let cfg_refine = LaunchConfig {
grid_dim: (grid_refine, 1, 1),
block_dim: (threads_refine, 1, 1),
shared_mem_bytes: shared_refine,
};
unsafe {
stream
.launch_builder(&refine_fn)
.arg(&challenge.d_query_vectors) // Original FP32 queries
.arg(&challenge.d_database_vectors) // Original FP32 DB
.arg(&d_topk_indices) // [num_queries * K] (i32)
.arg(&mut d_refined_index) // OUT best index per query
.arg(&mut d_refined_distance) // OUT best distance per query
.arg(&challenge.difficulty.num_queries) // num_queries
.arg(&challenge.vector_dims) // original vector dim
.arg(&k_i32) // K
.launch(cfg_refine)?;
}
//stream.synchronize()?;
// Use refined Top-1 as the final Solution
let top1_refined: Vec<i32> = stream.memcpy_dtov(&d_refined_index)?;
let mut final_idxs = Vec::<usize>::with_capacity(top1_refined.len());
for &idx in &top1_refined {
final_idxs.push(idx as usize);
}
// let elapsed_time_ms = start_time_total.elapsed().as_micros() as f32 / 1000.0;
// Internal timing statistics
// println!("===== stat_filter bitslice {}-bit ( Top-{} ) =====", mode.bits(), topk);
// println!(
// "Time for nonce: {:.3} ms (sum+stats: {:.3} ms + mad_sort: {:.3} ms + slice: {:.3} ms + search: {:.3} ms + rerank {:.3} ms)",
// elapsed_time_ms,
// elapsed_time_ms_1,
// elapsed_time_ms_2 - elapsed_time_ms_1,
// elapsed_time_ms_3 - elapsed_time_ms_2,
// elapsed_time_ms_4 - elapsed_time_ms_3,
// elapsed_time_ms - elapsed_time_ms_4
// );
let _ = save_solution(&Solution {
indexes: final_idxs,
});
return Ok(());
}
//------------ MAD Scale Factor Adjustment -------------
fn scale_factor(num_queries: usize) -> f32 {
match num_queries {
q if q <= 700 => 0.20,
q if q <= 1000 => 0.20 + (q as f32 - 700.0) * (0.10 / 300.0), // 0.30 at 1000
q if q <= 1500 => 0.30 + (q as f32 - 1000.0) * (0.20 / 500.0), // 0.50 at 1500
q if q <= 2000 => 0.50 + (q as f32 - 1500.0) * (0.44 / 500.0), // 0.94 at 2000
q if q <= 2500 => 0.94 + (q as f32 - 2000.0) * (1.08 / 500.0), // 2.02 at 2500
_ => 1.00,
}
}
//----------------- Env Variables -------------------
// fn read_topk() -> usize {
// env::var("STATFILT_TOP_K")
// .ok()
// .and_then(|s| s.trim().parse::<usize>().ok())
// .filter(|&v| v > 0)
// .unwrap_or(DEFAULT_TOP_K)
// }
//----------------- Alloc Not Zeroed ----------------
/*
use cudarc::driver::CudaSlice;
use cudarc::driver::DriverError;
fn alloc_uninit<T: cudarc::driver::DeviceRepr>(
stream: &Arc<CudaStream>,
len: usize,
) -> Result<CudaSlice<T>, DriverError> {
unsafe { stream.alloc::<T>(len) }
}
*/