mirror of
https://github.com/QuilibriumNetwork/ceremonyclient.git
synced 2026-02-21 18:37:26 +08:00
441 lines
8.4 KiB
Go
441 lines
8.4 KiB
Go
//
|
|
// Copyright (c) 2020-2021, 2023 Markku Rossi
|
|
//
|
|
// All rights reserved.
|
|
//
|
|
|
|
package circuit
|
|
|
|
import (
|
|
"crypto/aes"
|
|
"crypto/cipher"
|
|
"crypto/rand"
|
|
"fmt"
|
|
"time"
|
|
|
|
"source.quilibrium.com/quilibrium/monorepo/bedlam/ot"
|
|
"source.quilibrium.com/quilibrium/monorepo/bedlam/p2p"
|
|
)
|
|
|
|
const (
|
|
// StreamDebug controls the debugging output of the streaming
|
|
// garbling.
|
|
StreamDebug = false
|
|
)
|
|
|
|
// Streaming is a streaming garbled circuit garbler.
|
|
type Streaming struct {
|
|
conn *p2p.Conn
|
|
key []byte
|
|
alg cipher.Block
|
|
r ot.Label
|
|
wires []ot.Wire
|
|
tmp []ot.Wire
|
|
in []Wire
|
|
out []Wire
|
|
firstTmp Wire
|
|
firstOut Wire
|
|
}
|
|
|
|
// NewStreaming creates a new streaming garbled circuit garbler.
|
|
func NewStreaming(key []byte, inputs []Wire, conn *p2p.Conn) (
|
|
*Streaming, error) {
|
|
|
|
r, err := ot.NewLabel(rand.Reader)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
r.SetS(true)
|
|
|
|
alg, err := aes.NewCipher(key)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
stream := &Streaming{
|
|
conn: conn,
|
|
key: key,
|
|
alg: alg,
|
|
r: r,
|
|
}
|
|
|
|
stream.ensureWires(maxWire(0, inputs))
|
|
|
|
// Assing all input wires.
|
|
for i := 0; i < len(inputs); i++ {
|
|
w, err := makeLabels(stream.r)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
stream.wires[inputs[i]] = w
|
|
}
|
|
|
|
return stream, nil
|
|
}
|
|
|
|
func maxWire(max Wire, wires []Wire) Wire {
|
|
for _, w := range wires {
|
|
if w > max {
|
|
max = w
|
|
}
|
|
}
|
|
return max
|
|
}
|
|
|
|
func (stream *Streaming) ensureWires(max Wire) {
|
|
// Verify that wires is big enough.
|
|
if len(stream.wires) <= int(max) {
|
|
var i int
|
|
for i = 65536; i <= int(max); i <<= 1 {
|
|
}
|
|
n := make([]ot.Wire, i)
|
|
copy(n, stream.wires)
|
|
stream.wires = n
|
|
}
|
|
}
|
|
|
|
func (stream *Streaming) initCircuit(c *Circuit, in, out []Wire) {
|
|
stream.ensureWires(maxWire(maxWire(0, in), out))
|
|
|
|
if len(stream.tmp) < c.NumWires {
|
|
stream.tmp = make([]ot.Wire, c.NumWires)
|
|
}
|
|
|
|
stream.in = in
|
|
stream.out = out
|
|
|
|
stream.firstTmp = Wire(len(in))
|
|
stream.firstOut = Wire(c.NumWires - len(out))
|
|
}
|
|
|
|
// GetInput gets the value of the input wire.
|
|
func (stream *Streaming) GetInput(w Wire) ot.Wire {
|
|
return stream.wires[w]
|
|
}
|
|
|
|
// GetInputs gets the specified input wire range.
|
|
func (stream *Streaming) GetInputs(offset, count int) []ot.Wire {
|
|
return stream.wires[offset : offset+count]
|
|
}
|
|
|
|
// Get gets the value of the wire.
|
|
func (stream *Streaming) Get(w Wire) (ot.Wire, Wire, bool) {
|
|
if w < stream.firstTmp {
|
|
index := stream.in[w]
|
|
return stream.wires[index], index, false
|
|
} else if w >= stream.firstOut {
|
|
index := stream.out[w-stream.firstOut]
|
|
return stream.wires[index], index, false
|
|
} else {
|
|
return stream.tmp[w], w, true
|
|
}
|
|
}
|
|
|
|
// Set sets the value of the wire.
|
|
func (stream *Streaming) Set(w Wire, val ot.Wire) (index Wire, tmp bool) {
|
|
if w < stream.firstTmp {
|
|
index = stream.in[w]
|
|
stream.wires[index] = val
|
|
} else if w >= stream.firstOut {
|
|
index = stream.out[w-stream.firstOut]
|
|
stream.wires[index] = val
|
|
} else {
|
|
index = w
|
|
tmp = true
|
|
stream.tmp[w] = val
|
|
}
|
|
return index, tmp
|
|
}
|
|
|
|
// Garble garbles the circuit and streams the garbled tables into the
|
|
// stream.
|
|
func (stream *Streaming) Garble(c *Circuit, in, out []Wire) (
|
|
time.Duration, time.Duration, error) {
|
|
if StreamDebug {
|
|
fmt.Printf(" - Streaming.Garble: in=%v, out=%v\n", in, out)
|
|
}
|
|
|
|
start := time.Now()
|
|
|
|
stream.initCircuit(c, in, out)
|
|
|
|
// Garble gates.
|
|
|
|
var data ot.LabelData
|
|
var id uint32
|
|
var table [4]ot.Label
|
|
|
|
mid := time.Now()
|
|
|
|
for i := 0; i < len(c.Gates); i++ {
|
|
gate := &c.Gates[i]
|
|
err := stream.conn.NeedSpace(512)
|
|
if err != nil {
|
|
return 0, 0, err
|
|
}
|
|
err = stream.garbleGate(gate, &id, table[:], &data,
|
|
stream.conn.WriteBuf, &stream.conn.WritePos)
|
|
if err != nil {
|
|
return 0, 0, err
|
|
}
|
|
}
|
|
return mid.Sub(start), time.Now().Sub(mid), nil
|
|
}
|
|
|
|
// GarbleGate garbles the gate and streams it to the stream.
|
|
func (stream *Streaming) garbleGate(g *Gate, idp *uint32,
|
|
table []ot.Label, data *ot.LabelData, buf []byte, bufpos *int) error {
|
|
|
|
var a, b, c ot.Wire
|
|
var aIndex, bIndex, cIndex Wire
|
|
var aTmp, bTmp, cTmp bool
|
|
|
|
table = table[0:4]
|
|
var tableStart, tableCount, wireCount int
|
|
|
|
// Inputs.
|
|
switch g.Op {
|
|
case XOR, XNOR, AND, OR:
|
|
b, bIndex, bTmp = stream.Get(g.Input1)
|
|
fallthrough
|
|
|
|
case INV:
|
|
a, aIndex, aTmp = stream.Get(g.Input0)
|
|
|
|
default:
|
|
return fmt.Errorf("invalid gate type %s", g.Op)
|
|
}
|
|
|
|
// Output.
|
|
switch g.Op {
|
|
case XOR:
|
|
l0 := a.L0
|
|
l0.Xor(b.L0)
|
|
|
|
l1 := l0
|
|
l1.Xor(stream.r)
|
|
c = ot.Wire{
|
|
L0: l0,
|
|
L1: l1,
|
|
}
|
|
|
|
case XNOR:
|
|
l0 := a.L0
|
|
l0.Xor(b.L0)
|
|
|
|
l1 := l0
|
|
l1.Xor(stream.r)
|
|
c = ot.Wire{
|
|
L0: l1,
|
|
L1: l0,
|
|
}
|
|
|
|
case AND:
|
|
pa := a.L0.S()
|
|
pb := b.L0.S()
|
|
|
|
j0 := *idp
|
|
j1 := *idp + 1
|
|
*idp = *idp + 2
|
|
|
|
// First half gate.
|
|
tg := encryptHalf(stream.alg, a.L0, j0, data)
|
|
tg.Xor(encryptHalf(stream.alg, a.L1, j0, data))
|
|
if pb {
|
|
tg.Xor(stream.r)
|
|
}
|
|
wg0 := encryptHalf(stream.alg, a.L0, j0, data)
|
|
if pa {
|
|
wg0.Xor(tg)
|
|
}
|
|
|
|
// Second half gate.
|
|
te := encryptHalf(stream.alg, b.L0, j1, data)
|
|
te.Xor(encryptHalf(stream.alg, b.L1, j1, data))
|
|
te.Xor(a.L0)
|
|
we0 := encryptHalf(stream.alg, b.L0, j1, data)
|
|
if pb {
|
|
we0.Xor(te)
|
|
we0.Xor(a.L0)
|
|
}
|
|
|
|
// Combine halves
|
|
l0 := wg0
|
|
l0.Xor(we0)
|
|
|
|
l1 := l0
|
|
l1.Xor(stream.r)
|
|
|
|
c = ot.Wire{
|
|
L0: l0,
|
|
L1: l1,
|
|
}
|
|
table[0] = tg
|
|
table[1] = te
|
|
tableCount = 2
|
|
|
|
case OR, INV:
|
|
// Row reduction creates labels below so that the first row is
|
|
// all zero.
|
|
|
|
default:
|
|
panic("invalid gate type")
|
|
}
|
|
|
|
switch g.Op {
|
|
case XOR, XNOR:
|
|
// Free XOR.
|
|
wireCount = 3
|
|
|
|
case AND:
|
|
// Half AND garbled above.
|
|
wireCount = 3
|
|
|
|
case OR:
|
|
// a b c
|
|
// -----
|
|
// 0 0 0
|
|
// 0 1 1
|
|
// 1 0 1
|
|
// 1 1 1
|
|
id := *idp
|
|
*idp = *idp + 1
|
|
table[idx(a.L0, b.L0)] = encrypt(stream.alg, a.L0, b.L0, c.L0, id, data)
|
|
table[idx(a.L0, b.L1)] = encrypt(stream.alg, a.L0, b.L1, c.L1, id, data)
|
|
table[idx(a.L1, b.L0)] = encrypt(stream.alg, a.L1, b.L0, c.L1, id, data)
|
|
table[idx(a.L1, b.L1)] = encrypt(stream.alg, a.L1, b.L1, c.L1, id, data)
|
|
|
|
// Row reduction. Make first table all zero so we don't have
|
|
// to transmit it.
|
|
|
|
l0Index := idx(a.L0, b.L0)
|
|
|
|
c.L0 = table[0]
|
|
c.L1 = table[0]
|
|
|
|
if l0Index == 0 {
|
|
c.L1.Xor(stream.r)
|
|
} else {
|
|
c.L0.Xor(stream.r)
|
|
}
|
|
for i := 0; i < 4; i++ {
|
|
if i == l0Index {
|
|
table[i].Xor(c.L0)
|
|
} else {
|
|
table[i].Xor(c.L1)
|
|
}
|
|
}
|
|
|
|
tableStart = 1
|
|
tableCount = 3
|
|
wireCount = 3
|
|
|
|
case INV:
|
|
// a b c
|
|
// -----
|
|
// 0 1
|
|
// 1 0
|
|
zero := ot.Label{}
|
|
id := *idp
|
|
*idp = *idp + 1
|
|
table[idxUnary(a.L0)] = encrypt(stream.alg, a.L0, zero, c.L1, id, data)
|
|
table[idxUnary(a.L1)] = encrypt(stream.alg, a.L1, zero, c.L0, id, data)
|
|
|
|
l0Index := idxUnary(a.L0)
|
|
|
|
c.L0 = table[0]
|
|
c.L1 = table[0]
|
|
|
|
if l0Index == 0 {
|
|
c.L0.Xor(stream.r)
|
|
} else {
|
|
c.L1.Xor(stream.r)
|
|
}
|
|
for i := 0; i < 2; i++ {
|
|
if i == l0Index {
|
|
table[i].Xor(c.L1)
|
|
} else {
|
|
table[i].Xor(c.L0)
|
|
}
|
|
}
|
|
|
|
tableStart = 1
|
|
tableCount = 1
|
|
wireCount = 2
|
|
|
|
default:
|
|
return fmt.Errorf("invalid operand %s", g.Op)
|
|
}
|
|
|
|
if g.Output < stream.firstTmp {
|
|
cIndex = stream.in[g.Output]
|
|
stream.wires[cIndex] = c
|
|
} else if g.Output >= stream.firstOut {
|
|
cIndex = stream.out[g.Output-stream.firstOut]
|
|
stream.wires[cIndex] = c
|
|
} else {
|
|
cIndex = g.Output
|
|
cTmp = true
|
|
stream.tmp[g.Output] = c
|
|
}
|
|
|
|
op := byte(g.Op)
|
|
if aTmp {
|
|
op |= 0b10000000
|
|
}
|
|
if bTmp {
|
|
op |= 0b01000000
|
|
}
|
|
if cTmp {
|
|
op |= 0b00100000
|
|
}
|
|
if aIndex <= 0xffff && bIndex <= 0xffff && cIndex <= 0xffff {
|
|
op |= 0b00010000
|
|
buf[*bufpos] = op
|
|
*bufpos = *bufpos + 1
|
|
|
|
switch wireCount {
|
|
case 3:
|
|
bo.PutUint16(buf[*bufpos+0:], uint16(aIndex))
|
|
bo.PutUint16(buf[*bufpos+2:], uint16(bIndex))
|
|
bo.PutUint16(buf[*bufpos+4:], uint16(cIndex))
|
|
*bufpos = *bufpos + 6
|
|
|
|
case 2:
|
|
bo.PutUint16(buf[*bufpos+0:], uint16(aIndex))
|
|
bo.PutUint16(buf[*bufpos+2:], uint16(cIndex))
|
|
*bufpos = *bufpos + 4
|
|
|
|
default:
|
|
panic(fmt.Sprintf("invalid wire count: %d", wireCount))
|
|
}
|
|
} else {
|
|
buf[*bufpos] = op
|
|
*bufpos = *bufpos + 1
|
|
|
|
switch wireCount {
|
|
case 3:
|
|
bo.PutUint32(buf[*bufpos+0:], uint32(aIndex))
|
|
bo.PutUint32(buf[*bufpos+4:], uint32(bIndex))
|
|
bo.PutUint32(buf[*bufpos+8:], uint32(cIndex))
|
|
*bufpos = *bufpos + 12
|
|
|
|
case 2:
|
|
bo.PutUint32(buf[*bufpos+0:], uint32(aIndex))
|
|
bo.PutUint32(buf[*bufpos+4:], uint32(cIndex))
|
|
*bufpos = *bufpos + 8
|
|
|
|
default:
|
|
panic(fmt.Sprintf("invalid wire count: %d", wireCount))
|
|
}
|
|
}
|
|
|
|
for i := 0; i < tableCount; i++ {
|
|
bytes := table[tableStart+i].Bytes(data)
|
|
copy(buf[*bufpos:], bytes)
|
|
*bufpos = *bufpos + len(bytes)
|
|
}
|
|
|
|
return nil
|
|
}
|