rollup of non-core (e.g. not node consensus, execution, client, or protobufs so nobody can get a head start)

This commit is contained in:
Cassandra Heart 2025-07-10 13:27:28 -05:00
parent c3ebffc519
commit c1b4a86072
No known key found for this signature in database
GPG Key ID: 6352152859385958
1022 changed files with 2636022 additions and 438799 deletions

6
.gitignore vendored
View File

@ -1,4 +1,6 @@
.idea/
node/node
node/lunchtime-*
vouchers/
ceremony-client
.config*
@ -9,6 +11,10 @@ ceremony-client
.task
node-tmp-*
build
cover.out
# Rust
target
# Build outputs
vdf-test*

678
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -16,11 +16,16 @@ members = [
"crates/vdf",
"crates/channel",
"crates/channel-wasm",
"crates/bulletproofs-wasm",
"crates/bls48581-wasm",
"crates/verenc-wasm",
"crates/classgroup",
"crates/bls48581",
"crates/ed448-rust",
"crates/rpm",
"crates/bulletproofs",
"crates/verenc",
"crates/ferret"
]
[profile.release]

28
bedlam/.gitignore vendored Normal file
View File

@ -0,0 +1,28 @@
# Binaries for programs and plugins
*.exe
*.exe~
*.dll
*.so
*.dylib
# Test binary, build with `go test -c`
*.test
# Output of the go coverage tool, specifically when used with LiteIDE
*.out
.DS_Store
,*
*~
*.dot
*.pdf
*.ssa
*.prof
*.bristol
apps/circuit/circuit
apps/garbled/garbled
apps/iotest/iotest
apps/iter/iter
apps/qcldoc/qcldoc
apps/objdump/objdump
apps/ot/ot

21
bedlam/LICENSE Normal file
View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2019 Markku Rossi
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

3
bedlam/README.md Normal file
View File

@ -0,0 +1,3 @@
# Bedlam
This is a fork of Markku Rossi's MPC engine, retooled for use in Quilibrium and fitted to use FERRET for OT. There are additional divergences, but those are the most critical notes.

1
bedlam/apps/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
*.circ

View File

@ -0,0 +1,65 @@
//
// main.go
//
// Copyright (c) 2019-2022 Markku Rossi
//
// All rights reserved.
//
package main
import (
"flag"
"fmt"
"log"
"source.quilibrium.com/quilibrium/monorepo/bedlam/circuit"
)
func main() {
flag.Parse()
for _, file := range flag.Args() {
c, err := circuit.Parse(file)
if err != nil {
log.Fatal(err)
}
fmt.Printf("digraph circuit\n{\n")
fmt.Printf(" overlap=scale;\n")
fmt.Printf(" node\t[fontname=\"Helvetica\"];\n")
fmt.Printf(" {\n node [shape=plaintext];\n")
for w := 0; w < c.NumWires; w++ {
fmt.Printf(" w%d\t[label=\"%d\"];\n", w, w)
}
fmt.Printf(" }\n")
fmt.Printf(" {\n node [shape=box];\n")
for idx, gate := range c.Gates {
fmt.Printf(" g%d\t[label=\"%s\"];\n", idx, gate.Op)
}
fmt.Printf(" }\n")
if true {
fmt.Printf(" { rank=same")
for w := 0; w < c.Inputs.Size(); w++ {
fmt.Printf("; w%d", w)
}
fmt.Printf(";}\n")
fmt.Printf(" { rank=same")
for w := 0; w < c.Outputs.Size(); w++ {
fmt.Printf("; w%d", c.NumWires-w-1)
}
fmt.Printf(";}\n")
}
for idx, gate := range c.Gates {
for _, i := range gate.Inputs() {
fmt.Printf(" w%d -> g%d;\n", i, idx)
}
fmt.Printf(" g%d -> w%d;\n", idx, gate.Output)
}
fmt.Printf("}\n")
}
}

View File

@ -0,0 +1,122 @@
//
// main.go
//
// Copyright (c) 2019-2024 Markku Rossi
//
// All rights reserved.
//
package main
import (
"bufio"
"fmt"
"io"
"os"
"strings"
"source.quilibrium.com/quilibrium/monorepo/bedlam/circuit"
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler"
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler/utils"
)
func compileFiles(files []string, params *utils.Params, inputSizes [][]int,
compile, ssa, dot, svg bool, circFormat string) error {
var circ *circuit.Circuit
var err error
for _, file := range files {
if compile {
params.CircOut, err = makeOutput(file, circFormat)
if err != nil {
return err
}
params.CircFormat = circFormat
if dot {
params.CircDotOut, err = makeOutput(file, "circ.dot")
if err != nil {
return err
}
}
if svg {
params.CircSvgOut, err = makeOutput(file, "circ.svg")
if err != nil {
return err
}
}
}
if circuit.IsFilename(file) {
circ, err = circuit.Parse(file)
if err != nil {
return err
}
if params.CircOut != nil {
if params.Verbose {
fmt.Printf("Serializing circuit...\n")
}
err = circ.MarshalFormat(params.CircOut, params.CircFormat)
if err != nil {
return err
}
}
} else if strings.HasSuffix(file, ".qcl") {
if ssa {
params.SSAOut, err = makeOutput(file, "ssa")
if err != nil {
return err
}
if dot {
params.SSADotOut, err = makeOutput(file, "ssa.dot")
if err != nil {
return err
}
}
}
circ, _, err = compiler.New(params).CompileFile(file, inputSizes)
if err != nil {
return err
}
} else {
return fmt.Errorf("unknown file type '%s'", file)
}
}
return nil
}
func makeOutput(base, suffix string) (io.WriteCloser, error) {
var path string
idx := strings.LastIndexByte(base, '.')
if idx < 0 {
path = base + "." + suffix
} else {
path = base[:idx+1] + suffix
}
f, err := os.Create(path)
if err != nil {
return nil, err
}
return &OutputFile{
File: f,
Buffered: bufio.NewWriter(f),
}, nil
}
// OutputFile implements a buffered output file.
type OutputFile struct {
File *os.File
Buffered *bufio.Writer
}
func (out *OutputFile) Write(p []byte) (nn int, err error) {
return out.Buffered.Write(p)
}
// Close implements io.Closer.Close for the buffered output file.
func (out *OutputFile) Close() error {
if err := out.Buffered.Flush(); err != nil {
return err
}
return out.File.Close()
}

1
bedlam/apps/garbled/data/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
pre*

Binary file not shown.

View File

@ -0,0 +1,2 @@
*.svg
*.qclc

View File

@ -0,0 +1,9 @@
// -*- go -*-
// Sample 3-party circuit where each party provides their input bit
// and the result is bitwise AND of the inputs.
package main
func main(a, b, e uint1) uint {
return a & b & e
}

View File

@ -0,0 +1,26 @@
// -*- go -*-
//
// Example of how to add two uint256 values. You can use any data type for the input as long as it is divisible by 8,
// such as uint512.
//
// To run the Evaluator:
//
// ./garbled -e -i 0xb33d6a91b4ca8ac31c639c6742cba5a74c661a63311548af191c298a945d4891 examples/add.qcl
//
// To run the Garbler:
//
// ./garbled -i 0x5bf6db5927d799cf225f165e9508238edc5a1200fcad08c6411648733eb3100f examples/add.qcl
//
// The expected result should be:
// 6877051328478326342308659403308568813546041258230645271156059935820089415840 (0x0f3445eadca224923ec2b2c5d7d3c93628c02c642dc251755a3271fdd31058a0)
package main
import (
"math"
)
func main(a, b uint256) uint {
return a + b
}

View File

@ -0,0 +1,11 @@
// -*- go -*-
package main
import (
"crypto/aes"
)
func main(key, data [16]byte) []byte {
return aes.EncryptBlock(key, data)
}

View File

@ -0,0 +1,11 @@
// -*- go -*-
package main
import (
"crypto/aes"
)
func main(key, data [16]byte) []byte {
return aes.Block128(key, data)
}

View File

@ -0,0 +1,54 @@
// -*- go -*-
package main
import (
"crypto/cipher/cbc"
)
// Case #1: Encrypting 16 bytes (1 block) using AES-CBC with 128-bit key
//
// Key : 0x06a9214036b8a15b512e03d534120006
// IV : 0x3dafba429d9eb430b422da802c9fac41
// Plaintext : "Single block msg"
// Ciphertext: 0xe353779c1079aeb82708942dbe77181a
//
// Case #2: Encrypting 32 bytes (2 blocks) using AES-CBC with 128-bit key
//
// Key : 0xc286696d887c9aa0611bbb3e2025a45a
// IV : 0x562e17996d093d28ddb3ba695a2e6f58
// Plaintext : 0x000102030405060708090a0b0c0d0e0f
// 101112131415161718191a1b1c1d1e1f
// Ciphertext: 0xd296cd94c2cccf8a3a863028b5e1dc0a
//
// 7586602d253cfff91b8266bea6d61ab1
func main(g, e [16]byte) []byte {
key := []byte{
0x06, 0xa9, 0x21, 0x40, 0x36, 0xb8, 0xa1, 0x5b,
0x51, 0x2e, 0x03, 0xd5, 0x34, 0x12, 0x00, 0x06,
}
iv := []byte{
0x3d, 0xaf, 0xba, 0x42, 0x9d, 0x9e, 0xb4, 0x30,
0xb4, 0x22, 0xda, 0x80, 0x2c, 0x9f, 0xac, 0x41,
}
plain := []byte("Single block msg")
key2 := []byte{
0xc2, 0x86, 0x69, 0x6d, 0x88, 0x7c, 0x9a, 0xa0,
0x61, 0x1b, 0xbb, 0x3e, 0x20, 0x25, 0xa4, 0x5a,
}
iv2 := []byte{
0x56, 0x2e, 0x17, 0x99, 0x6d, 0x09, 0x3d, 0x28,
0xdd, 0xb3, 0xba, 0x69, 0x5a, 0x2e, 0x6f, 0x58,
}
plain2 := []byte{
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f,
}
//return cbc.EncryptAES128(key, iv, plain)
//return cbc.EncryptAES128(key2, iv2, plain2)
return cbc.EncryptAES128(g, iv, e)
}

View File

@ -0,0 +1,11 @@
// -*- go -*-
package main
import (
"crypto/aes"
)
func main(key, data [16]byte) []uint {
return aes.ExpandEncryptionKey(key)
}

View File

@ -0,0 +1,28 @@
// -*- go -*-
package main
import (
"bytes"
"crypto/cipher/gcm"
)
func main(a, b byte) (string, int) {
key := []byte{
0x06, 0xa9, 0x21, 0x40, 0x36, 0xb8, 0xa1, 0x5b,
0x51, 0x2e, 0x03, 0xd5, 0x34, 0x12, 0x00, 0x06,
}
nonce := []byte{
0x3d, 0xaf, 0xba, 0x42, 0x9d, 0x9e, 0xb4, 0x30,
0xb4, 0x22, 0xda, 0x80,
}
plain := []byte("Single block msgSingle block msg")
additional := []byte("additional data to be authenticated")
c := gcm.EncryptAES128(key, nonce, plain, additional)
p, ok := gcm.DecryptAES128(key, nonce, c, additional)
if !ok {
return "Open failed", 0
}
return string(p), bytes.Compare(plain, p)
}

View File

@ -0,0 +1,7 @@
// -*- go -*-
package main
func main(a, b uint1) uint1 {
return a & b
}

View File

@ -0,0 +1,31 @@
// -*- go -*-
package main
type Size = uint32
type Applicant struct {
male bool
age Size
income Size
}
type Bank struct {
maxAge Size
femaleIncome Size
maleIncome Size
}
func main(applicant Applicant, bank Bank) bool {
// Bank sets the maximum age limit.
if applicant.age > bank.maxAge {
return false
}
if applicant.male {
// Credit criteria for males.
return applicant.age >= 21 && applicant.income >= bank.maleIncome
} else {
// Credit criteria for females.
return applicant.age >= 18 && applicant.income >= bank.femaleIncome
}
}

View File

@ -0,0 +1,14 @@
// -*- go -*-
//
package main
import (
"math"
)
func main(a, b uint64) uint {
return a / b
//return int64(math.DivUint64(uint64(a), uint64(b)))
}

View File

@ -0,0 +1,36 @@
// -*- go -*-
package main
import (
"crypto/curve25519"
)
func main(g, e [32]byte) ([]byte, []byte, []byte) {
// var s [32]byte
// curve25519.ScalarMult(&s, &g, &e)
// return g, e, s
var privateKey [32]byte
for i := 0; i < len(privateKey); i++ {
//privateKey[i] = g[i] ^ e[i]
privateKey[i] = (i % 8) + 1
}
var publicKey [32]byte
curve25519.ScalarBaseMult(&publicKey, &privateKey)
var privateKey2 [32]byte
for i := 0; i < len(privateKey2); i++ {
//privateKey[i] = g[i] ^ e[i]
privateKey2[i] = (i % 16) + 1
}
var publicKey2 [32]byte
curve25519.ScalarBaseMult(&publicKey2, &privateKey2)
var secret [32]byte
curve25519.ScalarMult(&secret, &privateKey, &publicKey2)
return publicKey, publicKey2, secret
}

View File

@ -0,0 +1,105 @@
// -*- go -*-
// This example implements Ed25519 key generation. Both parties
// provide 3 random arguments:
//
// Seed [32]byte
// Split [64]byte
// Mask [64]byte
//
// The key generation seed is g.Seed^e.Seed. The generated private key
// is split into two random shares:
//
// Garbler's share : privG = g.Split^e.Split
// Evaluator's share: privE = privG ^ priv
//
// Garbler's and evaluator's shares are masked with respective masks:
//
// privGM = privG^g.Mask
// privEM = privE^e.Mask
//
// The key generation function returns pub, privGM, and privEM to both
// parties. After the multiparty protocol completes, both Garbler and
// Evaluator reveal their private key share:
//
// privG = privGM^g.Mask
// privE = privEM^e.Mask
//
// The following commands run the key generation algorithm. Both
// Evaluator and Garbler are started with 3 random arguments:
//
// ./garbled -stream -e -v -i 0x784db0ec4ca0cf5338249e6a09139109366dca1fac2838e5f0e5a46f0e191bae,0xd0da45d3c99e756da831d1e7d696eae3fa9fe39d3b1b2618c7ff997d17777989b5cf415b114298c8b10bed0f0eff118e43ab606ab01143151dff89171307dffa,0x44bf09357e19b1f96f9cf6d9e7d25a0e8dd62d6e0d4bba2bec4c59983c7dc84d1486677b6d8837746cd948c881913c36faeaee08e8309afac58be4757a1c544e
// ./garbled -stream -v -i 0x57c0e59c20ac7d75ef7e3188fdd7f5876abee1cab394af8125acaca9760bb54c,0x76b42e6292f4a3dc339d208481abeb9a24e08127c7cd8dbde62abcddc0c0e6f7a0f740e756b44dae137f0e7ff8eae0ceb1a962c130fdcbe8cbee3e31ab55b8dc,0xeb83eb1f5203f5b752c96264a21ff4a27fa60cf2313f5f53c3fa96e0b52a2814b786e43a3af64b66291b5b29f432cb8d5a930e31f4e6f072a6d33b861b5b5f13 examples/ed25519/keygen.qcl
//
// The example values return the following results:
//
// Result[0]: 8ae64963506002e267a59665e9a2e6f9348cc159be53747894478e182ece9fcb
// Result[1]: 4ded80ae09692306c9659307f522f5dba1d96e48cde9f4f6e22fb340629db76aa2bee5867d009e008b6fb85902273acda8910c9a740a788f70c28ca0a3093835
// Result[2]: cd5c37f4497fd56e236aa858442b3ff90f7a6401ee2186ea18d074fe93d8f9d18b582fa47a1ee0f0a9083ddd9e262b8f3c642dfad68f667f87dddd4bec80aca3
//
// The Garbler reveals its private key share by computing Result[1]^g.Mask:
//
// 4ded80ae09692306c9659307f522f5dba1d96e48cde9f4f6e22fb340629db76aa2bee5867d009e008b6fb85902273acda8910c9a740a788f70c28ca0a3093835
// ^ eb83eb1f5203f5b752c96264a21ff4a27fa60cf2313f5f53c3fa96e0b52a2814b786e43a3af64b66291b5b29f432cb8d5a930e31f4e6f072a6d33b861b5b5f13
// = a66e6bb15b6ad6b19bacf163573d0179de7f62bafcd6aba521d525a0d7b79f7e153801bc47f6d566a274e370f615f140f20202ab80ec88fdd611b726b8526726
//
// And Evaluator does the same computation for its values: Result[2]^e.Mask:
//
// cd5c37f4497fd56e236aa858442b3ff90f7a6401ee2186ea18d074fe93d8f9d18b582fa47a1ee0f0a9083ddd9e262b8f3c642dfad68f667f87dddd4bec80aca3
// ^ 44bf09357e19b1f96f9cf6d9e7d25a0e8dd62d6e0d4bba2bec4c59983c7dc84d1486677b6d8837746cd948c881913c36faeaee08e8309afac58be4757a1c544e
// = 89e33ec1376664974cf65e81a3f965f782ac496fe36a3cc1f49c2d66afa5319c9fde48df1796d784c5d175151fb717b9c68ec3f23ebffc854256393e969cf8ed
//
// Now Garbler and Evaluator have their private key shares and public
// key, but they do not know the full combined Ed25519 private
// key. The private key shares can be used in the `sign.qcl' example
// which computes the signature with the combined private key:
//
// privG ^ privE = priv:
// priv: 2f8d55706c0cb226d75aafe2f4c4648e5cd32bd51fbc9764d54908c67812aee2
// 8ae64963506002e267a59665e9a2e6f9348cc159be53747894478e182ece9fcb
// pub : 8ae64963506002e267a59665e9a2e6f9348cc159be53747894478e182ece9fcb
package main
import (
"crypto/ed25519"
)
type Arguments struct {
Seed [32]byte
Split [64]byte
Mask [64]byte
}
func main(g, e Arguments) ([]byte, []byte, []byte) {
var seed [32]byte
// Construct seed from peer seeds.
for i := 0; i < len(seed); i++ {
seed[i] = g.Seed[i] ^ e.Seed[i]
}
pub, priv := ed25519.NewKeyFromSeed(seed)
// Garbler's private key share is random value, constructed from
// peer split values.
var privG [64]byte
for i := 0; i < len(privG); i++ {
privG[i] = g.Split[i] ^ e.Split[i]
}
// Evaluator's private key share is real private key, xor'ed with
// Garbler's private key share.
var privE [64]byte
for i := 0; i < len(privE); i++ {
privE[i] = priv[i] ^ privG[i]
}
// Mask private key shares.
for i := 0; i < len(privG); i++ {
privG[i] ^= g.Mask[i]
}
for i := 0; i < len(privE); i++ {
privE[i] ^= e.Mask[i]
}
return pub, privG, privE
}

View File

@ -0,0 +1,55 @@
// -*- go -*-
// This example implements Ed25519 signature computation. The Ed25519
// keypair is:
//
// pub : 8ae64963506002e267a59665e9a2e6f9348cc159be53747894478e182ece9fcb
//
// priv : 2f8d55706c0cb226d75aafe2f4c4648e5cd32bd51fbc9764d54908c67812aee2
// 8ae64963506002e267a59665e9a2e6f9348cc159be53747894478e182ece9fcb
//
// The Garbler and Evaluator share the private key as two random
// shares. The private key is contructed during the signature
// computation by XOR:ing the random shares together:
//
// privG: a66e6bb15b6ad6b19bacf163573d0179de7f62bafcd6aba521d525a0d7b79f7e
// 153801bc47f6d566a274e370f615f140f20202ab80ec88fdd611b726b8526726
//
//
// privE: 89e33ec1376664974cf65e81a3f965f782ac496fe36a3cc1f49c2d66afa5319c
// 9fde48df1796d784c5d175151fb717b9c68ec3f23ebffc854256393e969cf8ed
//
// priv = privG ^ privE
//
// Run the Evaluator with one input: the Evaluator's private key share:
//
// ./garbled -e -v -stream -i 0x89e33ec1376664974cf65e81a3f965f782ac496fe36a3cc1f49c2d66afa5319c9fde48df1796d784c5d175151fb717b9c68ec3f23ebffc854256393e969cf8ed
//
// The Garbler takes two inputs: the message to sign, and the
// Garbler's private key share:
//
// ./garbled -stream -v -i 0x4d61726b6b7520526f737369203c6d747240696b692e66693e2068747470733a2f2f7777772e6d61726b6b75726f7373692e636f6d2f,0xa66e6bb15b6ad6b19bacf163573d0179de7f62bafcd6aba521d525a0d7b79f7e153801bc47f6d566a274e370f615f140f20202ab80ec88fdd611b726b8526726 examples/ed25519/sign.qcl
//
// The result signature is:
//
// Result[0]: f43c1cf1755345852211942af0838414334eec9cbf36e26a9f9e8d4bb720deb145ffbeec82249c875116757441206bcdc56b501e750f1f590917d772dfee980f
package main
import (
"crypto/ed25519"
)
type Garbler struct {
msg [64]byte
privShare [64]byte
}
func main(g Garbler, privShare [64]byte) []byte {
var priv [64]byte
for i := 0; i < len(priv); i++ {
priv[i] = g.privShare[i] ^ privShare[i]
}
return ed25519.Sign(priv, g.msg)
}

View File

@ -0,0 +1,44 @@
// -*- go -*-
// Example how to encrypt fixed sized data with AES-128-GCM.
//
// Run the Evaluator with two inputs: evaluator's key and nonce shares:
//
// ./garbled -e -i 0x8cd98b88adab08d6d60fe57c8b8a33f3,0xfd5e0f8f155e7102aa526ad0 examples/encrypt.qcl
//
// The Garbler takes three arguments: the message to encrypt, and its
// key and nonce shares:
//
// ./garbled -i 0x48656c6c6f2c20776f726c6421,0xed800b17b0c9d2334b249332155ddef5,0xa300751458c775a08762c2cd examples/encrypt.qcl
package main
import (
"crypto/cipher/gcm"
)
type Garbler struct {
msg [64]byte
keyShare [16]byte
nonceShare [12]byte
}
type Evaluator struct {
keyShare [16]byte
nonceShare [12]byte
}
func main(g Garbler, e Evaluator) []byte {
var key [16]byte
for i := 0; i < len(key); i++ {
key[i] = g.keyShare[i] ^ e.keyShare[i]
}
var nonce [12]byte
for i := 0; i < len(nonce); i++ {
nonce[i] = g.nonceShare[i] ^ e.nonceShare[i]
}
return gcm.EncryptAES128(key, nonce, g.msg, []byte("unused"))
}

View File

@ -0,0 +1,11 @@
// -*- go -*-
package main
import (
"encoding/binary"
)
func main(a, b uint1024) uint {
return binary.HammingDistance(a, b)
}

View File

@ -0,0 +1,61 @@
// -*- go -*-
// This example computes HMAC-SHA256 where the HMAC key is shared as
// two random shares between garbler and evaluator. The garbler's key
// share is:
//
// keyG: 4de216d2fdc9301e5b9c78486f7109a05670d200d9e2f275ec0aad08ec42af47
// fcb59bf460d50b01333a748f3a9efb13e08036d49a26c21ba2e33a5f8a2cf0e7
//
// The evaluator's key share is:
//
// keyE: f87a00ef89c2396de32f6ac0748f6fa1b641013d46f74ce25cc625904215a675
// 01c0c7196a2602f6516527958a82271847933c35d170d98bfdb04d2ddf3bb197
//
// The final HMAC key is keyG ^ keyE:
//
// key : b598163d740b0973b8b312881bfe6601e031d33d9f15be97b0cc8898ae570932
// fd755ced0af309f7625f531ab01cdc0ba7130ae14b561b905f53777255174170
//
// The example uses 32-byte messages (Garbler.msg) so with the message:
//
// msg : Hello, world!...................
// hex : 48656c6c6f2c20776f726c64212e2e2e2e2e2e2e2e2e2e2e2e2e2e2e2e2e2e2e
//
// We expect to get the following HMAC-SHA256 output:
//
// sum : 60d27dbd14f1e351f20069171fead00ef557d17ac9a41d02baa488ca4b90171a
//
// Now we can run the MPC computation as follows. First, run the
// evaluator with one input: the evaluator's key share:
//
// ./garbled -e -v -i 0xf87a00ef89c2396de32f6ac0748f6fa1b641013d46f74ce25cc625904215a67501c0c7196a2602f6516527958a82271847933c35d170d98bfdb04d2ddf3bb197 examples/hmac-sha256.qcl
//
// The garbler takes two inputs: the message and the garbler's key
// share:
//
// ./garbled -v -i 0x48656c6c6f2c20776f726c64212e2e2e2e2e2e2e2e2e2e2e2e2e2e2e2e2e2e2e,0x4de216d2fdc9301e5b9c78486f7109a05670d200d9e2f275ec0aad08ec42af47fcb59bf460d50b01333a748f3a9efb13e08036d49a26c21ba2e33a5f8a2cf0e7 examples/hmac-sha256.qcl
//
// The MCP computation providers the expected HMAC result:
//
// Result[0]: 60d27dbd14f1e351f20069171fead00ef557d17ac9a41d02baa488ca4b90171a
package main
import (
"crypto/hmac"
)
type Garbler struct {
msg []byte
keyShare [64]byte
}
func main(g Garbler, eKeyShare [64]byte) []byte {
var key [64]byte
for i := 0; i < len(key); i++ {
key[i] = g.keyShare[i] ^ eKeyShare[i]
}
return hmac.SumSHA256(g.msg, key)
}

View File

@ -0,0 +1,62 @@
// -*- go -*-
// This example computes HMAC-SHA256 where the HMAC key is shared as
// two random shares between garbler and evaluator. The garbler's key
// share is:
//
// keyG: 4de216d2fdc9301e5b9c78486f7109a05670d200d9e2f275ec0aad08ec42af47
// fcb59bf460d50b01333a748f3a9efb13e08036d49a26c21ba2e33a5f8a2cf0e7
//
// The evaluator's key share is:
//
// keyE: f87a00ef89c2396de32f6ac0748f6fa1b641013d46f74ce25cc625904215a675
// 01c0c7196a2602f6516527958a82271847933c35d170d98bfdb04d2ddf3bb197
//
// The final HMAC key is keyG ^ keyE:
//
// key : b598163d740b0973b8b312881bfe6601e031d33d9f15be97b0cc8898ae570932
// fd755ced0af309f7625f531ab01cdc0ba7130ae14b561b905f53777255174170
//
// The example uses 32-byte messages (Garbler.msg) so with the message:
//
// msg : Hello, world!...................
// hex : 48656c6c6f2c20776f726c64212e2e2e2e2e2e2e2e2e2e2e2e2e2e2e2e2e2e2e
//
// We expect to get the following HMAC-SHA256 output:
//
// sum : 89c648c4d7b4220f6767706dec64f69bbdb2725d062a09a447b9cf32af8636ee8853f92ca59c0e81712b72c79f3503f7f2131d20c7a5dfae87b79f839cecf2c4
//
// Now we can run the MPC computation as follows. First, run the
// evaluator with one input: the evaluator's key share:
//
// ./garbled -e -v -i 0xf87a00ef89c2396de32f6ac0748f6fa1b641013d46f74ce25cc625904215a67501c0c7196a2602f6516527958a82271847933c35d170d98bfdb04d2ddf3bb197 examples/hmac-sha512.qcl
//
// The garbler takes two inputs: the message and the garbler's key
// share:
//
// ./garbled -v -i 0x48656c6c6f2c20776f726c64212e2e2e2e2e2e2e2e2e2e2e2e2e2e2e2e2e2e2e,0x4de216d2fdc9301e5b9c78486f7109a05670d200d9e2f275ec0aad08ec42af47fcb59bf460d50b01333a748f3a9efb13e08036d49a26c21ba2e33a5f8a2cf0e7 examples/hmac-sha512.qcl
//
// The MCP computation providers the expected HMAC result:
//
// Result[0]: 89c648c4d7b4220f6767706dec64f69bbdb2725d062a09a447b9cf32af8636ee8853f92ca59c0e81712b72c79f3503f7f2131d20c7a5dfae87b79f839cecf2c4
package main
import (
"crypto/hmac"
)
type Garbler struct {
msg []byte
keyShare [64]byte
}
func main(g Garbler, eKeyShare [64]byte) []byte {
var key [64]byte
for i := 0; i < len(key); i++ {
key[i] = g.keyShare[i] ^ eKeyShare[i]
}
return hmac.SumSHA512(g.msg, key)
}

View File

@ -0,0 +1,102 @@
// -*- go -*-
// This example shows how a key can be imported to MCP peers garbler
// and evaluator so that after the import operation both peers hold a
// random share of the key, and shareG^shareE=key.
//
// The garbler provides three arguments:
//
// key [64]byte: the key to import
// split [64]byte: random input for the key split value
// mask [64]byte: random mask to mask garbler's key share result
//
// The evaluator provides two arguments: the split and the mask values.
//
// The MPC program splits the key into two random shares:
//
// garbler's share : keyG = g.Split ^ e.Split
// evaluator's share: keyE = keyG ^ g.Key
//
// The result shares are masked with respective masks and returned as
// the result of the computation:
//
// keyGM = keyG ^ g.Mask
// keyEM = keyE ^ e.Mask
//
// Finally, both parties can extract their key shares because the know
// their mask values:
//
// keyG = keyGM ^ g.Mask
// keyE = keyEM ^ e.Mask
//
// Using the example values:
//
// key: a968f050ebd5c4ed2ddf9717f0f0fd9325b07c68ff5d62094800f5b69464bab9
// 8dd886a7c49460503fafa75f5f7430f2cdda7bd5cb60c1cbd471e35d67432d58
// splitG: 7e1d9bb27838f5c8481b7194f07b5f3059f9471ae8e69ea3fe79c629a92588d9
// 524a6e4364e77d222210135f6c5435a8be52fc99ad8fc8280e8207cac91fc7b3
// maskG: bed1bc2a3e6089bd016ff0175c62346438a9eb7b741f41787e5f7aad1720ee08
// 233a89e81e3bbd5eef26d158750a0fdd47471ded518d781f23de6d4346ea68ad
//
// splitE: b2146ed7385d63a76f599b27f03e83971149208b0c41604eea010806460a3266
// 93820075bc25c485b2bfcb9226488ba961eeb07980f8ab374b38f793e41e5247
// maskE: a0698ff8e72f51bf3bff3895c80a8ba8a527abaa5a7603391545ed5dcebb22b5
// a2f191bcb3ac3a543cfdba99bded67a3ac6f5f254ff7e5c34520312c9b91f672
//
// we run the evaluator with two arguments:
//
// ./garbled -e -v -i 0xb2146ed7385d63a76f599b27f03e83971149208b0c41604eea010806460a326693820075bc25c485b2bfcb9226488ba961eeb07980f8ab374b38f793e41e5247,0xa0698ff8e72f51bf3bff3895c80a8ba8a527abaa5a7603391545ed5dcebb22b5a2f191bcb3ac3a543cfdba99bded67a3ac6f5f254ff7e5c34520312c9b91f672 examples/key-import.qcl
//
// and the garbler with three arguments:
//
// ./garbled -v -i 0xa968f050ebd5c4ed2ddf9717f0f0fd9325b07c68ff5d62094800f5b69464bab98dd886a7c49460503fafa75f5f7430f2cdda7bd5cb60c1cbd471e35d67432d58,0x7e1d9bb27838f5c8481b7194f07b5f3059f9471ae8e69ea3fe79c629a92588d9524a6e4364e77d222210135f6c5435a8be52fc99ad8fc8280e8207cac91fc7b3,0xbed1bc2a3e6089bd016ff0175c62346438a9eb7b741f41787e5f7aad1720ee08233a89e81e3bbd5eef26d158750a0fdd47471ded518d781f23de6d4346ea68ad examples/key-import.qcl
//
// The program returns two values: garbler's and evaluator's masked
// key shares:
//
// Result[0]: 72d8494f7e051fd2262d1aa45c27e8c370198cea90b8bf956a27b482f80f54b7e2f2e7dec6f904f97f8909953f16b1dc98fb510d7cfa1b0066649d1a6bebfd59
// Result[1]: c5088acd4c9f033d3162453138bfaa9cc827b053418c9fdd493dd6c4b5f022b3eee1792daffae3a393fdc50ba885e950be096810a9e04717d4eb2228d1d34ede
//
// and both peers can extract their key shares by XOR:ing their result
// with their mask value:
//
// shareG: cc09f5654065966f2742eab30045dca748b06791e4a7feed1478ce2fef2fbabf
// c1c86e36d8c2b9a790afd8cd4a1cbe01dfbc4ce02d77631f45baf0592d0195f4
// shareE: 65610535abb052820a9d7da4f0b521346d001bf91bfa9ce45c783b997b4b0006
// 4c10e8911c56d9f7af007f9215688ef312663735e617a2d491cb13044a42b8ac
//
// and we see that shareG^shareE = key
package main
type Garbler struct {
Key [64]byte
Split [64]byte
Mask [64]byte
}
type Evaluator struct {
Split [64]byte
Mask [64]byte
}
func main(g Garbler, e Evaluator) ([]byte, []byte) {
var keyG [64]byte
var keyE [64]byte
var keyGM [64]byte
var keyEM [64]byte
for i := 0; i < len(keyG); i++ {
keyG[i] = g.Split[i] ^ e.Split[i]
}
for i := 0; i < len(keyE); i++ {
keyE[i] = keyG[i] ^ g.Key[i]
}
for i := 0; i < len(keyGM); i++ {
keyGM[i] = keyG[i] ^ g.Mask[i]
}
for i := 0; i < len(keyEM); i++ {
keyEM[i] = keyE[i] ^ e.Mask[i]
}
return keyGM, keyEM
}

View File

@ -0,0 +1,13 @@
// -*- go -*-
//
// Yao's Millionaires' problem with int64 values.
package main
func main(a, b int64) bool {
if a > b {
return true
} else {
return false
}
}

View File

@ -0,0 +1,39 @@
// -*- go -*-
//
// RSA encryption with Montgomery modular multiplication.
//
// ./garbled -e -v -i 0x321af130 examples/montgomery.qcl
// ./garbled -v -i 0x6d7472,9,0xd60b2b09,0x10001 examples/montgomery.qcl
package main
import (
"math"
)
type Size = uint64
type Garbler struct {
msg Size
privShare Size
pubN Size
pubE Size
}
func main(g Garbler, privShare Size) (uint, uint) {
priv := g.privShare + privShare
cipher := Encrypt(g.msg, g.pubE, g.pubN)
plain := Decrypt(cipher, priv, g.pubN)
return cipher, plain
}
func Encrypt(msg, e, n uint) uint {
return math.ExpMontgomery(msg, e, n)
}
func Decrypt(cipher, d, n uint) uint {
return math.ExpMontgomery(cipher, d, n)
}

View File

@ -0,0 +1,8 @@
// -*- go -*-
//
package main
func main(a, b uint64) uint {
return a * b
}

View File

@ -0,0 +1,8 @@
// -*- go -*-
//
package main
func main(a, b int1024) int1024 {
return a * b
}

View File

@ -0,0 +1,15 @@
// -*- go -*-
//
// Rock, Paper, Scissors
package main
func main(g, e int8) string {
if g == e {
return "-"
}
if (g+1)%3 == e {
return "evaluator"
}
return "garbler"
}

View File

@ -0,0 +1,45 @@
// -*- go -*-
//
// 32-bit RSA encryption and decryption.
//
// The key parameters are:
//
// d: 0x321af139
// n: 0xd60b2b09
// e: 0x10001
//
// private: d, n
// public: e, n
//
// msg: 0x6d7472
// cipher: 0x61f9ef88
//
// Run garbler and evaluator as follows:
//
// ./garbled -e -v -i 9 examples/rsa.qcl
// ./garbled -v -i 0x6d7472,0x321af130,0xd60b2b09,0x10001 examples/rsa.qcl
package main
import (
"crypto/rsa"
)
type Size = uint32
type Garbler struct {
msg Size
privShare Size
pubN Size
pubE Size
}
func main(g Garbler, privShare Size) (uint, uint) {
priv := g.privShare + privShare
cipher := rsa.Encrypt(g.msg, g.pubE, g.pubN)
plain := rsa.Decrypt(cipher, priv, g.pubN)
return cipher, plain
}

View File

@ -0,0 +1,42 @@
// -*- go -*-
//
// RSA signature with Size bits.
//
// The key parameters are:
//
// d: 0x321af139
// n: 0xd60b2b09
// e: 0x10001
//
// private: d, n
// public: e, n
//
// msg: 0x6d7472
// signature: 0x55a83b79
//
// Run garbler and evaluator as follows:
//
// ./garbled -e -v -i 9 examples/rsasign.qcl
// ./garbled -v -i 0x6d7472,0x321af130,0xd60b2b09,0x10001 examples/rsasign.qcl
package main
import (
"crypto/rsa"
)
type Size = uint512
type Garbler struct {
msg Size
privShare Size
pubN Size
pubE Size
}
func main(g Garbler, privShare Size) uint {
priv := g.privShare + privShare
return rsa.Decrypt(g.msg, priv, g.pubN)
}

View File

@ -0,0 +1,33 @@
// -*- go -*-
package main
import (
"sort"
)
var input = []int9{
136, 142, 146, 165, 183, 189, 220, 223, 232, 235, 67, 73, 77, 88,
91, 93, 95, 97, 98, 132, 5, 6, 7, 10, 14, 18, 18, 37, 50, 64, 245,
249, 252, 136, 142, 146, 165, 183, 189, 220, 223, 232, 235, 67,
73, 77, 88, 91, 93, 95, 97, 98, 132, 136, 142, 146, 165, 183, 189,
220, 223, 232, 235, 67, 73, 77, 88, 91, 93, 95, 97, 98, 132, 5, 6,
7, 10, 14, 18, 18, 37, 50, 64, 245, 249, 252, 136, 142, 146, 165,
183, 189, 220, 223, 232, 235, 67, 73, 77, 88, 91, 93, 95, 97, 98,
132, 136, 142, 146, 165, 183, 189, 220, 223, 232, 235, 67, 73, 77,
88, 91, 93, 95, 97, 98, 132, 5, 6, 7, 10, 14, 18, 18, 37, 50, 64,
245, 249, 252, 136, 142, 146, 165, 183, 189, 220, 223, 232, 235,
67, 73, 77, 88, 91, 93, 95, 97, 98, 132, 5, 6, 7, 10, 14, 18, 18,
37, 50, 64, 245, 249, 252, 136, 142, 146, 165, 183, 189, 220, 223,
232, 235, 67, 73, 77, 88, 91, 93, 95, 97, 98, 132, 136, 142, 146,
165, 183, 189, 220, 223, 232, 235, 67, 73, 77, 88, 91, 93, 95, 97,
98, 132, 5, 6, 7, 10, 14, 18, 18, 37, 50, 64, 245, 249, 252, 136,
142, 146, 165, 183, 189, 220, 223, 232, 235, 67, 73, 77, 88, 91,
93, 95, 97, 98, 132, 136, 142, 146, 165, 183, 189, 220, 223, 232,
235, 67, 73, 77, 88, 91, 93, 95, 97, 98, 132, 5, 6, 7, 10, 14, 18,
18, 37, 50, 64, 245, 249, 252,
}
func main(g, e byte) []int {
return sort.Reverse(sort.Slice(input))
}

364
bedlam/apps/garbled/main.go Normal file
View File

