ceremonyclient/bedlam/circuit/stream_evaluator.go
Cassandra Heart dbd95bd9e9
v2.1.0 (#439)
* v2.1.0 [omit consensus and adjacent] - this commit will be amended with the full release after the file copy is complete

* 2.1.0 main node rollup
2025-09-30 02:48:15 -05:00

502 lines
11 KiB
Go

//
// Copyright (c) 2020-2024 Markku Rossi
//
// All rights reserved.
//
package circuit
import (
"crypto/aes"
"crypto/cipher"
"fmt"
"math/big"
"time"
"source.quilibrium.com/quilibrium/monorepo/bedlam/ot"
"source.quilibrium.com/quilibrium/monorepo/bedlam/p2p"
"source.quilibrium.com/quilibrium/monorepo/bedlam/types"
)
// Protocol operation codes.
const (
OpResult = iota
OpCircuit
OpReturn
)
// StreamEval is a streaming garbled circuit evaluator.
type StreamEval struct {
key []byte
alg cipher.Block
wires []ot.Label
tmp []ot.Label
}
// NewStreamEval creates a new streaming garbled circuit evaluator.
func NewStreamEval(key []byte, numInputs, numOutputs int) (*StreamEval, error) {
alg, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
return &StreamEval{
key: key,
alg: alg,
wires: make([]ot.Label, numInputs+numOutputs),
}, nil
}
// Get gets the value of the wire.
func (stream *StreamEval) Get(tmp bool, w int) ot.Label {
if tmp {
return stream.tmp[w]
}
return stream.wires[w]
}
// GetInputs gets the specified input wire range.
func (stream *StreamEval) GetInputs(offset, count int) []ot.Label {
return stream.wires[offset : offset+count]
}
// Set sets the value of the wire.
func (stream *StreamEval) Set(tmp bool, w int, label ot.Label) {
if tmp {
stream.tmp[w] = label
} else {
stream.wires[w] = label
}
}
// InitCircuit initializes the stream evaluator with wires.
func (stream *StreamEval) InitCircuit(numWires, numTmpWires int) {
if numWires > len(stream.wires) {
var size int
for size = 1024; size < numWires; size *= 2 {
}
n := make([]ot.Label, size)
copy(n, stream.wires)
stream.wires = n
}
if numTmpWires > len(stream.tmp) {
var size int
for size = 1024; size < numTmpWires; size *= 2 {
}
stream.tmp = make([]ot.Label, size)
}
}
// StreamEvaluator runs the stream evaluator on the connection.
func StreamEvaluator(conn *p2p.Conn, oti ot.OT, inputFlag []string,
verbose bool) (IO, []*big.Int, error) {
timing := NewTiming()
// Receive program info.
if verbose {
fmt.Printf(" - Waiting for program info...\n")
}
key, err := conn.ReceiveData()
if err != nil {
return nil, nil, err
}
alg, err := aes.NewCipher(key)
if err != nil {
return nil, nil, err
}
// Peer input.
in1, err := receiveArgument(conn)
if err != nil {
return nil, nil, err
}
// Our input.
in2, err := receiveArgument(conn)
if err != nil {
return nil, nil, err
}
inputs, err := in2.Parse(inputFlag)
if err != nil {
return nil, nil, err
}
// Program outputs.
numOutputs, err := conn.ReceiveUint32()
if err != nil {
return nil, nil, err
}
var outputs IO
for i := 0; i < numOutputs; i++ {
out, err := receiveArgument(conn)
if err != nil {
return nil, nil, err
}
outputs = append(outputs, out)
}
numSteps, err := conn.ReceiveUint32()
if err != nil {
return nil, nil, err
}
fmt.Printf(" - In1: %s\n", in1)
fmt.Printf(" + In2: %s\n", in2)
fmt.Printf(" - Out: %s\n", outputs)
fmt.Printf(" - In: %s\n", inputFlag)
streaming, err := NewStreamEval(key, int(in1.Type.Bits+in2.Type.Bits),
outputs.Size())
if err != nil {
return nil, nil, err
}
// Receive peer inputs.
var label ot.Label
var labelData ot.LabelData
for w := 0; w < int(in1.Type.Bits); w++ {
err := conn.ReceiveLabel(&label, &labelData)
if err != nil {
return nil, nil, err
}
streaming.Set(false, w, label)
}
// Init oblivious transfer.
err = oti.InitReceiver(conn)
if err != nil {
return nil, nil, err
}
ioStats := conn.Stats.Sum()
timing.Sample("Init", []string{FileSize(ioStats).String()})
// Query our inputs.
if verbose {
fmt.Printf(" - Querying our inputs...\n")
}
flags := make([]bool, in2.Type.Bits)
for i := 0; i < int(in2.Type.Bits); i++ {
if inputs.Bit(i) == 1 {
flags[i] = true
}
}
inputLabels := streaming.GetInputs(int(in1.Type.Bits), int(in2.Type.Bits))
if err := oti.Receive(flags, inputLabels); err != nil {
return nil, nil, err
}
xfer := conn.Stats.Sum() - ioStats
ioStats = conn.Stats.Sum()
timing.Sample("Inputs", []string{FileSize(xfer).String()})
ws := func(i int, tmp bool) string {
if tmp {
return fmt.Sprintf("~%d", i)
}
return fmt.Sprintf("w%d", i)
}
// Evaluate program.
if verbose {
fmt.Printf(" - Evaluating program...\n")
}
var garbled [4]ot.Label
var lastStep int
var rawResult *big.Int
start := time.Now()
lastReport := start
loop:
for {
op, err := conn.ReceiveUint32()
if err != nil {
return nil, nil, err
}
switch op {
case OpCircuit:
step, err := conn.ReceiveUint32()
if err != nil {
return nil, nil, err
}
numGates, err := conn.ReceiveUint32()
if err != nil {
return nil, nil, err
}
numTmpWires, err := conn.ReceiveUint32()
if err != nil {
return nil, nil, err
}
numWires, err := conn.ReceiveUint32()
if err != nil {
return nil, nil, err
}
if step-lastStep >= 10 && verbose {
lastStep = step
now := time.Now()
if now.Sub(lastReport) > time.Second*5 {
lastReport = now
elapsed := time.Now().Sub(start)
done := float64(step) / float64(numSteps)
if done > 0 {
total := time.Duration(float64(elapsed) / done)
progress := fmt.Sprintf("%d/%d", step, numSteps)
remaining := fmt.Sprintf("%24s", total-elapsed)
fmt.Printf("%-14s\t%s remaining\tETA %s\n",
progress, remaining,
start.Add(total).Format(time.Stamp))
} else {
fmt.Printf("%d/%d\n", step, numSteps)
}
}
}
streaming.InitCircuit(numWires, numTmpWires)
var id uint32
for i := 0; i < numGates; i++ {
gop, err := conn.ReceiveByte()
if err != nil {
return nil, nil, err
}
var aTmp, bTmp, cTmp bool
if gop&0b10000000 != 0 {
aTmp = true
}
if gop&0b01000000 != 0 {
bTmp = true
}
if gop&0b00100000 != 0 {
cTmp = true
}
var recvWire func() (int, error)
if gop&0b00010000 != 0 {
recvWire = conn.ReceiveUint16
} else {
recvWire = conn.ReceiveUint32
}
gop &^= 0b11110000
var aIndex, bIndex, cIndex int
var tableCount int
switch Operation(gop) {
case XOR, XNOR, AND, OR:
aIndex, err = recvWire()
if err != nil {
return nil, nil, err
}
bIndex, err = recvWire()
if err != nil {
return nil, nil, err
}
cIndex, err = recvWire()
if err != nil {
return nil, nil, err
}
case INV:
aIndex, err = recvWire()
if err != nil {
return nil, nil, err
}
cIndex, err = recvWire()
if err != nil {
return nil, nil, err
}
default:
return nil, nil, fmt.Errorf("invalid operation %s",
Operation(gop))
}
switch Operation(gop) {
case XOR, XNOR:
tableCount = 0
case INV:
tableCount = 1
case AND:
tableCount = 2
case OR:
tableCount = 3
}
for c := 0; c < tableCount; c++ {
err = conn.ReceiveLabel(&label, &labelData)
if err != nil {
return nil, nil, err
}
garbled[c] = label
}
var a, b, c ot.Label
switch Operation(gop) {
case XOR, XNOR, AND, OR:
if StreamDebug {
fmt.Printf("Gate%d:\t %s %s %s %s\n", i,
ws(aIndex, aTmp), ws(bIndex, bTmp),
Operation(gop), ws(cIndex, cTmp))
}
a = streaming.Get(aTmp, aIndex)
b = streaming.Get(bTmp, bIndex)
case INV:
if StreamDebug {
fmt.Printf("Gate%d:\t %s %s %s\n", i,
ws(aIndex, aTmp), Operation(gop), ws(bIndex, bTmp))
}
a = streaming.Get(aTmp, aIndex)
}
var output ot.Label
switch Operation(gop) {
case XOR, XNOR:
a.Xor(b)
output = a
case AND:
if tableCount != 2 {
return nil, nil,
fmt.Errorf("corrupted ciruit: AND table size: %d",
tableCount)
}
sa := a.S()
sb := b.S()
j0 := id
j1 := id + 1
id += 2
tg := garbled[0]
te := garbled[1]
wg := encryptHalf(alg, a, j0, &labelData)
if sa {
wg.Xor(tg)
}
we := encryptHalf(alg, b, j1, &labelData)
if sb {
we.Xor(te)
we.Xor(a)
}
output = wg
output.Xor(we)
case OR:
index := idx(a, b)
if index > 0 {
// First row is zero and not transmitted.
index--
if index >= tableCount {
return nil, nil,
fmt.Errorf("corrupted circuit: index %d >= %d",
index, tableCount)
}
c = garbled[index]
}
output = decrypt(alg, a, b, id, c, &labelData)
id++
case INV:
index := idxUnary(a)
if index > 0 {
// First row is zero and not transmitted.
index--
if index >= tableCount {
return nil, nil,
fmt.Errorf("corrupted circuit: index %d >= %d",
index, tableCount)
}
c = garbled[index]
}
output = decrypt(alg, a, b, id, c, &labelData)
id++
}
streaming.Set(cTmp, cIndex, output)
}
case OpReturn:
xfer := conn.Stats.Sum() - ioStats
ioStats = conn.Stats.Sum()
timing.Sample("Eval", []string{FileSize(xfer).String()})
var labels []ot.Label
for i := 0; i < outputs.Size(); i++ {
id, err := conn.ReceiveUint32()
if err != nil {
return nil, nil, err
}
label := streaming.Get(false, id)
labels = append(labels, label)
}
// Resolve result values.
if err := conn.SendUint32(OpResult); err != nil {
return nil, nil, err
}
var labelData ot.LabelData
for _, l := range labels {
if err := conn.SendLabel(l, &labelData); err != nil {
return nil, nil, err
}
}
if err := conn.Flush(); err != nil {
return nil, nil, err
}
result, err := conn.ReceiveData()
if err != nil {
return nil, nil, err
}
rawResult = new(big.Int).SetBytes(result)
break loop
default:
return nil, nil, fmt.Errorf("unknown operation %d", op)
}
}
xfer = conn.Stats.Sum() - ioStats
timing.Sample("Result", []string{FileSize(xfer).String()})
if verbose {
timing.Print(conn.Stats)
}
return outputs, outputs.Split(rawResult), nil
}
func receiveArgument(conn *p2p.Conn) (arg IOArg, err error) {
name, err := conn.ReceiveString()
if err != nil {
return arg, err
}
t, err := conn.ReceiveString()
if err != nil {
return arg, err
}
size, err := conn.ReceiveUint32()
if err != nil {
return arg, err
}
arg.Name = name
arg.Type, err = types.Parse(t)
if err != nil {
return arg, err
}
arg.Type.Bits = types.Size(size)
if arg.Type.Type == types.TSlice {
arg.Type.ArraySize = arg.Type.Bits / arg.Type.ElementType.Bits
}
count, err := conn.ReceiveUint32()
if err != nil {
return arg, err
}
for i := 0; i < count; i++ {
a, err := receiveArgument(conn)
if err != nil {
return arg, err
}
arg.Compound = append(arg.Compound, a)
}
return arg, nil
}