fix: token intrinsic locking

This commit is contained in:
Cassandra Heart 2025-10-08 22:44:21 -05:00
parent 9a1a9743a6
commit ff449ecc98
No known key found for this signature in database
GPG Key ID: 371083BFA6C240AA
5 changed files with 431 additions and 78 deletions

View File

@ -1124,7 +1124,7 @@ func (h *HypergraphIntrinsic) Lock(frameNumber uint64, input []byte) error {
// Check type prefix to determine request type
if len(input) < 4 {
observability.LockErrors.WithLabelValues(
"compute",
"hypergraph",
"invalid_input",
).Inc()
return errors.Wrap(errors.New("input too short"), "lock")
@ -1181,7 +1181,7 @@ func (h *HypergraphIntrinsic) Lock(frameNumber uint64, input []byte) error {
default:
observability.LockErrors.WithLabelValues(
"compute",
"hypergraph",
"unknown_type",
).Inc()
return errors.Wrap(

View File

@ -750,10 +750,7 @@ func (t *TokenIntrinsic) InvokeStep(
}
// Lock implements intrinsics.Intrinsic.
func (t *TokenIntrinsic) Lock(
writeAddresses [][]byte,
readAddresses [][]byte,
) error {
func (t *TokenIntrinsic) Lock(frameNumber uint64, input []byte) error {
t.lockedReadsMx.Lock()
t.lockedWritesMx.Lock()
defer t.lockedReadsMx.Unlock()
@ -767,7 +764,65 @@ func (t *TokenIntrinsic) Lock(
t.lockedWrites = make(map[string]struct{})
}
for _, address := range writeAddresses {
// Check type prefix to determine request type
if len(input) < 4 {
observability.LockErrors.WithLabelValues(
"token",
"invalid_input",
).Inc()
return errors.Wrap(errors.New("input too short"), "lock")
}
// Read the type prefix
typePrefix := binary.BigEndian.Uint32(input[:4])
var reads, writes [][]byte
var err error
// Handle each type based on type prefix
switch typePrefix {
case protobufs.TransactionType:
reads, writes, err = t.tryLockTransaction(frameNumber, input)
if err != nil {
return err
}
observability.LockTotal.WithLabelValues("token", "transaction").Inc()
case protobufs.PendingTransactionType:
reads, writes, err = t.tryLockPendingTransaction(frameNumber, input)
if err != nil {
return err
}
observability.LockTotal.WithLabelValues(
"token",
"pending_transaction",
).Inc()
case protobufs.MintTransactionType:
reads, writes, err = t.tryLockMintTransaction(frameNumber, input)
if err != nil {
return err
}
observability.LockTotal.WithLabelValues(
"token",
"mint_transaction",
).Inc()
default:
observability.LockErrors.WithLabelValues(
"token",
"unknown_type",
).Inc()
return errors.Wrap(
errors.New("unknown compute request type"),
"lock",
)
}
for _, address := range writes {
if _, ok := t.lockedWrites[string(address)]; ok {
return errors.Wrap(
fmt.Errorf("address %x is already locked for writing", address),
@ -782,7 +837,7 @@ func (t *TokenIntrinsic) Lock(
}
}
for _, address := range readAddresses {
for _, address := range reads {
if _, ok := t.lockedWrites[string(address)]; ok {
return errors.Wrap(
fmt.Errorf("address %x is already locked for writing", address),
@ -791,12 +846,12 @@ func (t *TokenIntrinsic) Lock(
}
}
for _, address := range writeAddresses {
for _, address := range writes {
t.lockedWrites[string(address)] = struct{}{}
t.lockedReads[string(address)] = t.lockedReads[string(address)] + 1
}
for _, address := range readAddresses {
for _, address := range reads {
t.lockedReads[string(address)] = t.lockedReads[string(address)] + 1
}
@ -804,67 +859,165 @@ func (t *TokenIntrinsic) Lock(
}
// Unlock implements intrinsics.Intrinsic.
func (t *TokenIntrinsic) Unlock(
writeAddresses [][]byte,
readAddresses [][]byte,
) error {
func (t *TokenIntrinsic) Unlock() error {
t.lockedReadsMx.Lock()
t.lockedWritesMx.Lock()
defer t.lockedReadsMx.Unlock()
defer t.lockedWritesMx.Unlock()
if t.lockedReads == nil {
t.lockedReads = make(map[string]int)
}
if t.lockedWrites == nil {
t.lockedWrites = make(map[string]struct{})
}
alteredWriteLocks := make(map[string]struct{})
for k, v := range t.lockedWrites {
alteredWriteLocks[k] = v
}
for _, address := range writeAddresses {
delete(alteredWriteLocks, string(address))
}
for _, address := range readAddresses {
if _, ok := alteredWriteLocks[string(address)]; ok {
return errors.Wrap(
fmt.Errorf("address %x is still locked for writing", address),
"unlock",
)
}
}
for _, address := range writeAddresses {
delete(t.lockedWrites, string(address))
i, ok := t.lockedReads[string(address)]
if ok {
if i <= 1 {
delete(t.lockedReads, string(address))
} else {
t.lockedReads[string(address)] = i - 1
}
}
}
for _, address := range readAddresses {
i, ok := t.lockedReads[string(address)]
if ok {
if i <= 1 {
delete(t.lockedReads, string(address))
} else {
t.lockedReads[string(address)] = i - 1
}
}
}
t.lockedReads = make(map[string]int)
t.lockedWrites = make(map[string]struct{})
return nil
}
func (t *TokenIntrinsic) tryLockTransaction(
frameNumber uint64,
input []byte,
) (
[][]byte,
[][]byte,
error,
) {
tx := &Transaction{}
if err := tx.FromBytes(
input,
t.config,
t.hypergraph,
t.bulletproofProver,
t.inclusionProver,
t.verEnc,
t.decafConstructor,
keys.ToKeyRing(t.keyManager, true),
"",
t.rdfMultiprover,
); err != nil {
observability.LockErrors.WithLabelValues(
"token",
"transaction",
).Inc()
return nil, nil, errors.Wrap(err, "lock")
}
reads, err := tx.GetReadAddresses(frameNumber)
if err != nil {
observability.LockErrors.WithLabelValues(
"token",
"transaction",
).Inc()
return nil, nil, errors.Wrap(err, "lock")
}
writes, err := tx.GetWriteAddresses(frameNumber)
if err != nil {
observability.LockErrors.WithLabelValues(
"token",
"transaction",
).Inc()
return nil, nil, errors.Wrap(err, "lock")
}
return reads, writes, nil
}
func (t *TokenIntrinsic) tryLockPendingTransaction(
frameNumber uint64,
input []byte,
) (
[][]byte,
[][]byte,
error,
) {
pendingTx := &PendingTransaction{}
if err := pendingTx.FromBytes(
input,
t.config,
t.hypergraph,
t.bulletproofProver,
t.inclusionProver,
t.verEnc,
t.decafConstructor,
keys.ToKeyRing(t.keyManager, true),
"",
t.rdfMultiprover,
); err != nil {
observability.LockErrors.WithLabelValues(
"token",
"pending_transaction",
).Inc()
return nil, nil, errors.Wrap(err, "lock")
}
reads, err := pendingTx.GetReadAddresses(frameNumber)
if err != nil {
observability.LockErrors.WithLabelValues(
"token",
"pending_transaction",
).Inc()
return nil, nil, errors.Wrap(err, "lock")
}
writes, err := pendingTx.GetWriteAddresses(frameNumber)
if err != nil {
observability.LockErrors.WithLabelValues(
"token",
"pending_transaction",
).Inc()
return nil, nil, errors.Wrap(err, "lock")
}
return reads, writes, nil
}
func (t *TokenIntrinsic) tryLockMintTransaction(
frameNumber uint64,
input []byte,
) (
[][]byte,
[][]byte,
error,
) {
mintTx := &MintTransaction{}
if err := mintTx.FromBytes(
input,
t.config,
t.hypergraph,
t.bulletproofProver,
t.inclusionProver,
t.verEnc,
t.decafConstructor,
keys.ToKeyRing(t.keyManager, true),
"",
t.rdfMultiprover,
); err != nil {
observability.LockErrors.WithLabelValues(
"token",
"mint_transaction",
).Inc()
return nil, nil, errors.Wrap(err, "lock")
}
reads, err := mintTx.GetReadAddresses(frameNumber)
if err != nil {
observability.LockErrors.WithLabelValues(
"token",
"mint_transaction",
).Inc()
return nil, nil, errors.Wrap(err, "lock")
}
writes, err := mintTx.GetWriteAddresses(frameNumber)
if err != nil {
observability.LockErrors.WithLabelValues(
"token",
"mint_transaction",
).Inc()
return nil, nil, errors.Wrap(err, "lock")
}
return reads, writes, nil
}
func (t *TokenIntrinsic) GetRDFSchemaDocument() string {
return t.rdfHypergraphSchema
}

View File

@ -2113,8 +2113,17 @@ func (tx *MintTransaction) GetCost() (*big.Int, error) {
return size, nil
}
func (tx *MintTransaction) GetReadAddresses(
frameNumber uint64,
) ([][]byte, error) {
return nil, nil
}
// GetWriteAddresses implements intrinsics.IntrinsicOperation.
func (tx *MintTransaction) GetWriteAddresses() [][]byte {
func (tx *MintTransaction) GetWriteAddresses(frameNumber uint64) (
[][]byte,
error,
) {
addresses := [][]byte{}
// Each output creates a new coin, which is written to an address based on
@ -2142,26 +2151,28 @@ func (tx *MintTransaction) GetWriteAddresses() [][]byte {
for i := range tx.Inputs {
proverRootDomain := [32]byte(tx.Domain)
proverAddress := slices.Concat(
proverRootDomain[:],
tx.Inputs[i].Proofs[1][:32],
)
// Check if not already in addresses
found := false
for _, addr := range addresses {
if bytes.Equal(addr, proverAddress) {
found = true
break
rewardAddress := []byte{}
if bytes.Equal(tx.Domain[:], QUIL_TOKEN_ADDRESS) {
// Special case: PoMW mints under QUIL use global records for proofs
proverRootDomain = intrinsics.GLOBAL_INTRINSIC_ADDRESS
rewardAddressBI, err := poseidon.HashBytes(slices.Concat(
QUIL_TOKEN_ADDRESS[:],
tx.Inputs[i].Proofs[1][:32],
))
if err != nil {
return nil, errors.Wrap(err, "materialize")
}
rewardAddress = rewardAddressBI.FillBytes(make([]byte, 32))
}
if !found {
addresses = append(addresses, proverAddress)
}
addresses = append(addresses, slices.Concat(
proverRootDomain[:],
rewardAddress,
))
}
}
return addresses
return addresses, nil
}
// Materialize implements intrinsics.IntrinsicOperation.

View File

@ -1504,6 +1504,30 @@ func (tx *PendingTransaction) Prove(frameNumber uint64) error {
return nil
}
func (tx *PendingTransaction) GetReadAddresses(
frameNumber uint64,
) ([][]byte, error) {
return nil, nil
}
func (tx *PendingTransaction) GetWriteAddresses(
frameNumber uint64,
) ([][]byte, error) {
addresses := [][]byte{}
// Build the trees if not already built
if err := tx.buildPendingTransactionTrees(); err != nil {
return nil, errors.Wrap(err, "get write addresses")
}
// Add pending transactions using cached trees
for i := range tx.cachedTrees {
addresses = append(addresses, tx.cachedAddresses[i])
}
return addresses, nil
}
func (tx *PendingTransaction) GetChallenge() ([]byte, error) {
transcript := []byte{}
transcript = append(transcript, tx.Domain[:]...)

View File

@ -1271,6 +1271,171 @@ func (tx *Transaction) Prove(frameNumber uint64) error {
return nil
}
func (tx *Transaction) GetReadAddresses(
frameNumber uint64,
) ([][]byte, error) {
return nil, nil
}
func (tx *Transaction) GetWriteAddresses(
frameNumber uint64,
) ([][]byte, error) {
// Create the coin type hash
coinTypeBI, err := poseidon.HashBytes(
slices.Concat(tx.Domain[:], []byte("coin:Coin")),
)
if err != nil {
return nil, errors.Wrap(err, "get write addresses")
}
coinTypeBytes := coinTypeBI.FillBytes(make([]byte, 32))
addresses := [][]byte{}
// For each output, create a coin
for _, output := range tx.Outputs {
// Create coin tree
coinTree := &qcrypto.VectorCommitmentTree{}
// Index 0: FrameNumber
if err := coinTree.Insert(
[]byte{0},
output.FrameNumber,
nil,
big.NewInt(8),
); err != nil {
return nil, errors.Wrap(err, "get write addresses")
}
// Index 1: Commitment
if err := coinTree.Insert(
[]byte{1 << 2},
output.Commitment,
nil,
big.NewInt(56),
); err != nil {
return nil, errors.Wrap(err, "get write addresses")
}
// Index 2: OneTimeKey
if err := coinTree.Insert(
[]byte{2 << 2},
output.RecipientOutput.OneTimeKey,
nil,
big.NewInt(56),
); err != nil {
return nil, errors.Wrap(err, "get write addresses")
}
// Index 3: VerificationKey
if err := coinTree.Insert(
[]byte{3 << 2},
output.RecipientOutput.VerificationKey,
nil,
big.NewInt(56),
); err != nil {
return nil, errors.Wrap(err, "get write addresses")
}
// Index 4: CoinBalance (encrypted)
if err := coinTree.Insert(
[]byte{4 << 2},
output.RecipientOutput.CoinBalance,
nil,
big.NewInt(56),
); err != nil {
return nil, errors.Wrap(err, "get write addresses")
}
// Index 5: Mask (encrypted)
if err := coinTree.Insert(
[]byte{5 << 2},
output.RecipientOutput.Mask,
nil,
big.NewInt(56),
); err != nil {
return nil, errors.Wrap(err, "get write addresses")
}
// Index 6 & 7: Additional references (for non-divisible tokens)
if len(output.RecipientOutput.AdditionalReference) == 64 &&
len(output.RecipientOutput.AdditionalReferenceKey) == 56 {
if err := coinTree.Insert(
[]byte{6 << 2},
output.RecipientOutput.AdditionalReference,
nil,
big.NewInt(56),
); err != nil {
return nil, errors.Wrap(err, "get write addresses")
}
if err := coinTree.Insert(
[]byte{7 << 2},
output.RecipientOutput.AdditionalReferenceKey,
nil,
big.NewInt(56),
); err != nil {
return nil, errors.Wrap(err, "get write addresses")
}
}
// Type marker at max index
if err := coinTree.Insert(
bytes.Repeat([]byte{0xff}, 32),
coinTypeBytes,
nil,
big.NewInt(32),
); err != nil {
return nil, errors.Wrap(err, "get write addresses")
}
// Compute address and add to state
commit := coinTree.Commit(tx.inclusionProver, false)
outAddrBI, err := poseidon.HashBytes(commit)
if err != nil {
return nil, errors.Wrap(err, "get write addresses")
}
coinAddress := outAddrBI.FillBytes(make([]byte, 32))
addresses = append(addresses, slices.Concat(
tx.Domain[:],
coinAddress,
))
}
// Mark inputs as spent
for _, input := range tx.Inputs {
if len(input.Signature) == 336 {
// Standard format
verificationKey := input.Signature[56*4 : 56*5]
spendCheckBI, err := poseidon.HashBytes(verificationKey)
if err != nil {
return nil, errors.Wrap(err, "get write addresses")
}
// Create spent marker
spentTree := &qcrypto.VectorCommitmentTree{}
if err := spentTree.Insert(
[]byte{0},
[]byte{0x01},
nil,
big.NewInt(0),
); err != nil {
return nil, errors.Wrap(err, "get write addresses")
}
spentAddress := spendCheckBI.FillBytes(make([]byte, 32))
addresses = append(addresses, slices.Concat(
tx.Domain[:],
spentAddress,
))
}
}
return addresses, nil
}
func (tx *Transaction) GetChallenge() ([]byte, error) {
transcript := []byte{}
transcript = append(transcript, tx.Domain[:]...)