@ -0,0 +1,364 @@
//
// main.go
//
// Copyright (c) 2019-2024 Markku Rossi
//
// All rights reserved.
//
package main
import (
"flag"
"fmt"
"io"
"log"
"net"
"os"
"runtime"
"runtime/pprof"
"slices"
"strings"
"source.quilibrium.com/quilibrium/monorepo/bedlam"
"source.quilibrium.com/quilibrium/monorepo/bedlam/circuit"
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler"
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler/utils"
"source.quilibrium.com/quilibrium/monorepo/bedlam/ot"
"source.quilibrium.com/quilibrium/monorepo/bedlam/p2p"
)
var (
verbose = false
)
type input []string
func (i *input) String() string {
return fmt.Sprint(*i)
}
func (i *input) Set(value string) error {
for _, v := range strings.Split(value, ",") {
*i = append(*i, v)
}
return nil
}
var inputFlag, peerFlag input
func init() {
flag.Var(&inputFlag, "i", "comma-separated list of circuit inputs")
flag.Var(&peerFlag, "pi", "comma-separated list of peer's circuit inputs")
}
func main() {
evaluator := flag.Bool("e", false, "evaluator / garbler mode")
stream := flag.Bool("stream", false, "streaming mode")
compile := flag.Bool("circ", false, "compile QCL to circuit")
circFormat := flag.String("format", "qclc",
"circuit format: qclc, bristol")
ssa := flag.Bool("ssa", false, "compile QCL to SSA assembly")
dot := flag.Bool("dot", false, "create Graphviz DOT output")
svg := flag.Bool("svg", false, "create SVG output")
optimize := flag.Int("O", 1, "optimization level")
address := flag.String("address", "127.0.0.1:8080", "address and port")
otAddress := flag.String("ot-address", "127.0.0.1:5555", "address and port")
fVerbose := flag.Bool("v", false, "verbose output")
fDiagnostics := flag.Bool("d", false, "diagnostics output")
cpuprofile := flag.String("cpuprofile", "", "write cpu profile to `file`")
memprofile := flag.String("memprofile", "",
"write memory profile to `file`")
qclcErrLoc := flag.Bool("qclc-err-loc", false,
"print QCLC error locations")
benchmarkCompile := flag.Bool("benchmark-compile", false,
"benchmark QCL compilation")
flag.Parse()
log.SetFlags(0)
verbose = *fVerbose
if len(*cpuprofile) > 0 {
f, err := os.Create(*cpuprofile)
if err != nil {
log.Fatal("could not create CPU profile: ", err)
}
defer f.Close()
if err := pprof.StartCPUProfile(f); err != nil {
log.Fatal("could not start CPU profile: ", err)
}
defer pprof.StopCPUProfile()
}
params := utils.NewParams()
defer params.Close()
params.Verbose = *fVerbose
params.Diagnostics = *fDiagnostics
params.QCLCErrorLoc = *qclcErrLoc
params.BenchmarkCompile = *benchmarkCompile
if *optimize > 0 {
params.OptPruneGates = true
}
if *ssa && !*compile {
params.NoCircCompile = true
}
if *compile || *ssa {
inputSizes := make([][]int, 2)
iSizes, err := circuit.InputSizes(inputFlag)
if err != nil {
log.Fatal(err)
}
pSizes, err := circuit.InputSizes(peerFlag)
if err != nil {
log.Fatal(err)
}
if *evaluator {
inputSizes[0] = pSizes
inputSizes[1] = iSizes
} else {
inputSizes[0] = iSizes
inputSizes[1] = pSizes
}
err = compileFiles(flag.Args(), params, inputSizes,
*compile, *ssa, *dot, *svg, *circFormat)
if err != nil {
log.Fatalf("compile failed: %s", err)
}
memProfile(*memprofile)
return
}
var err error
party := uint8(1)
if *evaluator {
party = 2
}
oti := ot.NewFerret(party, *otAddress)
if *stream {
if *evaluator {
err = streamEvaluatorMode(oti, inputFlag, len(*cpuprofile) > 0, *address)
} else {
err = streamGarblerMode(params, oti, inputFlag, flag.Args(), *address)
}
memProfile(*memprofile)
if err != nil {
log.Fatal(err)
}
return
}
if len(flag.Args()) != 1 {
log.Fatalf("expected one input file, got %v\n", len(flag.Args()))
}
file := flag.Args()[0]
if *evaluator {
err = evaluatorMode(oti, file, params, len(*cpuprofile) > 0, *address)
} else {
err = garblerMode(oti, file, params, *address)
}
if err != nil {
log.Fatal(err)
}
}
func loadCircuit(file string, params *utils.Params, inputSizes [][]int) (
*circuit.Circuit, error) {
var circ *circuit.Circuit
var err error
if circuit.IsFilename(file) {
circ, err = circuit.Parse(file)
if err != nil {
return nil, err
}
} else if strings.HasSuffix(file, ".qcl") {
circ, _, err = compiler.New(params).CompileFile(file, inputSizes)
if err != nil {
return nil, err
}
} else {
return nil, fmt.Errorf("unknown file type '%s'", file)
}
if circ != nil {
circ.AssignLevels()
if verbose {
fmt.Printf("circuit: %v\n", circ)
}
}
return circ, err
}
func memProfile(file string) {
if len(file) == 0 {
return
}
f, err := os.Create(file)
if err != nil {
log.Fatal("could not create memory profile: ", err)
}
defer f.Close()
if false {
runtime.GC()
}
if err := pprof.WriteHeapProfile(f); err != nil {
log.Fatal("could not write memory profile: ", err)
}
}
func evaluatorMode(oti ot.OT, file string, params *utils.Params,
once bool, address string) error {
inputSizes := make([][]int, 2)
myInputSizes, err := circuit.InputSizes(inputFlag)
if err != nil {
return err
}
inputSizes[1] = myInputSizes
ln, err := net.Listen("tcp", address)
if err != nil {
return err
}
fmt.Printf("Listening for connections at %s\n", address)
var oPeerInputSizes []int
var circ *circuit.Circuit
for {
nc, err := ln.Accept()
if err != nil {
return err
}
fmt.Printf("New connection from %s\n", nc.RemoteAddr())
conn := p2p.NewConn(nc)
err = conn.SendInputSizes(myInputSizes)
if err != nil {
conn.Close()
return err
}
err = conn.Flush()
if err != nil {
conn.Close()
return err
}
peerInputSizes, err := conn.ReceiveInputSizes()
if err != nil {
conn.Close()
return err
}
inputSizes[0] = peerInputSizes
if circ == nil || slices.Compare(peerInputSizes, oPeerInputSizes) != 0 {
circ, err = loadCircuit(file, params, inputSizes)
if err != nil {
conn.Close()
return err
}
oPeerInputSizes = peerInputSizes
}
circ.PrintInputs(circuit.IDEvaluator, inputFlag)
if len(circ.Inputs) != 2 {
return fmt.Errorf("invalid circuit for 2-party MPC: %d parties",
len(circ.Inputs))
}
input, err := circ.Inputs[1].Parse(inputFlag)
if err != nil {
conn.Close()
return fmt.Errorf("%s: %v", file, err)
}
result, err := circuit.Evaluator(conn, oti, circ, input, verbose)
conn.Close()
if err != nil && err != io.EOF {
return err
}
bedlam.PrintResults(result, circ.Outputs)
if once {
return nil
}
}
}
func garblerMode(oti ot.OT, file string, params *utils.Params, address string) error {
inputSizes := make([][]int, 2)
myInputSizes, err := circuit.InputSizes(inputFlag)
if err != nil {
return err
}
inputSizes[0] = myInputSizes
nc, err := net.Dial("tcp", address)
if err != nil {
return err
}
conn := p2p.NewConn(nc)
defer conn.Close()
if params.Verbose {
fmt.Println(" - Receiving input sizes")
}
peerInputSizes, err := conn.ReceiveInputSizes()
if err != nil {
conn.Close()
return err
}
if params.Verbose {
fmt.Println(" - Sending input sizes")
}
inputSizes[1] = peerInputSizes
err = conn.SendInputSizes(myInputSizes)
if err != nil {
conn.Close()
return err
}
if params.Verbose {
fmt.Println(" - Sent input sizes")
}
err = conn.Flush()
if err != nil {
conn.Close()
return err
}
circ, err := loadCircuit(file, params, inputSizes)
if err != nil {
return err
}
circ.PrintInputs(circuit.IDGarbler, inputFlag)
if len(circ.Inputs) != 2 {
return fmt.Errorf("invalid circuit for 2-party MPC: %d parties",
len(circ.Inputs))
}
input, err := circ.Inputs[0].Parse(inputFlag)
if err != nil {
return fmt.Errorf("%s: %v", file, err)
}
if params.Verbose {
fmt.Println(" - Initiating garbler")
}
result, err := circuit.Garbler(conn, oti, circ, input, verbose)
if err != nil {
return err
}
bedlam.PrintResults(result, circ.Outputs)
return nil
}

View File

@ -0,0 +1,104 @@
//
// Copyright (c) 2020-2023 Markku Rossi
//
// All rights reserved.
//
package main
import (
"fmt"
"io"
"net"
"strings"
"source.quilibrium.com/quilibrium/monorepo/bedlam"
"source.quilibrium.com/quilibrium/monorepo/bedlam/circuit"
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler"
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler/utils"
"source.quilibrium.com/quilibrium/monorepo/bedlam/ot"
"source.quilibrium.com/quilibrium/monorepo/bedlam/p2p"
)
func streamEvaluatorMode(oti ot.OT, input input, once bool, address string) error {
inputSizes, err := circuit.InputSizes(input)
if err != nil {
return err
}
ln, err := net.Listen("tcp", address)
if err != nil {
return err
}
fmt.Printf("Listening for connections at %s\n", address)
for {
nc, err := ln.Accept()
if err != nil {
return err
}
fmt.Printf("New connection from %s\n", nc.RemoteAddr())
conn := p2p.NewConn(nc)
err = conn.SendInputSizes(inputSizes)
if err != nil {
conn.Close()
return err
}
err = conn.Flush()
if err != nil {
conn.Close()
return err
}
outputs, result, err := circuit.StreamEvaluator(conn, oti, input,
verbose)
conn.Close()
if err != nil && err != io.EOF {
return fmt.Errorf("%s: %v", nc.RemoteAddr(), err)
}
bedlam.PrintResults(result, outputs)
if once {
return nil
}
}
}
func streamGarblerMode(params *utils.Params, oti ot.OT, input input,
args []string, address string) error {
inputSizes := make([][]int, 2)
sizes, err := circuit.InputSizes(input)
if err != nil {
return err
}
inputSizes[0] = sizes
if len(args) != 1 || !strings.HasSuffix(args[0], ".qcl") {
return fmt.Errorf("streaming mode takes single QCL file")
}
nc, err := net.Dial("tcp", address)
if err != nil {
return err
}
conn := p2p.NewConn(nc)
defer conn.Close()
sizes, err = conn.ReceiveInputSizes()
if err != nil {
return err
}
inputSizes[1] = sizes
outputs, result, err := compiler.New(params).StreamFile(
conn, oti, args[0], input, inputSizes)
if err != nil {
return err
}
bedlam.PrintResults(result, outputs)
return nil
}

View File

@ -0,0 +1,83 @@
//
// main.go
//
// Copyright (c) 2023 Markku Rossi
//
// All rights reserved.
//
package main
import (
"fmt"
"io"
"net"
"source.quilibrium.com/quilibrium/monorepo/bedlam/circuit"
"source.quilibrium.com/quilibrium/monorepo/bedlam/ot"
"source.quilibrium.com/quilibrium/monorepo/bedlam/p2p"
)
func evaluatorTestIO(size int64, once bool) error {
ln, err := net.Listen("tcp", port)
if err != nil {
return err
}
fmt.Printf("Listening for connections at %s\n", port)
for {
nc, err := ln.Accept()
if err != nil {
return err
}
fmt.Printf("New connection from %s\n", nc.RemoteAddr())
conn := p2p.NewConn(nc)
for {
var label ot.Label
var labelData ot.LabelData
err = conn.ReceiveLabel(&label, &labelData)
if err != nil {
if err == io.EOF {
break
}
return err
}
}
fmt.Printf("Received: %v\n",
circuit.FileSize(conn.Stats.Sum()).String())
if once {
return nil
}
}
}
func garblerTestIO(size int64) error {
nc, err := net.Dial("tcp", port)
if err != nil {
return err
}
conn := p2p.NewConn(nc)
var sent int64
var label ot.Label
var labelData ot.LabelData
for sent < size {
err = conn.SendLabel(label, &labelData)
if err != nil {
return err
}
sent += int64(len(labelData))
}
if err := conn.Flush(); err != nil {
return err
}
if err := conn.Close(); err != nil {
return err
}
fmt.Printf("Sent: %v\n", circuit.FileSize(conn.Stats.Sum()).String())
return nil
}

View File

@ -0,0 +1,56 @@
//
// main.go
//
// Copyright (c) 2019-2023 Markku Rossi
//
// All rights reserved.
//
package main
import (
"flag"
"log"
"os"
"runtime/pprof"
)
var (
port = ":8080"
)
func main() {
evaluator := flag.Bool("e", false, "evaluator / garbler mode")
cpuprofile := flag.String("cpuprofile", "", "write cpu profile to `file`")
testIO := flag.Int64("test-io", 0, "test I/O performance")
flag.Parse()
log.SetFlags(0)
if len(*cpuprofile) > 0 {
f, err := os.Create(*cpuprofile)
if err != nil {
log.Fatal("could not create CPU profile: ", err)
}
defer f.Close()
if err := pprof.StartCPUProfile(f); err != nil {
log.Fatal("could not start CPU profile: ", err)
}
defer pprof.StopCPUProfile()
}
if *testIO > 0 {
if *evaluator {
err := evaluatorTestIO(*testIO, len(*cpuprofile) > 0)
if err != nil {
log.Fatal(err)
}
} else {
err := garblerTestIO(*testIO)
if err != nil {
log.Fatal(err)
}
}
return
}
}

175
bedlam/apps/iter/main.go Normal file
View File

@ -0,0 +1,175 @@
//
// main.go
//
// Copyright (c) 2019-2023 Markku Rossi
//
// All rights reserved.
//
package main
import (
"flag"
"fmt"
"log"
"math"
"os"
"runtime/pprof"
"strings"
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler"
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler/utils"
)
var template = `
package main
func main(a, b int%d) int {
return a * b
}
`
type result struct {
bits int
bestLimit int
bestCost uint64
worstLimit int
worstCost uint64
costs []uint64
}
func main() {
numWorkers := flag.Int("workers", 8, "number of workers")
startBits := flag.Int("start", 8, "start bit count")
endBits := flag.Int("end", 0xffffffff, "end bit count")
minLimit := flag.Int("min", 8, "treshold minimum limit")
maxLimit := flag.Int("max", 22, "treshold maximum limit")
cpuprofile := flag.String("cpuprofile", "", "write cpu profile to `file`")
flag.Parse()
if len(*cpuprofile) > 0 {
f, err := os.Create(*cpuprofile)
if err != nil {
log.Fatal("could not create CPU profile: ", err)
}
defer f.Close()
if err := pprof.StartCPUProfile(f); err != nil {
log.Fatal("could not start CPU profile: ", err)
}
defer pprof.StopCPUProfile()
}
results := make(map[int]*result)
ch := make(chan *result)
for i := 0; i < *numWorkers; i++ {
go func(bits int) {
for ; bits <= *endBits; bits += *numWorkers {
code := fmt.Sprintf(template, bits)
var bestLimit int
var bestCost uint64
var worstLimit int
var worstCost uint64
var costs []uint64
params := utils.NewParams()
for limit := *minLimit; limit <= *maxLimit; limit++ {
params.CircMultArrayTreshold = limit
circ, _, err := compiler.New(params).Compile(code, nil)
if err != nil {
log.Fatalf("Compilation %d:%d failed: %s\n%s",
bits, limit, err, code)
}
cost := circ.Cost()
costs = append(costs, cost)
if bestCost == 0 || cost < bestCost ||
(limit == 21 && cost <= bestCost) {
bestCost = cost
bestLimit = limit
}
if cost > worstCost {
worstCost = cost
worstLimit = limit
}
}
ch <- &result{
bits: bits,
bestLimit: bestLimit,
bestCost: bestCost,
worstLimit: worstLimit,
worstCost: worstCost,
costs: costs,
}
}
}(*startBits + i)
}
next := *startBits
outer:
for result := range ch {
results[result.bits] = result
for {
r, ok := results[next]
if !ok {
break
}
if r.bestLimit == 21 {
fmt.Printf("\t// %d: %d, %10d\t%.4f\t%s\n",
r.bits, r.bestLimit,
r.bestCost, float64(r.bestCost)/float64(r.worstCost),
Sparkline(r.costs))
} else {
fmt.Printf("\t%d: %d, // %10d\t%.4f\t%s\n",
r.bits, r.bestLimit,
r.bestCost, float64(r.bestCost)/float64(r.worstCost),
Sparkline(r.costs))
}
if next >= *endBits {
break outer
}
next++
}
}
}
// Sparkline creates a histogram chart of values. The chart is scaled
// to [min...max] containing differences between values.
func Sparkline(values []uint64) string {
if len(values) == 0 {
return ""
}
var min uint64 = math.MaxUint64
var max uint64
for _, v := range values {
if v < min {
min = v
}
if v > max {
max = v
}
}
delta := max - min
var sb strings.Builder
for _, v := range values {
var tick uint64
if delta == 0 {
tick = 4
} else {
tick = (v - min) * 7 / delta
}
if v == min && false {
sb.WriteString("\x1b[92m")
sb.WriteRune(rune(0x2581 + tick))
sb.WriteString("\x1b[0m")
} else {
sb.WriteRune(rune(0x2581 + tick))
}
}
return sb.String()
}

View File

@ -0,0 +1,30 @@
//
// main.go
//
// Copyright (c) 2019-2023 Markku Rossi
//
// All rights reserved.
//
package main
import (
"flag"
"fmt"
"log"
"os"
)
func main() {
flag.Parse()
log.SetFlags(0)
if len(flag.Args()) == 0 {
fmt.Printf("no files specified\n")
os.Exit(1)
}
if err := dumpObjects(flag.Args()); err != nil {
log.Fatal(err)
}
}

View File

@ -0,0 +1,61 @@
//
// main.go
//
// Copyright (c) 2019-2022 Markku Rossi
//
// All rights reserved.
//
package main
import (
"os"
"github.com/markkurossi/tabulate"
"source.quilibrium.com/quilibrium/monorepo/bedlam/circuit"
)
func dumpObjects(files []string) error {
type oCircuit struct {
name string
circuit *circuit.Circuit
}
var circuits []oCircuit
for _, file := range files {
if circuit.IsFilename(file) {
c, err := circuit.Parse(file)
if err != nil {
return err
}
circuits = append(circuits, oCircuit{
name: file,
circuit: c,
})
}
}
if len(circuits) > 0 {
tab := tabulate.New(tabulate.Github)
tab.Header("File")
tab.Header("XOR").SetAlign(tabulate.MR)
tab.Header("XNOR").SetAlign(tabulate.MR)
tab.Header("AND").SetAlign(tabulate.MR)
tab.Header("OR").SetAlign(tabulate.MR)
tab.Header("INV").SetAlign(tabulate.MR)
tab.Header("Gates").SetAlign(tabulate.MR)
tab.Header("xor").SetAlign(tabulate.MR)
tab.Header("!xor").SetAlign(tabulate.MR)
tab.Header("Wires").SetAlign(tabulate.MR)
for _, c := range circuits {
row := tab.Row()
row.Column(c.name)
c.circuit.TabulateRow(row)
}
tab.Print(os.Stdout)
}
return nil
}

36
bedlam/build.sh Executable file
View File

@ -0,0 +1,36 @@
#!/bin/bash
set -euxo pipefail
# This script builds the node binary for the current platform and statically links it with VDF static lib.
# Assumes that the VDF library has been built by running the generate.sh script in the `../vdf` directory.
ROOT_DIR="${ROOT_DIR:-$( cd "$(dirname "$(realpath "$( dirname "${BASH_SOURCE[0]}" )")")" >/dev/null 2>&1 && pwd )}"
NODE_DIR="$ROOT_DIR/bedlam"
BINARIES_DIR="$ROOT_DIR/target/release"
pushd "$NODE_DIR" > /dev/null
export CGO_ENABLED=1
os_type="$(uname)"
case "$os_type" in
"Darwin")
# Check if the architecture is ARM
if [[ "$(uname -m)" == "arm64" ]]; then
# MacOS ld doesn't support -Bstatic and -Bdynamic, so it's important that there is only a static version of the library
go build -ldflags "-linkmode 'external' -extldflags '-L$BINARIES_DIR -L/usr/local/lib/ -L/opt/homebrew/Cellar/openssl@3/3.4.1/lib -lstdc++ -lferret -ldl -lm -lcrypto -lssl'" "$@"
else
echo "Unsupported platform"
exit 1
fi
;;
"Linux")
export CGO_LDFLAGS="-L/usr/local/lib -ldl -lm -L$BINARIES_DIR -lstdc++ -lcrypto -lssl -lferret -static"
go build -ldflags "-linkmode 'external'" "$@"
;;
*)
echo "Unsupported platform"
exit 1
;;
esac

1
bedlam/circuit/aesni/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
c/aesni

View File

@ -0,0 +1,9 @@
all: benchmarks
benchmarks: c/aesni
go test -bench=.
c/aesni
c/aesni: c/aesni.c
gcc -Wall -maes -march=native -o $@ $+

View File

@ -0,0 +1,197 @@
/*
* Circuit garbling using AES-NI instructions.
*/
#include <wmmintrin.h>
#include <emmintrin.h>
#include <stdio.h>
#include <stdint.h>
#include <string.h>
#include <sys/time.h>
#define AES_BLOCK_SIZE 16
struct cipher
{
__m128i key[15];
};
typedef struct cipher cipher;
struct label
{
uint64_t d0;
uint64_t d1;
};
typedef struct label label;
static inline void
label_xor(label *l, label *o)
{
l->d0 ^= o->d0;
l->d1 ^= o->d1;
}
static inline void
label_mul2(label *l)
{
l->d0 <<= 1;
l->d0 |= (l->d1 >> 63);
l->d1 <<= 1;
}
static inline void
label_mul4(label *l)
{
l->d0 <<= 2;
l->d0 |= (l->d1 >> 62);
l->d1 <<= 2;
}
static inline label *
make_k(label *a, label *b, uint32_t t)
{
label tweak = {(uint64_t) t, 0};
label_mul2(a);
label_mul4(b);
label_xor(a, b);
label_xor(a, &tweak);
return a;
}
void
garble2(cipher *cipher, label *a, label *b, label *c, uint32_t t, label *ret)
{
label *label_k;
__m128i k;
__m128i pi;
label_k = make_k(a, b, t);
k = _mm_set_epi64x(label_k->d0, label_k->d1);
/* Perform the AES encryption rounds. */
pi = _mm_xor_si128(k, cipher->key[0]);
pi = _mm_aesenc_si128(pi, cipher->key[1]);
pi = _mm_aesenc_si128(pi, cipher->key[2]);
pi = _mm_aesenc_si128(pi, cipher->key[3]);
pi = _mm_aesenc_si128(pi, cipher->key[4]);
pi = _mm_aesenc_si128(pi, cipher->key[5]);
pi = _mm_aesenc_si128(pi, cipher->key[6]);
pi = _mm_aesenc_si128(pi, cipher->key[7]);
pi = _mm_aesenc_si128(pi, cipher->key[8]);
pi = _mm_aesenc_si128(pi, cipher->key[9]);
pi = _mm_aesenc_si128(pi, cipher->key[10]);
pi = _mm_aesenc_si128(pi, cipher->key[11]);
pi = _mm_aesenc_si128(pi, cipher->key[12]);
pi = _mm_aesenc_si128(pi, cipher->key[13]);
pi = _mm_aesenclast_si128(pi, cipher->key[14]);
ret->d0 = ((uint64_t *) &pi)[0];
ret->d1 = ((uint64_t *) &pi)[1];
label_xor(ret, label_k);
label_xor(ret, c);
}
void
make_cipher(cipher *cipher, const unsigned char key[AES_BLOCK_SIZE])
{
cipher->key[0] = _mm_loadu_si128((__m128i *)key);
cipher->key[1] = _mm_aeskeygenassist_si128(cipher->key[0], 0x01);
cipher->key[2] = _mm_aeskeygenassist_si128(cipher->key[1], 0x02);
cipher->key[3] = _mm_aeskeygenassist_si128(cipher->key[2], 0x04);
cipher->key[4] = _mm_aeskeygenassist_si128(cipher->key[3], 0x08);
cipher->key[5] = _mm_aeskeygenassist_si128(cipher->key[4], 0x10);
cipher->key[6] = _mm_aeskeygenassist_si128(cipher->key[5], 0x20);
cipher->key[7] = _mm_aeskeygenassist_si128(cipher->key[6], 0x40);
cipher->key[8] = _mm_aeskeygenassist_si128(cipher->key[7], 0x80);
cipher->key[9] = _mm_aeskeygenassist_si128(cipher->key[8], 0x1B);
cipher->key[10] = _mm_aeskeygenassist_si128(cipher->key[9], 0x36);
cipher->key[11] = _mm_aeskeygenassist_si128(cipher->key[10], 0x6C);
cipher->key[12] = _mm_aeskeygenassist_si128(cipher->key[11], 0xD8);
cipher->key[13] = _mm_aeskeygenassist_si128(cipher->key[12], 0xAB);
cipher->key[14] = _mm_aeskeygenassist_si128(cipher->key[13], 0x4D);
}
static inline __m128i
garble(cipher *cipher, __m128i a, __m128i b, __m128i c, uint32_t t)
{
__m128i k;
__m128i pi;
k = _mm_xor_si128(_mm_xor_si128(_mm_slli_epi32(a, 1),
_mm_slli_epi32(b, 2)),
_mm_set1_epi32(t));
/* Perform the AES encryption rounds. */
pi = _mm_xor_si128(k, cipher->key[0]);
pi = _mm_aesenc_si128(pi, cipher->key[1]);
pi = _mm_aesenc_si128(pi, cipher->key[2]);
pi = _mm_aesenc_si128(pi, cipher->key[3]);
pi = _mm_aesenc_si128(pi, cipher->key[4]);
pi = _mm_aesenc_si128(pi, cipher->key[5]);
pi = _mm_aesenc_si128(pi, cipher->key[6]);
pi = _mm_aesenc_si128(pi, cipher->key[7]);
pi = _mm_aesenc_si128(pi, cipher->key[8]);
pi = _mm_aesenc_si128(pi, cipher->key[9]);
pi = _mm_aesenc_si128(pi, cipher->key[10]);
pi = _mm_aesenc_si128(pi, cipher->key[11]);
pi = _mm_aesenc_si128(pi, cipher->key[12]);
pi = _mm_aesenc_si128(pi, cipher->key[13]);
pi = _mm_aesenclast_si128(pi, cipher->key[14]);
return _mm_xor_si128(_mm_xor_si128(pi, k), c);
}
void
report(char *l, struct timeval *begin, struct timeval *end, int rounds)
{
int64_t d;
d = (int64_t) end->tv_sec - (int64_t) begin->tv_sec;
d *= 1000000;
d += (int64_t) end->tv_usec - (int64_t) begin->tv_usec;
d *= 1000;
d *= 100;
d /= rounds;
printf("%-20s\t%d\t\t%.2f ns/op\n", l, rounds, (double) d / 100);
}
int
main()
{
const unsigned char key[AES_BLOCK_SIZE] = "0123456789ABCDEF";
cipher cipher = {0};
struct timeval begin, end;
int rounds = 23802664;
__m128i a = _mm_set1_epi32(42);
__m128i b = _mm_set1_epi32(43);
__m128i c = _mm_set1_epi32(44);
label la = {42, 0};
label lb = {43, 0};
label lc = {44, 0};
label lr;
make_cipher(&cipher, key);
gettimeofday(&begin, NULL);
for (int i = 0; i < rounds; i++)
garble(&cipher, a, b, c, (uint32_t) i);
gettimeofday(&end, NULL);
report("AES-NI", &begin, &end, rounds);
gettimeofday(&begin, NULL);
for (int i = 0; i < rounds; i++)
garble2(&cipher, &la, &lb, &lc, (uint32_t) i, &lr);
gettimeofday(&end, NULL);
report("AES-NI+C", &begin, &end, rounds);
return 0;
}

View File

@ -0,0 +1,144 @@
//
// enc_test.go
//
// Copyright (c) 2019-2023 Markku Rossi
//
// All rights reserved.
//
package aesni
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/binary"
"io"
"testing"
)
type LabelX struct {
d0 uint64
d1 uint64
}
func NewLabelX(rand io.Reader) (LabelX, error) {
var buf [16]byte
var result LabelX
_, err := rand.Read(buf[:])
if err != nil {
return result, err
}
result.SetData(&buf)
return result, nil
}
func (l *LabelX) SetData(buf *[16]byte) {
l.d0 = binary.BigEndian.Uint64((*buf)[0:8])
l.d1 = binary.BigEndian.Uint64((*buf)[8:16])
}
func (l *LabelX) GetData(buf *[16]byte) {
binary.BigEndian.PutUint64((*buf)[0:8], l.d0)
binary.BigEndian.PutUint64((*buf)[8:16], l.d1)
}
func (l *LabelX) Xor(o LabelX) {
l.d0 ^= o.d0
l.d1 ^= o.d1
}
func (l *LabelX) Mul2() {
l.d0 <<= 1
l.d0 |= (l.d1 >> 63)
l.d1 <<= 1
}
func (l *LabelX) Mul4() {
l.d0 <<= 2
l.d0 |= (l.d1 >> 62)
l.d1 <<= 2
}
func TestLabelXor(t *testing.T) {
val := uint64(0b0101010101010101010101010101010101010101010101010101010101010101)
a := LabelX{
d0: val,
d1: val << 1,
}
b := LabelX{
d0: 0xffffffffffffffff,
d1: 0xffffffffffffffff,
}
a.Xor(b)
if a.d0 != val<<1 {
t.Errorf("Xor: unexpected d0=%x, epected %x", a.d0, val<<1)
}
if a.d1 != val {
t.Errorf("Xor: unexpected d1=%x, epected %x", a.d1, val)
}
}
func NewTweakX(tweak uint32) LabelX {
return LabelX{
d1: uint64(tweak),
}
}
func encryptX(alg cipher.Block, a, b, c LabelX, t uint32,
buf *[16]byte) LabelX {
k := makeKX(a, b, t)
k.GetData(buf)
alg.Encrypt(buf[:], buf[:])
var pi LabelX
pi.SetData(buf)
pi.Xor(k)
pi.Xor(c)
return pi
}
func makeKX(a, b LabelX, t uint32) LabelX {
a.Mul2()
b.Mul4()
a.Xor(b)
a.Xor(NewTweakX(t))
return a
}
func BenchmarkLabelX(b *testing.B) {
var key [32]byte
cipher, err := aes.NewCipher(key[:])
if err != nil {
b.Fatalf("Failed to create cipher: %s", err)
}
al, err := NewLabelX(rand.Reader)
if err != nil {
b.Fatalf("Failed to create label: %s", err)
}
bl, err := NewLabelX(rand.Reader)
if err != nil {
b.Fatalf("Failed to create label: %s", err)
}
cl, err := NewLabelX(rand.Reader)
if err != nil {
b.Fatalf("Failed to create label: %s", err)
}
b.ResetTimer()
var buf [16]byte
for i := 0; i < b.N; i++ {
encryptX(cipher, al, bl, cl, uint32(i), &buf)
}
}

42
bedlam/circuit/analyze.go Normal file
View File

@ -0,0 +1,42 @@
//
// Copyright (c) 2021 Markku Rossi
//
// All rights reserved.
//
package circuit
import (
"fmt"
)
// Analyze identifies potential optimizations for the circuit.
func (c *Circuit) Analyze() {
fmt.Printf("analyzing circuit %v\n", c)
from := make([][]Gate, c.NumWires)
to := make([][]Gate, c.NumWires)
// Collect wire inputs and outputs.
for _, g := range c.Gates {
switch g.Op {
case XOR, XNOR, AND, OR:
to[g.Input1] = append(to[g.Input1], g)
fallthrough
case INV:
to[g.Input0] = append(to[g.Input0], g)
from[g.Output] = append(to[g.Output], g)
}
}
// INV gates as single output of input gate.
for _, g := range c.Gates {
if g.Op != INV {
continue
}
if len(to[g.Input0]) == 1 && len(from[g.Input0]) == 1 {
fmt.Printf("%v -> %v\n", from[g.Input0][0].Op, g.Op)
}
}
}

270
bedlam/circuit/circuit.go Normal file
View File

@ -0,0 +1,270 @@
//
// Copyright (c) 2019-2023 Markku Rossi
//
// All rights reserved.
//
package circuit
import (
"fmt"
"io"
"math"
"github.com/markkurossi/tabulate"
)
// Operation specifies gate function.
type Operation byte
// Gate functions.
const (
XOR Operation = iota
XNOR
AND
OR
INV
Count
NumLevels
MaxWidth
)
// Known multi-party computation roles.
const (
IDGarbler int = iota
IDEvaluator
)
// Stats holds statistics about circuit operations.
type Stats [MaxWidth + 1]uint64
// Add adds the argument statistics to this statistics object.
func (stats *Stats) Add(o Stats) {
for i := XOR; i < Count; i++ {
stats[i] += o[i]
}
stats[Count]++
for i := NumLevels; i <= MaxWidth; i++ {
if o[i] > stats[i] {
stats[i] = o[i]
}
}
}
// Count returns the number of gates in the statistics object.
func (stats Stats) Count() uint64 {
var result uint64
for i := XOR; i < Count; i++ {
result += stats[i]
}
return result
}
// Cost computes the relative computational cost of the circuit.
func (stats Stats) Cost() uint64 {
return (stats[AND]+stats[INV])*2 + stats[OR]*3
}
func (stats Stats) String() string {
var result string
for i := XOR; i < Count; i++ {
v := stats[i]
if len(result) > 0 {
result += " "
}
result += fmt.Sprintf("%s=%d", i, v)
}
result += fmt.Sprintf(" xor=%d", stats[XOR]+stats[XNOR])
result += fmt.Sprintf(" !xor=%d", stats[AND]+stats[OR]+stats[INV])
result += fmt.Sprintf(" levels=%d", stats[NumLevels])
result += fmt.Sprintf(" width=%d", stats[MaxWidth])
return result
}
func (op Operation) String() string {
switch op {
case XOR:
return "XOR"
case XNOR:
return "XNOR"
case AND:
return "AND"
case OR:
return "OR"
case INV:
return "INV"
case Count:
return "#"
default:
return fmt.Sprintf("{Operation %d}", op)
}
}
// Circuit specifies a boolean circuit.
type Circuit struct {
NumGates int
NumWires int
Inputs IO
Outputs IO
Gates []Gate
Stats Stats
}
func (c *Circuit) String() string {
return fmt.Sprintf("#gates=%d (%s) #w=%d", c.NumGates, c.Stats, c.NumWires)
}
// NumParties returns the number of parties needed for the circuit.
func (c *Circuit) NumParties() int {
return len(c.Inputs)
}
// PrintInputs prints the circuit inputs.
func (c *Circuit) PrintInputs(id int, input []string) {
for i := 0; i < len(c.Inputs); i++ {
if i == id {
fmt.Print(" + ")
} else {
fmt.Print(" - ")
}
fmt.Printf("In%d: %s\n", i, c.Inputs[i])
}
fmt.Printf(" - Out: %s\n", c.Outputs)
fmt.Printf(" - In: %s\n", input)
}
// TabulateStats prints the circuit stats as a table to the specified
// output Writer.
func (c *Circuit) TabulateStats(out io.Writer) {
tab := tabulate.New(tabulate.UnicodeLight)
tab.Header("XOR").SetAlign(tabulate.MR)
tab.Header("XNOR").SetAlign(tabulate.MR)
tab.Header("AND").SetAlign(tabulate.MR)
tab.Header("OR").SetAlign(tabulate.MR)
tab.Header("INV").SetAlign(tabulate.MR)
tab.Header("Gates").SetAlign(tabulate.MR)
tab.Header("XOR").SetAlign(tabulate.MR)
tab.Header("!XOR").SetAlign(tabulate.MR)
tab.Header("Wires").SetAlign(tabulate.MR)
c.TabulateRow(tab.Row())
tab.Print(out)
}
// TabulateRow tabulates circuit statistics to the argument tabulation
// row.
func (c *Circuit) TabulateRow(row *tabulate.Row) {
var sumGates uint64
for op := XOR; op < Count; op++ {
row.Column(fmt.Sprintf("%v", c.Stats[op]))
sumGates += c.Stats[op]
}
row.Column(fmt.Sprintf("%v", sumGates))
row.Column(fmt.Sprintf("%v", c.Stats[XOR]+c.Stats[XNOR]))
row.Column(fmt.Sprintf("%v", c.Stats[AND]+c.Stats[OR]+c.Stats[INV]))
row.Column(fmt.Sprintf("%v", c.NumWires))
}
// Cost computes the relative computational cost of the circuit.
func (c *Circuit) Cost() uint64 {
return c.Stats.Cost()
}
// Dump prints a debug dump of the circuit.
func (c *Circuit) Dump() {
fmt.Printf("circuit %s\n", c)
for id, gate := range c.Gates {
fmt.Printf("%04d\t%s\n", id, gate)
}
}
// AssignLevels assigns levels for gates. The level desribes how many
// steps away the gate is from input wires.
func (c *Circuit) AssignLevels() {
levels := make([]Level, c.NumWires)
countByLevel := make([]uint32, c.NumWires)
var max Level
for idx, gate := range c.Gates {
level := levels[gate.Input0]
if gate.Op != INV {
l1 := levels[gate.Input1]
if l1 > level {
level = l1
}
}
c.Gates[idx].Level = level
countByLevel[level]++
level++
levels[gate.Output] = level
if level > max {
max = level
}
}
c.Stats[NumLevels] = uint64(max)
var maxWidth uint32
for _, count := range countByLevel {
if count > maxWidth {
maxWidth = count
}
}
if false {
for i := 0; i < int(max); i++ {
fmt.Printf("%v,%v\n", i, countByLevel[i])
}
}
c.Stats[MaxWidth] = uint64(maxWidth)
}
// Level defines gate's distance from input wires.
type Level uint32
// Gate specifies a boolean gate.
type Gate struct {
Input0 Wire
Input1 Wire
Output Wire
Op Operation
Level Level
}
func (g Gate) String() string {
return fmt.Sprintf("%v %v %v", g.Inputs(), g.Op, g.Output)
}
// Inputs returns gate input wires.
func (g Gate) Inputs() []Wire {
switch g.Op {
case XOR, XNOR, AND, OR:
return []Wire{g.Input0, g.Input1}
case INV:
return []Wire{g.Input0}
default:
panic(fmt.Sprintf("unsupported gate type %s", g.Op))
}
}
// Wire specifies a wire ID.
type Wire uint32
// InvalidWire specifies an invalid wire ID.
const InvalidWire Wire = math.MaxUint32
// Int returns the wire ID as integer.
func (w Wire) Int() int {
if uint64(w) > math.MaxInt {
panic(w)
}
return int(w)
}
func (w Wire) String() string {
return fmt.Sprintf("w%d", w)
}

View File

@ -0,0 +1,19 @@
//
// Copyright (c) 2022-2023 Markku Rossi
//
// All rights reserved.
//
package circuit
import (
"testing"
"unsafe"
)
func TestSize(t *testing.T) {
var g Gate
if unsafe.Sizeof(g) != 20 {
t.Errorf("unexpected gate size: got %v, expected 20", unsafe.Sizeof(g))
}
}

