ceremonyclient/bedlam/apps/garbled/main.go
Cassandra Heart dbd95bd9e9
v2.1.0 (#439)
* v2.1.0 [omit consensus and adjacent] - this commit will be amended with the full release after the file copy is complete

* 2.1.0 main node rollup
2025-09-30 02:48:15 -05:00

365 lines
7.8 KiB
Go

//
// main.go
//
// Copyright (c) 2019-2024 Markku Rossi
//
// All rights reserved.
//
package main
import (
"flag"
"fmt"
"io"
"log"
"net"
"os"
"runtime"
"runtime/pprof"
"slices"
"strings"
"source.quilibrium.com/quilibrium/monorepo/bedlam"
"source.quilibrium.com/quilibrium/monorepo/bedlam/circuit"
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler"
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler/utils"
"source.quilibrium.com/quilibrium/monorepo/bedlam/ot"
"source.quilibrium.com/quilibrium/monorepo/bedlam/p2p"
)
var (
verbose = false
)
type input []string
func (i *input) String() string {
return fmt.Sprint(*i)
}
func (i *input) Set(value string) error {
for _, v := range strings.Split(value, ",") {
*i = append(*i, v)
}
return nil
}
var inputFlag, peerFlag input
func init() {
flag.Var(&inputFlag, "i", "comma-separated list of circuit inputs")
flag.Var(&peerFlag, "pi", "comma-separated list of peer's circuit inputs")
}
func main() {
evaluator := flag.Bool("e", false, "evaluator / garbler mode")
stream := flag.Bool("stream", false, "streaming mode")
compile := flag.Bool("circ", false, "compile QCL to circuit")
circFormat := flag.String("format", "qclc",
"circuit format: qclc, bristol")
ssa := flag.Bool("ssa", false, "compile QCL to SSA assembly")
dot := flag.Bool("dot", false, "create Graphviz DOT output")
svg := flag.Bool("svg", false, "create SVG output")
optimize := flag.Int("O", 1, "optimization level")
address := flag.String("address", "127.0.0.1:8080", "address and port")
otAddress := flag.String("ot-address", "127.0.0.1:5555", "address and port")
fVerbose := flag.Bool("v", false, "verbose output")
fDiagnostics := flag.Bool("d", false, "diagnostics output")
cpuprofile := flag.String("cpuprofile", "", "write cpu profile to `file`")
memprofile := flag.String("memprofile", "",
"write memory profile to `file`")
qclcErrLoc := flag.Bool("qclc-err-loc", false,
"print QCLC error locations")
benchmarkCompile := flag.Bool("benchmark-compile", false,
"benchmark QCL compilation")
flag.Parse()
log.SetFlags(0)
verbose = *fVerbose
if len(*cpuprofile) > 0 {
f, err := os.Create(*cpuprofile)
if err != nil {
log.Fatal("could not create CPU profile: ", err)
}
defer f.Close()
if err := pprof.StartCPUProfile(f); err != nil {
log.Fatal("could not start CPU profile: ", err)
}
defer pprof.StopCPUProfile()
}
params := utils.NewParams()
defer params.Close()
params.Verbose = *fVerbose
params.Diagnostics = *fDiagnostics
params.QCLCErrorLoc = *qclcErrLoc
params.BenchmarkCompile = *benchmarkCompile
if *optimize > 0 {
params.OptPruneGates = true
}
if *ssa && !*compile {
params.NoCircCompile = true
}
if *compile || *ssa {
inputSizes := make([][]int, 2)
iSizes, err := circuit.InputSizes(inputFlag)
if err != nil {
log.Fatal(err)
}
pSizes, err := circuit.InputSizes(peerFlag)
if err != nil {
log.Fatal(err)
}
if *evaluator {
inputSizes[0] = pSizes
inputSizes[1] = iSizes
} else {
inputSizes[0] = iSizes
inputSizes[1] = pSizes
}
err = compileFiles(flag.Args(), params, inputSizes,
*compile, *ssa, *dot, *svg, *circFormat)
if err != nil {
log.Fatalf("compile failed: %s", err)
}
memProfile(*memprofile)
return
}
var err error
party := uint8(1)
if *evaluator {
party = 2
}
oti := ot.NewFerret(party, *otAddress)
if *stream {
if *evaluator {
err = streamEvaluatorMode(oti, inputFlag, len(*cpuprofile) > 0, *address)
} else {
err = streamGarblerMode(params, oti, inputFlag, flag.Args(), *address)
}
memProfile(*memprofile)
if err != nil {
log.Fatal(err)
}
return
}
if len(flag.Args()) != 1 {
log.Fatalf("expected one input file, got %v\n", len(flag.Args()))
}
file := flag.Args()[0]
if *evaluator {
err = evaluatorMode(oti, file, params, len(*cpuprofile) > 0, *address)
} else {
err = garblerMode(oti, file, params, *address)
}
if err != nil {
log.Fatal(err)
}
}
func loadCircuit(file string, params *utils.Params, inputSizes [][]int) (
*circuit.Circuit, error) {
var circ *circuit.Circuit
var err error
if circuit.IsFilename(file) {
circ, err = circuit.Parse(file)
if err != nil {
return nil, err
}
} else if strings.HasSuffix(file, ".qcl") {
circ, _, err = compiler.New(params).CompileFile(file, inputSizes)
if err != nil {
return nil, err
}
} else {
return nil, fmt.Errorf("unknown file type '%s'", file)
}
if circ != nil {
circ.AssignLevels()
if verbose {
fmt.Printf("circuit: %v\n", circ)
}
}
return circ, err
}
func memProfile(file string) {
if len(file) == 0 {
return
}
f, err := os.Create(file)
if err != nil {
log.Fatal("could not create memory profile: ", err)
}
defer f.Close()
if false {
runtime.GC()
}
if err := pprof.WriteHeapProfile(f); err != nil {
log.Fatal("could not write memory profile: ", err)
}
}
func evaluatorMode(oti ot.OT, file string, params *utils.Params,
once bool, address string) error {
inputSizes := make([][]int, 2)
myInputSizes, err := circuit.InputSizes(inputFlag)
if err != nil {
return err
}
inputSizes[1] = myInputSizes
ln, err := net.Listen("tcp", address)
if err != nil {
return err
}
fmt.Printf("Listening for connections at %s\n", address)
var oPeerInputSizes []int
var circ *circuit.Circuit
for {
nc, err := ln.Accept()
if err != nil {
return err
}
fmt.Printf("New connection from %s\n", nc.RemoteAddr())
conn := p2p.NewConn(nc)
err = conn.SendInputSizes(myInputSizes)
if err != nil {
conn.Close()
return err
}
err = conn.Flush()
if err != nil {
conn.Close()
return err
}
peerInputSizes, err := conn.ReceiveInputSizes()
if err != nil {
conn.Close()
return err
}
inputSizes[0] = peerInputSizes
if circ == nil || slices.Compare(peerInputSizes, oPeerInputSizes) != 0 {
circ, err = loadCircuit(file, params, inputSizes)
if err != nil {
conn.Close()
return err
}
oPeerInputSizes = peerInputSizes
}
circ.PrintInputs(circuit.IDEvaluator, inputFlag)
if len(circ.Inputs) != 2 {
return fmt.Errorf("invalid circuit for 2-party MPC: %d parties",
len(circ.Inputs))
}
input, err := circ.Inputs[1].Parse(inputFlag)
if err != nil {
conn.Close()
return fmt.Errorf("%s: %v", file, err)
}
result, err := circuit.Evaluator(conn, oti, circ, input, verbose)
conn.Close()
if err != nil && err != io.EOF {
return err
}
bedlam.PrintResults(result, circ.Outputs)
if once {
return nil
}
}
}
func garblerMode(oti ot.OT, file string, params *utils.Params, address string) error {
inputSizes := make([][]int, 2)
myInputSizes, err := circuit.InputSizes(inputFlag)
if err != nil {
return err
}
inputSizes[0] = myInputSizes
nc, err := net.Dial("tcp", address)
if err != nil {
return err
}
conn := p2p.NewConn(nc)
defer conn.Close()
if params.Verbose {
fmt.Println(" - Receiving input sizes")
}
peerInputSizes, err := conn.ReceiveInputSizes()
if err != nil {
conn.Close()
return err
}
if params.Verbose {
fmt.Println(" - Sending input sizes")
}
inputSizes[1] = peerInputSizes
err = conn.SendInputSizes(myInputSizes)
if err != nil {
conn.Close()
return err
}
if params.Verbose {
fmt.Println(" - Sent input sizes")
}
err = conn.Flush()
if err != nil {
conn.Close()
return err
}
circ, err := loadCircuit(file, params, inputSizes)
if err != nil {
return err
}
circ.PrintInputs(circuit.IDGarbler, inputFlag)
if len(circ.Inputs) != 2 {
return fmt.Errorf("invalid circuit for 2-party MPC: %d parties",
len(circ.Inputs))
}
input, err := circ.Inputs[0].Parse(inputFlag)
if err != nil {
return fmt.Errorf("%s: %v", file, err)
}
if params.Verbose {
fmt.Println(" - Initiating garbler")
}
result, err := circuit.Garbler(conn, oti, circ, input, verbose)
if err != nil {
return err
}
bedlam.PrintResults(result, circ.Outputs)
return nil
}