#!/usr/bin/env python3 from glob import glob import argparse import os import re import subprocess import sys import tempfile CHALLENGE = os.getenv("CHALLENGE") if CHALLENGE is None: print("CHALLENGE environment variable must be set!") sys.exit(1) # Import the dictionary from ptx_instructions.py instruction_fuel_cost = { 'add.u32': 1, 'add.u64': 5, 'add.f32': 1, 'add.f64': 5, 'add.s32': 1, 'add.s64': 5, 'sub.u32': 1, 'sub.u64': 5, 'sub.f32': 1, 'sub.f64': 5, 'mul.u32': 2, 'mul.u64': 5, 'mul.f32': 1, 'mul.f64': 8, 'div.u32': 10, 'div.u64': 12, 'div.f32': 15, 'div.f64': 20, 'mul.wide.u32': 6, 'mul.wide.u64': 8, 'mad.wide.u32': 8, 'mad.wide.u64': 10, 'mov.u32': 1, 'mov.u64': 1, 'mov.f32': 1, 'mov.f64': 1, 'and.b32': 1, 'and.b64': 1, 'or.b32': 1, 'or.b64': 1, 'xor.b32': 1, 'xor.b64': 1, 'shl.b32': 1, 'shl.b64': 3, 'shr.b32': 1, 'shr.b64': 3, 'cvt.u32.u64': 2, 'cvt.f32.f64': 2, 'cvt.u64.u32': 2, 'cvt.f64.f32': 2, 'setp.eq.u32': 1, 'setp.eq.u64': 2, 'setp.lt.u32': 1, 'setp.lt.u64': 2, 'setp.gt.u32': 1, 'setp.gt.u64': 2, 'setp.ne.u32': 1, 'setp.ne.u64': 2, 'selp.u32': 2, 'selp.u64': 3, 'abs.s32': 1, 'abs.s64': 2, 'abs.f32': 1, 'abs.f64': 2, 'min.u32': 1, 'min.u64': 3, 'min.f32': 1, 'min.f64': 4, 'max.u32': 1, 'max.u64': 3, 'max.f32': 1, 'max.f64': 4, 'sqrt.rn.f32': 15, 'sqrt.rn.f64': 20, 'rsqrt.rn.f32': 15, 'rsqrt.rn.f64': 20, 'sqrt.approx.ftz.f32': 8, 'sqrt.approx.ftz.f64': 10, 'sin.approx.f32': 8, 'sin.approx.f64': 10, 'cos.approx.f32': 8, 'cos.approx.f64': 10, 'tanh.approx.f32': 8, 'tanh.approx.f64': 10, 'add.f16': 1, 'add.f16x2': 1, 'add.bf16': 1, 'add.bf16x2': 1, 'fma.rn.bf16': 1, 'fma.rn.bf16x2': 1, 'cvt.rn.bf16.f32': 1, 'cvt.rn.f32.bf16': 1, 'cvt.rn.tf32.f32': 1, 'cvt.rn.f32.tf32': 1, 'atom.add.u32': 8, 'atom.add.u64': 10, 'atom.min.u32': 8, 'atom.min.u64': 10, 'atom.max.u32': 8, 'atom.max.u64': 10, 'tex.1d.v4.f32': 15, 'tex.2d.v4.f32': 20, 'tex.3d.v4.f32': 25, 'ld.param.u32': 3, 'ld.param.u64': 4, 'st.param.u32': 3, 'st.param.u64': 4, 'ld.const.u32': 3, 'ld.const.u64': 4, 'popc.b32': 1, 'popc.b64': 2, 'clz.b32': 3, 'clz.b64': 4, 'brev.b32': 3, 'brev.b64': 4, 'unused': 1, } def parse_ptx_code(ptx_code): parsed = [] kernel = None block = None for line in ptx_code: stripped_line = line.strip() if kernel is None: if (stripped_line.startswith(".visible .entry") or stripped_line.startswith(".func")): kernel = { "definition": [line], "blocks": None } parsed.append(kernel) else: parsed.append(line) elif kernel["blocks"] is None: if stripped_line == "{": block = [] kernel["blocks"] = [] else: kernel["definition"].append(line) else: if stripped_line == "}": if len(block) > 0: kernel["blocks"].append(block) kernel = None block = None elif stripped_line != "": block.append(line) if ( stripped_line == "ret;" or ("bra" in stripped_line and not stripped_line.startswith("//")) or (stripped_line.startswith("@") and "bra" in stripped_line) ): kernel["blocks"].append(block) block = [] return parsed def inject_fuel_and_runtime_sig(parsed, kernels_to_ignore): modified_code = [] block_sig = 0 for line in parsed: if not isinstance(line, dict): block_sig ^= hash(line) & 0xFFFFFFFFFFFFFFFF modified_code.append(line) continue kernel = line block_sig ^= hash(kernel["definition"][0]) & 0xFFFFFFFFFFFFFFFF name = ( kernel["definition"][0] # func sig in first line .split()[-1] # func name is last token .split("(")[0] # func name is before the first ( ) if name in kernels_to_ignore: print(f"kernel: {name}, #blocks: {len(kernel['blocks'])}, status: SKIPPED") modified_code.extend(kernel["definition"]) modified_code.append("{") for block in kernel["blocks"]: modified_code.extend(block) modified_code.append("}") continue print(f"kernel: {name}, #blocks: {len(kernel['blocks'])}, status: PROCESSING") modified_code.extend(kernel["definition"]) modified_code.append("{") modified_code.append( """ \t.reg .u64 \tr_signature; \t.reg .u64 \tr_sig_addr; \t.reg .u64 \tr_temp_fuel; \t.reg .u64 \tr_fuel_usage; \t.reg .u64 \tr_fuel_addr; \t.reg .pred \tp_fuel; \tmov.u64 \tr_signature, 0xa1b2c3d4e5f6a7b8; \tmov.u64 \tr_sig_addr, gbl_SIGNATURE; \tmov.u64 \tr_temp_fuel, 0; \tmov.u64 \tr_fuel_usage, 0; \tmov.u64 \tr_fuel_addr, gbl_FUELUSAGE; """ ) for i, block in enumerate(kernel["blocks"]): block_sig ^= hash(block[0]) & 0xFFFFFFFFFFFFFFFF block_fuel = sum( instruction_fuel_cost.get(instr.split()[0], 0) for instr in block ) print(f"\tblock {i}: fuel_usage: {block_fuel}, signature: 0x{block_sig:016x}") modified_code.extend(block[:-1]) modified_code.append( f""" \txor.b64 \tr_signature, r_signature, 0x{block_sig:016x}; \tadd.u64 \tr_fuel_usage, r_fuel_usage, {block_fuel}; """ ) if block[-1].strip() == "ret;": modified_code.append( """ \tatom.global.add.u64 \tr_temp_fuel, [r_fuel_addr], r_fuel_usage; \tsetp.lt.u64 \tp_fuel, r_temp_fuel, 0xdeadbeefdeadbeef; \t@p_fuel bra $NORMAL_EXIT; \tst.global.u64 \t[gbl_ERRORSTAT], 1; $NORMAL_EXIT: \tatom.global.xor.b64 \tr_sig_addr, [r_sig_addr], r_signature; \tatom.global.add.u64 \tr_fuel_addr, [r_fuel_addr], r_fuel_usage; """ ) modified_code.append(block[-1]) modified_code.append("}") return modified_code def main(): parser = argparse.ArgumentParser(description='Compile PTX with injected runtime signature') parser.add_argument('algorithm', help='Algorithm name') parser.add_argument('--extra-cu-files', nargs='*', default=[], help='Additional .cu files to include in the compilation') args = parser.parse_args() print(f"Compiling .ptx for {CHALLENGE}/{args.algorithm}") framework_cu = "tig-binary/src/framework.cu" if not os.path.exists(framework_cu): raise FileNotFoundError( f"Framework code does not exist @ '{framework_cu}'. This script must be run from the root of tig-monorepo" ) challenge_cus_pattern = f"tig-challenges/src/{CHALLENGE}/**/*.cu" challenge_cus = glob(challenge_cus_pattern, recursive=True) if not challenge_cus: raise FileNotFoundError( f"Challenge code does not exist @ '{challenge_cus_pattern}'. Is the challenge name correct?" ) algorithm_cus_pattern = f"tig-algorithms/src/{CHALLENGE}/{args.algorithm}/*.cu" algorithm_cus = glob(algorithm_cus_pattern) if not algorithm_cus: raise FileNotFoundError( f"Algorithm code does not exist @ '{algorithm_cus_pattern}'. Is the algorithm name correct?" ) # Combine .cu source files into a temporary file with tempfile.TemporaryDirectory() as temp_dir: temp_cu = os.path.join(temp_dir, "temp.cu") temp_ptx = os.path.join(temp_dir, "temp.ptx") with open(framework_cu, 'r') as f: code = f.read() + "\n" for cu_path in challenge_cus: with open(cu_path, 'r') as f: code += f.read() + "\n" kernel_regex = r'(?:extern\s+"C"\s+__global__|__device__)\s+\w+\s+(?P\w+)\s*\(' kernels_to_ignore = [match.group('func') for match in re.finditer(kernel_regex, code)] for algorithm_cu in algorithm_cus: with open(algorithm_cu, 'r') as f: code += f.read() + "\n\n" with open(temp_cu, 'w') as f: f.write(code) # Compile the temporary .cu file into a .ptx file using nvcc nvcc_command = [ "nvcc", "-ptx", temp_cu, "-o", temp_ptx, "-arch", "compute_70", "-code", "sm_70", "--use_fast_math", "-dopt=on" ] print(f"Running nvcc command: {' '.join(nvcc_command)}") subprocess.run(nvcc_command, check=True) print(f"Successfully compiled") print("Adding runtime signature opcodes") with open(temp_ptx, 'r') as f: ptx_code = f.readlines() parsed = parse_ptx_code(ptx_code) modified_code = inject_fuel_and_runtime_sig(parsed, kernels_to_ignore) output_path = f"tig-algorithms/lib/{CHALLENGE}/ptx/{args.algorithm}.ptx" os.makedirs(os.path.dirname(output_path), exist_ok=True) with open(output_path, 'w') as f: f.writelines(modified_code) print(f"Wrote ptx to {output_path}") print(f"Done") if __name__ == "__main__": try: main() except Exception as e: print(f"Error: {e}") sys.exit(1)