View File

@ -0,0 +1,91 @@
//
// Copyright (c) 2020-2023 Markku Rossi
//
// All rights reserved.
//
package circuit
import (
"fmt"
"math/big"
)
// Compute evaluates the circuit with the given input values.
func (c *Circuit) Compute(inputs []*big.Int) ([]*big.Int, error) {
// Flatten circuit arguments.
var args IO
for _, io := range c.Inputs {
if len(io.Compound) > 0 {
args = append(args, io.Compound...)
} else {
args = append(args, io)
}
}
if len(inputs) != len(args) {
return nil, fmt.Errorf("invalid inputs: got %d, expected %d",
len(inputs), len(args))
}
// Flatten inputs and arguments.
wires := make([]byte, c.NumWires)
var w int
for idx, io := range args {
for bit := 0; bit < int(io.Type.Bits); bit++ {
wires[w] = byte(inputs[idx].Bit(bit))
w++
}
}
// Evaluate circuit.
for _, gate := range c.Gates {
var result byte
switch gate.Op {
case XOR:
result = wires[gate.Input0] ^ wires[gate.Input1]
case XNOR:
if wires[gate.Input0]^wires[gate.Input1] == 0 {
result = 1
} else {
result = 0
}
case AND:
result = wires[gate.Input0] & wires[gate.Input1]
case OR:
result = wires[gate.Input0] | wires[gate.Input1]
case INV:
if wires[gate.Input0] == 0 {
result = 1
} else {
result = 0
}
default:
return nil, fmt.Errorf("invalid gate %s", gate.Op)
}
wires[gate.Output] = result
}
// Construct outputs
w = c.NumWires - c.Outputs.Size()
var result []*big.Int
for _, io := range c.Outputs {
r := new(big.Int)
for bit := 0; bit < int(io.Type.Bits); bit++ {
if wires[w] != 0 {
r.SetBit(r, bit, 1)
}
w++
}
result = append(result, r)
}
return result, nil
}

View File

@ -0,0 +1,19 @@
<svg xmlns="http://www.w3.org/2000/svg" width="100" height="100">
<g fill="none" stroke="#000" stroke-width="1">
<path d="M 25 25
h 50
v 25
a 25 25 0 1 1 -50 0
v -25
z" />
<path d="M 35 0
v 25
z" />
<path d="M 65 0
v 25
z" />
<path d="M 50 75
v 25
z" />
</g>
</svg>

After

Width:  |  Height:  |  Size: 429 B

View File

@ -0,0 +1,20 @@
<svg xmlns="http://www.w3.org/2000/svg" width="100" height="100">
<g fill="none" stroke="#000" stroke-width="1">
<path d="M 25 25
h 50
l -25 43
z" />
<circle cx="50"
cy="73.5"
r="5" />
<path d="M 35 0
v 25
z" />
<path d="M 65 0
v 25
z" />
<path d="M 50 79
v 25
z" />
</g>
</svg>

After

Width:  |  Height:  |  Size: 444 B

View File

@ -0,0 +1,21 @@
<svg xmlns="http://www.w3.org/2000/svg" width="100" height="100">
<g fill="none" stroke="#000" stroke-width="1">
<path d="M 25 20
c 10 10 40 10 50 0" />
<path d="M 75 20
v 30
s 0 10 -25 25 " />
<path d="M 25 20
v 30
s 0 10 25 25 " />
<path d="M 35 0
v 25
z" />
<path d="M 65 0
v 25
z" />
<path d="M 50 75
v 25
z" />
</g>
</svg>

After

Width:  |  Height:  |  Size: 499 B

View File

@ -0,0 +1,6 @@
<svg xmlns="http://www.w3.org/2000/svg" width="100" height="100">
<g fill="none" stroke="#000" stroke-width="1">
<path d="M 25 20
C 25 30 35 30 35 40" />
</g>
</svg>

After

Width:  |  Height:  |  Size: 187 B

View File

@ -0,0 +1,26 @@
<svg xmlns="http://www.w3.org/2000/svg" width="100" height="100">
<g fill="none" stroke="#000" stroke-width="1">
<path d="M 25 20
c 10 10 40 10 50 0" />
<path d="M 25 25
c 10 10 40 10 50 0" />
<path d="M 75 25
v 25
s 0 10 -25 25 " />
<path d="M 25 25
v 25
s 0 10 25 25 " />
<circle cx="50"
cy="80"
r="5" />
<path d="M 35 0
v 25
z" />
<path d="M 65 0
v 25
z" />
<path d="M 50 85
v 25
z" />
</g>
</svg>

After

Width:  |  Height:  |  Size: 617 B

View File

@ -0,0 +1,23 @@
<svg xmlns="http://www.w3.org/2000/svg" width="100" height="100">
<g fill="none" stroke="#000" stroke-width="1">
<path d="M 25 20
c 10 10 40 10 50 0" />
<path d="M 25 25
c 10 10 40 10 50 0" />
<path d="M 75 25
v 25
s 0 10 -25 25 " />
<path d="M 25 25
v 25
s 0 10 25 25 " />
<path d="M 35 0
v 25
z" />
<path d="M 65 0
v 25
z" />
<path d="M 50 75
v 25
z" />
</g>
</svg>

After

Width:  |  Height:  |  Size: 556 B

56
bedlam/circuit/dot.go Normal file
View File

@ -0,0 +1,56 @@
//
// Copyright (c) 2019-2023 Markku Rossi
//
// All rights reserved.
//
package circuit
import (
"fmt"
"io"
)
// Dot creates graphviz dot output of the circuit.
func (c *Circuit) Dot(out io.Writer) {
fmt.Fprintf(out, "digraph circuit\n{\n")
fmt.Fprintf(out, " overlap=scale;\n")
fmt.Fprintf(out, " node\t[fontname=\"Helvetica\"];\n")
fmt.Fprintf(out, " {\n node [shape=plaintext];\n")
for w := 0; w < c.NumWires; w++ {
fmt.Fprintf(out, " w%d\t[label=\"%d\"];\n", w, w)
}
fmt.Fprintf(out, " }\n")
fmt.Fprintf(out, " {\n node [shape=box];\n")
for idx, gate := range c.Gates {
fmt.Fprintf(out, " g%d\t[label=\"%s\"];\n", idx, gate.Op)
}
fmt.Fprintf(out, " }\n")
if true {
fmt.Fprintf(out, " { rank=same")
var numInputs int
for _, input := range c.Inputs {
numInputs += int(input.Type.Bits)
}
for w := 0; w < numInputs; w++ {
fmt.Fprintf(out, "; w%d", w)
}
fmt.Fprintf(out, ";}\n")
fmt.Fprintf(out, " { rank=same")
for w := 0; w < c.Outputs.Size(); w++ {
fmt.Fprintf(out, "; w%d", c.NumWires-w-1)
}
fmt.Fprintf(out, ";}\n")
}
for idx, gate := range c.Gates {
for _, i := range gate.Inputs() {
fmt.Fprintf(out, " w%d -> g%d;\n", i, idx)
}
fmt.Fprintf(out, " g%d -> w%d;\n", idx, gate.Output)
}
fmt.Fprintf(out, "}\n")
}

View File

@ -0,0 +1,91 @@
//
// enc_test.go
//
// Copyright (c) 2019-2023 Markku Rossi
//
// All rights reserved.
//
package circuit
import (
"crypto/aes"
"crypto/rand"
"testing"
"source.quilibrium.com/quilibrium/monorepo/bedlam/ot"
)
func TestEnc(t *testing.T) {
a, _ := ot.NewLabel(rand.Reader)
b, _ := ot.NewLabel(rand.Reader)
c, _ := ot.NewLabel(rand.Reader)
tweak := uint32(42)
var key [32]byte
cipher, err := aes.NewCipher(key[:])
if err != nil {
t.Fatalf("Failed to create cipher: %s", err)
}
var data ot.LabelData
encrypted := encrypt(cipher, a, b, c, tweak, &data)
if err != nil {
t.Fatalf("Encrypt failed: %s", err)
}
plain := decrypt(cipher, a, b, tweak, encrypted, &data)
if !c.Equal(plain) {
t.Fatalf("Encrypt-decrypt failed")
}
}
func BenchmarkEnc(b *testing.B) {
var key [32]byte
cipher, err := aes.NewCipher(key[:])
if err != nil {
b.Fatalf("Failed to create cipher: %s", err)
}
al, err := ot.NewLabel(rand.Reader)
if err != nil {
b.Fatalf("Failed to create label: %s", err)
}
bl, err := ot.NewLabel(rand.Reader)
if err != nil {
b.Fatalf("Failed to create label: %s", err)
}
cl, err := ot.NewLabel(rand.Reader)
if err != nil {
b.Fatalf("Failed to create label: %s", err)
}
b.ResetTimer()
var data ot.LabelData
for i := 0; i < b.N; i++ {
encrypt(cipher, al, bl, cl, uint32(i), &data)
}
}
func BenchmarkEncHalf(b *testing.B) {
var key [32]byte
cipher, err := aes.NewCipher(key[:])
if err != nil {
b.Fatalf("Failed to create cipher: %s", err)
}
xl, err := ot.NewLabel(rand.Reader)
if err != nil {
b.Fatalf("Failed to create label: %s", err)
}
b.ResetTimer()
var data ot.LabelData
for i := 0; i < b.N; i++ {
encryptHalf(cipher, xl, uint32(i), &data)
}
}

115
bedlam/circuit/eval.go Normal file
View File

@ -0,0 +1,115 @@
//
// Copyright (c) 2019-2021 Markku Rossi
//
// All rights reserved.
//
package circuit
import (
"crypto/aes"
"fmt"
"source.quilibrium.com/quilibrium/monorepo/bedlam/ot"
)
// Eval evaluates the circuit.
func (c *Circuit) Eval(key []byte, wires []ot.Label,
garbled [][]ot.Label) error {
alg, err := aes.NewCipher(key)
if err != nil {
return err
}
var data ot.LabelData
var id uint32
for i := 0; i < len(c.Gates); i++ {
gate := &c.Gates[i]
var a, b, c ot.Label
switch gate.Op {
case XOR, XNOR, AND, OR:
a = wires[gate.Input0]
b = wires[gate.Input1]
case INV:
a = wires[gate.Input0]
default:
return fmt.Errorf("invalid operation %s", gate.Op)
}
var output ot.Label
switch gate.Op {
case XOR, XNOR:
a.Xor(b)
output = a
case AND:
row := garbled[i]
if len(row) != 2 {
return fmt.Errorf("corrupted ciruit: AND row length: %d",
len(row))
}
sa := a.S()
sb := b.S()
j0 := id
j1 := id + 1
id += 2
tg := row[0]
te := row[1]
wg := encryptHalf(alg, a, j0, &data)
if sa {
wg.Xor(tg)
}
we := encryptHalf(alg, b, j1, &data)
if sb {
we.Xor(te)
we.Xor(a)
}
output = wg
output.Xor(we)
case OR:
row := garbled[i]
index := idx(a, b)
if index > 0 {
// First row is zero and not transmitted.
index--
if index >= len(row) {
return fmt.Errorf("corrupted circuit: index %d >= row %d",
index, len(row))
}
c = row[index]
}
output = decrypt(alg, a, b, id, c, &data)
id++
case INV:
row := garbled[i]
index := idxUnary(a)
if index > 0 {
// First row is zero and not transmitted.
index--
if index >= len(row) {
return fmt.Errorf("corrupted circuit: index %d >= row %d",
index, len(row))
}
c = row[index]
}
output = decrypt(alg, a, ot.Label{}, id, c, &data)
id++
}
wires[gate.Output] = output
}
return nil
}

159
bedlam/circuit/evaluator.go Normal file
View File

@ -0,0 +1,159 @@
//
// evaluator.go
//
// Copyright (c) 2019-2023 Markku Rossi
//
// All rights reserved.
//
package circuit
import (
"fmt"
"math/big"
"source.quilibrium.com/quilibrium/monorepo/bedlam/ot"
"source.quilibrium.com/quilibrium/monorepo/bedlam/p2p"
)
var (
debug = false
)
// Evaluator runs the evaluator on the P2P network.
func Evaluator(conn *p2p.Conn, oti ot.OT, circ *Circuit, inputs *big.Int,
verbose bool) ([]*big.Int, error) {
timing := NewTiming()
garbled := make([][]ot.Label, circ.NumGates)
// Receive program info.
if verbose {
fmt.Printf(" - Waiting for circuit info...\n")
}
key, err := conn.ReceiveData()
if err != nil {
return nil, err
}
// Receive garbled tables.
timing.Sample("Wait", nil)
if verbose {
fmt.Printf(" - Receiving garbled circuit...\n")
}
count, err := conn.ReceiveUint32()
if err != nil {
return nil, err
}
if count != circ.NumGates {
return nil, fmt.Errorf("wrong number of gates: got %d, expected %d",
count, circ.NumGates)
}
var label ot.Label
var labelData ot.LabelData
for i := 0; i < circ.NumGates; i++ {
count, err := conn.ReceiveUint32()
if err != nil {
return nil, err
}
values := make([]ot.Label, count)
for j := 0; j < count; j++ {
err := conn.ReceiveLabel(&label, &labelData)
if err != nil {
return nil, err
}
values[j] = label
}
garbled[i] = values
}
wires := make([]ot.Label, circ.NumWires)
// Receive peer inputs.
for i := 0; i < int(circ.Inputs[0].Type.Bits); i++ {
err := conn.ReceiveLabel(&label, &labelData)
if err != nil {
return nil, err
}
wires[Wire(i)] = label
}
// Init oblivious transfer.
err = oti.InitReceiver(conn)
if err != nil {
return nil, err
}
ioStats := conn.Stats.Sum()
timing.Sample("Recv", []string{FileSize(ioStats).String()})
// Query our inputs.
if verbose {
fmt.Printf(" - Querying our inputs...\n")
}
// Wire offset.
if err := conn.SendUint32(int(circ.Inputs[0].Type.Bits)); err != nil {
return nil, err
}
// Wire count.
if err := conn.SendUint32(int(circ.Inputs[1].Type.Bits)); err != nil {
return nil, err
}
if err := conn.Flush(); err != nil {
return nil, err
}
flags := make([]bool, int(circ.Inputs[1].Type.Bits))
for i := 0; i < int(circ.Inputs[1].Type.Bits); i++ {
if inputs.Bit(i) == 1 {
flags[i] = true
}
}
if err := oti.Receive(flags, wires[circ.Inputs[0].Type.Bits:]); err != nil {
return nil, err
}
xfer := conn.Stats.Sum() - ioStats
ioStats = conn.Stats.Sum()
timing.Sample("Inputs", []string{FileSize(xfer).String()})
// Evaluate gates.
if verbose {
fmt.Printf(" - Evaluating circuit...\n")
}
err = circ.Eval(key[:], wires, garbled)
if err != nil {
return nil, err
}
timing.Sample("Eval", nil)
// Resolve result values.
var labels []ot.Label
for i := 0; i < circ.Outputs.Size(); i++ {
r := wires[Wire(circ.NumWires-circ.Outputs.Size()+i)]
labels = append(labels, r)
}
for _, l := range labels {
if err := conn.SendLabel(l, &labelData); err != nil {
return nil, err
}
}
if err := conn.Flush(); err != nil {
return nil, err
}
result, err := conn.ReceiveData()
if err != nil {
return nil, err
}
raw := big.NewInt(0).SetBytes(result)
xfer = conn.Stats.Sum() - ioStats
timing.Sample("Result", []string{FileSize(xfer).String()})
if verbose {
timing.Print(conn.Stats)
}
return circ.Outputs.Split(raw), nil
}

410
bedlam/circuit/garble.go Normal file
View File

@ -0,0 +1,410 @@
//
// Copyright (c) 2019-2023 Markku Rossi
//
// All rights reserved.
//
package circuit
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/binary"
"fmt"
"source.quilibrium.com/quilibrium/monorepo/bedlam/ot"
)
var (
verbose = false
)
func idxUnary(l0 ot.Label) int {
if l0.S() {
return 1
}
return 0
}
func idx(l0, l1 ot.Label) int {
var ret int
if l0.S() {
ret |= 0x2
}
if l1.S() {
ret |= 0x1
}
return ret
}
func encrypt(alg cipher.Block, a, b, c ot.Label, t uint32,
data *ot.LabelData) ot.Label {
k := makeK(a, b, t)
k.GetData(data)
alg.Encrypt(data[:], data[:])
var pi ot.Label
pi.SetData(data)
pi.Xor(k)
pi.Xor(c)
return pi
}
func decrypt(alg cipher.Block, a, b ot.Label, t uint32, c ot.Label,
data *ot.LabelData) ot.Label {
k := makeK(a, b, t)
k.GetData(data)
alg.Encrypt(data[:], data[:])
var crypted ot.Label
crypted.SetData(data)
c.Xor(crypted)
c.Xor(k)
return c
}
func makeK(a, b ot.Label, t uint32) ot.Label {
a.Mul2()
b.Mul4()
a.Xor(b)
a.Xor(ot.NewTweak(t))
return a
}
// Hash function for half gates: Hπ(x, i) to be π(K) ⊕ K where K = 2x ⊕ i
func encryptHalfReference(alg cipher.Block, x ot.Label, i uint32,
data *ot.LabelData) ot.Label {
k := makeKHalf(x, i)
k.GetData(data)
alg.Encrypt(data[:], data[:])
var pi ot.Label
pi.SetData(data)
pi.Xor(k)
return pi
}
// Optimized version of encryptHalfReference. Label operations are
// inlined below, producing about 11% performance improvements.
func encryptHalf(alg cipher.Block, x ot.Label, i uint32,
data *ot.LabelData) ot.Label {
// k := makeKHalf(x, i) {
k := x
// k.Mul2()
k.D0 <<= 1
k.D0 |= (k.D1 >> 63)
k.D1 <<= 1
// k.Xor(ot.NewTweak(i))
k.D1 ^= uint64(i)
// }
// k.GetData(data) {
binary.BigEndian.PutUint64(data[0:8], k.D0)
binary.BigEndian.PutUint64(data[8:16], k.D1)
// }
alg.Encrypt(data[:], data[:])
var pi ot.Label
// pi.SetData(data) {
pi.D0 = binary.BigEndian.Uint64((*data)[0:8])
pi.D1 = binary.BigEndian.Uint64((*data)[8:16])
// }
// pi.Xor(k) {
pi.D0 ^= k.D0
pi.D1 ^= k.D1
// }
return pi
}
// K = 2x ⊕ i
func makeKHalf(x ot.Label, i uint32) ot.Label {
x.Mul2()
x.Xor(ot.NewTweak(i))
return x
}
func makeLabels(r ot.Label) (ot.Wire, error) {
l0, err := ot.NewLabel(rand.Reader)
if err != nil {
return ot.Wire{}, err
}
l1 := l0
l1.Xor(r)
return ot.Wire{
L0: l0,
L1: l1,
}, nil
}
// Garbled contains garbled circuit information.
type Garbled struct {
R ot.Label
Wires []ot.Wire
Gates [][]ot.Label
}
// Lambda returns the lambda value of the wire.
func (g *Garbled) Lambda(wire Wire) uint {
if g.Wires[wire].L0.S() {
return 1
}
return 0
}
// SetLambda sets the lambda value of the wire.
func (g *Garbled) SetLambda(wire Wire, val uint) {
w := g.Wires[wire]
if val == 0 {
w.L0.SetS(false)
} else {
w.L0.SetS(true)
}
g.Wires[wire] = w
}
// Garble garbles the circuit.
func (c *Circuit) Garble(key []byte) (*Garbled, error) {
// Create R.
r, err := ot.NewLabel(rand.Reader)
if err != nil {
return nil, err
}
r.SetS(true)
garbled := make([][]ot.Label, c.NumGates)
alg, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
// Wire labels.
wires := make([]ot.Wire, c.NumWires)
// Assing all input wires.
for i := 0; i < c.Inputs.Size(); i++ {
w, err := makeLabels(r)
if err != nil {
return nil, err
}
wires[i] = w
}
// Garble gates.
var data ot.LabelData
var id uint32
for i := 0; i < len(c.Gates); i++ {
gate := &c.Gates[i]
data, err := gate.garble(wires, alg, r, &id, &data)
if err != nil {
return nil, err
}
garbled[i] = data
}
return &Garbled{
R: r,
Wires: wires,
Gates: garbled,
}, nil
}
// Garble garbles the gate and returns it labels.
func (g *Gate) garble(wires []ot.Wire, enc cipher.Block, r ot.Label,
idp *uint32, data *ot.LabelData) ([]ot.Label, error) {
var a, b, c ot.Wire
var table [4]ot.Label
var start, count int
// Inputs.
switch g.Op {
case XOR, XNOR, AND, OR:
b = wires[g.Input1]
fallthrough
case INV:
a = wires[g.Input0]
default:
return nil, fmt.Errorf("invalid gate type %s", g.Op)
}
// Output.
switch g.Op {
case XOR:
l0 := a.L0
l0.Xor(b.L0)
l1 := l0
l1.Xor(r)
c = ot.Wire{
L0: l0,
L1: l1,
}
case XNOR:
l0 := a.L0
l0.Xor(b.L0)
l1 := l0
l1.Xor(r)
c = ot.Wire{
L0: l1,
L1: l0,
}
case AND:
pa := a.L0.S()
pb := b.L0.S()
j0 := *idp
j1 := *idp + 1
*idp = *idp + 2
// First half gate.
tg := encryptHalf(enc, a.L0, j0, data)
tg.Xor(encryptHalf(enc, a.L1, j0, data))
if pb {
tg.Xor(r)
}
wg0 := encryptHalf(enc, a.L0, j0, data)
if pa {
wg0.Xor(tg)
}
// Second half gate.
te := encryptHalf(enc, b.L0, j1, data)
te.Xor(encryptHalf(enc, b.L1, j1, data))
te.Xor(a.L0)
we0 := encryptHalf(enc, b.L0, j1, data)
if pb {
we0.Xor(te)
we0.Xor(a.L0)
}
// Combine halves
l0 := wg0
l0.Xor(we0)
l1 := l0
l1.Xor(r)
c = ot.Wire{
L0: l0,
L1: l1,
}
table[0] = tg
table[1] = te
count = 2
case OR, INV:
// Row reduction creates labels below so that the first row is
// all zero.
default:
panic("invalid gate type")
}
switch g.Op {
case XOR, XNOR:
// Free XOR.
case AND:
// Half AND garbled above.
case OR:
// a b c
// -----
// 0 0 0
// 0 1 1
// 1 0 1
// 1 1 1
id := *idp
*idp = *idp + 1
table[idx(a.L0, b.L0)] = encrypt(enc, a.L0, b.L0, c.L0, id, data)
table[idx(a.L0, b.L1)] = encrypt(enc, a.L0, b.L1, c.L1, id, data)
table[idx(a.L1, b.L0)] = encrypt(enc, a.L1, b.L0, c.L1, id, data)
table[idx(a.L1, b.L1)] = encrypt(enc, a.L1, b.L1, c.L1, id, data)
l0Index := idx(a.L0, b.L0)
c.L0 = table[0]
c.L1 = table[0]
if l0Index == 0 {
c.L1.Xor(r)
} else {
c.L0.Xor(r)
}
for i := 0; i < 4; i++ {
if i == l0Index {
table[i].Xor(c.L0)
} else {
table[i].Xor(c.L1)
}
}
start = 1
count = 3
case INV:
// a b c
// -----
// 0 1
// 1 0
id := *idp
*idp = *idp + 1
table[idxUnary(a.L0)] = encrypt(enc, a.L0, ot.Label{}, c.L1, id, data)
table[idxUnary(a.L1)] = encrypt(enc, a.L1, ot.Label{}, c.L0, id, data)
l0Index := idxUnary(a.L0)
c.L0 = table[0]
c.L1 = table[0]
if l0Index == 0 {
c.L0.Xor(r)
} else {
c.L1.Xor(r)
}
for i := 0; i < 2; i++ {
if i == l0Index {
table[i].Xor(c.L1)
} else {
table[i].Xor(c.L0)
}
}
start = 1
count = 1
default:
return nil, fmt.Errorf("invalid operand %s", g.Op)
}
wires[g.Output] = c
return table[start : start+count], nil
}

189
bedlam/circuit/garbler.go Normal file
View File

@ -0,0 +1,189 @@
//
// garbler.go
//
// Copyright (c) 2019-2023 Markku Rossi
//
// All rights reserved.
//
package circuit
import (
"crypto/rand"
"fmt"
"math/big"
"source.quilibrium.com/quilibrium/monorepo/bedlam/ot"
"source.quilibrium.com/quilibrium/monorepo/bedlam/p2p"
)
// FileSize specifies a file (or data transfer) size in bytes.
type FileSize uint64
func (s FileSize) String() string {
if s > 1000*1000*1000*1000 {
return fmt.Sprintf("%dTB", s/(1000*1000*1000*1000))
} else if s > 1000*1000*1000 {
return fmt.Sprintf("%dGB", s/(1000*1000*1000))
} else if s > 1000*1000 {
return fmt.Sprintf("%dMB", s/(1000*1000))
} else if s > 1000 {
return fmt.Sprintf("%dkB", s/1000)
} else {
return fmt.Sprintf("%dB", s)
}
}
// Garbler runs the garbler on the P2P network.
func Garbler(conn *p2p.Conn, oti ot.OT, circ *Circuit, inputs *big.Int,
verbose bool) ([]*big.Int, error) {
timing := NewTiming()
if verbose {
fmt.Printf(" - Garbling...\n")
}
var key [32]byte
_, err := rand.Read(key[:])
if err != nil {
return nil, err
}
garbled, err := circ.Garble(key[:])
if err != nil {
return nil, err
}
timing.Sample("Garble", nil)
// Send program info.
if verbose {
fmt.Printf(" - Sending garbled circuit...\n")
}
if err := conn.SendData(key[:]); err != nil {
return nil, err
}
// Send garbled tables.
if err := conn.SendUint32(len(garbled.Gates)); err != nil {
return nil, err
}
var labelData ot.LabelData
for _, data := range garbled.Gates {
if err := conn.SendUint32(len(data)); err != nil {
return nil, err
}
for _, d := range data {
if err := conn.SendLabel(d, &labelData); err != nil {
return nil, err
}
}
}
if err := conn.Flush(); err != nil {
return nil, err
}
// Select our inputs.
var n1 []ot.Label
for i := 0; i < int(circ.Inputs[0].Type.Bits); i++ {
wire := garbled.Wires[i]
var n ot.Label
if inputs.Bit(i) == 1 {
n = wire.L1
} else {
n = wire.L0
}
n1 = append(n1, n)
}
// Send our inputs.
for idx, i := range n1 {
if verbose && false {
fmt.Printf("N1[%d]:\t%s\n", idx, i)
}
if err := conn.SendLabel(i, &labelData); err != nil {
return nil, err
}
}
ioStats := conn.Stats.Sum()
timing.Sample("Xfer", []string{FileSize(ioStats).String()})
if verbose {
fmt.Printf(" - Processing messages...\n")
}
if err := conn.Flush(); err != nil {
return nil, err
}
// Init oblivious transfer.
err = oti.InitSender(conn)
if err != nil {
return nil, err
}
xfer := conn.Stats.Sum() - ioStats
ioStats = conn.Stats.Sum()
timing.Sample("OT Init", []string{FileSize(xfer).String()})
// Peer OTs its inputs.
offset, err := conn.ReceiveUint32()
if err != nil {
return nil, err
}
count, err := conn.ReceiveUint32()
if err != nil {
return nil, err
}
if offset != int(circ.Inputs[0].Type.Bits) ||
count != int(circ.Inputs[1].Type.Bits) {
return nil, fmt.Errorf("peer can't OT wires [%d...%d[",
offset, offset+count)
}
err = oti.Send(garbled.Wires[offset : offset+count])
if err != nil {
return nil, err
}
xfer = conn.Stats.Sum() - ioStats
ioStats = conn.Stats.Sum()
timing.Sample("OT", []string{FileSize(xfer).String()})
// Resolve result values.
result := big.NewInt(0)
var label ot.Label
for i := 0; i < circ.Outputs.Size(); i++ {
err := conn.ReceiveLabel(&label, &labelData)
if err != nil {
return nil, err
}
if i == 0 {
timing.Sample("Eval", nil)
}
wire := garbled.Wires[circ.NumWires-circ.Outputs.Size()+i]
var bit uint
if label.Equal(wire.L0) {
bit = 0
} else if label.Equal(wire.L1) {
bit = 1
} else {
return nil, fmt.Errorf("unknown label %s for result %d", label, i)
}
result = big.NewInt(0).SetBit(result, i, bit)
}
data := result.Bytes()
if err := conn.SendData(data); err != nil {
return nil, err
}
if err := conn.Flush(); err != nil {
return nil, err
}
xfer = conn.Stats.Sum() - ioStats
timing.Sample("Result", []string{FileSize(xfer).String()})
if verbose {
timing.Print(conn.Stats)
}
return circ.Outputs.Split(result), nil
}

213
bedlam/circuit/ioarg.go Normal file
View File

@ -0,0 +1,213 @@
//
// Copyright (c) 2019-2024 Markku Rossi
//
// All rights reserved.
//
package circuit
import (
"fmt"
"math/big"
"strings"
"source.quilibrium.com/quilibrium/monorepo/bedlam/types"
)
// IO specifies circuit input and output arguments.
type IO []IOArg
// Size computes the size of the circuit input and output arguments in
// bits.
func (io IO) Size() int {
var sum int
for _, a := range io {
sum += int(a.Type.Bits)
}
return sum
}
func (io IO) String() string {
var str = ""
for i, a := range io {
if i > 0 {
str += ", "
}
if len(a.Name) > 0 {
str += a.Name + ":"
}
str += a.Type.String()
}
return str
}
// Split splits the value into separate I/O arguments.
func (io IO) Split(in *big.Int) []*big.Int {
var result []*big.Int
var bit int
for _, arg := range io {
r := big.NewInt(0)
for i := 0; i < int(arg.Type.Bits); i++ {
if in.Bit(bit) == 1 {
r = big.NewInt(0).SetBit(r, i, 1)
}
bit++
}
result = append(result, r)
}
return result
}
// IOArg describes circuit input argument.
type IOArg struct {
Name string
Type types.Info
Compound IO
}
func (io IOArg) String() string {
if len(io.Compound) > 0 {
return io.Compound.String()
}
if len(io.Name) > 0 {
return io.Name + ":" + io.Type.String()
}
return io.Type.String()
}
// Parse parses the I/O argument from the input string values.
func (io IOArg) Parse(inputs []string) (*big.Int, error) {
result := new(big.Int)
if len(io.Compound) == 0 {
if len(inputs) != 1 {
return nil,
fmt.Errorf("invalid amount of arguments, got %d, expected 1",
len(inputs))
}
switch io.Type.Type {
case types.TInt, types.TUint:
_, ok := result.SetString(inputs[0], 0)
if !ok {
return nil, fmt.Errorf("invalid input '%s' for %s",
inputs[0], io.Type)
}
case types.TBool:
switch inputs[0] {
case "0", "f", "false":
case "1", "t", "true":
result.SetInt64(1)
default:
return nil, fmt.Errorf("invalid bool constant: %s", inputs[0])
}
case types.TArray, types.TSlice:
count := int(io.Type.ArraySize)
elSize := int(io.Type.ElementType.Bits)
if io.Type.Type == types.TArray && count == 0 {
// Handle empty types.TArray arguments.
break
}
val := new(big.Int)
_, ok := val.SetString(inputs[0], 0)
if !ok {
return nil, fmt.Errorf("invalid input '%s' for %s",
inputs[0], io.Type)
}
var bitLen int
if strings.HasPrefix(inputs[0], "0x") {
bitLen = (len(inputs[0]) - 2) * 4
} else {
bitLen = val.BitLen()
}
valElCount := bitLen / elSize
if bitLen%elSize != 0 {
valElCount++
}
if io.Type.Type == types.TSlice {
// Set the count=valElCount for types.TSlice arguments.
count = valElCount
}
if valElCount > count {
return nil, fmt.Errorf("too many values for input: %s",
inputs[0])
}
pad := count - valElCount
val.Lsh(val, uint(pad*elSize))
mask := new(big.Int)
for i := 0; i < elSize; i++ {
mask.SetBit(mask, i, 1)
}
for i := 0; i < count; i++ {
next := new(big.Int).Rsh(val, uint((count-i-1)*elSize))
next = next.And(next, mask)
next.Lsh(next, uint(i*elSize))
result.Or(result, next)
}
default:
return nil, fmt.Errorf("unsupported input type: %s", io.Type)
}
return result, nil
}
if len(inputs) != len(io.Compound) {
return nil,
fmt.Errorf("invalid amount of arguments, got %d, expected %d",
len(inputs), len(io.Compound))
}
var offset int
for idx, arg := range io.Compound {
input, err := arg.Parse(inputs[idx : idx+1])
if err != nil {
return nil, err
}
input.Lsh(input, uint(offset))
result.Or(result, input)
offset += int(arg.Type.Bits)
}
return result, nil
}
// InputSizes computes the bit sizes of the input arguments. This is
// used for parametrized main() when the program is instantiated based
// on input sizes.
func InputSizes(inputs []string) ([]int, error) {
var result []int
for _, input := range inputs {
switch input {
case "_":
result = append(result, 0)
case "0", "f", "false", "1", "t", "true":
result = append(result, 1)
default:
if strings.HasPrefix(input, "0x") {
result = append(result, (len(input)-2)*4)
} else {
val := new(big.Int)
_, ok := val.SetString(input, 0)
if !ok {
return nil, fmt.Errorf("invalid input: %s", input)
}
result = append(result, val.BitLen())
}
}
}
return result, nil
}

View File

@ -0,0 +1,62 @@
//
// Copyright (c) 2023 Markku Rossi
//
// All rights reserved.
//
package circuit
import (
"testing"
)
var inputSizeTests = []struct {
inputs []string
sizes []int
}{
{
inputs: []string{
"0", "f", "false", "1", "t", "true",
},
sizes: []int{
1, 1, 1, 1, 1, 1,
},
},
{
inputs: []string{
"0xdeadbeef", "255",
},
sizes: []int{
32, 8,
},
},
{
inputs: []string{
"0x0", "0x00", "0x000", "0x0000",
},
sizes: []int{
4, 8, 12, 16,
},
},
}
func TestInputSizes(t *testing.T) {
for idx, test := range inputSizeTests {
sizes, err := InputSizes(test.inputs)
if err != nil {
t.Errorf("t%v: InputSizes(%v) failed: %v", idx, test.inputs, err)
continue
}
if len(sizes) != len(test.sizes) {
t.Errorf("t%v: unexpected # of sizes: got %v, expected %v",
idx, len(sizes), len(test.sizes))
continue
}
for i := 0; i < len(sizes); i++ {
if sizes[i] != test.sizes[i] {
t.Errorf("t%v: sizes[%v]=%v, expected %v",
idx, i, sizes[i], test.sizes[i])
}
}
}
}

141
bedlam/circuit/marshal.go Normal file
View File

@ -0,0 +1,141 @@
//
// Copyright (c) 2020-2021, 2023 Markku Rossi
//
// All rights reserved.
//
package circuit
import (
"encoding/binary"
"fmt"
"io"
)
const (
// MAGIC is a magic number for the QCL circuit format version 0.
MAGIC = 0x63726300 // crc0
)
var (
bo = binary.BigEndian
)
// MarshalFormat marshals circuit in the specified format.
func (c *Circuit) MarshalFormat(out io.Writer, format string) error {
switch format {
case "qclc":
return c.Marshal(out)
case "bristol":
return c.MarshalBristol(out)
default:
return fmt.Errorf("unsupported circuit format: %s", format)
}
}
// Marshal marshals circuit in the QCL circuit format.
func (c *Circuit) Marshal(out io.Writer) error {
var data = []interface{}{
uint32(MAGIC),
uint32(c.NumGates),
uint32(c.NumWires),
uint32(len(c.Inputs)),
uint32(len(c.Outputs)),
}
for _, v := range data {
if err := binary.Write(out, bo, v); err != nil {
return err
}
}
for _, input := range c.Inputs {
if err := marshalIOArg(out, input); err != nil {
return err
}
}
for _, output := range c.Outputs {
if err := marshalIOArg(out, output); err != nil {
return err
}
}
for _, g := range c.Gates {
switch g.Op {
case XOR, XNOR, AND, OR:
data = []interface{}{
byte(g.Op),
uint32(g.Input0), uint32(g.Input1), uint32(g.Output),
}
case INV:
data = []interface{}{
byte(g.Op),
uint32(g.Input0), uint32(g.Output),
}
default:
return fmt.Errorf("unsupported gate type %s", g.Op)
}
for _, v := range data {
if err := binary.Write(out, bo, v); err != nil {
return err
}
}
}
return nil
}
func marshalIOArg(out io.Writer, arg IOArg) error {
if err := marshalString(out, arg.Name); err != nil {
return err
}
if err := marshalString(out, arg.Type.String()); err != nil {
return err
}
if err := binary.Write(out, bo, uint32(arg.Type.Bits)); err != nil {
return err
}
if err := binary.Write(out, bo, uint32(len(arg.Compound))); err != nil {
return err
}
for _, c := range arg.Compound {
if err := marshalIOArg(out, c); err != nil {
return err
}
}
return nil
}
func marshalString(out io.Writer, val string) error {
bytes := []byte(val)
if err := binary.Write(out, bo, uint32(len(bytes))); err != nil {
return err
}
_, err := out.Write(bytes)
return err
}
// MarshalBristol marshals the circuit in the Bristol format.
func (c *Circuit) MarshalBristol(out io.Writer) error {
fmt.Fprintf(out, "%d %d\n", c.NumGates, c.NumWires)
fmt.Fprintf(out, "%d", len(c.Inputs))
for _, input := range c.Inputs {
fmt.Fprintf(out, " %d", input.Type.Bits)
}
fmt.Fprintln(out)
fmt.Fprintf(out, "%d", len(c.Outputs))
for _, ret := range c.Outputs {
fmt.Fprintf(out, " %d", ret.Type.Bits)
}
fmt.Fprintln(out)
fmt.Fprintln(out)
for _, g := range c.Gates {
fmt.Fprintf(out, "%d 1", len(g.Inputs()))
for _, w := range g.Inputs() {
fmt.Fprintf(out, " %d", w)
}
fmt.Fprintf(out, " %d", g.Output)
fmt.Fprintf(out, " %s\n", g.Op)
}
return nil
}

511
bedlam/circuit/parser.go Normal file
View File

