mirror of
https://github.com/QuilibriumNetwork/ceremonyclient.git
synced 2026-02-21 18:37:26 +08:00
625 lines
14 KiB
Go
625 lines
14 KiB
Go
//
|
|
// Copyright (c) 2020-2024 Markku Rossi
|
|
//
|
|
// All rights reserved.
|
|
//
|
|
|
|
package ssa
|
|
|
|
import (
|
|
"fmt"
|
|
|
|
"source.quilibrium.com/quilibrium/monorepo/bedlam/circuit"
|
|
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler/circuits"
|
|
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler/utils"
|
|
"source.quilibrium.com/quilibrium/monorepo/bedlam/types"
|
|
)
|
|
|
|
// CompileCircuit compiles the QCL program into a boolean circuit.
|
|
func (prog *Program) CompileCircuit(params *utils.Params) (
|
|
*circuit.Circuit, error) {
|
|
|
|
calloc := circuits.NewAllocator()
|
|
|
|
cc, err := circuits.NewCompiler(params, calloc, prog.Inputs, prog.Outputs,
|
|
prog.InputWires, prog.OutputWires)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
err = prog.DefineConstants(cc.ZeroWire(), cc.OneWire())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if params.Verbose {
|
|
fmt.Printf("Creating circuit...\n")
|
|
}
|
|
err = prog.Circuit(cc)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if params.Verbose {
|
|
fmt.Printf("Compiling circuit...\n")
|
|
}
|
|
cc.ConstPropagate()
|
|
cc.ShortCircuitXORZero()
|
|
if params.OptPruneGates {
|
|
orig := float64(len(cc.Gates))
|
|
pruned := cc.Prune()
|
|
if params.Verbose {
|
|
fmt.Printf(" - Pruned %d gates (%.2f%%)\n", pruned,
|
|
float64(pruned)/orig*100)
|
|
}
|
|
}
|
|
circ := cc.Compile()
|
|
if params.CircOut != nil {
|
|
if params.Verbose {
|
|
fmt.Printf("Serializing circuit...\n")
|
|
}
|
|
err = circ.MarshalFormat(params.CircOut, params.CircFormat)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
if params.CircDotOut != nil {
|
|
circ.Dot(params.CircDotOut)
|
|
}
|
|
if params.CircSvgOut != nil {
|
|
circ.Svg(params.CircSvgOut)
|
|
}
|
|
|
|
return circ, nil
|
|
}
|
|
|
|
// Circuit creates the boolean circuits for the program steps.
|
|
func (prog *Program) Circuit(cc *circuits.Compiler) error {
|
|
|
|
for _, step := range prog.Steps {
|
|
instr := step.Instr
|
|
var wires [][]*circuits.Wire
|
|
for idx, in := range instr.In {
|
|
if !in.Type.Concrete() {
|
|
return fmt.Errorf("%s: type %v of input %v not concrete",
|
|
instr, in, idx)
|
|
}
|
|
w, err := prog.walloc.Wires(in, in.Type.Bits)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if len(w) != int(in.Type.Bits) {
|
|
// Const values are cast to different value
|
|
// sizes. Make sure wire length matches type size.
|
|
cw := make([]*circuits.Wire, in.Type.Bits)
|
|
|
|
var pad *circuits.Wire
|
|
if in.Type.Type == types.TInt && len(w) > 0 {
|
|
// Sign extension.
|
|
pad = w[len(w)-1]
|
|
} else {
|
|
pad = cc.ZeroWire()
|
|
}
|
|
|
|
for bit := 0; bit < int(in.Type.Bits); bit++ {
|
|
if bit < len(w) {
|
|
cw[bit] = w[bit]
|
|
} else {
|
|
cw[bit] = pad
|
|
}
|
|
}
|
|
wires = append(wires, cw)
|
|
} else {
|
|
wires = append(wires, w)
|
|
}
|
|
}
|
|
switch instr.Op {
|
|
case Iadd, Uadd:
|
|
o, err := prog.walloc.Wires(*instr.Out, instr.Out.Type.Bits)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = circuits.NewAdder(cc, wires[0], wires[1], o)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
case Isub, Usub:
|
|
o, err := prog.walloc.Wires(*instr.Out, instr.Out.Type.Bits)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = circuits.NewSubtractor(cc, wires[0], wires[1], o)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
case Imult, Umult:
|
|
o, err := prog.walloc.Wires(*instr.Out, instr.Out.Type.Bits)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = circuits.NewMultiplier(cc, cc.Params.CircMultArrayTreshold,
|
|
wires[0], wires[1], o)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
case Idiv:
|
|
o, err := prog.walloc.Wires(*instr.Out, instr.Out.Type.Bits)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = circuits.NewIDivider(cc, wires[0], wires[1], o, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
case Udiv:
|
|
o, err := prog.walloc.Wires(*instr.Out, instr.Out.Type.Bits)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = circuits.NewUDivider(cc, wires[0], wires[1], o, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
case Imod:
|
|
o, err := prog.walloc.Wires(*instr.Out, instr.Out.Type.Bits)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = circuits.NewIDivider(cc, wires[0], wires[1], nil, o)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
case Umod:
|
|
o, err := prog.walloc.Wires(*instr.Out, instr.Out.Type.Bits)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = circuits.NewUDivider(cc, wires[0], wires[1], nil, o)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
case Concat:
|
|
o := make([]*circuits.Wire, instr.Out.Type.Bits)
|
|
for i := 0; i < len(wires[0]); i++ {
|
|
o[i] = wires[0][i]
|
|
}
|
|
for i := 0; i < len(wires[1]); i++ {
|
|
o[len(wires[0])+i] = wires[1][i]
|
|
}
|
|
prog.walloc.SetWires(*instr.Out, o)
|
|
|
|
case Lshift:
|
|
count, err := instr.In[1].ConstInt()
|
|
if err != nil {
|
|
return fmt.Errorf("%s: unsupported index type %T: %s",
|
|
instr.Op, instr.In[1], err)
|
|
}
|
|
if count < 0 {
|
|
return fmt.Errorf("%s: negative shift count %d",
|
|
instr.Op, count)
|
|
}
|
|
o := make([]*circuits.Wire, instr.Out.Type.Bits)
|
|
for bit := 0; bit < len(o); bit++ {
|
|
var w *circuits.Wire
|
|
if bit-int(count) >= 0 && bit-int(count) < len(wires[0]) {
|
|
w = wires[0][bit-int(count)]
|
|
} else {
|
|
w = cc.ZeroWire()
|
|
}
|
|
o[bit] = w
|
|
}
|
|
prog.walloc.SetWires(*instr.Out, o)
|
|
|
|
case Rshift, Srshift:
|
|
var signWire *circuits.Wire
|
|
if instr.Op == Srshift {
|
|
signWire = wires[0][len(wires[0])-1]
|
|
} else {
|
|
signWire = cc.ZeroWire()
|
|
}
|
|
count, err := instr.In[1].ConstInt()
|
|
if err != nil {
|
|
return fmt.Errorf("%s: unsupported index type %T: %s",
|
|
instr.Op, instr.In[1], err)
|
|
}
|
|
if count < 0 {
|
|
return fmt.Errorf("%s: negative shift count %d",
|
|
instr.Op, count)
|
|
}
|
|
o := make([]*circuits.Wire, instr.Out.Type.Bits)
|
|
for bit := 0; bit < len(o); bit++ {
|
|
var w *circuits.Wire
|
|
if bit+int(count) < len(wires[0]) {
|
|
w = wires[0][bit+int(count)]
|
|
} else {
|
|
w = signWire
|
|
}
|
|
o[bit] = w
|
|
}
|
|
prog.walloc.SetWires(*instr.Out, o)
|
|
|
|
case Slice:
|
|
from, err := instr.In[1].ConstInt()
|
|
if err != nil {
|
|
return fmt.Errorf("%s: unsupported index type %T: %s",
|
|
instr.Op, instr.In[1], err)
|
|
}
|
|
|
|
to, err := instr.In[2].ConstInt()
|
|
if err != nil {
|
|
return fmt.Errorf("%s: unsupported index type %T: %s",
|
|
instr.Op, instr.In[2], err)
|
|
}
|
|
if from >= to {
|
|
return fmt.Errorf("%s: bounds out of range [%d:%d]",
|
|
instr.Op, from, to)
|
|
}
|
|
o := make([]*circuits.Wire, instr.Out.Type.Bits)
|
|
|
|
for bit := from; bit < to; bit++ {
|
|
var w *circuits.Wire
|
|
if int(bit) < len(wires[0]) {
|
|
w = wires[0][bit]
|
|
} else {
|
|
w = cc.ZeroWire()
|
|
}
|
|
o[bit-from] = w
|
|
}
|
|
// Make sure all output bits are wired.
|
|
for bit := to - from; int(bit) < len(o); bit++ {
|
|
o[bit] = cc.ZeroWire()
|
|
}
|
|
prog.walloc.SetWires(*instr.Out, o)
|
|
|
|
case Index:
|
|
o, err := prog.walloc.Wires(*instr.Out, instr.Out.Type.Bits)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
offset, err := instr.In[1].ConstInt()
|
|
if err != nil {
|
|
return fmt.Errorf("%s: unsupported offset type %T: %s",
|
|
instr.Op, instr.In[1], err)
|
|
}
|
|
err = circuits.NewIndex(cc, int(instr.In[0].Type.ElementType.Bits),
|
|
wires[0][offset:], wires[2], o)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
case Ilt, Ult:
|
|
o, err := prog.walloc.Wires(*instr.Out, instr.Out.Type.Bits)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = circuits.NewLtComparator(cc, wires[0], wires[1], o)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
case Ile, Ule:
|
|
o, err := prog.walloc.Wires(*instr.Out, instr.Out.Type.Bits)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = circuits.NewLeComparator(cc, wires[0], wires[1], o)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
case Igt, Ugt:
|
|
o, err := prog.walloc.Wires(*instr.Out, instr.Out.Type.Bits)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = circuits.NewGtComparator(cc, wires[0], wires[1], o)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
case Ige, Uge:
|
|
o, err := prog.walloc.Wires(*instr.Out, instr.Out.Type.Bits)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = circuits.NewGeComparator(cc, wires[0], wires[1], o)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
case Eq:
|
|
o, err := prog.walloc.Wires(*instr.Out, instr.Out.Type.Bits)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = circuits.NewEqComparator(cc, wires[0], wires[1], o)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
case Neq:
|
|
o, err := prog.walloc.Wires(*instr.Out, instr.Out.Type.Bits)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = circuits.NewNeqComparator(cc, wires[0], wires[1], o)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
case Bts:
|
|
index, err := instr.In[1].ConstInt()
|
|
if err != nil {
|
|
return fmt.Errorf("%s unsupported index type %T: %s",
|
|
instr.Op, instr.In[1], err)
|
|
}
|
|
o, err := prog.walloc.Wires(*instr.Out, instr.Out.Type.Bits)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = circuits.NewBitSetTest(cc, wires[0], index, o)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
case Btc:
|
|
index, err := instr.In[1].ConstInt()
|
|
if err != nil {
|
|
return fmt.Errorf("%s unsupported index type %T: %s",
|
|
instr.Op, instr.In[1], err)
|
|
}
|
|
o, err := prog.walloc.Wires(*instr.Out, instr.Out.Type.Bits)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = circuits.NewBitClrTest(cc, wires[0], index, o)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
case And:
|
|
o, err := prog.walloc.Wires(*instr.Out, instr.Out.Type.Bits)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = circuits.NewLogicalAND(cc, wires[0], wires[1], o)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
case Or:
|
|
o, err := prog.walloc.Wires(*instr.Out, instr.Out.Type.Bits)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = circuits.NewLogicalOR(cc, wires[0], wires[1], o)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
case Not:
|
|
o, err := prog.walloc.Wires(*instr.Out, instr.Out.Type.Bits)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for i := 0; i < int(instr.Out.Type.Bits); i++ {
|
|
cc.INV(wires[0][i], o[i])
|
|
}
|
|
|
|
case Band:
|
|
o, err := prog.walloc.Wires(*instr.Out, instr.Out.Type.Bits)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = circuits.NewBinaryAND(cc, wires[0], wires[1], o)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
case Bclr:
|
|
o, err := prog.walloc.Wires(*instr.Out, instr.Out.Type.Bits)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = circuits.NewBinaryClear(cc, wires[0], wires[1], o)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
case Bor:
|
|
o, err := prog.walloc.Wires(*instr.Out, instr.Out.Type.Bits)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = circuits.NewBinaryOR(cc, wires[0], wires[1], o)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
case Bxor:
|
|
o, err := prog.walloc.Wires(*instr.Out, instr.Out.Type.Bits)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = circuits.NewBinaryXOR(cc, wires[0], wires[1], o)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
case Mov, Smov:
|
|
var signWire *circuits.Wire
|
|
if instr.Op == Smov {
|
|
signWire = wires[0][len(wires[0])-1]
|
|
} else {
|
|
signWire = cc.ZeroWire()
|
|
}
|
|
|
|
o := make([]*circuits.Wire, instr.Out.Type.Bits)
|
|
|
|
for bit := 0; bit < int(instr.Out.Type.Bits); bit++ {
|
|
var w *circuits.Wire
|
|
if bit < len(wires[0]) {
|
|
w = wires[0][bit]
|
|
} else {
|
|
w = signWire
|
|
}
|
|
o[bit] = w
|
|
}
|
|
prog.walloc.SetWires(*instr.Out, o)
|
|
|
|
case Amov:
|
|
// v arr from to:
|
|
// array[from:to] = v
|
|
from, err := instr.In[2].ConstInt()
|
|
if err != nil {
|
|
return fmt.Errorf("%s: unsupported index type %T: %s",
|
|
instr.Op, instr.In[2], err)
|
|
}
|
|
to, err := instr.In[3].ConstInt()
|
|
if err != nil {
|
|
return fmt.Errorf("%s: unsupported index type %T: %s",
|
|
instr.Op, instr.In[3], err)
|
|
}
|
|
if from < 0 || from >= to {
|
|
return fmt.Errorf("%s: bounds out of range [%d:%d]",
|
|
instr.Op, from, to)
|
|
}
|
|
o := make([]*circuits.Wire, instr.Out.Type.Bits)
|
|
|
|
for bit := types.Size(0); bit < instr.Out.Type.Bits; bit++ {
|
|
var w *circuits.Wire
|
|
if bit < from || bit >= to {
|
|
if bit < types.Size(len(wires[1])) {
|
|
w = wires[1][bit]
|
|
} else {
|
|
w = cc.ZeroWire()
|
|
}
|
|
} else {
|
|
idx := bit - from
|
|
if idx < types.Size(len(wires[0])) {
|
|
w = wires[0][idx]
|
|
} else {
|
|
w = cc.ZeroWire()
|
|
}
|
|
}
|
|
o[bit] = w
|
|
}
|
|
prog.walloc.SetWires(*instr.Out, o)
|
|
|
|
case Phi:
|
|
o, err := prog.walloc.Wires(*instr.Out, instr.Out.Type.Bits)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = circuits.NewMUX(cc, wires[0], wires[1], wires[2], o)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
case Ret:
|
|
// Assign output wires.
|
|
for _, wg := range wires {
|
|
for _, w := range wg {
|
|
o := cc.Calloc.Wire()
|
|
cc.ID(w, o)
|
|
cc.OutputWires = append(cc.OutputWires, o)
|
|
}
|
|
}
|
|
for _, o := range cc.OutputWires {
|
|
o.SetOutput(true)
|
|
}
|
|
|
|
case Circ:
|
|
var circWires []*circuits.Wire
|
|
|
|
// Flatten input wires.
|
|
for wi, w := range wires {
|
|
circWires = append(circWires, w...)
|
|
for i := len(w); i < int(instr.Circ.Inputs[wi].Type.Bits); i++ {
|
|
// Zeroes for unset input wires.
|
|
zw := cc.ZeroWire()
|
|
circWires = append(circWires, zw)
|
|
}
|
|
}
|
|
|
|
// Flatten output wires.
|
|
var circOut []*circuits.Wire
|
|
|
|
for _, r := range instr.Ret {
|
|
o, err := prog.walloc.Wires(r, r.Type.Bits)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
circOut = append(circOut, o...)
|
|
}
|
|
|
|
// Add intermediate wires.
|
|
nint := instr.Circ.NumWires - len(circWires) - len(circOut)
|
|
for i := 0; i < nint; i++ {
|
|
circWires = append(circWires, cc.Calloc.Wire())
|
|
}
|
|
|
|
// Append output wires.
|
|
circWires = append(circWires, circOut...)
|
|
|
|
// Add gates.
|
|
for _, gate := range instr.Circ.Gates {
|
|
switch gate.Op {
|
|
case circuit.XOR:
|
|
cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR,
|
|
circWires[gate.Input0],
|
|
circWires[gate.Input1],
|
|
circWires[gate.Output]))
|
|
case circuit.XNOR:
|
|
cc.AddGate(cc.Calloc.BinaryGate(circuit.XNOR,
|
|
circWires[gate.Input0],
|
|
circWires[gate.Input1],
|
|
circWires[gate.Output]))
|
|
case circuit.AND:
|
|
cc.AddGate(cc.Calloc.BinaryGate(circuit.AND,
|
|
circWires[gate.Input0],
|
|
circWires[gate.Input1],
|
|
circWires[gate.Output]))
|
|
case circuit.OR:
|
|
cc.AddGate(cc.Calloc.BinaryGate(circuit.OR,
|
|
circWires[gate.Input0],
|
|
circWires[gate.Input1],
|
|
circWires[gate.Output]))
|
|
case circuit.INV:
|
|
cc.INV(circWires[gate.Input0], circWires[gate.Output])
|
|
default:
|
|
return fmt.Errorf("unknown gate %s", gate)
|
|
}
|
|
}
|
|
|
|
case Builtin:
|
|
o, err := prog.walloc.Wires(*instr.Out, instr.Out.Type.Bits)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = instr.Builtin(cc, wires[0], wires[1], o)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
case GC:
|
|
|
|
default:
|
|
return fmt.Errorf("Block.Circuit: %s not implemented yet", instr.Op)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|