Migrate to new slave py
This commit is contained in:
parent
109dd3c235
commit
d1f4c4f6e6
152
tig-benchmarker/common/merkle_tree.py
Normal file
152
tig-benchmarker/common/merkle_tree.py
Normal file
@ -0,0 +1,152 @@
|
||||
from blake3 import blake3
|
||||
from typing import List, Tuple
|
||||
from .utils import FromStr, u8s_from_str
|
||||
|
||||
class MerkleHash(FromStr):
|
||||
def __init__(self, value: bytes):
|
||||
if len(value) != 32:
|
||||
raise ValueError("MerkleHash must be exactly 32 bytes")
|
||||
self.value = value
|
||||
|
||||
@classmethod
|
||||
def from_str(cls, str: str):
|
||||
return cls(bytes.fromhex(str))
|
||||
|
||||
@classmethod
|
||||
def null(cls):
|
||||
return cls(bytes([0] * 32))
|
||||
|
||||
def to_str(self):
|
||||
return self.value.hex()
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, MerkleHash) and self.value == other.value
|
||||
|
||||
def __repr__(self):
|
||||
return f"MerkleHash({self.to_str()})"
|
||||
|
||||
class MerkleTree(FromStr):
|
||||
def __init__(self, hashed_leafs: List[MerkleHash], n: int):
|
||||
if len(hashed_leafs) > n:
|
||||
raise ValueError("Invalid tree size")
|
||||
if n & (n - 1) != 0:
|
||||
raise ValueError("n must be a power of 2")
|
||||
self.hashed_leafs = hashed_leafs
|
||||
self.n = n
|
||||
|
||||
def to_str(self):
|
||||
"""Serializes the MerkleTree to a string"""
|
||||
n_hex = f"{self.n:016x}"
|
||||
hashes_hex = ''.join([h.to_str() for h in self.hashed_leafs])
|
||||
return n_hex + hashes_hex
|
||||
|
||||
def __repr__(self):
|
||||
return f"MerkleTree([{', '.join([str(h) for h in self.hashed_leafs])}], {self.n})"
|
||||
|
||||
@classmethod
|
||||
def from_str(cls, s: str):
|
||||
"""Deserializes a MerkleTree from a string"""
|
||||
if len(s) < 16 or (len(s) - 16) % 64 != 0:
|
||||
raise ValueError("Invalid MerkleTree string length")
|
||||
|
||||
n_hex = s[:16]
|
||||
n = int(n_hex, 16)
|
||||
|
||||
hashes_hex = s[16:]
|
||||
hashed_leafs = [
|
||||
MerkleHash.from_str(hashes_hex[i:i + 64])
|
||||
for i in range(0, len(hashes_hex), 64)
|
||||
]
|
||||
|
||||
return cls(hashed_leafs, n)
|
||||
|
||||
def calc_merkle_root(self) -> MerkleHash:
|
||||
hashes = self.hashed_leafs[:]
|
||||
|
||||
while len(hashes) > 1:
|
||||
new_hashes = []
|
||||
for i in range(0, len(hashes), 2):
|
||||
left = hashes[i]
|
||||
result = MerkleHash(left.value)
|
||||
if i + 1 < len(hashes):
|
||||
right = hashes[i + 1]
|
||||
combined = left.value + right.value
|
||||
result = MerkleHash(blake3(combined).digest())
|
||||
new_hashes.append(result)
|
||||
hashes = new_hashes
|
||||
|
||||
return hashes[0]
|
||||
|
||||
def calc_merkle_branch(self, branch_idx: int) -> 'MerkleBranch':
|
||||
if branch_idx >= self.n:
|
||||
raise ValueError("Invalid branch index")
|
||||
|
||||
hashes = self.hashed_leafs[:]
|
||||
branch = []
|
||||
idx = branch_idx
|
||||
depth = 0
|
||||
|
||||
while len(hashes) > 1:
|
||||
new_hashes = []
|
||||
for i in range(0, len(hashes), 2):
|
||||
left = hashes[i]
|
||||
result = MerkleHash(left.value)
|
||||
if i + 1 < len(hashes):
|
||||
right = hashes[i + 1]
|
||||
if idx // 2 == i // 2:
|
||||
branch.append((depth, right if idx % 2 == 0 else left))
|
||||
combined = left.value + right.value
|
||||
result = MerkleHash(blake3(combined).digest())
|
||||
new_hashes.append(result)
|
||||
hashes = new_hashes
|
||||
idx //= 2
|
||||
depth += 1
|
||||
|
||||
return MerkleBranch(branch)
|
||||
|
||||
class MerkleBranch:
|
||||
def __init__(self, stems: List[Tuple[int, MerkleHash]]):
|
||||
self.stems = stems
|
||||
|
||||
def calc_merkle_root(self, hashed_leaf: MerkleHash, branch_idx: int) -> MerkleHash:
|
||||
root = hashed_leaf
|
||||
idx = branch_idx
|
||||
curr_depth = 0
|
||||
|
||||
for depth, hash in self.stems:
|
||||
if curr_depth > depth:
|
||||
raise ValueError("Invalid branch")
|
||||
while curr_depth != depth:
|
||||
idx //= 2
|
||||
curr_depth += 1
|
||||
|
||||
if idx % 2 == 0:
|
||||
combined = root.value + hash.value
|
||||
else:
|
||||
combined = hash.value + root.value
|
||||
root = MerkleHash(blake3(combined).digest())
|
||||
idx //= 2
|
||||
curr_depth += 1
|
||||
|
||||
return root
|
||||
|
||||
def to_str(self):
|
||||
"""Serializes the MerkleBranch to a hex string"""
|
||||
return ''.join([f"{depth:02x}{hash.to_str()}" for depth, hash in self.stems])
|
||||
|
||||
def __repr__(self):
|
||||
return f"MerkleBranch([{', '.join([f'({depth}, {hash})' for depth, hash in self.stems])}])"
|
||||
|
||||
@classmethod
|
||||
def from_str(cls, s: str):
|
||||
"""Deserializes a MerkleBranch from a hex string"""
|
||||
if len(s) % 66 != 0:
|
||||
raise ValueError("Invalid MerkleBranch string length")
|
||||
|
||||
stems = []
|
||||
for i in range(0, len(s), 66):
|
||||
depth = int(s[i:i+2], 16)
|
||||
hash_hex = s[i+2:i+66]
|
||||
stems.append((depth, MerkleHash.from_str(hash_hex)))
|
||||
|
||||
return cls(stems)
|
||||
297
tig-benchmarker/common/structs.py
Normal file
297
tig-benchmarker/common/structs.py
Normal file
@ -0,0 +1,297 @@
|
||||
from .merkle_tree import MerkleHash, MerkleBranch
|
||||
from .utils import FromDict, u64s_from_str, u8s_from_str, jsonify, PreciseNumber
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Set, Any, Tuple
|
||||
|
||||
Point = Tuple[int, ...]
|
||||
Frontier = Set[Point]
|
||||
|
||||
@dataclass
|
||||
class AlgorithmDetails(FromDict):
|
||||
name: str
|
||||
player_id: str
|
||||
challenge_id: str
|
||||
breakthrough_id: Optional[str]
|
||||
type: str
|
||||
fee_paid: PreciseNumber
|
||||
|
||||
@dataclass
|
||||
class AlgorithmState(FromDict):
|
||||
block_confirmed: int
|
||||
round_submitted: int
|
||||
round_pushed: Optional[int]
|
||||
round_active: Optional[int]
|
||||
round_merged: Optional[int]
|
||||
banned: bool
|
||||
|
||||
@dataclass
|
||||
class AlgorithmBlockData(FromDict):
|
||||
num_qualifiers_by_player: Dict[str, int]
|
||||
adoption: PreciseNumber
|
||||
merge_points: int
|
||||
reward: PreciseNumber
|
||||
|
||||
@dataclass
|
||||
class Algorithm(FromDict):
|
||||
id: str
|
||||
details: AlgorithmDetails
|
||||
state: AlgorithmState
|
||||
block_data: Optional[AlgorithmBlockData]
|
||||
|
||||
@dataclass
|
||||
class BenchmarkSettings(FromDict):
|
||||
player_id: str
|
||||
block_id: str
|
||||
challenge_id: str
|
||||
algorithm_id: str
|
||||
difficulty: List[int]
|
||||
|
||||
def calc_seed(self, rand_hash: str, nonce: int) -> bytes:
|
||||
return u8s_from_str(f"{jsonify(self)}_{rand_hash}_{nonce}")
|
||||
|
||||
@dataclass
|
||||
class PrecommitDetails(FromDict):
|
||||
block_started: int
|
||||
num_nonces: int
|
||||
rand_hash: str
|
||||
fee_paid: PreciseNumber
|
||||
|
||||
@dataclass
|
||||
class PrecommitState(FromDict):
|
||||
block_confirmed: int
|
||||
|
||||
@dataclass
|
||||
class Precommit(FromDict):
|
||||
benchmark_id: str
|
||||
details: PrecommitDetails
|
||||
settings: BenchmarkSettings
|
||||
state: PrecommitState
|
||||
|
||||
@dataclass
|
||||
class BenchmarkDetails(FromDict):
|
||||
num_solutions: int
|
||||
merkle_root: MerkleHash
|
||||
sampled_nonces: List[int]
|
||||
|
||||
@dataclass
|
||||
class BenchmarkState(FromDict):
|
||||
block_confirmed: int
|
||||
|
||||
@dataclass
|
||||
class Benchmark(FromDict):
|
||||
id: str
|
||||
details: BenchmarkDetails
|
||||
state: BenchmarkState
|
||||
solution_nonces: Optional[Set[int]]
|
||||
|
||||
@dataclass
|
||||
class OutputMetaData(FromDict):
|
||||
nonce: int
|
||||
runtime_signature: int
|
||||
fuel_consumed: int
|
||||
solution_signature: int
|
||||
|
||||
@classmethod
|
||||
def from_output_data(cls, output_data: 'OutputData') -> 'OutputMetaData':
|
||||
return OutputData.to_output_metadata()
|
||||
|
||||
def to_merkle_hash(self) -> MerkleHash:
|
||||
return MerkleHash(u8s_from_str(jsonify(self)))
|
||||
|
||||
@dataclass
|
||||
class OutputData(FromDict):
|
||||
nonce: int
|
||||
runtime_signature: int
|
||||
fuel_consumed: int
|
||||
solution: dict
|
||||
|
||||
def calc_solution_signature(self) -> int:
|
||||
return u64s_from_str(jsonify(self.solution))[0]
|
||||
|
||||
def to_output_metadata(self) -> OutputMetaData:
|
||||
return OutputMetaData(
|
||||
nonce=self.nonce,
|
||||
runtime_signature=self.runtime_signature,
|
||||
fuel_consumed=self.fuel_consumed,
|
||||
solution_signature=self.calc_solution_signature()
|
||||
)
|
||||
|
||||
def to_merkle_hash(self) -> MerkleHash:
|
||||
return self.to_output_metadata().to_merkle_hash()
|
||||
|
||||
@dataclass
|
||||
class MerkleProof(FromDict):
|
||||
leaf: OutputData
|
||||
branch: MerkleBranch
|
||||
|
||||
@dataclass
|
||||
class ProofDetails(FromDict):
|
||||
submission_delay: int
|
||||
block_active: int
|
||||
|
||||
@dataclass
|
||||
class ProofState(FromDict):
|
||||
block_confirmed: int
|
||||
|
||||
@dataclass
|
||||
class Proof(FromDict):
|
||||
benchmark_id: str
|
||||
details: ProofDetails
|
||||
state: ProofState
|
||||
merkle_proofs: Optional[List[MerkleProof]]
|
||||
|
||||
@dataclass
|
||||
class FraudState(FromDict):
|
||||
block_confirmed: int
|
||||
|
||||
@dataclass
|
||||
class Fraud(FromDict):
|
||||
benchmark_id: str
|
||||
state: FraudState
|
||||
allegation: Optional[str]
|
||||
|
||||
@dataclass
|
||||
class BlockDetails(FromDict):
|
||||
prev_block_id: str
|
||||
height: int
|
||||
round: int
|
||||
timestamp: int
|
||||
num_confirmed: Dict[str, int]
|
||||
num_active: Dict[str, int]
|
||||
|
||||
@dataclass
|
||||
class BlockData(FromDict):
|
||||
confirmed_ids: Dict[str, Set[int]]
|
||||
active_ids: Dict[str, Set[int]]
|
||||
|
||||
@dataclass
|
||||
class Block(FromDict):
|
||||
id: str
|
||||
details: BlockDetails
|
||||
config: dict
|
||||
data: Optional[BlockData]
|
||||
|
||||
@dataclass
|
||||
class ChallengeDetails(FromDict):
|
||||
name: str
|
||||
|
||||
@dataclass
|
||||
class ChallengeState(FromDict):
|
||||
round_active: int
|
||||
|
||||
@dataclass
|
||||
class ChallengeBlockData(FromDict):
|
||||
num_qualifiers: int
|
||||
qualifier_difficulties: Set[Point]
|
||||
base_frontier: Frontier
|
||||
scaled_frontier: Frontier
|
||||
scaling_factor: float
|
||||
base_fee: PreciseNumber
|
||||
per_nonce_fee: PreciseNumber
|
||||
|
||||
@dataclass
|
||||
class Challenge(FromDict):
|
||||
id: str
|
||||
details: ChallengeDetails
|
||||
state: ChallengeState
|
||||
block_data: Optional[ChallengeBlockData]
|
||||
|
||||
@dataclass
|
||||
class OPoWBlockData(FromDict):
|
||||
num_qualifiers_by_challenge: Dict[str, int]
|
||||
cutoff: int
|
||||
delegated_weighted_deposit: PreciseNumber
|
||||
delegators: Set[str]
|
||||
reward_share: float
|
||||
imbalance: PreciseNumber
|
||||
influence: PreciseNumber
|
||||
reward: PreciseNumber
|
||||
|
||||
@dataclass
|
||||
class OPoW(FromDict):
|
||||
player_id: str
|
||||
block_data: Optional[OPoWBlockData]
|
||||
|
||||
@dataclass
|
||||
class PlayerDetails(FromDict):
|
||||
name: Optional[str]
|
||||
is_multisig: bool
|
||||
|
||||
@dataclass
|
||||
class PlayerState(FromDict):
|
||||
total_fees_paid: PreciseNumber
|
||||
available_fee_balance: PreciseNumber
|
||||
delegatee: Optional[dict]
|
||||
votes: dict
|
||||
reward_share: Optional[dict]
|
||||
|
||||
@dataclass
|
||||
class PlayerBlockData(FromDict):
|
||||
delegatee: Optional[str]
|
||||
reward_by_type: Dict[str, PreciseNumber]
|
||||
deposit_by_locked_period: List[PreciseNumber]
|
||||
weighted_deposit: PreciseNumber
|
||||
|
||||
@dataclass
|
||||
class Player(FromDict):
|
||||
id: str
|
||||
details: PlayerDetails
|
||||
state: PlayerState
|
||||
block_data: Optional[PlayerBlockData]
|
||||
|
||||
@dataclass
|
||||
class BinaryDetails(FromDict):
|
||||
compile_success: bool
|
||||
download_url: Optional[str]
|
||||
|
||||
@dataclass
|
||||
class BinaryState(FromDict):
|
||||
block_confirmed: int
|
||||
|
||||
@dataclass
|
||||
class Binary(FromDict):
|
||||
algorithm_id: str
|
||||
details: BinaryDetails
|
||||
state: BinaryState
|
||||
|
||||
@dataclass
|
||||
class TopUpDetails(FromDict):
|
||||
player_id: str
|
||||
amount: PreciseNumber
|
||||
log_idx: int
|
||||
tx_hash: str
|
||||
|
||||
@dataclass
|
||||
class TopUpState(FromDict):
|
||||
block_confirmed: int
|
||||
|
||||
@dataclass
|
||||
class TopUp(FromDict):
|
||||
id: str
|
||||
details: TopUpDetails
|
||||
state: TopUpState
|
||||
|
||||
@dataclass
|
||||
class DifficultyData(FromDict):
|
||||
num_solutions: int
|
||||
num_nonces: int
|
||||
difficulty: Point
|
||||
|
||||
@dataclass
|
||||
class DepositDetails(FromDict):
|
||||
player_id: str
|
||||
amount: PreciseNumber
|
||||
log_idx: int
|
||||
tx_hash: str
|
||||
start_timestamp: int
|
||||
end_timestamp: int
|
||||
|
||||
@dataclass
|
||||
class DepositState(FromDict):
|
||||
block_confirmed: int
|
||||
|
||||
@dataclass
|
||||
class Deposit(FromDict):
|
||||
id: str
|
||||
details: DepositDetails
|
||||
state: DepositState
|
||||
216
tig-benchmarker/common/utils.py
Normal file
216
tig-benchmarker/common/utils.py
Normal file
@ -0,0 +1,216 @@
|
||||
from __future__ import annotations
|
||||
from abc import ABC, abstractclassmethod, abstractmethod
|
||||
from blake3 import blake3
|
||||
from dataclasses import dataclass, fields, is_dataclass, asdict
|
||||
from typing import TypeVar, Type, Dict, Any, List, Union, Optional, get_origin, get_args
|
||||
import json
|
||||
import time
|
||||
|
||||
T = TypeVar('T', bound='DataclassBase')
|
||||
|
||||
class FromStr(ABC):
|
||||
@abstractclassmethod
|
||||
def from_str(cls, s: str):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def to_str(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@dataclass
|
||||
class FromDict:
|
||||
@classmethod
|
||||
def from_dict(cls: Type[T], d: Dict[str, Any]) -> T:
|
||||
field_types = {f.name: f.type for f in fields(cls)}
|
||||
kwargs = {}
|
||||
|
||||
for field in fields(cls):
|
||||
value = d.pop(field.name, None)
|
||||
field_type = field_types[field.name]
|
||||
|
||||
if value is None:
|
||||
if cls._is_optional(field_type):
|
||||
kwargs[field.name] = None
|
||||
else:
|
||||
raise ValueError(f"Missing required field: {field.name}")
|
||||
continue
|
||||
|
||||
kwargs[field.name] = cls._process_value(value, field_type)
|
||||
|
||||
return cls(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def _process_value(cls, value: Any, field_type: Type) -> Any:
|
||||
origin_type = get_origin(field_type)
|
||||
|
||||
if cls._is_optional(field_type):
|
||||
if value is None:
|
||||
return None
|
||||
non_none_type = next(arg for arg in get_args(field_type) if arg is not type(None))
|
||||
return cls._process_value(value, non_none_type)
|
||||
|
||||
if hasattr(field_type, 'from_dict') and isinstance(value, dict):
|
||||
return field_type.from_dict(value)
|
||||
elif hasattr(field_type, 'from_str') and isinstance(value, str):
|
||||
return field_type.from_str(value)
|
||||
elif origin_type in (list, set, tuple):
|
||||
elem_type = get_args(field_type)[0]
|
||||
return origin_type(cls._process_value(item, elem_type) for item in value)
|
||||
elif origin_type is dict:
|
||||
key_type, val_type = get_args(field_type)
|
||||
return {cls._process_value(k, key_type): cls._process_value(v, val_type) for k, v in value.items()}
|
||||
else:
|
||||
return field_type(value)
|
||||
|
||||
@staticmethod
|
||||
def _is_optional(field_type: Type) -> bool:
|
||||
return get_origin(field_type) is Union and type(None) in get_args(field_type)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
d = {}
|
||||
for field in fields(self):
|
||||
value = getattr(self, field.name)
|
||||
if value is not None:
|
||||
if hasattr(value, 'to_dict'):
|
||||
d[field.name] = value.to_dict()
|
||||
elif hasattr(value, 'to_str'):
|
||||
d[field.name] = value.to_str()
|
||||
elif isinstance(value, (list, set, tuple)):
|
||||
d[field.name] = [
|
||||
item.to_dict() if hasattr(item, 'to_dict')
|
||||
else item.to_str() if hasattr(item, 'to_str')
|
||||
else item
|
||||
for item in value
|
||||
]
|
||||
elif isinstance(value, dict):
|
||||
d[field.name] = {
|
||||
k: (v.to_dict() if hasattr(v, 'to_dict')
|
||||
else v.to_str() if hasattr(v, 'to_str')
|
||||
else v)
|
||||
for k, v in value.items()
|
||||
}
|
||||
elif is_dataclass(value):
|
||||
d[field.name] = asdict(value)
|
||||
else:
|
||||
d[field.name] = value
|
||||
return d
|
||||
|
||||
|
||||
class PreciseNumber(FromStr):
|
||||
PRECISION = 10**18 # 18 decimal places of precision
|
||||
|
||||
def __init__(self, value: Union[int, float, str, PreciseNumber]):
|
||||
if isinstance(value, PreciseNumber):
|
||||
self._value = value._value
|
||||
elif isinstance(value, int):
|
||||
self._value = value * self.PRECISION
|
||||
elif isinstance(value, float):
|
||||
self._value = int(value * self.PRECISION)
|
||||
elif isinstance(value, str):
|
||||
self._value = int(value)
|
||||
else:
|
||||
raise TypeError(f"Unsupported type for PreciseNumber: {type(value)}")
|
||||
|
||||
@classmethod
|
||||
def from_str(cls, s: str) -> 'PreciseNumber':
|
||||
return cls(s)
|
||||
|
||||
def to_str(self) -> str:
|
||||
return str(self._value)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"PreciseNumber({self.to_float()})"
|
||||
|
||||
def to_float(self) -> float:
|
||||
return self._value / self.PRECISION
|
||||
|
||||
def __add__(self, other: Union[PreciseNumber, int, float]) -> PreciseNumber:
|
||||
if isinstance(other, (int, float)):
|
||||
other = PreciseNumber(other)
|
||||
return PreciseNumber(self._value + other._value)
|
||||
|
||||
def __sub__(self, other: Union[PreciseNumber, int, float]) -> PreciseNumber:
|
||||
if isinstance(other, (int, float)):
|
||||
other = PreciseNumber(other)
|
||||
return PreciseNumber(self._value - other._value)
|
||||
|
||||
def __mul__(self, other: Union[PreciseNumber, int, float]) -> PreciseNumber:
|
||||
if isinstance(other, (int, float)):
|
||||
other = PreciseNumber(other)
|
||||
return PreciseNumber((self._value * other._value) // self.PRECISION)
|
||||
|
||||
def __truediv__(self, other: Union[PreciseNumber, int, float]) -> PreciseNumber:
|
||||
if isinstance(other, (int, float)):
|
||||
other = PreciseNumber(other)
|
||||
if other._value == 0:
|
||||
raise ZeroDivisionError
|
||||
return PreciseNumber((self._value * self.PRECISION) // other._value)
|
||||
|
||||
def __floordiv__(self, other: Union[PreciseNumber, int, float]) -> PreciseNumber:
|
||||
if isinstance(other, (int, float)):
|
||||
other = PreciseNumber(other)
|
||||
if other._value == 0:
|
||||
raise ZeroDivisionError
|
||||
return PreciseNumber((self._value * self.PRECISION // other._value))
|
||||
|
||||
def __eq__(self, other: Union[PreciseNumber, int, float]) -> bool:
|
||||
if isinstance(other, (int, float)):
|
||||
other = PreciseNumber(other)
|
||||
return self._value == other._value
|
||||
|
||||
def __lt__(self, other: Union[PreciseNumber, int, float]) -> bool:
|
||||
if isinstance(other, (int, float)):
|
||||
other = PreciseNumber(other)
|
||||
return self._value < other._value
|
||||
|
||||
def __le__(self, other: Union[PreciseNumber, int, float]) -> bool:
|
||||
if isinstance(other, (int, float)):
|
||||
other = PreciseNumber(other)
|
||||
return self._value <= other._value
|
||||
|
||||
def __gt__(self, other: Union[PreciseNumber, int, float]) -> bool:
|
||||
if isinstance(other, (int, float)):
|
||||
other = PreciseNumber(other)
|
||||
return self._value > other._value
|
||||
|
||||
def __ge__(self, other: Union[PreciseNumber, int, float]) -> bool:
|
||||
if isinstance(other, (int, float)):
|
||||
other = PreciseNumber(other)
|
||||
return self._value >= other._value
|
||||
|
||||
def __radd__(self, other: Union[int, float]) -> PreciseNumber:
|
||||
return self + other
|
||||
|
||||
def __rsub__(self, other: Union[int, float]) -> PreciseNumber:
|
||||
return PreciseNumber(other) - self
|
||||
|
||||
def __rmul__(self, other: Union[int, float]) -> PreciseNumber:
|
||||
return self * other
|
||||
|
||||
def __rtruediv__(self, other: Union[int, float]) -> PreciseNumber:
|
||||
return PreciseNumber(other) / self
|
||||
|
||||
def __rfloordiv__(self, other: Union[int, float]) -> PreciseNumber:
|
||||
return PreciseNumber(other) // self
|
||||
|
||||
def jsonify(obj: Any) -> str:
|
||||
if hasattr(obj, 'to_dict'):
|
||||
obj = obj.to_dict()
|
||||
return json.dumps(obj, sort_keys=True, separators=(',', ':'))
|
||||
|
||||
def u8s_from_str(input: str) -> bytes:
|
||||
return blake3(input.encode()).digest()
|
||||
|
||||
def u64s_from_str(input: str) -> List[int]:
|
||||
u8s = u8s_from_str(input)
|
||||
return [
|
||||
int.from_bytes(
|
||||
u8s[i * 8:(i + 1) * 8],
|
||||
byteorder='little',
|
||||
signed=False
|
||||
)
|
||||
for i in range(4)
|
||||
]
|
||||
|
||||
def now():
|
||||
return int(time.time() * 1000)
|
||||
@ -3,138 +3,263 @@ import json
|
||||
import os
|
||||
import logging
|
||||
import randomname
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import requests
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
from threading import Thread
|
||||
from common.structs import OutputData, MerkleProof
|
||||
from common.merkle_tree import MerkleTree
|
||||
|
||||
logger = logging.getLogger(os.path.splitext(os.path.basename(__file__))[0])
|
||||
PENDING_BATCH_IDS = set()
|
||||
PROCESSING_BATCH_IDS = set()
|
||||
READY_BATCH_IDS = set()
|
||||
FINISHED_BATCH_IDS = {}
|
||||
|
||||
def now():
|
||||
return int(time.time() * 1000)
|
||||
|
||||
async def download_wasm(session, download_url, wasm_path):
|
||||
def download_wasm(session, download_url, wasm_path):
|
||||
if not os.path.exists(wasm_path):
|
||||
start = now()
|
||||
logger.info(f"downloading WASM from {download_url}")
|
||||
async with session.get(download_url) as resp:
|
||||
if resp.status != 200:
|
||||
raise Exception(f"status {resp.status} when downloading WASM: {await resp.text()}")
|
||||
with open(wasm_path, 'wb') as f:
|
||||
f.write(await resp.read())
|
||||
resp = session.get(download_url)
|
||||
if resp.status_code != 200:
|
||||
raise Exception(f"status {resp.status_code} when downloading WASM: {resp.text}")
|
||||
with open(wasm_path, 'wb') as f:
|
||||
f.write(resp.content)
|
||||
logger.debug(f"downloading WASM: took {now() - start}ms")
|
||||
logger.debug(f"WASM Path: {wasm_path}")
|
||||
|
||||
async def run_tig_worker(tig_worker_path, batch, wasm_path, num_workers):
|
||||
|
||||
def run_tig_worker(tig_worker_path, batch, wasm_path, num_workers, output_path):
|
||||
start = now()
|
||||
cmd = [
|
||||
tig_worker_path, "compute_batch",
|
||||
json.dumps(batch["settings"]),
|
||||
batch["rand_hash"],
|
||||
str(batch["start_nonce"]),
|
||||
json.dumps(batch["settings"]),
|
||||
batch["rand_hash"],
|
||||
str(batch["start_nonce"]),
|
||||
str(batch["num_nonces"]),
|
||||
str(batch["batch_size"]),
|
||||
str(batch["batch_size"]),
|
||||
wasm_path,
|
||||
"--mem", str(batch["runtime_config"]["max_memory"]),
|
||||
"--fuel", str(batch["runtime_config"]["max_fuel"]),
|
||||
"--workers", str(num_workers),
|
||||
"--output", f"{output_path}/{batch['id']}",
|
||||
]
|
||||
if batch["sampled_nonces"]:
|
||||
cmd += ["--sampled", *map(str, batch["sampled_nonces"])]
|
||||
logger.info(f"computing batch: {' '.join(cmd)}")
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
process = subprocess.Popen(
|
||||
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
||||
)
|
||||
stdout, stderr = await process.communicate()
|
||||
stdout, stderr = process.communicate()
|
||||
if process.returncode != 0:
|
||||
raise Exception(f"tig-worker failed: {stderr.decode()}")
|
||||
result = json.loads(stdout.decode())
|
||||
logger.info(f"computing batch took {now() - start}ms")
|
||||
logger.debug(f"batch result: {result}")
|
||||
return result
|
||||
with open(f"{output_path}/{batch['id']}/result.json", "w") as f:
|
||||
json.dump(result, f)
|
||||
|
||||
PROCESSING_BATCH_IDS.remove(batch["id"])
|
||||
READY_BATCH_IDS.add(batch["id"])
|
||||
|
||||
|
||||
async def process_batch(session, master_ip, master_port, tig_worker_path, download_wasms_folder, num_workers, batch, headers):
|
||||
def purge_folders(output_path):
|
||||
n = now()
|
||||
purge_batch_ids = [
|
||||
batch_id
|
||||
for batch_id, finish_time in FINISHED_BATCH_IDS.items()
|
||||
if n >= finish_time + 300000
|
||||
]
|
||||
if len(purge_batch_ids) == 0:
|
||||
time.sleep(5)
|
||||
return
|
||||
|
||||
for batch_id in purge_batch_ids:
|
||||
shutil.rmtree(f"{output_path}/{batch_id}")
|
||||
FINISHED_BATCH_IDS.pop(batch_id)
|
||||
|
||||
|
||||
def send_results(session, master_ip, master_port, tig_worker_path, download_wasms_folder, num_workers, output_path):
|
||||
try:
|
||||
batch_id = f"{batch['benchmark_id']}_{batch['start_nonce']}"
|
||||
logger.info(f"Processing batch {batch_id}: {batch}")
|
||||
batch_id = READY_BATCH_IDS.pop()
|
||||
except KeyError:
|
||||
logger.debug("No pending batches")
|
||||
time.sleep(1)
|
||||
return
|
||||
|
||||
output_folder = f"{output_path}/{batch_id}"
|
||||
with open(f"{output_folder}/batch.json") as f:
|
||||
batch = json.load(f)
|
||||
|
||||
# Step 2: Download WASM
|
||||
wasm_path = os.path.join(download_wasms_folder, f"{batch['settings']['algorithm_id']}.wasm")
|
||||
await download_wasm(session, batch['download_url'], wasm_path)
|
||||
if (
|
||||
not os.path.exists(f"{output_folder}/result.json")
|
||||
or not all(
|
||||
os.path.exists(f"{output_folder}/{nonce}.json")
|
||||
for nonce in range(batch["start_nonce"], batch["start_nonce"] + batch["num_nonces"])
|
||||
)
|
||||
):
|
||||
if os.path.exists(f"{output_folder}/result.json"):
|
||||
os.remove(f"{output_folder}/result.json")
|
||||
logger.debug(f"Batch {batch_id} flagged as ready, but missing nonce files")
|
||||
PENDING_BATCH_IDS.add(batch_id)
|
||||
return
|
||||
|
||||
# Step 3: Run tig-worker
|
||||
result = await run_tig_worker(tig_worker_path, batch, wasm_path, num_workers)
|
||||
with open(f"{output_folder}/result.json") as f:
|
||||
result = json.load(f)
|
||||
merkle_proofs = None
|
||||
|
||||
# Step 4: Submit results
|
||||
start = now()
|
||||
submit_url = f"http://{master_ip}:{master_port}/submit-batch-result/{batch_id}"
|
||||
logger.info(f"posting results to {submit_url}")
|
||||
async with session.post(submit_url, json=result, headers=headers) as resp:
|
||||
if resp.status != 200:
|
||||
raise Exception(f"status {resp.status} when posting results to master: {await resp.text()}")
|
||||
logger.debug(f"posting results took {now() - start} ms")
|
||||
if batch["sampled_nonces"] is not None:
|
||||
leafs = {}
|
||||
for nonce in range(batch["start_nonce"], batch["start_nonce"] + batch["num_nonces"]):
|
||||
with open(f"{output_folder}/{nonce}.json") as f:
|
||||
leafs[nonce] = OutputData.from_dict(json.load(f))
|
||||
|
||||
merkle_tree = MerkleTree(
|
||||
[x.to_merkle_hash() for x in leafs.values()],
|
||||
batch["batch_size"]
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing batch {batch_id}: {e}")
|
||||
merkle_proofs = [
|
||||
MerkleProof(
|
||||
leaf=leafs[n],
|
||||
branch=merkle_tree.calc_merkle_branch(branch_idx=n - batch["start_nonce"])
|
||||
).to_dict()
|
||||
for n in batch["sampled_nonces"]
|
||||
]
|
||||
|
||||
result["merkle_proof"] = merkle_proofs if merkle_proofs is not None else []
|
||||
|
||||
submit_url = f"http://{master_ip}:{master_port}/submit-batch-result/{batch_id}"
|
||||
logger.info(f"posting proofs to {submit_url}")
|
||||
resp = session.post(submit_url, json=result})
|
||||
if resp.status_code == 200:
|
||||
FINISHED_BATCH_IDS[batch_id] = now()
|
||||
logger.info(f"successfully posted proofs for batch {batch_id}")
|
||||
elif resp.status_code == 408: # took too long
|
||||
FINISHED_BATCH_IDS[batch_id] = 0
|
||||
logger.error(f"status {resp.status_code} when posting proofs for batch {batch_id} to master: {resp.text}")
|
||||
else:
|
||||
logger.error(f"status {resp.status_code} when posting proofs for batch {batch_id} to master: {resp.text}")
|
||||
READY_BATCH_IDS.add(batch_id) # requeue
|
||||
time.sleep(2)
|
||||
|
||||
async def main(
|
||||
|
||||
def process_batch(session, tig_worker_path, download_wasms_folder, num_workers, output_path):
|
||||
try:
|
||||
batch_id = PENDING_BATCH_IDS.pop()
|
||||
except KeyError:
|
||||
logger.debug("No pending batches")
|
||||
time.sleep(1)
|
||||
return
|
||||
|
||||
if (
|
||||
batch_id in PROCESSING_BATCH_IDS or
|
||||
batch_id in READY_BATCH_IDS
|
||||
):
|
||||
return
|
||||
|
||||
if os.path.exists(f"{output_path}/{batch_id}/result.json"):
|
||||
logger.info(f"Batch {batch_id} already processed")
|
||||
READY_BATCH_IDS.add(batch_id)
|
||||
return
|
||||
|
||||
with open(f"{output_path}/{batch_id}/batch.json") as f:
|
||||
batch = json.load(f)
|
||||
|
||||
wasm_path = os.path.join(download_wasms_folder, f"{batch['settings']['algorithm_id']}.wasm")
|
||||
download_wasm(session, batch['download_url'], wasm_path)
|
||||
|
||||
PROCESSING_BATCH_IDS.add(batch_id)
|
||||
Thread(
|
||||
target=run_tig_worker,
|
||||
args=(tig_worker_path, batch, wasm_path, num_workers, output_path)
|
||||
).start()
|
||||
|
||||
|
||||
def poll_batch(session, master_ip, master_port, output_path):
|
||||
get_batch_url = f"http://{master_ip}:{master_port}/get-batches"
|
||||
logger.info(f"fetching job from {get_batch_url}")
|
||||
resp = session.get(get_batch_url)
|
||||
|
||||
if resp.status_code == 200:
|
||||
batch = resp.json()
|
||||
logger.info(f"fetched batch: {batch}")
|
||||
output_folder = f"{output_path}/{batch['id']}"
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
with open(f"{output_folder}/batch.json", "w") as f:
|
||||
json.dump(batch, f)
|
||||
PENDING_BATCH_IDS.add(batch['id'])
|
||||
time.sleep(0.2)
|
||||
|
||||
elif resp.status_code == 425: # too early
|
||||
batches = resp.json()
|
||||
batch_ids = [batch['id'] for batch in batches]
|
||||
logger.info(f"max concurrent batches reached: {batch_ids}")
|
||||
for batch in batches:
|
||||
output_folder = f"{output_path}/{batch['id']}"
|
||||
if os.path.exists(output_folder):
|
||||
continue
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
with open(f"{output_folder}/batch.json", "w") as f:
|
||||
json.dump(batch, f)
|
||||
PENDING_BATCH_IDS.clear()
|
||||
PENDING_BATCH_IDS.update(batch_ids)
|
||||
time.sleep(5)
|
||||
|
||||
else:
|
||||
logger.error(f"status {resp.status_code} when fetching batch: {resp.text}")
|
||||
time.sleep(5)
|
||||
|
||||
|
||||
def wrap_thread(func, *args):
|
||||
logger.info(f"Starting thread for {func.__name__}")
|
||||
while True:
|
||||
try:
|
||||
func(*args)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in {func.__name__}: {e}")
|
||||
time.sleep(5)
|
||||
|
||||
|
||||
def main(
|
||||
master_ip: str,
|
||||
tig_worker_path: str,
|
||||
download_wasms_folder: str,
|
||||
num_workers: int,
|
||||
slave_name: str,
|
||||
master_port: int
|
||||
master_port: int,
|
||||
output_path: str,
|
||||
):
|
||||
print(f"Starting slave {slave_name}")
|
||||
|
||||
if not os.path.exists(tig_worker_path):
|
||||
raise FileNotFoundError(f"tig-worker not found at path: {tig_worker_path}")
|
||||
os.makedirs(download_wasms_folder, exist_ok=True)
|
||||
|
||||
headers = {
|
||||
session = requests.Session()
|
||||
session.headers.update({
|
||||
"User-Agent": slave_name
|
||||
}
|
||||
})
|
||||
|
||||
Thread(
|
||||
target=wrap_thread,
|
||||
args=(process_batch, session, tig_worker_path, download_wasms_folder, num_workers, output_path)
|
||||
).start()
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
while True:
|
||||
try:
|
||||
# Step 1: Query for job test maj
|
||||
start = now()
|
||||
get_batch_url = f"http://{master_ip}:{master_port}/get-batches"
|
||||
logger.info(f"fetching job from {get_batch_url}")
|
||||
try:
|
||||
resp = await asyncio.wait_for(session.get(get_batch_url, headers=headers), timeout=5)
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
if resp.status == 404 and text.strip() == "No batches available":
|
||||
# Retry with master_port - 1
|
||||
new_port = master_port - 1
|
||||
get_batch_url = f"http://{master_ip}:{new_port}/get-batches"
|
||||
logger.info(f"No batches available on port {master_port}, trying port {new_port}")
|
||||
resp_retry = await asyncio.wait_for(session.get(get_batch_url, headers=headers), timeout=10)
|
||||
if resp_retry.status != 200:
|
||||
raise Exception(f"status {resp_retry.status} when fetching job: {await resp_retry.text()}")
|
||||
master_port_w = new_port
|
||||
batches = await resp_retry.json(content_type=None)
|
||||
else:
|
||||
raise Exception(f"status {resp.status} when fetching job: {text}")
|
||||
else:
|
||||
master_port_w = master_port
|
||||
batches = await resp.json(content_type=None)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Timeout occurred when fetching job from {get_batch_url}")
|
||||
continue
|
||||
logger.debug(f"fetching job: took {now() - start}ms")
|
||||
Thread(
|
||||
target=wrap_thread,
|
||||
args=(send_results, session, master_ip, master_port, tig_worker_path, download_wasms_folder, num_workers, output_path)
|
||||
).start()
|
||||
|
||||
# Process batches concurrently
|
||||
tasks = [
|
||||
process_batch(session, master_ip, master_port_w, tig_worker_path, download_wasms_folder, num_workers, batch, headers)
|
||||
for batch in batches
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
Thread(
|
||||
target=wrap_thread,
|
||||
args=(purge_folders, output_path)
|
||||
).start()
|
||||
|
||||
wrap_thread(poll_batch, session, master_ip, master_port, output_path)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
await asyncio.sleep(2)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="TIG Slave Benchmarker")
|
||||
@ -145,12 +270,13 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--name", type=str, default=randomname.get_name(), help="Name for the slave (default: randomly generated)")
|
||||
parser.add_argument("--port", type=int, default=5115, help="Port for master (default: 5115)")
|
||||
parser.add_argument("--verbose", action='store_true', help="Print debug logs")
|
||||
|
||||
parser.add_argument("--output", type=str, default="results", help="Folder to output results to (default: results)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
format='%(levelname)s - [%(name)s] - %(message)s',
|
||||
level=logging.DEBUG if args.verbose else logging.INFO
|
||||
)
|
||||
|
||||
asyncio.run(main(args.master_ip, args.tig_worker_path, args.download, args.workers, args.name, args.port))
|
||||
main(args.master, args.tig_worker_path, args.download, args.workers, args.name, args.port, args.output)
|
||||
Loading…
Reference in New Issue
Block a user