@ -0,0 +1,511 @@
//
// Copyright (c) 2019-2023 Markku Rossi
//
// All rights reserved.
//
package circuit
import (
"bufio"
"encoding/binary"
"errors"
"fmt"
"io"
"math"
"os"
"regexp"
"strconv"
"strings"
"source.quilibrium.com/quilibrium/monorepo/bedlam/types"
)
var reParts = regexp.MustCompilePOSIX("[[:space:]]+")
// Seen describes whether wire has been seen.
type Seen []bool
// Get gets the wire seen flag.
func (s Seen) Get(index Wire) (bool, error) {
if index >= Wire(len(s)) {
return false, fmt.Errorf("invalid wire %d [0...%d[", index, len(s))
}
return s[index], nil
}
// Set marks the wire seen.
func (s Seen) Set(index Wire) error {
if index >= Wire(len(s)) {
return fmt.Errorf("invalid wire %d [0...%d[", index, len(s))
}
s[index] = true
return nil
}
// IsFilename tests if the argument file is a potential circuit
// filename.
func IsFilename(file string) bool {
return strings.HasSuffix(file, ".circ") ||
strings.HasSuffix(file, ".bristol") ||
strings.HasSuffix(file, ".qclc")
}
// Parse parses the circuit file.
func Parse(file string) (*Circuit, error) {
f, err := os.Open(file)
if err != nil {
return nil, err
}
defer f.Close()
if strings.HasSuffix(file, ".circ") || strings.HasSuffix(file, ".bristol") {
return ParseBristol(f)
} else if strings.HasSuffix(file, ".qclc") {
return ParseQCLC(f)
}
return nil, fmt.Errorf("unsupported circuit format")
}
// ParseQCLC parses an QCL circuit file.
func ParseQCLC(in io.Reader) (*Circuit, error) {
r := bufio.NewReader(in)
var header struct {
Magic uint32
NumGates uint32
NumWires uint32
NumInputs uint32
NumOutputs uint32
}
if err := binary.Read(r, bo, &header); err != nil {
return nil, err
}
var inputs, outputs IO
var inputWires, outputWires int
wiresSeen := make(Seen, header.NumWires)
for i := 0; i < int(header.NumInputs); i++ {
arg, err := parseIOArg(r)
if err != nil {
return nil, err
}
inputs = append(inputs, arg)
inputWires += int(arg.Type.Bits)
}
for i := 0; i < int(header.NumOutputs); i++ {
out, err := parseIOArg(r)
if err != nil {
return nil, err
}
outputs = append(outputs, out)
outputWires += int(out.Type.Bits)
}
// Mark input wires seen.
for i := 0; i < inputWires; i++ {
if err := wiresSeen.Set(Wire(i)); err != nil {
return nil, err
}
}
gates := make([]Gate, header.NumGates)
var stats Stats
var gate int
for gate = 0; ; gate++ {
op, err := r.ReadByte()
if err != nil {
if err == io.EOF {
break
}
return nil, err
}
switch Operation(op) {
case XOR, XNOR, AND, OR:
var bin struct {
Input0 uint32
Input1 uint32
Output uint32
}
if err := binary.Read(r, bo, &bin); err != nil {
return nil, err
}
seen, err := wiresSeen.Get(Wire(bin.Input0))
if err != nil {
return nil, err
}
if !seen {
return nil, fmt.Errorf("input %d of gate %d not set",
bin.Input0, gate)
}
seen, err = wiresSeen.Get(Wire(bin.Input1))
if err != nil {
return nil, err
}
if !seen {
return nil, fmt.Errorf("input %d of gate %d not set",
bin.Input1, gate)
}
if err := wiresSeen.Set(Wire(bin.Output)); err != nil {
return nil, err
}
gates[gate] = Gate{
Input0: Wire(bin.Input0),
Input1: Wire(bin.Input1),
Output: Wire(bin.Output),
Op: Operation(op),
}
case INV:
var unary struct {
Input0 uint32
Output uint32
}
if err := binary.Read(r, bo, &unary); err != nil {
return nil, err
}
seen, err := wiresSeen.Get(Wire(unary.Input0))
if err != nil {
return nil, err
}
if !seen {
return nil, fmt.Errorf("input %d of gate %d not set",
unary.Input0, gate)
}
if err := wiresSeen.Set(Wire(unary.Output)); err != nil {
return nil, err
}
gates[gate] = Gate{
Input0: Wire(unary.Input0),
Output: Wire(unary.Output),
Op: Operation(op),
}
default:
return nil, fmt.Errorf("unsupported gate type %s", Operation(op))
}
stats[Operation(op)]++
}
if uint32(gate) != header.NumGates {
return nil, fmt.Errorf("not enough gates: got %d, expected %d",
gate, header.NumGates)
}
// Check that all wires are seen.
for i := 0; i < len(wiresSeen); i++ {
if !wiresSeen[i] {
return nil, fmt.Errorf("wire %d not assigned", i)
}
}
return &Circuit{
NumGates: int(header.NumGates),
NumWires: int(header.NumWires),
Inputs: inputs,
Outputs: outputs,
Gates: gates,
Stats: stats,
}, nil
}
func parseIOArg(r *bufio.Reader) (arg IOArg, err error) {
name, err := parseString(r)
if err != nil {
return arg, err
}
t, err := parseString(r)
if err != nil {
return arg, err
}
var ui32 uint32
if err := binary.Read(r, bo, &ui32); err != nil {
return arg, err
}
arg.Name = name
arg.Type, err = types.Parse(t)
if err != nil {
return arg, err
}
arg.Type.Bits = types.Size(ui32)
// Compound
if err := binary.Read(r, bo, &ui32); err != nil {
return arg, err
}
for i := 0; i < int(ui32); i++ {
c, err := parseIOArg(r)
if err != nil {
return arg, err
}
arg.Compound = append(arg.Compound, c)
}
return
}
func parseString(r *bufio.Reader) (string, error) {
var ui32 uint32
if err := binary.Read(r, bo, &ui32); err != nil {
return "", err
}
if ui32 == 0 {
return "", nil
}
buf := make([]byte, ui32)
_, err := r.Read(buf)
if err != nil {
return "", err
}
return string(buf), nil
}
// ParseBristol parses a Briston circuit file.
func ParseBristol(in io.Reader) (*Circuit, error) {
r := bufio.NewReader(in)
// NumGates NumWires
line, err := readLine(r)
if err != nil {
return nil, err
}
if len(line) != 2 {
return nil, fmt.Errorf("invalid 1st line: '%s'", line)
}
numGates, err := strconv.Atoi(line[0])
if err != nil {
return nil, err
}
if numGates < 0 || numGates > math.MaxInt32 {
return nil, fmt.Errorf("invalid numGates: %d", numGates)
}
numWires, err := strconv.Atoi(line[1])
if err != nil {
return nil, err
}
if numWires < 0 || numWires > math.MaxInt32 {
return nil, fmt.Errorf("invalid numWires: %d", numWires)
}
wiresSeen := make(Seen, numWires)
// Inputs
line, err = readLine(r)
if err != nil {
return nil, err
}
niv, err := strconv.Atoi(line[0])
if err != nil {
return nil, err
}
if 1+niv != len(line) {
return nil, fmt.Errorf("invalid inputs line: niv=%d, len=%d",
niv, len(line))
}
var inputs IO
var inputWires int64
for i := 1; i < len(line); i++ {
bits, err := strconv.ParseInt(line[i], 10, 32)
if err != nil {
return nil, fmt.Errorf("invalid input bits: %s", err)
}
if bits < 0 {
return nil, fmt.Errorf("invalid input bits: %d", bits)
}
inputs = append(inputs, IOArg{
Name: fmt.Sprintf("NI%d", i),
Type: types.Info{
Type: types.TUint,
IsConcrete: true,
Bits: types.Size(bits),
},
})
inputWires += bits
}
if inputWires == 0 {
return nil, fmt.Errorf("no inputs defined")
}
// Mark input wires seen.
for i := int64(0); i < inputWires; i++ {
if err := wiresSeen.Set(Wire(i)); err != nil {
return nil, err
}
}
// Outputs
line, err = readLine(r)
if err != nil {
return nil, err
}
nov, err := strconv.Atoi(line[0])
if err != nil {
return nil, err
}
if 1+nov != len(line) {
return nil, errors.New("invalid outputs line")
}
var outputs IO
for i := 1; i < len(line); i++ {
bits, err := strconv.ParseInt(line[i], 10, 32)
if err != nil {
return nil, fmt.Errorf("invalid output bits: %s", err)
}
if bits < 0 {
return nil, fmt.Errorf("invalid output bits: %d", bits)
}
outputs = append(outputs, IOArg{
Name: fmt.Sprintf("NO%d", i),
Type: types.Info{
Type: types.TUint,
IsConcrete: true,
Bits: types.Size(bits),
},
})
}
gates := make([]Gate, numGates)
var stats Stats
var gate int
for gate = 0; ; gate++ {
line, err = readLine(r)
if err != nil {
if err == io.EOF {
break
}
return nil, err
}
if gate >= numGates {
return nil, errors.New("too many gates")
}
if len(line) < 3 {
return nil, fmt.Errorf("invalid gate: %v", line)
}
n1, err := strconv.Atoi(line[0])
if err != nil {
return nil, err
}
if n1 < 0 {
return nil, fmt.Errorf("invalid n1: %v", n1)
}
n2, err := strconv.Atoi(line[1])
if err != nil {
return nil, err
}
if n2 < 0 {
return nil, fmt.Errorf("invalid n2: %v", n2)
}
if 2+n1+n2+1 != len(line) {
return nil, fmt.Errorf("invalid gate: %v", line)
}
var inputs []Wire
for i := 0; i < n1; i++ {
v, err := strconv.ParseUint(line[2+i], 10, 32)
if err != nil {
return nil, err
}
seen, err := wiresSeen.Get(Wire(v))
if err != nil {
return nil, err
}
if !seen {
return nil, fmt.Errorf("input %d of gate %d not set", v, gate)
}
inputs = append(inputs, Wire(v))
}
var outputs []Wire
for i := 0; i < n2; i++ {
v, err := strconv.ParseUint(line[2+n1+i], 10, 32)
if err != nil {
return nil, err
}
err = wiresSeen.Set(Wire(v))
if err != nil {
return nil, err
}
outputs = append(outputs, Wire(v))
}
var op Operation
var numInputs int
switch line[len(line)-1] {
case "XOR":
op = XOR
numInputs = 2
case "XNOR":
op = XNOR
numInputs = 2
case "AND":
op = AND
numInputs = 2
case "OR":
op = OR
numInputs = 2
case "INV":
op = INV
numInputs = 1
default:
return nil, fmt.Errorf("invalid operation '%s'", line[len(line)-1])
}
if len(inputs) != numInputs {
return nil, fmt.Errorf("invalid number of inputs %d for %s",
len(inputs), op)
}
if len(outputs) != 1 {
return nil, fmt.Errorf("invalid number of outputs %d for %s",
len(outputs), op)
}
var input1 Wire
if len(inputs) > 1 {
input1 = inputs[1]
}
gates[gate] = Gate{
Input0: inputs[0],
Input1: input1,
Output: outputs[0],
Op: op,
}
stats[op]++
}
if gate != numGates {
return nil, fmt.Errorf("not enough gates: got %d, expected %d",
gate, numGates)
}
// Check that all wires are seen.
for i := 0; i < len(wiresSeen); i++ {
if !wiresSeen[i] {
return nil, fmt.Errorf("wire %d not assigned", i)
}
}
return &Circuit{
NumGates: numGates,
NumWires: numWires,
Inputs: inputs,
Outputs: outputs,
Gates: gates,
Stats: stats,
}, nil
}
func readLine(r *bufio.Reader) ([]string, error) {
for {
line, err := r.ReadString('\n')
if err != nil {
return nil, err
}
line = strings.TrimSpace(line)
if len(line) == 0 {
continue
}
parts := reParts.Split(line, -1)
if len(parts) > 0 {
return parts, nil
}
}
}

View File

@ -0,0 +1,28 @@
//
// parser_test.go
//
// Copyright (c) 2019-2023 Markku Rossi
//
// All rights reserved.
//
package circuit
import (
"bytes"
"testing"
)
var data = `1 3
2 1 1
1 1
2 1 0 1 2 AND
`
func TestParse(t *testing.T) {
_, err := ParseBristol(bytes.NewReader([]byte(data)))
if err != nil {
t.Fatalf("Parse failed: %s", err)
}
}

View File

@ -0,0 +1,501 @@
//
// Copyright (c) 2020-2024 Markku Rossi
//
// All rights reserved.
//
package circuit
import (
"crypto/aes"
"crypto/cipher"
"fmt"
"math/big"
"time"
"source.quilibrium.com/quilibrium/monorepo/bedlam/ot"
"source.quilibrium.com/quilibrium/monorepo/bedlam/p2p"
"source.quilibrium.com/quilibrium/monorepo/bedlam/types"
)
// Protocol operation codes.
const (
OpResult = iota
OpCircuit
OpReturn
)
// StreamEval is a streaming garbled circuit evaluator.
type StreamEval struct {
key []byte
alg cipher.Block
wires []ot.Label
tmp []ot.Label
}
// NewStreamEval creates a new streaming garbled circuit evaluator.
func NewStreamEval(key []byte, numInputs, numOutputs int) (*StreamEval, error) {
alg, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
return &StreamEval{
key: key,
alg: alg,
wires: make([]ot.Label, numInputs+numOutputs),
}, nil
}
// Get gets the value of the wire.
func (stream *StreamEval) Get(tmp bool, w int) ot.Label {
if tmp {
return stream.tmp[w]
}
return stream.wires[w]
}
// GetInputs gets the specified input wire range.
func (stream *StreamEval) GetInputs(offset, count int) []ot.Label {
return stream.wires[offset : offset+count]
}
// Set sets the value of the wire.
func (stream *StreamEval) Set(tmp bool, w int, label ot.Label) {
if tmp {
stream.tmp[w] = label
} else {
stream.wires[w] = label
}
}
// InitCircuit initializes the stream evaluator with wires.
func (stream *StreamEval) InitCircuit(numWires, numTmpWires int) {
if numWires > len(stream.wires) {
var size int
for size = 1024; size < numWires; size *= 2 {
}
n := make([]ot.Label, size)
copy(n, stream.wires)
stream.wires = n
}
if numTmpWires > len(stream.tmp) {
var size int
for size = 1024; size < numTmpWires; size *= 2 {
}
stream.tmp = make([]ot.Label, size)
}
}
// StreamEvaluator runs the stream evaluator on the connection.
func StreamEvaluator(conn *p2p.Conn, oti ot.OT, inputFlag []string,
verbose bool) (IO, []*big.Int, error) {
timing := NewTiming()
// Receive program info.
if verbose {
fmt.Printf(" - Waiting for program info...\n")
}
key, err := conn.ReceiveData()
if err != nil {
return nil, nil, err
}
alg, err := aes.NewCipher(key)
if err != nil {
return nil, nil, err
}
// Peer input.
in1, err := receiveArgument(conn)
if err != nil {
return nil, nil, err
}
// Our input.
in2, err := receiveArgument(conn)
if err != nil {
return nil, nil, err
}
inputs, err := in2.Parse(inputFlag)
if err != nil {
return nil, nil, err
}
// Program outputs.
numOutputs, err := conn.ReceiveUint32()
if err != nil {
return nil, nil, err
}
var outputs IO
for i := 0; i < numOutputs; i++ {
out, err := receiveArgument(conn)
if err != nil {
return nil, nil, err
}
outputs = append(outputs, out)
}
numSteps, err := conn.ReceiveUint32()
if err != nil {
return nil, nil, err
}
fmt.Printf(" - In1: %s\n", in1)
fmt.Printf(" + In2: %s\n", in2)
fmt.Printf(" - Out: %s\n", outputs)
fmt.Printf(" - In: %s\n", inputFlag)
streaming, err := NewStreamEval(key, int(in1.Type.Bits+in2.Type.Bits),
outputs.Size())
if err != nil {
return nil, nil, err
}
// Receive peer inputs.
var label ot.Label
var labelData ot.LabelData
for w := 0; w < int(in1.Type.Bits); w++ {
err := conn.ReceiveLabel(&label, &labelData)
if err != nil {
return nil, nil, err
}
streaming.Set(false, w, label)
}
// Init oblivious transfer.
err = oti.InitReceiver(conn)
if err != nil {
return nil, nil, err
}
ioStats := conn.Stats.Sum()
timing.Sample("Init", []string{FileSize(ioStats).String()})
// Query our inputs.
if verbose {
fmt.Printf(" - Querying our inputs...\n")
}
flags := make([]bool, in2.Type.Bits)
for i := 0; i < int(in2.Type.Bits); i++ {
if inputs.Bit(i) == 1 {
flags[i] = true
}
}
inputLabels := streaming.GetInputs(int(in1.Type.Bits), int(in2.Type.Bits))
if err := oti.Receive(flags, inputLabels); err != nil {
return nil, nil, err
}
xfer := conn.Stats.Sum() - ioStats
ioStats = conn.Stats.Sum()
timing.Sample("Inputs", []string{FileSize(xfer).String()})
ws := func(i int, tmp bool) string {
if tmp {
return fmt.Sprintf("~%d", i)
}
return fmt.Sprintf("w%d", i)
}
// Evaluate program.
if verbose {
fmt.Printf(" - Evaluating program...\n")
}
var garbled [4]ot.Label
var lastStep int
var rawResult *big.Int
start := time.Now()
lastReport := start
loop:
for {
op, err := conn.ReceiveUint32()
if err != nil {
return nil, nil, err
}
switch op {
case OpCircuit:
step, err := conn.ReceiveUint32()
if err != nil {
return nil, nil, err
}
numGates, err := conn.ReceiveUint32()
if err != nil {
return nil, nil, err
}
numTmpWires, err := conn.ReceiveUint32()
if err != nil {
return nil, nil, err
}
numWires, err := conn.ReceiveUint32()
if err != nil {
return nil, nil, err
}
if step-lastStep >= 10 && verbose {
lastStep = step
now := time.Now()
if now.Sub(lastReport) > time.Second*5 {
lastReport = now
elapsed := time.Now().Sub(start)
done := float64(step) / float64(numSteps)
if done > 0 {
total := time.Duration(float64(elapsed) / done)
progress := fmt.Sprintf("%d/%d", step, numSteps)
remaining := fmt.Sprintf("%24s", total-elapsed)
fmt.Printf("%-14s\t%s remaining\tETA %s\n",
progress, remaining,
start.Add(total).Format(time.Stamp))
} else {
fmt.Printf("%d/%d\n", step, numSteps)
}
}
}
streaming.InitCircuit(numWires, numTmpWires)
var id uint32
for i := 0; i < numGates; i++ {
gop, err := conn.ReceiveByte()
if err != nil {
return nil, nil, err
}
var aTmp, bTmp, cTmp bool
if gop&0b10000000 != 0 {
aTmp = true
}
if gop&0b01000000 != 0 {
bTmp = true
}
if gop&0b00100000 != 0 {
cTmp = true
}
var recvWire func() (int, error)
if gop&0b00010000 != 0 {
recvWire = conn.ReceiveUint16
} else {
recvWire = conn.ReceiveUint32
}
gop &^= 0b11110000
var aIndex, bIndex, cIndex int
var tableCount int
switch Operation(gop) {
case XOR, XNOR, AND, OR:
aIndex, err = recvWire()
if err != nil {
return nil, nil, err
}
bIndex, err = recvWire()
if err != nil {
return nil, nil, err
}
cIndex, err = recvWire()
if err != nil {
return nil, nil, err
}
case INV:
aIndex, err = recvWire()
if err != nil {
return nil, nil, err
}
cIndex, err = recvWire()
if err != nil {
return nil, nil, err
}
default:
return nil, nil, fmt.Errorf("invalid operation %s",
Operation(gop))
}
switch Operation(gop) {
case XOR, XNOR:
tableCount = 0
case INV:
tableCount = 1
case AND:
tableCount = 2
case OR:
tableCount = 3
}
for c := 0; c < tableCount; c++ {
err = conn.ReceiveLabel(&label, &labelData)
if err != nil {
return nil, nil, err
}
garbled[c] = label
}
var a, b, c ot.Label
switch Operation(gop) {
case XOR, XNOR, AND, OR:
if StreamDebug {
fmt.Printf("Gate%d:\t %s %s %s %s\n", i,
ws(aIndex, aTmp), ws(bIndex, bTmp),
Operation(gop), ws(cIndex, cTmp))
}
a = streaming.Get(aTmp, aIndex)
b = streaming.Get(bTmp, bIndex)
case INV:
if StreamDebug {
fmt.Printf("Gate%d:\t %s %s %s\n", i,
ws(aIndex, aTmp), Operation(gop), ws(bIndex, bTmp))
}
a = streaming.Get(aTmp, aIndex)
}
var output ot.Label
switch Operation(gop) {
case XOR, XNOR:
a.Xor(b)
output = a
case AND:
if tableCount != 2 {
return nil, nil,
fmt.Errorf("corrupted ciruit: AND table size: %d",
tableCount)
}
sa := a.S()
sb := b.S()
j0 := id
j1 := id + 1
id += 2
tg := garbled[0]
te := garbled[1]
wg := encryptHalf(alg, a, j0, &labelData)
if sa {
wg.Xor(tg)
}
we := encryptHalf(alg, b, j1, &labelData)
if sb {
we.Xor(te)
we.Xor(a)
}
output = wg
output.Xor(we)
case OR:
index := idx(a, b)
if index > 0 {
// First row is zero and not transmitted.
index--
if index >= tableCount {
return nil, nil,
fmt.Errorf("corrupted circuit: index %d >= %d",
index, tableCount)
}
c = garbled[index]
}
output = decrypt(alg, a, b, id, c, &labelData)
id++
case INV:
index := idxUnary(a)
if index > 0 {
// First row is zero and not transmitted.
index--
if index >= tableCount {
return nil, nil,
fmt.Errorf("corrupted circuit: index %d >= %d",
index, tableCount)
}
c = garbled[index]
}
output = decrypt(alg, a, b, id, c, &labelData)
id++
}
streaming.Set(cTmp, cIndex, output)
}
case OpReturn:
xfer := conn.Stats.Sum() - ioStats
ioStats = conn.Stats.Sum()
timing.Sample("Eval", []string{FileSize(xfer).String()})
var labels []ot.Label
for i := 0; i < outputs.Size(); i++ {
id, err := conn.ReceiveUint32()
if err != nil {
return nil, nil, err
}
label := streaming.Get(false, id)
labels = append(labels, label)
}
// Resolve result values.
if err := conn.SendUint32(OpResult); err != nil {
return nil, nil, err
}
var labelData ot.LabelData
for _, l := range labels {
if err := conn.SendLabel(l, &labelData); err != nil {
return nil, nil, err
}
}
if err := conn.Flush(); err != nil {
return nil, nil, err
}
result, err := conn.ReceiveData()
if err != nil {
return nil, nil, err
}
rawResult = new(big.Int).SetBytes(result)
break loop
default:
return nil, nil, fmt.Errorf("unknown operation %d", op)
}
}
xfer = conn.Stats.Sum() - ioStats
timing.Sample("Result", []string{FileSize(xfer).String()})
if verbose {
timing.Print(conn.Stats)
}
return outputs, outputs.Split(rawResult), nil
}
func receiveArgument(conn *p2p.Conn) (arg IOArg, err error) {
name, err := conn.ReceiveString()
if err != nil {
return arg, err
}
t, err := conn.ReceiveString()
if err != nil {
return arg, err
}
size, err := conn.ReceiveUint32()
if err != nil {
return arg, err
}
arg.Name = name
arg.Type, err = types.Parse(t)
if err != nil {
return arg, err
}
arg.Type.Bits = types.Size(size)
if arg.Type.Type == types.TSlice {
arg.Type.ArraySize = arg.Type.Bits / arg.Type.ElementType.Bits
}
count, err := conn.ReceiveUint32()
if err != nil {
return arg, err
}
for i := 0; i < count; i++ {
a, err := receiveArgument(conn)
if err != nil {
return arg, err
}
arg.Compound = append(arg.Compound, a)
}
return arg, nil
}

View File

@ -0,0 +1,440 @@
//
// Copyright (c) 2020-2021, 2023 Markku Rossi
//
// All rights reserved.
//
package circuit
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"fmt"
"time"
"source.quilibrium.com/quilibrium/monorepo/bedlam/ot"
"source.quilibrium.com/quilibrium/monorepo/bedlam/p2p"
)
const (
// StreamDebug controls the debugging output of the streaming
// garbling.
StreamDebug = false
)
// Streaming is a streaming garbled circuit garbler.
type Streaming struct {
conn *p2p.Conn
key []byte
alg cipher.Block
r ot.Label
wires []ot.Wire
tmp []ot.Wire
in []Wire
out []Wire
firstTmp Wire
firstOut Wire
}
// NewStreaming creates a new streaming garbled circuit garbler.
func NewStreaming(key []byte, inputs []Wire, conn *p2p.Conn) (
*Streaming, error) {
r, err := ot.NewLabel(rand.Reader)
if err != nil {
return nil, err
}
r.SetS(true)
alg, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
stream := &Streaming{
conn: conn,
key: key,
alg: alg,
r: r,
}
stream.ensureWires(maxWire(0, inputs))
// Assing all input wires.
for i := 0; i < len(inputs); i++ {
w, err := makeLabels(stream.r)
if err != nil {
return nil, err
}
stream.wires[inputs[i]] = w
}
return stream, nil
}
func maxWire(max Wire, wires []Wire) Wire {
for _, w := range wires {
if w > max {
max = w
}
}
return max
}
func (stream *Streaming) ensureWires(max Wire) {
// Verify that wires is big enough.
if len(stream.wires) <= int(max) {
var i int
for i = 65536; i <= int(max); i <<= 1 {
}
n := make([]ot.Wire, i)
copy(n, stream.wires)
stream.wires = n
}
}
func (stream *Streaming) initCircuit(c *Circuit, in, out []Wire) {
stream.ensureWires(maxWire(maxWire(0, in), out))
if len(stream.tmp) < c.NumWires {
stream.tmp = make([]ot.Wire, c.NumWires)
}
stream.in = in
stream.out = out
stream.firstTmp = Wire(len(in))
stream.firstOut = Wire(c.NumWires - len(out))
}
// GetInput gets the value of the input wire.
func (stream *Streaming) GetInput(w Wire) ot.Wire {
return stream.wires[w]
}
// GetInputs gets the specified input wire range.
func (stream *Streaming) GetInputs(offset, count int) []ot.Wire {
return stream.wires[offset : offset+count]
}
// Get gets the value of the wire.
func (stream *Streaming) Get(w Wire) (ot.Wire, Wire, bool) {
if w < stream.firstTmp {
index := stream.in[w]
return stream.wires[index], index, false
} else if w >= stream.firstOut {
index := stream.out[w-stream.firstOut]
return stream.wires[index], index, false
} else {
return stream.tmp[w], w, true
}
}
// Set sets the value of the wire.
func (stream *Streaming) Set(w Wire, val ot.Wire) (index Wire, tmp bool) {
if w < stream.firstTmp {
index = stream.in[w]
stream.wires[index] = val
} else if w >= stream.firstOut {
index = stream.out[w-stream.firstOut]
stream.wires[index] = val
} else {
index = w
tmp = true
stream.tmp[w] = val
}
return index, tmp
}
// Garble garbles the circuit and streams the garbled tables into the
// stream.
func (stream *Streaming) Garble(c *Circuit, in, out []Wire) (
time.Duration, time.Duration, error) {
if StreamDebug {
fmt.Printf(" - Streaming.Garble: in=%v, out=%v\n", in, out)
}
start := time.Now()
stream.initCircuit(c, in, out)
// Garble gates.
var data ot.LabelData
var id uint32
var table [4]ot.Label
mid := time.Now()
for i := 0; i < len(c.Gates); i++ {
gate := &c.Gates[i]
err := stream.conn.NeedSpace(512)
if err != nil {
return 0, 0, err
}
err = stream.garbleGate(gate, &id, table[:], &data,
stream.conn.WriteBuf, &stream.conn.WritePos)
if err != nil {
return 0, 0, err
}
}
return mid.Sub(start), time.Now().Sub(mid), nil
}
// GarbleGate garbles the gate and streams it to the stream.
func (stream *Streaming) garbleGate(g *Gate, idp *uint32,
table []ot.Label, data *ot.LabelData, buf []byte, bufpos *int) error {
var a, b, c ot.Wire
var aIndex, bIndex, cIndex Wire
var aTmp, bTmp, cTmp bool
table = table[0:4]
var tableStart, tableCount, wireCount int
// Inputs.
switch g.Op {
case XOR, XNOR, AND, OR:
b, bIndex, bTmp = stream.Get(g.Input1)
fallthrough
case INV:
a, aIndex, aTmp = stream.Get(g.Input0)
default:
return fmt.Errorf("invalid gate type %s", g.Op)
}
// Output.
switch g.Op {
case XOR:
l0 := a.L0
l0.Xor(b.L0)
l1 := l0
l1.Xor(stream.r)
c = ot.Wire{
L0: l0,
L1: l1,
}
case XNOR:
l0 := a.L0
l0.Xor(b.L0)
l1 := l0
l1.Xor(stream.r)
c = ot.Wire{
L0: l1,
L1: l0,
}
case AND:
pa := a.L0.S()
pb := b.L0.S()
j0 := *idp
j1 := *idp + 1
*idp = *idp + 2
// First half gate.
tg := encryptHalf(stream.alg, a.L0, j0, data)
tg.Xor(encryptHalf(stream.alg, a.L1, j0, data))
if pb {
tg.Xor(stream.r)
}
wg0 := encryptHalf(stream.alg, a.L0, j0, data)
if pa {
wg0.Xor(tg)
}
// Second half gate.
te := encryptHalf(stream.alg, b.L0, j1, data)
te.Xor(encryptHalf(stream.alg, b.L1, j1, data))
te.Xor(a.L0)
we0 := encryptHalf(stream.alg, b.L0, j1, data)
if pb {
we0.Xor(te)
we0.Xor(a.L0)
}
// Combine halves
l0 := wg0
l0.Xor(we0)
l1 := l0
l1.Xor(stream.r)
c = ot.Wire{
L0: l0,
L1: l1,
}
table[0] = tg
table[1] = te
tableCount = 2
case OR, INV:
// Row reduction creates labels below so that the first row is
// all zero.
default:
panic("invalid gate type")
}
switch g.Op {
case XOR, XNOR:
// Free XOR.
wireCount = 3
case AND:
// Half AND garbled above.
wireCount = 3
case OR:
// a b c
// -----
// 0 0 0
// 0 1 1
// 1 0 1
// 1 1 1
id := *idp
*idp = *idp + 1
table[idx(a.L0, b.L0)] = encrypt(stream.alg, a.L0, b.L0, c.L0, id, data)
table[idx(a.L0, b.L1)] = encrypt(stream.alg, a.L0, b.L1, c.L1, id, data)
table[idx(a.L1, b.L0)] = encrypt(stream.alg, a.L1, b.L0, c.L1, id, data)
table[idx(a.L1, b.L1)] = encrypt(stream.alg, a.L1, b.L1, c.L1, id, data)
// Row reduction. Make first table all zero so we don't have
// to transmit it.
l0Index := idx(a.L0, b.L0)
c.L0 = table[0]
c.L1 = table[0]
if l0Index == 0 {
c.L1.Xor(stream.r)
} else {
c.L0.Xor(stream.r)
}
for i := 0; i < 4; i++ {
if i == l0Index {
table[i].Xor(c.L0)
} else {
table[i].Xor(c.L1)
}
}
tableStart = 1
tableCount = 3
wireCount = 3
case INV:
// a b c
// -----
// 0 1
// 1 0
zero := ot.Label{}
id := *idp
*idp = *idp + 1
table[idxUnary(a.L0)] = encrypt(stream.alg, a.L0, zero, c.L1, id, data)
table[idxUnary(a.L1)] = encrypt(stream.alg, a.L1, zero, c.L0, id, data)
l0Index := idxUnary(a.L0)
c.L0 = table[0]
c.L1 = table[0]
if l0Index == 0 {
c.L0.Xor(stream.r)
} else {
c.L1.Xor(stream.r)
}
for i := 0; i < 2; i++ {
if i == l0Index {
table[i].Xor(c.L1)
} else {
table[i].Xor(c.L0)
}
}
tableStart = 1
tableCount = 1
wireCount = 2
default:
return fmt.Errorf("invalid operand %s", g.Op)
}
if g.Output < stream.firstTmp {
cIndex = stream.in[g.Output]
stream.wires[cIndex] = c
} else if g.Output >= stream.firstOut {
cIndex = stream.out[g.Output-stream.firstOut]
stream.wires[cIndex] = c
} else {
cIndex = g.Output
cTmp = true
stream.tmp[g.Output] = c
}
op := byte(g.Op)
if aTmp {
op |= 0b10000000
}
if bTmp {
op |= 0b01000000
}
if cTmp {
op |= 0b00100000
}
if aIndex <= 0xffff && bIndex <= 0xffff && cIndex <= 0xffff {
op |= 0b00010000
buf[*bufpos] = op
*bufpos = *bufpos + 1
switch wireCount {
case 3:
bo.PutUint16(buf[*bufpos+0:], uint16(aIndex))
bo.PutUint16(buf[*bufpos+2:], uint16(bIndex))
bo.PutUint16(buf[*bufpos+4:], uint16(cIndex))
*bufpos = *bufpos + 6
case 2:
bo.PutUint16(buf[*bufpos+0:], uint16(aIndex))
bo.PutUint16(buf[*bufpos+2:], uint16(cIndex))
*bufpos = *bufpos + 4
default:
panic(fmt.Sprintf("invalid wire count: %d", wireCount))
}
} else {
buf[*bufpos] = op
*bufpos = *bufpos + 1
switch wireCount {
case 3:
bo.PutUint32(buf[*bufpos+0:], uint32(aIndex))
bo.PutUint32(buf[*bufpos+4:], uint32(bIndex))
bo.PutUint32(buf[*bufpos+8:], uint32(cIndex))
*bufpos = *bufpos + 12
case 2:
bo.PutUint32(buf[*bufpos+0:], uint32(aIndex))
bo.PutUint32(buf[*bufpos+4:], uint32(cIndex))
*bufpos = *bufpos + 8
default:
panic(fmt.Sprintf("invalid wire count: %d", wireCount))
}
}
for i := 0; i < tableCount; i++ {
bytes := table[tableStart+i].Bytes(data)
copy(buf[*bufpos:], bytes)
*bufpos = *bufpos + len(bytes)
}
return nil
}

View File

@ -0,0 +1,356 @@
//
// Copyright (c) 2020-2023 Markku Rossi
//
// All rights reserved.
//
package circuit
import (
"testing"
"time"
"source.quilibrium.com/quilibrium/monorepo/bedlam/ot"
)
func BenchmarkGarbleXOR(b *testing.B) {
benchmarkGate(b, newGate(XOR))
}
func BenchmarkGarbleXNOR(b *testing.B) {
benchmarkGate(b, newGate(XNOR))
}
func BenchmarkGarbleAND(b *testing.B) {
benchmarkGate(b, newGate(AND))
}
func BenchmarkGarbleOR(b *testing.B) {
benchmarkGate(b, newGate(OR))
}
func BenchmarkGarbleINV(b *testing.B) {
benchmarkGate(b, newGate(INV))
}
func newGate(op Operation) *Gate {
return &Gate{
Input0: 0,
Input1: 1,
Output: 2,
Op: op,
}
}
func benchmarkGate(b *testing.B, g *Gate) {
var key [16]byte
inputs := []Wire{0, 1}
outputs := []Wire{2}
stream, err := NewStreaming(key[:], inputs, nil)
if err != nil {
b.Fatalf("failed to init streaming: %s", err)
}
stream.wires = []ot.Wire{{}, {}, {}}
stream.in = inputs
stream.out = outputs
stream.firstTmp = 2
stream.firstOut = 2
var id uint32
var data ot.LabelData
var table [4]ot.Label
b.ResetTimer()
for i := 0; i < b.N; i++ {
var buf [128]byte
var bufpos int
err = stream.garbleGate(g, &id, table[:], &data, buf[:], &bufpos)
if err != nil {
b.Fatalf("garble failed: %s", err)
}
}
}
func BenchmarkEncodeWire32(b *testing.B) {
var buf [64]byte
for i := 0; i < b.N; i++ {
var pos int
bufpos := &pos
var aIndex = Wire(i)
var bIndex = Wire(i * 2)
var cIndex = Wire(i * 4)
var op byte
buf[*bufpos] = op
*bufpos = *bufpos + 1
bo.PutUint32(buf[*bufpos+0:], uint32(aIndex))
bo.PutUint32(buf[*bufpos+4:], uint32(bIndex))
bo.PutUint32(buf[*bufpos+8:], uint32(cIndex))
*bufpos = *bufpos + 12
}
}
func BenchmarkEncodeWire16_32(b *testing.B) {
var buf [64]byte
for i := 0; i < b.N; i++ {
var pos int
bufpos := &pos
var aIndex = Wire(i)
var bIndex = Wire(i * 2)
var cIndex = Wire(i * 4)
var op byte
if aIndex <= 0xffff && bIndex <= 0xffff && cIndex <= 0xffff {
op |= 0b00010000
buf[*bufpos] = op
*bufpos = *bufpos + 1
bo.PutUint16(buf[*bufpos+0:], uint16(aIndex))
bo.PutUint16(buf[*bufpos+2:], uint16(bIndex))
bo.PutUint16(buf[*bufpos+4:], uint16(cIndex))
*bufpos = *bufpos + 6
} else {
buf[*bufpos] = op
*bufpos = *bufpos + 1
bo.PutUint32(buf[*bufpos+0:], uint32(aIndex))
bo.PutUint32(buf[*bufpos+4:], uint32(bIndex))
bo.PutUint32(buf[*bufpos+8:], uint32(cIndex))
*bufpos = *bufpos + 12
}
}
}
func BenchmarkEncodeWire8_16_24_32(b *testing.B) {
var buf [64]byte
for i := 0; i < b.N; i++ {
var pos int
bufpos := &pos
var aIndex = Wire(i)
var bIndex = Wire(i * 2)
var cIndex = Wire(i * 4)
var op byte
if aIndex <= 0xff && bIndex <= 0xff && cIndex <= 0xff {
buf[*bufpos] = op
*bufpos = *bufpos + 1
buf[*bufpos+0] = byte(aIndex)
buf[*bufpos+1] = byte(bIndex)
buf[*bufpos+2] = byte(cIndex)
*bufpos = *bufpos + 3
} else if aIndex <= 0xffff && bIndex <= 0xffff && cIndex <= 0xffff {
op |= 0b00010000
buf[*bufpos] = op
*bufpos = *bufpos + 1
bo.PutUint16(buf[*bufpos+0:], uint16(aIndex))
bo.PutUint16(buf[*bufpos+2:], uint16(bIndex))
bo.PutUint16(buf[*bufpos+4:], uint16(cIndex))
*bufpos = *bufpos + 6
} else if aIndex <= 0xffffff && bIndex <= 0xffffff && cIndex <= 0xffffff {
op |= 0b00100000
buf[*bufpos] = op
*bufpos = *bufpos + 1
PutUint24(buf[*bufpos+0:], uint32(aIndex))
PutUint24(buf[*bufpos+3:], uint32(bIndex))
PutUint24(buf[*bufpos+6:], uint32(cIndex))
*bufpos = *bufpos + 9
} else {
op |= 0b00110000
buf[*bufpos] = op
*bufpos = *bufpos + 1
bo.PutUint32(buf[*bufpos+0:], uint32(aIndex))
bo.PutUint32(buf[*bufpos+4:], uint32(bIndex))
bo.PutUint32(buf[*bufpos+8:], uint32(cIndex))
*bufpos = *bufpos + 12
}
}
}
func PutUint24(b []byte, v uint32) {
b[0] = byte(v >> 16)
b[1] = byte(v >> 8)
b[2] = byte(v)
}
func BenchmarkEncodeWire7var(b *testing.B) {
var buf [64]byte
for i := 0; i < b.N; i++ {
var pos int
bufpos := &pos
var aIndex = Wire(i)
var bIndex = Wire(i * 2)
var cIndex = Wire(i * 4)
var op byte
buf[*bufpos] = op
*bufpos = *bufpos + 1
*bufpos += encode7varRev(buf[*bufpos:], uint32(aIndex))
*bufpos += encode7varRev(buf[*bufpos:], uint32(bIndex))
*bufpos += encode7varRev(buf[*bufpos:], uint32(cIndex))
}
}
func encode7var(b []byte, v uint32) int {
if v <= 0b01111111 {
b[0] = byte(v)
return 1
} else if v <= 0b01111111_1111111 {
b[0] = byte(v >> 7)
b[1] = byte(v)
return 2
} else if v <= 0b01111111_1111111_1111111 {
b[0] = byte(v >> 14)
b[1] = byte(v >> 7)
b[2] = byte(v)
return 3
} else if v <= 0b01111111_1111111_1111111_1111111 {
b[0] = byte(v >> 21)
b[1] = byte(v >> 14)
b[2] = byte(v >> 7)
b[3] = byte(v)
return 4
} else {
b[0] = byte(v >> 28)
b[1] = byte(v >> 21)
b[2] = byte(v >> 14)
b[3] = byte(v >> 7)
b[4] = byte(v)
return 5
}
}
func encode7varRev(b []byte, v uint32) int {
if v > 0b01111111_1111111_1111111_1111111 {
b[0] = byte(v >> 28)
b[1] = byte(v >> 21)
b[2] = byte(v >> 14)
b[3] = byte(v >> 7)
b[4] = byte(v)
return 5
} else if v > 0b01111111_1111111_1111111 {
b[0] = byte(v >> 21)
b[1] = byte(v >> 14)
b[2] = byte(v >> 7)
b[3] = byte(v)
return 4
} else if v > 0b01111111_1111111 {
b[0] = byte(v >> 14)
b[1] = byte(v >> 7)
b[2] = byte(v)
return 3
} else if v > 0b01111111 {
b[0] = byte(v >> 7)
b[1] = byte(v)
return 2
} else {
b[0] = byte(v)
return 1
}
}
func BenchmarkEncodeWire7varInline(b *testing.B) {
var buf [64]byte
for i := 0; i < b.N; i++ {
var pos int
bufpos := &pos
var aIndex = Wire(i)
var bIndex = Wire(i * 2)
var cIndex = Wire(i * 4)
var op byte
buf[*bufpos] = op
*bufpos = *bufpos + 1
if aIndex <= 0b01111111 &&
bIndex <= 0b01111111 &&
cIndex <= 0b01111111 {
encode7var1(buf[*bufpos+0:], uint32(aIndex))
encode7var1(buf[*bufpos+1:], uint32(bIndex))
encode7var1(buf[*bufpos+2:], uint32(cIndex))
*bufpos += 3
} else if aIndex <= 0b01111111_1111111 &&
bIndex <= 0b01111111_1111111 &&
cIndex <= 0b01111111_1111111 {
encode7var2(buf[*bufpos+0:], uint32(aIndex))
encode7var2(buf[*bufpos+2:], uint32(bIndex))
encode7var2(buf[*bufpos+4:], uint32(cIndex))
*bufpos += 6
} else if aIndex <= 0b01111111_1111111_1111111 &&
bIndex <= 0b01111111_1111111_1111111 &&
cIndex <= 0b01111111_1111111_1111111 {
encode7var3(buf[*bufpos+0:], uint32(aIndex))
encode7var3(buf[*bufpos+3:], uint32(bIndex))
encode7var3(buf[*bufpos+6:], uint32(cIndex))
*bufpos += 9
} else if aIndex <= 0b01111111_1111111_1111111_1111111 &&
bIndex <= 0b01111111_1111111_1111111_1111111 &&
cIndex <= 0b01111111_1111111_1111111_1111111 {
encode7var4(buf[*bufpos+0:], uint32(aIndex))
encode7var4(buf[*bufpos+4:], uint32(bIndex))
encode7var4(buf[*bufpos+8:], uint32(cIndex))
*bufpos += 12
} else {
encode7var5(buf[*bufpos+0:], uint32(aIndex))
encode7var5(buf[*bufpos+5:], uint32(bIndex))
encode7var5(buf[*bufpos+10:], uint32(cIndex))
*bufpos += 15
}
}
}
func encode7var1(b []byte, v uint32) {
b[0] = byte(v)
}
func encode7var2(b []byte, v uint32) {
b[0] = byte(v >> 7)
b[1] = byte(v)
}
func encode7var3(b []byte, v uint32) {
b[0] = byte(v >> 14)
b[1] = byte(v >> 7)
b[2] = byte(v)
}
func encode7var4(b []byte, v uint32) {
b[0] = byte(v >> 21)
b[1] = byte(v >> 14)
b[2] = byte(v >> 7)
b[3] = byte(v)
}
func encode7var5(b []byte, v uint32) {
b[0] = byte(v >> 28)
b[1] = byte(v >> 21)
b[2] = byte(v >> 14)
b[3] = byte(v >> 7)
b[4] = byte(v)
}
func BenchmarkTimeDuration(b *testing.B) {
var total time.Duration
for i := 0; i < b.N; i++ {
start := time.Now()
total += time.Now().Sub(start)
}
}

