diff --git a/tig-benchmarker/common/merkle_tree.py b/tig-benchmarker/common/merkle_tree.py new file mode 100644 index 0000000..202a4de --- /dev/null +++ b/tig-benchmarker/common/merkle_tree.py @@ -0,0 +1,152 @@ +from blake3 import blake3 +from typing import List, Tuple +from .utils import FromStr, u8s_from_str + +class MerkleHash(FromStr): + def __init__(self, value: bytes): + if len(value) != 32: + raise ValueError("MerkleHash must be exactly 32 bytes") + self.value = value + + @classmethod + def from_str(cls, str: str): + return cls(bytes.fromhex(str)) + + @classmethod + def null(cls): + return cls(bytes([0] * 32)) + + def to_str(self): + return self.value.hex() + + def __eq__(self, other): + return isinstance(other, MerkleHash) and self.value == other.value + + def __repr__(self): + return f"MerkleHash({self.to_str()})" + +class MerkleTree(FromStr): + def __init__(self, hashed_leafs: List[MerkleHash], n: int): + if len(hashed_leafs) > n: + raise ValueError("Invalid tree size") + if n & (n - 1) != 0: + raise ValueError("n must be a power of 2") + self.hashed_leafs = hashed_leafs + self.n = n + + def to_str(self): + """Serializes the MerkleTree to a string""" + n_hex = f"{self.n:016x}" + hashes_hex = ''.join([h.to_str() for h in self.hashed_leafs]) + return n_hex + hashes_hex + + def __repr__(self): + return f"MerkleTree([{', '.join([str(h) for h in self.hashed_leafs])}], {self.n})" + + @classmethod + def from_str(cls, s: str): + """Deserializes a MerkleTree from a string""" + if len(s) < 16 or (len(s) - 16) % 64 != 0: + raise ValueError("Invalid MerkleTree string length") + + n_hex = s[:16] + n = int(n_hex, 16) + + hashes_hex = s[16:] + hashed_leafs = [ + MerkleHash.from_str(hashes_hex[i:i + 64]) + for i in range(0, len(hashes_hex), 64) + ] + + return cls(hashed_leafs, n) + + def calc_merkle_root(self) -> MerkleHash: + hashes = self.hashed_leafs[:] + + while len(hashes) > 1: + new_hashes = [] + for i in range(0, len(hashes), 2): + left = hashes[i] + result = MerkleHash(left.value) + if i + 1 < len(hashes): + right = hashes[i + 1] + combined = left.value + right.value + result = MerkleHash(blake3(combined).digest()) + new_hashes.append(result) + hashes = new_hashes + + return hashes[0] + + def calc_merkle_branch(self, branch_idx: int) -> 'MerkleBranch': + if branch_idx >= self.n: + raise ValueError("Invalid branch index") + + hashes = self.hashed_leafs[:] + branch = [] + idx = branch_idx + depth = 0 + + while len(hashes) > 1: + new_hashes = [] + for i in range(0, len(hashes), 2): + left = hashes[i] + result = MerkleHash(left.value) + if i + 1 < len(hashes): + right = hashes[i + 1] + if idx // 2 == i // 2: + branch.append((depth, right if idx % 2 == 0 else left)) + combined = left.value + right.value + result = MerkleHash(blake3(combined).digest()) + new_hashes.append(result) + hashes = new_hashes + idx //= 2 + depth += 1 + + return MerkleBranch(branch) + +class MerkleBranch: + def __init__(self, stems: List[Tuple[int, MerkleHash]]): + self.stems = stems + + def calc_merkle_root(self, hashed_leaf: MerkleHash, branch_idx: int) -> MerkleHash: + root = hashed_leaf + idx = branch_idx + curr_depth = 0 + + for depth, hash in self.stems: + if curr_depth > depth: + raise ValueError("Invalid branch") + while curr_depth != depth: + idx //= 2 + curr_depth += 1 + + if idx % 2 == 0: + combined = root.value + hash.value + else: + combined = hash.value + root.value + root = MerkleHash(blake3(combined).digest()) + idx //= 2 + curr_depth += 1 + + return root + + def to_str(self): + """Serializes the MerkleBranch to a hex string""" + return ''.join([f"{depth:02x}{hash.to_str()}" for depth, hash in self.stems]) + + def __repr__(self): + return f"MerkleBranch([{', '.join([f'({depth}, {hash})' for depth, hash in self.stems])}])" + + @classmethod + def from_str(cls, s: str): + """Deserializes a MerkleBranch from a hex string""" + if len(s) % 66 != 0: + raise ValueError("Invalid MerkleBranch string length") + + stems = [] + for i in range(0, len(s), 66): + depth = int(s[i:i+2], 16) + hash_hex = s[i+2:i+66] + stems.append((depth, MerkleHash.from_str(hash_hex))) + + return cls(stems) \ No newline at end of file diff --git a/tig-benchmarker/common/structs.py b/tig-benchmarker/common/structs.py new file mode 100644 index 0000000..985d0b4 --- /dev/null +++ b/tig-benchmarker/common/structs.py @@ -0,0 +1,297 @@ +from .merkle_tree import MerkleHash, MerkleBranch +from .utils import FromDict, u64s_from_str, u8s_from_str, jsonify, PreciseNumber +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Set, Any, Tuple + +Point = Tuple[int, ...] +Frontier = Set[Point] + +@dataclass +class AlgorithmDetails(FromDict): + name: str + player_id: str + challenge_id: str + breakthrough_id: Optional[str] + type: str + fee_paid: PreciseNumber + +@dataclass +class AlgorithmState(FromDict): + block_confirmed: int + round_submitted: int + round_pushed: Optional[int] + round_active: Optional[int] + round_merged: Optional[int] + banned: bool + +@dataclass +class AlgorithmBlockData(FromDict): + num_qualifiers_by_player: Dict[str, int] + adoption: PreciseNumber + merge_points: int + reward: PreciseNumber + +@dataclass +class Algorithm(FromDict): + id: str + details: AlgorithmDetails + state: AlgorithmState + block_data: Optional[AlgorithmBlockData] + +@dataclass +class BenchmarkSettings(FromDict): + player_id: str + block_id: str + challenge_id: str + algorithm_id: str + difficulty: List[int] + + def calc_seed(self, rand_hash: str, nonce: int) -> bytes: + return u8s_from_str(f"{jsonify(self)}_{rand_hash}_{nonce}") + +@dataclass +class PrecommitDetails(FromDict): + block_started: int + num_nonces: int + rand_hash: str + fee_paid: PreciseNumber + +@dataclass +class PrecommitState(FromDict): + block_confirmed: int + +@dataclass +class Precommit(FromDict): + benchmark_id: str + details: PrecommitDetails + settings: BenchmarkSettings + state: PrecommitState + +@dataclass +class BenchmarkDetails(FromDict): + num_solutions: int + merkle_root: MerkleHash + sampled_nonces: List[int] + +@dataclass +class BenchmarkState(FromDict): + block_confirmed: int + +@dataclass +class Benchmark(FromDict): + id: str + details: BenchmarkDetails + state: BenchmarkState + solution_nonces: Optional[Set[int]] + +@dataclass +class OutputMetaData(FromDict): + nonce: int + runtime_signature: int + fuel_consumed: int + solution_signature: int + + @classmethod + def from_output_data(cls, output_data: 'OutputData') -> 'OutputMetaData': + return OutputData.to_output_metadata() + + def to_merkle_hash(self) -> MerkleHash: + return MerkleHash(u8s_from_str(jsonify(self))) + +@dataclass +class OutputData(FromDict): + nonce: int + runtime_signature: int + fuel_consumed: int + solution: dict + + def calc_solution_signature(self) -> int: + return u64s_from_str(jsonify(self.solution))[0] + + def to_output_metadata(self) -> OutputMetaData: + return OutputMetaData( + nonce=self.nonce, + runtime_signature=self.runtime_signature, + fuel_consumed=self.fuel_consumed, + solution_signature=self.calc_solution_signature() + ) + + def to_merkle_hash(self) -> MerkleHash: + return self.to_output_metadata().to_merkle_hash() + +@dataclass +class MerkleProof(FromDict): + leaf: OutputData + branch: MerkleBranch + +@dataclass +class ProofDetails(FromDict): + submission_delay: int + block_active: int + +@dataclass +class ProofState(FromDict): + block_confirmed: int + +@dataclass +class Proof(FromDict): + benchmark_id: str + details: ProofDetails + state: ProofState + merkle_proofs: Optional[List[MerkleProof]] + +@dataclass +class FraudState(FromDict): + block_confirmed: int + +@dataclass +class Fraud(FromDict): + benchmark_id: str + state: FraudState + allegation: Optional[str] + +@dataclass +class BlockDetails(FromDict): + prev_block_id: str + height: int + round: int + timestamp: int + num_confirmed: Dict[str, int] + num_active: Dict[str, int] + +@dataclass +class BlockData(FromDict): + confirmed_ids: Dict[str, Set[int]] + active_ids: Dict[str, Set[int]] + +@dataclass +class Block(FromDict): + id: str + details: BlockDetails + config: dict + data: Optional[BlockData] + +@dataclass +class ChallengeDetails(FromDict): + name: str + +@dataclass +class ChallengeState(FromDict): + round_active: int + +@dataclass +class ChallengeBlockData(FromDict): + num_qualifiers: int + qualifier_difficulties: Set[Point] + base_frontier: Frontier + scaled_frontier: Frontier + scaling_factor: float + base_fee: PreciseNumber + per_nonce_fee: PreciseNumber + +@dataclass +class Challenge(FromDict): + id: str + details: ChallengeDetails + state: ChallengeState + block_data: Optional[ChallengeBlockData] + +@dataclass +class OPoWBlockData(FromDict): + num_qualifiers_by_challenge: Dict[str, int] + cutoff: int + delegated_weighted_deposit: PreciseNumber + delegators: Set[str] + reward_share: float + imbalance: PreciseNumber + influence: PreciseNumber + reward: PreciseNumber + +@dataclass +class OPoW(FromDict): + player_id: str + block_data: Optional[OPoWBlockData] + +@dataclass +class PlayerDetails(FromDict): + name: Optional[str] + is_multisig: bool + +@dataclass +class PlayerState(FromDict): + total_fees_paid: PreciseNumber + available_fee_balance: PreciseNumber + delegatee: Optional[dict] + votes: dict + reward_share: Optional[dict] + +@dataclass +class PlayerBlockData(FromDict): + delegatee: Optional[str] + reward_by_type: Dict[str, PreciseNumber] + deposit_by_locked_period: List[PreciseNumber] + weighted_deposit: PreciseNumber + +@dataclass +class Player(FromDict): + id: str + details: PlayerDetails + state: PlayerState + block_data: Optional[PlayerBlockData] + +@dataclass +class BinaryDetails(FromDict): + compile_success: bool + download_url: Optional[str] + +@dataclass +class BinaryState(FromDict): + block_confirmed: int + +@dataclass +class Binary(FromDict): + algorithm_id: str + details: BinaryDetails + state: BinaryState + +@dataclass +class TopUpDetails(FromDict): + player_id: str + amount: PreciseNumber + log_idx: int + tx_hash: str + +@dataclass +class TopUpState(FromDict): + block_confirmed: int + +@dataclass +class TopUp(FromDict): + id: str + details: TopUpDetails + state: TopUpState + +@dataclass +class DifficultyData(FromDict): + num_solutions: int + num_nonces: int + difficulty: Point + +@dataclass +class DepositDetails(FromDict): + player_id: str + amount: PreciseNumber + log_idx: int + tx_hash: str + start_timestamp: int + end_timestamp: int + +@dataclass +class DepositState(FromDict): + block_confirmed: int + +@dataclass +class Deposit(FromDict): + id: str + details: DepositDetails + state: DepositState \ No newline at end of file diff --git a/tig-benchmarker/common/utils.py b/tig-benchmarker/common/utils.py new file mode 100644 index 0000000..362a644 --- /dev/null +++ b/tig-benchmarker/common/utils.py @@ -0,0 +1,216 @@ +from __future__ import annotations +from abc import ABC, abstractclassmethod, abstractmethod +from blake3 import blake3 +from dataclasses import dataclass, fields, is_dataclass, asdict +from typing import TypeVar, Type, Dict, Any, List, Union, Optional, get_origin, get_args +import json +import time + +T = TypeVar('T', bound='DataclassBase') + +class FromStr(ABC): + @abstractclassmethod + def from_str(cls, s: str): + raise NotImplementedError + + @abstractmethod + def to_str(self) -> str: + raise NotImplementedError + +@dataclass +class FromDict: + @classmethod + def from_dict(cls: Type[T], d: Dict[str, Any]) -> T: + field_types = {f.name: f.type for f in fields(cls)} + kwargs = {} + + for field in fields(cls): + value = d.pop(field.name, None) + field_type = field_types[field.name] + + if value is None: + if cls._is_optional(field_type): + kwargs[field.name] = None + else: + raise ValueError(f"Missing required field: {field.name}") + continue + + kwargs[field.name] = cls._process_value(value, field_type) + + return cls(**kwargs) + + @classmethod + def _process_value(cls, value: Any, field_type: Type) -> Any: + origin_type = get_origin(field_type) + + if cls._is_optional(field_type): + if value is None: + return None + non_none_type = next(arg for arg in get_args(field_type) if arg is not type(None)) + return cls._process_value(value, non_none_type) + + if hasattr(field_type, 'from_dict') and isinstance(value, dict): + return field_type.from_dict(value) + elif hasattr(field_type, 'from_str') and isinstance(value, str): + return field_type.from_str(value) + elif origin_type in (list, set, tuple): + elem_type = get_args(field_type)[0] + return origin_type(cls._process_value(item, elem_type) for item in value) + elif origin_type is dict: + key_type, val_type = get_args(field_type) + return {cls._process_value(k, key_type): cls._process_value(v, val_type) for k, v in value.items()} + else: + return field_type(value) + + @staticmethod + def _is_optional(field_type: Type) -> bool: + return get_origin(field_type) is Union and type(None) in get_args(field_type) + + def to_dict(self) -> Dict[str, Any]: + d = {} + for field in fields(self): + value = getattr(self, field.name) + if value is not None: + if hasattr(value, 'to_dict'): + d[field.name] = value.to_dict() + elif hasattr(value, 'to_str'): + d[field.name] = value.to_str() + elif isinstance(value, (list, set, tuple)): + d[field.name] = [ + item.to_dict() if hasattr(item, 'to_dict') + else item.to_str() if hasattr(item, 'to_str') + else item + for item in value + ] + elif isinstance(value, dict): + d[field.name] = { + k: (v.to_dict() if hasattr(v, 'to_dict') + else v.to_str() if hasattr(v, 'to_str') + else v) + for k, v in value.items() + } + elif is_dataclass(value): + d[field.name] = asdict(value) + else: + d[field.name] = value + return d + + +class PreciseNumber(FromStr): + PRECISION = 10**18 # 18 decimal places of precision + + def __init__(self, value: Union[int, float, str, PreciseNumber]): + if isinstance(value, PreciseNumber): + self._value = value._value + elif isinstance(value, int): + self._value = value * self.PRECISION + elif isinstance(value, float): + self._value = int(value * self.PRECISION) + elif isinstance(value, str): + self._value = int(value) + else: + raise TypeError(f"Unsupported type for PreciseNumber: {type(value)}") + + @classmethod + def from_str(cls, s: str) -> 'PreciseNumber': + return cls(s) + + def to_str(self) -> str: + return str(self._value) + + def __repr__(self) -> str: + return f"PreciseNumber({self.to_float()})" + + def to_float(self) -> float: + return self._value / self.PRECISION + + def __add__(self, other: Union[PreciseNumber, int, float]) -> PreciseNumber: + if isinstance(other, (int, float)): + other = PreciseNumber(other) + return PreciseNumber(self._value + other._value) + + def __sub__(self, other: Union[PreciseNumber, int, float]) -> PreciseNumber: + if isinstance(other, (int, float)): + other = PreciseNumber(other) + return PreciseNumber(self._value - other._value) + + def __mul__(self, other: Union[PreciseNumber, int, float]) -> PreciseNumber: + if isinstance(other, (int, float)): + other = PreciseNumber(other) + return PreciseNumber((self._value * other._value) // self.PRECISION) + + def __truediv__(self, other: Union[PreciseNumber, int, float]) -> PreciseNumber: + if isinstance(other, (int, float)): + other = PreciseNumber(other) + if other._value == 0: + raise ZeroDivisionError + return PreciseNumber((self._value * self.PRECISION) // other._value) + + def __floordiv__(self, other: Union[PreciseNumber, int, float]) -> PreciseNumber: + if isinstance(other, (int, float)): + other = PreciseNumber(other) + if other._value == 0: + raise ZeroDivisionError + return PreciseNumber((self._value * self.PRECISION // other._value)) + + def __eq__(self, other: Union[PreciseNumber, int, float]) -> bool: + if isinstance(other, (int, float)): + other = PreciseNumber(other) + return self._value == other._value + + def __lt__(self, other: Union[PreciseNumber, int, float]) -> bool: + if isinstance(other, (int, float)): + other = PreciseNumber(other) + return self._value < other._value + + def __le__(self, other: Union[PreciseNumber, int, float]) -> bool: + if isinstance(other, (int, float)): + other = PreciseNumber(other) + return self._value <= other._value + + def __gt__(self, other: Union[PreciseNumber, int, float]) -> bool: + if isinstance(other, (int, float)): + other = PreciseNumber(other) + return self._value > other._value + + def __ge__(self, other: Union[PreciseNumber, int, float]) -> bool: + if isinstance(other, (int, float)): + other = PreciseNumber(other) + return self._value >= other._value + + def __radd__(self, other: Union[int, float]) -> PreciseNumber: + return self + other + + def __rsub__(self, other: Union[int, float]) -> PreciseNumber: + return PreciseNumber(other) - self + + def __rmul__(self, other: Union[int, float]) -> PreciseNumber: + return self * other + + def __rtruediv__(self, other: Union[int, float]) -> PreciseNumber: + return PreciseNumber(other) / self + + def __rfloordiv__(self, other: Union[int, float]) -> PreciseNumber: + return PreciseNumber(other) // self + +def jsonify(obj: Any) -> str: + if hasattr(obj, 'to_dict'): + obj = obj.to_dict() + return json.dumps(obj, sort_keys=True, separators=(',', ':')) + +def u8s_from_str(input: str) -> bytes: + return blake3(input.encode()).digest() + +def u64s_from_str(input: str) -> List[int]: + u8s = u8s_from_str(input) + return [ + int.from_bytes( + u8s[i * 8:(i + 1) * 8], + byteorder='little', + signed=False + ) + for i in range(4) + ] + +def now(): + return int(time.time() * 1000) diff --git a/tig-benchmarker/slave.py b/tig-benchmarker/slave.py index 6ca7f31..1e12789 100644 --- a/tig-benchmarker/slave.py +++ b/tig-benchmarker/slave.py @@ -3,138 +3,263 @@ import json import os import logging import randomname -import aiohttp -import asyncio +import requests +import shutil import subprocess import time +from threading import Thread +from common.structs import OutputData, MerkleProof +from common.merkle_tree import MerkleTree logger = logging.getLogger(os.path.splitext(os.path.basename(__file__))[0]) +PENDING_BATCH_IDS = set() +PROCESSING_BATCH_IDS = set() +READY_BATCH_IDS = set() +FINISHED_BATCH_IDS = {} def now(): return int(time.time() * 1000) -async def download_wasm(session, download_url, wasm_path): +def download_wasm(session, download_url, wasm_path): if not os.path.exists(wasm_path): start = now() logger.info(f"downloading WASM from {download_url}") - async with session.get(download_url) as resp: - if resp.status != 200: - raise Exception(f"status {resp.status} when downloading WASM: {await resp.text()}") - with open(wasm_path, 'wb') as f: - f.write(await resp.read()) + resp = session.get(download_url) + if resp.status_code != 200: + raise Exception(f"status {resp.status_code} when downloading WASM: {resp.text}") + with open(wasm_path, 'wb') as f: + f.write(resp.content) logger.debug(f"downloading WASM: took {now() - start}ms") logger.debug(f"WASM Path: {wasm_path}") -async def run_tig_worker(tig_worker_path, batch, wasm_path, num_workers): + +def run_tig_worker(tig_worker_path, batch, wasm_path, num_workers, output_path): start = now() cmd = [ tig_worker_path, "compute_batch", - json.dumps(batch["settings"]), - batch["rand_hash"], - str(batch["start_nonce"]), + json.dumps(batch["settings"]), + batch["rand_hash"], + str(batch["start_nonce"]), str(batch["num_nonces"]), - str(batch["batch_size"]), + str(batch["batch_size"]), wasm_path, "--mem", str(batch["runtime_config"]["max_memory"]), "--fuel", str(batch["runtime_config"]["max_fuel"]), "--workers", str(num_workers), + "--output", f"{output_path}/{batch['id']}", ] - if batch["sampled_nonces"]: - cmd += ["--sampled", *map(str, batch["sampled_nonces"])] logger.info(f"computing batch: {' '.join(cmd)}") - process = await asyncio.create_subprocess_exec( - *cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE + process = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE ) - stdout, stderr = await process.communicate() + stdout, stderr = process.communicate() if process.returncode != 0: raise Exception(f"tig-worker failed: {stderr.decode()}") result = json.loads(stdout.decode()) logger.info(f"computing batch took {now() - start}ms") logger.debug(f"batch result: {result}") - return result + with open(f"{output_path}/{batch['id']}/result.json", "w") as f: + json.dump(result, f) + + PROCESSING_BATCH_IDS.remove(batch["id"]) + READY_BATCH_IDS.add(batch["id"]) + -async def process_batch(session, master_ip, master_port, tig_worker_path, download_wasms_folder, num_workers, batch, headers): +def purge_folders(output_path): + n = now() + purge_batch_ids = [ + batch_id + for batch_id, finish_time in FINISHED_BATCH_IDS.items() + if n >= finish_time + 300000 + ] + if len(purge_batch_ids) == 0: + time.sleep(5) + return + + for batch_id in purge_batch_ids: + shutil.rmtree(f"{output_path}/{batch_id}") + FINISHED_BATCH_IDS.pop(batch_id) + + +def send_results(session, master_ip, master_port, tig_worker_path, download_wasms_folder, num_workers, output_path): try: - batch_id = f"{batch['benchmark_id']}_{batch['start_nonce']}" - logger.info(f"Processing batch {batch_id}: {batch}") + batch_id = READY_BATCH_IDS.pop() + except KeyError: + logger.debug("No pending batches") + time.sleep(1) + return + + output_folder = f"{output_path}/{batch_id}" + with open(f"{output_folder}/batch.json") as f: + batch = json.load(f) - # Step 2: Download WASM - wasm_path = os.path.join(download_wasms_folder, f"{batch['settings']['algorithm_id']}.wasm") - await download_wasm(session, batch['download_url'], wasm_path) + if ( + not os.path.exists(f"{output_folder}/result.json") + or not all( + os.path.exists(f"{output_folder}/{nonce}.json") + for nonce in range(batch["start_nonce"], batch["start_nonce"] + batch["num_nonces"]) + ) + ): + if os.path.exists(f"{output_folder}/result.json"): + os.remove(f"{output_folder}/result.json") + logger.debug(f"Batch {batch_id} flagged as ready, but missing nonce files") + PENDING_BATCH_IDS.add(batch_id) + return - # Step 3: Run tig-worker - result = await run_tig_worker(tig_worker_path, batch, wasm_path, num_workers) + with open(f"{output_folder}/result.json") as f: + result = json.load(f) + merkle_proofs = None - # Step 4: Submit results - start = now() - submit_url = f"http://{master_ip}:{master_port}/submit-batch-result/{batch_id}" - logger.info(f"posting results to {submit_url}") - async with session.post(submit_url, json=result, headers=headers) as resp: - if resp.status != 200: - raise Exception(f"status {resp.status} when posting results to master: {await resp.text()}") - logger.debug(f"posting results took {now() - start} ms") + if batch["sampled_nonces"] is not None: + leafs = {} + for nonce in range(batch["start_nonce"], batch["start_nonce"] + batch["num_nonces"]): + with open(f"{output_folder}/{nonce}.json") as f: + leafs[nonce] = OutputData.from_dict(json.load(f)) + + merkle_tree = MerkleTree( + [x.to_merkle_hash() for x in leafs.values()], + batch["batch_size"] + ) - except Exception as e: - logger.error(f"Error processing batch {batch_id}: {e}") + merkle_proofs = [ + MerkleProof( + leaf=leafs[n], + branch=merkle_tree.calc_merkle_branch(branch_idx=n - batch["start_nonce"]) + ).to_dict() + for n in batch["sampled_nonces"] + ] + + result["merkle_proof"] = merkle_proofs if merkle_proofs is not None else [] + + submit_url = f"http://{master_ip}:{master_port}/submit-batch-result/{batch_id}" + logger.info(f"posting proofs to {submit_url}") + resp = session.post(submit_url, json=result}) + if resp.status_code == 200: + FINISHED_BATCH_IDS[batch_id] = now() + logger.info(f"successfully posted proofs for batch {batch_id}") + elif resp.status_code == 408: # took too long + FINISHED_BATCH_IDS[batch_id] = 0 + logger.error(f"status {resp.status_code} when posting proofs for batch {batch_id} to master: {resp.text}") + else: + logger.error(f"status {resp.status_code} when posting proofs for batch {batch_id} to master: {resp.text}") + READY_BATCH_IDS.add(batch_id) # requeue + time.sleep(2) -async def main( + +def process_batch(session, tig_worker_path, download_wasms_folder, num_workers, output_path): + try: + batch_id = PENDING_BATCH_IDS.pop() + except KeyError: + logger.debug("No pending batches") + time.sleep(1) + return + + if ( + batch_id in PROCESSING_BATCH_IDS or + batch_id in READY_BATCH_IDS + ): + return + + if os.path.exists(f"{output_path}/{batch_id}/result.json"): + logger.info(f"Batch {batch_id} already processed") + READY_BATCH_IDS.add(batch_id) + return + + with open(f"{output_path}/{batch_id}/batch.json") as f: + batch = json.load(f) + + wasm_path = os.path.join(download_wasms_folder, f"{batch['settings']['algorithm_id']}.wasm") + download_wasm(session, batch['download_url'], wasm_path) + + PROCESSING_BATCH_IDS.add(batch_id) + Thread( + target=run_tig_worker, + args=(tig_worker_path, batch, wasm_path, num_workers, output_path) + ).start() + + +def poll_batch(session, master_ip, master_port, output_path): + get_batch_url = f"http://{master_ip}:{master_port}/get-batches" + logger.info(f"fetching job from {get_batch_url}") + resp = session.get(get_batch_url) + + if resp.status_code == 200: + batch = resp.json() + logger.info(f"fetched batch: {batch}") + output_folder = f"{output_path}/{batch['id']}" + os.makedirs(output_folder, exist_ok=True) + with open(f"{output_folder}/batch.json", "w") as f: + json.dump(batch, f) + PENDING_BATCH_IDS.add(batch['id']) + time.sleep(0.2) + + elif resp.status_code == 425: # too early + batches = resp.json() + batch_ids = [batch['id'] for batch in batches] + logger.info(f"max concurrent batches reached: {batch_ids}") + for batch in batches: + output_folder = f"{output_path}/{batch['id']}" + if os.path.exists(output_folder): + continue + os.makedirs(output_folder, exist_ok=True) + with open(f"{output_folder}/batch.json", "w") as f: + json.dump(batch, f) + PENDING_BATCH_IDS.clear() + PENDING_BATCH_IDS.update(batch_ids) + time.sleep(5) + + else: + logger.error(f"status {resp.status_code} when fetching batch: {resp.text}") + time.sleep(5) + + +def wrap_thread(func, *args): + logger.info(f"Starting thread for {func.__name__}") + while True: + try: + func(*args) + except Exception as e: + logger.error(f"Error in {func.__name__}: {e}") + time.sleep(5) + + +def main( master_ip: str, tig_worker_path: str, download_wasms_folder: str, num_workers: int, slave_name: str, - master_port: int + master_port: int, + output_path: str, ): + print(f"Starting slave {slave_name}") + if not os.path.exists(tig_worker_path): raise FileNotFoundError(f"tig-worker not found at path: {tig_worker_path}") os.makedirs(download_wasms_folder, exist_ok=True) - headers = { + session = requests.Session() + session.headers.update({ "User-Agent": slave_name - } + }) + Thread( + target=wrap_thread, + args=(process_batch, session, tig_worker_path, download_wasms_folder, num_workers, output_path) + ).start() - async with aiohttp.ClientSession() as session: - while True: - try: - # Step 1: Query for job test maj - start = now() - get_batch_url = f"http://{master_ip}:{master_port}/get-batches" - logger.info(f"fetching job from {get_batch_url}") - try: - resp = await asyncio.wait_for(session.get(get_batch_url, headers=headers), timeout=5) - if resp.status != 200: - text = await resp.text() - if resp.status == 404 and text.strip() == "No batches available": - # Retry with master_port - 1 - new_port = master_port - 1 - get_batch_url = f"http://{master_ip}:{new_port}/get-batches" - logger.info(f"No batches available on port {master_port}, trying port {new_port}") - resp_retry = await asyncio.wait_for(session.get(get_batch_url, headers=headers), timeout=10) - if resp_retry.status != 200: - raise Exception(f"status {resp_retry.status} when fetching job: {await resp_retry.text()}") - master_port_w = new_port - batches = await resp_retry.json(content_type=None) - else: - raise Exception(f"status {resp.status} when fetching job: {text}") - else: - master_port_w = master_port - batches = await resp.json(content_type=None) - except asyncio.TimeoutError: - logger.error(f"Timeout occurred when fetching job from {get_batch_url}") - continue - logger.debug(f"fetching job: took {now() - start}ms") + Thread( + target=wrap_thread, + args=(send_results, session, master_ip, master_port, tig_worker_path, download_wasms_folder, num_workers, output_path) + ).start() - # Process batches concurrently - tasks = [ - process_batch(session, master_ip, master_port_w, tig_worker_path, download_wasms_folder, num_workers, batch, headers) - for batch in batches - ] - await asyncio.gather(*tasks) + Thread( + target=wrap_thread, + args=(purge_folders, output_path) + ).start() + + wrap_thread(poll_batch, session, master_ip, master_port, output_path) - except Exception as e: - logger.error(e) - await asyncio.sleep(2) if __name__ == "__main__": parser = argparse.ArgumentParser(description="TIG Slave Benchmarker") @@ -145,12 +270,13 @@ if __name__ == "__main__": parser.add_argument("--name", type=str, default=randomname.get_name(), help="Name for the slave (default: randomly generated)") parser.add_argument("--port", type=int, default=5115, help="Port for master (default: 5115)") parser.add_argument("--verbose", action='store_true', help="Print debug logs") - + parser.add_argument("--output", type=str, default="results", help="Folder to output results to (default: results)") + args = parser.parse_args() - + logging.basicConfig( format='%(levelname)s - [%(name)s] - %(message)s', level=logging.DEBUG if args.verbose else logging.INFO ) - asyncio.run(main(args.master_ip, args.tig_worker_path, args.download, args.workers, args.name, args.port)) \ No newline at end of file + main(args.master, args.tig_worker_path, args.download, args.workers, args.name, args.port, args.output) \ No newline at end of file