fix: compute intrinsic locking

This commit is contained in:
Cassandra Heart 2025-10-08 22:08:03 -05:00
parent d5b7679e70
commit 76405d7876
No known key found for this signature in database
GPG Key ID: 371083BFA6C240AA
4 changed files with 305 additions and 78 deletions

View File

@ -1082,31 +1082,86 @@ func (c *ComputeIntrinsic) InvokeStep(
}
// Lock implements intrinsics.Intrinsic.
func (c *ComputeIntrinsic) Lock(
writeAddresses [][]byte,
readAddresses [][]byte,
) error {
c.lockedReadsMx.Lock()
c.lockedWritesMx.Lock()
defer c.lockedReadsMx.Unlock()
defer c.lockedWritesMx.Unlock()
func (a *ComputeIntrinsic) Lock(frameNumber uint64, input []byte) error {
a.lockedReadsMx.Lock()
a.lockedWritesMx.Lock()
defer a.lockedReadsMx.Unlock()
defer a.lockedWritesMx.Unlock()
if c.lockedReads == nil {
c.lockedReads = make(map[string]int)
if a.lockedReads == nil {
a.lockedReads = make(map[string]int)
}
if c.lockedWrites == nil {
c.lockedWrites = make(map[string]struct{})
if a.lockedWrites == nil {
a.lockedWrites = make(map[string]struct{})
}
for _, address := range writeAddresses {
if _, ok := c.lockedWrites[string(address)]; ok {
// Check type prefix to determine request type
if len(input) < 4 {
observability.LockErrors.WithLabelValues(
"compute",
"invalid_input",
).Inc()
return errors.Wrap(errors.New("input too short"), "lock")
}
// Read the type prefix
typePrefix := binary.BigEndian.Uint32(input[:4])
var reads, writes [][]byte
var err error
// Handle each type based on type prefix
switch typePrefix {
case protobufs.CodeDeploymentType:
reads, writes, err = a.tryLockCodeDeployment(frameNumber, input)
if err != nil {
return err
}
observability.LockTotal.WithLabelValues(
"compute",
"code_deployment",
).Inc()
case protobufs.CodeExecuteType:
reads, writes, err = a.tryLockCodeExecute(frameNumber, input)
if err != nil {
return err
}
observability.LockTotal.WithLabelValues("compute", "code_execute").Inc()
case protobufs.CodeFinalizeType:
reads, writes, err = a.tryLockCodeFinalize(frameNumber, input)
if err != nil {
return err
}
observability.LockTotal.WithLabelValues(
"compute",
"code_finalize",
).Inc()
default:
observability.LockErrors.WithLabelValues(
"compute",
"unknown_type",
).Inc()
return errors.Wrap(
errors.New("unknown compute request type"),
"lock",
)
}
for _, address := range writes {
if _, ok := a.lockedWrites[string(address)]; ok {
return errors.Wrap(
fmt.Errorf("address %x is already locked for writing", address),
"lock",
)
}
if _, ok := c.lockedReads[string(address)]; ok {
if _, ok := a.lockedReads[string(address)]; ok {
return errors.Wrap(
fmt.Errorf("address %x is already locked for reading", address),
"lock",
@ -1114,8 +1169,8 @@ func (c *ComputeIntrinsic) Lock(
}
}
for _, address := range readAddresses {
if _, ok := c.lockedWrites[string(address)]; ok {
for _, address := range reads {
if _, ok := a.lockedWrites[string(address)]; ok {
return errors.Wrap(
fmt.Errorf("address %x is already locked for writing", address),
"lock",
@ -1123,76 +1178,27 @@ func (c *ComputeIntrinsic) Lock(
}
}
for _, address := range writeAddresses {
c.lockedWrites[string(address)] = struct{}{}
c.lockedReads[string(address)] = c.lockedReads[string(address)] + 1
for _, address := range writes {
a.lockedWrites[string(address)] = struct{}{}
a.lockedReads[string(address)] = a.lockedReads[string(address)] + 1
}
for _, address := range readAddresses {
c.lockedReads[string(address)] = c.lockedReads[string(address)] + 1
for _, address := range reads {
a.lockedReads[string(address)] = a.lockedReads[string(address)] + 1
}
return nil
}
// Unlock implements intrinsics.Intrinsic.
func (c *ComputeIntrinsic) Unlock(
writeAddresses [][]byte,
readAddresses [][]byte,
) error {
c.lockedReadsMx.Lock()
c.lockedWritesMx.Lock()
defer c.lockedReadsMx.Unlock()
defer c.lockedWritesMx.Unlock()
func (a *ComputeIntrinsic) Unlock() error {
a.lockedReadsMx.Lock()
a.lockedWritesMx.Lock()
defer a.lockedReadsMx.Unlock()
defer a.lockedWritesMx.Unlock()
if c.lockedReads == nil {
c.lockedReads = make(map[string]int)
}
if c.lockedWrites == nil {
c.lockedWrites = make(map[string]struct{})
}
alteredWriteLocks := make(map[string]struct{})
for k, v := range c.lockedWrites {
alteredWriteLocks[k] = v
}
for _, address := range writeAddresses {
delete(alteredWriteLocks, string(address))
}
for _, address := range readAddresses {
if _, ok := alteredWriteLocks[string(address)]; ok {
return errors.Wrap(
fmt.Errorf("address %x is still locked for writing", address),
"unlock",
)
}
}
for _, address := range writeAddresses {
delete(c.lockedWrites, string(address))
i, ok := c.lockedReads[string(address)]
if ok {
if i <= 1 {
delete(c.lockedReads, string(address))
} else {
c.lockedReads[string(address)] = i - 1
}
}
}
for _, address := range readAddresses {
i, ok := c.lockedReads[string(address)]
if ok {
if i <= 1 {
delete(c.lockedReads, string(address))
} else {
c.lockedReads[string(address)] = i - 1
}
}
}
a.lockedReads = make(map[string]int)
a.lockedWrites = make(map[string]struct{})
return nil
}
@ -1332,4 +1338,137 @@ type ComputeUpdate struct {
OwnerSignature *protobufs.BLS48581AggregateSignature
}
func (c *ComputeIntrinsic) tryLockCodeDeployment(
frameNumber uint64,
input []byte,
) (
[][]byte,
[][]byte,
error,
) {
codeDeployment := &CodeDeployment{}
if err := codeDeployment.FromBytes(input, c.compiler); err != nil {
observability.LockErrors.WithLabelValues(
"compute",
"code_deployment",
).Inc()
return nil, nil, errors.Wrap(err, "lock")
}
reads, err := codeDeployment.GetReadAddresses(frameNumber)
if err != nil {
observability.LockErrors.WithLabelValues(
"compute",
"code_deployment",
).Inc()
return nil, nil, errors.Wrap(err, "lock")
}
writes, err := codeDeployment.GetWriteAddresses(frameNumber)
if err != nil {
observability.LockErrors.WithLabelValues(
"compute",
"code_deployment",
).Inc()
return nil, nil, errors.Wrap(err, "lock")
}
return reads, writes, nil
}
func (c *ComputeIntrinsic) tryLockCodeExecute(
frameNumber uint64,
input []byte,
) (
[][]byte,
[][]byte,
error,
) {
codeExecute := &CodeExecute{}
if err := codeExecute.FromBytes(
input,
c.hypergraph,
c.bulletproofProver,
c.inclusionProver,
c.verEnc,
c.decafConstructor,
c.keyManager,
c.rdfMultiprover,
); err != nil {
observability.LockErrors.WithLabelValues(
"compute",
"code_execute",
).Inc()
return nil, nil, errors.Wrap(err, "lock")
}
reads, err := codeExecute.GetReadAddresses(frameNumber)
if err != nil {
observability.LockErrors.WithLabelValues(
"compute",
"code_execute",
).Inc()
return nil, nil, errors.Wrap(err, "lock")
}
writes, err := codeExecute.GetWriteAddresses(frameNumber)
if err != nil {
observability.LockErrors.WithLabelValues(
"compute",
"code_execute",
).Inc()
return nil, nil, errors.Wrap(err, "lock")
}
return reads, writes, nil
}
func (c *ComputeIntrinsic) tryLockCodeFinalize(
frameNumber uint64,
input []byte,
) (
[][]byte,
[][]byte,
error,
) {
codeFinalize := &CodeFinalize{}
if err := codeFinalize.FromBytes(
input,
c.domain,
c.hypergraph,
c.bulletproofProver,
c.inclusionProver,
c.verEnc,
c.keyManager,
c.config,
nil,
); err != nil {
observability.LockErrors.WithLabelValues(
"compute",
"code_finalize",
).Inc()
return nil, nil, errors.Wrap(err, "lock")
}
reads, err := codeFinalize.GetReadAddresses(frameNumber)
if err != nil {
observability.LockErrors.WithLabelValues(
"compute",
"code_finalize",
).Inc()
return nil, nil, errors.Wrap(err, "lock")
}
writes, err := codeFinalize.GetWriteAddresses(frameNumber)
if err != nil {
observability.LockErrors.WithLabelValues(
"compute",
"code_finalize",
).Inc()
return nil, nil, errors.Wrap(err, "lock")
}
return reads, writes, nil
}
var _ intrinsics.Intrinsic = (*ComputeIntrinsic)(nil)

View File

@ -85,6 +85,37 @@ func (c *CodeDeployment) Prove(frameNumber uint64) (err error) {
return nil
}
func (c *CodeDeployment) GetReadAddresses(
frameNumber uint64,
) ([][]byte, error) {
return nil, nil
}
func (c *CodeDeployment) GetWriteAddresses(
frameNumber uint64,
) ([][]byte, error) {
// Get the domain from the hypergraph
domain := c.Domain
// Generate a unique address for this code file
codeAddressBI, err := poseidon.HashBytes(
slices.Concat(
domain[:],
c.Circuit,
),
)
if err != nil {
return nil, errors.Wrap(err, "get write addresses")
}
codeAddress := codeAddressBI.FillBytes(make([]byte, 32))
codeFullAddress := [64]byte{}
copy(codeFullAddress[:32], c.Domain[:])
copy(codeFullAddress[32:], codeAddress)
return [][]byte{codeFullAddress[:]}, nil
}
// Verify implements intrinsics.IntrinsicOperation
func (c *CodeDeployment) Verify(frameNumber uint64) (bool, error) {
buf := bytes.NewReader(c.Circuit)

View File

@ -5,6 +5,7 @@ import (
"encoding/binary"
"fmt"
"math/big"
"slices"
"github.com/pkg/errors"
hg "source.quilibrium.com/quilibrium/monorepo/node/execution/state/hypergraph"
@ -330,6 +331,21 @@ func (c *CodeExecute) Prove(frameNumber uint64) error {
return nil
}
func (c *CodeExecute) GetReadAddresses(
frameNumber uint64,
) ([][]byte, error) {
return nil, nil
}
func (c *CodeExecute) GetWriteAddresses(
frameNumber uint64,
) ([][]byte, error) {
return [][]byte{slices.Concat(
c.Domain[:],
c.Rendezvous[:],
)}, nil
}
// Verify implements intrinsics.IntrinsicOperation.
func (c *CodeExecute) Verify(frameNumber uint64) (bool, error) {
if !bytes.Equal(c.ProofOfPayment[0], make([]byte, 56)) {

View File

@ -302,6 +302,47 @@ func (c *CodeFinalize) Prove(frameNumber uint64) error {
return errors.Wrap(err, "prove")
}
func (c *CodeFinalize) GetReadAddresses(
frameNumber uint64,
) ([][]byte, error) {
return nil, nil
}
func (c *CodeFinalize) GetWriteAddresses(
frameNumber uint64,
) ([][]byte, error) {
// Generate results address
resultsBI, err := poseidon.HashBytes(slices.Concat(
c.Rendezvous[:],
[]byte("RESULTS_CODE_FINALIZE"),
))
if err != nil {
return nil, errors.Wrap(err, "get write addresses")
}
resultsAddress := resultsBI.FillBytes(make([]byte, 32))
// Generate state changes address similar to results address
changesBI, err := poseidon.HashBytes(slices.Concat(
c.Rendezvous[:],
[]byte("STATE_CHANGES_CODE_FINALIZE"),
))
if err != nil {
return nil, errors.Wrap(err, "get write addresses")
}
changesAddress := changesBI.FillBytes(make([]byte, 32))
return [][]byte{
slices.Concat(
c.domain[:],
resultsAddress,
),
slices.Concat(
c.domain[:],
changesAddress,
),
}, nil
}
// Verify implements intrinsics.IntrinsicOperation.
func (c *CodeFinalize) Verify(frameNumber uint64) (bool, error) {
// Verify all results are present