511
bedlam/circuit/svg.go Normal file
View File

@ -0,0 +1,511 @@
//
// Copyright (c) 2019-2022 Markku Rossi
//
// All rights reserved.
//
package circuit
import (
"fmt"
"io"
"sort"
)
const (
ioWidth = 32
ioHeight = 32
ioPadX = 16
ioPadY = 32
gateWidth = 32
gateHeight = 32
gatePadX = 16
gatePadY = 64
)
type tile struct {
gate *Gate
avg float64
x float64
y float64
}
type point struct {
x, y float64
}
type wireType int
const (
wireTypeNormal wireType = iota
wireTypeZero
wireTypeOne
)
type wire struct {
t wireType
from point
to point
}
func (w *wire) svg(out io.Writer) {
label := "1"
switch w.t {
case wireTypeNormal:
midY := w.from.y + (w.to.y-w.from.y)/2 - 5
fmt.Fprintf(out, ` <path d="M %v %v
v %v
C %v %v %v %v %v %v
v %v" />
`,
w.from.x, w.from.y,
midY-w.from.y,
w.from.x, midY+10,
w.to.x, midY,
w.to.x, midY+10,
w.to.y-midY-10,
)
case wireTypeZero:
label = "0"
fallthrough
case wireTypeOne:
fmt.Fprintf(out, ` <g fill="#000">
<text x="%v" y="%v" text-anchor="middle">%v</text>
</g>
`,
w.to.x, w.to.y-2, label)
}
}
type svgCtx struct {
wireStarts []point
zero Wire
one Wire
}
func (ctx *svgCtx) setWireType(input Wire, w *wire) {
if input == ctx.zero {
w.t = wireTypeZero
} else if input == ctx.one {
w.t = wireTypeOne
}
}
func (ctx *svgCtx) tileAvgInputX(t *tile) {
var count float64
switch t.gate.Op {
case XOR, XNOR, AND, OR:
if t.gate.Input1 != ctx.zero && t.gate.Input1 != ctx.one {
count++
t.avg += ctx.wireStarts[t.gate.Input1].x
}
fallthrough
case INV:
if count == 0 ||
(t.gate.Input0 != ctx.zero && t.gate.Input0 != ctx.one) {
count++
t.avg += ctx.wireStarts[t.gate.Input0].x
}
}
t.avg /= count
t.avg -= (gateWidth) / 2
}
// Svg creates an SVG output of the circuit.
func (c *Circuit) Svg(out io.Writer) {
c.AssignLevels()
cols := c.Stats[MaxWidth]
rows := c.Stats[NumLevels]
fmt.Printf("")
// Header.
ctx := &svgCtx{
wireStarts: make([]point, c.NumWires),
zero: InvalidWire,
one: InvalidWire,
}
iw := uint64(ioPadX + c.Inputs.Size()*(ioWidth+ioPadX))
ow := uint64(ioPadX + c.Outputs.Size()*(ioWidth+ioPadX))
width := cols * (gateWidth + gatePadX)
if iw > width {
width = iw
}
if ow > width {
width = ow
}
fmt.Fprintf(out,
`<svg xmlns="http://www.w3.org/2000/svg" width="%d" height="%d">
<style><![CDATA[
text {
font: 10px Courier, monospace;
}
]]></style>
<g fill="none" stroke="#000" stroke-width=".5">
`,
width, rows*(gateHeight+gatePadY)+2*(ioHeight+gatePadY))
// Input wires.
leftPad := ioPadX + int((width-iw)/2)
for i := 0; i < c.Inputs.Size(); i++ {
p := point{
x: float64(leftPad + i*(ioWidth+ioPadX)),
y: ioHeight,
}
ctx.wireStarts[i] = p
staticInput(out, p.x, p.y, fmt.Sprintf("i%v", i))
}
// Compute level widths.
widths := make([]uint64, rows)
var lastLevel Level
var x, y int
for _, g := range c.Gates {
if g.Level != lastLevel {
lastLevel = g.Level
y++
}
widths[y]++
}
// Render circuit.
var tiles []*tile
var wires []*wire
x = 0
y = -1
lastLevel = 0
for idx, g := range c.Gates {
if idx == 0 || g.Level != lastLevel {
wires = append(wires, renderRow(out, int(width),
float64(ioHeight+gatePadY+y*(gateHeight+gatePadY)), tiles,
ctx)...)
tiles = nil
lastLevel = g.Level
y++
x = 0
}
tiles = append(tiles, &tile{
gate: &c.Gates[idx],
})
x++
}
wires = append(wires, renderRow(out, int(width),
float64(ioHeight+gatePadY+y*(gateHeight+gatePadY)), tiles, ctx)...)
// Output wires.
y++
numOutputs := c.Outputs.Size()
var oAvg float64
for i := 0; i < numOutputs; i++ {
oAvg += ctx.wireStarts[c.NumWires-numOutputs+i].x
}
oAvg /= float64(numOutputs)
oWidth := float64((numOutputs-1)*(ioWidth+ioPadX) + ioWidth)
leftPad = int(oAvg-oWidth/2) + ioWidth/2
for i := 0; i < numOutputs; i++ {
p := point{
x: float64(leftPad + i*(ioWidth+ioPadX)),
y: float64(ioHeight + gatePadY + y*(gateHeight+gatePadY)),
}
staticOutput(out, p.x, p.y, fmt.Sprintf("o%v", i))
wires = append(wires, &wire{
from: ctx.wireStarts[c.NumWires-numOutputs+i],
to: p,
})
}
for _, w := range wires {
w.svg(out)
}
fmt.Fprintln(out, " </g>\n</svg>")
}
func renderRow(out io.Writer, width int, y float64, tiles []*tile,
ctx *svgCtx) []*wire {
var wires []*wire
for _, t := range tiles {
ctx.tileAvgInputX(t)
}
sort.Slice(tiles, func(i, j int) bool {
return tiles[i].avg < tiles[j].avg
})
// Assign x based on input average and push tiles right.
var next float64
for _, t := range tiles {
if next <= t.avg {
t.x = t.avg
} else {
t.x = next
}
next = t.x + gateWidth + gatePadX
}
// Starting from the right end, shift tiles left until they are on
// screen and not overlapping.
next = float64(width) - gateWidth - gatePadX
for i := len(tiles) - 1; i >= 0; i-- {
if tiles[i].x > next {
tiles[i].x = next
}
next -= gateWidth + gatePadX
}
for _, t := range tiles {
wires = append(wires, t.gate.svg(out, t.x, y, ctx)...)
}
return wires
}
func (g *Gate) svg(out io.Writer, x, y float64, ctx *svgCtx) []*wire {
fmt.Fprintf(out, ` <g transform="translate(%v %v)">
`,
x, y)
tmpl := templates[g.Op]
if tmpl == nil {
tmpl = templates[Count]
}
out.Write([]byte(tmpl.Expand()))
fmt.Fprintln(out, ` </g>`)
ctx.wireStarts[g.Output] = point{
x: x + gateWidth/2,
y: y + gateHeight,
}
var wires []*wire
switch g.Op {
case XOR, XNOR, AND, OR:
x0 := x + intCvt(35)
x1 := x + intCvt(65)
f0 := ctx.wireStarts[g.Input0]
f1 := ctx.wireStarts[g.Input1]
w0 := &wire{
from: f0,
to: point{
x: x0,
y: y,
},
}
w1 := &wire{
from: f1,
to: point{
x: x1,
y: y,
},
}
ctx.setWireType(g.Input0, w0)
ctx.setWireType(g.Input1, w1)
// The input pin order does not matter in the
// visualization. Swap input pins if input wires would cross
// each other.
if f0.x > f1.x {
w0.to, w1.to = w1.to, w0.to
}
wires = append(wires, w0)
wires = append(wires, w1)
case INV:
wire := &wire{
from: ctx.wireStarts[g.Input0],
to: point{
x: x + intCvt(50),
y: y,
},
}
ctx.setWireType(g.Input0, wire)
wires = append(wires, wire)
}
// Constant value gates.
switch g.Op {
case XOR:
if g.Input0 == g.Input1 {
ctx.zero = g.Output
staticOutput(out, x+gateWidth/2, y+gateHeight, "0")
}
case XNOR:
if g.Input0 == g.Input1 {
fmt.Printf("*** one!\n")
ctx.one = g.Output
staticOutput(out, x+gateWidth/2, y+gateHeight, "1")
}
}
return wires
}
func staticInput(out io.Writer, x, y float64, label string) {
fmt.Fprintf(out, ` <g fill="#000">
<text x="%v" y="%v" text-anchor="middle">%v</text>
</g>
`,
x, y-2, label)
}
func staticOutput(out io.Writer, x, y float64, label string) {
fmt.Fprintf(out, ` <g fill="#000">
<text x="%v" y="%v" text-anchor="middle">%v</text>
</g>
`,
x, y+10, label)
}
func scale(in int) float64 {
return float64(in) * gateWidth / 100
}
func path(out io.Writer) {
fmt.Fprintln(out, ` <path fill="none" stroke="#000" stroke-width=".5"`)
}
var templates [Count + 1]*Template
var intCvt IntCvt
var floatCvt FloatCvt
func init() {
templates[XOR] = NewTemplate(` <path d="M {{25}} {{20}}
c {{10}} {{10}} {{40}} {{10}} {{50}} 0" />
<path d="M {{25}} {{25}}
c {{10}} {{10}} {{40}} {{10}} {{50}} 0" />
<path d="M {{75}} {{25}}
v {{25}}
s 0 {{10}} {{-25}} {{25}}" />
<path d="M {{25}} {{25}}
v {{25}}
s 0 {{10}} {{25}} {{25}}" />
<path d="M {{35}} 0
v {{25}}
z" />
<path d="M {{65}} 0
v {{25}}
z" />
<path d="M {{50}} {{75}}
v {{25}}
z" />
`)
templates[XNOR] = NewTemplate(` <path d="M {{25}} {{20}}
c {{10}} {{10}} {{40}} {{10}} {{50}} 0" />
<path d="M {{25}} {{25}}
c {{10}} {{10}} {{40}} {{10}} {{50}} 0" />
<path d="M {{75}} {{25}}
v {{25}}
s 0 {{10}} {{-25}} {{25}}" />
<path d="M {{25}} {{25}}
v {{25}}
s 0 {{10}} {{25}} {{25}}" />
<circle cx="{{50}}" cy="{{80}}" r="{{5}}" />
<path d="M {{35}} 0
v {{25}}
z" />
<path d="M {{65}} 0
v {{25}}
z" />
<path d="M {{50}} {{85}}
v {{15}}
z" />
`)
templates[AND] = NewTemplate(` <path d="M {{25}} {{25}}
h {{50}}
v {{25}}
a {{25}} {{25}} 0 1 1 {{-50}} 0
v {{-25}}
z" />
<path d="M {{35}} 0
v {{25}}
z" />
<path d="M {{65}} 0
v {{25}}
z" />
<path d="M {{50}} {{75}}
v {{25}}
z" />
`)
templates[OR] = NewTemplate(` <path d="M {{25}} {{20}}
c {{10}} {{10}} {{40}} {{10}} {{50}} 0" />
<path d="M {{75}} {{20}}
v {{30}}
s 0 {{10}} {{-25}} {{25}}" />
<path d="M {{25}} {{20}}
v {{30}}
s 0 {{10}} {{25}} {{25}}" />
<path d="M {{35}} 0
v {{25}}
z" />
<path d="M {{65}} 0
v {{25}}
z" />
<path d="M {{50}} {{75}}
v {{25}}
z" />
`)
templates[INV] = NewTemplate(` <path d="M {{25}} {{25}}
h {{50}}
l {{-25}} {{43}}
z" />
<circle cx="{{50}}" cy="{{73.5}}" r="{{5}}" />
<path d="M {{50}} 0
v {{25}}
z" />
<path d="M {{50}} {{79}}
v {{21}}
z" />
`)
templates[Count] = NewTemplate(`<path
d="M {{25}} {{25}} h {{50}} v{{50}} h {{-50}} z" />`)
floatCvt = func(v float64) float64 {
return v * gateWidth / 100
}
intCvt = func(v int) float64 {
return float64(v) * gateWidth / 100
}
for op := XOR; op < Count+1; op++ {
if templates[op] != nil {
templates[op].FloatCvt = floatCvt
templates[op].IntCvt = intCvt
}
}
}

142
bedlam/circuit/template.go Normal file
View File

@ -0,0 +1,142 @@
//
// Copyright (c) 2022 Markku Rossi
//
// All rights reserved.
//
package circuit
import (
"fmt"
"regexp"
"strconv"
"strings"
)
var reVar = regexp.MustCompilePOSIX(`{(.?){([^\}]+)}}`)
// Template defines an expandable text template.
type Template struct {
parts []*part
FloatCvt FloatCvt
IntCvt IntCvt
StringCvt StringCvt
}
// FloatCvt converts a float64 value to float64 value.
type FloatCvt func(v float64) float64
// IntCvt converts an integer value to float64 value.
type IntCvt func(v int) float64
// StringCvt converts a string value to a string value.
type StringCvt func(v string) string
const (
partFloat = iota
partInt
partString
)
type part struct {
t int
fv float64
iv int
sv string
}
func (p part) String() string {
switch p.t {
case partFloat:
return fmt.Sprintf("float64:%v", p.fv)
case partInt:
return fmt.Sprintf("int:%v", p.iv)
case partString:
return p.sv
default:
return fmt.Sprintf("{part %d}", p.t)
}
}
// NewTemplate parses the input string and returns the parsed
// Template.
func NewTemplate(input string) *Template {
t := &Template{
FloatCvt: func(v float64) float64 { return v },
IntCvt: func(v int) float64 { return float64(v) },
StringCvt: func(v string) string { return v },
}
matches := reVar.FindAllStringSubmatchIndex(input, -1)
if matches == nil {
return t
}
var start int
var err error
for _, m := range matches {
if m[0] > start {
t.parts = append(t.parts, &part{
t: partString,
sv: input[start:m[0]],
})
}
content := input[m[4]:m[5]]
part := &part{
t: partFloat,
sv: content,
}
t.parts = append(t.parts, part)
start = m[1]
if m[2] != m[3] {
switch input[m[2]:m[3]] {
default:
panic(fmt.Sprintf("unknown template variable conversion: %s",
input[m[2]:m[3]]))
}
} else {
part.iv, err = strconv.Atoi(content)
if err == nil {
part.t = partInt
} else {
part.fv, err = strconv.ParseFloat(content, 64)
if err == nil {
part.t = partFloat
} else {
part.sv = content
part.t = partString
}
}
}
}
if start < len(input) {
t.parts = append(t.parts, &part{
t: partString,
sv: input[start:],
})
}
return t
}
// Expand expands the template.
func (t *Template) Expand() string {
var b strings.Builder
for _, part := range t.parts {
switch part.t {
case partFloat:
b.WriteString(fmt.Sprintf("%v", t.FloatCvt(part.fv)))
case partInt:
b.WriteString(fmt.Sprintf("%v", t.IntCvt(part.iv)))
case partString:
b.WriteString(part.sv)
default:
panic(fmt.Sprintf("invalid part type: %v", part.t))
}
}
return b.String()
}

View File

@ -0,0 +1,88 @@
//
// Copyright (c) 2022-2023 Markku Rossi
//
// All rights reserved.
//
package circuit
import (
"testing"
)
var tmplXOR = `<svg xmlns="http://www.w3.org/2000/svg" width="100" height="100">
<path fill="none" stroke="#000" stroke-width="1"
d="M {{25}} {{20}}
c {{10}} {{10}} {{40}} {{10}} {{50}} 0" />
<path fill="none" stroke="#000" stroke-width="1"
d="M 25 25
c 10 10 40 10 50 0" />
<path fill="none" stroke="#000" stroke-width="1"
d="M 75 25
v 25
s 0 10 -25 25 " />
<path fill="none" stroke="#000" stroke-width="1"
d="M 25 25
v 25
s 0 10 25 25 " />
<!-- Wires -->
<path fill="none" stroke="#000" stroke-width="1"
d="M 35 0
v 25
z" />
<path fill="none" stroke="#000" stroke-width="1"
d="M 65 0
v 25
z" />
<path fill="none" stroke="#000" stroke-width="1"
d="M 50 75
v 25
z" />
</svg>
`
var tmplXORExpanded = `<svg xmlns="http://www.w3.org/2000/svg" width="100" height="100">
<path fill="none" stroke="#000" stroke-width="1"
d="M 6.25 5
c 2.5 2.5 10 2.5 12.5 0" />
<path fill="none" stroke="#000" stroke-width="1"
d="M 25 25
c 10 10 40 10 50 0" />
<path fill="none" stroke="#000" stroke-width="1"
d="M 75 25
v 25
s 0 10 -25 25 " />
<path fill="none" stroke="#000" stroke-width="1"
d="M 25 25
v 25
s 0 10 25 25 " />
<!-- Wires -->
<path fill="none" stroke="#000" stroke-width="1"
d="M 35 0
v 25
z" />
<path fill="none" stroke="#000" stroke-width="1"
d="M 65 0
v 25
z" />
<path fill="none" stroke="#000" stroke-width="1"
d="M 50 75
v 25
z" />
</svg>
`
func TestTemplate(t *testing.T) {
tmpl := NewTemplate(tmplXOR)
tmpl.IntCvt = func(v int) float64 {
return float64(v) * 25 / 100
}
expanded := tmpl.Expand()
if expanded != tmplXORExpanded {
t.Errorf("template expansion failed: got\n%v\n", expanded)
}
}

163
bedlam/circuit/timing.go Normal file
View File

@ -0,0 +1,163 @@
//
// Copyright (c) 2020-2023 Markku Rossi
//
// All rights reserved.
//
package circuit
import (
"fmt"
"os"
"time"
"github.com/markkurossi/tabulate"
"source.quilibrium.com/quilibrium/monorepo/bedlam/p2p"
)
// Timing records timing samples and renders a profiling report.
type Timing struct {
Start time.Time
Samples []*Sample
}
// NewTiming creates a new Timing instance.
func NewTiming() *Timing {
return &Timing{
Start: time.Now(),
}
}
// Sample adds a timing sample with label and data columns.
func (t *Timing) Sample(label string, cols []string) *Sample {
start := t.Start
if len(t.Samples) > 0 {
start = t.Samples[len(t.Samples)-1].End
}
sample := &Sample{
Label: label,
Start: start,
End: time.Now(),
Cols: cols,
}
t.Samples = append(t.Samples, sample)
return sample
}
// Print prints profiling report to standard output.
func (t *Timing) Print(stats p2p.IOStats) {
if len(t.Samples) == 0 {
return
}
sent := stats.Sent.Load()
received := stats.Recvd.Load()
flushed := stats.Flushed.Load()
tab := tabulate.New(tabulate.UnicodeLight)
tab.Header("Op").SetAlign(tabulate.ML)
tab.Header("Time").SetAlign(tabulate.MR)
tab.Header("%").SetAlign(tabulate.MR)
tab.Header("Xfer").SetAlign(tabulate.MR)
total := t.Samples[len(t.Samples)-1].End.Sub(t.Start)
for _, sample := range t.Samples {
row := tab.Row()
row.Column(sample.Label)
duration := sample.End.Sub(sample.Start)
row.Column(fmt.Sprintf("%s", duration.String()))
row.Column(fmt.Sprintf("%.2f%%",
float64(duration)/float64(total)*100))
for _, col := range sample.Cols {
row.Column(col)
}
for idx, sub := range sample.Samples {
row := tab.Row()
var prefix string
if idx+1 >= len(sample.Samples) {
prefix = "\u2570\u2574"
} else {
prefix = "\u251C\u2574"
}
row.Column(prefix + sub.Label).SetFormat(tabulate.FmtItalic)
var d time.Duration
if sub.Abs > 0 {
d = sub.Abs
} else {
d = sub.End.Sub(sub.Start)
}
row.Column(d.String()).SetFormat(tabulate.FmtItalic)
row.Column(
fmt.Sprintf("%.2f%%", float64(d)/float64(duration)*100)).
SetFormat(tabulate.FmtItalic)
}
}
row := tab.Row()
row.Column("Total").SetFormat(tabulate.FmtBold)
row.Column(t.Samples[len(t.Samples)-1].End.Sub(t.Start).String()).
SetFormat(tabulate.FmtBold)
row.Column("").SetFormat(tabulate.FmtBold)
row.Column(FileSize(sent + received).String()).SetFormat(tabulate.FmtBold)
row = tab.Row()
row.Column("\u251C\u2574Sent").SetFormat(tabulate.FmtItalic)
row.Column("")
row.Column(
fmt.Sprintf("%.2f%%", float64(sent)/float64(sent+received)*100)).
SetFormat(tabulate.FmtItalic)
row.Column(FileSize(sent).String()).SetFormat(tabulate.FmtItalic)
row = tab.Row()
row.Column("\u251C\u2574Rcvd").SetFormat(tabulate.FmtItalic)
row.Column("")
row.Column(
fmt.Sprintf("%.2f%%", float64(received)/float64(sent+received)*100)).
SetFormat(tabulate.FmtItalic)
row.Column(FileSize(received).String()).SetFormat(tabulate.FmtItalic)
row = tab.Row()
row.Column("\u2570\u2574Flcd").SetFormat(tabulate.FmtItalic)
row.Column("")
row.Column("")
row.Column(fmt.Sprintf("%v", flushed)).SetFormat(tabulate.FmtItalic)
tab.Print(os.Stdout)
}
// Sample contains information about one timing sample.
type Sample struct {
Label string
Start time.Time
End time.Time
Abs time.Duration
Cols []string
Samples []*Sample
}
// SubSample adds a sub-sample for a timing sample.
func (s *Sample) SubSample(label string, end time.Time) {
start := s.Start
if len(s.Samples) > 0 {
start = s.Samples[len(s.Samples)-1].End
}
s.Samples = append(s.Samples, &Sample{
Label: label,
Start: start,
End: end,
})
}
// AbsSubSample adds an absolute sub-sample for a timing sample.
func (s *Sample) AbsSubSample(label string, duration time.Duration) {
s.Samples = append(s.Samples, &Sample{
Label: label,
Abs: duration,
})
}

1
bedlam/compiler/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
*.circ

View File

@ -0,0 +1,238 @@
//
// Copyright (c) 2019-2023 Markku Rossi
//
// All rights reserved.
//
package compiler
import (
"fmt"
"io"
"math/big"
"testing"
"source.quilibrium.com/quilibrium/monorepo/bedlam/circuit"
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler/utils"
"source.quilibrium.com/quilibrium/monorepo/bedlam/ot"
"source.quilibrium.com/quilibrium/monorepo/bedlam/p2p"
)
type Test struct {
Name string
Heavy bool
Operand string
Bits int
Eval func(a *big.Int, b *big.Int) *big.Int
Code string
}
var tests = []Test{
{
Name: "Add",
Heavy: true,
Operand: "+",
Bits: 2,
Eval: func(a *big.Int, b *big.Int) *big.Int {
result := big.NewInt(0)
result.Add(a, b)
return result
},
Code: `
package main
func main(a, b int3) int3 {
return a + b
}
`,
},
// 1-bit, 2-bit, and n-bit multipliers have a bit different wiring.
{
Name: "Multiply 1-bit",
Heavy: false,
Operand: "*",
Bits: 1,
Eval: func(a *big.Int, b *big.Int) *big.Int {
result := big.NewInt(0)
result.Mul(a, b)
return result
},
Code: `
package main
func main(a, b int1) int1 {
return a * b
}
`,
},
{
Name: "Multiply 2-bits",
Heavy: true,
Operand: "*",
Bits: 2,
Eval: func(a *big.Int, b *big.Int) *big.Int {
result := big.NewInt(0)
result.Mul(a, b)
return result
},
Code: `
package main
func main(a, b int4) int4 {
return a * b
}
`,
},
{
Name: "Multiply n-bits",
Heavy: true,
Operand: "*",
Bits: 2,
Eval: func(a *big.Int, b *big.Int) *big.Int {
result := big.NewInt(0)
result.Mul(a, b)
return result
},
Code: `
package main
func main(a, b int6) int6 {
return a * b
}
`,
},
}
func TestArithmetics(t *testing.T) {
for _, test := range tests {
if testing.Short() && test.Heavy {
fmt.Printf("Skipping %s\n", test.Name)
continue
}
circ, _, err := New(utils.NewParams()).Compile(test.Code, nil)
if err != nil {
t.Fatalf("Failed to compile test %s: %s", test.Name, err)
}
limit := 1 << test.Bits
for g := 0; g < limit; g++ {
for e := 0; e < limit; e++ {
gr, ew := io.Pipe()
er, gw := io.Pipe()
gio := newReadWriter(gr, gw)
eio := newReadWriter(er, ew)
gInput := big.NewInt(int64(g))
eInput := big.NewInt(int64(e))
gerr := make(chan error)
eerr := make(chan error)
res := make(chan []*big.Int)
go func() {
fmt.Println("start garbler")
_, err := circuit.Garbler(p2p.NewConn(gio), ot.NewFerret(1, ":5555"),
circ, gInput, false)
fmt.Println("end garbler")
gerr <- err
}()
go func() {
fmt.Println("start evaluator")
result, err := circuit.Evaluator(p2p.NewConn(eio),
ot.NewFerret(2, "127.0.0.1:5555"), circ, eInput, false)
fmt.Println("end evaluator")
eerr <- err
res <- result
}()
err = <-gerr
if err != nil {
t.Fatalf("Garbler failed: %s\n", err)
}
err = <-eerr
if err != nil {
t.Fatalf("Evaluator failed: %s\n", err)
}
result := <-res
expected := test.Eval(gInput, eInput)
if expected.Cmp(result[0]) != 0 {
t.Errorf("%s failed: %s %s %s = %s, expected %s\n",
test.Name, gInput, test.Operand, eInput, result,
expected)
}
}
}
}
}
var mult512 = `
package main
func main(a, b int512) int512 {
return a * b
}
`
func BenchmarkMult(b *testing.B) {
circ, _, err := New(utils.NewParams()).Compile(mult512, nil)
if err != nil {
b.Fatalf("failed to compile test: %s", err)
}
gr, ew := io.Pipe()
er, gw := io.Pipe()
gio := newReadWriter(gr, gw)
eio := newReadWriter(er, ew)
gInput := big.NewInt(int64(11))
eInput := big.NewInt(int64(13))
gerr := make(chan error)
eerr := make(chan error)
res := make(chan []*big.Int)
go func() {
_, err := circuit.Garbler(p2p.NewConn(gio), ot.NewFerret(1, ":5555"),
circ, gInput, false)
gerr <- err
}()
go func() {
result, err := circuit.Evaluator(p2p.NewConn(eio),
ot.NewFerret(2, "127.0.0.1:5555"), circ, eInput, false)
eerr <- err
res <- result
}()
err = <-gerr
if err != nil {
b.Fatalf("Garbler failed: %s\n", err)
}
err = <-eerr
if err != nil {
b.Fatalf("Evaluator failed: %s\n", err)
}
<-res
}
func newReadWriter(in io.Reader, out io.Writer) io.ReadWriter {
return &wrap{
in: in,
out: out,
}
}
type wrap struct {
in io.Reader
out io.Writer
}
func (w *wrap) Read(p []byte) (n int, err error) {
return w.in.Read(p)
}
func (w *wrap) Write(p []byte) (n int, err error) {
return w.out.Write(p)
}

918
bedlam/compiler/ast/ast.go Normal file
View File

