mirror of
https://github.com/QuilibriumNetwork/ceremonyclient.git
synced 2026-02-21 10:27:26 +08:00
244 lines
6.0 KiB
Go
244 lines
6.0 KiB
Go
package reward
|
||
|
||
import (
|
||
"math/big"
|
||
"runtime"
|
||
"sync"
|
||
|
||
"github.com/pkg/errors"
|
||
"github.com/shopspring/decimal"
|
||
"golang.org/x/sync/errgroup"
|
||
"source.quilibrium.com/quilibrium/monorepo/types/consensus"
|
||
)
|
||
|
||
// OptimizedProofOfMeaningfulWorkRewardIssuance is optimized by caching the
|
||
// decimal square root calculations
|
||
type OptimizedProofOfMeaningfulWorkRewardIssuance struct {
|
||
sqrtCache map[uint64]decimal.Decimal
|
||
sqrtCacheLock sync.RWMutex
|
||
halfDecimal decimal.Decimal
|
||
}
|
||
|
||
func NewOptRewardIssuance() *OptimizedProofOfMeaningfulWorkRewardIssuance {
|
||
halfRat := big.NewRat(1, 2)
|
||
return &OptimizedProofOfMeaningfulWorkRewardIssuance{
|
||
sqrtCache: make(map[uint64]decimal.Decimal),
|
||
halfDecimal: decimal.NewFromBigRat(halfRat, 53),
|
||
}
|
||
}
|
||
|
||
func (p *OptimizedProofOfMeaningfulWorkRewardIssuance) getSqrt(
|
||
shards uint64,
|
||
) (decimal.Decimal, error) {
|
||
p.sqrtCacheLock.RLock()
|
||
if sqrt, exists := p.sqrtCache[shards]; exists {
|
||
p.sqrtCacheLock.RUnlock()
|
||
return sqrt, nil
|
||
}
|
||
p.sqrtCacheLock.RUnlock()
|
||
|
||
p.sqrtCacheLock.Lock()
|
||
defer p.sqrtCacheLock.Unlock()
|
||
|
||
if sqrt, exists := p.sqrtCache[shards]; exists {
|
||
return sqrt, nil
|
||
}
|
||
|
||
shardFactor := decimal.NewFromUint64(shards)
|
||
sqrt, err := shardFactor.PowWithPrecision(p.halfDecimal, 53)
|
||
if err != nil {
|
||
return decimal.Zero, err
|
||
}
|
||
|
||
p.sqrtCache[shards] = sqrt
|
||
return sqrt, nil
|
||
}
|
||
|
||
func (p *OptimizedProofOfMeaningfulWorkRewardIssuance) Calculate(
|
||
difficulty uint64,
|
||
worldStateBytes uint64,
|
||
units uint64,
|
||
provers []map[string]*consensus.ProverAllocation,
|
||
) ([]*big.Int, error) {
|
||
basis := PomwBasis(difficulty, worldStateBytes, units)
|
||
|
||
output := make([]*big.Int, len(provers))
|
||
|
||
eg := errgroup.Group{}
|
||
eg.SetLimit(runtime.GOMAXPROCS(0))
|
||
|
||
for i := range provers {
|
||
proverIndex := i
|
||
eg.Go(func() error {
|
||
output[proverIndex] = big.NewInt(0)
|
||
|
||
for _, alloc := range provers[proverIndex] {
|
||
// Divide by 2^s
|
||
divisor := int64(1)
|
||
for i := uint8(0); i < alloc.Ring+1; i++ {
|
||
divisor <<= 1
|
||
}
|
||
|
||
// shard is oversubscribed, treat as no rewards
|
||
if divisor == 0 {
|
||
continue
|
||
}
|
||
|
||
ringScaled := decimal.NewFromInt(divisor)
|
||
|
||
factor := decimal.NewFromUint64(alloc.StateSize)
|
||
factor = factor.Div(decimal.NewFromUint64(worldStateBytes))
|
||
|
||
result := factor.Mul(decimal.NewFromBigInt(basis, 0))
|
||
result = result.Div(ringScaled)
|
||
|
||
shardFactor, err := p.getSqrt(alloc.Shards)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
if shardFactor.IsZero() {
|
||
return errors.New("divisor is zero")
|
||
}
|
||
|
||
result = result.Div(shardFactor)
|
||
output[proverIndex] = output[proverIndex].Add(
|
||
output[proverIndex],
|
||
result.BigInt(),
|
||
)
|
||
}
|
||
|
||
return nil
|
||
})
|
||
}
|
||
|
||
if err := eg.Wait(); err != nil {
|
||
return nil, errors.Wrap(err, "calculate")
|
||
}
|
||
|
||
return output, nil
|
||
}
|
||
|
||
// MinimalAllocationOptimizedProofOfMeaningfulWorkRewardIssuance reduces
|
||
// allocations while maintaining compatibility – low memory environments may
|
||
// want to use this, but it is left unused.
|
||
type MinimalAllocationOptimizedProofOfMeaningfulWorkRewardIssuance struct {
|
||
sqrtCache map[uint64]decimal.Decimal
|
||
halfDecimal decimal.Decimal
|
||
worldStateCache map[uint64]decimal.Decimal
|
||
cacheLock sync.RWMutex
|
||
}
|
||
|
||
func NewMinOptRewardIssuance() *MinimalAllocationOptimizedProofOfMeaningfulWorkRewardIssuance {
|
||
halfRat := big.NewRat(1, 2)
|
||
return &MinimalAllocationOptimizedProofOfMeaningfulWorkRewardIssuance{
|
||
sqrtCache: make(map[uint64]decimal.Decimal),
|
||
halfDecimal: decimal.NewFromBigRat(halfRat, 53),
|
||
worldStateCache: make(map[uint64]decimal.Decimal),
|
||
}
|
||
}
|
||
|
||
func (p *MinimalAllocationOptimizedProofOfMeaningfulWorkRewardIssuance) Calculate(
|
||
difficulty uint64,
|
||
worldStateBytes uint64,
|
||
units uint64,
|
||
provers []map[string]*consensus.ProverAllocation,
|
||
) ([]*big.Int, error) {
|
||
basis := PomwBasis(difficulty, worldStateBytes, units)
|
||
basisDecimal := decimal.NewFromBigInt(basis, 0)
|
||
|
||
p.cacheLock.RLock()
|
||
worldStateBytesDecimal, exists := p.worldStateCache[worldStateBytes]
|
||
p.cacheLock.RUnlock()
|
||
|
||
if !exists {
|
||
p.cacheLock.Lock()
|
||
worldStateBytesDecimal = decimal.NewFromUint64(worldStateBytes)
|
||
p.worldStateCache[worldStateBytes] = worldStateBytesDecimal
|
||
p.cacheLock.Unlock()
|
||
}
|
||
|
||
uniqueShards := make(map[uint64]bool)
|
||
for _, prover := range provers {
|
||
for _, alloc := range prover {
|
||
uniqueShards[alloc.Shards] = true
|
||
}
|
||
}
|
||
|
||
missingSqrts := make(map[uint64]decimal.Decimal)
|
||
p.cacheLock.RLock()
|
||
for shards := range uniqueShards {
|
||
if _, exists := p.sqrtCache[shards]; !exists {
|
||
missingSqrts[shards] = decimal.Zero // Mark as needing computation
|
||
}
|
||
}
|
||
p.cacheLock.RUnlock()
|
||
|
||
if len(missingSqrts) > 0 {
|
||
for shards := range missingSqrts {
|
||
shardFactor := decimal.NewFromUint64(shards)
|
||
sqrt, err := shardFactor.PowWithPrecision(p.halfDecimal, 53)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
missingSqrts[shards] = sqrt
|
||
}
|
||
|
||
p.cacheLock.Lock()
|
||
for shards, sqrt := range missingSqrts {
|
||
p.sqrtCache[shards] = sqrt
|
||
}
|
||
p.cacheLock.Unlock()
|
||
}
|
||
|
||
output := make([]*big.Int, len(provers))
|
||
|
||
eg := errgroup.Group{}
|
||
eg.SetLimit(runtime.GOMAXPROCS(0))
|
||
|
||
ringDivisors := make([]decimal.Decimal, 256)
|
||
ringDivisors[0] = decimal.NewFromInt(1)
|
||
for i := uint8(1); i < 255; i++ {
|
||
ringDivisors[i] = decimal.NewFromInt(int64(1) << i)
|
||
}
|
||
|
||
for i := range provers {
|
||
proverIndex := i
|
||
eg.Go(func() error {
|
||
accumulator := decimal.Zero
|
||
|
||
for _, alloc := range provers[proverIndex] {
|
||
factor := decimal.NewFromUint64(alloc.StateSize)
|
||
factor = factor.Div(worldStateBytesDecimal)
|
||
|
||
result := factor.Mul(basisDecimal)
|
||
divisor := ringDivisors[alloc.Ring+1]
|
||
if divisor.IsZero() {
|
||
continue
|
||
}
|
||
result = result.Div(divisor)
|
||
|
||
p.cacheLock.RLock()
|
||
shardFactor := p.sqrtCache[alloc.Shards]
|
||
p.cacheLock.RUnlock()
|
||
|
||
if shardFactor.IsZero() {
|
||
return errors.New("divisor is zero")
|
||
}
|
||
|
||
result = result.Div(shardFactor)
|
||
accumulator = accumulator.Add(result)
|
||
}
|
||
|
||
output[proverIndex] = accumulator.BigInt()
|
||
return nil
|
||
})
|
||
}
|
||
|
||
if err := eg.Wait(); err != nil {
|
||
return nil, errors.Wrap(err, "calculate")
|
||
}
|
||
|
||
return output, nil
|
||
}
|