mirror of
https://github.com/tig-pool-nk/tig-monorepo.git
synced 2026-02-21 16:37:22 +08:00
318 lines
9.5 KiB
Python
318 lines
9.5 KiB
Python
#!/usr/bin/env python3
|
|
|
|
from glob import glob
|
|
import argparse
|
|
import os
|
|
import re
|
|
import subprocess
|
|
import sys
|
|
import tempfile
|
|
|
|
CHALLENGE = os.getenv("CHALLENGE")
|
|
if CHALLENGE is None:
|
|
print("CHALLENGE environment variable must be set!")
|
|
sys.exit(1)
|
|
|
|
# Import the dictionary from ptx_instructions.py
|
|
instruction_fuel_cost = {
|
|
'add.u32': 2,
|
|
'add.u64': 3,
|
|
'add.f32': 4,
|
|
'add.f64': 5,
|
|
'add.s32': 2,
|
|
'add.s64': 3,
|
|
'sub.u32': 2,
|
|
'sub.u64': 3,
|
|
'sub.f32': 4,
|
|
'sub.f64': 5,
|
|
'mul.u32': 4,
|
|
'mul.u64': 5,
|
|
'mul.f32': 5,
|
|
'mul.f64': 6,
|
|
'div.u32': 10,
|
|
'div.u64': 12,
|
|
'div.f32': 15,
|
|
'div.f64': 20,
|
|
'mul.wide.u32': 6,
|
|
'mul.wide.u64': 8,
|
|
'mad.wide.u32': 8,
|
|
'mad.wide.u64': 10,
|
|
'mov.u32': 1,
|
|
'mov.u64': 1,
|
|
'mov.f32': 1,
|
|
'mov.f64': 1,
|
|
'and.b32': 1,
|
|
'and.b64': 1,
|
|
'or.b32': 1,
|
|
'or.b64': 1,
|
|
'xor.b32': 1,
|
|
'xor.b64': 1,
|
|
'shl.b32': 2,
|
|
'shl.b64': 3,
|
|
'shr.b32': 2,
|
|
'shr.b64': 3,
|
|
'cvt.u32.u64': 2,
|
|
'cvt.f32.f64': 3,
|
|
'cvt.u64.u32': 2,
|
|
'cvt.f64.f32': 3,
|
|
'setp.eq.u32': 2,
|
|
'setp.eq.u64': 3,
|
|
'setp.lt.u32': 2,
|
|
'setp.lt.u64': 3,
|
|
'setp.gt.u32': 2,
|
|
'setp.gt.u64': 3,
|
|
'setp.ne.u32': 2,
|
|
'setp.ne.u64': 3,
|
|
'selp.u32': 3,
|
|
'selp.u64': 4,
|
|
'abs.s32': 2,
|
|
'abs.s64': 3,
|
|
'abs.f32': 3,
|
|
'abs.f64': 4,
|
|
'min.u32': 2,
|
|
'min.u64': 3,
|
|
'min.f32': 3,
|
|
'min.f64': 4,
|
|
'max.u32': 2,
|
|
'max.u64': 3,
|
|
'max.f32': 3,
|
|
'max.f64': 4,
|
|
'sqrt.rn.f32': 15,
|
|
'sqrt.rn.f64': 20,
|
|
'rsqrt.rn.f32': 15,
|
|
'rsqrt.rn.f64': 20,
|
|
'sqrt.approx.ftz.f32': 8,
|
|
'sqrt.approx.ftz.f64': 10,
|
|
'sin.approx.f32': 8,
|
|
'sin.approx.f64': 10,
|
|
'cos.approx.f32': 8,
|
|
'cos.approx.f64': 10,
|
|
'tanh.approx.f32': 8,
|
|
'tanh.approx.f64': 10,
|
|
'add.f16': 1,
|
|
'add.f16x2': 1,
|
|
'add.bf16': 1,
|
|
'add.bf16x2': 1,
|
|
'fma.rn.bf16': 1,
|
|
'fma.rn.bf16x2': 1,
|
|
'cvt.rn.bf16.f32': 1,
|
|
'cvt.rn.f32.bf16': 1,
|
|
'cvt.rn.tf32.f32': 1,
|
|
'cvt.rn.f32.tf32': 1,
|
|
'atom.add.u32': 8,
|
|
'atom.add.u64': 10,
|
|
'atom.min.u32': 8,
|
|
'atom.min.u64': 10,
|
|
'atom.max.u32': 8,
|
|
'atom.max.u64': 10,
|
|
'tex.1d.v4.f32': 15,
|
|
'tex.2d.v4.f32': 20,
|
|
'tex.3d.v4.f32': 25,
|
|
'ld.param.u32': 3,
|
|
'ld.param.u64': 4,
|
|
'st.param.u32': 3,
|
|
'st.param.u64': 4,
|
|
'ld.const.u32': 3,
|
|
'ld.const.u64': 4,
|
|
'popc.b32': 3,
|
|
'popc.b64': 4,
|
|
'clz.b32': 3,
|
|
'clz.b64': 4,
|
|
'brev.b32': 3,
|
|
'brev.b64': 4,
|
|
'unused': 1,
|
|
}
|
|
|
|
def parse_ptx_code(ptx_code):
|
|
parsed = []
|
|
kernel = None
|
|
block = None
|
|
for line in ptx_code:
|
|
stripped_line = line.strip()
|
|
if kernel is None:
|
|
if (stripped_line.startswith(".visible .entry") or stripped_line.startswith(".func")):
|
|
kernel = {
|
|
"definition": [line],
|
|
"blocks": None
|
|
}
|
|
parsed.append(kernel)
|
|
else:
|
|
parsed.append(line)
|
|
elif kernel["blocks"] is None:
|
|
if stripped_line == "{":
|
|
block = []
|
|
kernel["blocks"] = []
|
|
else:
|
|
kernel["definition"].append(line)
|
|
else:
|
|
if stripped_line == "}":
|
|
if len(block) > 0:
|
|
kernel["blocks"].append(block)
|
|
kernel = None
|
|
block = None
|
|
elif stripped_line != "":
|
|
block.append(line)
|
|
if (
|
|
stripped_line == "ret;" or
|
|
("bra" in stripped_line and not stripped_line.startswith("//")) or
|
|
(stripped_line.startswith("@") and "bra" in stripped_line)
|
|
):
|
|
kernel["blocks"].append(block)
|
|
block = []
|
|
return parsed
|
|
|
|
def inject_fuel_and_runtime_sig(parsed, kernels_to_ignore):
|
|
modified_code = []
|
|
block_sig = 0
|
|
|
|
for line in parsed:
|
|
if not isinstance(line, dict):
|
|
block_sig ^= hash(line) & 0xFFFFFFFFFFFFFFFF
|
|
modified_code.append(line)
|
|
continue
|
|
|
|
kernel = line
|
|
block_sig ^= hash(kernel["definition"][0]) & 0xFFFFFFFFFFFFFFFF
|
|
name = (
|
|
kernel["definition"][0] # func sig in first line
|
|
.split()[-1] # func name is last token
|
|
.split("(")[0] # func name is before the first (
|
|
)
|
|
if name in kernels_to_ignore:
|
|
print(f"kernel: {name}, #blocks: {len(kernel['blocks'])}, status: SKIPPED")
|
|
modified_code.extend(kernel["definition"])
|
|
modified_code.append("{")
|
|
for block in kernel["blocks"]:
|
|
modified_code.extend(block)
|
|
modified_code.append("}")
|
|
continue
|
|
|
|
print(f"kernel: {name}, #blocks: {len(kernel['blocks'])}, status: PROCESSING")
|
|
modified_code.extend(kernel["definition"])
|
|
modified_code.append("{")
|
|
modified_code.append(
|
|
"""
|
|
\t.reg .u64 \tr_signature;
|
|
\t.reg .u64 \tr_sig_addr;
|
|
\t.reg .u64 \tr_temp_fuel;
|
|
\t.reg .u64 \tr_fuel_usage;
|
|
\t.reg .u64 \tr_fuel_addr;
|
|
\t.reg .pred \tp_fuel;
|
|
\tmov.u64 \tr_signature, 0xa1b2c3d4e5f6a7b8;
|
|
\tmov.u64 \tr_sig_addr, gbl_SIGNATURE;
|
|
\tmov.u64 \tr_temp_fuel, 0;
|
|
\tmov.u64 \tr_fuel_usage, 0;
|
|
\tmov.u64 \tr_fuel_addr, gbl_FUELUSAGE;
|
|
"""
|
|
)
|
|
for i, block in enumerate(kernel["blocks"]):
|
|
block_sig ^= hash(block[0]) & 0xFFFFFFFFFFFFFFFF
|
|
block_fuel = sum(
|
|
instruction_fuel_cost.get(instr.split()[0], 0)
|
|
for instr in block
|
|
)
|
|
print(f"\tblock {i}: fuel_usage: {block_fuel}, signature: 0x{block_sig:016x}")
|
|
modified_code.extend(block[:-1])
|
|
modified_code.append(
|
|
f"""
|
|
\txor.b64 \tr_signature, r_signature, 0x{block_sig:016x};
|
|
\tadd.u64 \tr_fuel_usage, r_fuel_usage, {block_fuel};
|
|
"""
|
|
)
|
|
if block[-1].strip() == "ret;":
|
|
modified_code.append(
|
|
"""
|
|
\tatom.global.add.u64 \tr_temp_fuel, [r_fuel_addr], r_fuel_usage;
|
|
\tsetp.lt.u64 \tp_fuel, r_temp_fuel, 0xdeadbeefdeadbeef;
|
|
\t@p_fuel bra $NORMAL_EXIT;
|
|
\tst.global.u64 \t[gbl_ERRORSTAT], 1;
|
|
$NORMAL_EXIT:
|
|
\tatom.global.xor.b64 \tr_sig_addr, [r_sig_addr], r_signature;
|
|
\tatom.global.add.u64 \tr_fuel_addr, [r_fuel_addr], r_fuel_usage;
|
|
"""
|
|
)
|
|
modified_code.append(block[-1])
|
|
modified_code.append("}")
|
|
return modified_code
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='Compile PTX with injected runtime signature')
|
|
parser.add_argument('algorithm', help='Algorithm name')
|
|
parser.add_argument('--extra-cu-files', nargs='*', default=[], help='Additional .cu files to include in the compilation')
|
|
|
|
args = parser.parse_args()
|
|
|
|
print(f"Compiling .ptx for {CHALLENGE}/{args.algorithm}")
|
|
|
|
framework_cu = "tig-binary/src/framework.cu"
|
|
if not os.path.exists(framework_cu):
|
|
raise FileNotFoundError(
|
|
f"Framework code does not exist @ '{framework_cu}'. This script must be run from the root of tig-monorepo"
|
|
)
|
|
|
|
challenge_cu = f"tig-challenges/src/{CHALLENGE}.cu"
|
|
if not os.path.exists(challenge_cu):
|
|
raise FileNotFoundError(
|
|
f"Challenge code does not exist @ '{challenge_cu}'. Is the challenge name correct?"
|
|
)
|
|
|
|
algorithm_cus = glob(f"tig-algorithms/src/{CHALLENGE}/{args.algorithm}/*.cu")
|
|
if not algorithm_cus:
|
|
raise FileNotFoundError(
|
|
f"Algorithm code does not exist @ '{algorithm_cus}'. Is the algorithm name correct?"
|
|
)
|
|
|
|
# Combine .cu source files into a temporary file
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
temp_cu = os.path.join(temp_dir, "temp.cu")
|
|
temp_ptx = os.path.join(temp_dir, "temp.ptx")
|
|
|
|
with open(framework_cu, 'r') as f:
|
|
code = f.read() + "\n"
|
|
with open(challenge_cu, 'r') as f:
|
|
code += f.read() + "\n"
|
|
for extra_cu in args.extra_cu_files:
|
|
if not os.path.exists(extra_cu):
|
|
raise FileNotFoundError(f"Extra .cu file does not exist: {extra_cu}")
|
|
with open(extra_cu, 'r') as f:
|
|
code += f.read() + "\n"
|
|
kernel_regex = r'(?:extern\s+"C"\s+__global__|__device__)\s+\w+\s+(?P<func>\w+)\s*\('
|
|
kernels_to_ignore = [match.group('func') for match in re.finditer(kernel_regex, code)]
|
|
for algorithm_cu in algorithm_cus:
|
|
with open(algorithm_cu, 'r') as f:
|
|
code += f.read() + "\n\n"
|
|
with open(temp_cu, 'w') as f:
|
|
f.write(code)
|
|
|
|
# Compile the temporary .cu file into a .ptx file using nvcc
|
|
nvcc_command = [
|
|
"nvcc", "-ptx", temp_cu, "-o", temp_ptx,
|
|
"-arch", "compute_70",
|
|
"-code", "sm_70",
|
|
"--use_fast_math",
|
|
"-dopt=on"
|
|
]
|
|
|
|
print(f"Running nvcc command: {' '.join(nvcc_command)}")
|
|
subprocess.run(nvcc_command, check=True)
|
|
print(f"Successfully compiled")
|
|
|
|
print("Adding runtime signature opcodes")
|
|
with open(temp_ptx, 'r') as f:
|
|
ptx_code = f.readlines()
|
|
parsed = parse_ptx_code(ptx_code)
|
|
modified_code = inject_fuel_and_runtime_sig(parsed, kernels_to_ignore)
|
|
|
|
output_path = f"tig-algorithms/lib/{CHALLENGE}/ptx/{args.algorithm}.ptx"
|
|
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
|
with open(output_path, 'w') as f:
|
|
f.writelines(modified_code)
|
|
print(f"Wrote ptx to {output_path}")
|
|
print(f"Done")
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
main()
|
|
except Exception as e:
|
|
print(f"Error: {e}")
|
|
sys.exit(1) |