@ -0,0 +1,918 @@
//
// ast.go
//
// Copyright (c) 2019-2024 Markku Rossi
//
// All rights reserved.
//
package ast
import (
"fmt"
"io"
"strings"
"unicode"
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler/mpa"
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler/ssa"
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler/utils"
"source.quilibrium.com/quilibrium/monorepo/bedlam/types"
)
var (
_ AST = &List{}
_ AST = &Func{}
_ AST = &VariableDef{}
_ AST = &Assign{}
_ AST = &If{}
_ AST = &Call{}
_ AST = &ArrayCast{}
_ AST = &Return{}
_ AST = &For{}
_ AST = &ForRange{}
_ AST = &Binary{}
_ AST = &Unary{}
_ AST = &Slice{}
_ AST = &Index{}
_ AST = &VariableRef{}
_ AST = &BasicLit{}
_ AST = &CompositeLit{}
_ AST = &Make{}
_ AST = &Copy{}
)
func indent(w io.Writer, indent int) {
for i := 0; i < indent; i++ {
fmt.Fprint(w, " ")
}
}
// Type specifies AST types.
type Type int
// AST types.
const (
TypeName Type = iota
TypeArray
TypeSlice
TypeStruct
TypePointer
TypeAlias
)
// TypeInfo contains AST type information.
type TypeInfo struct {
utils.Point
Type Type
Name Identifier
ElementType *TypeInfo
ArrayLength AST
TypeName string
StructFields []StructField
AliasType *TypeInfo
Methods map[string]*Func
Annotations Annotations
}
// Equal tests if the argument TypeInfo is equal to this TypeInfo.
func (ti *TypeInfo) Equal(o *TypeInfo) bool {
if ti.Type != o.Type {
return false
}
switch ti.Type {
case TypeName:
return ti.Name.String() == o.Name.String()
case TypeArray:
return ti.ElementType.Equal(o.ElementType) &&
ti.ArrayLength == o.ArrayLength
case TypeSlice, TypePointer:
return ti.ElementType.Equal(o.ElementType)
case TypeStruct:
if len(ti.StructFields) != len(o.StructFields) {
return false
}
for idx, f := range ti.StructFields {
if !f.Type.Equal(o.StructFields[idx].Type) {
return false
}
}
return true
case TypeAlias:
return ti.AliasType.Equal(o.AliasType)
default:
panic("unsupported type")
}
}
// StructField contains AST structure field information.
type StructField struct {
utils.Point
Name string
Type *TypeInfo
}
func (ti *TypeInfo) String() string {
return ti.format(false)
}
// Format print the type definition of the type info.
func (ti *TypeInfo) Format() string {
return ti.format(true)
}
func (ti *TypeInfo) format(pp bool) string {
var str string
if pp {
str = fmt.Sprintf("type %s ", ti.TypeName)
}
switch ti.Type {
case TypeName:
return str + ti.Name.String()
case TypeArray:
return fmt.Sprintf("%s[%s]%s", str, ti.ArrayLength, ti.ElementType)
case TypeSlice:
return fmt.Sprintf("%s[]%s", str, ti.ElementType)
case TypeStruct:
str = fmt.Sprintf("%sstruct {", str)
if pp {
var width int
for _, field := range ti.StructFields {
if len(field.Name) > width {
width = len(field.Name)
}
}
for idx, field := range ti.StructFields {
if idx == 0 {
str += "\n"
}
str += " "
str += field.Name
for i := len(field.Name); i < width; i++ {
str += " "
}
str += fmt.Sprintf(" %s\n", field.Type.String())
}
} else {
for idx, field := range ti.StructFields {
if idx > 0 {
str += ", "
}
str += fmt.Sprintf("%s %s", field.Name, field.Type.String())
}
}
return str + "}"
case TypeAlias:
return fmt.Sprintf("%s= %s", str, ti.AliasType)
case TypePointer:
return fmt.Sprintf("%s*%s", str, ti.ElementType)
default:
return fmt.Sprintf("%s{TypeInfo %d}", str, ti.Type)
}
}
// IsIdentifier returns true if the type info specifies a type name
// without package.
func (ti *TypeInfo) IsIdentifier() bool {
return ti.Type == TypeName && len(ti.Name.Package) == 0
}
// Resolve resolves the type information in the environment.
func (ti *TypeInfo) Resolve(env *Env, ctx *Codegen, gen *ssa.Generator) (
result types.Info, err error) {
if ti == nil {
return
}
switch ti.Type {
case TypeName:
result, err = types.Parse(ti.Name.Name)
if err == nil {
return
}
// Check dynamic types from the env.
var b ssa.Binding
var pkg *Package
var ok bool
if len(ti.Name.Package) == 0 {
// Plain indentifiers.
b, ok = env.Get(ti.Name.Name)
if !ok {
// Check dynamic types from the pkg.
b, ok = ctx.Package.Bindings.Get(ti.Name.Name)
}
}
if !ok {
// Qualified names and package-local names.
var pkgName string
if len(ti.Name.Package) > 0 {
pkgName = ti.Name.Package
} else if ti.Name.Defined != ctx.Package.Name {
pkgName = ti.Name.Defined
}
if len(pkgName) > 0 {
pkg, ok = ctx.Packages[pkgName]
if !ok {
return result, ctx.Errorf(ti, "unknown package: %s",
pkgName)
}
b, ok = pkg.Bindings.Get(ti.Name.Name)
}
}
if ok {
val, ok := b.Bound.(*ssa.Value)
if ok && val.TypeRef {
return val.Type, nil
}
}
return result, ctx.Errorf(ti, "undefined name: %s", ti)
case TypeArray:
// Array length.
constLength, ok, err := ti.ArrayLength.Eval(env, ctx, gen)
if err != nil {
return result, err
}
if !ok {
return result, ctx.Errorf(ti.ArrayLength,
"array length is not constant: %s", ti.ArrayLength)
}
length, err := constLength.ConstInt()
if err != nil {
return result, ctx.Errorf(ti.ArrayLength,
"invalid array length: %s", err)
}
// Element type.
elInfo, err := ti.ElementType.Resolve(env, ctx, gen)
if err != nil {
return result, err
}
return types.Info{
Type: types.TArray,
IsConcrete: true,
Bits: length * elInfo.Bits,
MinBits: length * elInfo.MinBits,
ElementType: &elInfo,
ArraySize: length,
}, nil
case TypeSlice:
// Element type.
elInfo, err := ti.ElementType.Resolve(env, ctx, gen)
if err != nil {
return result, err
}
// Bits and ArraySize are left uninitialized and they must be
// defined when type is instantiated.
return types.Info{
Type: types.TSlice,
ElementType: &elInfo,
}, nil
case TypePointer:
// Element type.
elInfo, err := ti.ElementType.Resolve(env, ctx, gen)
if err != nil {
return result, err
}
return types.Info{
Type: types.TPtr,
IsConcrete: true,
Bits: elInfo.Bits,
MinBits: elInfo.Bits,
ElementType: &elInfo,
}, nil
default:
return result, ctx.Errorf(ti, "can't resolve type %s", ti)
}
}
// ConstantDef implements a constant definition.
type ConstantDef struct {
utils.Point
Name string
Type *TypeInfo
Init AST
Annotations Annotations
}
// Exported describes if the constant is exported from the package.
func (ast *ConstantDef) Exported() bool {
return IsExported(ast.Name)
}
// IsExported describes if the name is exported from the package.
func IsExported(name string) bool {
if len(name) == 0 {
return false
}
return unicode.IsUpper([]rune(name)[0])
}
func (ast *ConstantDef) String() string {
result := fmt.Sprintf("const %s", ast.Name)
if ast.Type != nil {
result += fmt.Sprintf(" %s", ast.Type)
}
if ast.Init != nil {
result += fmt.Sprintf(" = %s", ast.Init)
}
return result
}
// Identifier implements an AST identifier.
type Identifier struct {
Defined string
Package string
Name string
}
func (i Identifier) String() string {
if len(i.Package) == 0 {
return i.Name
}
return fmt.Sprintf("%s.%s", i.Package, i.Name)
}
// Qualified tells if the identifier has package part.
func (i Identifier) Qualified() bool {
return len(i.Package) > 0
}
// AST implements abstract syntax tree nodes.
type AST interface {
utils.Locator
String() string
// SSA generates SSA code from the AST node. The code is appended
// into the basic block `block'. The function returns the next
// sequential basic block. The `ssa.Dead' is set to `true' if the
// code terminates i.e. all following AST nodes are dead code.
SSA(block *ssa.Block, ctx *Codegen, gen *ssa.Generator) (
*ssa.Block, []ssa.Value, error)
// Eval evaluates the AST node during constant propagation.
Eval(env *Env, ctx *Codegen, gen *ssa.Generator) (
value ssa.Value, isConstant bool, err error)
}
// NewEnv creates a new environment based on the current environment
// bindings in the block.
func NewEnv(block *ssa.Block) *Env {
return &Env{
Bindings: block.Bindings.Clone(),
}
}
// Env implements a value bindings environment.
type Env struct {
Bindings *ssa.Bindings
}
// Debug print the environment.
func (env *Env) Debug() {
fmt.Print("Env: ")
env.Bindings.Debug()
}
// Get gets the value binding from the environment.
func (env *Env) Get(name string) (ssa.Binding, bool) {
return env.Bindings.Get(name)
}
// Set sets the value binding to the environment.
func (env *Env) Set(v ssa.Value, val *ssa.Value) {
env.Bindings.Define(v, val)
}
// List implements an AST list statement.
type List []AST
func (ast List) String() string {
result := "{\n"
for _, a := range ast {
result += fmt.Sprintf("\t%s\n", a)
}
return result + "}\n"
}
// Location implements the compiler.utils.Locator interface.
func (ast List) Location() utils.Point {
if len(ast) > 0 {
return ast[0].Location()
}
return utils.Point{}
}
// Variable implements an AST variable.
type Variable struct {
utils.Point
Name string
Type *TypeInfo
}
func (v Variable) String() string {
if v.Type == nil {
return v.Name
}
return fmt.Sprintf("%s %s", v.Name, v.Type)
}
// Func implements an AST function.
type Func struct {
utils.Point
Name string
This *Variable
Args []*Variable
Return []*Variable
Returns []*ReturnInfo
NamedReturn bool
Body List
End utils.Point
NumInstances int
Annotations Annotations
}
// ReturnInfo provide information about function return values.
type ReturnInfo struct {
Return *Return
Types []types.Info
}
// Annotations specify function annotations.
type Annotations []string
// FirstSentence returns the first sentence from the annotations or an
// empty string it if annotations are empty.
func (ann Annotations) FirstSentence() string {
str := strings.Join(ann, "\n")
idx := strings.IndexRune(str, '.')
if idx > 0 {
return str[:idx+1]
}
return ""
}
// NewFunc creates a new function definition.
func NewFunc(loc utils.Point, name string, args []*Variable, ret []*Variable,
namedReturn bool, body List, end utils.Point,
annotations Annotations) *Func {
// Skip empty lines from the beginning and end of annotations.
for i := 0; i < len(annotations); i++ {
if len(strings.TrimSpace(annotations[i])) > 0 {
annotations = annotations[i:]
break
}
}
for i := len(annotations) - 1; i >= 0; i-- {
if len(strings.TrimSpace(annotations[i])) > 0 {
annotations = annotations[0 : i+1]
break
}
}
return &Func{
Point: loc,
Name: name,
Args: args,
Return: ret,
NamedReturn: namedReturn,
Body: body,
End: end,
Annotations: annotations,
}
}
func (ast *Func) String() string {
var str string
if ast.This != nil {
str = fmt.Sprintf("func (%s %s) %s(",
ast.This.Name, ast.This.Type, ast.Name)
} else {
str = fmt.Sprintf("func %s(", ast.Name)
}
for idx, arg := range ast.Args {
if idx > 0 {
str += ", "
}
if idx+1 < len(ast.Args) && arg.Type.Equal(ast.Args[idx+1].Type) {
str += arg.Name
} else {
str += fmt.Sprintf("%s %s", arg.Name, arg.Type)
}
}
str += ")"
if len(ast.Return) > 0 {
if ast.NamedReturn {
str += " ("
for idx, ret := range ast.Return {
if idx > 0 {
str += ", "
}
if idx+1 < len(ast.Return) &&
ret.Type.Equal(ast.Return[idx+1].Type) {
str += ret.Name
} else {
str += fmt.Sprintf("%s %s", ret.Name, ret.Type)
}
}
str += ")"
} else if len(ast.Return) > 1 {
str += " ("
for idx, ret := range ast.Return {
if idx > 0 {
str += ", "
}
str += fmt.Sprintf("%s", ret.Type)
}
str += ")"
} else {
str += fmt.Sprintf(" %s", ast.Return[0].Type)
}
}
return str
}
// VariableDef implements an AST variable definition.
type VariableDef struct {
utils.Point
Names []string
Type *TypeInfo
Init AST
Annotations Annotations
}
func (ast *VariableDef) String() string {
result := fmt.Sprintf("var %v", strings.Join(ast.Names, ", "))
if ast.Type != nil {
result += fmt.Sprintf(" %s", ast.Type)
}
if ast.Init != nil {
result += fmt.Sprintf(" = %s", ast.Init)
}
return result
}
// Assign implements an AST assignment expression.
type Assign struct {
utils.Point
LValues []AST
Exprs []AST
Define bool
}
func (ast *Assign) String() string {
var op string
if ast.Define {
op = ":="
} else {
op = "="
}
return fmt.Sprintf("%v %s %v", ast.LValues, op, ast.Exprs)
}
// If implements an AST if statement.
type If struct {
utils.Point
Expr AST
True AST
False AST
}
func (ast *If) String() string {
return fmt.Sprintf("if %s", ast.Expr)
}
// Call implements an AST call expression.
type Call struct {
utils.Point
Ref *VariableRef
Exprs []AST
}
func (ast *Call) String() string {
str := fmt.Sprintf("%s(", ast.Ref)
for idx, expr := range ast.Exprs {
if idx > 0 {
str += ", "
}
str += fmt.Sprintf("%v", expr)
}
return str + ")"
}
// ArrayCast implements array cast expressions.
type ArrayCast struct {
utils.Point
TypeInfo *TypeInfo
Expr AST
}
func (ast *ArrayCast) String() string {
return fmt.Sprintf("%v(%v)", ast.TypeInfo, ast.Expr)
}
// Return implements an AST return statement.
type Return struct {
utils.Point
Exprs []AST
AutoGenerated bool
}
func (ast *Return) String() string {
return fmt.Sprintf("return %v", ast.Exprs)
}
// For implements an AST for statement.
type For struct {
utils.Point
Init AST
Cond AST
Inc AST
Body List
}
func (ast *For) String() string {
return fmt.Sprintf("for %s; %s; %s %s",
ast.Init, ast.Cond, ast.Inc, ast.Body)
}
// ForRange implements an AST for for-range statement.
type ForRange struct {
utils.Point
ExprList []AST
Def bool
Expr AST
Body List
}
func (ast *ForRange) String() string {
result := "for "
for idx, expr := range ast.ExprList {
if idx > 0 {
result += ", "
}
result += expr.String()
}
if ast.Def {
result += " := range "
} else {
result += " = range "
}
result += ast.Expr.String()
return result
}
// BinaryType defines binary expression types.
type BinaryType int
// Binary expression types.
const (
BinaryMul BinaryType = iota
BinaryDiv
BinaryMod
BinaryLshift
BinaryRshift
BinaryBand
BinaryBclear
BinaryBor
BinaryBxor
BinaryAdd
BinarySub
BinaryEq
BinaryNeq
BinaryLt
BinaryLe
BinaryGt
BinaryGe
BinaryAnd
BinaryOr
)
var binaryTypes = map[BinaryType]string{
BinaryMul: "*",
BinaryDiv: "/",
BinaryMod: "%",
BinaryLshift: "<<",
BinaryRshift: ">>",
BinaryBand: "&",
BinaryBclear: "&^",
BinaryBor: "|",
BinaryBxor: "^",
BinaryAdd: "+",
BinarySub: "-",
BinaryEq: "==",
BinaryNeq: "!=",
BinaryLt: "<",
BinaryLe: "<=",
BinaryGt: ">",
BinaryGe: ">=",
BinaryAnd: "&&",
BinaryOr: "||",
}
func (t BinaryType) String() string {
name, ok := binaryTypes[t]
if ok {
return name
}
return fmt.Sprintf("{BinaryType %d}", t)
}
// Binary implements an AST binary expression.
type Binary struct {
utils.Point
Left AST
Op BinaryType
Right AST
}
func (ast *Binary) String() string {
return fmt.Sprintf("%s %s %s", ast.Left, ast.Op, ast.Right)
}
// UnaryType defines unary expression types.
type UnaryType int
// Unary expression types.
const (
UnaryPlus UnaryType = iota
UnaryMinus
UnaryNot
UnaryXor
UnaryPtr
UnaryAddr
UnarySend
)
var unaryTypes = map[UnaryType]string{
UnaryPlus: "+",
UnaryMinus: "-",
UnaryNot: "!",
UnaryXor: "^",
UnaryPtr: "*",
UnaryAddr: "&",
UnarySend: "<-",
}
func (t UnaryType) String() string {
name, ok := unaryTypes[t]
if ok {
return name
}
return fmt.Sprintf("{UnaryType %d}", t)
}
// Unary implements an AST unary expression.
type Unary struct {
utils.Point
Type UnaryType
Expr AST
}
func (ast *Unary) String() string {
return fmt.Sprintf("%s%s", ast.Type, ast.Expr)
}
// Slice implements an AST slice expression.
type Slice struct {
utils.Point
Expr AST
From AST
To AST
}
func (ast *Slice) String() string {
var fromStr, toStr string
if ast.From != nil {
fromStr = ast.From.String()
}
if ast.To != nil {
toStr = ast.To.String()
}
return fmt.Sprintf("%s[%s:%s]", ast.Expr, fromStr, toStr)
}
// Index implements an AST array index expression.
type Index struct {
utils.Point
Expr AST
Index AST
}
func (ast *Index) String() string {
return fmt.Sprintf("%s[%s]", ast.Expr, ast.Index)
}
// VariableRef implements an AST variable reference.
type VariableRef struct {
utils.Point
Name Identifier
}
func (ast *VariableRef) String() string {
return ast.Name.String()
}
// BasicLit implements an AST basic literal value.
type BasicLit struct {
utils.Point
Value interface{}
}
func (ast *BasicLit) String() string {
return ConstantName(ast.Value)
}
// ConstantName returns the name of the constant value.
func ConstantName(value interface{}) string {
switch val := value.(type) {
case int, uint, int32, uint32, int64, uint64:
return fmt.Sprintf("%d", val)
case *mpa.Int:
return fmt.Sprintf("%s", val)
case bool:
return fmt.Sprintf("%v", val)
case string:
return fmt.Sprintf("%q", val)
default:
return fmt.Sprintf("{undefined constant %v (%T)}", val, val)
}
}
// CompositeLit implements an AST composite literal value.
type CompositeLit struct {
utils.Point
Type *TypeInfo
Value []KeyedElement
}
func (ast *CompositeLit) String() string {
str := ast.Type.String()
str += "{"
for idx, e := range ast.Value {
if idx > 0 {
str += ","
}
if e.Key != nil {
str += fmt.Sprintf("%s: %s", e.Key, e.Element)
} else {
str += e.Element.String()
}
}
return str + "}"
}
// KeyedElement implements a keyed element of composite literal.
type KeyedElement struct {
Key AST
Element AST
}
// Make implements the builtin function make.
type Make struct {
utils.Point
Type *TypeInfo
Exprs []AST
}
func (ast *Make) String() string {
str := fmt.Sprintf("make(%s", ast.Type)
for _, expr := range ast.Exprs {
str += ", "
str += expr.String()
}
return str + ")"
}
// Copy implements the builtin function copy.
type Copy struct {
utils.Point
Dst AST
Src AST
}
func (ast *Copy) String() string {
return fmt.Sprintf("copy(%v, %v)", ast.Dst, ast.Src)
}

View File

@ -0,0 +1,409 @@
//
// Copyright (c) 2019-2024 Markku Rossi
//
// All rights reserved.
//
package ast
import (
"fmt"
"path"
"source.quilibrium.com/quilibrium/monorepo/bedlam/circuit"
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler/circuits"
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler/ssa"
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler/utils"
"source.quilibrium.com/quilibrium/monorepo/bedlam/types"
)
// Builtin implements QCL builtin functions.
type Builtin struct {
SSA SSA
Eval Eval
}
// SSA implements the builtin SSA generation.
type SSA func(block *ssa.Block, ctx *Codegen, gen *ssa.Generator,
args []ssa.Value, loc utils.Point) (*ssa.Block, []ssa.Value, error)
// Eval implements the builtin evaluation in constant folding.
type Eval func(args []AST, env *Env, ctx *Codegen, gen *ssa.Generator,
loc utils.Point) (ssa.Value, bool, error)
// Predeclared identifiers.
var builtins = map[string]Builtin{
"floorPow2": {
SSA: floorPow2SSA,
Eval: floorPow2Eval,
},
"len": {
SSA: lenSSA,
Eval: lenEval,
},
"native": {
SSA: nativeSSA,
},
"panic": {
SSA: panicSSA,
Eval: panicEval,
},
"size": {
SSA: sizeSSA,
Eval: sizeEval,
},
}
func floorPow2SSA(block *ssa.Block, ctx *Codegen, gen *ssa.Generator,
args []ssa.Value, loc utils.Point) (*ssa.Block, []ssa.Value, error) {
return nil, nil, ctx.Errorf(loc, "floorPow2SSA not implemented")
}
func floorPow2Eval(args []AST, env *Env, ctx *Codegen, gen *ssa.Generator,
loc utils.Point) (ssa.Value, bool, error) {
if len(args) != 1 {
return ssa.Undefined, false, ctx.Errorf(loc,
"invalid amount of arguments in call to floorPow2")
}
constVal, _, err := args[0].Eval(env, ctx, gen)
if err != nil {
return ssa.Undefined, false, ctx.Errorf(loc, "%s", err)
}
val, err := constVal.ConstInt()
if err != nil {
return ssa.Undefined, false, ctx.Errorf(loc,
"non-integer (%T) argument in %s: %s", constVal, args[0], err)
}
var i types.Size
for i = 1; i <= val; i <<= 1 {
}
i >>= 1
return gen.Constant(int64(i), types.Undefined), true, nil
}
func lenSSA(block *ssa.Block, ctx *Codegen, gen *ssa.Generator,
args []ssa.Value, loc utils.Point) (*ssa.Block, []ssa.Value, error) {
if len(args) != 1 {
return nil, nil, ctx.Errorf(loc,
"invalid amount of arguments in call to len")
}
var val types.Size
switch args[0].Type.Type {
case types.TString:
val = args[0].Type.Bits / types.ByteBits
case types.TArray, types.TSlice:
val = args[0].Type.ArraySize
case types.TNil:
val = 0
default:
return nil, nil, ctx.Errorf(loc, "invalid argument 1 (type %s) for len",
args[0].Type)
}
v := gen.Constant(int64(val), types.Undefined)
gen.AddConstant(v)
return block, []ssa.Value{v}, nil
}
func lenEval(args []AST, env *Env, ctx *Codegen, gen *ssa.Generator,
loc utils.Point) (ssa.Value, bool, error) {
if len(args) != 1 {
return ssa.Undefined, false, ctx.Errorf(loc,
"invalid amount of arguments in call to len")
}
switch arg := args[0].(type) {
case *VariableRef:
var typeInfo types.Info
if len(arg.Name.Package) > 0 {
// Check if the package name is bound to a value.
b, ok := env.Get(arg.Name.Package)
if ok {
if b.Type.Type != types.TStruct {
return ssa.Undefined, false, ctx.Errorf(loc,
"%s undefined", arg.Name)
}
ok = false
for _, f := range b.Type.Struct {
if f.Name == arg.Name.Name {
typeInfo = f.Type
ok = true
break
}
}
if !ok {
return ssa.Undefined, false, ctx.Errorf(loc,
"undefined variable '%s'", arg.Name)
}
} else {
// Resolve name from the package.
pkg, ok := ctx.Packages[arg.Name.Package]
if !ok {
return ssa.Undefined, false, ctx.Errorf(loc,
"package '%s' not found", arg.Name.Package)
}
b, ok := pkg.Bindings.Get(arg.Name.Name)
if !ok {
return ssa.Undefined, false, ctx.Errorf(loc,
"undefined variable '%s'", arg.Name)
}
typeInfo = b.Type
}
} else {
b, ok := env.Get(arg.Name.Name)
if !ok {
return ssa.Undefined, false, ctx.Errorf(loc,
"undefined variable '%s'", arg.Name)
}
typeInfo = b.Type
}
if typeInfo.Type == types.TPtr {
typeInfo = *typeInfo.ElementType
}
switch typeInfo.Type {
case types.TString:
return gen.Constant(int64(typeInfo.Bits/types.ByteBits),
types.Undefined), true, nil
case types.TArray, types.TSlice:
return gen.Constant(int64(typeInfo.ArraySize), types.Undefined),
true, nil
default:
return ssa.Undefined, false, ctx.Errorf(loc,
"invalid argument 1 (type %s) for len", typeInfo)
}
default:
return ssa.Undefined, false, ctx.Errorf(loc,
"len(%v/%T) is not constant", arg, arg)
}
}
func nativeSSA(block *ssa.Block, ctx *Codegen, gen *ssa.Generator,
args []ssa.Value, loc utils.Point) (*ssa.Block, []ssa.Value, error) {
if len(args) < 1 {
return nil, nil, ctx.Errorf(loc,
"not enough arguments in call to native")
}
name, ok := args[0].ConstValue.(string)
if !args[0].Const || !ok {
return nil, nil, ctx.Errorf(loc,
"not enough arguments in call to native")
}
// Our native name constant is not needed in the implementation.
gen.RemoveConstant(args[0])
args = args[1:]
switch name {
case "hamming":
if len(args) != 2 {
return nil, nil, ctx.Errorf(loc,
"invalid amount of arguments in call to '%s'", name)
}
var typeInfo types.Info
for _, arg := range args {
if arg.Type.Bits > typeInfo.Bits {
typeInfo = arg.Type
}
}
v := gen.AnonVal(typeInfo)
block.AddInstr(ssa.NewBuiltinInstr(circuits.Hamming, args[0], args[1],
v))
return block, []ssa.Value{v}, nil
default:
if circuit.IsFilename(name) {
return nativeCircuit(name, block, ctx, gen, args, loc)
}
return nil, nil, ctx.Errorf(loc, "unknown native '%s'", name)
}
}
func nativeCircuit(name string, block *ssa.Block, ctx *Codegen,
gen *ssa.Generator, args []ssa.Value, loc utils.Point) (
*ssa.Block, []ssa.Value, error) {
dir := path.Dir(loc.Source)
fp := path.Join(dir, name)
var err error
circ, ok := ctx.Native[fp]
if !ok {
circ, err = circuit.Parse(fp)
if err != nil {
return nil, nil, ctx.Errorf(loc, "failed to parse circuit: %s", err)
}
circ.AssignLevels()
ctx.Native[fp] = circ
if ctx.Verbose {
fmt.Printf(" - native %s: %v\n", name, circ)
}
} else if ctx.Verbose {
fmt.Printf(" - native %s: cached\n", name)
}
if len(circ.Inputs) > len(args) {
return nil, nil, ctx.Errorf(loc,
"not enough arguments in call to native")
} else if len(circ.Inputs) < len(args) {
return nil, nil, ctx.Errorf(loc, "too many argument in call to native")
}
// Check that the argument types match.
for idx, io := range circ.Inputs {
arg := args[idx]
if io.Type.Bits < arg.Type.Bits || io.Type.Bits > arg.Type.Bits &&
!arg.Const {
return nil, nil, ctx.Errorf(loc,
"invalid argument %d for native circuit: got %s, need %d",
idx, arg.Type, io.Type.Bits)
}
}
var result []ssa.Value
for _, io := range circ.Outputs {
result = append(result, gen.AnonVal(types.Info{
Type: types.TUndefined,
IsConcrete: true,
Bits: io.Type.Bits,
}))
}
block.AddInstr(ssa.NewCircInstr(args, circ, result))
return block, result, nil
}
func panicSSA(block *ssa.Block, ctx *Codegen, gen *ssa.Generator,
args []ssa.Value, loc utils.Point) (*ssa.Block, []ssa.Value, error) {
var arr []string
for _, arg := range args {
var str string
if arg.Const {
str = fmt.Sprintf("%v", arg.ConstValue)
} else {
str = arg.String()
}
arr = append(arr, str)
}
return nil, nil, ctx.Errorf(loc, "panic: %v", panicMessage(arr))
}
func panicEval(args []AST, env *Env, ctx *Codegen, gen *ssa.Generator,
loc utils.Point) (ssa.Value, bool, error) {
var arr []string
for _, arg := range args {
arr = append(arr, arg.String())
}
return ssa.Undefined, false, ctx.Errorf(loc, "panic: %v", panicMessage(arr))
}
func panicMessage(args []string) string {
if len(args) == 0 {
return ""
}
format := args[0]
args = args[1:]
var result string
for i := 0; i < len(format); i++ {
if format[i] != '%' || i+1 >= len(format) {
result += string(format[i])
continue
}
i++
if len(args) == 0 {
result += fmt.Sprintf("%%!%c(MISSING)", format[i])
continue
}
switch format[i] {
case 'v':
result += args[0]
default:
result += fmt.Sprintf("%%!%c(%v)", format[i], args[0])
}
args = args[1:]
}
for _, arg := range args {
result += " " + arg
}
return result
}
func sizeSSA(block *ssa.Block, ctx *Codegen, gen *ssa.Generator,
args []ssa.Value, loc utils.Point) (*ssa.Block, []ssa.Value, error) {
if len(args) != 1 {
return nil, nil, ctx.Errorf(loc,
"invalid amount of arguments in call to size")
}
v := gen.Constant(int64(args[0].Type.Bits), types.Undefined)
gen.AddConstant(v)
return block, []ssa.Value{v}, nil
}
func sizeEval(args []AST, env *Env, ctx *Codegen, gen *ssa.Generator,
loc utils.Point) (ssa.Value, bool, error) {
if len(args) != 1 {
return ssa.Undefined, false, ctx.Errorf(loc,
"invalid amount of arguments in call to size")
}
switch arg := args[0].(type) {
case *VariableRef:
var b ssa.Binding
var ok bool
if len(arg.Name.Package) > 0 {
var pkg *Package
pkg, ok = ctx.Packages[arg.Name.Package]
if !ok {
return ssa.Undefined, false, ctx.Errorf(loc,
"package '%s' not found", arg.Name.Package)
}
b, ok = pkg.Bindings.Get(arg.Name.Name)
} else {
b, ok = env.Get(arg.Name.Name)
}
if !ok {
return ssa.Undefined, false, ctx.Errorf(loc,
"undefined variable '%s'", arg.Name.String())
}
return gen.Constant(int64(b.Type.Bits), types.Undefined), true, nil
default:
return ssa.Undefined, false, ctx.Errorf(loc,
"size(%v/%T) is not constant", arg, arg)
}
}

View File

@ -0,0 +1,219 @@
//
// ast.go
//
// Copyright (c) 2019-2024 Markku Rossi
//
// All rights reserved.
//
package ast
import (
"fmt"
"runtime"
"strings"
"source.quilibrium.com/quilibrium/monorepo/bedlam/circuit"
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler/ssa"
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler/utils"
"source.quilibrium.com/quilibrium/monorepo/bedlam/types"
)
// Codegen implements compilation stack.
type Codegen struct {
logger *utils.Logger
Params *utils.Params
Verbose bool
Package *Package
Packages map[string]*Package
MainInputSizes [][]int
Stack []Compilation
Types map[types.ID]*TypeInfo
Native map[string]*circuit.Circuit
HeapID int
}
// NewCodegen creates a new compilation.
func NewCodegen(logger *utils.Logger, pkg *Package,
packages map[string]*Package, params *utils.Params,
mainInputSizes [][]int) *Codegen {
return &Codegen{
logger: logger,
Params: params,
Verbose: params.Verbose,
Package: pkg,
Packages: packages,
MainInputSizes: mainInputSizes,
Types: make(map[types.ID]*TypeInfo),
Native: make(map[string]*circuit.Circuit),
}
}
func (ctx *Codegen) errorLoc(err error) error {
if !ctx.Params.QCLCErrorLoc {
return err
}
for skip := 2; ; skip++ {
pc, file, line, ok := runtime.Caller(skip)
if !ok {
return err
}
f := runtime.FuncForPC(pc)
if f != nil && strings.HasSuffix(f.Name(), ".errf") {
continue
}
fmt.Printf("%s:%d: MCPLC error:\n\u2514\u2574%s\n", file, line, err)
return err
}
}
// Error logs an error message.
func (ctx *Codegen) Error(locator utils.Locator, msg string) error {
return ctx.errorLoc(ctx.logger.Errorf(locator.Location(), "%s", msg))
}
// Errorf logs an error message.
func (ctx *Codegen) Errorf(locator utils.Locator, format string,
a ...interface{}) error {
return ctx.errorLoc(ctx.logger.Errorf(locator.Location(), format, a...))
}
// Warningf logs a warning message
func (ctx *Codegen) Warningf(locator utils.Locator, format string,
a ...interface{}) {
ctx.logger.Warningf(locator.Location(), format, a...)
}
// DefineType defines the argument type and assigns it an unique type
// ID.
func (ctx *Codegen) DefineType(t *TypeInfo) types.ID {
id := types.ID(len(ctx.Types) + 0x80000000)
ctx.Types[id] = t
return id
}
// LookupFunc resolves the named function from the context.
func (ctx *Codegen) LookupFunc(block *ssa.Block, ref *VariableRef) (
*Func, error) {
// First, check method calls.
if len(ref.Name.Package) > 0 {
// Check if package name is bound to a value.
var b ssa.Binding
var ok bool
b, ok = block.Bindings.Get(ref.Name.Package)
if !ok {
// Check names in the current package.
b, ok = ctx.Package.Bindings.Get(ref.Name.Package)
}
if ok {
var typeInfo types.Info
if b.Type.Type == types.TPtr {
typeInfo = *b.Type.ElementType
} else {
typeInfo = b.Type
}
info, ok := ctx.Types[typeInfo.ID]
if !ok {
return nil, ctx.Errorf(ref, "%s undefined", ref)
}
method, ok := info.Methods[ref.Name.Name]
if !ok {
return nil, ctx.Errorf(ref, "%s undefined", ref)
}
return method, nil
}
}
// Next, check function calls.
var pkgName string
if len(ref.Name.Package) > 0 {
pkgName = ref.Name.Package
} else {
pkgName = ref.Name.Defined
}
pkg, ok := ctx.Packages[pkgName]
if !ok {
return nil, ctx.Errorf(ref, "package '%s' not found", pkgName)
}
called, ok := pkg.Functions[ref.Name.Name]
if !ok {
return nil, nil
}
return called, nil
}
// Func returns the current function in the current compilation.
func (ctx *Codegen) Func() *Func {
if len(ctx.Stack) == 0 {
return nil
}
return ctx.Stack[len(ctx.Stack)-1].Called
}
// Scope returns the value scope in the current compilation.
func (ctx *Codegen) Scope() ssa.Scope {
if ctx.Func() != nil {
return 1
}
return 0
}
// PushCompilation pushes a new compilation to the compilation stack.
func (ctx *Codegen) PushCompilation(start, ret, caller *ssa.Block,
called *Func) {
ctx.Stack = append(ctx.Stack, Compilation{
Start: start,
Return: ret,
Caller: caller,
Called: called,
})
}
// PopCompilation pops the topmost compilation from the compilation
// stack.
func (ctx *Codegen) PopCompilation() {
if len(ctx.Stack) == 0 {
panic("compilation stack underflow")
}
ctx.Stack = ctx.Stack[:len(ctx.Stack)-1]
}
// Start returns the start block of the current compilation.
func (ctx *Codegen) Start() *ssa.Block {
return ctx.Stack[len(ctx.Stack)-1].Start
}
// Return returns the return block of the current compilation.
func (ctx *Codegen) Return() *ssa.Block {
return ctx.Stack[len(ctx.Stack)-1].Return
}
// Caller returns the caller block of the current compilation.
func (ctx *Codegen) Caller() *ssa.Block {
return ctx.Stack[len(ctx.Stack)-1].Caller
}
// HeapVar returns the name of the next global heap variable.
func (ctx *Codegen) HeapVar() string {
name := fmt.Sprintf("$heap%v", ctx.HeapID)
ctx.HeapID++
return name
}
// Compilation contains information about a compilation
// scope. Toplevel, each function call, and each nested block specify
// their own scope with their own variable bindings.
type Compilation struct {
Start *ssa.Block
Return *ssa.Block
Caller *ssa.Block
Called *Func
// XXX Bindings
// XXX Parent scope.
}

684
bedlam/compiler/ast/eval.go Normal file
View File

