diff --git a/.github/workflows/build_algorithm.yml b/.github/workflows/build_algorithm.yml index 4be6f202..f9fcf5a6 100644 --- a/.github/workflows/build_algorithm.yml +++ b/.github/workflows/build_algorithm.yml @@ -68,7 +68,7 @@ jobs: ghcr.io/tig-foundation/tig-monorepo/dev:0.0.1-aarch64 \ bash -c "RUST_TARGET=aarch64-unknown-linux-gnu \ LD_LIBRARY_PATH=\${LD_LIBRARY_PATH}:/usr/local/lib/rust \ - build_so.sh \$CHALLENGE \$ALGORITHM" + build_so \$CHALLENGE \$ALGORITHM" - name: Build GPU Algorithm if: needs.init.outputs.GPU == 'true' @@ -81,8 +81,8 @@ jobs: ghcr.io/tig-foundation/tig-monorepo/dev:0.0.1-aarch64-cuda12.6.3 \ bash -c "RUST_TARGET=aarch64-unknown-linux-gnu \ LD_LIBRARY_PATH=\${LD_LIBRARY_PATH}:/usr/local/lib/rust \ - build_so.sh \$CHALLENGE \$ALGORITHM --cuda && \ - build_ptx.py \$CHALLENGE \$ALGORITHM" + build_so \$CHALLENGE \$ALGORITHM --cuda && \ + build_ptx \$CHALLENGE \$ALGORITHM" - name: Upload Artifact uses: actions/upload-artifact@v4 @@ -111,7 +111,7 @@ jobs: ghcr.io/tig-foundation/tig-monorepo/dev:0.0.1-amd64 \ bash -c "RUST_TARGET=x86_64-unknown-linux-gnu \ LD_LIBRARY_PATH=\${LD_LIBRARY_PATH}:/usr/local/lib/rust \ - build_so.sh \$CHALLENGE \$ALGORITHM" + build_so \$CHALLENGE \$ALGORITHM" - name: Build GPU Algorithm if: needs.init.outputs.GPU == 'true' @@ -124,8 +124,8 @@ jobs: ghcr.io/tig-foundation/tig-monorepo/dev:0.0.1-amd64-cuda12.6.3 \ bash -c "RUST_TARGET=x86_64-unknown-linux-gnu \ LD_LIBRARY_PATH=\${LD_LIBRARY_PATH}:/usr/local/lib/rust \ - build_so.sh \$CHALLENGE \$ALGORITHM --cuda && \ - build_ptx.py \$CHALLENGE \$ALGORITHM" + build_so \$CHALLENGE \$ALGORITHM --cuda && \ + build_ptx \$CHALLENGE \$ALGORITHM" - name: Upload Artifact uses: actions/upload-artifact@v4 diff --git a/Dockerfile.dev b/Dockerfile.dev index d9825fd5..7b9529ea 100644 --- a/Dockerfile.dev +++ b/Dockerfile.dev @@ -63,7 +63,7 @@ RUN if command -v nvcc > /dev/null 2>&1; then \ rm -rf tig-monorepo COPY tig-binary/scripts /usr/local/bin/ -RUN chmod +x /usr/local/bin/build_so.sh && \ - chmod +x /usr/local/bin/build_ptx.py +RUN chmod +x /usr/local/bin/build_so && \ + chmod +x /usr/local/bin/build_ptx WORKDIR /app diff --git a/docs/guides/innovating.md b/docs/guides/innovating.md index 8c21c6ab..82943261 100644 --- a/docs/guides/innovating.md +++ b/docs/guides/innovating.md @@ -41,8 +41,8 @@ Each algorithm branch will have 6 key files (11 if there is CUDA code): **READ THE IMPORTANT NOTES AT THE BOTTOM OF THIS SECTION** 1. Pick a challenge (``) to develop an algorithm for -2. Make a copy of `tig-algorithms//template.rs` or an existing algorithm (see notes) - * (Optional) for Cuda, additionally make a copy of `tig-algorithms//template.cu` +2. Make a copy of `tig-algorithms/src//template.rs` or an existing algorithm (see notes) + * (Optional) for Cuda, additionally make a copy of `tig-algorithms/src//template.cu` 3. Make sure your file has the following notice in its header if you intend to submit it to TIG: ``` Copyright [year copyright work created] [name of copyright owner] @@ -181,9 +181,9 @@ language governing permissions and limitations under the License. # example docker run -it -v $(pwd):/app --gpus all ghcr.io/tig-foundation/tig-monorepo/dev:0.0.1-amd64-cuda12.6.3 ``` -6. If you have Cuda code, use `build_ptx.py` to compile it +6. If you have Cuda code, use `build_ptx` to compile it ``` - build_ptx.py + build_ptx ``` 7. Run the test * No cuda: @@ -203,8 +203,9 @@ language governing permissions and limitations under the License. * If you are copying and modifying an algorithm that has been submitted to TIG, make sure to use the `innovator_outbound` version * Do not include tests in your algorithm file. TIG will reject your algorithm submission. * Only your algorithm's code gets submitted. You should not be modifying `Cargo.toml` in `tig-algorithms`. Any extra dependencies you add will not be available when TIG compiles your algorithm -* If you need to use random number generation, ensure that it is seeded so that your algorithm is deterministic. - * Suggest to use `let mut rng = SmallRng::from_seed(StdRng::from_seed(challenge.seed).gen())` +* There are comments with more tips inside the templates! + * Rust: `tig-algorithms/src//template.rs` + * Cuda: `tig-algorithms/src//template.cu` ## Locally Compiling Your Algorithm into Shared Object diff --git a/tig-algorithms/src/hypergraph/template.cu b/tig-algorithms/src/hypergraph/template.cu index bb473871..bd0c6dc4 100644 --- a/tig-algorithms/src/hypergraph/template.cu +++ b/tig-algorithms/src/hypergraph/template.cu @@ -36,10 +36,35 @@ acknowledgments below: */ // License must be the same as the rust code -// You can import any libraries available in nvidia/cuda:12.6.3-cudnn-devel-ubuntu24.04 -#include -#include -#include -#include - -// Any functions available in the .cu file will be available here +// IMPORTANT NOTES: +// 1. You can import any libraries available in nvidia/cuda:12.6.3-cudnn-devel-ubuntu24.04 +// Example: +// #include +// #include +// #include +// #include +// +// 2. If you launch a kernel with multiple blocks, any writes should be to non-overlapping parts of the memory +// Example: +// arr[blockIdx.x] = 1; // This IS deterministic +// arr[0] = 1; // This is NOT deterministic +// +// 3. Any kernel available in .cu will be available here +// +// 4. If you need to use random numbers, you can use the CURAND library and seed it with challenge.seed. +// Example rust: +// let d_seed = stream.memcpy_stod(seed)?; +// stream +// .launch_builder(&my_kernel) +// .arg(&d_seed) +// ... +// +// Example cuda: +// extern "C" __global__ void my_kernel( +// const uint8_t *seed, +// ... +// ) { +// curandState state; +// curand_init(((uint64_t *)(seed))[0], 0, 0, &state); +// ... +// } diff --git a/tig-algorithms/src/hypergraph/template.rs b/tig-algorithms/src/hypergraph/template.rs index 148149c4..bc72bf4e 100644 --- a/tig-algorithms/src/hypergraph/template.rs +++ b/tig-algorithms/src/hypergraph/template.rs @@ -70,6 +70,18 @@ pub fn solve_sub_instance( stream: Arc, prop: &cudaDeviceProp, ) -> anyhow::Result> { + // If you need random numbers, recommend using SmallRng with challenge.seed: + // use rand::{rngs::SmallRng, Rng, SeedableRng}; + // let mut rng = SmallRng::from_seed(challenge.seed); + + // when launching kernels, you should hardcode the LaunchConfig for determinism: + // Example: + // LaunchConfig { + // grid_dim: (1024, 1, 1), // do not exceed 1024 for compatibility with compute 3.6 + // block_dim: ((arr_len + 1023) / 1024, 1, 1), + // shared_mem_bytes: 400, + // } + // return Err() if your algorithm encounters an error // return Ok(None) if your algorithm finds no solution or needs to exit early // return Ok(SubSolution { .. }) if your algorithm finds a solution diff --git a/tig-algorithms/src/knapsack/template.rs b/tig-algorithms/src/knapsack/template.rs index 480ec4ce..881feae5 100644 --- a/tig-algorithms/src/knapsack/template.rs +++ b/tig-algorithms/src/knapsack/template.rs @@ -55,6 +55,10 @@ pub fn solve_challenge(challenge: &Challenge) -> anyhow::Result } pub fn solve_sub_instance(instance: &SubInstance) -> Result> { + // If you need random numbers, recommend using SmallRng with instance.seed: + // use rand::{rngs::SmallRng, Rng, SeedableRng}; + // let mut rng = SmallRng::from_seed(instance.seed); + // return Err() if your algorithm encounters an error // return Ok(None) if your algorithm finds no solution or needs to exit early // return Ok(SubSolution { .. }) if your algorithm finds a solution diff --git a/tig-algorithms/src/satisfiability/template.rs b/tig-algorithms/src/satisfiability/template.rs index f19cba21..238ac721 100644 --- a/tig-algorithms/src/satisfiability/template.rs +++ b/tig-algorithms/src/satisfiability/template.rs @@ -40,6 +40,10 @@ use anyhow::{anyhow, Result}; use tig_challenges::satisfiability::*; pub fn solve_challenge(challenge: &Challenge) -> Result> { + // If you need random numbers, recommend using SmallRng with challenge.seed: + // use rand::{rngs::SmallRng, Rng, SeedableRng}; + // let mut rng = SmallRng::from_seed(challenge.seed); + // return Err() if your algorithm encounters an error // return Ok(None) if your algorithm finds no solution or needs to exit early // return Ok(Solution { .. }) if your algorithm finds a solution diff --git a/tig-algorithms/src/vector_search/template.cu b/tig-algorithms/src/vector_search/template.cu index bb473871..b2c520d8 100644 --- a/tig-algorithms/src/vector_search/template.cu +++ b/tig-algorithms/src/vector_search/template.cu @@ -36,10 +36,35 @@ acknowledgments below: */ // License must be the same as the rust code -// You can import any libraries available in nvidia/cuda:12.6.3-cudnn-devel-ubuntu24.04 -#include -#include -#include -#include - -// Any functions available in the .cu file will be available here +// IMPORTANT NOTES: +// 1. You can import any libraries available in nvidia/cuda:12.6.3-cudnn-devel-ubuntu24.04 +// Example: +// #include +// #include +// #include +// #include +// +// 2. If you launch a kernel with multiple blocks, any writes should be to non-overlapping parts of the memory +// Example: +// arr[blockIdx.x] = 1; // This IS deterministic +// arr[0] = 1; // This is NOT deterministic as multiple blocks are writing to the same location +// +// 3. Any kernel available in .cu will be available here +// +// 4. If you need to use random numbers, you can use the CURAND library and seed it with challenge.seed. +// Example rust: +// let d_seed = stream.memcpy_stod(seed)?; +// stream +// .launch_builder(&my_kernel) +// .arg(&d_seed) +// ... +// +// Example cuda: +// extern "C" __global__ void my_kernel( +// const uint8_t *seed, +// ... +// ) { +// curandState state; +// curand_init(((uint64_t *)(seed))[0], 0, 0, &state); +// ... +// } \ No newline at end of file diff --git a/tig-algorithms/src/vector_search/template.rs b/tig-algorithms/src/vector_search/template.rs index a34237bb..c29bf910 100644 --- a/tig-algorithms/src/vector_search/template.rs +++ b/tig-algorithms/src/vector_search/template.rs @@ -44,12 +44,27 @@ use cudarc::{ use std::sync::Arc; use tig_challenges::vector_search::*; +// when launching kernels, you should not exceed this const or else it may not be deterministic +const MAX_THREADS_PER_BLOCK: u32 = 1024; + pub fn solve_challenge( challenge: &Challenge, module: Arc, stream: Arc, prop: &cudaDeviceProp, ) -> anyhow::Result> { + // If you need random numbers, recommend using SmallRng with challenge.seed: + // use rand::{rngs::SmallRng, Rng, SeedableRng}; + // let mut rng = SmallRng::from_seed(challenge.seed); + + // when launching kernels, you should hardcode the LaunchConfig for determinism: + // Example: + // LaunchConfig { + // grid_dim: (1024, 1, 1), // do not exceed 1024 for compatibility with compute 3.6 + // block_dim: ((arr_len + 1023) / 1024, 1, 1), + // shared_mem_bytes: 400, + // } + // return Err() if your algorithm encounters an error // return Ok(None) if your algorithm finds no solution or needs to exit early // return Ok(Solution { .. }) if your algorithm finds a solution diff --git a/tig-algorithms/src/vehicle_routing/template.rs b/tig-algorithms/src/vehicle_routing/template.rs index f141e136..4cfa6469 100644 --- a/tig-algorithms/src/vehicle_routing/template.rs +++ b/tig-algorithms/src/vehicle_routing/template.rs @@ -55,6 +55,10 @@ pub fn solve_challenge(challenge: &Challenge) -> anyhow::Result } pub fn solve_sub_instance(instance: &SubInstance) -> Result> { + // If you need random numbers, recommend using SmallRng with instance.seed: + // use rand::{rngs::SmallRng, Rng, SeedableRng}; + // let mut rng = SmallRng::from_seed(instance.seed); + // return Err() if your algorithm encounters an error // return Ok(None) if your algorithm finds no solution or needs to exit early // return Ok(SubSolution { .. }) if your algorithm finds a solution diff --git a/tig-binary/README.md b/tig-binary/README.md index 66c6c161..adf224e6 100644 --- a/tig-binary/README.md +++ b/tig-binary/README.md @@ -16,22 +16,22 @@ For CUDA, TIG uses `nvcc` to generate ptx using target version `sm_70/compute_70 ``` # example docker run -it -v $(pwd):/app ghcr.io/tig-foundation/tig-monorepo/dev:0.0.1-aarch64 -# scripts build_so.sh and build_ptx.py are on PATH +# scripts build_so and build_ptx are on PATH ``` -2. Build shared object using `build_so.sh` script: +2. Build shared object using `build_so` script: * Expects `tig_algorithm::::::solve_challenge` to be importable * Outputs to `tig-algorithms/lib///.so`, where `ARCH` is aarch64 or amd64 ``` # add '--cuda' flag if building cuda algorithm - build_so.sh $CHALLENGE $ALGORITHM + build_so $CHALLENGE $ALGORITHM ``` -3. (Optional) Build ptx using `build_ptx.py` script: +3. (Optional) Build ptx using `build_ptx` script: * Expects `tig_algorithm/src//.cu` or `tig_algorithm/src///benchmarker_outbound.cu` file to exist * Outputs to `tig-algorithms/lib//ptx/.ptx` ``` -build_ptx.py $CHALLENGE $ALGORITHM +build_ptx $CHALLENGE $ALGORITHM ``` # License diff --git a/tig-binary/scripts/build_ptx b/tig-binary/scripts/build_ptx new file mode 100644 index 00000000..fff90995 --- /dev/null +++ b/tig-binary/scripts/build_ptx @@ -0,0 +1,312 @@ +#!/usr/bin/env python3 + +import argparse +import os +import re +import shutil +import subprocess +import sys +import tempfile + +# Import the dictionary from ptx_instructions.py +instruction_fuel_cost = { + 'add.u32': 2, + 'add.u64': 3, + 'add.f32': 4, + 'add.f64': 5, + 'add.s32': 2, + 'add.s64': 3, + 'sub.u32': 2, + 'sub.u64': 3, + 'sub.f32': 4, + 'sub.f64': 5, + 'mul.u32': 4, + 'mul.u64': 5, + 'mul.f32': 5, + 'mul.f64': 6, + 'div.u32': 10, + 'div.u64': 12, + 'div.f32': 15, + 'div.f64': 20, + 'mul.wide.u32': 6, + 'mul.wide.u64': 8, + 'mad.wide.u32': 8, + 'mad.wide.u64': 10, + 'mov.u32': 1, + 'mov.u64': 1, + 'mov.f32': 1, + 'mov.f64': 1, + 'and.b32': 1, + 'and.b64': 1, + 'or.b32': 1, + 'or.b64': 1, + 'xor.b32': 1, + 'xor.b64': 1, + 'shl.b32': 2, + 'shl.b64': 3, + 'shr.b32': 2, + 'shr.b64': 3, + 'cvt.u32.u64': 2, + 'cvt.f32.f64': 3, + 'cvt.u64.u32': 2, + 'cvt.f64.f32': 3, + 'setp.eq.u32': 2, + 'setp.eq.u64': 3, + 'setp.lt.u32': 2, + 'setp.lt.u64': 3, + 'setp.gt.u32': 2, + 'setp.gt.u64': 3, + 'setp.ne.u32': 2, + 'setp.ne.u64': 3, + 'selp.u32': 3, + 'selp.u64': 4, + 'abs.s32': 2, + 'abs.s64': 3, + 'abs.f32': 3, + 'abs.f64': 4, + 'min.u32': 2, + 'min.u64': 3, + 'min.f32': 3, + 'min.f64': 4, + 'max.u32': 2, + 'max.u64': 3, + 'max.f32': 3, + 'max.f64': 4, + 'sqrt.rn.f32': 15, + 'sqrt.rn.f64': 20, + 'rsqrt.rn.f32': 15, + 'rsqrt.rn.f64': 20, + 'sqrt.approx.ftz.f32': 8, + 'sqrt.approx.ftz.f64': 10, + 'sin.approx.f32': 8, + 'sin.approx.f64': 10, + 'cos.approx.f32': 8, + 'cos.approx.f64': 10, + 'tanh.approx.f32': 8, + 'tanh.approx.f64': 10, + 'add.f16': 1, + 'add.f16x2': 1, + 'add.bf16': 1, + 'add.bf16x2': 1, + 'fma.rn.bf16': 1, + 'fma.rn.bf16x2': 1, + 'cvt.rn.bf16.f32': 1, + 'cvt.rn.f32.bf16': 1, + 'cvt.rn.tf32.f32': 1, + 'cvt.rn.f32.tf32': 1, + 'atom.add.u32': 8, + 'atom.add.u64': 10, + 'atom.min.u32': 8, + 'atom.min.u64': 10, + 'atom.max.u32': 8, + 'atom.max.u64': 10, + 'tex.1d.v4.f32': 15, + 'tex.2d.v4.f32': 20, + 'tex.3d.v4.f32': 25, + 'ld.param.u32': 3, + 'ld.param.u64': 4, + 'st.param.u32': 3, + 'st.param.u64': 4, + 'ld.const.u32': 3, + 'ld.const.u64': 4, + 'popc.b32': 3, + 'popc.b64': 4, + 'clz.b32': 3, + 'clz.b64': 4, + 'brev.b32': 3, + 'brev.b64': 4, + 'unused': 1, +} + +def parse_ptx_code(ptx_code): + parsed = [] + kernel = None + block = None + for line in ptx_code: + stripped_line = line.strip() + if kernel is None: + if (stripped_line.startswith(".visible .entry") or stripped_line.startswith(".func")): + kernel = { + "definition": [line], + "blocks": None + } + parsed.append(kernel) + else: + parsed.append(line) + elif kernel["blocks"] is None: + if stripped_line == "{": + block = [] + kernel["blocks"] = [] + else: + kernel["definition"].append(line) + else: + if stripped_line == "}": + if len(block) > 0: + kernel["blocks"].append(block) + kernel = None + block = None + elif stripped_line != "": + block.append(line) + if ( + stripped_line == "ret;" or + ("bra" in stripped_line and not stripped_line.startswith("//")) or + (stripped_line.startswith("@") and "bra" in stripped_line) + ): + kernel["blocks"].append(block) + block = [] + return parsed + +def inject_fuel_and_runtime_sig(parsed, kernels_to_ignore): + modified_code = [] + block_sig = 0 + + for line in parsed: + if not isinstance(line, dict): + block_sig ^= hash(line) & 0xFFFFFFFFFFFFFFFF + modified_code.append(line) + continue + + kernel = line + block_sig ^= hash(kernel["definition"][0]) & 0xFFFFFFFFFFFFFFFF + name = ( + kernel["definition"][0] # func sig in first line + .split()[-1] # func name is last token + .split("(")[0] # func name is before the first ( + ) + if name in kernels_to_ignore: + print(f"kernel: {name}, #blocks: {len(kernel['blocks'])}, status: SKIPPED") + modified_code.extend(kernel["definition"]) + modified_code.append("{") + for block in kernel["blocks"]: + modified_code.extend(block) + modified_code.append("}") + continue + + print(f"kernel: {name}, #blocks: {len(kernel['blocks'])}, status: PROCESSING") + modified_code.extend(kernel["definition"]) + modified_code.append("{") + modified_code.append( +""" +\t.reg .u64 \tr_signature; +\t.reg .u64 \tr_sig_addr; +\t.reg .u64 \tr_temp_fuel; +\t.reg .u64 \tr_fuel_usage; +\t.reg .u64 \tr_fuel_addr; +\t.reg .pred \tp_fuel; +\tmov.u64 \tr_signature, 0xa1b2c3d4e5f6a7b8; +\tmov.u64 \tr_sig_addr, gbl_SIGNATURE; +\tmov.u64 \tr_temp_fuel, 0; +\tmov.u64 \tr_fuel_usage, 0; +\tmov.u64 \tr_fuel_addr, gbl_FUELUSAGE; +""" + ) + for i, block in enumerate(kernel["blocks"]): + block_sig ^= hash(block[0]) & 0xFFFFFFFFFFFFFFFF + block_fuel = sum( + instruction_fuel_cost.get(instr.split()[0], 0) + for instr in block + ) + print(f"\tblock {i}: fuel_usage: {block_fuel}, signature: 0x{block_sig:016x}") + modified_code.extend(block[:-1]) + modified_code.append( +f""" +\txor.b64 \tr_signature, r_signature, 0x{block_sig:016x}; +\tadd.u64 \tr_fuel_usage, r_fuel_usage, {block_fuel}; +""" +) + if block[-1].strip() == "ret;": + modified_code.append( +""" +\tatom.global.add.u64 \tr_temp_fuel, [r_fuel_addr], r_fuel_usage; +\tsetp.lt.u64 \tp_fuel, r_temp_fuel, 0xdeadbeefdeadbeef; +\t@p_fuel bra $NORMAL_EXIT; +\tst.global.u64 \t[gbl_ERRORSTAT], 1; +$NORMAL_EXIT: +\tatom.global.xor.b64 \tr_sig_addr, [r_sig_addr], r_signature; +\tatom.global.add.u64 \tr_fuel_addr, [r_fuel_addr], r_fuel_usage; +""" + ) + modified_code.append(block[-1]) + modified_code.append("}") + return modified_code + +def main(): + parser = argparse.ArgumentParser(description='Compile PTX with injected runtime signature') + parser.add_argument('challenge', help='Challenge name') + parser.add_argument('algorithm', help='Algorithm name') + + args = parser.parse_args() + + print(f"Compiling .ptx for {args.challenge}/{args.algorithm}") + + framework_cu = "tig-binary/src/framework.cu" + if not os.path.exists(framework_cu): + raise FileNotFoundError( + f"Framework code does not exist @ '{framework_cu}'. This script must be run from the root of tig-monorepo" + ) + + challenge_cu = f"tig-challenges/src/{args.challenge}.cu" + if not os.path.exists(challenge_cu): + raise FileNotFoundError( + f"Challenge code does not exist @ '{challenge_cu}'. Is the challenge name correct?" + ) + + algorithm_cu = f"tig-algorithms/src/{args.challenge}/{args.algorithm}.cu" + algorithm_cu2 = f"tig-algorithms/src/{args.challenge}/{args.algorithm}/benchmarker_outbound.cu" + if not os.path.exists(algorithm_cu) and not os.path.exists(algorithm_cu2): + raise FileNotFoundError( + f"Algorithm code does not exist @ '{algorithm_cu}' or '{algorithm_cu2}'. Is the algorithm name correct?" + ) + if not os.path.exists(algorithm_cu): + algorithm_cu = algorithm_cu2 + + # Combine .cu source files into a temporary file + with tempfile.TemporaryDirectory() as temp_dir: + temp_cu = os.path.join(temp_dir, "temp.cu") + temp_ptx = os.path.join(temp_dir, "temp.ptx") + + with open(framework_cu, 'r') as f: + code = f.read() + "\n" + with open(challenge_cu, 'r') as f: + code += f.read() + "\n" + kernel_regex = r'(?:extern\s+"C"\s+__global__|__device__)\s+\w+\s+(?P\w+)\s*\(' + kernels_to_ignore = [match.group('func') for match in re.finditer(kernel_regex, code)] + with open(algorithm_cu, 'r') as f: + code += f.read() + with open(temp_cu, 'w') as f: + f.write(code) + + # Compile the temporary .cu file into a .ptx file using nvcc + nvcc_command = [ + "nvcc", "-ptx", temp_cu, "-o", temp_ptx, + "-arch", "compute_70", + "-code", "sm_70", + "--use_fast_math", + "-dopt=on" + ] + + print(f"Running nvcc command: {' '.join(nvcc_command)}") + subprocess.run(nvcc_command, check=True) + print(f"Successfully compiled") + + print("Adding runtime signature opcodes") + with open(temp_ptx, 'r') as f: + ptx_code = f.readlines() + parsed = parse_ptx_code(ptx_code) + modified_code = inject_fuel_and_runtime_sig(parsed, kernels_to_ignore) + + output_path = f"tig-algorithms/lib/{args.challenge}/ptx/{args.algorithm}.ptx" + os.makedirs(os.path.dirname(output_path), exist_ok=True) + # with open(output_path, 'w') as f: + # f.writelines(ptx_code) + with open(output_path, 'w') as f: + f.writelines(modified_code) + print(f"Wrote ptx to {output_path}") + print(f"Done") + +if __name__ == "__main__": + try: + main() + except Exception as e: + print(f"Error: {e}") + sys.exit(1) \ No newline at end of file diff --git a/tig-binary/scripts/build_ptx.py b/tig-binary/scripts/build_ptx.py deleted file mode 100644 index 750d1ef3..00000000 --- a/tig-binary/scripts/build_ptx.py +++ /dev/null @@ -1,394 +0,0 @@ -#!/usr/bin/env python3 - -import argparse -import os -import re -import shutil -import subprocess -import sys -import tempfile - -# Import the dictionary from ptx_instructions.py -instruction_prime_map = { - 'add.u32': (0xEDC1A7F47ACD2EE3, 2), - 'add.u64': (0xBE8098FE5CFCC7B5, 3), - 'add.f32': (0x90B2AD995AF897CD, 4), - 'add.f64': (0xC3EA137E165AECED, 5), - 'add.s32': (0xE620952D37DDFF91, 2), - 'add.s64': (0x9BD972EB4677C3F1, 3), - 'sub.u32': (0x87353D33C2976737, 2), - 'sub.u64': (0xE155AD1E824B25E3, 3), - 'sub.f32': (0xC256B33A0EF0745F, 4), - 'sub.f64': (0x8C4AAED901B0472B, 5), - 'mul.u32': (0x91C6A1D9B4EDC1AB, 4), - 'mul.u64': (0x9A16DFEE16D4BBF5, 5), - 'mul.f32': (0x9AF1495384190163, 5), - 'mul.f64': (0x883C081D0CDE3065, 6), - 'div.u32': (0xFD9B02FDC35C72B3, 10), - 'div.u64': (0xF5F2BCF367F68ED1, 12), - 'div.f32': (0xB161AB7FB04F004D, 15), - 'div.f64': (0x96E8220B9C5E8C97, 20), - 'mul.wide.u32': (0x9B8C768C1401BA7B, 6), - 'mul.wide.u64': (0xBC49DE9B3A818899, 8), - 'mad.wide.u32': (0xA17576DF232E7669, 8), - 'mad.wide.u64': (0xD298F1C0BF940F95, 10), - 'mov.u32': (0xED571105D6049077, 1), - 'mov.u64': (0xEF0F967771856DC3, 1), - 'mov.f32': (0xC73484F9874FAEDF, 1), - 'mov.f64': (0xC41D2D54E517F109, 1), - 'and.b32': (0x80B429DFBFC318EB, 1), - 'and.b64': (0xBDA9ECF782F0D3FB, 1), - 'or.b32': (0xC26355832083F5A1, 1), - 'or.b64': (0xEB578CCCB8F3BD37, 1), - 'xor.b32': (0xE58E31633D73A4A5, 1), - 'xor.b64': (0xB62B0F67BFA44C95, 1), - 'shl.b32': (0x88330E6E1BFA5411, 2), - 'shl.b64': (0x8374F43D807A6F91, 3), - 'shr.b32': (0xF1B8109D2F948463, 2), - 'shr.b64': (0xA3230556089777C7, 3), - 'cvt.u32.u64': (0xCEDC3D307D8D8683, 2), - 'cvt.f32.f64': (0x9DA4540AE2D7A161, 3), - 'cvt.u64.u32': (0x9A9E961B54B29955, 2), - 'cvt.f64.f32': (0x8B1B3D4D77BC7BB3, 3), - 'setp.eq.u32': (0xC161CEDAF256D5E5, 2), - 'setp.eq.u64': (0xD1E5DA2FDCD5E157, 3), - 'setp.lt.u32': (0x986CCCCFA9F5B10B, 2), - 'setp.lt.u64': (0x8CC7C690FF547D63, 3), - 'setp.gt.u32': (0x80FF2A3B14A4D19D, 2), - 'setp.gt.u64': (0xE0E9526F53C79197, 3), - 'setp.ne.u32': (0xB19319E767B773DF, 2), - 'setp.ne.u64': (0x8AC38410037C32D5, 3), - 'selp.u32': (0xAB0A8BAC52D5D76B, 3), - 'selp.u64': (0x9CDBFC00628D628B, 4), - 'abs.s32': (0x84378B096D13B6A3, 2), - 'abs.s64': (0xC04FBAAA56FA0DAB, 3), - 'abs.f32': (0xCE9AA2EB4B22456B, 3), - 'abs.f64': (0xF165D7826D16DF47, 4), - 'min.u32': (0xE95BA56F275D3EC5, 2), - 'min.u64': (0x8C3F2F6F2EB0C34F, 3), - 'min.f32': (0xBF2A12007525FEED, 3), - 'min.f64': (0xF3A8D718FEC1B393, 4), - 'max.u32': (0xF4692B3CA0566779, 2), - 'max.u64': (0xADB80126D86C7295, 3), - 'max.f32': (0xE6FD4C5BCBC70E3D, 3), - 'max.f64': (0xD89E33A78DBF9527, 4), - 'sqrt.rn.f32': (0xECE9FBD3A6D77023, 15), - 'sqrt.rn.f64': (0xC7D7E4D1245D5CCD, 20), - 'rsqrt.rn.f32': (0xBFA9439C30B70919, 15), - 'rsqrt.rn.f64': (0xDF9DC483E11B08A5, 20), - 'sqrt.approx.ftz.f32': (0x8062539FCA30F685, 8), - 'sqrt.approx.ftz.f64': (0xBEAA214049557A1B, 10), - 'sin.approx.f32': (0x8062539FCA30F685, 8), - 'sin.approx.f64': (0xBEAA214049557A1B, 10), - 'cos.approx.f32': (0x854FD92C1227AF13, 8), - 'cos.approx.f64': (0xE4ED779797574CBD, 10), - 'tanh.approx.f32': (0x8062539FCA30F685, 8), - 'tanh.approx.f64': (0xBEAA214049557A1B, 10), - 'add.f16': (0xA5234567890ABCDE, 1), - 'add.f16x2': (0xB5234567890ABCDE, 1), - 'add.bf16': (0xC5234567890ABCDE, 1), - 'add.bf16x2': (0xD5234567890ABCDE, 1), - 'fma.rn.bf16': (0xE5234567890ABCDE, 1), - 'fma.rn.bf16x2': (0xF5234567890ABCDE, 1), - 'cvt.rn.bf16.f32': (0xA6234567890ABCDE, 1), - 'cvt.rn.f32.bf16': (0xB6234567890ABCDE, 1), - 'cvt.rn.tf32.f32': (0xC6234567890ABCDE, 1), - 'cvt.rn.f32.tf32': (0xD6234567890ABCDE, 1), - 'atom.add.u32': (0xF2D231CD2E1FBA23, 8), - 'atom.add.u64': (0xE511FE4AEA87A429, 10), - 'atom.min.u32': (0xC4FA795EB531A38B, 8), - 'atom.min.u64': (0xF0FB61E6281360FD, 10), - 'atom.max.u32': (0x8F28A03020BF8813, 8), - 'atom.max.u64': (0xD80622EB6110253F, 10), - 'tex.1d.v4.f32': (0xC56DD2999CD1234F, 15), - 'tex.2d.v4.f32': (0x918C3A31D782B0E5, 20), - 'tex.3d.v4.f32': (0xC002087960B604D5, 25), - 'ld.param.u32': (0xE97ADA4F02ABD567, 3), - 'ld.param.u64': (0x8521CA309251BB1D, 4), - 'st.param.u32': (0xF9DD000BE29F68F1, 3), - 'st.param.u64': (0xEC242CB6C99E502D, 4), - 'ld.const.u32': (0xCCBED9D942A60229, 3), - 'ld.const.u64': (0x8E8D513AAC06F061, 4), - 'popc.b32': (0xC6051FFCF3752D2B, 3), - 'popc.b64': (0xE94FDBF317D00AB7, 4), - 'clz.b32': (0xA2E613C950EA7F17, 3), - 'clz.b64': (0x8DB562EEC3BBD64F, 4), - 'brev.b32': (0xAE6A290706DD70E7, 3), - 'brev.b64': (0xD4FDC03C4401C533, 4), - 'unused': (0xA3E6942DA60A926F, 1), - 'unused': (0xBCB5CD76C9DC3253, 1), - 'unused': (0x966C40D3884717FF, 1), - 'unused': (0xFED4DB903BB79241, 1), - 'unused': (0xDA31E485E5E0C445, 1), - 'unused': (0xB8290AAC45720989, 1), - 'unused': (0xD0C39B467979E695, 1), - 'unused': (0xE12B4AF73AAABEEF, 1), - 'unused': (0xC1B8CDB5496BAFD7, 1), -} - -def read_ptx_file(file_path): - with open(file_path, 'r') as file: - ptx_code = file.readlines() - return ptx_code - -def write_ptx_file(modified_code, output_path): - with open(output_path, 'w') as file: - file.writelines(modified_code) - -def rotate_left(x, n): - n = n & 0x3F - return ((x << n) | (x >> (64 - n))) & 0xFFFFFFFFFFFFFFFF - -def rotate_right(x, n): - n = n & 0x3F - return ((x >> n) | (x << (64 - n))) & 0xFFFFFFFFFFFFFFFF - -def get_position_modifier(bb_idx, inst_idx, func_hash): - left_rot = ((bb_idx * 7 + inst_idx * 13 + func_hash * 17) & 0x3F) - right_rot = ((bb_idx * 11 + inst_idx * 17 + func_hash * 23) & 0x3F) - return left_rot, right_rot - -# Function to modify the PTX code by inserting XOR commands after certain instructions -def add_xor_commands(ptx_code, instruction_prime_map, ignore_patterns): - modified_code = [] - inside_kernel_function = False - skip_kernel = False - r_signature_inserted = False - current_block_signature = 0 - current_block_fuel = 0 - bb_idx = 0 - inst_idx = 0 - func_hash = 0 - current_func = "" - - # Track branch targets and their frequency - branch_targets = {} - # First pass - identify potential loop headers - for line in ptx_code: - stripped_line = line.strip() - if stripped_line.startswith("$L") or stripped_line.startswith("BB"): - label = stripped_line.split(":")[0] - branch_targets[label] = 0 - elif "bra" in stripped_line: - target = stripped_line.split()[-1].rstrip(";") - if target in branch_targets: - branch_targets[target] += 1 - - # Labels with multiple branches to them are likely loop headers - loop_headers = {label for label, count in branch_targets.items() if count > 0} - print(f"Identified potential loop headers: {loop_headers}") - - # Main processing pass - for line in enumerate(ptx_code): - stripped_line = line[1].strip() - - if stripped_line.startswith(".visible .entry") or stripped_line.startswith(".func"): - current_func = stripped_line.split()[-1] - func_hash = hash(current_func) & 0xFFFFFFFFFFFFFFFF - print(f"\nProcessing function: {current_func} (hash: {func_hash:016x})") - - # Check if current function should be skipped - skip_kernel = current_func in ignore_patterns - if skip_kernel: - print(f"Skipping kernel: {current_func[:-1]}") - - inside_kernel_function = True - r_signature_inserted = False - bb_idx = 0 - inst_idx = 0 - - if inside_kernel_function and skip_kernel: - modified_code.append(line[1]) - if stripped_line == "}": - inside_kernel_function = False - continue - - if inside_kernel_function and stripped_line == "{" and not r_signature_inserted: - modified_code.append(line[1]) - modified_code.append("\t.reg .u64 \tr_signature;\n") - modified_code.append("\t.reg .u64 \tr_fuelusage;\n") - modified_code.append("\t.reg .u64 \tr_fuel_backup;\n") - modified_code.append("\t.reg .u64 \tr_fuel_addr;\n") - modified_code.append("\t.reg .u64 \tr_temp_fuel;\n") - modified_code.append("\t.reg .u64 \tr_sig_addr;\n") - modified_code.append("\t.reg .pred \tp_fuel;\n") - modified_code.append("\tmov.u64 \tr_signature, 0x1111111111111111;\n") - modified_code.append("\tmov.u64 \tr_fuelusage, 0;\n") - modified_code.append("\tmov.u64 \tr_temp_fuel, 0;\n") - modified_code.append("\tmov.u64 \tr_sig_addr, gbl_SIGNATURE;\n") - modified_code.append("\tmov.u64 \tr_fuel_addr, gbl_FUELUSAGE;\n") - r_signature_inserted = True - continue - - # Handle branch instructions - if inside_kernel_function and not skip_kernel: - is_branch = ("bra" in stripped_line and not stripped_line.startswith("//")) or \ - (stripped_line.startswith("@") and "bra" in stripped_line) - is_return = stripped_line == "ret;" - - if (is_branch or is_return): - if current_block_signature != 0 or current_block_fuel != 0: - left_rot, right_rot = get_position_modifier(bb_idx, inst_idx, func_hash) - rotated_sig = rotate_left(rotate_right(current_block_signature, right_rot), left_rot) - modified_code.append(f"\txor.b64 \tr_signature, r_signature, 0x{rotated_sig:016x};\n") - modified_code.append(f"\tadd.u64 \tr_fuelusage, r_fuelusage, {current_block_fuel};\n") - - if is_return: - modified_code.append("\tmov.u64 \tr_fuel_backup, r_fuelusage;\n") - modified_code.append("\tatom.global.add.u64 \tr_temp_fuel, [r_fuel_addr], r_fuelusage;\n") - modified_code.append("\tadd.u64 \tr_temp_fuel, r_temp_fuel, r_fuel_backup;\n") - modified_code.append("\tmov.u64 \tr_fuelusage, 0;\n") - modified_code.append("\tsetp.gt.u64 p_fuel, r_temp_fuel, 0xdeadbeefdeadbeef;\n") - modified_code.append("\t@p_fuel bra $FUEL_EXCEEDED;\n") - modified_code.append("\tbra $NORMAL_EXIT;\n") - - modified_code.append("$FUEL_EXCEEDED:\n") - modified_code.append("\tmov.u64 \tr_temp_fuel, 1;\n") - modified_code.append("\tst.global.u64 \t[gbl_ERRORSTAT], r_temp_fuel;\n") - - modified_code.append("$NORMAL_EXIT:\n") - modified_code.append("\tatom.global.xor.b64 \tr_sig_addr, [r_sig_addr], r_signature;\n") - modified_code.append("\tatom.global.add.u64 \tr_fuel_addr, [r_fuel_addr], r_fuelusage;\n") - modified_code.append("\tret;\n") - continue - - modified_code.append(line[1]) - current_block_signature = 0 - current_block_fuel = 0 - bb_idx += 1 - inst_idx = 0 - continue - - if stripped_line.startswith("$L") or stripped_line.startswith("BB"): - if inside_kernel_function and not skip_kernel: - if current_block_signature != 0 or current_block_fuel != 0: - left_rot, right_rot = get_position_modifier(bb_idx, inst_idx, func_hash) - rotated_sig = rotate_left(rotate_right(current_block_signature, right_rot), left_rot) - print(f"BB{bb_idx}: Signature: {current_block_signature:016x} -> {rotated_sig:016x} (rotl: {left_rot}, rotr: {right_rot})") - print(f"BB{bb_idx}: Fuel: {current_block_fuel}") - modified_code.append(f"\txor.b64 \tr_signature, r_signature, 0x{rotated_sig:016x};\n") - modified_code.append(f"\tadd.u64 \tr_fuelusage, r_fuelusage, {current_block_fuel};\n") - current_block_signature = 0 - current_block_fuel = 0 - bb_idx += 1 - inst_idx = 0 - - if inside_kernel_function and not skip_kernel: - if "trap;" in stripped_line: - modified_code.append("\tmov.u64 \tr_fuel_backup, r_fuelusage;\n") - modified_code.append("\tatom.global.add.u64 \tr_temp_fuel, [r_fuel_addr], r_fuelusage;\n") - modified_code.append("\tadd.u64 \tr_temp_fuel, r_temp_fuel, r_fuel_backup;\n") - modified_code.append("\tmov.u64 \tr_fuelusage, 0;\n") - modified_code.append("\tsetp.gt.u64 p_fuel, r_temp_fuel, 0xdeadbeefdeadbeef;\n") - modified_code.append("\t@p_fuel bra $FUEL_EXCEEDED;\n") - continue - - for instr, (prime, fuel_cost) in instruction_prime_map.items(): - if stripped_line.lstrip().startswith(f"{instr} "): - # XOR with rotation to prevent nullification - left_rot, right_rot = get_position_modifier(bb_idx, inst_idx, func_hash) - rotated_prime = rotate_left(rotate_right(prime, right_rot), left_rot) - current_block_signature ^= rotated_prime - current_block_fuel += fuel_cost - inst_idx += 1 - print(f"Found instruction {instr} in BB{bb_idx} (idx: {inst_idx}, cost: {fuel_cost}, prime: {prime:016x} -> {rotated_prime:016x}, rotl: {left_rot}, rotr: {right_rot})") - break - - if inside_kernel_function and stripped_line == "}": - if not skip_kernel: - if current_block_signature != 0 or current_block_fuel != 0: - left_rot, right_rot = get_position_modifier(bb_idx, inst_idx, func_hash) - rotated_sig = rotate_left(rotate_right(current_block_signature, right_rot), left_rot) - modified_code.append(f"\txor.b64 \tr_signature, r_signature, 0x{rotated_sig:016x};\n") - modified_code.append(f"\tadd.u64 \tr_fuelusage, r_fuelusage, {current_block_fuel};\n") - - modified_code.append(line[1]) - inside_kernel_function = False - continue - - modified_code.append(line[1]) - - return modified_code - -def main(): - parser = argparse.ArgumentParser(description='Compile PTX with injected runtime signature') - parser.add_argument('challenge', help='Challenge name') - parser.add_argument('algorithm', help='Algorithm name') - - args = parser.parse_args() - - print(f"Compiling .ptx for {args.challenge}/{args.algorithm}") - - - framework_cu = "tig-binary/src/framework.cu" - if not os.path.exists(framework_cu): - raise FileNotFoundError( - f"Framework code does not exist @ '{framework_cu}'. This script must be run from the root of tig-monorepo" - ) - - challenge_cu = f"tig-challenges/src/{args.challenge}.cu" - if not os.path.exists(challenge_cu): - raise FileNotFoundError( - f"Challenge code does not exist @ '{challenge_cu}'. Is the challenge name correct?" - ) - - algorithm_cu = f"tig-algorithms/src/{args.challenge}/{args.algorithm}.cu" - algorithm_cu2 = f"tig-algorithms/src/{args.challenge}/{args.algorithm}/benchmarker_outbound.cu" - if not os.path.exists(algorithm_cu) and not os.path.exists(algorithm_cu2): - raise FileNotFoundError( - f"Algorithm code does not exist @ '{algorithm_cu}' or '{algorithm_cu2}'. Is the algorithm name correct?" - ) - if not os.path.exists(algorithm_cu): - algorithm_cu = algorithm_cu2 - - # Combine .cu source files into a temporary file - with tempfile.TemporaryDirectory() as temp_dir: - temp_cu = os.path.join(temp_dir, "temp.cu") - temp_ptx = os.path.join(temp_dir, "temp.ptx") - - with open(framework_cu, 'r') as f: - code = f.read() + "\n" - with open(challenge_cu, 'r') as f: - code += f.read() + "\n" - func_regex = r'(?:extern\s+"C"\s+__global__|__device__)\s+\w+\s+(?P\w+)\s*\(' - funcs_to_ignore = [match.group('func') for match in re.finditer(func_regex, code)] - with open(algorithm_cu, 'r') as f: - code += f.read() - with open(temp_cu, 'w') as f: - f.write(code) - - # Compile the temporary .cu file into a .ptx file using nvcc - nvcc_command = [ - "nvcc", "-ptx", temp_cu, "-o", temp_ptx, - "-arch", "compute_70", - "-code", "sm_70", - "--use_fast_math", - "-dopt=on" - ] - - print(f"Running nvcc command: {' '.join(nvcc_command)}") - subprocess.run(nvcc_command, check=True) - print(f"Successfully compiled") - - print("Adding runtime signature opcodes") - with open(temp_ptx, 'r') as f: - ptx_code = f.readlines() - modified_ptx_code = add_xor_commands( - ptx_code, - instruction_prime_map, - set(f"{x}(" for x in funcs_to_ignore) - ) - - output_path = f"tig-algorithms/lib/{args.challenge}/ptx/{args.algorithm}.ptx" - os.makedirs(os.path.dirname(output_path), exist_ok=True) - with open(output_path, 'w') as f: - f.writelines(modified_ptx_code) - print(f"Wrote ptx to {output_path}") - print(f"Done") - -if __name__ == "__main__": - try: - main() - except Exception as e: - print(f"Error: {e}") - sys.exit(1) \ No newline at end of file diff --git a/tig-binary/scripts/build_so.sh b/tig-binary/scripts/build_so similarity index 100% rename from tig-binary/scripts/build_so.sh rename to tig-binary/scripts/build_so diff --git a/tig-challenges/src/hypergraph.cu b/tig-challenges/src/hypergraph.cu index 125652ca..6ebe44d0 100644 --- a/tig-challenges/src/hypergraph.cu +++ b/tig-challenges/src/hypergraph.cu @@ -454,6 +454,7 @@ extern "C" __global__ void greedy_bipartition( const int *node_offsets, const int *sorted_nodes, const int *node_degrees, + const int *curr_partition, int *partition, unsigned long long *left_hyperedge_flags, unsigned long long *right_hyperedge_flags @@ -466,7 +467,7 @@ extern "C" __global__ void greedy_bipartition( } __syncthreads(); for (int v = threadIdx.x; v < num_nodes; v += blockDim.x) { - if (partition[v] == p) { + if (curr_partition[v] == p) { atomicAdd(&count, 1); } } @@ -492,7 +493,7 @@ extern "C" __global__ void greedy_bipartition( for (int idx = 0; idx < num_nodes; idx++) { int v = sorted_nodes[idx]; - if (partition[v] != p) continue; + if (curr_partition[v] != p) continue; // Get range of hyperedges for this node int start_pos = node_offsets[v]; diff --git a/tig-challenges/src/hypergraph.rs b/tig-challenges/src/hypergraph.rs index 07f97606..565fc4d5 100644 --- a/tig-challenges/src/hypergraph.rs +++ b/tig-challenges/src/hypergraph.rs @@ -71,6 +71,7 @@ pub struct SubInstance { } pub const NUM_SUB_INSTANCES: usize = 16; +pub const MAX_THREADS_PER_BLOCK: u32 = 1024; impl Challenge { pub fn generate_instance( @@ -167,7 +168,7 @@ impl SubInstance { let finalize_shuffle_kernel = module.load_function("finalize_shuffle")?; let calc_connectivity_metric_kernel = module.load_function("calc_connectivity_metric")?; - let block_size = prop.maxThreadsPerBlock as u32; + let block_size = MAX_THREADS_PER_BLOCK; let cfg = LaunchConfig { grid_dim: ((num_hyperedges + block_size - 1) / block_size, 1, 1), block_dim: (block_size, 1, 1), @@ -299,6 +300,7 @@ impl SubInstance { let num_flags = (num_hyperedges + 63) / 64 * num_parts_this_level; let mut d_left_hyperedge_flags = stream.alloc_zeros::(num_flags as usize)?; let mut d_right_hyperedge_flags = stream.alloc_zeros::(num_flags as usize)?; + let d_curr_partition = d_partition.clone(); unsafe { stream @@ -310,6 +312,7 @@ impl SubInstance { .arg(&d_node_offsets) .arg(&d_sorted_nodes) .arg(&d_node_degrees) + .arg(&d_curr_partition) .arg(&mut d_partition) .arg(&mut d_left_hyperedge_flags) .arg(&mut d_right_hyperedge_flags) @@ -455,7 +458,7 @@ impl SubInstance { let calc_connectivity_metric_kernel = module.load_function("calc_connectivity_metric")?; let count_nodes_in_part_kernel = module.load_function("count_nodes_in_part")?; - let block_size = prop.maxThreadsPerBlock as u32; + let block_size = MAX_THREADS_PER_BLOCK; let grid_size = (self.difficulty.num_hyperedges + block_size - 1) / block_size; let cfg = LaunchConfig { diff --git a/tig-challenges/src/vector_search.rs b/tig-challenges/src/vector_search.rs index d0ac890a..8d0cd194 100644 --- a/tig-challenges/src/vector_search.rs +++ b/tig-challenges/src/vector_search.rs @@ -51,6 +51,8 @@ pub struct Challenge { pub max_distance: f32, } +pub const MAX_THREADS_PER_BLOCK: u32 = 1024; + impl Challenge { pub fn generate_instance( seed: &[u8; 32], @@ -75,7 +77,7 @@ impl Challenge { let generate_clusters_kernel = module.load_function("generate_clusters")?; let generate_vectors_kernel = module.load_function("generate_vectors")?; - let block_size = prop.maxThreadsPerBlock as u32; + let block_size = MAX_THREADS_PER_BLOCK; let d_seed = stream.memcpy_stod(seed).unwrap(); let mut d_cluster_means = stream @@ -181,7 +183,7 @@ impl Challenge { let mut d_total_distance = stream.alloc_zeros::(1)?; let mut errorflag = stream.alloc_zeros::(1)?; - let threads_per_block = prop.maxThreadsPerBlock as u32; + let threads_per_block = MAX_THREADS_PER_BLOCK; let blocks = (self.difficulty.num_queries as u32 + threads_per_block - 1) / threads_per_block;