mirror of
https://github.com/tig-pool-nk/tig-monorepo.git
synced 2026-02-21 18:07:22 +08:00
Add tests to ensure python and rust versions are the same.
This commit is contained in:
parent
3cb372f9d5
commit
46e44fa3ee
4
.gitignore
vendored
4
.gitignore
vendored
@ -13,4 +13,6 @@ Cargo.lock
|
||||
# MSVC Windows builds of rustc generate these, which store debugging information
|
||||
*.pdb
|
||||
|
||||
.vscode/
|
||||
.vscode/
|
||||
|
||||
__pycache__/
|
||||
1
tig-benchmarker/.gitignore
vendored
1
tig-benchmarker/.gitignore
vendored
@ -1 +0,0 @@
|
||||
__pycache__
|
||||
@ -1,307 +0,0 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Set, Any, Tuple
|
||||
|
||||
Point = Tuple[int, ...]
|
||||
Frontier = Set[Point]
|
||||
|
||||
@dataclass
|
||||
class AlgorithmDetails:
|
||||
name: str
|
||||
player_id: str
|
||||
challenge_id: str
|
||||
tx_hash: str
|
||||
|
||||
@dataclass
|
||||
class AlgorithmState:
|
||||
block_confirmed: Optional[int] = None
|
||||
round_submitted: Optional[int] = None
|
||||
round_pushed: Optional[int] = None
|
||||
round_merged: Optional[int] = None
|
||||
banned: bool = False
|
||||
|
||||
@dataclass
|
||||
class AlgorithmBlockData:
|
||||
num_qualifiers_by_player: Optional[Dict[str, int]] = None
|
||||
adoption: Optional[int] = None
|
||||
merge_points: Optional[int] = None
|
||||
reward: Optional[int] = None
|
||||
round_earnings: Optional[int] = None
|
||||
|
||||
@dataclass
|
||||
class Algorithm:
|
||||
id: str
|
||||
details: AlgorithmDetails
|
||||
state: Optional[AlgorithmState] = None
|
||||
block_data: Optional[AlgorithmBlockData] = None
|
||||
code: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict) -> "Algorithm":
|
||||
data = d.pop("block_data")
|
||||
return cls(
|
||||
id=d.pop("id"),
|
||||
details=AlgorithmDetails(**d.pop("details")),
|
||||
state=AlgorithmState(**d.pop("state")),
|
||||
block_data=AlgorithmBlockData(**data) if data else None,
|
||||
code=d.pop("code", None)
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class BenchmarkSettings:
|
||||
player_id: str
|
||||
block_id: str
|
||||
challenge_id: str
|
||||
algorithm_id: str
|
||||
difficulty: List[int]
|
||||
|
||||
@dataclass
|
||||
class BenchmarkDetails:
|
||||
block_started: int
|
||||
num_solutions: int
|
||||
|
||||
@dataclass
|
||||
class BenchmarkState:
|
||||
block_confirmed: Optional[int] = None
|
||||
sampled_nonces: Optional[List[int]] = None
|
||||
|
||||
@dataclass
|
||||
class SolutionMetaData:
|
||||
nonce: int
|
||||
solution_signature: int
|
||||
|
||||
@dataclass
|
||||
class SolutionData:
|
||||
nonce: int
|
||||
runtime_signature: int
|
||||
fuel_consumed: int
|
||||
solution: Dict[str, Any]
|
||||
|
||||
@dataclass
|
||||
class Benchmark:
|
||||
id: str
|
||||
settings: BenchmarkSettings
|
||||
details: BenchmarkDetails
|
||||
state: Optional[BenchmarkState] = None
|
||||
solutions_meta_data: Optional[List[SolutionMetaData]] = None
|
||||
solution_data: Optional[SolutionData] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict) -> "Benchmark":
|
||||
solution_data = d.pop("solution_data", None)
|
||||
solutions_meta_data = d.pop("solutions_meta_data")
|
||||
return cls(
|
||||
id=d.pop("id"),
|
||||
settings=BenchmarkSettings(**d.pop("settings")),
|
||||
details=BenchmarkDetails(**d.pop("details")),
|
||||
state=BenchmarkState(**d.pop("state")),
|
||||
solutions_meta_data=[SolutionMetaData(**s) for s in solutions_meta_data] if solutions_meta_data else None,
|
||||
solution_data=SolutionData(**solution_data) if solution_data else None
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class BlockDetails:
|
||||
prev_block_id: str
|
||||
height: int
|
||||
round: int
|
||||
eth_block_num: Optional[str] = None
|
||||
|
||||
@dataclass
|
||||
class BlockData:
|
||||
mempool_challenge_ids: Set[str] = field(default_factory=set)
|
||||
mempool_algorithm_ids: Set[str] = field(default_factory=set)
|
||||
mempool_benchmark_ids: Set[str] = field(default_factory=set)
|
||||
mempool_proof_ids: Set[str] = field(default_factory=set)
|
||||
mempool_fraud_ids: Set[str] = field(default_factory=set)
|
||||
mempool_wasm_ids: Set[str] = field(default_factory=set)
|
||||
active_challenge_ids: Set[str] = field(default_factory=set)
|
||||
active_algorithm_ids: Set[str] = field(default_factory=set)
|
||||
active_benchmark_ids: Set[str] = field(default_factory=set)
|
||||
active_player_ids: Set[str] = field(default_factory=set)
|
||||
|
||||
@dataclass
|
||||
class Block:
|
||||
id: str
|
||||
details: BlockDetails
|
||||
config: dict
|
||||
data: Optional[BlockData] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict) -> "Block":
|
||||
data = d.pop("data", None)
|
||||
return cls(
|
||||
id=d.pop("id"),
|
||||
details=BlockDetails(**d.pop("details")),
|
||||
config=d.pop("config"),
|
||||
data=BlockData(**data) if data else None
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class ChallengeDetails:
|
||||
name: str
|
||||
|
||||
@dataclass
|
||||
class ChallengeState:
|
||||
block_confirmed: Optional[int] = None
|
||||
round_active: Optional[int] = None
|
||||
|
||||
@dataclass
|
||||
class ChallengeBlockData:
|
||||
solution_signature_threshold: Optional[int] = None
|
||||
num_qualifiers: Optional[int] = None
|
||||
qualifier_difficulties: Optional[Set[Point]] = None
|
||||
base_frontier: Optional[Frontier] = None
|
||||
cutoff_frontier: Optional[Frontier] = None
|
||||
scaled_frontier: Optional[Frontier] = None
|
||||
scaling_factor: Optional[float] = None
|
||||
|
||||
@dataclass
|
||||
class Challenge:
|
||||
id: str
|
||||
details: ChallengeDetails
|
||||
state: Optional[ChallengeState] = None
|
||||
block_data: Optional[ChallengeBlockData] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict) -> "Challenge":
|
||||
block_data = d.pop("block_data", None)
|
||||
return cls(
|
||||
id=d.pop("id"),
|
||||
details=ChallengeDetails(**d.pop("details")),
|
||||
state=ChallengeState(**d.pop("state")),
|
||||
block_data=ChallengeBlockData(**block_data) if block_data else None
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class PlayerDetails:
|
||||
name: str
|
||||
is_multisig: bool
|
||||
|
||||
@dataclass
|
||||
class PlayerBlockData:
|
||||
num_qualifiers_by_challenge: Optional[Dict[str, int]] = None
|
||||
cutoff: Optional[int] = None
|
||||
deposit: Optional[int] = None
|
||||
rolling_deposit: Optional[int] = None
|
||||
imbalance: Optional[int] = None
|
||||
imbalance_penalty: Optional[int] = None
|
||||
influence: Optional[int] = None
|
||||
reward: Optional[int] = None
|
||||
round_earnings: Optional[int] = None
|
||||
qualifying_percent_rolling_deposit: Optional[int] = None
|
||||
|
||||
@dataclass
|
||||
class Player:
|
||||
id: str
|
||||
details: PlayerDetails
|
||||
block_data: Optional[PlayerBlockData] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict) -> "Player":
|
||||
data = d.pop("block_data")
|
||||
return cls(
|
||||
id=d.pop("id"),
|
||||
details=PlayerDetails(**d.pop("details")),
|
||||
block_data=PlayerBlockData(**data) if data else None
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class ProofState:
|
||||
block_confirmed: Optional[int] = None
|
||||
submission_delay: Optional[int] = None
|
||||
|
||||
@dataclass
|
||||
class Proof:
|
||||
benchmark_id: str
|
||||
state: Optional[ProofState] = None
|
||||
solutions_data: Optional[List[SolutionData]] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict) -> "Proof":
|
||||
solutions_data = d.pop("solutions_data")
|
||||
return cls(
|
||||
benchmark_id=d.pop("benchmark_id"),
|
||||
state=ProofState(**d.pop("state")),
|
||||
solutions_data=[SolutionData(**s) for s in solutions_data] if solutions_data else None
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class FraudState:
|
||||
block_confirmed: Optional[int] = None
|
||||
|
||||
@dataclass
|
||||
class Fraud:
|
||||
benchmark_id: str
|
||||
state: Optional[FraudState] = None
|
||||
allegation: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict) -> "Fraud":
|
||||
return cls(
|
||||
benchmark_id=d.pop("benchmark_id"),
|
||||
state=FraudState(**d.pop("state")),
|
||||
allegation=d.pop("allegation", None)
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class WasmDetails:
|
||||
compile_success: bool
|
||||
download_url: Optional[str] = None
|
||||
checksum: Optional[str] = None
|
||||
|
||||
@dataclass
|
||||
class WasmState:
|
||||
block_confirmed: Optional[int] = None
|
||||
|
||||
@dataclass
|
||||
class Wasm:
|
||||
algorithm_id: str
|
||||
details: WasmDetails
|
||||
state: Optional[WasmState] = None
|
||||
wasm_blob: Optional[bytes] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict) -> "Wasm":
|
||||
return cls(
|
||||
algorithm_id=d.pop("algorithm_id"),
|
||||
details=WasmDetails(**d.pop("details")),
|
||||
state=WasmState(**d.pop("state")),
|
||||
wasm_blob=d.pop("wasm_blob", None)
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class QueryData:
|
||||
block: Block
|
||||
algorithms: Dict[str, Algorithm]
|
||||
wasms: Dict[str, Wasm]
|
||||
player: Optional[Player]
|
||||
benchmarks: Dict[str, Benchmark]
|
||||
proofs: Dict[str, Proof]
|
||||
frauds: Dict[str, Fraud]
|
||||
challenges: Dict[str, Challenge]
|
||||
|
||||
@dataclass
|
||||
class Timestamps:
|
||||
start: int
|
||||
end: int
|
||||
submit: int
|
||||
|
||||
@dataclass
|
||||
class Job:
|
||||
download_url: str
|
||||
benchmark_id: str
|
||||
settings: BenchmarkSettings
|
||||
solution_signature_threshold: int
|
||||
sampled_nonces: Optional[List[int]]
|
||||
wasm_vm_config: dict
|
||||
weight: float
|
||||
timestamps: Timestamps
|
||||
solutions_data: Dict[int, SolutionData]
|
||||
|
||||
@dataclass
|
||||
class State:
|
||||
query_data: QueryData
|
||||
available_jobs: Dict[str, Job]
|
||||
pending_benchmark_jobs: Dict[str, Job]
|
||||
pending_proof_jobs: Dict[str, Job]
|
||||
submitted_proof_ids: Set[str]
|
||||
difficulty_samplers: dict
|
||||
@ -1,9 +0,0 @@
|
||||
from hashlib import md5
|
||||
from datetime import datetime
|
||||
|
||||
def now() -> int:
|
||||
return int(datetime.now().timestamp() * 1000)
|
||||
|
||||
def u32_from_str(input_str: str) -> int:
|
||||
result = md5(input_str.encode('utf-8')).digest()
|
||||
return int.from_bytes(result[-4:], byteorder='little', signed=False)
|
||||
@ -2,4 +2,5 @@ aiohttp
|
||||
asyncio
|
||||
quart
|
||||
hypercorn
|
||||
dataclasses
|
||||
dataclasses
|
||||
randomname
|
||||
112
tig-benchmarker/slave.py
Normal file
112
tig-benchmarker/slave.py
Normal file
@ -0,0 +1,112 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import randomname
|
||||
import requests
|
||||
import subprocess
|
||||
import time
|
||||
from dataclasses import dataclass, asdict
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
@dataclass
|
||||
class BenchmarkSettings:
|
||||
player_id: str
|
||||
block_id: str
|
||||
challenge_id: str
|
||||
algorithm_id: str
|
||||
difficulty: List[int]
|
||||
|
||||
def main(
|
||||
master_ip: str,
|
||||
tig_worker_path: str,
|
||||
wasm_folder: str,
|
||||
num_workers: int,
|
||||
slave_name: str,
|
||||
master_port: int,
|
||||
api_url: str
|
||||
):
|
||||
if not os.path.exists(tig_worker_path):
|
||||
raise FileNotFoundError(f"tig-worker not found at path: {tig_worker_path}")
|
||||
if not os.path.exists(wasm_folder):
|
||||
raise FileNotFoundError(f"WASM folder not found at path: {wasm_folder}")
|
||||
|
||||
headers = {
|
||||
"User-Agent": slave_name
|
||||
}
|
||||
get_job_url = f"http://{master_ip}:{master_port}/get-job"
|
||||
submit_results_url = f"http://{master_ip}:{master_port}/submit-results"
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Step 1: Query for job
|
||||
start = datetime.now()
|
||||
print(f"Fetching Job: url={get_job_url}, headers={headers}")
|
||||
response = requests.get(get_job_url, headers=headers)
|
||||
response.raise_for_status()
|
||||
job = response.json()
|
||||
job = Job(settings=BenchmarkSettings(**job.pop("settings")), **job)
|
||||
print(f"Fetching Job: took {(datetime.now() - start).total_seconds()} seconds")
|
||||
print(f"Job: {job}")
|
||||
|
||||
# Step 2: Download WASM
|
||||
wasm_path = os.path.join(wasm_folder, f"{job.settings.algorithm_id}.wasm")
|
||||
if not os.path.exists(wasm_path):
|
||||
start = datetime.now()
|
||||
download_url = f"{api_url}/get-wasm-blob?algorithm_id={job.settings.algorithm_id}"
|
||||
print(f"Downloading WASM: {download_url}")
|
||||
response = requests.get(download_url)
|
||||
response.raise_for_status()
|
||||
with open(wasm_path, 'wb') as f:
|
||||
f.write(response.content)
|
||||
print(f"Downloading WASM: took {(datetime.now() - start).total_seconds()} seconds")
|
||||
print(f"WASM Path: {wasm_path}")
|
||||
|
||||
# Step 3: Run tig-worker
|
||||
start = datetime.now()
|
||||
cmd = [
|
||||
tig_worker_path, "compute_batch",
|
||||
json.dumps(asdict(job.settings)),
|
||||
job.rand_hash,
|
||||
str(job.start_nonce),
|
||||
str(job.num_nonces),
|
||||
str(job.batch_size),
|
||||
wasm_path,
|
||||
"--mem", str(job.wasm_vm_config["max_mem"]),
|
||||
"--fuel", str(job.wasm_vm_config["max_fuel"]),
|
||||
"--workers", str(num_workers),
|
||||
]
|
||||
if job.sampled_nonces:
|
||||
cmd += ["--sampled", *map(str, job.sampled_nonces)]
|
||||
print(f"Running Command: {' '.join(cmd)}")
|
||||
cmd_start = datetime.now()
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
|
||||
result = json.loads(result.stdout)
|
||||
print(f"Running Command: took {(datetime.now() - cmd_start).total_seconds()} seconds")
|
||||
print(f"Result: {result}")
|
||||
|
||||
# Step 4: Submit results
|
||||
start = datetime.now()
|
||||
print(f"Submitting Results: url={submit_results_url}/{job.id}, headers={headers}")
|
||||
submit_url = f"{submit_results_url}/{job.id}"
|
||||
submit_response = requests.post(submit_url, json=result, headers=headers)
|
||||
submit_response.raise_for_status()
|
||||
print(f"Submitting Results: took {(datetime.now() - cmd_start).total_seconds()} seconds")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {str(e)}")
|
||||
time.sleep(5)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="TIG Slave Benchmarker")
|
||||
parser.add_argument("master_ip", help="IP address of the master")
|
||||
parser.add_argument("tig_worker_path", help="Path to tig-worker executable")
|
||||
parser.add_argument("wasm_folder", help="Path to folder to download WASMs")
|
||||
parser.add_argument("--workers", type=int, default=8, help="Number of workers (default: 8)")
|
||||
parser.add_argument("--name", type=str, default=randomname.get_name(), help="Name for the slave (default: randomly generated)")
|
||||
parser.add_argument("--port", type=int, default=5115, help="Port for master (default: 5115)")
|
||||
parser.add_argument("--api", type=str, default="https://mainnet-api.tig.foundation", help="TIG API URL (default: https://mainnet-api.tig.foundation)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args.master_ip, args.tig_worker_path, args.wasm_folder, args.workers, args.name, args.port, args.api)
|
||||
70
tig-benchmarker/tests/data.py
Normal file
70
tig-benchmarker/tests/data.py
Normal file
@ -0,0 +1,70 @@
|
||||
import unittest
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
|
||||
from tig_benchmarker.utils import u64s_from_str, u8s_from_str, jsonify
|
||||
from tig_benchmarker.merkle_tree import MerkleHash
|
||||
from tig_benchmarker.data import BenchmarkSettings, OutputData
|
||||
|
||||
class TestData(unittest.TestCase):
|
||||
def test_calc_solution_signature(self):
|
||||
solution = {
|
||||
"data_x": 42,
|
||||
"data_y": "test"
|
||||
}
|
||||
|
||||
output_data = OutputData(
|
||||
nonce=123,
|
||||
runtime_signature=456,
|
||||
fuel_consumed=789,
|
||||
solution=solution
|
||||
)
|
||||
|
||||
# Assert same as Rust version: tig-structs/tests/core.rs
|
||||
self.assertEqual(output_data.calc_solution_signature(), 11549591319018095145)
|
||||
|
||||
def test_calc_seed(self):
|
||||
settings = BenchmarkSettings(
|
||||
player_id="some_player",
|
||||
block_id="some_block",
|
||||
challenge_id="some_challenge",
|
||||
algorithm_id="some_algorithm",
|
||||
difficulty=[1, 2, 3]
|
||||
)
|
||||
|
||||
rand_hash = "random_hash"
|
||||
nonce = 1337
|
||||
|
||||
# Assert same as Rust version: tig-structs/tests/core.rs
|
||||
expected = bytes([
|
||||
135, 168, 152, 35, 57, 28, 184, 91, 10, 189, 139, 111, 171, 82, 156, 14,
|
||||
165, 68, 80, 41, 169, 236, 42, 41, 198, 73, 124, 78, 130, 216, 168, 67
|
||||
])
|
||||
self.assertEqual(settings.calc_seed(rand_hash, nonce), expected)
|
||||
|
||||
def test_outputdata_to_merklehash(self):
|
||||
solution = {
|
||||
"data_x": 42,
|
||||
"data_y": "test"
|
||||
}
|
||||
|
||||
output_data = OutputData(
|
||||
nonce=123,
|
||||
runtime_signature=456,
|
||||
fuel_consumed=789,
|
||||
solution=solution
|
||||
)
|
||||
|
||||
merkle_hash = output_data.to_merkle_hash()
|
||||
|
||||
# Assert same as Rust version: tig-structs/tests/core.rs
|
||||
expected = MerkleHash(bytes([
|
||||
207, 29, 184, 163, 158, 22, 137, 73, 72, 58, 24, 246, 67, 9, 44, 20,
|
||||
32, 22, 86, 206, 191, 5, 52, 241, 41, 113, 198, 85, 11, 53, 190, 57
|
||||
]))
|
||||
self.assertEqual(merkle_hash, expected)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
109
tig-benchmarker/tests/merkle_tree.py
Normal file
109
tig-benchmarker/tests/merkle_tree.py
Normal file
@ -0,0 +1,109 @@
|
||||
import unittest
|
||||
import sys
|
||||
import os
|
||||
from blake3 import blake3
|
||||
from typing import List
|
||||
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
|
||||
from tig_benchmarker.merkle_tree import MerkleHash, MerkleTree, MerkleBranch
|
||||
|
||||
def create_test_hashes() -> List[MerkleHash]:
|
||||
return [MerkleHash(blake3(i.to_bytes(4, 'big')).digest()) for i in range(9)]
|
||||
|
||||
class TestMerkleTree(unittest.TestCase):
|
||||
def test_merkle_tree(self):
|
||||
hashes = create_test_hashes()
|
||||
|
||||
tree = MerkleTree(hashes, 16)
|
||||
root = tree.calc_merkle_root()
|
||||
# Assert same as Rust version: tig-utils/tests/merkle_tree.rs
|
||||
self.assertEqual(root, MerkleHash(bytes.fromhex("fa6d5e8cb2667f5e340b8d1a145891859ad34391cd232f4fbc8d28d8d6284e15")))
|
||||
|
||||
branch = tree.calc_merkle_branch(7)
|
||||
self.assertEqual(len(branch.stems), 4)
|
||||
leaf_hash = hashes[7]
|
||||
calculated_root = branch.calc_merkle_root(leaf_hash, 7)
|
||||
self.assertEqual(root, calculated_root)
|
||||
|
||||
branch = tree.calc_merkle_branch(8)
|
||||
self.assertEqual(len(branch.stems), 1)
|
||||
leaf_hash = hashes[8]
|
||||
calculated_root = branch.calc_merkle_root(leaf_hash, 8)
|
||||
self.assertEqual(root, calculated_root)
|
||||
|
||||
def test_batched_tree(self):
|
||||
hashes = create_test_hashes()
|
||||
tree = MerkleTree(hashes, 16)
|
||||
|
||||
batches = [MerkleTree(hashes[i:i+4], 4) for i in range(0, len(hashes), 4)]
|
||||
batch_roots = [batch.calc_merkle_root() for batch in batches]
|
||||
batch_tree = MerkleTree(batch_roots, 4)
|
||||
root = tree.calc_merkle_root()
|
||||
self.assertEqual(root, batch_tree.calc_merkle_root())
|
||||
# Assert same as Rust version: tig-utils/tests/merkle_tree.rs
|
||||
self.assertEqual(root, MerkleHash(bytes.fromhex("fa6d5e8cb2667f5e340b8d1a145891859ad34391cd232f4fbc8d28d8d6284e15")))
|
||||
|
||||
branch = tree.calc_merkle_branch(7)
|
||||
batch_branch = batches[1].calc_merkle_branch(3)
|
||||
batch_branch.stems.extend(
|
||||
[(d + 2, h) for d, h in batch_tree.calc_merkle_branch(1).stems]
|
||||
)
|
||||
self.assertEqual(branch.stems, batch_branch.stems)
|
||||
|
||||
branch = tree.calc_merkle_branch(8)
|
||||
batch_branch = batches[2].calc_merkle_branch(0)
|
||||
batch_branch.stems.extend(
|
||||
[(d + 2, h) for d, h in batch_tree.calc_merkle_branch(2).stems]
|
||||
)
|
||||
self.assertEqual(branch.stems, batch_branch.stems)
|
||||
|
||||
def test_invalid_tree_size(self):
|
||||
hashes = create_test_hashes()
|
||||
with self.assertRaises(ValueError):
|
||||
MerkleTree(hashes, 8)
|
||||
|
||||
def test_invalid_branch_index(self):
|
||||
hashes = create_test_hashes()
|
||||
tree = MerkleTree(hashes, 16)
|
||||
with self.assertRaises(ValueError):
|
||||
tree.calc_merkle_branch(16)
|
||||
|
||||
def test_invalid_branch(self):
|
||||
hashes = create_test_hashes()
|
||||
tree = MerkleTree(hashes, 16)
|
||||
branch = tree.calc_merkle_branch(7)
|
||||
branch.stems[0] = (10, branch.stems[0][1]) # Modify depth to an invalid value
|
||||
with self.assertRaises(ValueError):
|
||||
branch.calc_merkle_root(hashes[7], 7)
|
||||
|
||||
def test_serialization(self):
|
||||
hashes = create_test_hashes()
|
||||
tree = MerkleTree(hashes, 16)
|
||||
branch = tree.calc_merkle_branch(7)
|
||||
|
||||
tree_str = tree.to_str()
|
||||
# Assert same as Rust version: tig-utils/tests/merkle_tree.rs
|
||||
self.assertEqual(tree_str, "0000000000000010ec2bd03bf86b935fa34d71ad7ebb049f1f10f87d343e521511d8f9e6625620cda4b6064b23dbaa408b171b0fed5628afa267ef40a4f5a806ae2405e85fa6f1c460604abfd7695c05c911fd1ba39654b8381bcee3797692bb863134aa16b68a2c5882f75066fd0398619cdfe6fcfa463ad254ebdecc381c10dd328cb07b498486988d142bfec4b57545a44b809984ab6bee66df2f6d3fb349532199a9daf6a7a2d2f2ce2738e64d2dd1c507c90673c5a3b7d0bb3077a3947a4aa17aa24dc2c48db8c9e67f5bdeaf090a49c34b6fb567d1fa6ffaee939a2c875c510a1d1e6d4a6cb9d8db6bb71b4287b682b768b62a83a92da369d8d66a10980e5e32e4e429aea50cfe342e104404324f40468de99d6f9ad7b8ae4ab228cf1ccd84b4963b12aea5")
|
||||
deserialized_tree = MerkleTree.from_str(tree_str)
|
||||
self.assertEqual(tree.calc_merkle_root(), deserialized_tree.calc_merkle_root())
|
||||
|
||||
branch_str = branch.to_str()
|
||||
# Assert same as Rust version: tig-utils/tests/merkle_tree.rs
|
||||
self.assertEqual(branch_str, "00b8c9e67f5bdeaf090a49c34b6fb567d1fa6ffaee939a2c875c510a1d1e6d4a6c01897c33b84ad3657652be252aae642f7c5e1bdf4e22231d013907254e817753d602f94c4d317f59fd4df80655d879260ce43279ae1962953d79c90d6fb26970b27a030cfe342e104404324f40468de99d6f9ad7b8ae4ab228cf1ccd84b4963b12aea5")
|
||||
deserialized_branch = MerkleBranch.from_str(branch_str)
|
||||
self.assertEqual(
|
||||
branch.calc_merkle_root(hashes[7], 7),
|
||||
deserialized_branch.calc_merkle_root(hashes[7], 7)
|
||||
)
|
||||
|
||||
def test_merkle_hash_serialization(self):
|
||||
hash = MerkleHash(bytes([1] * 32))
|
||||
serialized = hash.to_str()
|
||||
# Assert same as Rust version: tig-utils/tests/merkle_tree.rs
|
||||
self.assertEqual(serialized, "0101010101010101010101010101010101010101010101010101010101010101")
|
||||
deserialized = MerkleHash.from_str(serialized)
|
||||
self.assertEqual(hash, deserialized)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
306
tig-benchmarker/tig_benchmarker/data.py
Normal file
306
tig-benchmarker/tig_benchmarker/data.py
Normal file
@ -0,0 +1,306 @@
|
||||
from tig_benchmarker.merkle_tree import MerkleHash, MerkleBranch
|
||||
from tig_benchmarker.utils import FromDict, u64s_from_str, u8s_from_str, jsonify
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Set, Any, Tuple
|
||||
|
||||
Point = Tuple[int, ...]
|
||||
Frontier = Set[Point]
|
||||
|
||||
@dataclass
|
||||
class AlgorithmDetails(FromDict):
|
||||
name: str
|
||||
player_id: str
|
||||
challenge_id: str
|
||||
tx_hash: str
|
||||
|
||||
@dataclass
|
||||
class AlgorithmState(FromDict):
|
||||
block_confirmed: int
|
||||
round_submitted: int
|
||||
round_pushed: Optional[int]
|
||||
round_merged: Optional[int]
|
||||
banned: bool
|
||||
|
||||
@dataclass
|
||||
class AlgorithmBlockData(FromDict):
|
||||
num_qualifiers_by_player: Dict[str, int]
|
||||
adoption: int
|
||||
merge_points: int
|
||||
reward: int
|
||||
round_earnings: int
|
||||
|
||||
@dataclass
|
||||
class Algorithm(FromDict):
|
||||
id: str
|
||||
details: AlgorithmDetails
|
||||
state: AlgorithmState
|
||||
block_data: Optional[AlgorithmBlockData]
|
||||
code: Optional[str]
|
||||
|
||||
@dataclass
|
||||
class BenchmarkSettings(FromDict):
|
||||
player_id: str
|
||||
block_id: str
|
||||
challenge_id: str
|
||||
algorithm_id: str
|
||||
difficulty: List[int]
|
||||
|
||||
def calc_seed(self, rand_hash: str, nonce: int) -> bytes:
|
||||
return u8s_from_str(f"{jsonify(self)}_{rand_hash}_{nonce}")
|
||||
|
||||
@dataclass
|
||||
class PrecommitDetails(FromDict):
|
||||
block_started: int
|
||||
num_nonces: Optional[int] # Optional for backwards compatibility
|
||||
fee_paid: Optional[int] # Optional for backwards compatibility
|
||||
|
||||
@dataclass
|
||||
class PrecommitState(FromDict):
|
||||
block_confirmed: int
|
||||
rand_hash: Optional[str] # Optional for backwards compatibility
|
||||
|
||||
@dataclass
|
||||
class Precommit(FromDict):
|
||||
benchmark_id: str
|
||||
details: PrecommitDetails
|
||||
settings: BenchmarkSettings
|
||||
state: PrecommitState
|
||||
|
||||
@dataclass
|
||||
class BenchmarkDetails(FromDict):
|
||||
num_solutions: int
|
||||
merkle_root: Optional[MerkleHash] # Optional for backwards compatibility
|
||||
|
||||
@dataclass
|
||||
class BenchmarkState(FromDict):
|
||||
block_confirmed: int
|
||||
sampled_nonces: List[int]
|
||||
|
||||
@dataclass
|
||||
class Benchmark(FromDict):
|
||||
id: str
|
||||
details: BenchmarkDetails
|
||||
state: BenchmarkState
|
||||
solution_nonces: Optional[Set[int]]
|
||||
|
||||
@dataclass
|
||||
class OutputMetaData(FromDict):
|
||||
nonce: int
|
||||
runtime_signature: int
|
||||
fuel_consumed: int
|
||||
solution_signature: int
|
||||
|
||||
@classmethod
|
||||
def from_output_data(cls, output_data: 'OutputData') -> 'OutputMetaData':
|
||||
return OutputData.to_output_metadata()
|
||||
|
||||
def to_merkle_hash(self) -> MerkleHash:
|
||||
return MerkleHash(u8s_from_str(jsonify(self)))
|
||||
|
||||
@dataclass
|
||||
class OutputData(FromDict):
|
||||
nonce: int
|
||||
runtime_signature: int
|
||||
fuel_consumed: int
|
||||
solution: Dict[str, Any]
|
||||
|
||||
def calc_solution_signature(self) -> int:
|
||||
return u64s_from_str(jsonify(self.solution))[0]
|
||||
|
||||
def to_output_metadata(self) -> OutputMetaData:
|
||||
return OutputMetaData(
|
||||
nonce=self.nonce,
|
||||
runtime_signature=self.runtime_signature,
|
||||
fuel_consumed=self.fuel_consumed,
|
||||
solution_signature=self.calc_solution_signature()
|
||||
)
|
||||
|
||||
def to_merkle_hash(self) -> MerkleHash:
|
||||
return self.to_output_metadata().to_merkle_hash()
|
||||
|
||||
@dataclass
|
||||
class MerkleProof(FromDict):
|
||||
leaf: OutputData
|
||||
branch: Optional[MerkleBranch] # Optional for backwards compatibility
|
||||
|
||||
@dataclass
|
||||
class ProofState(FromDict):
|
||||
block_confirmed: int
|
||||
submission_delay: int
|
||||
|
||||
@dataclass
|
||||
class Proof(FromDict):
|
||||
benchmark_id: str
|
||||
state: Optional[ProofState]
|
||||
merkle_proofs: Optional[List[MerkleProof]]
|
||||
|
||||
@dataclass
|
||||
class FraudState(FromDict):
|
||||
block_confirmed: int
|
||||
|
||||
@dataclass
|
||||
class Fraud(FromDict):
|
||||
benchmark_id: str
|
||||
state: FraudState
|
||||
allegation: Optional[str]
|
||||
|
||||
@dataclass
|
||||
class BlockDetails(FromDict):
|
||||
prev_block_id: str
|
||||
height: int
|
||||
round: int
|
||||
eth_block_num: Optional[str] # Optional for backwards compatability
|
||||
fees_paid: Optional[int] # Optional for backwards compatability
|
||||
num_confirmed_challenges: Optional[int] # Optional for backwards compatability
|
||||
num_confirmed_algorithms: Optional[int] # Optional for backwards compatability
|
||||
num_confirmed_benchmarks: Optional[int] # Optional for backwards compatability
|
||||
num_confirmed_precommits: Optional[int] # Optional for backwards compatability
|
||||
num_confirmed_proofs: Optional[int] # Optional for backwards compatability
|
||||
num_confirmed_frauds: Optional[int] # Optional for backwards compatability
|
||||
num_confirmed_topups: Optional[int] # Optional for backwards compatability
|
||||
num_confirmed_wasms: Optional[int] # Optional for backwards compatability
|
||||
num_active_challenges: Optional[int] # Optional for backwards compatability
|
||||
num_active_algorithms: Optional[int] # Optional for backwards compatability
|
||||
num_active_benchmarks: Optional[int] # Optional for backwards compatability
|
||||
num_active_players: Optional[int] # Optional for backwards compatability
|
||||
|
||||
@dataclass
|
||||
class BlockData(FromDict):
|
||||
confirmed_challenge_ids: Set[int]
|
||||
confirmed_algorithm_ids: Set[int]
|
||||
confirmed_benchmark_ids: Set[int]
|
||||
confirmed_precommit_ids: Set[int]
|
||||
confirmed_proof_ids: Set[int]
|
||||
confirmed_fraud_ids: Set[int]
|
||||
confirmed_topup_ids: Set[int]
|
||||
confirmed_wasm_ids: Set[int]
|
||||
active_challenge_ids: Set[int]
|
||||
active_algorithm_ids: Set[int]
|
||||
active_benchmark_ids: Set[int]
|
||||
active_player_ids: Set[int]
|
||||
|
||||
@dataclass
|
||||
class Block(FromDict):
|
||||
id: str
|
||||
details: BlockDetails
|
||||
config: dict
|
||||
data: Optional[BlockData]
|
||||
|
||||
@dataclass
|
||||
class ChallengeDetails(FromDict):
|
||||
name: str
|
||||
|
||||
@dataclass
|
||||
class ChallengeState(FromDict):
|
||||
block_confirmed: int
|
||||
round_active: Optional[int]
|
||||
|
||||
@dataclass
|
||||
class ChallengeBlockData(FromDict):
|
||||
solution_signature_threshold: int
|
||||
num_qualifiers: int
|
||||
qualifier_difficulties: Set[Point]
|
||||
base_frontier: Frontier
|
||||
scaled_frontier: Frontier
|
||||
scaling_factor: float
|
||||
|
||||
@dataclass
|
||||
class Challenge(FromDict):
|
||||
id: str
|
||||
details: ChallengeDetails
|
||||
state: ChallengeState
|
||||
block_data: Optional[ChallengeBlockData]
|
||||
|
||||
@dataclass
|
||||
class PlayerDetails(FromDict):
|
||||
name: str
|
||||
is_multisig: bool
|
||||
|
||||
@dataclass
|
||||
class PlayerBlockData(FromDict):
|
||||
num_qualifiers_by_challenge: Optional[Dict[str, int]]
|
||||
cutoff: Optional[int]
|
||||
deposit: Optional[int]
|
||||
rolling_deposit: Optional[int]
|
||||
qualifying_percent_rolling_deposit: Optional[int]
|
||||
imbalance: Optional[int]
|
||||
imbalance_penalty: Optional[int]
|
||||
influence: Optional[int]
|
||||
reward: Optional[int]
|
||||
round_earnings: int
|
||||
|
||||
@dataclass
|
||||
class PlayerState(FromDict):
|
||||
total_fees_paid: int
|
||||
available_fee_balance: int
|
||||
|
||||
@dataclass
|
||||
class Player(FromDict):
|
||||
id: str
|
||||
details: PlayerDetails
|
||||
state: Optional[PlayerState]
|
||||
block_data: Optional[PlayerBlockData]
|
||||
|
||||
@dataclass
|
||||
class WasmDetails(FromDict):
|
||||
compile_success: bool
|
||||
download_url: Optional[str]
|
||||
checksum: Optional[str]
|
||||
|
||||
@dataclass
|
||||
class WasmState(FromDict):
|
||||
block_confirmed: int
|
||||
|
||||
@dataclass
|
||||
class Wasm(FromDict):
|
||||
algorithm_id: str
|
||||
details: WasmDetails
|
||||
state: WasmState
|
||||
wasm_blob: Optional[bytes]
|
||||
|
||||
@dataclass
|
||||
class TopUpDetails(FromDict):
|
||||
player_id: str
|
||||
amount: int
|
||||
|
||||
@dataclass
|
||||
class TopUpState(FromDict):
|
||||
block_confirmed: int
|
||||
|
||||
@dataclass
|
||||
class TopUp(FromDict):
|
||||
id: str
|
||||
details: TopUpDetails
|
||||
state: TopUpState
|
||||
|
||||
@dataclass
|
||||
class QueryData(FromDict):
|
||||
block: Block
|
||||
algorithms: Dict[str, Algorithm]
|
||||
wasms: Dict[str, Wasm]
|
||||
player: Optional[Player]
|
||||
precommits: Dict[str, Precommit]
|
||||
benchmarks: Dict[str, Benchmark]
|
||||
proofs: Dict[str, Proof]
|
||||
frauds: Dict[str, Fraud]
|
||||
challenges: Dict[str, Challenge]
|
||||
|
||||
@dataclass
|
||||
class Job(FromDict):
|
||||
id: str
|
||||
settings: BenchmarkSettings
|
||||
rand_hash: str
|
||||
start_nonce: int
|
||||
num_nonces: int
|
||||
batch_size: int
|
||||
wasm_vm_config: dict
|
||||
sampled_nonces: Optional[List[int]]
|
||||
|
||||
# @dataclass
|
||||
# class State:
|
||||
# query_data: QueryData
|
||||
# available_jobs: Dict[str, Job]
|
||||
# pending_benchmark_jobs: Dict[str, Job]
|
||||
# pending_proof_jobs: Dict[str, Job]
|
||||
# submitted_proof_ids: Set[str]
|
||||
# difficulty_samplers: dict
|
||||
152
tig-benchmarker/tig_benchmarker/merkle_tree.py
Normal file
152
tig-benchmarker/tig_benchmarker/merkle_tree.py
Normal file
@ -0,0 +1,152 @@
|
||||
from blake3 import blake3
|
||||
from typing import List, Tuple
|
||||
from tig_benchmarker.utils import FromStr, u8s_from_str
|
||||
|
||||
class MerkleHash(FromStr):
|
||||
def __init__(self, value: bytes):
|
||||
if len(value) != 32:
|
||||
raise ValueError("MerkleHash must be exactly 32 bytes")
|
||||
self.value = value
|
||||
|
||||
@classmethod
|
||||
def from_str(cls, str: str):
|
||||
return cls(bytes.fromhex(str))
|
||||
|
||||
@classmethod
|
||||
def null(cls):
|
||||
return cls(bytes([0] * 32))
|
||||
|
||||
def to_str(self):
|
||||
return self.value.hex()
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, MerkleHash) and self.value == other.value
|
||||
|
||||
def __repr__(self):
|
||||
return f"MerkleHash({self.to_str()})"
|
||||
|
||||
class MerkleTree(FromStr):
|
||||
def __init__(self, hashed_leafs: List[MerkleHash], n: int):
|
||||
if len(hashed_leafs) > n:
|
||||
raise ValueError("Invalid tree size")
|
||||
if n & (n - 1) != 0:
|
||||
raise ValueError("n must be a power of 2")
|
||||
self.hashed_leafs = hashed_leafs
|
||||
self.n = n
|
||||
|
||||
def to_str(self):
|
||||
"""Serializes the MerkleTree to a string"""
|
||||
n_hex = f"{self.n:016x}"
|
||||
hashes_hex = ''.join([h.to_str() for h in self.hashed_leafs])
|
||||
return n_hex + hashes_hex
|
||||
|
||||
def __repr__(self):
|
||||
return f"MerkleTree([{', '.join([str(h) for h in self.hashed_leafs])}], {self.n})"
|
||||
|
||||
@classmethod
|
||||
def from_str(cls, s: str):
|
||||
"""Deserializes a MerkleTree from a string"""
|
||||
if len(s) < 16 or (len(s) - 16) % 64 != 0:
|
||||
raise ValueError("Invalid MerkleTree string length")
|
||||
|
||||
n_hex = s[:16]
|
||||
n = int(n_hex, 16)
|
||||
|
||||
hashes_hex = s[16:]
|
||||
hashed_leafs = [
|
||||
MerkleHash.from_str(hashes_hex[i:i + 64])
|
||||
for i in range(0, len(hashes_hex), 64)
|
||||
]
|
||||
|
||||
return cls(hashed_leafs, n)
|
||||
|
||||
def calc_merkle_root(self) -> MerkleHash:
|
||||
hashes = self.hashed_leafs[:]
|
||||
|
||||
while len(hashes) > 1:
|
||||
new_hashes = []
|
||||
for i in range(0, len(hashes), 2):
|
||||
left = hashes[i]
|
||||
result = MerkleHash(left.value)
|
||||
if i + 1 < len(hashes):
|
||||
right = hashes[i + 1]
|
||||
combined = left.value + right.value
|
||||
result = MerkleHash(blake3(combined).digest())
|
||||
new_hashes.append(result)
|
||||
hashes = new_hashes
|
||||
|
||||
return hashes[0]
|
||||
|
||||
def calc_merkle_branch(self, branch_idx: int) -> 'MerkleBranch':
|
||||
if branch_idx >= self.n:
|
||||
raise ValueError("Invalid branch index")
|
||||
|
||||
hashes = self.hashed_leafs[:]
|
||||
branch = []
|
||||
idx = branch_idx
|
||||
depth = 0
|
||||
|
||||
while len(hashes) > 1:
|
||||
new_hashes = []
|
||||
for i in range(0, len(hashes), 2):
|
||||
left = hashes[i]
|
||||
result = MerkleHash(left.value)
|
||||
if i + 1 < len(hashes):
|
||||
right = hashes[i + 1]
|
||||
if idx // 2 == i // 2:
|
||||
branch.append((depth, right if idx % 2 == 0 else left))
|
||||
combined = left.value + right.value
|
||||
result = MerkleHash(blake3(combined).digest())
|
||||
new_hashes.append(result)
|
||||
hashes = new_hashes
|
||||
idx //= 2
|
||||
depth += 1
|
||||
|
||||
return MerkleBranch(branch)
|
||||
|
||||
class MerkleBranch:
|
||||
def __init__(self, stems: List[Tuple[int, MerkleHash]]):
|
||||
self.stems = stems
|
||||
|
||||
def calc_merkle_root(self, hashed_leaf: MerkleHash, branch_idx: int) -> MerkleHash:
|
||||
root = hashed_leaf
|
||||
idx = branch_idx
|
||||
curr_depth = 0
|
||||
|
||||
for depth, hash in self.stems:
|
||||
if curr_depth > depth:
|
||||
raise ValueError("Invalid branch")
|
||||
while curr_depth != depth:
|
||||
idx //= 2
|
||||
curr_depth += 1
|
||||
|
||||
if idx % 2 == 0:
|
||||
combined = root.value + hash.value
|
||||
else:
|
||||
combined = hash.value + root.value
|
||||
root = MerkleHash(blake3(combined).digest())
|
||||
idx //= 2
|
||||
curr_depth += 1
|
||||
|
||||
return root
|
||||
|
||||
def to_str(self):
|
||||
"""Serializes the MerkleBranch to a hex string"""
|
||||
return ''.join([f"{depth:02x}{hash.to_str()}" for depth, hash in self.stems])
|
||||
|
||||
def __repr__(self):
|
||||
return f"MerkleBranch([{', '.join([f'({depth}, {hash})' for depth, hash in self.stems])}])"
|
||||
|
||||
@classmethod
|
||||
def from_str(cls, s: str):
|
||||
"""Deserializes a MerkleBranch from a hex string"""
|
||||
if len(s) % 66 != 0:
|
||||
raise ValueError("Invalid MerkleBranch string length")
|
||||
|
||||
stems = []
|
||||
for i in range(0, len(s), 66):
|
||||
depth = int(s[i:i+2], 16)
|
||||
hash_hex = s[i+2:i+66]
|
||||
stems.append((depth, MerkleHash.from_str(hash_hex)))
|
||||
|
||||
return cls(stems)
|
||||
115
tig-benchmarker/tig_benchmarker/utils.py
Normal file
115
tig-benchmarker/tig_benchmarker/utils.py
Normal file
@ -0,0 +1,115 @@
|
||||
import json
|
||||
from abc import ABC, abstractclassmethod, abstractmethod
|
||||
from blake3 import blake3
|
||||
from datetime import datetime
|
||||
from dataclasses import dataclass, fields, is_dataclass
|
||||
from hashlib import md5
|
||||
from typing import TypeVar, Type, Dict, Any, List, Union, Optional, get_origin, get_args
|
||||
|
||||
T = TypeVar('T', bound='DataclassBase')
|
||||
|
||||
class FromStr(ABC):
|
||||
@abstractclassmethod
|
||||
def from_str(cls, s: str):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def to_str(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@dataclass
|
||||
class FromDict:
|
||||
@classmethod
|
||||
def from_dict(cls: Type[T], d: Dict[str, Any]) -> T:
|
||||
field_types = {f.name: f.type for f in fields(cls)}
|
||||
kwargs = {}
|
||||
|
||||
for field in fields(cls):
|
||||
value = d.pop(field.name, None)
|
||||
field_type = field_types[field.name]
|
||||
origin_type = get_origin(field_type)
|
||||
|
||||
is_optional = origin_type is Union and type(None) in get_args(field_type)
|
||||
|
||||
if value is None:
|
||||
if not is_optional:
|
||||
raise ValueError(f"Missing required field: {field.name}")
|
||||
kwargs[field.name] = None
|
||||
continue
|
||||
|
||||
if is_optional:
|
||||
field_type = next(arg for arg in get_args(field_type) if arg is not type(None))
|
||||
|
||||
kwargs[field.name] = cls._process_value(value, field_type)
|
||||
|
||||
return cls(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def _process_value(cls, value: Any, field_type: Type) -> Any:
|
||||
if hasattr(field_type, 'from_dict') and isinstance(value, dict):
|
||||
return field_type.from_dict(value)
|
||||
elif hasattr(field_type, 'from_str') and isinstance(value, str):
|
||||
return field_type.from_str(value)
|
||||
elif get_origin(field_type) in (list, set):
|
||||
elem_type = get_args(field_type)[0]
|
||||
return get_origin(field_type)(cls._process_value(item, elem_type) for item in value)
|
||||
elif get_origin(field_type) is dict:
|
||||
key_type, val_type = get_args(field_type)
|
||||
return {k: cls._process_value(v, val_type) for k, v in value.items()}
|
||||
else:
|
||||
return value
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
d = {}
|
||||
for field in fields(self):
|
||||
value = getattr(self, field.name)
|
||||
if value is not None:
|
||||
if hasattr(value, 'to_dict'):
|
||||
d[field.name] = value.to_dict()
|
||||
elif hasattr(value, 'to_str'):
|
||||
d[field.name] = value.to_str()
|
||||
elif isinstance(value, (list, set)):
|
||||
d[field.name] = [
|
||||
item.to_dict() if hasattr(item, 'to_dict')
|
||||
else item.to_str() if hasattr(item, 'to_str')
|
||||
else item
|
||||
for item in value
|
||||
]
|
||||
elif isinstance(value, dict):
|
||||
d[field.name] = {
|
||||
k: (v.to_dict() if hasattr(v, 'to_dict')
|
||||
else v.to_str() if hasattr(v, 'to_str')
|
||||
else v)
|
||||
for k, v in value.items()
|
||||
}
|
||||
elif is_dataclass(value):
|
||||
d[field.name] = asdict(value)
|
||||
else:
|
||||
d[field.name] = value
|
||||
return d
|
||||
|
||||
def now() -> int:
|
||||
return int(datetime.now().timestamp() * 1000)
|
||||
|
||||
def jsonify(obj: Any) -> str:
|
||||
if hasattr(obj, 'to_dict'):
|
||||
obj = obj.to_dict()
|
||||
return json.dumps(obj, sort_keys=True, separators=(',', ':'))
|
||||
|
||||
def u8s_from_str(input: str) -> bytes:
|
||||
return blake3(input.encode()).digest()
|
||||
|
||||
def u64s_from_str(input: str) -> List[int]:
|
||||
u8s = u8s_from_str(input)
|
||||
return [
|
||||
int.from_bytes(
|
||||
u8s[i * 8:(i + 1) * 8],
|
||||
byteorder='little',
|
||||
signed=False
|
||||
)
|
||||
for i in range(4)
|
||||
]
|
||||
|
||||
def u32_from_str(input_str: str) -> int:
|
||||
result = md5(input_str.encode('utf-8')).digest()
|
||||
return int.from_bytes(result[-4:], byteorder='little', signed=False)
|
||||
@ -168,6 +168,11 @@ impl From<OutputMetaData> for MerkleHash {
|
||||
MerkleHash(u8s_from_str(&jsonify(&data)))
|
||||
}
|
||||
}
|
||||
impl From<OutputData> for MerkleHash {
|
||||
fn from(data: OutputData) -> Self {
|
||||
MerkleHash::from(OutputMetaData::from(data))
|
||||
}
|
||||
}
|
||||
|
||||
// Block child structs
|
||||
serializable_struct_with_getters! {
|
||||
|
||||
76
tig-structs/tests/core.rs
Normal file
76
tig-structs/tests/core.rs
Normal file
@ -0,0 +1,76 @@
|
||||
use serde_json::json;
|
||||
use tig_structs::core::{BenchmarkSettings, OutputData};
|
||||
use tig_utils::MerkleHash;
|
||||
|
||||
#[test]
|
||||
fn test_calc_solution_signature() {
|
||||
let solution = json!({
|
||||
"data_x": 42,
|
||||
"data_y": "test"
|
||||
})
|
||||
.as_object()
|
||||
.unwrap()
|
||||
.clone();
|
||||
|
||||
let output_data = OutputData {
|
||||
nonce: 123,
|
||||
runtime_signature: 456,
|
||||
fuel_consumed: 789,
|
||||
solution: solution.clone(),
|
||||
};
|
||||
|
||||
// Assert same as Python version: tig-benchmarker/tests/core.rs
|
||||
assert_eq!(output_data.calc_solution_signature(), 11549591319018095145);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_calc_seed() {
|
||||
let settings = BenchmarkSettings {
|
||||
player_id: "some_player".to_string(),
|
||||
block_id: "some_block".to_string(),
|
||||
challenge_id: "some_challenge".to_string(),
|
||||
algorithm_id: "some_algorithm".to_string(),
|
||||
difficulty: vec![1, 2, 3],
|
||||
};
|
||||
|
||||
let rand_hash = "random_hash".to_string();
|
||||
let nonce = 1337;
|
||||
|
||||
// Assert same as Python version: tig-benchmarker/tests/core.rs
|
||||
assert_eq!(
|
||||
settings.calc_seed(&rand_hash, nonce),
|
||||
[
|
||||
135, 168, 152, 35, 57, 28, 184, 91, 10, 189, 139, 111, 171, 82, 156, 14, 165, 68, 80,
|
||||
41, 169, 236, 42, 41, 198, 73, 124, 78, 130, 216, 168, 67
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_outputdata_to_merklehash() {
|
||||
let solution = json!({
|
||||
"data_x": 42,
|
||||
"data_y": "test"
|
||||
})
|
||||
.as_object()
|
||||
.unwrap()
|
||||
.clone();
|
||||
|
||||
let output_data = OutputData {
|
||||
nonce: 123,
|
||||
runtime_signature: 456,
|
||||
fuel_consumed: 789,
|
||||
solution: solution.clone(),
|
||||
};
|
||||
|
||||
let merkle_hash: MerkleHash = output_data.into();
|
||||
|
||||
// Assert same as Python version: tig-benchmarker/tests/core.rs
|
||||
assert_eq!(
|
||||
merkle_hash,
|
||||
MerkleHash([
|
||||
207, 29, 184, 163, 158, 22, 137, 73, 72, 58, 24, 246, 67, 9, 44, 20, 32, 22, 86, 206,
|
||||
191, 5, 52, 241, 41, 113, 198, 85, 11, 53, 190, 57
|
||||
])
|
||||
);
|
||||
}
|
||||
@ -1,201 +0,0 @@
|
||||
import blake3
|
||||
import binascii
|
||||
from typing import List
|
||||
|
||||
|
||||
class MerkleHash:
|
||||
def __init__(self, value: bytes):
|
||||
if len(value) != 32:
|
||||
raise ValueError("MerkleHash must be exactly 32 bytes")
|
||||
self.value = value
|
||||
|
||||
@classmethod
|
||||
def from_hex(cls, hex_str: str):
|
||||
return cls(binascii.unhexlify(hex_str))
|
||||
|
||||
|
||||
@classmethod
|
||||
def null(cls):
|
||||
return cls(bytes([0] * 32))
|
||||
|
||||
def to_hex(self):
|
||||
return binascii.hexlify(self.value).decode()
|
||||
|
||||
def __str__(self):
|
||||
return self.to_hex()
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, MerkleHash) and self.value == other.value
|
||||
|
||||
def __repr__(self):
|
||||
return f"MerkleHash({self.to_hex()})"
|
||||
|
||||
class MerkleTree:
|
||||
def __init__(self, hashed_leafs: List[MerkleHash], n: int):
|
||||
if len(hashed_leafs) > n:
|
||||
raise ValueError("Invalid tree size")
|
||||
self.hashed_leafs = hashed_leafs
|
||||
self.n = n
|
||||
|
||||
def serialize(self):
|
||||
"""Serializes the MerkleTree to a string"""
|
||||
# Convert 'n' to a 16-character hexadecimal string (padded)
|
||||
n_hex = f"{self.n:016x}"
|
||||
# Convert all MerkleHash objects to hex and concatenate
|
||||
hashes_hex = ''.join([h.to_hex() for h in self.hashed_leafs])
|
||||
# Return the serialized string
|
||||
return n_hex + hashes_hex
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, serialized_str: str):
|
||||
"""Deserializes a MerkleTree from a string"""
|
||||
if len(serialized_str) < 16:
|
||||
raise ValueError("Invalid MerkleTree string length")
|
||||
|
||||
# Extract the first 16 characters as the hex-encoded size 'n'
|
||||
n_hex = serialized_str[:16]
|
||||
n = int(n_hex, 16)
|
||||
|
||||
# Extract the remaining part as hex-encoded MerkleHash values
|
||||
hashes_hex = serialized_str[16:]
|
||||
|
||||
if len(hashes_hex) % 64 != 0:
|
||||
raise ValueError("Invalid MerkleTree hashes length")
|
||||
|
||||
# Split the string into 64-character chunks and convert them to MerkleHash objects
|
||||
hashed_leafs = [
|
||||
MerkleHash.from_hex(hashes_hex[i:i + 64])
|
||||
for i in range(0, len(hashes_hex), 64)
|
||||
]
|
||||
|
||||
return cls(hashed_leafs, n)
|
||||
|
||||
def calc_merkle_root(self) -> MerkleHash:
|
||||
null_hash = MerkleHash.null()
|
||||
hashes = self.hashed_leafs[:]
|
||||
|
||||
while len(hashes) > 1:
|
||||
new_hashes = []
|
||||
for i in range(0, len(hashes), 2):
|
||||
left = hashes[i]
|
||||
right = hashes[i+1] if i+1 < len(hashes) else null_hash
|
||||
combined = left.value + right.value
|
||||
new_hashes.append(MerkleHash(blake3.blake3(combined).digest()))
|
||||
hashes = new_hashes
|
||||
|
||||
return hashes[0]
|
||||
|
||||
def calc_merkle_proof(self, branch_idx: int):
|
||||
if branch_idx >= self.n:
|
||||
raise ValueError("Invalid branch index")
|
||||
|
||||
hashes = self.hashed_leafs[:]
|
||||
null_hash = MerkleHash.null()
|
||||
proof = []
|
||||
idx = branch_idx
|
||||
|
||||
while len(hashes) > 1:
|
||||
new_hashes = []
|
||||
for i in range(0, len(hashes), 2):
|
||||
left = hashes[i]
|
||||
right = hashes[i+1] if i+1 < len(hashes) else null_hash
|
||||
|
||||
if idx // 2 == i // 2:
|
||||
proof.append(right if idx % 2 == 0 else left)
|
||||
|
||||
combined = left.value + right.value
|
||||
new_hashes.append(MerkleHash(blake3.blake3(combined).digest()))
|
||||
hashes = new_hashes
|
||||
idx //= 2
|
||||
|
||||
return MerkleBranch(proof)
|
||||
|
||||
class MerkleBranch:
|
||||
def __init__(self, proof_hashes: List[MerkleHash]):
|
||||
self.proof_hashes = proof_hashes
|
||||
|
||||
def calc_merkle_root(self, hashed_leaf: MerkleHash, branch_idx: int) -> MerkleHash:
|
||||
root = hashed_leaf
|
||||
idx = branch_idx
|
||||
|
||||
for hash in self.proof_hashes:
|
||||
if idx % 2 == 0:
|
||||
combined = root.value + hash.value
|
||||
else:
|
||||
combined = hash.value + root.value
|
||||
root = MerkleHash(blake3.blake3(combined).digest())
|
||||
idx //= 2
|
||||
|
||||
return root
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, serialized_str: str):
|
||||
"""Deserializes a MerkleBranch from a hex string of concatenated MerkleHash values"""
|
||||
if len(serialized_str) % 64 != 0:
|
||||
raise ValueError("Invalid MerkleProof string length")
|
||||
|
||||
# Split the string into 64-character chunks (32 bytes represented as 64 hex characters)
|
||||
hashes = [
|
||||
MerkleHash.from_hex(serialized_str[i:i + 64])
|
||||
for i in range(0, len(serialized_str), 64)
|
||||
]
|
||||
|
||||
return cls(hashes)
|
||||
|
||||
def __repr__(self):
|
||||
return f"MerkleBranch({[str(h) for h in self.proof_hashes]})"
|
||||
|
||||
|
||||
# Example usage:
|
||||
import json
|
||||
# Example list of hashed leaves
|
||||
print("Hashes:")
|
||||
hashed_leafs = [MerkleHash(blake3.blake3(f"leaf {i}".encode()).digest()) for i in range(14)]
|
||||
for hashleaf in hashed_leafs:
|
||||
print(hashleaf.to_hex())
|
||||
n = len(hashed_leafs)
|
||||
|
||||
# Build the Merkle tree
|
||||
merkle_tree = MerkleTree(hashed_leafs, n)
|
||||
|
||||
# Calculate Merkle root
|
||||
root = merkle_tree.calc_merkle_root()
|
||||
|
||||
print("\nMerkle Root:\n", root)
|
||||
|
||||
# Generate Merkle proof for a specific leaf
|
||||
proof = merkle_tree.calc_merkle_proof(2)
|
||||
print("\nMerkle Proof:")
|
||||
for node in proof.proof_hashes:
|
||||
print(node.to_hex())
|
||||
|
||||
print("\nUsing serialized strings from rust: ")
|
||||
|
||||
serialized_root = '"bb3b20745d03ce3eaa4603a19056be544bba00f036725d9025205b883c0bf54e"'
|
||||
serialized_proof = '"ceb50f111fece8844fe4432ed3d19cbce3f54c2ba3994dcd37fe2ceca29791a4af311d272dc334e92c7d626141fa11430dc3b8f55a4911ae1b2542124bdbbef20c2467559ed3061deac0779b0e035514576e2910872b85a84a769087588149a9da007281955a8ed1cbcf3a6f28ec3eb41a385193a7a3a507299032effed88c77"'
|
||||
|
||||
|
||||
# Deserialize Merkle root
|
||||
root_hex = json.loads(serialized_root)
|
||||
merkle_root = MerkleHash.from_hex(root_hex)
|
||||
print("\nDeserialized Merkle Root:", merkle_root)
|
||||
|
||||
# Deserialize Merkle proof
|
||||
proof_str = json.loads(serialized_proof)
|
||||
proof = MerkleBranch.deserialize(proof_str)
|
||||
print("\nDeserialized Merkle Proof:")
|
||||
for node in proof.proof_hashes:
|
||||
print(node.to_hex())
|
||||
|
||||
|
||||
# # Verify Merkle proof and calculate root from the proof
|
||||
calculated_root = proof.calc_merkle_root(hashed_leafs[2], 2)
|
||||
print("\nCalculated Root from Proof:", calculated_root)
|
||||
|
||||
# Check if the root matches
|
||||
assert calculated_root == root
|
||||
|
||||
mt_ser = merkle_tree.serialize()
|
||||
|
||||
merkle_tree = MerkleTree.deserialize(mt_ser)
|
||||
assert merkle_tree.calc_merkle_root() == root
|
||||
@ -14,6 +14,16 @@ fn test_merkle_tree() {
|
||||
|
||||
let tree = MerkleTree::new(hashes.clone(), 16).unwrap();
|
||||
let root = tree.calc_merkle_root();
|
||||
// Assert same as Python version: tig-benchmarker/tests/merkle_tree.rs
|
||||
assert_eq!(
|
||||
root,
|
||||
MerkleHash(
|
||||
hex::decode("fa6d5e8cb2667f5e340b8d1a145891859ad34391cd232f4fbc8d28d8d6284e15")
|
||||
.unwrap()
|
||||
.try_into()
|
||||
.unwrap()
|
||||
)
|
||||
);
|
||||
|
||||
let branch = tree.calc_merkle_branch(7).unwrap();
|
||||
assert_eq!(branch.0.len(), 4);
|
||||
@ -42,7 +52,18 @@ fn test_batched_tree() {
|
||||
.map(|tree| tree.calc_merkle_root())
|
||||
.collect::<Vec<MerkleHash>>();
|
||||
let batch_tree = MerkleTree::new(batch_roots.clone(), 4).unwrap();
|
||||
assert_eq!(tree.calc_merkle_root(), batch_tree.calc_merkle_root());
|
||||
let root = tree.calc_merkle_root();
|
||||
assert_eq!(root, batch_tree.calc_merkle_root());
|
||||
// Assert same as Python version: tig-benchmarker/tests/merkle_tree.rs
|
||||
assert_eq!(
|
||||
root,
|
||||
MerkleHash(
|
||||
hex::decode("fa6d5e8cb2667f5e340b8d1a145891859ad34391cd232f4fbc8d28d8d6284e15")
|
||||
.unwrap()
|
||||
.try_into()
|
||||
.unwrap()
|
||||
)
|
||||
);
|
||||
|
||||
let branch = tree.calc_merkle_branch(7).unwrap();
|
||||
let mut batch_branch = batches[1].calc_merkle_branch(3).unwrap();
|
||||
@ -105,6 +126,8 @@ fn test_serialization() {
|
||||
let branch = tree.calc_merkle_branch(7).unwrap();
|
||||
|
||||
let tree_json = serde_json::to_string(&tree).unwrap();
|
||||
// Assert same as Python version: tig-benchmarker/tests/merkle_tree.rs
|
||||
assert_eq!(&tree_json, "\"0000000000000010ec2bd03bf86b935fa34d71ad7ebb049f1f10f87d343e521511d8f9e6625620cda4b6064b23dbaa408b171b0fed5628afa267ef40a4f5a806ae2405e85fa6f1c460604abfd7695c05c911fd1ba39654b8381bcee3797692bb863134aa16b68a2c5882f75066fd0398619cdfe6fcfa463ad254ebdecc381c10dd328cb07b498486988d142bfec4b57545a44b809984ab6bee66df2f6d3fb349532199a9daf6a7a2d2f2ce2738e64d2dd1c507c90673c5a3b7d0bb3077a3947a4aa17aa24dc2c48db8c9e67f5bdeaf090a49c34b6fb567d1fa6ffaee939a2c875c510a1d1e6d4a6cb9d8db6bb71b4287b682b768b62a83a92da369d8d66a10980e5e32e4e429aea50cfe342e104404324f40468de99d6f9ad7b8ae4ab228cf1ccd84b4963b12aea5\"");
|
||||
let deserialized_tree: MerkleTree = serde_json::from_str(&tree_json).unwrap();
|
||||
assert_eq!(
|
||||
tree.calc_merkle_root(),
|
||||
@ -112,6 +135,8 @@ fn test_serialization() {
|
||||
);
|
||||
|
||||
let branch_json = serde_json::to_string(&branch).unwrap();
|
||||
// Assert same as Python version: tig-benchmarker/tests/merkle_tree.rs
|
||||
assert_eq!(&branch_json, "\"00b8c9e67f5bdeaf090a49c34b6fb567d1fa6ffaee939a2c875c510a1d1e6d4a6c01897c33b84ad3657652be252aae642f7c5e1bdf4e22231d013907254e817753d602f94c4d317f59fd4df80655d879260ce43279ae1962953d79c90d6fb26970b27a030cfe342e104404324f40468de99d6f9ad7b8ae4ab228cf1ccd84b4963b12aea5\"");
|
||||
let deserialized_branch: MerkleBranch = serde_json::from_str(&branch_json).unwrap();
|
||||
assert_eq!(
|
||||
branch.calc_merkle_root(&hashes[7], 7).unwrap(),
|
||||
@ -123,6 +148,11 @@ fn test_serialization() {
|
||||
fn test_merkle_hash_serialization() {
|
||||
let hash = MerkleHash([1; 32]);
|
||||
let serialized = serde_json::to_string(&hash).unwrap();
|
||||
// Assert same as Python version: tig-benchmarker/tests/merkle_tree.rs
|
||||
assert_eq!(
|
||||
&serialized,
|
||||
"\"0101010101010101010101010101010101010101010101010101010101010101\""
|
||||
);
|
||||
let deserialized: MerkleHash = serde_json::from_str(&serialized).unwrap();
|
||||
assert_eq!(hash, deserialized);
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user