@ -0,0 +1,684 @@
//
// Copyright (c) 2019-2024 Markku Rossi
//
// All rights reserved.
//
package ast
import (
"fmt"
"math"
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler/mpa"
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler/ssa"
"source.quilibrium.com/quilibrium/monorepo/bedlam/types"
)
const (
debugEval = false
)
// Eval implements the compiler.ast.AST.Eval for list statements.
func (ast List) Eval(env *Env, ctx *Codegen, gen *ssa.Generator) (
ssa.Value, bool, error) {
return ssa.Undefined, false, ctx.Errorf(ast, "List.Eval not implemented")
}
// Eval implements the compiler.ast.AST.Eval for function definitions.
func (ast *Func) Eval(env *Env, ctx *Codegen, gen *ssa.Generator) (
ssa.Value, bool, error) {
return ssa.Undefined, false, nil
}
// Eval implements the compiler.ast.AST.Eval for variable definitions.
func (ast *VariableDef) Eval(env *Env, ctx *Codegen,
gen *ssa.Generator) (ssa.Value, bool, error) {
return ssa.Undefined, false, nil
}
// Eval implements the compiler.ast.AST.Eval for assignment expressions.
func (ast *Assign) Eval(env *Env, ctx *Codegen, gen *ssa.Generator) (
ssa.Value, bool, error) {
var values []interface{}
for _, expr := range ast.Exprs {
val, ok, err := expr.Eval(env, ctx, gen)
if err != nil || !ok {
return ssa.Undefined, ok, err
}
// XXX multiple return values.
values = append(values, val)
}
if len(ast.LValues) != len(values) {
return ssa.Undefined, false, ctx.Errorf(ast,
"assignment mismatch: %d variables but %d values",
len(ast.LValues), len(values))
}
arrType := types.Info{
Type: types.TArray,
IsConcrete: true,
ArraySize: types.Size(len(values)),
}
if ast.Define {
for idx, lv := range ast.LValues {
constVal := gen.Constant(values[idx], types.Undefined)
gen.AddConstant(constVal)
arrType.ElementType = &constVal.Type
ref, ok := lv.(*VariableRef)
if !ok {
return ssa.Undefined, false,
ctx.Errorf(ast, "cannot assign to %s", lv)
}
// XXX package.name below
lValue := gen.NewVal(ref.Name.Name, constVal.Type, ctx.Scope())
env.Set(lValue, &constVal)
}
} else {
for idx, lv := range ast.LValues {
ref, ok := lv.(*VariableRef)
if !ok {
return ssa.Undefined, false,
ctx.Errorf(ast, "cannot assign to %s", lv)
}
// XXX package.name below
b, ok := env.Get(ref.Name.Name)
if !ok {
return ssa.Undefined, false,
ctx.Errorf(ast, "undefined variable '%s'", ref.Name)
}
lValue := gen.NewVal(b.Name, b.Type, ctx.Scope())
constVal := gen.Constant(values[idx], b.Type)
gen.AddConstant(constVal)
arrType.ElementType = &constVal.Type
env.Set(lValue, &constVal)
}
}
return gen.Constant(values, arrType), true, nil
}
// Eval implements the compiler.ast.AST.Eval for if statements.
func (ast *If) Eval(env *Env, ctx *Codegen, gen *ssa.Generator) (
ssa.Value, bool, error) {
return ssa.Undefined, false, nil
}
// Eval implements the compiler.ast.AST.Eval for call expressions.
func (ast *Call) Eval(env *Env, ctx *Codegen, gen *ssa.Generator) (
ssa.Value, bool, error) {
if debugEval {
fmt.Printf("Call.Eval: %s(", ast.Ref)
for idx, expr := range ast.Exprs {
if idx > 0 {
fmt.Print(", ")
}
fmt.Printf("%v", expr)
}
fmt.Println(")")
}
// Resolve called.
var pkgName string
if len(ast.Ref.Name.Package) > 0 {
pkgName = ast.Ref.Name.Package
} else {
pkgName = ast.Ref.Name.Defined
}
pkg, ok := ctx.Packages[pkgName]
if !ok {
return ssa.Undefined, false,
ctx.Errorf(ast, "package '%s' not found", pkgName)
}
_, ok = pkg.Functions[ast.Ref.Name.Name]
if ok {
return ssa.Undefined, false, nil
}
// Check builtin functions.
bi, ok := builtins[ast.Ref.Name.Name]
if ok && bi.Eval != nil {
return bi.Eval(ast.Exprs, env, ctx, gen, ast.Location())
}
// Resolve name as type.
typeName := &TypeInfo{
Point: ast.Point,
Type: TypeName,
Name: ast.Ref.Name,
}
typeInfo, err := typeName.Resolve(env, ctx, gen)
if err != nil {
return ssa.Undefined, false, err
}
if len(ast.Exprs) != 1 {
return ssa.Undefined, false, nil
}
constVal, ok, err := ast.Exprs[0].Eval(env, ctx, gen)
if err != nil {
return ssa.Undefined, false, err
}
if !ok {
return ssa.Undefined, false, nil
}
switch typeInfo.Type {
case types.TInt, types.TUint:
switch constVal.Type.Type {
case types.TInt, types.TUint:
if !typeInfo.Concrete() {
typeInfo.Bits = constVal.Type.Bits
typeInfo.SetConcrete(true)
}
if constVal.Type.MinBits > typeInfo.Bits {
typeInfo.MinBits = typeInfo.Bits
} else {
typeInfo.MinBits = constVal.Type.MinBits
}
cast := constVal
cast.Type = typeInfo
if constVal.HashCode() != cast.HashCode() {
panic("const cast changes value HashCode")
}
if !constVal.Equal(&cast) {
panic("const cast changes value equality")
}
return cast, true, nil
default:
return ssa.Undefined, false,
ctx.Errorf(ast.Ref, "casting %T not supported", constVal.Type)
}
}
return ssa.Undefined, false, nil
}
// Eval implements the compiler.ast.AST.Eval.
func (ast *ArrayCast) Eval(env *Env, ctx *Codegen, gen *ssa.Generator) (
ssa.Value, bool, error) {
typeInfo, err := ast.TypeInfo.Resolve(env, ctx, gen)
if err != nil {
return ssa.Undefined, false, err
}
if !typeInfo.Type.Array() {
return ssa.Undefined, false,
ctx.Errorf(ast.Expr, "array cast to non-array type %v", typeInfo)
}
cv, ok, err := ast.Expr.Eval(env, ctx, gen)
if err != nil {
return ssa.Undefined, false, err
}
if !ok {
return ssa.Undefined, false, nil
}
switch cv.Type.Type {
case types.TString:
if cv.Type.Bits%8 != 0 {
return ssa.Undefined, false,
ctx.Errorf(ast.Expr, "invalid string length %v", cv.Type.Bits)
}
chars := cv.Type.Bits / 8
et := typeInfo.ElementType
if et.Bits != 8 || et.Type != types.TUint {
return ssa.Undefined, false,
ctx.Errorf(ast.Expr, "cast from %v to %v",
cv.Type, ast.TypeInfo)
}
if typeInfo.Concrete() {
if typeInfo.ArraySize != chars || typeInfo.Bits != cv.Type.Bits {
return ssa.Undefined, false,
ctx.Errorf(ast.Expr, "cast from %v to %v",
cv.Type, ast.TypeInfo)
}
} else {
typeInfo.Bits = cv.Type.Bits
typeInfo.ArraySize = chars
typeInfo.SetConcrete(true)
}
cast := cv
cast.Type = typeInfo
if cv.HashCode() != cast.HashCode() {
panic("const array cast changes value HashCode")
}
if !cv.Equal(&cast) {
panic("const array cast changes value equality")
}
return cast, true, nil
}
return ssa.Undefined, false, nil
}
// Eval implements the compiler.ast.AST.Eval for return statements.
func (ast *Return) Eval(env *Env, ctx *Codegen, gen *ssa.Generator) (
ssa.Value, bool, error) {
return ssa.Undefined, false, nil
}
// Eval implements the compiler.ast.AST.Eval for for statements.
func (ast *For) Eval(env *Env, ctx *Codegen, gen *ssa.Generator) (
ssa.Value, bool, error) {
return ssa.Undefined, false, nil
}
// Eval implements the compiler.ast.AST.Eval.
func (ast *ForRange) Eval(env *Env, ctx *Codegen, gen *ssa.Generator) (
ssa.Value, bool, error) {
return ssa.Undefined, false, nil
}
// Eval implements the compiler.ast.AST.Eval for binary expressions.
func (ast *Binary) Eval(env *Env, ctx *Codegen, gen *ssa.Generator) (
ssa.Value, bool, error) {
l, ok, err := ast.Left.Eval(env, ctx, gen)
if err != nil || !ok {
return ssa.Undefined, ok, err
}
r, ok, err := ast.Right.Eval(env, ctx, gen)
if err != nil || !ok {
return ssa.Undefined, ok, err
}
if debugEval {
fmt.Printf("%s: Binary.Eval: %v[%v:%v] %v %v[%v:%v]\n",
ast.Location().ShortString(),
ast.Left, l, l.Type, ast.Op, ast.Right, r, r.Type)
}
// Resolve result type.
rt, err := ast.resultType(ctx, l, r)
if err != nil {
return ssa.Undefined, false, err
}
switch lval := l.ConstValue.(type) {
case bool:
rval, ok := r.ConstValue.(bool)
if !ok {
return ssa.Undefined, false, ctx.Errorf(ast.Right,
"%s %v %s: invalid r-value %v (%T)", l, ast.Op, r, rval, rval)
}
switch ast.Op {
case BinaryEq:
return gen.Constant(lval == rval, types.Bool), true, nil
case BinaryNeq:
return gen.Constant(lval != rval, types.Bool), true, nil
case BinaryAnd:
return gen.Constant(lval && rval, types.Bool), true, nil
case BinaryOr:
return gen.Constant(lval || rval, types.Bool), true, nil
}
case *mpa.Int:
rval, ok := r.ConstValue.(*mpa.Int)
if !ok {
return ssa.Undefined, false, ctx.Errorf(ast.Right,
"%s %v %s: invalid r-value %v (%T)", l, ast.Op, r, rval, rval)
}
switch ast.Op {
case BinaryMul:
return gen.Constant(mpa.New(rt.Bits).Mul(lval, rval), rt),
true, nil
case BinaryDiv:
return gen.Constant(mpa.New(rt.Bits).Div(lval, rval), rt),
true, nil
case BinaryMod:
return gen.Constant(mpa.New(rt.Bits).Mod(lval, rval), rt),
true, nil
case BinaryLshift:
return gen.Constant(mpa.New(rt.Bits).Lsh(lval, uint(rval.Int64())),
rt), true, nil
case BinaryRshift:
return gen.Constant(mpa.New(rt.Bits).Rsh(lval, uint(rval.Int64())),
rt), true, nil
case BinaryBand:
return gen.Constant(mpa.New(rt.Bits).And(lval, rval), rt),
true, nil
case BinaryBclear:
return gen.Constant(mpa.New(rt.Bits).AndNot(lval, rval), rt),
true, nil
case BinaryBor:
return gen.Constant(mpa.New(rt.Bits).Or(lval, rval), rt),
true, nil
case BinaryBxor:
return gen.Constant(mpa.New(rt.Bits).Xor(lval, rval), rt),
true, nil
case BinaryAdd:
return gen.Constant(mpa.New(rt.Bits).Add(lval, rval), rt),
true, nil
case BinarySub:
return gen.Constant(mpa.New(rt.Bits).Sub(lval, rval), rt),
true, nil
case BinaryEq:
return gen.Constant(lval.Cmp(rval) == 0, types.Bool), true, nil
case BinaryNeq:
return gen.Constant(lval.Cmp(rval) != 0, types.Bool), true, nil
case BinaryLt:
return gen.Constant(lval.Cmp(rval) == -1, types.Bool), true, nil
case BinaryLe:
return gen.Constant(lval.Cmp(rval) != 1, types.Bool), true, nil
case BinaryGt:
return gen.Constant(lval.Cmp(rval) == 1, types.Bool), true, nil
case BinaryGe:
return gen.Constant(lval.Cmp(rval) != -1, types.Bool), true, nil
}
case string:
rval, ok := r.ConstValue.(string)
if !ok {
return ssa.Undefined, false, ctx.Errorf(ast.Right,
"%s %v %s: invalid r-value %v (%T)", l, ast.Op, r, rval, rval)
}
switch ast.Op {
case BinaryAdd:
return gen.Constant(lval+rval, types.Undefined), true, nil
}
}
return ssa.Undefined, false, ctx.Errorf(ast.Right,
"invalid operation: operator %s not defined on %v (%v)",
ast.Op, l, l.Type)
}
// Eval implements the compiler.ast.AST.Eval for unary expressions.
func (ast *Unary) Eval(env *Env, ctx *Codegen, gen *ssa.Generator) (
ssa.Value, bool, error) {
expr, ok, err := ast.Expr.Eval(env, ctx, gen)
if err != nil || !ok {
return ssa.Undefined, ok, err
}
switch val := expr.ConstValue.(type) {
case bool:
switch ast.Type {
case UnaryNot:
return gen.Constant(!val, types.Bool), true, nil
}
case *mpa.Int:
switch ast.Type {
case UnaryMinus:
r := mpa.NewInt(0, expr.Type.Bits)
return gen.Constant(r.Sub(r, val), expr.Type), true, nil
}
}
return ssa.Undefined, false, ctx.Errorf(ast.Expr,
"invalid unary expression: %s%T", ast.Type, ast.Expr)
}
// Eval implements the compiler.ast.AST.Eval for slice expressions.
func (ast *Slice) Eval(env *Env, ctx *Codegen, gen *ssa.Generator) (
ssa.Value, bool, error) {
expr, ok, err := ast.Expr.Eval(env, ctx, gen)
if err != nil || !ok {
return ssa.Undefined, ok, err
}
from := 0
to := math.MaxInt32
if ast.From != nil {
val, ok, err := ast.From.Eval(env, ctx, gen)
if err != nil || !ok {
return ssa.Undefined, ok, err
}
from, err = intVal(val)
if err != nil {
return ssa.Undefined, false, ctx.Errorf(ast.From, err.Error())
}
}
if ast.To != nil {
val, ok, err := ast.To.Eval(env, ctx, gen)
if err != nil || !ok {
return ssa.Undefined, ok, err
}
to, err = intVal(val)
if err != nil {
return ssa.Undefined, false, ctx.Errorf(ast.To, err.Error())
}
}
if to < from {
return ssa.Undefined, false, ctx.Errorf(ast.Expr,
"invalid slice range %d:%d", from, to)
}
if !expr.Type.Type.Array() {
return ssa.Undefined, false, ctx.Errorf(ast.Expr,
"invalid operation: cannot slice %v (%v)", expr, expr.Type)
}
arr, err := expr.ConstArray()
if err != nil {
return ssa.Undefined, false, err
}
if to == math.MaxInt32 {
to = int(expr.Type.ArraySize)
}
if to > int(expr.Type.ArraySize) || from > to {
return ssa.Undefined, false, ctx.Errorf(ast.From,
"slice bounds out of range [%d:%d] in slice of length %v",
from, to, expr.Type.ArraySize)
}
numElements := to - from
switch val := arr.(type) {
case []interface{}:
ti := expr.Type
ti.ArraySize = types.Size(numElements)
// The gen.Constant will set the bit sizes.
return gen.Constant(val[from:to], ti), true, nil
case []byte:
constVal := make([]interface{}, numElements)
for i := 0; i < numElements; i++ {
constVal[i] = int64(val[from+i])
}
ti := expr.Type
ti.ArraySize = types.Size(numElements)
// The gen.Constant will set the bit sizes.
return gen.Constant(constVal, ti), true, nil
default:
return ssa.Undefined, false, ctx.Errorf(ast.Expr,
"invalid operation: cannot slice %T array", arr)
}
}
func intVal(val interface{}) (int, error) {
switch v := val.(type) {
case *mpa.Int:
return int(v.Int64()), nil
case ssa.Value:
if !v.Const {
return 0, fmt.Errorf("non-const slice index: %v", v)
}
return intVal(v.ConstValue)
default:
return 0, fmt.Errorf("invalid slice index: %T", v)
}
}
// Eval implements the compiler.ast.AST.Eval() for index expressions.
func (ast *Index) Eval(env *Env, ctx *Codegen, gen *ssa.Generator) (
ssa.Value, bool, error) {
expr, ok, err := ast.Expr.Eval(env, ctx, gen)
if err != nil || !ok {
return ssa.Undefined, ok, err
}
val, ok, err := ast.Index.Eval(env, ctx, gen)
if err != nil || !ok {
return ssa.Undefined, ok, err
}
index, err := intVal(val)
if err != nil {
return ssa.Undefined, false, ctx.Errorf(ast.Index, err.Error())
}
switch expr.Type.Type {
case types.TArray, types.TSlice:
if index < 0 || index >= int(expr.Type.ArraySize) {
return ssa.Undefined, false, ctx.Errorf(ast.Index,
"invalid array index %d (out of bounds for %d-element array)",
index, expr.Type.ArraySize)
}
arr, err := expr.ConstArray()
if err != nil {
return ssa.Undefined, false, err
}
switch val := arr.(type) {
case []interface{}:
return gen.Constant(val[index], *expr.Type.ElementType), true, nil
case []byte:
return gen.Constant(int64(val[index]), *expr.Type.ElementType),
true, nil
}
case types.TString:
numBytes := expr.Type.Bits / types.ByteBits
if index < 0 || index >= int(numBytes) {
return ssa.Undefined, false, ctx.Errorf(ast.Index,
"invalid array index %d (out of bounds for %d-element string)",
index, numBytes)
}
str, err := expr.ConstString()
if err != nil {
return ssa.Undefined, false, err
}
bytes := []byte(str)
return gen.Constant(int64(bytes[index]), types.Byte), true, nil
}
return ssa.Undefined, false, ctx.Errorf(ast.Expr,
"invalid operation: cannot index %v (%v)", expr, expr.Type)
}
// Eval implements the compiler.ast.AST.Eval for variable references.
func (ast *VariableRef) Eval(env *Env, ctx *Codegen, gen *ssa.Generator) (
ssa.Value, bool, error) {
lrv, ok, _, err := ctx.LookupVar(nil, gen, env.Bindings, ast)
if err != nil {
return ssa.Undefined, false, ctx.Error(ast, err.Error())
}
if !ok {
return ssa.Undefined, ok, nil
}
return lrv.ConstValue()
}
// Eval implements the compiler.ast.AST.Eval for constant values.
func (ast *BasicLit) Eval(env *Env, ctx *Codegen, gen *ssa.Generator) (
ssa.Value, bool, error) {
return gen.Constant(ast.Value, types.Undefined), true, nil
}
// Eval implements the compiler.ast.AST.Eval for constant values.
func (ast *CompositeLit) Eval(env *Env, ctx *Codegen, gen *ssa.Generator) (
ssa.Value, bool, error) {
// XXX the init values might be short so we must pad them with
// zero values so that we create correctly sized values.
typeInfo, err := ast.Type.Resolve(env, ctx, gen)
if err != nil {
return ssa.Undefined, false, err
}
switch typeInfo.Type {
case types.TStruct:
// Check if all elements are constants.
var values []interface{}
for _, el := range ast.Value {
// XXX check if el.Key is specified
v, ok, err := el.Element.Eval(env, ctx, gen)
if err != nil || !ok {
return ssa.Undefined, ok, err
}
// XXX check that v is assignment compatible with typeInfo.Struct[i]
values = append(values, v)
}
return gen.Constant(values, typeInfo), true, nil
case types.TArray, types.TSlice:
// Check if all elements are constants.
var values []interface{}
for _, el := range ast.Value {
// XXX check if el.Key is specified
v, ok, err := el.Element.Eval(env, ctx, gen)
if err != nil || !ok {
return ssa.Undefined, ok, err
}
// XXX check that v is assignment compatible with array.
values = append(values, v)
}
typeInfo.ArraySize = types.Size(len(values))
typeInfo.Bits = typeInfo.ArraySize * typeInfo.ElementType.Bits
typeInfo.MinBits = typeInfo.Bits
return gen.Constant(values, typeInfo), true, nil
default:
fmt.Printf("CompositeLit.Eval: not implemented yet: %v, Value: %v\n",
typeInfo, ast.Value)
return ssa.Undefined, false, nil
}
}
// Eval implements the compiler.ast.AST.Eval for the builtin function make.
func (ast *Make) Eval(env *Env, ctx *Codegen, gen *ssa.Generator) (
ssa.Value, bool, error) {
if len(ast.Exprs) != 1 {
return ssa.Undefined, false, ctx.Errorf(ast,
"invalid amount of argument in call to make")
}
typeInfo, err := ast.Type.Resolve(env, ctx, gen)
if err != nil {
return ssa.Undefined, false, ctx.Errorf(ast.Type, "%s is not a type",
ast.Type)
}
if typeInfo.Type.Array() {
// Arrays are made in Make.SSA.
return ssa.Undefined, false, nil
}
if typeInfo.Bits != 0 {
return ssa.Undefined, false, ctx.Errorf(ast.Type,
"can't make specified type %s", typeInfo)
}
constVal, _, err := ast.Exprs[0].Eval(env, ctx, gen)
if err != nil {
return ssa.Undefined, false, ctx.Error(ast.Exprs[0], err.Error())
}
length, err := constVal.ConstInt()
if err != nil {
return ssa.Undefined, false, ctx.Errorf(ast.Exprs[0],
"non-integer (%T) len argument in %s: %s", constVal, ast, err)
}
typeInfo.IsConcrete = true
typeInfo.Bits = length
// Create typeref constant.
return gen.Constant(typeInfo, types.Undefined), true, nil
}
// Eval implements the compiler.ast.AST.Eval for the builtin function copy.
func (ast *Copy) Eval(env *Env, ctx *Codegen, gen *ssa.Generator) (
ssa.Value, bool, error) {
return ssa.Undefined, false, nil
}

View File

@ -0,0 +1,367 @@
//
// ast.go
//
// Copyright (c) 2019-2024 Markku Rossi
//
// All rights reserved.
//
package ast
import (
"fmt"
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler/ssa"
"source.quilibrium.com/quilibrium/monorepo/bedlam/types"
)
// LRValue implements value as l-value or r-value. The LRValues have
// two types:
//
// 1. base type that specifies the base memory location containing
// the value
// 2. value type that specifies the wires of the value
//
// Both types can be the same. The base and value types are set as
// follows for different names:
//
// 1. Struct.Field:
// - baseInfo points to the containing variable
// - baseValue is the Struct
// - structField defines the structure field
// - valueType is the type of the Struct.Field
// - value is nil
//
// 2. Package.Name:
// - baseInfo points to the containing variable in Package
// - baseValue is the value of Package.Name
// - structField is nil
// - valueType is the type of Package.Name
// - value is the value of Package.Name
//
// 3. Name:
// - baseInfo points to the containing local variable
// - baseValue is the value of Name
// - structField is nil
// - valueType is the type of Name
// - value is the value of Name
type LRValue struct {
ctx *Codegen
ast AST
block *ssa.Block
gen *ssa.Generator
baseInfo *ssa.PtrInfo
baseValue ssa.Value
valueType types.Info
value ssa.Value
structField *types.StructField
}
func (lrv LRValue) String() string {
offset := lrv.baseInfo.Offset + lrv.valueType.Offset
return fmt.Sprintf("%s[%d-%d]@%s{%d}%s/%v",
lrv.valueType, offset, offset+lrv.valueType.Bits,
lrv.baseInfo.Name, lrv.baseInfo.Scope, lrv.baseInfo.ContainerType,
lrv.baseInfo.ContainerType.Bits)
}
// BaseType returns the base type of the LRValue.
func (lrv *LRValue) BaseType() types.Info {
return lrv.baseInfo.ContainerType
}
// BaseValue returns the base value of the LRValue
func (lrv *LRValue) BaseValue() ssa.Value {
return lrv.baseValue
}
// BasePtrInfo returns the base value as PtrInfo.
func (lrv *LRValue) BasePtrInfo() *ssa.PtrInfo {
return lrv.baseInfo
}
// Indirect returns LRValue for the value that lrv points to. If lrv
// is not a pointer, Indirect returns lrv.
func (lrv *LRValue) Indirect() *LRValue {
v := lrv.RValue()
if v.Type.Type != types.TPtr {
return lrv
}
ret := *lrv
ret.valueType = *lrv.valueType.ElementType
ret.value.PtrInfo = nil
if lrv.baseInfo.ContainerType.Type == types.TStruct {
// Set value to undefined so RValue() can regenerate it.
ret.value.Type = types.Undefined
// Lookup struct field.
ret.structField = nil
for idx, f := range lrv.baseValue.Type.Struct {
if f.Type.Offset == lrv.baseInfo.Offset {
ret.structField = &lrv.baseValue.Type.Struct[idx]
break
}
}
if ret.structField == nil {
panic("LRValue.Indirect: could not find struct field")
}
} else {
ret.value.Type = *lrv.value.Type.ElementType
ret.value = lrv.baseValue
}
return &ret
}
// Set sets the l-value to rv.
func (lrv LRValue) Set(rv ssa.Value) error {
if !ssa.CanAssign(lrv.valueType, rv) {
return fmt.Errorf("cannot assing %v to variable of type %v",
rv.Type, lrv.valueType)
}
lValue := lrv.LValue()
if lrv.structField != nil {
fromConst := lrv.gen.Constant(int64(lrv.structField.Type.Offset),
types.Undefined)
toConst := lrv.gen.Constant(int64(lrv.structField.Type.Offset+
lrv.structField.Type.Bits), types.Undefined)
lrv.block.AddInstr(ssa.NewAmovInstr(rv, lrv.baseValue,
fromConst, toConst, lValue))
return lrv.baseInfo.Bindings.Set(lValue, nil)
}
if rv.Const && rv.IntegerLike() {
// Type coersions rules for const int r-values.
if lValue.Type.Concrete() {
rv.Type = lValue.Type
} else if rv.Type.Concrete() {
lValue.Type = rv.Type
} else {
return fmt.Errorf("unspecified size for type %v", rv.Type)
}
} else if rv.Type.Concrete() {
// Specifying the value of an unspecified variable, or
// specializing it (assining arrays with values of different
// size).
lValue.Type = rv.Type
} else if lValue.Type.Concrete() {
// Specializing r-value.
rv.Type = lValue.Type
} else {
return fmt.Errorf("unspecified size for type %v", rv.Type)
}
lrv.block.AddInstr(ssa.NewMovInstr(rv, lValue))
// The l-value and r-value types are now resolved. Let's define
// the variable with correct type and value information,
// overriding any old values.
lrv.baseInfo.Bindings.Define(lValue, &rv)
return nil
}
// LValue returns the l-value of the LRValue.
func (lrv *LRValue) LValue() ssa.Value {
return lrv.gen.NewVal(lrv.baseInfo.Name, lrv.baseInfo.ContainerType,
lrv.baseInfo.Scope)
}
// RValue returns the r-value of the LRValue.
func (lrv *LRValue) RValue() ssa.Value {
if lrv.value.Type.Undefined() && lrv.structField != nil {
fieldType := lrv.valueType
fieldType.Offset = 0
lrv.value = lrv.gen.AnonVal(fieldType)
from := int64(lrv.valueType.Offset)
to := int64(lrv.valueType.Offset + lrv.valueType.Bits)
if to > from {
fromConst := lrv.gen.Constant(from, types.Undefined)
toConst := lrv.gen.Constant(to, types.Undefined)
lrv.block.AddInstr(ssa.NewSliceInstr(lrv.baseValue, fromConst,
toConst, lrv.value))
}
}
return lrv.value
}
// ValueType returns the value type of the LRValue.
func (lrv *LRValue) ValueType() types.Info {
return lrv.valueType
}
func (lrv *LRValue) ptrBaseValue() (ssa.Value, error) {
b, ok := lrv.baseInfo.Bindings.Get(lrv.baseInfo.Name)
if !ok {
return ssa.Undefined, fmt.Errorf("undefined: %s", lrv.baseInfo.Name)
}
return b.Value(lrv.block, lrv.gen), nil
}
// ConstValue returns the constant value of the LRValue if available.
func (lrv *LRValue) ConstValue() (ssa.Value, bool, error) {
switch lrv.value.Type.Type {
case types.TUndefined:
return lrv.value, false, nil
case types.TBool, types.TInt, types.TUint, types.TFloat, types.TString,
types.TStruct, types.TArray, types.TSlice, types.TNil:
return lrv.value, true, nil
default:
return ssa.Undefined, false, lrv.ctx.Errorf(lrv.ast,
"LRValue.ConstValue: %s not supported yet: %v",
lrv.value.Type, lrv.value)
}
}
// LookupVar resolves the named variable from the context.
func (ctx *Codegen) LookupVar(block *ssa.Block, gen *ssa.Generator,
bindings *ssa.Bindings, ref *VariableRef) (
lrv *LRValue, cf, df bool, err error) {
lrv = &LRValue{
ctx: ctx,
ast: ref,
block: block,
gen: gen,
}
var env *ssa.Bindings
var b ssa.Binding
var ok bool
if len(ref.Name.Package) > 0 {
// Check if package name is bound to a value.
b, ok = bindings.Get(ref.Name.Package)
if ok {
env = bindings
} else {
// Check names in the current package.
b, ok = ctx.Package.Bindings.Get(ref.Name.Package)
if ok {
env = ctx.Package.Bindings
}
}
if ok {
if block != nil {
lrv.baseValue = b.Value(block, gen)
} else {
// Evaluating a const value.
v, ok := b.Bound.(*ssa.Value)
if !ok || !v.Const {
// Value is not const
return nil, false, false, nil
}
lrv.baseValue = *v
}
if lrv.baseValue.Type.Type == types.TPtr {
lrv.baseInfo = lrv.baseValue.PtrInfo
lrv.baseValue, err = lrv.ptrBaseValue()
if err != nil {
return nil, false, false, err
}
} else {
lrv.baseInfo = &ssa.PtrInfo{
Name: ref.Name.Package,
Bindings: env,
Scope: b.Scope,
ContainerType: b.Type,
}
}
if lrv.baseValue.Type.Type != types.TStruct {
return nil, false, false, fmt.Errorf("%s undefined", ref.Name)
}
for _, f := range lrv.baseValue.Type.Struct {
if f.Name == ref.Name.Name {
lrv.structField = &f
break
}
}
if lrv.structField == nil {
return nil, false, false, fmt.Errorf(
"%s undefined (type %s has no field or method %s)",
ref.Name, lrv.baseValue.Type, ref.Name.Name)
}
lrv.valueType = lrv.structField.Type
return lrv, true, false, nil
}
}
// Explicit package references.
var pkg *Package
if len(ref.Name.Package) > 0 {
pkg, ok = ctx.Packages[ref.Name.Package]
if !ok {
return nil, false, false, fmt.Errorf("package '%s' not found",
ref.Name.Package)
}
env = pkg.Bindings
b, ok = env.Get(ref.Name.Name)
if !ok {
return nil, false, false, fmt.Errorf("undefined variable '%s'",
ref.Name)
}
} else {
// Check block bindings.
env = bindings
b, ok = env.Get(ref.Name.Name)
if !ok {
// Check names in the name's package.
if len(ref.Name.Defined) > 0 {
pkg, ok = ctx.Packages[ref.Name.Defined]
if !ok {
return nil, false, false,
fmt.Errorf("package '%s' not found", ref.Name.Defined)
}
env = pkg.Bindings
b, ok = env.Get(ref.Name.Name)
}
}
}
if !ok {
return nil, false, true, fmt.Errorf("undefined variable '%s'",
ref.Name)
}
if block != nil {
lrv.value = b.Value(block, gen)
} else {
// Evaluating const value.
v, ok := b.Bound.(*ssa.Value)
if !ok || !v.Const {
// Value is not const
return nil, false, false, nil
}
lrv.value = *v
}
lrv.valueType = lrv.value.Type
if lrv.value.Type.Type == types.TPtr {
lrv.baseInfo = lrv.value.PtrInfo
lrv.baseValue, err = lrv.ptrBaseValue()
if err != nil {
return nil, false, false, err
}
} else {
lrv.baseInfo = &ssa.PtrInfo{
Name: ref.Name.Name,
Bindings: env,
Scope: b.Scope,
ContainerType: b.Type,
}
lrv.baseValue = lrv.value
}
return lrv, true, false, nil
}

View File

@ -0,0 +1,376 @@
//
// Copyright (c) 2020-2024 Markku Rossi
//
// All rights reserved.
//
package ast
import (
"errors"
"fmt"
"source.quilibrium.com/quilibrium/monorepo/bedlam/circuit"
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler/ssa"
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler/utils"
"source.quilibrium.com/quilibrium/monorepo/bedlam/types"
)
// Package implements a QCL package.
type Package struct {
Name string
Source string
Annotations Annotations
Initialized bool
Imports map[string]string
Bindings *ssa.Bindings
Types []*TypeInfo
Constants []*ConstantDef
Variables []*VariableDef
Functions map[string]*Func
}
// NewPackage creates a new package.
func NewPackage(name, source string, annotations Annotations) *Package {
return &Package{
Name: name,
Source: source,
Annotations: annotations,
Imports: make(map[string]string),
Bindings: new(ssa.Bindings),
Functions: make(map[string]*Func),
}
}
// Compile compiles the package.
func (pkg *Package) Compile(ctx *Codegen) (*ssa.Program, Annotations, error) {
main, err := pkg.Main()
if err != nil {
return nil, nil, ctx.Error(utils.Point{
Source: pkg.Source,
}, err.Error())
}
gen := ssa.NewGenerator(ctx.Params)
// Init is the program start point.
init := gen.Block()
// Init package.
block, err := pkg.Init(ctx.Packages, init, ctx, gen)
if err != nil {
return nil, nil, err
}
// Main block derives package's bindings from block with NextBlock().
ctx.PushCompilation(gen.NextBlock(block), gen.Block(), nil, main)
// Arguments.
var inputs circuit.IO
for idx, arg := range main.Args {
typeInfo, err := arg.Type.Resolve(NewEnv(ctx.Start()), ctx, gen)
if err != nil {
return nil, nil, ctx.Errorf(arg, "invalid argument type: %s", err)
}
if !typeInfo.Concrete() {
if ctx.MainInputSizes == nil {
return nil, nil,
ctx.Errorf(arg, "argument %s of %s has unspecified type",
arg.Name, main)
}
// Specify unspecified argument type.
if idx >= len(ctx.MainInputSizes) {
return nil, nil, ctx.Errorf(arg,
"not enough values for argument %s of %s",
arg.Name, main)
}
err = typeInfo.InstantiateWithSizes(ctx.MainInputSizes[idx])
if err != nil {
return nil, nil, ctx.Errorf(arg,
"can't specify unspecified argument %s of %s: %s",
arg.Name, main, err)
}
}
// Define argument in block.
a := gen.NewVal(arg.Name, typeInfo, ctx.Scope())
ctx.Start().Bindings.Define(a, nil)
input := circuit.IOArg{
Name: arg.Name,
Type: a.Type,
}
if typeInfo.Type == types.TStruct {
input.Compound = flattenStruct(typeInfo)
}
inputs = append(inputs, input)
}
// Compile main.
_, returnVars, err := main.SSA(ctx.Start(), ctx, gen)
if err != nil {
return nil, nil, err
}
// Return values
var outputs circuit.IO
for idx, rt := range main.Return {
if idx >= len(returnVars) {
return nil, nil, fmt.Errorf("too few values for %s", main)
}
typeInfo, err := rt.Type.Resolve(NewEnv(ctx.Start()), ctx, gen)
if err != nil {
return nil, nil, ctx.Errorf(rt, "invalid return type: %s", err)
}
// Instantiate result values for template functions.
if !typeInfo.Concrete() && !typeInfo.Instantiate(returnVars[idx].Type) {
return nil, nil, ctx.Errorf(main,
"invalid value %v for return value %d of %s",
returnVars[idx].Type, idx, main)
}
// The native() returns undefined values.
if returnVars[idx].Type.Undefined() {
returnVars[idx].Type.Type = typeInfo.Type
}
if !ssa.CanAssign(typeInfo, returnVars[idx]) {
return nil, nil,
ctx.Errorf(main, "invalid value %v for return value %d of %s",
returnVars[idx].Type, idx, main)
}
v := returnVars[idx]
outputs = append(outputs, circuit.IOArg{
Name: v.String(),
Type: v.Type,
})
}
steps := init.Serialize()
program, err := ssa.NewProgram(ctx.Params, inputs, outputs, gen.Constants(),
steps)
if err != nil {
return nil, nil, err
}
if false { // XXX Peephole liveness analysis is broken.
err = program.Peephole()
if err != nil {
return nil, nil, err
}
}
program.GC()
if ctx.Params.SSAOut != nil {
program.PP(ctx.Params.SSAOut)
}
if ctx.Params.SSADotOut != nil {
ssa.Dot(ctx.Params.SSADotOut, init)
}
return program, main.Annotations, nil
}
// Main returns package's main function.
func (pkg *Package) Main() (*Func, error) {
main, ok := pkg.Functions["main"]
if !ok {
return nil, errors.New("no main function defined")
}
return main, nil
}
func flattenStruct(t types.Info) circuit.IO {
var result circuit.IO
if t.Type != types.TStruct {
return result
}
for _, f := range t.Struct {
if f.Type.Type == types.TStruct {
ios := flattenStruct(f.Type)
result = append(result, ios...)
} else {
result = append(result, circuit.IOArg{
Name: f.Name,
Type: f.Type,
})
}
}
return result
}
// Init initializes the package.
func (pkg *Package) Init(packages map[string]*Package, block *ssa.Block,
ctx *Codegen, gen *ssa.Generator) (*ssa.Block, error) {
if pkg.Initialized {
return block, nil
}
pkg.Initialized = true
if ctx.Verbose {
fmt.Printf("Initializing %s\n", pkg.Name)
}
// Imported packages.
for alias, name := range pkg.Imports {
p, ok := packages[alias]
if !ok {
return nil, fmt.Errorf("imported and not used: \"%s\"", name)
}
var err error
block, err = p.Init(packages, block, ctx, gen)
if err != nil {
return nil, err
}
}
// Define constants.
for _, def := range pkg.Constants {
err := pkg.defineConstant(def, ctx, gen)
if err != nil {
return nil, err
}
}
// Define types.
for _, typeDef := range pkg.Types {
err := pkg.defineType(typeDef, ctx, gen)
if err != nil {
return nil, err
}
}
// Package initializer block.
block = gen.NextBlock(block)
block.Name = fmt.Sprintf(".%s", pkg.Name)
// Package sees only its bindings.
block.Bindings = pkg.Bindings.Clone()
var err error
// Define variables.
for _, def := range pkg.Variables {
block, _, err = def.SSA(block, ctx, gen)
if err != nil {
return nil, err
}
}
pkg.Bindings = block.Bindings
return block, nil
}
func (pkg *Package) defineConstant(def *ConstantDef, ctx *Codegen,
gen *ssa.Generator) error {
env := &Env{
Bindings: pkg.Bindings,
}
typeInfo, err := def.Type.Resolve(env, ctx, gen)
if err != nil {
return err
}
constVal, ok, err := def.Init.Eval(env, ctx, gen)
if err != nil {
return err
}
if !ok {
return ctx.Errorf(def.Init, "init value is not constant")
}
constVar := gen.Constant(constVal, typeInfo)
if typeInfo.Undefined() {
typeInfo.Type = constVar.Type.Type
}
if !typeInfo.Concrete() {
typeInfo.Bits = constVar.Type.Bits
}
if !typeInfo.CanAssignConst(constVar.Type) {
return ctx.Errorf(def.Init,
"invalid init value %s for type %s", constVar.Type, typeInfo)
}
_, ok = pkg.Bindings.Get(def.Name)
if ok {
return ctx.Errorf(def, "constant %s already defined", def.Name)
}
lValue := constVar
lValue.Name = def.Name
pkg.Bindings.Define(lValue, &constVar)
gen.AddConstant(constVal)
return nil
}
func (pkg *Package) defineType(def *TypeInfo, ctx *Codegen,
gen *ssa.Generator) error {
_, ok := pkg.Bindings.Get(def.TypeName)
if ok {
return ctx.Errorf(def, "type %s already defined", def.TypeName)
}
env := &Env{
Bindings: pkg.Bindings,
}
var info types.Info
var err error
switch def.Type {
case TypeStruct:
// Construct compound type.
var fields []types.StructField
var bits types.Size
var minBits types.Size
var offset types.Size
for _, field := range def.StructFields {
info, err := field.Type.Resolve(env, ctx, gen)
if err != nil {
return err
}
field := types.StructField{
Name: field.Name,
Type: info,
}
field.Type.Offset = offset
fields = append(fields, field)
bits += info.Bits
minBits += info.MinBits
offset += info.Bits
}
info = types.Info{
Type: types.TStruct,
IsConcrete: true,
Bits: bits,
MinBits: minBits,
Struct: fields,
}
case TypeArray:
info, err = def.Resolve(env, ctx, gen)
if err != nil {
return err
}
case TypeAlias:
info, err = def.AliasType.Resolve(env, ctx, gen)
if err != nil {
return err
}
default:
return ctx.Errorf(def, "invalid type definition: %s", def)
}
info.ID = ctx.DefineType(def)
v := gen.Constant(info, types.Undefined)
lval := gen.NewVal(def.TypeName, info, ctx.Scope())
pkg.Bindings.Define(lval, &v)
return nil
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,101 @@
//
// Copyright (c) 2023 Markku Rossi
//
// All rights reserved.
//
package circuits
import (
"fmt"
"unsafe"
"source.quilibrium.com/quilibrium/monorepo/bedlam/circuit"
"source.quilibrium.com/quilibrium/monorepo/bedlam/types"
)
var (
sizeofWire = uint64(unsafe.Sizeof(Wire{}))
sizeofGate = uint64(unsafe.Sizeof(Gate{}))
)
// Allocator implements circuit wire and gate allocation.
type Allocator struct {
numWire uint64
numWires uint64
numGates uint64
}
// NewAllocator creates a new circuit allocator.
func NewAllocator() *Allocator {
return new(Allocator)
}
// Wire allocates a new Wire.
func (alloc *Allocator) Wire() *Wire {
alloc.numWire++
w := new(Wire)
w.Reset(UnassignedID)
return w
}
// Wires allocate an array of Wires.
func (alloc *Allocator) Wires(bits types.Size) []*Wire {
alloc.numWires += uint64(bits)
wires := make([]Wire, bits)
result := make([]*Wire, bits)
for i := 0; i < int(bits); i++ {
w := &wires[i]
w.id = UnassignedID
result[i] = w
}
return result
}
// BinaryGate creates a new binary gate.
func (alloc *Allocator) BinaryGate(op circuit.Operation, a, b, o *Wire) *Gate {
alloc.numGates++
gate := &Gate{
Op: op,
A: a,
B: b,
O: o,
}
a.AddOutput(gate)
b.AddOutput(gate)
o.SetInput(gate)
return gate
}
// INVGate creates a new INV gate.
func (alloc *Allocator) INVGate(i, o *Wire) *Gate {
alloc.numGates++
gate := &Gate{
Op: circuit.INV,
A: i,
O: o,
}
i.AddOutput(gate)
o.SetInput(gate)
return gate
}
// Debug print debugging information about the circuit allocator.
func (alloc *Allocator) Debug() {
wireSize := circuit.FileSize(alloc.numWire * sizeofWire)
wiresSize := circuit.FileSize(alloc.numWires * sizeofWire)
gatesSize := circuit.FileSize(alloc.numGates * sizeofGate)
total := float64(wireSize + wiresSize + gatesSize)
fmt.Println("circuits.Allocator:")
fmt.Printf(" wire : %9v %5s %5.2f%%\n",
alloc.numWire, wireSize, float64(wireSize)/total*100.0)
fmt.Printf(" wires: %9v %5s %5.2f%%\n",
alloc.numWires, wiresSize, float64(wiresSize)/total*100.0)
fmt.Printf(" gates: %9v %5s %5.2f%%\n",
alloc.numGates, gatesSize, float64(gatesSize)/total*100.0)
}

View File

@ -0,0 +1,96 @@
//
// circ_adder.go
//
// Copyright (c) 2019-2023 Markku Rossi
//
// All rights reserved.
//
package circuits
import (
"source.quilibrium.com/quilibrium/monorepo/bedlam/circuit"
)
// NewHalfAdder creates a half adder circuit.
func NewHalfAdder(cc *Compiler, a, b, s, c *Wire) {
// S = XOR(A, B)
cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, a, b, s))
if c != nil {
// C = AND(A, B)
cc.AddGate(cc.Calloc.BinaryGate(circuit.AND, a, b, c))
}
}
// NewFullAdder creates a full adder circuit
func NewFullAdder(cc *Compiler, a, b, cin, s, cout *Wire) {
w1 := cc.Calloc.Wire()
w2 := cc.Calloc.Wire()
w3 := cc.Calloc.Wire()
// s = a XOR b XOR cin
// cout = cin XOR ((a XOR cin) AND (b XOR cin)).
// w1 = XOR(b, cin)
cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, b, cin, w1))
// s = XOR(a, w1)
cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, a, w1, s))
if cout != nil {
// w2 = XOR(a, cin)
cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, a, cin, w2))
// w3 = AND(w1, w2)
cc.AddGate(cc.Calloc.BinaryGate(circuit.AND, w1, w2, w3))
// cout = XOR(cin, w3)
cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, cin, w3, cout))
}
}
// NewAdder creates a new adder circuit implementing z=x+y.
func NewAdder(cc *Compiler, x, y, z []*Wire) error {
x, y = cc.ZeroPad(x, y)
if len(x) > len(z) {
x = x[0:len(z)]
y = y[0:len(z)]
}
if len(x) == 1 {
var cin *Wire
if len(z) > 1 {
cin = z[1]
}
NewHalfAdder(cc, x[0], y[0], z[0], cin)
} else {
cin := cc.Calloc.Wire()
NewHalfAdder(cc, x[0], y[0], z[0], cin)
for i := 1; i < len(x); i++ {
var cout *Wire
if i+1 >= len(x) {
if i+1 >= len(z) {
// N+N=N, overflow, drop carry bit.
cout = nil
} else {
cout = z[len(x)]
}
} else {
cout = cc.Calloc.Wire()
}
NewFullAdder(cc, x[i], y[i], cin, z[i], cout)
cin = cout
}
}
// Set all leftover bits to zero.
for i := len(x) + 1; i < len(z); i++ {
z[i] = cc.ZeroWire()
}
return nil
}

View File

@ -0,0 +1,65 @@
//
// Copyright (c) 2020-2021, 2023 Markku Rossi
//
// All rights reserved.
//
package circuits
import (
"source.quilibrium.com/quilibrium/monorepo/bedlam/circuit"
)
// NewBinaryAND creates a new binary AND circuit implementing r=x&y
func NewBinaryAND(cc *Compiler, x, y, r []*Wire) error {
x, y = cc.ZeroPad(x, y)
if len(r) < len(x) {
x = x[0:len(r)]
y = y[0:len(r)]
}
for i := 0; i < len(x); i++ {
cc.AddGate(cc.Calloc.BinaryGate(circuit.AND, x[i], y[i], r[i]))
}
return nil
}
// NewBinaryClear creates a new binary clear circuit implementing r=x&^y.
func NewBinaryClear(cc *Compiler, x, y, r []*Wire) error {
x, y = cc.ZeroPad(x, y)
if len(r) < len(x) {
x = x[0:len(r)]
y = y[0:len(r)]
}
for i := 0; i < len(x); i++ {
w := cc.Calloc.Wire()
cc.INV(y[i], w)
cc.AddGate(cc.Calloc.BinaryGate(circuit.AND, x[i], w, r[i]))
}
return nil
}
// NewBinaryOR creates a new binary OR circuit implementing r=x|y.
func NewBinaryOR(cc *Compiler, x, y, r []*Wire) error {
x, y = cc.ZeroPad(x, y)
if len(r) < len(x) {
x = x[0:len(r)]
y = y[0:len(r)]
}
for i := 0; i < len(x); i++ {
cc.AddGate(cc.Calloc.BinaryGate(circuit.OR, x[i], y[i], r[i]))
}
return nil
}
// NewBinaryXOR creates a new binary XOR circuit implementing r=x^y.
func NewBinaryXOR(cc *Compiler, x, y, r []*Wire) error {
x, y = cc.ZeroPad(x, y)
if len(r) < len(x) {
x = x[0:len(r)]
y = y[0:len(r)]
}
for i := 0; i < len(x); i++ {
cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, x[i], y[i], r[i]))
}
return nil
}

View File

