tig-monorepopool/tig-binary/scripts/build_ptx
2025-05-13 16:18:03 +01:00

312 lines
9.3 KiB
Python

#!/usr/bin/env python3
import argparse
import os
import re
import shutil
import subprocess
import sys
import tempfile
# Import the dictionary from ptx_instructions.py
instruction_fuel_cost = {
'add.u32': 2,
'add.u64': 3,
'add.f32': 4,
'add.f64': 5,
'add.s32': 2,
'add.s64': 3,
'sub.u32': 2,
'sub.u64': 3,
'sub.f32': 4,
'sub.f64': 5,
'mul.u32': 4,
'mul.u64': 5,
'mul.f32': 5,
'mul.f64': 6,
'div.u32': 10,
'div.u64': 12,
'div.f32': 15,
'div.f64': 20,
'mul.wide.u32': 6,
'mul.wide.u64': 8,
'mad.wide.u32': 8,
'mad.wide.u64': 10,
'mov.u32': 1,
'mov.u64': 1,
'mov.f32': 1,
'mov.f64': 1,
'and.b32': 1,
'and.b64': 1,
'or.b32': 1,
'or.b64': 1,
'xor.b32': 1,
'xor.b64': 1,
'shl.b32': 2,
'shl.b64': 3,
'shr.b32': 2,
'shr.b64': 3,
'cvt.u32.u64': 2,
'cvt.f32.f64': 3,
'cvt.u64.u32': 2,
'cvt.f64.f32': 3,
'setp.eq.u32': 2,
'setp.eq.u64': 3,
'setp.lt.u32': 2,
'setp.lt.u64': 3,
'setp.gt.u32': 2,
'setp.gt.u64': 3,
'setp.ne.u32': 2,
'setp.ne.u64': 3,
'selp.u32': 3,
'selp.u64': 4,
'abs.s32': 2,
'abs.s64': 3,
'abs.f32': 3,
'abs.f64': 4,
'min.u32': 2,
'min.u64': 3,
'min.f32': 3,
'min.f64': 4,
'max.u32': 2,
'max.u64': 3,
'max.f32': 3,
'max.f64': 4,
'sqrt.rn.f32': 15,
'sqrt.rn.f64': 20,
'rsqrt.rn.f32': 15,
'rsqrt.rn.f64': 20,
'sqrt.approx.ftz.f32': 8,
'sqrt.approx.ftz.f64': 10,
'sin.approx.f32': 8,
'sin.approx.f64': 10,
'cos.approx.f32': 8,
'cos.approx.f64': 10,
'tanh.approx.f32': 8,
'tanh.approx.f64': 10,
'add.f16': 1,
'add.f16x2': 1,
'add.bf16': 1,
'add.bf16x2': 1,
'fma.rn.bf16': 1,
'fma.rn.bf16x2': 1,
'cvt.rn.bf16.f32': 1,
'cvt.rn.f32.bf16': 1,
'cvt.rn.tf32.f32': 1,
'cvt.rn.f32.tf32': 1,
'atom.add.u32': 8,
'atom.add.u64': 10,
'atom.min.u32': 8,
'atom.min.u64': 10,
'atom.max.u32': 8,
'atom.max.u64': 10,
'tex.1d.v4.f32': 15,
'tex.2d.v4.f32': 20,
'tex.3d.v4.f32': 25,
'ld.param.u32': 3,
'ld.param.u64': 4,
'st.param.u32': 3,
'st.param.u64': 4,
'ld.const.u32': 3,
'ld.const.u64': 4,
'popc.b32': 3,
'popc.b64': 4,
'clz.b32': 3,
'clz.b64': 4,
'brev.b32': 3,
'brev.b64': 4,
'unused': 1,
}
def parse_ptx_code(ptx_code):
parsed = []
kernel = None
block = None
for line in ptx_code:
stripped_line = line.strip()
if kernel is None:
if (stripped_line.startswith(".visible .entry") or stripped_line.startswith(".func")):
kernel = {
"definition": [line],
"blocks": None
}
parsed.append(kernel)
else:
parsed.append(line)
elif kernel["blocks"] is None:
if stripped_line == "{":
block = []
kernel["blocks"] = []
else:
kernel["definition"].append(line)
else:
if stripped_line == "}":
if len(block) > 0:
kernel["blocks"].append(block)
kernel = None
block = None
elif stripped_line != "":
block.append(line)
if (
stripped_line == "ret;" or
("bra" in stripped_line and not stripped_line.startswith("//")) or
(stripped_line.startswith("@") and "bra" in stripped_line)
):
kernel["blocks"].append(block)
block = []
return parsed
def inject_fuel_and_runtime_sig(parsed, kernels_to_ignore):
modified_code = []
block_sig = 0
for line in parsed:
if not isinstance(line, dict):
block_sig ^= hash(line) & 0xFFFFFFFFFFFFFFFF
modified_code.append(line)
continue
kernel = line
block_sig ^= hash(kernel["definition"][0]) & 0xFFFFFFFFFFFFFFFF
name = (
kernel["definition"][0] # func sig in first line
.split()[-1] # func name is last token
.split("(")[0] # func name is before the first (
)
if name in kernels_to_ignore:
print(f"kernel: {name}, #blocks: {len(kernel['blocks'])}, status: SKIPPED")
modified_code.extend(kernel["definition"])
modified_code.append("{")
for block in kernel["blocks"]:
modified_code.extend(block)
modified_code.append("}")
continue
print(f"kernel: {name}, #blocks: {len(kernel['blocks'])}, status: PROCESSING")
modified_code.extend(kernel["definition"])
modified_code.append("{")
modified_code.append(
"""
\t.reg .u64 \tr_signature;
\t.reg .u64 \tr_sig_addr;
\t.reg .u64 \tr_temp_fuel;
\t.reg .u64 \tr_fuel_usage;
\t.reg .u64 \tr_fuel_addr;
\t.reg .pred \tp_fuel;
\tmov.u64 \tr_signature, 0xa1b2c3d4e5f6a7b8;
\tmov.u64 \tr_sig_addr, gbl_SIGNATURE;
\tmov.u64 \tr_temp_fuel, 0;
\tmov.u64 \tr_fuel_usage, 0;
\tmov.u64 \tr_fuel_addr, gbl_FUELUSAGE;
"""
)
for i, block in enumerate(kernel["blocks"]):
block_sig ^= hash(block[0]) & 0xFFFFFFFFFFFFFFFF
block_fuel = sum(
instruction_fuel_cost.get(instr.split()[0], 0)
for instr in block
)
print(f"\tblock {i}: fuel_usage: {block_fuel}, signature: 0x{block_sig:016x}")
modified_code.extend(block[:-1])
modified_code.append(
f"""
\txor.b64 \tr_signature, r_signature, 0x{block_sig:016x};
\tadd.u64 \tr_fuel_usage, r_fuel_usage, {block_fuel};
"""
)
if block[-1].strip() == "ret;":
modified_code.append(
"""
\tatom.global.add.u64 \tr_temp_fuel, [r_fuel_addr], r_fuel_usage;
\tsetp.lt.u64 \tp_fuel, r_temp_fuel, 0xdeadbeefdeadbeef;
\t@p_fuel bra $NORMAL_EXIT;
\tst.global.u64 \t[gbl_ERRORSTAT], 1;
$NORMAL_EXIT:
\tatom.global.xor.b64 \tr_sig_addr, [r_sig_addr], r_signature;
\tatom.global.add.u64 \tr_fuel_addr, [r_fuel_addr], r_fuel_usage;
"""
)
modified_code.append(block[-1])
modified_code.append("}")
return modified_code
def main():
parser = argparse.ArgumentParser(description='Compile PTX with injected runtime signature')
parser.add_argument('challenge', help='Challenge name')
parser.add_argument('algorithm', help='Algorithm name')
args = parser.parse_args()
print(f"Compiling .ptx for {args.challenge}/{args.algorithm}")
framework_cu = "tig-binary/src/framework.cu"
if not os.path.exists(framework_cu):
raise FileNotFoundError(
f"Framework code does not exist @ '{framework_cu}'. This script must be run from the root of tig-monorepo"
)
challenge_cu = f"tig-challenges/src/{args.challenge}.cu"
if not os.path.exists(challenge_cu):
raise FileNotFoundError(
f"Challenge code does not exist @ '{challenge_cu}'. Is the challenge name correct?"
)
algorithm_cu = f"tig-algorithms/src/{args.challenge}/{args.algorithm}.cu"
algorithm_cu2 = f"tig-algorithms/src/{args.challenge}/{args.algorithm}/benchmarker_outbound.cu"
if not os.path.exists(algorithm_cu) and not os.path.exists(algorithm_cu2):
raise FileNotFoundError(
f"Algorithm code does not exist @ '{algorithm_cu}' or '{algorithm_cu2}'. Is the algorithm name correct?"
)
if not os.path.exists(algorithm_cu):
algorithm_cu = algorithm_cu2
# Combine .cu source files into a temporary file
with tempfile.TemporaryDirectory() as temp_dir:
temp_cu = os.path.join(temp_dir, "temp.cu")
temp_ptx = os.path.join(temp_dir, "temp.ptx")
with open(framework_cu, 'r') as f:
code = f.read() + "\n"
with open(challenge_cu, 'r') as f:
code += f.read() + "\n"
kernel_regex = r'(?:extern\s+"C"\s+__global__|__device__)\s+\w+\s+(?P<func>\w+)\s*\('
kernels_to_ignore = [match.group('func') for match in re.finditer(kernel_regex, code)]
with open(algorithm_cu, 'r') as f:
code += f.read()
with open(temp_cu, 'w') as f:
f.write(code)
# Compile the temporary .cu file into a .ptx file using nvcc
nvcc_command = [
"nvcc", "-ptx", temp_cu, "-o", temp_ptx,
"-arch", "compute_70",
"-code", "sm_70",
"--use_fast_math",
"-dopt=on"
]
print(f"Running nvcc command: {' '.join(nvcc_command)}")
subprocess.run(nvcc_command, check=True)
print(f"Successfully compiled")
print("Adding runtime signature opcodes")
with open(temp_ptx, 'r') as f:
ptx_code = f.readlines()
parsed = parse_ptx_code(ptx_code)
modified_code = inject_fuel_and_runtime_sig(parsed, kernels_to_ignore)
output_path = f"tig-algorithms/lib/{args.challenge}/ptx/{args.algorithm}.ptx"
os.makedirs(os.path.dirname(output_path), exist_ok=True)
# with open(output_path, 'w') as f:
# f.writelines(ptx_code)
with open(output_path, 'w') as f:
f.writelines(modified_code)
print(f"Wrote ptx to {output_path}")
print(f"Done")
if __name__ == "__main__":
try:
main()
except Exception as e:
print(f"Error: {e}")
sys.exit(1)