@ -0,0 +1,161 @@
//
// Copyright (c) 2020-2023 Markku Rossi
//
// All rights reserved.
//
package circuits
import (
"fmt"
"source.quilibrium.com/quilibrium/monorepo/bedlam/circuit"
"source.quilibrium.com/quilibrium/monorepo/bedlam/types"
)
// comparator tests if x>y if cin=0, and x>=y if cin=1.
func comparator(cc *Compiler, cin *Wire, x, y, r []*Wire) error {
x, y = cc.ZeroPad(x, y)
if len(r) != 1 {
return fmt.Errorf("invalid lt comparator arguments: r=%d", len(r))
}
for i := 0; i < len(x); i++ {
w1 := cc.Calloc.Wire()
cc.AddGate(cc.Calloc.BinaryGate(circuit.XNOR, cin, y[i], w1))
w2 := cc.Calloc.Wire()
cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, cin, x[i], w2))
w3 := cc.Calloc.Wire()
cc.AddGate(cc.Calloc.BinaryGate(circuit.AND, w1, w2, w3))
var cout *Wire
if i+1 < len(x) {
cout = cc.Calloc.Wire()
} else {
cout = r[0]
}
cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, cin, w3, cout))
cin = cout
}
return nil
}
// NewGtComparator tests if x>y.
func NewGtComparator(cc *Compiler, x, y, r []*Wire) error {
return comparator(cc, cc.ZeroWire(), x, y, r)
}
// NewGeComparator tests if x>=y.
func NewGeComparator(cc *Compiler, x, y, r []*Wire) error {
return comparator(cc, cc.OneWire(), x, y, r)
}
// NewLtComparator tests if x<y.
func NewLtComparator(cc *Compiler, x, y, r []*Wire) error {
return comparator(cc, cc.ZeroWire(), y, x, r)
}
// NewLeComparator tests if x<=y.
func NewLeComparator(cc *Compiler, x, y, r []*Wire) error {
return comparator(cc, cc.OneWire(), y, x, r)
}
// NewNeqComparator tewsts if x!=y.
func NewNeqComparator(cc *Compiler, x, y, r []*Wire) error {
x, y = cc.ZeroPad(x, y)
if len(r) != 1 {
return fmt.Errorf("invalid neq comparator arguments: r=%d", len(r))
}
if len(x) == 1 {
cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, x[0], y[0], r[0]))
return nil
}
c := cc.Calloc.Wire()
cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, x[0], y[0], c))
for i := 1; i < len(x); i++ {
xor := cc.Calloc.Wire()
cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, x[i], y[i], xor))
var out *Wire
if i+1 >= len(x) {
out = r[0]
} else {
out = cc.Calloc.Wire()
}
cc.AddGate(cc.Calloc.BinaryGate(circuit.OR, c, xor, out))
c = out
}
return nil
}
// NewEqComparator tests if x==y.
func NewEqComparator(cc *Compiler, x, y, r []*Wire) error {
if len(r) != 1 {
return fmt.Errorf("invalid eq comparator arguments: r=%d", len(r))
}
// w = x == y
w := cc.Calloc.Wire()
err := NewNeqComparator(cc, x, y, []*Wire{w})
if err != nil {
return err
}
// r = !w
cc.INV(w, r[0])
return nil
}
// NewLogicalAND implements logical AND implementing r=x&y. The input
// and output wires must be 1 bit wide.
func NewLogicalAND(cc *Compiler, x, y, r []*Wire) error {
if len(x) != 1 || len(y) != 1 || len(r) != 1 {
return fmt.Errorf("invalid logical and arguments: x=%d, y=%d, r=%d",
len(x), len(y), len(r))
}
cc.AddGate(cc.Calloc.BinaryGate(circuit.AND, x[0], y[0], r[0]))
return nil
}
// NewLogicalOR implements logical OR implementing r=x|y. The input
// and output wires must be 1 bit wide.
func NewLogicalOR(cc *Compiler, x, y, r []*Wire) error {
if len(x) != 1 || len(y) != 1 || len(r) != 1 {
return fmt.Errorf("invalid logical or arguments: x=%d, y=%d, r=%d",
len(x), len(y), len(r))
}
cc.AddGate(cc.Calloc.BinaryGate(circuit.OR, x[0], y[0], r[0]))
return nil
}
// NewBitSetTest tests if the index'th bit of x is set.
func NewBitSetTest(cc *Compiler, x []*Wire, index types.Size, r []*Wire) error {
if len(r) != 1 {
return fmt.Errorf("invalid bit set test arguments: x=%d, r=%d",
len(x), len(r))
}
if index < types.Size(len(x)) {
w := cc.ZeroWire()
cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, x[index], w, r[0]))
} else {
r[0] = cc.ZeroWire()
}
return nil
}
// NewBitClrTest tests if the index'th bit of x is unset.
func NewBitClrTest(cc *Compiler, x []*Wire, index types.Size, r []*Wire) error {
if len(r) != 1 {
return fmt.Errorf("invalid bit clear test arguments: x=%d, r=%d",
len(x), len(r))
}
if index < types.Size(len(x)) {
w := cc.OneWire()
cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, x[index], w, r[0]))
} else {
r[0] = cc.OneWire()
}
return nil
}

View File

@ -0,0 +1,309 @@
//
// Copyright (c) 2019-2024 Markku Rossi
//
// All rights reserved.
//
package circuits
// NewUDividerLong creates an unsigned integer division circuit
// computing r=a/b, q=a%b. This function uses Long Division algorithm.
func NewUDividerLong(cc *Compiler, a, b, q, rret []*Wire) error {
a, b = cc.ZeroPad(a, b)
r := make([]*Wire, len(a))
for i := 0; i < len(r); i++ {
r[i] = cc.ZeroWire()
}
for i := len(a) - 1; i >= 0; i-- {
// r << 1
for j := len(r) - 1; j > 0; j-- {
r[j] = r[j-1]
}
r[0] = a[i]
// r-d, overlow: r < d
diff := make([]*Wire, len(r)+1)
for j := 0; j < len(diff); j++ {
diff[j] = cc.Calloc.Wire()
}
err := NewSubtractor(cc, r, b, diff)
if err != nil {
return err
}
if i < len(q) {
err = NewMUX(cc, diff[len(diff)-1:], []*Wire{cc.ZeroWire()},
[]*Wire{cc.OneWire()}, q[i:i+1])
if err != nil {
return err
}
}
nr := make([]*Wire, len(r))
for j := 0; j < len(nr); j++ {
if i == 0 && j < len(rret) {
nr[j] = rret[j]
} else {
nr[j] = cc.Calloc.Wire()
}
}
err = NewMUX(cc, diff[len(diff)-1:], r, diff[:len(diff)-1], nr)
if err != nil {
return err
}
r = nr
}
return nil
}
// NewUDividerRestoring creates an unsigned integer division circuit
// computing r=a/b, q=a%b. This function uses Restoring Division
// algorithm.
func NewUDividerRestoring(cc *Compiler, a, b, q, rret []*Wire) error {
a, b = cc.ZeroPad(a, b)
r := make([]*Wire, len(a)*2)
for i := 0; i < len(r); i++ {
if i < len(a) {
r[i] = a[i]
} else {
r[i] = cc.ZeroWire()
}
}
d := make([]*Wire, len(b)*2)
for i := 0; i < len(d); i++ {
if i < len(b) {
d[i] = cc.ZeroWire()
} else {
d[i] = b[i-len(b)]
}
}
for i := len(a) - 1; i >= 0; i-- {
// r << 1
for j := len(r) - 1; j > 0; j-- {
r[j] = r[j-1]
}
r[0] = cc.ZeroWire()
// r-d, overlow: r < d
diff := make([]*Wire, len(r)+1)
for j := 0; j < len(diff); j++ {
diff[j] = cc.Calloc.Wire()
}
err := NewSubtractor(cc, r, d, diff)
if err != nil {
return err
}
if i < len(q) {
err = NewMUX(cc, diff[len(diff)-1:], []*Wire{cc.ZeroWire()},
[]*Wire{cc.OneWire()}, q[i:i+1])
if err != nil {
return err
}
}
nr := make([]*Wire, len(r))
for j := 0; j < len(nr); j++ {
if i == 0 && j >= len(a) && j-len(a) < len(rret) {
nr[j] = rret[j-len(a)]
} else {
nr[j] = cc.Calloc.Wire()
}
}
err = NewMUX(cc, diff[len(diff)-1:], r, diff[:len(diff)-1], nr)
if err != nil {
return err
}
r = nr
}
return nil
}
// NewUDividerArray creates an unsigned integer division circuit
// computing r=a/b, q=a%b. This function uses Array Divider algorithm.
func NewUDividerArray(cc *Compiler, a, b, q, r []*Wire) error {
a, b = cc.ZeroPad(a, b)
rIn := make([]*Wire, len(b)+1)
rOut := make([]*Wire, len(b)+1)
// Init bINV.
bINV := make([]*Wire, len(b))
for i := 0; i < len(b); i++ {
bINV[i] = cc.Calloc.Wire()
cc.INV(b[i], bINV[i])
}
// Init for the first row.
for i := 0; i < len(b); i++ {
rOut[i] = cc.ZeroWire()
}
// Generate matrix.
for y := 0; y < len(a); y++ {
// Init rIn.
rIn[0] = a[len(a)-1-y]
copy(rIn[1:], rOut)
// Adders from b{0} to b{n-1}, 0
cIn := cc.OneWire()
for x := 0; x < len(b)+1; x++ {
var bw *Wire
if x < len(b) {
bw = bINV[x]
} else {
bw = cc.OneWire() // INV(0)
}
co := cc.Calloc.Wire()
ro := cc.Calloc.Wire()
NewFullAdder(cc, rIn[x], bw, cIn, ro, co)
rOut[x] = ro
cIn = co
}
// Quotient y.
if len(a)-1-y < len(q) {
w := cc.Calloc.Wire()
cc.INV(cIn, w)
cc.INV(w, q[len(a)-1-y])
}
// MUXes from high to low bit.
for x := len(b); x >= 0; x-- {
var ro *Wire
if y+1 >= len(a) && x < len(r) {
ro = r[x]
} else {
ro = cc.Calloc.Wire()
}
err := NewMUX(cc, []*Wire{cIn}, rOut[x:x+1], rIn[x:x+1],
[]*Wire{ro})
if err != nil {
return err
}
rOut[x] = ro
}
}
// Set extra quotient bits to zero.
for y := len(a); y < len(q); y++ {
q[y] = cc.ZeroWire()
}
// Set extra remainder bits to zero.
for x := len(b); x < len(r); x++ {
r[x] = cc.ZeroWire()
}
return nil
}
// NewUDivider creates an unsigned integer division circuit computing
// r=a/b, q=a%b.
func NewUDivider(cc *Compiler, a, b, q, r []*Wire) error {
return NewUDividerLong(cc, a, b, q, r)
}
// NewIDivider creates a signed integer division circuit computing
// r=a/b, q=a%b. The function converts negative a and b to positive
// values before doing the divmod with NewUDivider. If a or b was
// negative, the funtion converts the result quotient to negative
// value.
func NewIDivider(cc *Compiler, a, b, q, r []*Wire) error {
a, b = cc.ZeroPad(a, b)
zero := []*Wire{cc.ZeroWire()}
neg0 := cc.ZeroWire()
// If a is negative, set neg=!neg, a=-a.
neg1 := cc.Calloc.Wire()
cc.INV(neg0, neg1)
a1 := make([]*Wire, len(a))
for i := 0; i < len(a1); i++ {
a1[i] = cc.Calloc.Wire()
}
err := NewSubtractor(cc, zero, a, a1)
if err != nil {
return err
}
neg2 := cc.Calloc.Wire()
err = NewMUX(cc, a[len(a)-1:], []*Wire{neg1}, []*Wire{neg0}, []*Wire{neg2})
if err != nil {
return err
}
a2 := make([]*Wire, len(a))
for i := 0; i < len(a2); i++ {
a2[i] = cc.Calloc.Wire()
}
err = NewMUX(cc, a[len(a)-1:], a1, a, a2)
if err != nil {
return err
}
// If b is negative, set neg=!neg, b=-b.
neg3 := cc.Calloc.Wire()
cc.INV(neg2, neg3)
b1 := make([]*Wire, len(b))
for i := 0; i < len(b1); i++ {
b1[i] = cc.Calloc.Wire()
}
err = NewSubtractor(cc, zero, b, b1)
if err != nil {
return err
}
neg4 := cc.Calloc.Wire()
err = NewMUX(cc, b[len(b)-1:], []*Wire{neg3}, []*Wire{neg2}, []*Wire{neg4})
if err != nil {
return err
}
b2 := make([]*Wire, len(b))
for i := 0; i < len(a2); i++ {
b2[i] = cc.Calloc.Wire()
}
err = NewMUX(cc, b[len(b)-1:], b1, b, b2)
if err != nil {
return err
}
if len(q) == 0 {
// Modulo operation.
return NewUDivider(cc, a2, b2, q, r)
}
// If neg is set, set q=-q
q0 := make([]*Wire, len(q))
for i := 0; i < len(q0); i++ {
q0[i] = cc.Calloc.Wire()
}
err = NewUDivider(cc, a2, b2, q0, r)
if err != nil {
return err
}
q1 := make([]*Wire, len(q))
for i := 0; i < len(q1); i++ {
q1[i] = cc.Calloc.Wire()
}
err = NewSubtractor(cc, zero, q0, q1)
if err != nil {
return err
}
return NewMUX(cc, []*Wire{neg4}, q1, q0, q)
}

View File

@ -0,0 +1,44 @@
//
// Copyright (c) 2020-2023 Markku Rossi
//
// All rights reserved.
//
package circuits
import (
"source.quilibrium.com/quilibrium/monorepo/bedlam/circuit"
"source.quilibrium.com/quilibrium/monorepo/bedlam/types"
)
// Hamming creates a hamming distance circuit computing the hamming
// distance between a and b and returning the distance in r.
func Hamming(cc *Compiler, a, b, r []*Wire) error {
a, b = cc.ZeroPad(a, b)
var arr [][]*Wire
for i := 0; i < len(a); i++ {
w := cc.Calloc.Wire()
cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, a[i], b[i], w))
arr = append(arr, []*Wire{w})
}
for len(arr) > 2 {
var n [][]*Wire
for i := 0; i < len(arr); i += 2 {
if i+1 < len(arr) {
result := cc.Calloc.Wires(types.Size(len(arr[i]) + 1))
err := NewAdder(cc, arr[i], arr[i+1], result)
if err != nil {
return err
}
n = append(n, result)
} else {
n = append(n, arr[i])
}
}
arr = n
}
return NewAdder(cc, arr[0], arr[1], r)
}

View File

@ -0,0 +1,101 @@
//
// Copyright (c) 2021-2023 Markku Rossi
//
// All rights reserved.
//
package circuits
import (
"fmt"
)
// NewIndex creates a new array element selection (index) circuit.
func NewIndex(cc *Compiler, size int, array, index, out []*Wire) error {
if len(array)%size != 0 {
return fmt.Errorf("array width %d must be multiple of element size %d",
len(array), size)
}
if len(out) < size {
return fmt.Errorf("out %d too small for element size %d",
len(out), size)
}
n := len(array) / size
if n == 0 {
for i := 0; i < len(out); i++ {
out[i] = cc.ZeroWire()
}
return nil
}
bits := 1
var length int
for length = 2; length < n; length *= 2 {
bits++
}
return newIndex(cc, bits-1, length, size, array, index, out)
}
func newIndex(cc *Compiler, bit, length, size int,
array, index, out []*Wire) error {
// Default "not found" value.
def := make([]*Wire, size)
for i := 0; i < size; i++ {
def[i] = cc.ZeroWire()
}
n := len(array) / size
if bit == 0 {
fVal := array[:size]
var tVal []*Wire
if n > 1 {
tVal = array[size : 2*size]
} else {
tVal = def
}
return NewMUX(cc, index[0:1], tVal, fVal, out)
}
length /= 2
fArray := array
if n > length {
fArray = fArray[:length*size]
}
if bit >= len(index) {
// Not enough bits to select upper half so just select from
// the lower half.
return newIndex(cc, bit-1, length, size, fArray, index, out)
}
fVal := make([]*Wire, size)
for i := 0; i < size; i++ {
fVal[i] = cc.Calloc.Wire()
}
err := newIndex(cc, bit-1, length, size, fArray, index, fVal)
if err != nil {
return err
}
var tVal []*Wire
if n > length {
tVal = make([]*Wire, size)
for i := 0; i < size; i++ {
tVal[i] = cc.Calloc.Wire()
}
err = newIndex(cc, bit-1, length, size,
array[length*size:], index, tVal)
if err != nil {
return err
}
} else {
tVal = def
}
return NewMUX(cc, index[bit:bit+1], tVal, fVal, out)
}

View File

@ -0,0 +1,227 @@
//
// circ_multiplier.go
//
// Copyright (c) 2019-2023 Markku Rossi
//
// All rights reserved.
//
package circuits
import (
"source.quilibrium.com/quilibrium/monorepo/bedlam/circuit"
"source.quilibrium.com/quilibrium/monorepo/bedlam/types"
)
// NewMultiplier creates a multiplier circuit implementing x*y=z.
func NewMultiplier(c *Compiler, arrayTreshold int, x, y, z []*Wire) error {
if false {
return NewArrayMultiplier(c, x, y, z)
}
if arrayTreshold < 8 {
var ok bool
arrayTreshold, ok = multiplierArrayTresholds[len(x)]
if !ok {
arrayTreshold = 21
}
}
return NewKaratsubaMultiplier(c, arrayTreshold, x, y, z)
}
// NewArrayMultiplier creates a multiplier circuit implementing
// x*y=z. This function implements Array Multiplier Circuit.
func NewArrayMultiplier(cc *Compiler, x, y, z []*Wire) error {
x, y = cc.ZeroPad(x, y)
if len(x) > len(z) {
x = x[0:len(z)]
y = y[0:len(z)]
}
// One bit multiplication is AND.
if len(x) == 1 {
cc.AddGate(cc.Calloc.BinaryGate(circuit.AND, x[0], y[0], z[0]))
if len(z) > 1 {
z[1] = cc.ZeroWire()
}
return nil
}
var sums []*Wire
// Construct Y0 sums
for i, xn := range x {
var s *Wire
if i == 0 {
s = z[0]
} else {
s = cc.Calloc.Wire()
sums = append(sums, s)
}
cc.AddGate(cc.Calloc.BinaryGate(circuit.AND, xn, y[0], s))
}
// Construct len(y)-2 intermediate layers
var j int
for j = 1; j+1 < len(y); j++ {
// ANDs for y(j)
var ands []*Wire
for _, xn := range x {
wire := cc.Calloc.Wire()
cc.AddGate(cc.Calloc.BinaryGate(circuit.AND, xn, y[j], wire))
ands = append(ands, wire)
}
// Compute next sums.
var nsums []*Wire
var c *Wire
for i := 0; i < len(ands); i++ {
cout := cc.Calloc.Wire()
var s *Wire
if i == 0 {
s = z[j]
} else {
s = cc.Calloc.Wire()
nsums = append(nsums, s)
}
if i == 0 {
NewHalfAdder(cc, ands[i], sums[i], s, cout)
} else if i >= len(sums) {
NewHalfAdder(cc, ands[i], c, s, cout)
} else {
NewFullAdder(cc, ands[i], sums[i], c, s, cout)
}
c = cout
}
// New sums with carry as the highest bit.
sums = append(nsums, c)
}
// Construct final layer.
var c *Wire
for i, xn := range x {
and := cc.Calloc.Wire()
cc.AddGate(cc.Calloc.BinaryGate(circuit.AND, xn, y[j], and))
var cout *Wire
if i+1 >= len(x) && j+i+1 < len(z) {
cout = z[j+i+1]
} else {
cout = cc.Calloc.Wire()
}
if j+i < len(z) {
if i == 0 {
NewHalfAdder(cc, and, sums[i], z[j+i], cout)
} else if i >= len(sums) {
NewHalfAdder(cc, and, c, z[j+i], cout)
} else {
NewFullAdder(cc, and, sums[i], c, z[j+i], cout)
}
}
c = cout
}
for i := j + len(x) + 1; i < len(z); i++ {
z[1] = cc.ZeroWire()
}
return nil
}
// NewKaratsubaMultiplier creates a multiplier circuit implementing
// the Karatsuba algorithm
// (https://en.wikipedia.org/wiki/Karatsuba_algorithm). The Karatsuba
// algorithm is should produce faster circuits on inputs of about 128
// bits (the number of non-XOR gates is smaller). On input sizes of
// 256 bits also the overall circuits are smaller than with the array
// multiplier algorithm.
//
// Bits Array a-xor a-and Karatsu K-xor K-and
// 8 301 172 129 993 724 269
// 16 1365 852 513 3573 2660 913
// 32 5797 3748 2049 11937 8980 2957
// 64 23877 15684 8193 38277 28964 9313
// 128 96901 64132 32769 119793 90964 28829
// 256 390405 259332 131073 369333 281060 88273
// 512 1567237 1042948 524289 1127937 859540 268397
// 1024 6280197 4183044 2097153 3423717 2611364 812353
// 2048 25143301 16754692 8388609 10350993 7899604 2451389
func NewKaratsubaMultiplier(cc *Compiler, limit int, a, b, r []*Wire) error {
a, b = cc.ZeroPad(a, b)
if len(a) > len(r) {
a = a[0:len(r)]
b = b[0:len(r)]
}
// Compute smaller multiplications with array multiplier.
if len(a) <= limit {
return NewArrayMultiplier(cc, a, b, r)
}
mid := len(a) / 2
aLow := a[:mid]
aHigh := a[mid:]
bLow := b[:mid]
bHigh := b[mid:]
z0 := cc.Calloc.Wires(types.Size(min(max(len(aLow), len(bLow))*2, len(r))))
if err := NewKaratsubaMultiplier(cc, limit, aLow, bLow, z0); err != nil {
return err
}
aSumLen := max(len(aLow), len(aHigh)) + 1
aSum := cc.Calloc.Wires(types.Size(aSumLen))
if err := NewAdder(cc, aLow, aHigh, aSum); err != nil {
return err
}
bSumLen := max(len(bLow), len(bHigh)) + 1
bSum := cc.Calloc.Wires(types.Size(bSumLen))
if err := NewAdder(cc, bLow, bHigh, bSum); err != nil {
return err
}
z1 := cc.Calloc.Wires(types.Size(min(max(aSumLen, bSumLen)*2, len(r))))
if err := NewKaratsubaMultiplier(cc, limit, aSum, bSum, z1); err != nil {
return err
}
z2 := cc.Calloc.Wires(types.Size(min(max(len(aHigh), len(bHigh))*2, len(r))))
if err := NewKaratsubaMultiplier(cc, limit, aHigh, bHigh, z2); err != nil {
return err
}
sub1 := cc.Calloc.Wires(types.Size(len(r)))
if err := NewSubtractor(cc, z1, z2, sub1); err != nil {
return err
}
sub2 := cc.Calloc.Wires(types.Size(len(r)))
if err := NewSubtractor(cc, sub1, z0, sub2); err != nil {
return err
}
shift1 := cc.ShiftLeft(z2, len(r), mid*2)
shift2 := cc.ShiftLeft(sub2, len(r), mid)
add1 := cc.Calloc.Wires(types.Size(len(r)))
if err := NewAdder(cc, shift1, shift2, add1); err != nil {
return err
}
return NewAdder(cc, add1, z0, r)
}
func max(a, b int) int {
if a > b {
return a
}
return b
}
func min(a, b int) int {
if a < b {
return a
}
return b
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,39 @@
//
// Copyright (c) 2020-2023 Markku Rossi
//
// All rights reserved.
//
package circuits
import (
"fmt"
"source.quilibrium.com/quilibrium/monorepo/bedlam/circuit"
)
// NewMUX creates a multiplexer circuit that selects the input t or f
// to output, based on the value of the condition cond.
func NewMUX(cc *Compiler, cond, t, f, out []*Wire) error {
t, f = cc.ZeroPad(t, f)
if len(cond) != 1 || len(t) != len(f) || len(t) != len(out) {
return fmt.Errorf("invalid mux arguments: cond=%d, l=%d, r=%d, out=%d",
len(cond), len(t), len(f), len(out))
}
for i := 0; i < len(t); i++ {
w1 := cc.Calloc.Wire()
w2 := cc.Calloc.Wire()
// w1 = XOR(f[i], t[i])
cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, f[i], t[i], w1))
// w2 = AND(w1, cond)
cc.AddGate(cc.Calloc.BinaryGate(circuit.AND, w1, cond[0], w2))
// out[i] = XOR(w2, f[i])
cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, w2, f[i], out[i]))
}
return nil
}

View File

@ -0,0 +1,63 @@
//
// circ_subtractor.go
//
// Copyright (c) 2019-2023 Markku Rossi
//
// All rights reserved.
//
package circuits
import (
"source.quilibrium.com/quilibrium/monorepo/bedlam/circuit"
)
// NewFullSubtractor creates a full subtractor circuit.
func NewFullSubtractor(cc *Compiler, x, y, cin, d, cout *Wire) {
w1 := cc.Calloc.Wire()
cc.AddGate(cc.Calloc.BinaryGate(circuit.XNOR, y, cin, w1))
cc.AddGate(cc.Calloc.BinaryGate(circuit.XNOR, x, w1, d))
if cout != nil {
w2 := cc.Calloc.Wire()
cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, x, cin, w2))
w3 := cc.Calloc.Wire()
cc.AddGate(cc.Calloc.BinaryGate(circuit.AND, w1, w2, w3))
cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, w3, cin, cout))
}
}
// NewSubtractor creates a new subtractor circuit implementing z=x-y.
func NewSubtractor(cc *Compiler, x, y, z []*Wire) error {
x, y = cc.ZeroPad(x, y)
if len(x) > len(z) {
x = x[0:len(z)]
y = y[0:len(z)]
}
cin := cc.ZeroWire()
for i := 0; i < len(x); i++ {
var cout *Wire
if i+1 >= len(x) {
if i+1 >= len(z) {
// N-N=N, overflow, drop carry bit.
cout = nil
} else {
cout = z[i+1]
}
} else {
cout = cc.Calloc.Wire()
}
// Note y-x here.
NewFullSubtractor(cc, y[i], x[i], cin, z[i], cout)
cin = cout
}
for i := len(x) + 1; i < len(z); i++ {
z[i] = cc.ZeroWire()
}
return nil
}

View File

@ -0,0 +1,144 @@
//
// circuits_test.go
//
// Copyright (c) 2019-2023 Markku Rossi
//
// All rights reserved.
//
package circuits
import (
"fmt"
"os"
"testing"
"source.quilibrium.com/quilibrium/monorepo/bedlam/circuit"
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler/utils"
"source.quilibrium.com/quilibrium/monorepo/bedlam/types"
)
const (
verbose = false
)
var (
params = utils.NewParams()
calloc = NewAllocator()
)
func makeWires(count int, output bool) []*Wire {
var result []*Wire
for i := 0; i < count; i++ {
w := calloc.Wire()
w.SetOutput(output)
result = append(result, w)
}
return result
}
func NewIO(size int, name string) circuit.IO {
return circuit.IO{
circuit.IOArg{
Name: name,
Type: types.Info{
Type: types.TUint,
IsConcrete: true,
Bits: types.Size(size),
},
},
}
}
func TestAdd4(t *testing.T) {
bits := 4
// 2xbits inputs, bits+1 outputs
inputs := makeWires(bits*2, false)
outputs := makeWires(bits+1, true)
c, err := NewCompiler(params, calloc, NewIO(bits*2, "in"),
NewIO(bits+1, "out"), inputs, outputs)
if err != nil {
t.Fatalf("NewCompiler: %s", err)
}
cin := calloc.Wire()
NewHalfAdder(c, inputs[0], inputs[bits], outputs[0], cin)
for i := 1; i < bits; i++ {
var cout *Wire
if i+1 >= bits {
cout = outputs[bits]
} else {
cout = calloc.Wire()
}
NewFullAdder(c, inputs[i], inputs[bits+i], cin, outputs[i], cout)
cin = cout
}
result := c.Compile()
if verbose {
fmt.Printf("Result: %s\n", result)
result.Marshal(os.Stdout)
}
}
func TestFullSubtractor(t *testing.T) {
inputs := makeWires(1+2, false)
outputs := makeWires(2, true)
c, err := NewCompiler(params, calloc, NewIO(1+2, "in"), NewIO(2, "out"),
inputs, outputs)
if err != nil {
t.Fatalf("NewCompiler: %s", err)
}
NewFullSubtractor(c, inputs[0], inputs[1], inputs[2],
outputs[0], outputs[1])
result := c.Compile()
if verbose {
fmt.Printf("Result: %s\n", result)
result.Marshal(os.Stdout)
}
}
func TestMultiply1(t *testing.T) {
inputs := makeWires(2, false)
outputs := makeWires(2, true)
c, err := NewCompiler(params, calloc, NewIO(2, "in"), NewIO(2, "out"),
inputs, outputs)
if err != nil {
t.Fatalf("NewCompiler: %s", err)
}
err = NewMultiplier(c, 0, inputs[0:1], inputs[1:2], outputs)
if err != nil {
t.Error(err)
}
}
func TestMultiply(t *testing.T) {
bits := 64
inputs := makeWires(bits*2, false)
outputs := makeWires(bits*2, true)
c, err := NewCompiler(params, calloc, NewIO(bits*2, "in"),
NewIO(bits*2, "out"), inputs, outputs)
if err != nil {
t.Fatalf("NewCompiler: %s", err)
}
err = NewMultiplier(c, 0, inputs[0:bits], inputs[bits:2*bits], outputs)
if err != nil {
t.Error(err)
}
result := c.Compile()
if verbose {
fmt.Printf("Result: %s\n", result)
result.Marshal(os.Stdout)
}
}

View File

@ -0,0 +1,399 @@
//
// Copyright (c) 2019-2023 Markku Rossi
//
// All rights reserved.
//
package circuits
import (
"fmt"
"time"
"source.quilibrium.com/quilibrium/monorepo/bedlam/circuit"
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler/utils"
)
// Builtin implements a buitin circuit that uses input wires a and b
// and returns the circuit result in r.
type Builtin func(cc *Compiler, a, b, r []*Wire) error
// Compiler implements binary circuit compiler.
type Compiler struct {
Params *utils.Params
Calloc *Allocator
OutputsAssigned bool
Inputs circuit.IO
Outputs circuit.IO
InputWires []*Wire
OutputWires []*Wire
Gates []*Gate
nextWireID circuit.Wire
pending []*Gate
assigned []*Gate
compiled []circuit.Gate
invI0Wire *Wire
zeroWire *Wire
oneWire *Wire
}
// NewCompiler creates a new circuit compiler for the specified
// circuit input and output values.
func NewCompiler(params *utils.Params, calloc *Allocator,
inputs, outputs circuit.IO, inputWires, outputWires []*Wire) (
*Compiler, error) {
if len(inputWires) == 0 {
return nil, fmt.Errorf("no inputs defined")
}
return &Compiler{
Params: params,
Calloc: calloc,
Inputs: inputs,
Outputs: outputs,
InputWires: inputWires,
OutputWires: outputWires,
Gates: make([]*Gate, 0, 65536),
}, nil
}
// InvI0Wire returns a wire holding value INV(input[0]).
func (cc *Compiler) InvI0Wire() *Wire {
if cc.invI0Wire == nil {
cc.invI0Wire = cc.Calloc.Wire()
cc.AddGate(cc.Calloc.INVGate(cc.InputWires[0], cc.invI0Wire))
}
return cc.invI0Wire
}
// ZeroWire returns a wire holding value 0.
func (cc *Compiler) ZeroWire() *Wire {
if cc.zeroWire == nil {
cc.zeroWire = cc.Calloc.Wire()
cc.AddGate(cc.Calloc.BinaryGate(circuit.AND, cc.InputWires[0],
cc.InvI0Wire(), cc.zeroWire))
cc.zeroWire.SetValue(Zero)
}
return cc.zeroWire
}
// OneWire returns a wire holding value 1.
func (cc *Compiler) OneWire() *Wire {
if cc.oneWire == nil {
cc.oneWire = cc.Calloc.Wire()
cc.AddGate(cc.Calloc.BinaryGate(circuit.OR, cc.InputWires[0],
cc.InvI0Wire(), cc.oneWire))
cc.oneWire.SetValue(One)
}
return cc.oneWire
}
// ZeroPad pads the argument wires x and y with zero values so that
// the resulting wires have the same number of bits.
func (cc *Compiler) ZeroPad(x, y []*Wire) ([]*Wire, []*Wire) {
if len(x) == len(y) {
return x, y
}
max := len(x)
if len(y) > max {
max = len(y)
}
rx := make([]*Wire, max)
for i := 0; i < max; i++ {
if i < len(x) {
rx[i] = x[i]
} else {
rx[i] = cc.ZeroWire()
}
}
ry := make([]*Wire, max)
for i := 0; i < max; i++ {
if i < len(y) {
ry[i] = y[i]
} else {
ry[i] = cc.ZeroWire()
}
}
return rx, ry
}
// ShiftLeft shifts the size number of bits of the input wires w,
// count bits left.
func (cc *Compiler) ShiftLeft(w []*Wire, size, count int) []*Wire {
result := make([]*Wire, size)
if count < size {
copy(result[count:], w)
}
for i := 0; i < count; i++ {
result[i] = cc.ZeroWire()
}
for i := count + len(w); i < size; i++ {
result[i] = cc.ZeroWire()
}
return result
}
// INV creates an inverse wire inverting the input wire i's value to
// the output wire o.
func (cc *Compiler) INV(i, o *Wire) {
cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, i, cc.OneWire(), o))
}
// ID creates an identity wire passing the input wire i's value to the
// output wire o.
func (cc *Compiler) ID(i, o *Wire) {
cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, i, cc.ZeroWire(), o))
}
// AddGate adds a get into the circuit.
func (cc *Compiler) AddGate(gate *Gate) {
cc.Gates = append(cc.Gates, gate)
}
// SetNextWireID sets the next unique wire ID to use.
func (cc *Compiler) SetNextWireID(next circuit.Wire) {
cc.nextWireID = next
}
// NextWireID returns the next unique wire ID.
func (cc *Compiler) NextWireID() circuit.Wire {
ret := cc.nextWireID
cc.nextWireID++
return ret
}
// ConstPropagate propagates constant wire values in the circuit and
// short circuits gates if their output does not depend on the gate's
// logical operation.
func (cc *Compiler) ConstPropagate() {
var stats circuit.Stats
start := time.Now()
for _, g := range cc.Gates {
switch g.Op {
case circuit.XOR:
if (g.A.Value() == Zero && g.B.Value() == Zero) ||
(g.A.Value() == One && g.B.Value() == One) {
g.O.SetValue(Zero)
stats[g.Op]++
} else if (g.A.Value() == Zero && g.B.Value() == One) ||
(g.A.Value() == One && g.B.Value() == Zero) {
g.O.SetValue(One)
stats[g.Op]++
} else if g.A.Value() == Zero {
// O = B
stats[g.Op]++
g.ShortCircuit(g.B)
} else if g.B.Value() == Zero {
// O = A
stats[g.Op]++
g.ShortCircuit(g.A)
}
case circuit.XNOR:
if (g.A.Value() == Zero && g.B.Value() == Zero) ||
(g.A.Value() == One && g.B.Value() == One) {
g.O.SetValue(One)
stats[g.Op]++
} else if (g.A.Value() == Zero && g.B.Value() == One) ||
(g.A.Value() == One && g.B.Value() == Zero) {
g.O.SetValue(Zero)
stats[g.Op]++
}
case circuit.AND:
if g.A.Value() == Zero || g.B.Value() == Zero {
g.O.SetValue(Zero)
stats[g.Op]++
} else if g.A.Value() == One && g.B.Value() == One {
g.O.SetValue(One)
stats[g.Op]++
} else if g.A.Value() == One {
// O = B
stats[g.Op]++
g.ShortCircuit(g.B)
} else if g.B.Value() == One {
// O = A
stats[g.Op]++
g.ShortCircuit(g.A)
}
case circuit.OR:
if g.A.Value() == One || g.B.Value() == One {
g.O.SetValue(One)
stats[g.Op]++
} else if g.A.Value() == Zero && g.B.Value() == Zero {
g.O.SetValue(Zero)
stats[g.Op]++
} else if g.A.Value() == Zero {
// O = B
stats[g.Op]++
g.ShortCircuit(g.B)
} else if g.B.Value() == Zero {
// O = A
stats[g.Op]++
g.ShortCircuit(g.A)
}
case circuit.INV:
if g.A.Value() == One {
g.O.SetValue(Zero)
stats[g.Op]++
} else if g.A.Value() == Zero {
g.O.SetValue(One)
stats[g.Op]++
}
}
if g.A.Value() == Zero {
g.A.RemoveOutput(g)
g.A = cc.ZeroWire()
g.A.AddOutput(g)
} else if g.A.Value() == One {
g.A.RemoveOutput(g)
g.A = cc.OneWire()
g.A.AddOutput(g)
}
if g.B != nil {
if g.B.Value() == Zero {
g.B.RemoveOutput(g)
g.B = cc.ZeroWire()
g.B.AddOutput(g)
} else if g.B.Value() == One {
g.B.RemoveOutput(g)
g.B = cc.OneWire()
g.B.AddOutput(g)
}
}
}
elapsed := time.Since(start)
if cc.Params.Diagnostics && stats.Count() > 0 {
fmt.Printf(" - ConstPropagate: %12s: %d/%d (%.2f%%)\n",
elapsed, stats.Count(), len(cc.Gates),
float64(stats.Count())/float64(len(cc.Gates))*100)
}
}
// ShortCircuitXORZero short circuits input to output where input is
// XOR'ed to zero.
func (cc *Compiler) ShortCircuitXORZero() {
var stats circuit.Stats
start := time.Now()
for _, g := range cc.Gates {
if g.Op != circuit.XOR {
continue
}
if g.A.Value() == Zero && !g.B.IsInput() &&
g.B.Input().O.NumOutputs() == 1 {
g.B.Input().ResetOutput(g.O)
// Disconnect gate's output wire.
g.O = cc.Calloc.Wire()
stats[g.Op]++
}
if g.B.Value() == Zero && !g.A.IsInput() &&
g.A.Input().O.NumOutputs() == 1 {
g.A.Input().ResetOutput(g.O)
// Disconnect gate's output wire.
g.O = cc.Calloc.Wire()
stats[g.Op]++
}
}
elapsed := time.Since(start)
if cc.Params.Diagnostics && stats.Count() > 0 {
fmt.Printf(" - ShortCircuitXORZero: %12s: %d/%d (%.2f%%)\n",
elapsed, stats.Count(), len(cc.Gates),
float64(stats.Count())/float64(len(cc.Gates))*100)
}
}
// Prune removes all gates whose output wires are unused.
func (cc *Compiler) Prune() int {
n := make([]*Gate, len(cc.Gates))
nPos := len(n)
for i := len(cc.Gates) - 1; i >= 0; i-- {
g := cc.Gates[i]
if !g.Prune() {
nPos--
n[nPos] = g
}
}
cc.Gates = n[nPos:]
return nPos
}
// Compile compiles the circuit.
func (cc *Compiler) Compile() *circuit.Circuit {
if len(cc.pending) != 0 {
panic("Compile: pending set")
}
cc.pending = make([]*Gate, 0, len(cc.Gates))
if len(cc.assigned) != 0 {
panic("Compile: assigned set")
}
cc.assigned = make([]*Gate, 0, len(cc.Gates))
if len(cc.compiled) != 0 {
panic("Compile: compiled set")
}
cc.compiled = make([]circuit.Gate, 0, len(cc.Gates))
for _, w := range cc.InputWires {
w.Assign(cc)
}
for len(cc.pending) > 0 {
gate := cc.pending[0]
cc.pending = cc.pending[1:]
gate.Assign(cc)
}
// Assign outputs.
for _, w := range cc.OutputWires {
if w.Assigned() {
if !cc.OutputsAssigned {
panic("Output already assigned")
}
} else {
w.SetID(cc.NextWireID())
}
}
// Compile circuit.
for _, gate := range cc.assigned {
gate.Compile(cc)
}
var stats circuit.Stats
for _, g := range cc.compiled {
stats[g.Op]++
}
result := &circuit.Circuit{
NumGates: len(cc.compiled),
NumWires: int(cc.nextWireID),
Inputs: cc.Inputs,
Outputs: cc.Outputs,
Gates: cc.compiled,
Stats: stats,
}
return result
}

Some files were not shown because too many files have changed in this diff Show More