mirror of
https://github.com/QuilibriumNetwork/ceremonyclient.git
synced 2026-02-21 10:27:26 +08:00
rollup of non-core (e.g. not node consensus, execution, client, or protobufs so nobody can get a head start)
This commit is contained in:
parent
c3ebffc519
commit
c1b4a86072
6
.gitignore
vendored
6
.gitignore
vendored
@ -1,4 +1,6 @@
|
||||
.idea/
|
||||
node/node
|
||||
node/lunchtime-*
|
||||
vouchers/
|
||||
ceremony-client
|
||||
.config*
|
||||
@ -9,6 +11,10 @@ ceremony-client
|
||||
.task
|
||||
node-tmp-*
|
||||
build
|
||||
cover.out
|
||||
|
||||
# Rust
|
||||
target
|
||||
|
||||
# Build outputs
|
||||
vdf-test*
|
||||
|
||||
678
Cargo.lock
generated
678
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@ -16,11 +16,16 @@ members = [
|
||||
"crates/vdf",
|
||||
"crates/channel",
|
||||
"crates/channel-wasm",
|
||||
"crates/bulletproofs-wasm",
|
||||
"crates/bls48581-wasm",
|
||||
"crates/verenc-wasm",
|
||||
"crates/classgroup",
|
||||
"crates/bls48581",
|
||||
"crates/ed448-rust",
|
||||
"crates/rpm",
|
||||
"crates/bulletproofs",
|
||||
"crates/verenc",
|
||||
"crates/ferret"
|
||||
]
|
||||
|
||||
[profile.release]
|
||||
|
||||
28
bedlam/.gitignore
vendored
Normal file
28
bedlam/.gitignore
vendored
Normal file
@ -0,0 +1,28 @@
|
||||
# Binaries for programs and plugins
|
||||
*.exe
|
||||
*.exe~
|
||||
*.dll
|
||||
*.so
|
||||
*.dylib
|
||||
|
||||
# Test binary, build with `go test -c`
|
||||
*.test
|
||||
|
||||
# Output of the go coverage tool, specifically when used with LiteIDE
|
||||
*.out
|
||||
|
||||
.DS_Store
|
||||
,*
|
||||
*~
|
||||
*.dot
|
||||
*.pdf
|
||||
*.ssa
|
||||
*.prof
|
||||
*.bristol
|
||||
apps/circuit/circuit
|
||||
apps/garbled/garbled
|
||||
apps/iotest/iotest
|
||||
apps/iter/iter
|
||||
apps/qcldoc/qcldoc
|
||||
apps/objdump/objdump
|
||||
apps/ot/ot
|
||||
21
bedlam/LICENSE
Normal file
21
bedlam/LICENSE
Normal file
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2019 Markku Rossi
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
3
bedlam/README.md
Normal file
3
bedlam/README.md
Normal file
@ -0,0 +1,3 @@
|
||||
# Bedlam
|
||||
|
||||
This is a fork of Markku Rossi's MPC engine, retooled for use in Quilibrium and fitted to use FERRET for OT. There are additional divergences, but those are the most critical notes.
|
||||
1
bedlam/apps/.gitignore
vendored
Normal file
1
bedlam/apps/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
*.circ
|
||||
65
bedlam/apps/circuit/main.go
Normal file
65
bedlam/apps/circuit/main.go
Normal file
@ -0,0 +1,65 @@
|
||||
//
|
||||
// main.go
|
||||
//
|
||||
// Copyright (c) 2019-2022 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/circuit"
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
for _, file := range flag.Args() {
|
||||
c, err := circuit.Parse(file)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
fmt.Printf("digraph circuit\n{\n")
|
||||
fmt.Printf(" overlap=scale;\n")
|
||||
fmt.Printf(" node\t[fontname=\"Helvetica\"];\n")
|
||||
fmt.Printf(" {\n node [shape=plaintext];\n")
|
||||
for w := 0; w < c.NumWires; w++ {
|
||||
fmt.Printf(" w%d\t[label=\"%d\"];\n", w, w)
|
||||
}
|
||||
fmt.Printf(" }\n")
|
||||
|
||||
fmt.Printf(" {\n node [shape=box];\n")
|
||||
for idx, gate := range c.Gates {
|
||||
fmt.Printf(" g%d\t[label=\"%s\"];\n", idx, gate.Op)
|
||||
}
|
||||
fmt.Printf(" }\n")
|
||||
|
||||
if true {
|
||||
fmt.Printf(" { rank=same")
|
||||
for w := 0; w < c.Inputs.Size(); w++ {
|
||||
fmt.Printf("; w%d", w)
|
||||
}
|
||||
fmt.Printf(";}\n")
|
||||
|
||||
fmt.Printf(" { rank=same")
|
||||
for w := 0; w < c.Outputs.Size(); w++ {
|
||||
fmt.Printf("; w%d", c.NumWires-w-1)
|
||||
}
|
||||
fmt.Printf(";}\n")
|
||||
}
|
||||
|
||||
for idx, gate := range c.Gates {
|
||||
for _, i := range gate.Inputs() {
|
||||
fmt.Printf(" w%d -> g%d;\n", i, idx)
|
||||
}
|
||||
fmt.Printf(" g%d -> w%d;\n", idx, gate.Output)
|
||||
}
|
||||
fmt.Printf("}\n")
|
||||
}
|
||||
}
|
||||
122
bedlam/apps/garbled/compile.go
Normal file
122
bedlam/apps/garbled/compile.go
Normal file
@ -0,0 +1,122 @@
|
||||
//
|
||||
// main.go
|
||||
//
|
||||
// Copyright (c) 2019-2024 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/circuit"
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler"
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler/utils"
|
||||
)
|
||||
|
||||
func compileFiles(files []string, params *utils.Params, inputSizes [][]int,
|
||||
compile, ssa, dot, svg bool, circFormat string) error {
|
||||
|
||||
var circ *circuit.Circuit
|
||||
var err error
|
||||
|
||||
for _, file := range files {
|
||||
if compile {
|
||||
params.CircOut, err = makeOutput(file, circFormat)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
params.CircFormat = circFormat
|
||||
if dot {
|
||||
params.CircDotOut, err = makeOutput(file, "circ.dot")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if svg {
|
||||
params.CircSvgOut, err = makeOutput(file, "circ.svg")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
if circuit.IsFilename(file) {
|
||||
circ, err = circuit.Parse(file)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if params.CircOut != nil {
|
||||
if params.Verbose {
|
||||
fmt.Printf("Serializing circuit...\n")
|
||||
}
|
||||
err = circ.MarshalFormat(params.CircOut, params.CircFormat)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else if strings.HasSuffix(file, ".qcl") {
|
||||
if ssa {
|
||||
params.SSAOut, err = makeOutput(file, "ssa")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if dot {
|
||||
params.SSADotOut, err = makeOutput(file, "ssa.dot")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
circ, _, err = compiler.New(params).CompileFile(file, inputSizes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("unknown file type '%s'", file)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func makeOutput(base, suffix string) (io.WriteCloser, error) {
|
||||
var path string
|
||||
|
||||
idx := strings.LastIndexByte(base, '.')
|
||||
if idx < 0 {
|
||||
path = base + "." + suffix
|
||||
} else {
|
||||
path = base[:idx+1] + suffix
|
||||
}
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &OutputFile{
|
||||
File: f,
|
||||
Buffered: bufio.NewWriter(f),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// OutputFile implements a buffered output file.
|
||||
type OutputFile struct {
|
||||
File *os.File
|
||||
Buffered *bufio.Writer
|
||||
}
|
||||
|
||||
func (out *OutputFile) Write(p []byte) (nn int, err error) {
|
||||
return out.Buffered.Write(p)
|
||||
}
|
||||
|
||||
// Close implements io.Closer.Close for the buffered output file.
|
||||
func (out *OutputFile) Close() error {
|
||||
if err := out.Buffered.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
return out.File.Close()
|
||||
}
|
||||
1
bedlam/apps/garbled/data/.gitignore
vendored
Normal file
1
bedlam/apps/garbled/data/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
pre*
|
||||
BIN
bedlam/apps/garbled/default.pgo
Normal file
BIN
bedlam/apps/garbled/default.pgo
Normal file
Binary file not shown.
2
bedlam/apps/garbled/examples/.gitignore
vendored
Normal file
2
bedlam/apps/garbled/examples/.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
*.svg
|
||||
*.qclc
|
||||
9
bedlam/apps/garbled/examples/3party.qcl
Normal file
9
bedlam/apps/garbled/examples/3party.qcl
Normal file
@ -0,0 +1,9 @@
|
||||
// -*- go -*-
|
||||
|
||||
// Sample 3-party circuit where each party provides their input bit
|
||||
// and the result is bitwise AND of the inputs.
|
||||
package main
|
||||
|
||||
func main(a, b, e uint1) uint {
|
||||
return a & b & e
|
||||
}
|
||||
26
bedlam/apps/garbled/examples/add.qcl
Normal file
26
bedlam/apps/garbled/examples/add.qcl
Normal file
@ -0,0 +1,26 @@
|
||||
// -*- go -*-
|
||||
//
|
||||
|
||||
// Example of how to add two uint256 values. You can use any data type for the input as long as it is divisible by 8,
|
||||
// such as uint512.
|
||||
//
|
||||
// To run the Evaluator:
|
||||
//
|
||||
// ./garbled -e -i 0xb33d6a91b4ca8ac31c639c6742cba5a74c661a63311548af191c298a945d4891 examples/add.qcl
|
||||
//
|
||||
// To run the Garbler:
|
||||
//
|
||||
// ./garbled -i 0x5bf6db5927d799cf225f165e9508238edc5a1200fcad08c6411648733eb3100f examples/add.qcl
|
||||
//
|
||||
// The expected result should be:
|
||||
// 6877051328478326342308659403308568813546041258230645271156059935820089415840 (0x0f3445eadca224923ec2b2c5d7d3c93628c02c642dc251755a3271fdd31058a0)
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"math"
|
||||
)
|
||||
|
||||
func main(a, b uint256) uint {
|
||||
return a + b
|
||||
}
|
||||
11
bedlam/apps/garbled/examples/aesblock.qcl
Normal file
11
bedlam/apps/garbled/examples/aesblock.qcl
Normal file
@ -0,0 +1,11 @@
|
||||
// -*- go -*-
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
)
|
||||
|
||||
func main(key, data [16]byte) []byte {
|
||||
return aes.EncryptBlock(key, data)
|
||||
}
|
||||
11
bedlam/apps/garbled/examples/aesblock2.qcl
Normal file
11
bedlam/apps/garbled/examples/aesblock2.qcl
Normal file
@ -0,0 +1,11 @@
|
||||
// -*- go -*-
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
)
|
||||
|
||||
func main(key, data [16]byte) []byte {
|
||||
return aes.Block128(key, data)
|
||||
}
|
||||
54
bedlam/apps/garbled/examples/aescbc.qcl
Normal file
54
bedlam/apps/garbled/examples/aescbc.qcl
Normal file
@ -0,0 +1,54 @@
|
||||
// -*- go -*-
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/cipher/cbc"
|
||||
)
|
||||
|
||||
// Case #1: Encrypting 16 bytes (1 block) using AES-CBC with 128-bit key
|
||||
//
|
||||
// Key : 0x06a9214036b8a15b512e03d534120006
|
||||
// IV : 0x3dafba429d9eb430b422da802c9fac41
|
||||
// Plaintext : "Single block msg"
|
||||
// Ciphertext: 0xe353779c1079aeb82708942dbe77181a
|
||||
//
|
||||
// Case #2: Encrypting 32 bytes (2 blocks) using AES-CBC with 128-bit key
|
||||
//
|
||||
// Key : 0xc286696d887c9aa0611bbb3e2025a45a
|
||||
// IV : 0x562e17996d093d28ddb3ba695a2e6f58
|
||||
// Plaintext : 0x000102030405060708090a0b0c0d0e0f
|
||||
// 101112131415161718191a1b1c1d1e1f
|
||||
// Ciphertext: 0xd296cd94c2cccf8a3a863028b5e1dc0a
|
||||
//
|
||||
// 7586602d253cfff91b8266bea6d61ab1
|
||||
func main(g, e [16]byte) []byte {
|
||||
key := []byte{
|
||||
0x06, 0xa9, 0x21, 0x40, 0x36, 0xb8, 0xa1, 0x5b,
|
||||
0x51, 0x2e, 0x03, 0xd5, 0x34, 0x12, 0x00, 0x06,
|
||||
}
|
||||
iv := []byte{
|
||||
0x3d, 0xaf, 0xba, 0x42, 0x9d, 0x9e, 0xb4, 0x30,
|
||||
0xb4, 0x22, 0xda, 0x80, 0x2c, 0x9f, 0xac, 0x41,
|
||||
}
|
||||
plain := []byte("Single block msg")
|
||||
|
||||
key2 := []byte{
|
||||
0xc2, 0x86, 0x69, 0x6d, 0x88, 0x7c, 0x9a, 0xa0,
|
||||
0x61, 0x1b, 0xbb, 0x3e, 0x20, 0x25, 0xa4, 0x5a,
|
||||
}
|
||||
iv2 := []byte{
|
||||
0x56, 0x2e, 0x17, 0x99, 0x6d, 0x09, 0x3d, 0x28,
|
||||
0xdd, 0xb3, 0xba, 0x69, 0x5a, 0x2e, 0x6f, 0x58,
|
||||
}
|
||||
plain2 := []byte{
|
||||
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
|
||||
0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
|
||||
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
|
||||
0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f,
|
||||
}
|
||||
|
||||
//return cbc.EncryptAES128(key, iv, plain)
|
||||
//return cbc.EncryptAES128(key2, iv2, plain2)
|
||||
return cbc.EncryptAES128(g, iv, e)
|
||||
}
|
||||
11
bedlam/apps/garbled/examples/aesexpand.qcl
Normal file
11
bedlam/apps/garbled/examples/aesexpand.qcl
Normal file
@ -0,0 +1,11 @@
|
||||
// -*- go -*-
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
)
|
||||
|
||||
func main(key, data [16]byte) []uint {
|
||||
return aes.ExpandEncryptionKey(key)
|
||||
}
|
||||
28
bedlam/apps/garbled/examples/aesgcm.qcl
Normal file
28
bedlam/apps/garbled/examples/aesgcm.qcl
Normal file
@ -0,0 +1,28 @@
|
||||
// -*- go -*-
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/cipher/gcm"
|
||||
)
|
||||
|
||||
func main(a, b byte) (string, int) {
|
||||
key := []byte{
|
||||
0x06, 0xa9, 0x21, 0x40, 0x36, 0xb8, 0xa1, 0x5b,
|
||||
0x51, 0x2e, 0x03, 0xd5, 0x34, 0x12, 0x00, 0x06,
|
||||
}
|
||||
nonce := []byte{
|
||||
0x3d, 0xaf, 0xba, 0x42, 0x9d, 0x9e, 0xb4, 0x30,
|
||||
0xb4, 0x22, 0xda, 0x80,
|
||||
}
|
||||
plain := []byte("Single block msgSingle block msg")
|
||||
additional := []byte("additional data to be authenticated")
|
||||
|
||||
c := gcm.EncryptAES128(key, nonce, plain, additional)
|
||||
p, ok := gcm.DecryptAES128(key, nonce, c, additional)
|
||||
if !ok {
|
||||
return "Open failed", 0
|
||||
}
|
||||
return string(p), bytes.Compare(plain, p)
|
||||
}
|
||||
7
bedlam/apps/garbled/examples/and.qcl
Normal file
7
bedlam/apps/garbled/examples/and.qcl
Normal file
@ -0,0 +1,7 @@
|
||||
// -*- go -*-
|
||||
|
||||
package main
|
||||
|
||||
func main(a, b uint1) uint1 {
|
||||
return a & b
|
||||
}
|
||||
31
bedlam/apps/garbled/examples/credit.qcl
Normal file
31
bedlam/apps/garbled/examples/credit.qcl
Normal file
@ -0,0 +1,31 @@
|
||||
// -*- go -*-
|
||||
|
||||
package main
|
||||
|
||||
type Size = uint32
|
||||
|
||||
type Applicant struct {
|
||||
male bool
|
||||
age Size
|
||||
income Size
|
||||
}
|
||||
|
||||
type Bank struct {
|
||||
maxAge Size
|
||||
femaleIncome Size
|
||||
maleIncome Size
|
||||
}
|
||||
|
||||
func main(applicant Applicant, bank Bank) bool {
|
||||
// Bank sets the maximum age limit.
|
||||
if applicant.age > bank.maxAge {
|
||||
return false
|
||||
}
|
||||
if applicant.male {
|
||||
// Credit criteria for males.
|
||||
return applicant.age >= 21 && applicant.income >= bank.maleIncome
|
||||
} else {
|
||||
// Credit criteria for females.
|
||||
return applicant.age >= 18 && applicant.income >= bank.femaleIncome
|
||||
}
|
||||
}
|
||||
14
bedlam/apps/garbled/examples/div.qcl
Normal file
14
bedlam/apps/garbled/examples/div.qcl
Normal file
@ -0,0 +1,14 @@
|
||||
// -*- go -*-
|
||||
//
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"math"
|
||||
)
|
||||
|
||||
func main(a, b uint64) uint {
|
||||
return a / b
|
||||
|
||||
//return int64(math.DivUint64(uint64(a), uint64(b)))
|
||||
}
|
||||
36
bedlam/apps/garbled/examples/ecdh/keygen.qcl
Normal file
36
bedlam/apps/garbled/examples/ecdh/keygen.qcl
Normal file
@ -0,0 +1,36 @@
|
||||
// -*- go -*-
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/curve25519"
|
||||
)
|
||||
|
||||
func main(g, e [32]byte) ([]byte, []byte, []byte) {
|
||||
// var s [32]byte
|
||||
// curve25519.ScalarMult(&s, &g, &e)
|
||||
// return g, e, s
|
||||
|
||||
var privateKey [32]byte
|
||||
for i := 0; i < len(privateKey); i++ {
|
||||
//privateKey[i] = g[i] ^ e[i]
|
||||
privateKey[i] = (i % 8) + 1
|
||||
}
|
||||
|
||||
var publicKey [32]byte
|
||||
curve25519.ScalarBaseMult(&publicKey, &privateKey)
|
||||
|
||||
var privateKey2 [32]byte
|
||||
for i := 0; i < len(privateKey2); i++ {
|
||||
//privateKey[i] = g[i] ^ e[i]
|
||||
privateKey2[i] = (i % 16) + 1
|
||||
}
|
||||
|
||||
var publicKey2 [32]byte
|
||||
curve25519.ScalarBaseMult(&publicKey2, &privateKey2)
|
||||
|
||||
var secret [32]byte
|
||||
curve25519.ScalarMult(&secret, &privateKey, &publicKey2)
|
||||
|
||||
return publicKey, publicKey2, secret
|
||||
}
|
||||
105
bedlam/apps/garbled/examples/ed25519/keygen.qcl
Normal file
105
bedlam/apps/garbled/examples/ed25519/keygen.qcl
Normal file
@ -0,0 +1,105 @@
|
||||
// -*- go -*-
|
||||
|
||||
// This example implements Ed25519 key generation. Both parties
|
||||
// provide 3 random arguments:
|
||||
//
|
||||
// Seed [32]byte
|
||||
// Split [64]byte
|
||||
// Mask [64]byte
|
||||
//
|
||||
// The key generation seed is g.Seed^e.Seed. The generated private key
|
||||
// is split into two random shares:
|
||||
//
|
||||
// Garbler's share : privG = g.Split^e.Split
|
||||
// Evaluator's share: privE = privG ^ priv
|
||||
//
|
||||
// Garbler's and evaluator's shares are masked with respective masks:
|
||||
//
|
||||
// privGM = privG^g.Mask
|
||||
// privEM = privE^e.Mask
|
||||
//
|
||||
// The key generation function returns pub, privGM, and privEM to both
|
||||
// parties. After the multiparty protocol completes, both Garbler and
|
||||
// Evaluator reveal their private key share:
|
||||
//
|
||||
// privG = privGM^g.Mask
|
||||
// privE = privEM^e.Mask
|
||||
//
|
||||
// The following commands run the key generation algorithm. Both
|
||||
// Evaluator and Garbler are started with 3 random arguments:
|
||||
//
|
||||
// ./garbled -stream -e -v -i 0x784db0ec4ca0cf5338249e6a09139109366dca1fac2838e5f0e5a46f0e191bae,0xd0da45d3c99e756da831d1e7d696eae3fa9fe39d3b1b2618c7ff997d17777989b5cf415b114298c8b10bed0f0eff118e43ab606ab01143151dff89171307dffa,0x44bf09357e19b1f96f9cf6d9e7d25a0e8dd62d6e0d4bba2bec4c59983c7dc84d1486677b6d8837746cd948c881913c36faeaee08e8309afac58be4757a1c544e
|
||||
// ./garbled -stream -v -i 0x57c0e59c20ac7d75ef7e3188fdd7f5876abee1cab394af8125acaca9760bb54c,0x76b42e6292f4a3dc339d208481abeb9a24e08127c7cd8dbde62abcddc0c0e6f7a0f740e756b44dae137f0e7ff8eae0ceb1a962c130fdcbe8cbee3e31ab55b8dc,0xeb83eb1f5203f5b752c96264a21ff4a27fa60cf2313f5f53c3fa96e0b52a2814b786e43a3af64b66291b5b29f432cb8d5a930e31f4e6f072a6d33b861b5b5f13 examples/ed25519/keygen.qcl
|
||||
//
|
||||
// The example values return the following results:
|
||||
//
|
||||
// Result[0]: 8ae64963506002e267a59665e9a2e6f9348cc159be53747894478e182ece9fcb
|
||||
// Result[1]: 4ded80ae09692306c9659307f522f5dba1d96e48cde9f4f6e22fb340629db76aa2bee5867d009e008b6fb85902273acda8910c9a740a788f70c28ca0a3093835
|
||||
// Result[2]: cd5c37f4497fd56e236aa858442b3ff90f7a6401ee2186ea18d074fe93d8f9d18b582fa47a1ee0f0a9083ddd9e262b8f3c642dfad68f667f87dddd4bec80aca3
|
||||
//
|
||||
// The Garbler reveals its private key share by computing Result[1]^g.Mask:
|
||||
//
|
||||
// 4ded80ae09692306c9659307f522f5dba1d96e48cde9f4f6e22fb340629db76aa2bee5867d009e008b6fb85902273acda8910c9a740a788f70c28ca0a3093835
|
||||
// ^ eb83eb1f5203f5b752c96264a21ff4a27fa60cf2313f5f53c3fa96e0b52a2814b786e43a3af64b66291b5b29f432cb8d5a930e31f4e6f072a6d33b861b5b5f13
|
||||
// = a66e6bb15b6ad6b19bacf163573d0179de7f62bafcd6aba521d525a0d7b79f7e153801bc47f6d566a274e370f615f140f20202ab80ec88fdd611b726b8526726
|
||||
//
|
||||
// And Evaluator does the same computation for its values: Result[2]^e.Mask:
|
||||
//
|
||||
// cd5c37f4497fd56e236aa858442b3ff90f7a6401ee2186ea18d074fe93d8f9d18b582fa47a1ee0f0a9083ddd9e262b8f3c642dfad68f667f87dddd4bec80aca3
|
||||
// ^ 44bf09357e19b1f96f9cf6d9e7d25a0e8dd62d6e0d4bba2bec4c59983c7dc84d1486677b6d8837746cd948c881913c36faeaee08e8309afac58be4757a1c544e
|
||||
// = 89e33ec1376664974cf65e81a3f965f782ac496fe36a3cc1f49c2d66afa5319c9fde48df1796d784c5d175151fb717b9c68ec3f23ebffc854256393e969cf8ed
|
||||
//
|
||||
// Now Garbler and Evaluator have their private key shares and public
|
||||
// key, but they do not know the full combined Ed25519 private
|
||||
// key. The private key shares can be used in the `sign.qcl' example
|
||||
// which computes the signature with the combined private key:
|
||||
//
|
||||
// privG ^ privE = priv:
|
||||
// priv: 2f8d55706c0cb226d75aafe2f4c4648e5cd32bd51fbc9764d54908c67812aee2
|
||||
// 8ae64963506002e267a59665e9a2e6f9348cc159be53747894478e182ece9fcb
|
||||
// pub : 8ae64963506002e267a59665e9a2e6f9348cc159be53747894478e182ece9fcb
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
)
|
||||
|
||||
type Arguments struct {
|
||||
Seed [32]byte
|
||||
Split [64]byte
|
||||
Mask [64]byte
|
||||
}
|
||||
|
||||
func main(g, e Arguments) ([]byte, []byte, []byte) {
|
||||
var seed [32]byte
|
||||
|
||||
// Construct seed from peer seeds.
|
||||
for i := 0; i < len(seed); i++ {
|
||||
seed[i] = g.Seed[i] ^ e.Seed[i]
|
||||
}
|
||||
pub, priv := ed25519.NewKeyFromSeed(seed)
|
||||
|
||||
// Garbler's private key share is random value, constructed from
|
||||
// peer split values.
|
||||
var privG [64]byte
|
||||
for i := 0; i < len(privG); i++ {
|
||||
privG[i] = g.Split[i] ^ e.Split[i]
|
||||
}
|
||||
|
||||
// Evaluator's private key share is real private key, xor'ed with
|
||||
// Garbler's private key share.
|
||||
var privE [64]byte
|
||||
for i := 0; i < len(privE); i++ {
|
||||
privE[i] = priv[i] ^ privG[i]
|
||||
}
|
||||
|
||||
// Mask private key shares.
|
||||
for i := 0; i < len(privG); i++ {
|
||||
privG[i] ^= g.Mask[i]
|
||||
}
|
||||
for i := 0; i < len(privE); i++ {
|
||||
privE[i] ^= e.Mask[i]
|
||||
}
|
||||
|
||||
return pub, privG, privE
|
||||
}
|
||||
55
bedlam/apps/garbled/examples/ed25519/sign.qcl
Normal file
55
bedlam/apps/garbled/examples/ed25519/sign.qcl
Normal file
@ -0,0 +1,55 @@
|
||||
// -*- go -*-
|
||||
|
||||
// This example implements Ed25519 signature computation. The Ed25519
|
||||
// keypair is:
|
||||
//
|
||||
// pub : 8ae64963506002e267a59665e9a2e6f9348cc159be53747894478e182ece9fcb
|
||||
//
|
||||
// priv : 2f8d55706c0cb226d75aafe2f4c4648e5cd32bd51fbc9764d54908c67812aee2
|
||||
// 8ae64963506002e267a59665e9a2e6f9348cc159be53747894478e182ece9fcb
|
||||
//
|
||||
// The Garbler and Evaluator share the private key as two random
|
||||
// shares. The private key is contructed during the signature
|
||||
// computation by XOR:ing the random shares together:
|
||||
//
|
||||
// privG: a66e6bb15b6ad6b19bacf163573d0179de7f62bafcd6aba521d525a0d7b79f7e
|
||||
// 153801bc47f6d566a274e370f615f140f20202ab80ec88fdd611b726b8526726
|
||||
//
|
||||
//
|
||||
// privE: 89e33ec1376664974cf65e81a3f965f782ac496fe36a3cc1f49c2d66afa5319c
|
||||
// 9fde48df1796d784c5d175151fb717b9c68ec3f23ebffc854256393e969cf8ed
|
||||
//
|
||||
// priv = privG ^ privE
|
||||
//
|
||||
// Run the Evaluator with one input: the Evaluator's private key share:
|
||||
//
|
||||
// ./garbled -e -v -stream -i 0x89e33ec1376664974cf65e81a3f965f782ac496fe36a3cc1f49c2d66afa5319c9fde48df1796d784c5d175151fb717b9c68ec3f23ebffc854256393e969cf8ed
|
||||
//
|
||||
// The Garbler takes two inputs: the message to sign, and the
|
||||
// Garbler's private key share:
|
||||
//
|
||||
// ./garbled -stream -v -i 0x4d61726b6b7520526f737369203c6d747240696b692e66693e2068747470733a2f2f7777772e6d61726b6b75726f7373692e636f6d2f,0xa66e6bb15b6ad6b19bacf163573d0179de7f62bafcd6aba521d525a0d7b79f7e153801bc47f6d566a274e370f615f140f20202ab80ec88fdd611b726b8526726 examples/ed25519/sign.qcl
|
||||
//
|
||||
// The result signature is:
|
||||
//
|
||||
// Result[0]: f43c1cf1755345852211942af0838414334eec9cbf36e26a9f9e8d4bb720deb145ffbeec82249c875116757441206bcdc56b501e750f1f590917d772dfee980f
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
)
|
||||
|
||||
type Garbler struct {
|
||||
msg [64]byte
|
||||
privShare [64]byte
|
||||
}
|
||||
|
||||
func main(g Garbler, privShare [64]byte) []byte {
|
||||
var priv [64]byte
|
||||
|
||||
for i := 0; i < len(priv); i++ {
|
||||
priv[i] = g.privShare[i] ^ privShare[i]
|
||||
}
|
||||
|
||||
return ed25519.Sign(priv, g.msg)
|
||||
}
|
||||
44
bedlam/apps/garbled/examples/encrypt.qcl
Normal file
44
bedlam/apps/garbled/examples/encrypt.qcl
Normal file
@ -0,0 +1,44 @@
|
||||
// -*- go -*-
|
||||
|
||||
// Example how to encrypt fixed sized data with AES-128-GCM.
|
||||
//
|
||||
// Run the Evaluator with two inputs: evaluator's key and nonce shares:
|
||||
//
|
||||
// ./garbled -e -i 0x8cd98b88adab08d6d60fe57c8b8a33f3,0xfd5e0f8f155e7102aa526ad0 examples/encrypt.qcl
|
||||
//
|
||||
// The Garbler takes three arguments: the message to encrypt, and its
|
||||
// key and nonce shares:
|
||||
//
|
||||
// ./garbled -i 0x48656c6c6f2c20776f726c6421,0xed800b17b0c9d2334b249332155ddef5,0xa300751458c775a08762c2cd examples/encrypt.qcl
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/cipher/gcm"
|
||||
)
|
||||
|
||||
type Garbler struct {
|
||||
msg [64]byte
|
||||
keyShare [16]byte
|
||||
nonceShare [12]byte
|
||||
}
|
||||
|
||||
type Evaluator struct {
|
||||
keyShare [16]byte
|
||||
nonceShare [12]byte
|
||||
}
|
||||
|
||||
func main(g Garbler, e Evaluator) []byte {
|
||||
var key [16]byte
|
||||
|
||||
for i := 0; i < len(key); i++ {
|
||||
key[i] = g.keyShare[i] ^ e.keyShare[i]
|
||||
}
|
||||
|
||||
var nonce [12]byte
|
||||
|
||||
for i := 0; i < len(nonce); i++ {
|
||||
nonce[i] = g.nonceShare[i] ^ e.nonceShare[i]
|
||||
}
|
||||
|
||||
return gcm.EncryptAES128(key, nonce, g.msg, []byte("unused"))
|
||||
}
|
||||
11
bedlam/apps/garbled/examples/hamming.qcl
Normal file
11
bedlam/apps/garbled/examples/hamming.qcl
Normal file
@ -0,0 +1,11 @@
|
||||
// -*- go -*-
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
)
|
||||
|
||||
func main(a, b uint1024) uint {
|
||||
return binary.HammingDistance(a, b)
|
||||
}
|
||||
61
bedlam/apps/garbled/examples/hmac-sha256.qcl
Normal file
61
bedlam/apps/garbled/examples/hmac-sha256.qcl
Normal file
@ -0,0 +1,61 @@
|
||||
// -*- go -*-
|
||||
|
||||
// This example computes HMAC-SHA256 where the HMAC key is shared as
|
||||
// two random shares between garbler and evaluator. The garbler's key
|
||||
// share is:
|
||||
//
|
||||
// keyG: 4de216d2fdc9301e5b9c78486f7109a05670d200d9e2f275ec0aad08ec42af47
|
||||
// fcb59bf460d50b01333a748f3a9efb13e08036d49a26c21ba2e33a5f8a2cf0e7
|
||||
//
|
||||
// The evaluator's key share is:
|
||||
//
|
||||
// keyE: f87a00ef89c2396de32f6ac0748f6fa1b641013d46f74ce25cc625904215a675
|
||||
// 01c0c7196a2602f6516527958a82271847933c35d170d98bfdb04d2ddf3bb197
|
||||
//
|
||||
// The final HMAC key is keyG ^ keyE:
|
||||
//
|
||||
// key : b598163d740b0973b8b312881bfe6601e031d33d9f15be97b0cc8898ae570932
|
||||
// fd755ced0af309f7625f531ab01cdc0ba7130ae14b561b905f53777255174170
|
||||
//
|
||||
// The example uses 32-byte messages (Garbler.msg) so with the message:
|
||||
//
|
||||
// msg : Hello, world!...................
|
||||
// hex : 48656c6c6f2c20776f726c64212e2e2e2e2e2e2e2e2e2e2e2e2e2e2e2e2e2e2e
|
||||
//
|
||||
// We expect to get the following HMAC-SHA256 output:
|
||||
//
|
||||
// sum : 60d27dbd14f1e351f20069171fead00ef557d17ac9a41d02baa488ca4b90171a
|
||||
//
|
||||
// Now we can run the MPC computation as follows. First, run the
|
||||
// evaluator with one input: the evaluator's key share:
|
||||
//
|
||||
// ./garbled -e -v -i 0xf87a00ef89c2396de32f6ac0748f6fa1b641013d46f74ce25cc625904215a67501c0c7196a2602f6516527958a82271847933c35d170d98bfdb04d2ddf3bb197 examples/hmac-sha256.qcl
|
||||
//
|
||||
// The garbler takes two inputs: the message and the garbler's key
|
||||
// share:
|
||||
//
|
||||
// ./garbled -v -i 0x48656c6c6f2c20776f726c64212e2e2e2e2e2e2e2e2e2e2e2e2e2e2e2e2e2e2e,0x4de216d2fdc9301e5b9c78486f7109a05670d200d9e2f275ec0aad08ec42af47fcb59bf460d50b01333a748f3a9efb13e08036d49a26c21ba2e33a5f8a2cf0e7 examples/hmac-sha256.qcl
|
||||
//
|
||||
// The MCP computation providers the expected HMAC result:
|
||||
//
|
||||
// Result[0]: 60d27dbd14f1e351f20069171fead00ef557d17ac9a41d02baa488ca4b90171a
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
)
|
||||
|
||||
type Garbler struct {
|
||||
msg []byte
|
||||
keyShare [64]byte
|
||||
}
|
||||
|
||||
func main(g Garbler, eKeyShare [64]byte) []byte {
|
||||
var key [64]byte
|
||||
|
||||
for i := 0; i < len(key); i++ {
|
||||
key[i] = g.keyShare[i] ^ eKeyShare[i]
|
||||
}
|
||||
|
||||
return hmac.SumSHA256(g.msg, key)
|
||||
}
|
||||
62
bedlam/apps/garbled/examples/hmac-sha512.qcl
Normal file
62
bedlam/apps/garbled/examples/hmac-sha512.qcl
Normal file
@ -0,0 +1,62 @@
|
||||
// -*- go -*-
|
||||
|
||||
// This example computes HMAC-SHA256 where the HMAC key is shared as
|
||||
// two random shares between garbler and evaluator. The garbler's key
|
||||
// share is:
|
||||
//
|
||||
// keyG: 4de216d2fdc9301e5b9c78486f7109a05670d200d9e2f275ec0aad08ec42af47
|
||||
// fcb59bf460d50b01333a748f3a9efb13e08036d49a26c21ba2e33a5f8a2cf0e7
|
||||
//
|
||||
// The evaluator's key share is:
|
||||
//
|
||||
// keyE: f87a00ef89c2396de32f6ac0748f6fa1b641013d46f74ce25cc625904215a675
|
||||
// 01c0c7196a2602f6516527958a82271847933c35d170d98bfdb04d2ddf3bb197
|
||||
//
|
||||
// The final HMAC key is keyG ^ keyE:
|
||||
//
|
||||
// key : b598163d740b0973b8b312881bfe6601e031d33d9f15be97b0cc8898ae570932
|
||||
// fd755ced0af309f7625f531ab01cdc0ba7130ae14b561b905f53777255174170
|
||||
//
|
||||
// The example uses 32-byte messages (Garbler.msg) so with the message:
|
||||
//
|
||||
// msg : Hello, world!...................
|
||||
// hex : 48656c6c6f2c20776f726c64212e2e2e2e2e2e2e2e2e2e2e2e2e2e2e2e2e2e2e
|
||||
//
|
||||
// We expect to get the following HMAC-SHA256 output:
|
||||
//
|
||||
// sum : 89c648c4d7b4220f6767706dec64f69bbdb2725d062a09a447b9cf32af8636ee8853f92ca59c0e81712b72c79f3503f7f2131d20c7a5dfae87b79f839cecf2c4
|
||||
//
|
||||
// Now we can run the MPC computation as follows. First, run the
|
||||
// evaluator with one input: the evaluator's key share:
|
||||
//
|
||||
// ./garbled -e -v -i 0xf87a00ef89c2396de32f6ac0748f6fa1b641013d46f74ce25cc625904215a67501c0c7196a2602f6516527958a82271847933c35d170d98bfdb04d2ddf3bb197 examples/hmac-sha512.qcl
|
||||
//
|
||||
// The garbler takes two inputs: the message and the garbler's key
|
||||
// share:
|
||||
//
|
||||
// ./garbled -v -i 0x48656c6c6f2c20776f726c64212e2e2e2e2e2e2e2e2e2e2e2e2e2e2e2e2e2e2e,0x4de216d2fdc9301e5b9c78486f7109a05670d200d9e2f275ec0aad08ec42af47fcb59bf460d50b01333a748f3a9efb13e08036d49a26c21ba2e33a5f8a2cf0e7 examples/hmac-sha512.qcl
|
||||
//
|
||||
// The MCP computation providers the expected HMAC result:
|
||||
//
|
||||
// Result[0]: 89c648c4d7b4220f6767706dec64f69bbdb2725d062a09a447b9cf32af8636ee8853f92ca59c0e81712b72c79f3503f7f2131d20c7a5dfae87b79f839cecf2c4
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
)
|
||||
|
||||
type Garbler struct {
|
||||
msg []byte
|
||||
keyShare [64]byte
|
||||
}
|
||||
|
||||
func main(g Garbler, eKeyShare [64]byte) []byte {
|
||||
var key [64]byte
|
||||
|
||||
for i := 0; i < len(key); i++ {
|
||||
key[i] = g.keyShare[i] ^ eKeyShare[i]
|
||||
}
|
||||
|
||||
return hmac.SumSHA512(g.msg, key)
|
||||
}
|
||||
102
bedlam/apps/garbled/examples/key-import.qcl
Normal file
102
bedlam/apps/garbled/examples/key-import.qcl
Normal file
@ -0,0 +1,102 @@
|
||||
// -*- go -*-
|
||||
|
||||
// This example shows how a key can be imported to MCP peers garbler
|
||||
// and evaluator so that after the import operation both peers hold a
|
||||
// random share of the key, and shareG^shareE=key.
|
||||
//
|
||||
// The garbler provides three arguments:
|
||||
//
|
||||
// key [64]byte: the key to import
|
||||
// split [64]byte: random input for the key split value
|
||||
// mask [64]byte: random mask to mask garbler's key share result
|
||||
//
|
||||
// The evaluator provides two arguments: the split and the mask values.
|
||||
//
|
||||
// The MPC program splits the key into two random shares:
|
||||
//
|
||||
// garbler's share : keyG = g.Split ^ e.Split
|
||||
// evaluator's share: keyE = keyG ^ g.Key
|
||||
//
|
||||
// The result shares are masked with respective masks and returned as
|
||||
// the result of the computation:
|
||||
//
|
||||
// keyGM = keyG ^ g.Mask
|
||||
// keyEM = keyE ^ e.Mask
|
||||
//
|
||||
// Finally, both parties can extract their key shares because the know
|
||||
// their mask values:
|
||||
//
|
||||
// keyG = keyGM ^ g.Mask
|
||||
// keyE = keyEM ^ e.Mask
|
||||
//
|
||||
// Using the example values:
|
||||
//
|
||||
// key: a968f050ebd5c4ed2ddf9717f0f0fd9325b07c68ff5d62094800f5b69464bab9
|
||||
// 8dd886a7c49460503fafa75f5f7430f2cdda7bd5cb60c1cbd471e35d67432d58
|
||||
// splitG: 7e1d9bb27838f5c8481b7194f07b5f3059f9471ae8e69ea3fe79c629a92588d9
|
||||
// 524a6e4364e77d222210135f6c5435a8be52fc99ad8fc8280e8207cac91fc7b3
|
||||
// maskG: bed1bc2a3e6089bd016ff0175c62346438a9eb7b741f41787e5f7aad1720ee08
|
||||
// 233a89e81e3bbd5eef26d158750a0fdd47471ded518d781f23de6d4346ea68ad
|
||||
//
|
||||
// splitE: b2146ed7385d63a76f599b27f03e83971149208b0c41604eea010806460a3266
|
||||
// 93820075bc25c485b2bfcb9226488ba961eeb07980f8ab374b38f793e41e5247
|
||||
// maskE: a0698ff8e72f51bf3bff3895c80a8ba8a527abaa5a7603391545ed5dcebb22b5
|
||||
// a2f191bcb3ac3a543cfdba99bded67a3ac6f5f254ff7e5c34520312c9b91f672
|
||||
//
|
||||
// we run the evaluator with two arguments:
|
||||
//
|
||||
// ./garbled -e -v -i 0xb2146ed7385d63a76f599b27f03e83971149208b0c41604eea010806460a326693820075bc25c485b2bfcb9226488ba961eeb07980f8ab374b38f793e41e5247,0xa0698ff8e72f51bf3bff3895c80a8ba8a527abaa5a7603391545ed5dcebb22b5a2f191bcb3ac3a543cfdba99bded67a3ac6f5f254ff7e5c34520312c9b91f672 examples/key-import.qcl
|
||||
//
|
||||
// and the garbler with three arguments:
|
||||
//
|
||||
// ./garbled -v -i 0xa968f050ebd5c4ed2ddf9717f0f0fd9325b07c68ff5d62094800f5b69464bab98dd886a7c49460503fafa75f5f7430f2cdda7bd5cb60c1cbd471e35d67432d58,0x7e1d9bb27838f5c8481b7194f07b5f3059f9471ae8e69ea3fe79c629a92588d9524a6e4364e77d222210135f6c5435a8be52fc99ad8fc8280e8207cac91fc7b3,0xbed1bc2a3e6089bd016ff0175c62346438a9eb7b741f41787e5f7aad1720ee08233a89e81e3bbd5eef26d158750a0fdd47471ded518d781f23de6d4346ea68ad examples/key-import.qcl
|
||||
//
|
||||
// The program returns two values: garbler's and evaluator's masked
|
||||
// key shares:
|
||||
//
|
||||
// Result[0]: 72d8494f7e051fd2262d1aa45c27e8c370198cea90b8bf956a27b482f80f54b7e2f2e7dec6f904f97f8909953f16b1dc98fb510d7cfa1b0066649d1a6bebfd59
|
||||
// Result[1]: c5088acd4c9f033d3162453138bfaa9cc827b053418c9fdd493dd6c4b5f022b3eee1792daffae3a393fdc50ba885e950be096810a9e04717d4eb2228d1d34ede
|
||||
//
|
||||
// and both peers can extract their key shares by XOR:ing their result
|
||||
// with their mask value:
|
||||
//
|
||||
// shareG: cc09f5654065966f2742eab30045dca748b06791e4a7feed1478ce2fef2fbabf
|
||||
// c1c86e36d8c2b9a790afd8cd4a1cbe01dfbc4ce02d77631f45baf0592d0195f4
|
||||
// shareE: 65610535abb052820a9d7da4f0b521346d001bf91bfa9ce45c783b997b4b0006
|
||||
// 4c10e8911c56d9f7af007f9215688ef312663735e617a2d491cb13044a42b8ac
|
||||
//
|
||||
// and we see that shareG^shareE = key
|
||||
package main
|
||||
|
||||
type Garbler struct {
|
||||
Key [64]byte
|
||||
Split [64]byte
|
||||
Mask [64]byte
|
||||
}
|
||||
|
||||
type Evaluator struct {
|
||||
Split [64]byte
|
||||
Mask [64]byte
|
||||
}
|
||||
|
||||
func main(g Garbler, e Evaluator) ([]byte, []byte) {
|
||||
var keyG [64]byte
|
||||
var keyE [64]byte
|
||||
var keyGM [64]byte
|
||||
var keyEM [64]byte
|
||||
|
||||
for i := 0; i < len(keyG); i++ {
|
||||
keyG[i] = g.Split[i] ^ e.Split[i]
|
||||
}
|
||||
for i := 0; i < len(keyE); i++ {
|
||||
keyE[i] = keyG[i] ^ g.Key[i]
|
||||
}
|
||||
|
||||
for i := 0; i < len(keyGM); i++ {
|
||||
keyGM[i] = keyG[i] ^ g.Mask[i]
|
||||
}
|
||||
for i := 0; i < len(keyEM); i++ {
|
||||
keyEM[i] = keyE[i] ^ e.Mask[i]
|
||||
}
|
||||
return keyGM, keyEM
|
||||
}
|
||||
13
bedlam/apps/garbled/examples/millionaire.qcl
Normal file
13
bedlam/apps/garbled/examples/millionaire.qcl
Normal file
@ -0,0 +1,13 @@
|
||||
// -*- go -*-
|
||||
//
|
||||
|
||||
// Yao's Millionaires' problem with int64 values.
|
||||
package main
|
||||
|
||||
func main(a, b int64) bool {
|
||||
if a > b {
|
||||
return true
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
39
bedlam/apps/garbled/examples/montgomery.qcl
Normal file
39
bedlam/apps/garbled/examples/montgomery.qcl
Normal file
@ -0,0 +1,39 @@
|
||||
// -*- go -*-
|
||||
//
|
||||
|
||||
// RSA encryption with Montgomery modular multiplication.
|
||||
//
|
||||
// ./garbled -e -v -i 0x321af130 examples/montgomery.qcl
|
||||
// ./garbled -v -i 0x6d7472,9,0xd60b2b09,0x10001 examples/montgomery.qcl
|
||||
package main
|
||||
|
||||
import (
|
||||
"math"
|
||||
)
|
||||
|
||||
type Size = uint64
|
||||
|
||||
type Garbler struct {
|
||||
msg Size
|
||||
privShare Size
|
||||
pubN Size
|
||||
pubE Size
|
||||
}
|
||||
|
||||
func main(g Garbler, privShare Size) (uint, uint) {
|
||||
|
||||
priv := g.privShare + privShare
|
||||
|
||||
cipher := Encrypt(g.msg, g.pubE, g.pubN)
|
||||
plain := Decrypt(cipher, priv, g.pubN)
|
||||
|
||||
return cipher, plain
|
||||
}
|
||||
|
||||
func Encrypt(msg, e, n uint) uint {
|
||||
return math.ExpMontgomery(msg, e, n)
|
||||
}
|
||||
|
||||
func Decrypt(cipher, d, n uint) uint {
|
||||
return math.ExpMontgomery(cipher, d, n)
|
||||
}
|
||||
8
bedlam/apps/garbled/examples/mult.qcl
Normal file
8
bedlam/apps/garbled/examples/mult.qcl
Normal file
@ -0,0 +1,8 @@
|
||||
// -*- go -*-
|
||||
//
|
||||
|
||||
package main
|
||||
|
||||
func main(a, b uint64) uint {
|
||||
return a * b
|
||||
}
|
||||
8
bedlam/apps/garbled/examples/mult1024.qcl
Normal file
8
bedlam/apps/garbled/examples/mult1024.qcl
Normal file
@ -0,0 +1,8 @@
|
||||
// -*- go -*-
|
||||
//
|
||||
|
||||
package main
|
||||
|
||||
func main(a, b int1024) int1024 {
|
||||
return a * b
|
||||
}
|
||||
15
bedlam/apps/garbled/examples/rps.qcl
Normal file
15
bedlam/apps/garbled/examples/rps.qcl
Normal file
@ -0,0 +1,15 @@
|
||||
// -*- go -*-
|
||||
//
|
||||
// Rock, Paper, Scissors
|
||||
|
||||
package main
|
||||
|
||||
func main(g, e int8) string {
|
||||
if g == e {
|
||||
return "-"
|
||||
}
|
||||
if (g+1)%3 == e {
|
||||
return "evaluator"
|
||||
}
|
||||
return "garbler"
|
||||
}
|
||||
45
bedlam/apps/garbled/examples/rsa.qcl
Normal file
45
bedlam/apps/garbled/examples/rsa.qcl
Normal file
@ -0,0 +1,45 @@
|
||||
// -*- go -*-
|
||||
//
|
||||
|
||||
// 32-bit RSA encryption and decryption.
|
||||
//
|
||||
// The key parameters are:
|
||||
//
|
||||
// d: 0x321af139
|
||||
// n: 0xd60b2b09
|
||||
// e: 0x10001
|
||||
//
|
||||
// private: d, n
|
||||
// public: e, n
|
||||
//
|
||||
// msg: 0x6d7472
|
||||
// cipher: 0x61f9ef88
|
||||
//
|
||||
// Run garbler and evaluator as follows:
|
||||
//
|
||||
// ./garbled -e -v -i 9 examples/rsa.qcl
|
||||
// ./garbled -v -i 0x6d7472,0x321af130,0xd60b2b09,0x10001 examples/rsa.qcl
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
)
|
||||
|
||||
type Size = uint32
|
||||
|
||||
type Garbler struct {
|
||||
msg Size
|
||||
privShare Size
|
||||
pubN Size
|
||||
pubE Size
|
||||
}
|
||||
|
||||
func main(g Garbler, privShare Size) (uint, uint) {
|
||||
|
||||
priv := g.privShare + privShare
|
||||
|
||||
cipher := rsa.Encrypt(g.msg, g.pubE, g.pubN)
|
||||
plain := rsa.Decrypt(cipher, priv, g.pubN)
|
||||
|
||||
return cipher, plain
|
||||
}
|
||||
42
bedlam/apps/garbled/examples/rsasign.qcl
Normal file
42
bedlam/apps/garbled/examples/rsasign.qcl
Normal file
@ -0,0 +1,42 @@
|
||||
// -*- go -*-
|
||||
//
|
||||
|
||||
// RSA signature with Size bits.
|
||||
//
|
||||
// The key parameters are:
|
||||
//
|
||||
// d: 0x321af139
|
||||
// n: 0xd60b2b09
|
||||
// e: 0x10001
|
||||
//
|
||||
// private: d, n
|
||||
// public: e, n
|
||||
//
|
||||
// msg: 0x6d7472
|
||||
// signature: 0x55a83b79
|
||||
//
|
||||
// Run garbler and evaluator as follows:
|
||||
//
|
||||
// ./garbled -e -v -i 9 examples/rsasign.qcl
|
||||
// ./garbled -v -i 0x6d7472,0x321af130,0xd60b2b09,0x10001 examples/rsasign.qcl
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
)
|
||||
|
||||
type Size = uint512
|
||||
|
||||
type Garbler struct {
|
||||
msg Size
|
||||
privShare Size
|
||||
pubN Size
|
||||
pubE Size
|
||||
}
|
||||
|
||||
func main(g Garbler, privShare Size) uint {
|
||||
|
||||
priv := g.privShare + privShare
|
||||
|
||||
return rsa.Decrypt(g.msg, priv, g.pubN)
|
||||
}
|
||||
33
bedlam/apps/garbled/examples/sort.qcl
Normal file
33
bedlam/apps/garbled/examples/sort.qcl
Normal file
@ -0,0 +1,33 @@
|
||||
// -*- go -*-
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"sort"
|
||||
)
|
||||
|
||||
var input = []int9{
|
||||
136, 142, 146, 165, 183, 189, 220, 223, 232, 235, 67, 73, 77, 88,
|
||||
91, 93, 95, 97, 98, 132, 5, 6, 7, 10, 14, 18, 18, 37, 50, 64, 245,
|
||||
249, 252, 136, 142, 146, 165, 183, 189, 220, 223, 232, 235, 67,
|
||||
73, 77, 88, 91, 93, 95, 97, 98, 132, 136, 142, 146, 165, 183, 189,
|
||||
220, 223, 232, 235, 67, 73, 77, 88, 91, 93, 95, 97, 98, 132, 5, 6,
|
||||
7, 10, 14, 18, 18, 37, 50, 64, 245, 249, 252, 136, 142, 146, 165,
|
||||
183, 189, 220, 223, 232, 235, 67, 73, 77, 88, 91, 93, 95, 97, 98,
|
||||
132, 136, 142, 146, 165, 183, 189, 220, 223, 232, 235, 67, 73, 77,
|
||||
88, 91, 93, 95, 97, 98, 132, 5, 6, 7, 10, 14, 18, 18, 37, 50, 64,
|
||||
245, 249, 252, 136, 142, 146, 165, 183, 189, 220, 223, 232, 235,
|
||||
67, 73, 77, 88, 91, 93, 95, 97, 98, 132, 5, 6, 7, 10, 14, 18, 18,
|
||||
37, 50, 64, 245, 249, 252, 136, 142, 146, 165, 183, 189, 220, 223,
|
||||
232, 235, 67, 73, 77, 88, 91, 93, 95, 97, 98, 132, 136, 142, 146,
|
||||
165, 183, 189, 220, 223, 232, 235, 67, 73, 77, 88, 91, 93, 95, 97,
|
||||
98, 132, 5, 6, 7, 10, 14, 18, 18, 37, 50, 64, 245, 249, 252, 136,
|
||||
142, 146, 165, 183, 189, 220, 223, 232, 235, 67, 73, 77, 88, 91,
|
||||
93, 95, 97, 98, 132, 136, 142, 146, 165, 183, 189, 220, 223, 232,
|
||||
235, 67, 73, 77, 88, 91, 93, 95, 97, 98, 132, 5, 6, 7, 10, 14, 18,
|
||||
18, 37, 50, 64, 245, 249, 252,
|
||||
}
|
||||
|
||||
func main(g, e byte) []int {
|
||||
return sort.Reverse(sort.Slice(input))
|
||||
}
|
||||
364
bedlam/apps/garbled/main.go
Normal file
364
bedlam/apps/garbled/main.go
Normal file
@ -0,0 +1,364 @@
|
||||
//
|
||||
// main.go
|
||||
//
|
||||
// Copyright (c) 2019-2024 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"runtime"
|
||||
"runtime/pprof"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam"
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/circuit"
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler"
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler/utils"
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/ot"
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/p2p"
|
||||
)
|
||||
|
||||
var (
|
||||
verbose = false
|
||||
)
|
||||
|
||||
type input []string
|
||||
|
||||
func (i *input) String() string {
|
||||
return fmt.Sprint(*i)
|
||||
}
|
||||
|
||||
func (i *input) Set(value string) error {
|
||||
for _, v := range strings.Split(value, ",") {
|
||||
*i = append(*i, v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var inputFlag, peerFlag input
|
||||
|
||||
func init() {
|
||||
flag.Var(&inputFlag, "i", "comma-separated list of circuit inputs")
|
||||
flag.Var(&peerFlag, "pi", "comma-separated list of peer's circuit inputs")
|
||||
}
|
||||
|
||||
func main() {
|
||||
evaluator := flag.Bool("e", false, "evaluator / garbler mode")
|
||||
stream := flag.Bool("stream", false, "streaming mode")
|
||||
compile := flag.Bool("circ", false, "compile QCL to circuit")
|
||||
circFormat := flag.String("format", "qclc",
|
||||
"circuit format: qclc, bristol")
|
||||
ssa := flag.Bool("ssa", false, "compile QCL to SSA assembly")
|
||||
dot := flag.Bool("dot", false, "create Graphviz DOT output")
|
||||
svg := flag.Bool("svg", false, "create SVG output")
|
||||
optimize := flag.Int("O", 1, "optimization level")
|
||||
address := flag.String("address", "127.0.0.1:8080", "address and port")
|
||||
otAddress := flag.String("ot-address", "127.0.0.1:5555", "address and port")
|
||||
fVerbose := flag.Bool("v", false, "verbose output")
|
||||
fDiagnostics := flag.Bool("d", false, "diagnostics output")
|
||||
cpuprofile := flag.String("cpuprofile", "", "write cpu profile to `file`")
|
||||
memprofile := flag.String("memprofile", "",
|
||||
"write memory profile to `file`")
|
||||
qclcErrLoc := flag.Bool("qclc-err-loc", false,
|
||||
"print QCLC error locations")
|
||||
benchmarkCompile := flag.Bool("benchmark-compile", false,
|
||||
"benchmark QCL compilation")
|
||||
flag.Parse()
|
||||
|
||||
log.SetFlags(0)
|
||||
|
||||
verbose = *fVerbose
|
||||
|
||||
if len(*cpuprofile) > 0 {
|
||||
f, err := os.Create(*cpuprofile)
|
||||
if err != nil {
|
||||
log.Fatal("could not create CPU profile: ", err)
|
||||
}
|
||||
defer f.Close()
|
||||
if err := pprof.StartCPUProfile(f); err != nil {
|
||||
log.Fatal("could not start CPU profile: ", err)
|
||||
}
|
||||
defer pprof.StopCPUProfile()
|
||||
}
|
||||
|
||||
params := utils.NewParams()
|
||||
defer params.Close()
|
||||
|
||||
params.Verbose = *fVerbose
|
||||
params.Diagnostics = *fDiagnostics
|
||||
params.QCLCErrorLoc = *qclcErrLoc
|
||||
params.BenchmarkCompile = *benchmarkCompile
|
||||
|
||||
if *optimize > 0 {
|
||||
params.OptPruneGates = true
|
||||
}
|
||||
if *ssa && !*compile {
|
||||
params.NoCircCompile = true
|
||||
}
|
||||
|
||||
if *compile || *ssa {
|
||||
inputSizes := make([][]int, 2)
|
||||
iSizes, err := circuit.InputSizes(inputFlag)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
pSizes, err := circuit.InputSizes(peerFlag)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
if *evaluator {
|
||||
inputSizes[0] = pSizes
|
||||
inputSizes[1] = iSizes
|
||||
} else {
|
||||
inputSizes[0] = iSizes
|
||||
inputSizes[1] = pSizes
|
||||
}
|
||||
|
||||
err = compileFiles(flag.Args(), params, inputSizes,
|
||||
*compile, *ssa, *dot, *svg, *circFormat)
|
||||
if err != nil {
|
||||
log.Fatalf("compile failed: %s", err)
|
||||
}
|
||||
memProfile(*memprofile)
|
||||
return
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
party := uint8(1)
|
||||
if *evaluator {
|
||||
party = 2
|
||||
}
|
||||
|
||||
oti := ot.NewFerret(party, *otAddress)
|
||||
|
||||
if *stream {
|
||||
if *evaluator {
|
||||
err = streamEvaluatorMode(oti, inputFlag, len(*cpuprofile) > 0, *address)
|
||||
} else {
|
||||
err = streamGarblerMode(params, oti, inputFlag, flag.Args(), *address)
|
||||
}
|
||||
memProfile(*memprofile)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if len(flag.Args()) != 1 {
|
||||
log.Fatalf("expected one input file, got %v\n", len(flag.Args()))
|
||||
}
|
||||
file := flag.Args()[0]
|
||||
|
||||
if *evaluator {
|
||||
err = evaluatorMode(oti, file, params, len(*cpuprofile) > 0, *address)
|
||||
} else {
|
||||
err = garblerMode(oti, file, params, *address)
|
||||
}
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func loadCircuit(file string, params *utils.Params, inputSizes [][]int) (
|
||||
*circuit.Circuit, error) {
|
||||
|
||||
var circ *circuit.Circuit
|
||||
var err error
|
||||
|
||||
if circuit.IsFilename(file) {
|
||||
circ, err = circuit.Parse(file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else if strings.HasSuffix(file, ".qcl") {
|
||||
circ, _, err = compiler.New(params).CompileFile(file, inputSizes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("unknown file type '%s'", file)
|
||||
}
|
||||
|
||||
if circ != nil {
|
||||
circ.AssignLevels()
|
||||
if verbose {
|
||||
fmt.Printf("circuit: %v\n", circ)
|
||||
}
|
||||
}
|
||||
return circ, err
|
||||
}
|
||||
|
||||
func memProfile(file string) {
|
||||
if len(file) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
f, err := os.Create(file)
|
||||
if err != nil {
|
||||
log.Fatal("could not create memory profile: ", err)
|
||||
}
|
||||
defer f.Close()
|
||||
if false {
|
||||
runtime.GC()
|
||||
}
|
||||
if err := pprof.WriteHeapProfile(f); err != nil {
|
||||
log.Fatal("could not write memory profile: ", err)
|
||||
}
|
||||
}
|
||||
|
||||
func evaluatorMode(oti ot.OT, file string, params *utils.Params,
|
||||
once bool, address string) error {
|
||||
|
||||
inputSizes := make([][]int, 2)
|
||||
myInputSizes, err := circuit.InputSizes(inputFlag)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
inputSizes[1] = myInputSizes
|
||||
|
||||
ln, err := net.Listen("tcp", address)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Printf("Listening for connections at %s\n", address)
|
||||
|
||||
var oPeerInputSizes []int
|
||||
var circ *circuit.Circuit
|
||||
|
||||
for {
|
||||
nc, err := ln.Accept()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Printf("New connection from %s\n", nc.RemoteAddr())
|
||||
|
||||
conn := p2p.NewConn(nc)
|
||||
|
||||
err = conn.SendInputSizes(myInputSizes)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return err
|
||||
}
|
||||
err = conn.Flush()
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return err
|
||||
}
|
||||
peerInputSizes, err := conn.ReceiveInputSizes()
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return err
|
||||
}
|
||||
inputSizes[0] = peerInputSizes
|
||||
|
||||
if circ == nil || slices.Compare(peerInputSizes, oPeerInputSizes) != 0 {
|
||||
circ, err = loadCircuit(file, params, inputSizes)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return err
|
||||
}
|
||||
oPeerInputSizes = peerInputSizes
|
||||
}
|
||||
circ.PrintInputs(circuit.IDEvaluator, inputFlag)
|
||||
if len(circ.Inputs) != 2 {
|
||||
return fmt.Errorf("invalid circuit for 2-party MPC: %d parties",
|
||||
len(circ.Inputs))
|
||||
}
|
||||
|
||||
input, err := circ.Inputs[1].Parse(inputFlag)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return fmt.Errorf("%s: %v", file, err)
|
||||
}
|
||||
result, err := circuit.Evaluator(conn, oti, circ, input, verbose)
|
||||
conn.Close()
|
||||
if err != nil && err != io.EOF {
|
||||
return err
|
||||
}
|
||||
bedlam.PrintResults(result, circ.Outputs)
|
||||
if once {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func garblerMode(oti ot.OT, file string, params *utils.Params, address string) error {
|
||||
inputSizes := make([][]int, 2)
|
||||
myInputSizes, err := circuit.InputSizes(inputFlag)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
inputSizes[0] = myInputSizes
|
||||
|
||||
nc, err := net.Dial("tcp", address)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
conn := p2p.NewConn(nc)
|
||||
defer conn.Close()
|
||||
|
||||
if params.Verbose {
|
||||
fmt.Println(" - Receiving input sizes")
|
||||
}
|
||||
peerInputSizes, err := conn.ReceiveInputSizes()
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
if params.Verbose {
|
||||
fmt.Println(" - Sending input sizes")
|
||||
}
|
||||
inputSizes[1] = peerInputSizes
|
||||
err = conn.SendInputSizes(myInputSizes)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
if params.Verbose {
|
||||
fmt.Println(" - Sent input sizes")
|
||||
}
|
||||
err = conn.Flush()
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
circ, err := loadCircuit(file, params, inputSizes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
circ.PrintInputs(circuit.IDGarbler, inputFlag)
|
||||
if len(circ.Inputs) != 2 {
|
||||
return fmt.Errorf("invalid circuit for 2-party MPC: %d parties",
|
||||
len(circ.Inputs))
|
||||
}
|
||||
|
||||
input, err := circ.Inputs[0].Parse(inputFlag)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %v", file, err)
|
||||
}
|
||||
|
||||
if params.Verbose {
|
||||
fmt.Println(" - Initiating garbler")
|
||||
}
|
||||
result, err := circuit.Garbler(conn, oti, circ, input, verbose)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
bedlam.PrintResults(result, circ.Outputs)
|
||||
|
||||
return nil
|
||||
}
|
||||
104
bedlam/apps/garbled/streaming.go
Normal file
104
bedlam/apps/garbled/streaming.go
Normal file
@ -0,0 +1,104 @@
|
||||
//
|
||||
// Copyright (c) 2020-2023 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam"
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/circuit"
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler"
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler/utils"
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/ot"
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/p2p"
|
||||
)
|
||||
|
||||
func streamEvaluatorMode(oti ot.OT, input input, once bool, address string) error {
|
||||
inputSizes, err := circuit.InputSizes(input)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", address)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Printf("Listening for connections at %s\n", address)
|
||||
|
||||
for {
|
||||
nc, err := ln.Accept()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Printf("New connection from %s\n", nc.RemoteAddr())
|
||||
|
||||
conn := p2p.NewConn(nc)
|
||||
|
||||
err = conn.SendInputSizes(inputSizes)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return err
|
||||
}
|
||||
err = conn.Flush()
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
outputs, result, err := circuit.StreamEvaluator(conn, oti, input,
|
||||
verbose)
|
||||
conn.Close()
|
||||
|
||||
if err != nil && err != io.EOF {
|
||||
return fmt.Errorf("%s: %v", nc.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
bedlam.PrintResults(result, outputs)
|
||||
if once {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func streamGarblerMode(params *utils.Params, oti ot.OT, input input,
|
||||
args []string, address string) error {
|
||||
|
||||
inputSizes := make([][]int, 2)
|
||||
|
||||
sizes, err := circuit.InputSizes(input)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
inputSizes[0] = sizes
|
||||
|
||||
if len(args) != 1 || !strings.HasSuffix(args[0], ".qcl") {
|
||||
return fmt.Errorf("streaming mode takes single QCL file")
|
||||
}
|
||||
nc, err := net.Dial("tcp", address)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
conn := p2p.NewConn(nc)
|
||||
defer conn.Close()
|
||||
|
||||
sizes, err = conn.ReceiveInputSizes()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
inputSizes[1] = sizes
|
||||
|
||||
outputs, result, err := compiler.New(params).StreamFile(
|
||||
conn, oti, args[0], input, inputSizes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
bedlam.PrintResults(result, outputs)
|
||||
return nil
|
||||
}
|
||||
83
bedlam/apps/iotest/iotest.go
Normal file
83
bedlam/apps/iotest/iotest.go
Normal file
@ -0,0 +1,83 @@
|
||||
//
|
||||
// main.go
|
||||
//
|
||||
// Copyright (c) 2023 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/circuit"
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/ot"
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/p2p"
|
||||
)
|
||||
|
||||
func evaluatorTestIO(size int64, once bool) error {
|
||||
ln, err := net.Listen("tcp", port)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Printf("Listening for connections at %s\n", port)
|
||||
|
||||
for {
|
||||
nc, err := ln.Accept()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Printf("New connection from %s\n", nc.RemoteAddr())
|
||||
|
||||
conn := p2p.NewConn(nc)
|
||||
for {
|
||||
var label ot.Label
|
||||
var labelData ot.LabelData
|
||||
err = conn.ReceiveLabel(&label, &labelData)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
fmt.Printf("Received: %v\n",
|
||||
circuit.FileSize(conn.Stats.Sum()).String())
|
||||
|
||||
if once {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func garblerTestIO(size int64) error {
|
||||
nc, err := net.Dial("tcp", port)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
conn := p2p.NewConn(nc)
|
||||
|
||||
var sent int64
|
||||
var label ot.Label
|
||||
var labelData ot.LabelData
|
||||
|
||||
for sent < size {
|
||||
err = conn.SendLabel(label, &labelData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sent += int64(len(labelData))
|
||||
}
|
||||
if err := conn.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := conn.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Printf("Sent: %v\n", circuit.FileSize(conn.Stats.Sum()).String())
|
||||
return nil
|
||||
}
|
||||
56
bedlam/apps/iotest/main.go
Normal file
56
bedlam/apps/iotest/main.go
Normal file
@ -0,0 +1,56 @@
|
||||
//
|
||||
// main.go
|
||||
//
|
||||
// Copyright (c) 2019-2023 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"log"
|
||||
"os"
|
||||
"runtime/pprof"
|
||||
)
|
||||
|
||||
var (
|
||||
port = ":8080"
|
||||
)
|
||||
|
||||
func main() {
|
||||
evaluator := flag.Bool("e", false, "evaluator / garbler mode")
|
||||
cpuprofile := flag.String("cpuprofile", "", "write cpu profile to `file`")
|
||||
testIO := flag.Int64("test-io", 0, "test I/O performance")
|
||||
flag.Parse()
|
||||
|
||||
log.SetFlags(0)
|
||||
|
||||
if len(*cpuprofile) > 0 {
|
||||
f, err := os.Create(*cpuprofile)
|
||||
if err != nil {
|
||||
log.Fatal("could not create CPU profile: ", err)
|
||||
}
|
||||
defer f.Close()
|
||||
if err := pprof.StartCPUProfile(f); err != nil {
|
||||
log.Fatal("could not start CPU profile: ", err)
|
||||
}
|
||||
defer pprof.StopCPUProfile()
|
||||
}
|
||||
|
||||
if *testIO > 0 {
|
||||
if *evaluator {
|
||||
err := evaluatorTestIO(*testIO, len(*cpuprofile) > 0)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
} else {
|
||||
err := garblerTestIO(*testIO)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
175
bedlam/apps/iter/main.go
Normal file
175
bedlam/apps/iter/main.go
Normal file
@ -0,0 +1,175 @@
|
||||
//
|
||||
// main.go
|
||||
//
|
||||
// Copyright (c) 2019-2023 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"math"
|
||||
"os"
|
||||
"runtime/pprof"
|
||||
"strings"
|
||||
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler"
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler/utils"
|
||||
)
|
||||
|
||||
var template = `
|
||||
package main
|
||||
|
||||
func main(a, b int%d) int {
|
||||
return a * b
|
||||
}
|
||||
`
|
||||
|
||||
type result struct {
|
||||
bits int
|
||||
bestLimit int
|
||||
bestCost uint64
|
||||
worstLimit int
|
||||
worstCost uint64
|
||||
costs []uint64
|
||||
}
|
||||
|
||||
func main() {
|
||||
numWorkers := flag.Int("workers", 8, "number of workers")
|
||||
startBits := flag.Int("start", 8, "start bit count")
|
||||
endBits := flag.Int("end", 0xffffffff, "end bit count")
|
||||
minLimit := flag.Int("min", 8, "treshold minimum limit")
|
||||
maxLimit := flag.Int("max", 22, "treshold maximum limit")
|
||||
cpuprofile := flag.String("cpuprofile", "", "write cpu profile to `file`")
|
||||
flag.Parse()
|
||||
|
||||
if len(*cpuprofile) > 0 {
|
||||
f, err := os.Create(*cpuprofile)
|
||||
if err != nil {
|
||||
log.Fatal("could not create CPU profile: ", err)
|
||||
}
|
||||
defer f.Close()
|
||||
if err := pprof.StartCPUProfile(f); err != nil {
|
||||
log.Fatal("could not start CPU profile: ", err)
|
||||
}
|
||||
defer pprof.StopCPUProfile()
|
||||
}
|
||||
|
||||
results := make(map[int]*result)
|
||||
ch := make(chan *result)
|
||||
|
||||
for i := 0; i < *numWorkers; i++ {
|
||||
go func(bits int) {
|
||||
for ; bits <= *endBits; bits += *numWorkers {
|
||||
code := fmt.Sprintf(template, bits)
|
||||
var bestLimit int
|
||||
var bestCost uint64
|
||||
var worstLimit int
|
||||
var worstCost uint64
|
||||
var costs []uint64
|
||||
|
||||
params := utils.NewParams()
|
||||
|
||||
for limit := *minLimit; limit <= *maxLimit; limit++ {
|
||||
params.CircMultArrayTreshold = limit
|
||||
circ, _, err := compiler.New(params).Compile(code, nil)
|
||||
if err != nil {
|
||||
log.Fatalf("Compilation %d:%d failed: %s\n%s",
|
||||
bits, limit, err, code)
|
||||
}
|
||||
cost := circ.Cost()
|
||||
costs = append(costs, cost)
|
||||
|
||||
if bestCost == 0 || cost < bestCost ||
|
||||
(limit == 21 && cost <= bestCost) {
|
||||
bestCost = cost
|
||||
bestLimit = limit
|
||||
}
|
||||
if cost > worstCost {
|
||||
worstCost = cost
|
||||
worstLimit = limit
|
||||
}
|
||||
}
|
||||
ch <- &result{
|
||||
bits: bits,
|
||||
bestLimit: bestLimit,
|
||||
bestCost: bestCost,
|
||||
worstLimit: worstLimit,
|
||||
worstCost: worstCost,
|
||||
costs: costs,
|
||||
}
|
||||
}
|
||||
}(*startBits + i)
|
||||
}
|
||||
|
||||
next := *startBits
|
||||
|
||||
outer:
|
||||
for result := range ch {
|
||||
results[result.bits] = result
|
||||
for {
|
||||
r, ok := results[next]
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
if r.bestLimit == 21 {
|
||||
fmt.Printf("\t// %d: %d, %10d\t%.4f\t%s\n",
|
||||
r.bits, r.bestLimit,
|
||||
r.bestCost, float64(r.bestCost)/float64(r.worstCost),
|
||||
Sparkline(r.costs))
|
||||
} else {
|
||||
fmt.Printf("\t%d: %d, // %10d\t%.4f\t%s\n",
|
||||
r.bits, r.bestLimit,
|
||||
r.bestCost, float64(r.bestCost)/float64(r.worstCost),
|
||||
Sparkline(r.costs))
|
||||
}
|
||||
if next >= *endBits {
|
||||
break outer
|
||||
}
|
||||
next++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sparkline creates a histogram chart of values. The chart is scaled
|
||||
// to [min...max] containing differences between values.
|
||||
func Sparkline(values []uint64) string {
|
||||
if len(values) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var min uint64 = math.MaxUint64
|
||||
var max uint64
|
||||
|
||||
for _, v := range values {
|
||||
if v < min {
|
||||
min = v
|
||||
}
|
||||
if v > max {
|
||||
max = v
|
||||
}
|
||||
}
|
||||
delta := max - min
|
||||
|
||||
var sb strings.Builder
|
||||
for _, v := range values {
|
||||
var tick uint64
|
||||
if delta == 0 {
|
||||
tick = 4
|
||||
} else {
|
||||
tick = (v - min) * 7 / delta
|
||||
}
|
||||
if v == min && false {
|
||||
sb.WriteString("\x1b[92m")
|
||||
sb.WriteRune(rune(0x2581 + tick))
|
||||
sb.WriteString("\x1b[0m")
|
||||
} else {
|
||||
sb.WriteRune(rune(0x2581 + tick))
|
||||
}
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
30
bedlam/apps/objdump/main.go
Normal file
30
bedlam/apps/objdump/main.go
Normal file
@ -0,0 +1,30 @@
|
||||
//
|
||||
// main.go
|
||||
//
|
||||
// Copyright (c) 2019-2023 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
log.SetFlags(0)
|
||||
|
||||
if len(flag.Args()) == 0 {
|
||||
fmt.Printf("no files specified\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
if err := dumpObjects(flag.Args()); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
61
bedlam/apps/objdump/objdump.go
Normal file
61
bedlam/apps/objdump/objdump.go
Normal file
@ -0,0 +1,61 @@
|
||||
//
|
||||
// main.go
|
||||
//
|
||||
// Copyright (c) 2019-2022 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/markkurossi/tabulate"
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/circuit"
|
||||
)
|
||||
|
||||
func dumpObjects(files []string) error {
|
||||
type oCircuit struct {
|
||||
name string
|
||||
circuit *circuit.Circuit
|
||||
}
|
||||
var circuits []oCircuit
|
||||
|
||||
for _, file := range files {
|
||||
if circuit.IsFilename(file) {
|
||||
c, err := circuit.Parse(file)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
circuits = append(circuits, oCircuit{
|
||||
name: file,
|
||||
circuit: c,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if len(circuits) > 0 {
|
||||
tab := tabulate.New(tabulate.Github)
|
||||
tab.Header("File")
|
||||
tab.Header("XOR").SetAlign(tabulate.MR)
|
||||
tab.Header("XNOR").SetAlign(tabulate.MR)
|
||||
tab.Header("AND").SetAlign(tabulate.MR)
|
||||
tab.Header("OR").SetAlign(tabulate.MR)
|
||||
tab.Header("INV").SetAlign(tabulate.MR)
|
||||
tab.Header("Gates").SetAlign(tabulate.MR)
|
||||
tab.Header("xor").SetAlign(tabulate.MR)
|
||||
tab.Header("!xor").SetAlign(tabulate.MR)
|
||||
tab.Header("Wires").SetAlign(tabulate.MR)
|
||||
|
||||
for _, c := range circuits {
|
||||
row := tab.Row()
|
||||
row.Column(c.name)
|
||||
c.circuit.TabulateRow(row)
|
||||
}
|
||||
|
||||
tab.Print(os.Stdout)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
36
bedlam/build.sh
Executable file
36
bedlam/build.sh
Executable file
@ -0,0 +1,36 @@
|
||||
#!/bin/bash
|
||||
set -euxo pipefail
|
||||
|
||||
# This script builds the node binary for the current platform and statically links it with VDF static lib.
|
||||
# Assumes that the VDF library has been built by running the generate.sh script in the `../vdf` directory.
|
||||
|
||||
ROOT_DIR="${ROOT_DIR:-$( cd "$(dirname "$(realpath "$( dirname "${BASH_SOURCE[0]}" )")")" >/dev/null 2>&1 && pwd )}"
|
||||
|
||||
NODE_DIR="$ROOT_DIR/bedlam"
|
||||
BINARIES_DIR="$ROOT_DIR/target/release"
|
||||
|
||||
pushd "$NODE_DIR" > /dev/null
|
||||
|
||||
export CGO_ENABLED=1
|
||||
|
||||
os_type="$(uname)"
|
||||
case "$os_type" in
|
||||
"Darwin")
|
||||
# Check if the architecture is ARM
|
||||
if [[ "$(uname -m)" == "arm64" ]]; then
|
||||
# MacOS ld doesn't support -Bstatic and -Bdynamic, so it's important that there is only a static version of the library
|
||||
go build -ldflags "-linkmode 'external' -extldflags '-L$BINARIES_DIR -L/usr/local/lib/ -L/opt/homebrew/Cellar/openssl@3/3.4.1/lib -lstdc++ -lferret -ldl -lm -lcrypto -lssl'" "$@"
|
||||
else
|
||||
echo "Unsupported platform"
|
||||
exit 1
|
||||
fi
|
||||
;;
|
||||
"Linux")
|
||||
export CGO_LDFLAGS="-L/usr/local/lib -ldl -lm -L$BINARIES_DIR -lstdc++ -lcrypto -lssl -lferret -static"
|
||||
go build -ldflags "-linkmode 'external'" "$@"
|
||||
;;
|
||||
*)
|
||||
echo "Unsupported platform"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
1
bedlam/circuit/aesni/.gitignore
vendored
Normal file
1
bedlam/circuit/aesni/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
c/aesni
|
||||
9
bedlam/circuit/aesni/GNUmakefile
Normal file
9
bedlam/circuit/aesni/GNUmakefile
Normal file
@ -0,0 +1,9 @@
|
||||
|
||||
all: benchmarks
|
||||
|
||||
benchmarks: c/aesni
|
||||
go test -bench=.
|
||||
c/aesni
|
||||
|
||||
c/aesni: c/aesni.c
|
||||
gcc -Wall -maes -march=native -o $@ $+
|
||||
197
bedlam/circuit/aesni/c/aesni.c
Normal file
197
bedlam/circuit/aesni/c/aesni.c
Normal file
@ -0,0 +1,197 @@
|
||||
/*
|
||||
* Circuit garbling using AES-NI instructions.
|
||||
*/
|
||||
|
||||
#include <wmmintrin.h>
|
||||
#include <emmintrin.h>
|
||||
#include <stdio.h>
|
||||
#include <stdint.h>
|
||||
#include <string.h>
|
||||
#include <sys/time.h>
|
||||
|
||||
#define AES_BLOCK_SIZE 16
|
||||
|
||||
struct cipher
|
||||
{
|
||||
__m128i key[15];
|
||||
};
|
||||
|
||||
typedef struct cipher cipher;
|
||||
|
||||
struct label
|
||||
{
|
||||
uint64_t d0;
|
||||
uint64_t d1;
|
||||
};
|
||||
|
||||
typedef struct label label;
|
||||
|
||||
static inline void
|
||||
label_xor(label *l, label *o)
|
||||
{
|
||||
l->d0 ^= o->d0;
|
||||
l->d1 ^= o->d1;
|
||||
}
|
||||
|
||||
static inline void
|
||||
label_mul2(label *l)
|
||||
{
|
||||
l->d0 <<= 1;
|
||||
l->d0 |= (l->d1 >> 63);
|
||||
l->d1 <<= 1;
|
||||
}
|
||||
|
||||
static inline void
|
||||
label_mul4(label *l)
|
||||
{
|
||||
l->d0 <<= 2;
|
||||
l->d0 |= (l->d1 >> 62);
|
||||
l->d1 <<= 2;
|
||||
}
|
||||
|
||||
static inline label *
|
||||
make_k(label *a, label *b, uint32_t t)
|
||||
{
|
||||
label tweak = {(uint64_t) t, 0};
|
||||
|
||||
label_mul2(a);
|
||||
label_mul4(b);
|
||||
label_xor(a, b);
|
||||
label_xor(a, &tweak);
|
||||
|
||||
return a;
|
||||
}
|
||||
|
||||
void
|
||||
garble2(cipher *cipher, label *a, label *b, label *c, uint32_t t, label *ret)
|
||||
{
|
||||
label *label_k;
|
||||
__m128i k;
|
||||
__m128i pi;
|
||||
|
||||
label_k = make_k(a, b, t);
|
||||
k = _mm_set_epi64x(label_k->d0, label_k->d1);
|
||||
|
||||
/* Perform the AES encryption rounds. */
|
||||
pi = _mm_xor_si128(k, cipher->key[0]);
|
||||
pi = _mm_aesenc_si128(pi, cipher->key[1]);
|
||||
pi = _mm_aesenc_si128(pi, cipher->key[2]);
|
||||
pi = _mm_aesenc_si128(pi, cipher->key[3]);
|
||||
pi = _mm_aesenc_si128(pi, cipher->key[4]);
|
||||
pi = _mm_aesenc_si128(pi, cipher->key[5]);
|
||||
pi = _mm_aesenc_si128(pi, cipher->key[6]);
|
||||
pi = _mm_aesenc_si128(pi, cipher->key[7]);
|
||||
pi = _mm_aesenc_si128(pi, cipher->key[8]);
|
||||
pi = _mm_aesenc_si128(pi, cipher->key[9]);
|
||||
pi = _mm_aesenc_si128(pi, cipher->key[10]);
|
||||
pi = _mm_aesenc_si128(pi, cipher->key[11]);
|
||||
pi = _mm_aesenc_si128(pi, cipher->key[12]);
|
||||
pi = _mm_aesenc_si128(pi, cipher->key[13]);
|
||||
pi = _mm_aesenclast_si128(pi, cipher->key[14]);
|
||||
|
||||
ret->d0 = ((uint64_t *) &pi)[0];
|
||||
ret->d1 = ((uint64_t *) &pi)[1];
|
||||
|
||||
label_xor(ret, label_k);
|
||||
label_xor(ret, c);
|
||||
}
|
||||
|
||||
void
|
||||
make_cipher(cipher *cipher, const unsigned char key[AES_BLOCK_SIZE])
|
||||
{
|
||||
cipher->key[0] = _mm_loadu_si128((__m128i *)key);
|
||||
cipher->key[1] = _mm_aeskeygenassist_si128(cipher->key[0], 0x01);
|
||||
cipher->key[2] = _mm_aeskeygenassist_si128(cipher->key[1], 0x02);
|
||||
cipher->key[3] = _mm_aeskeygenassist_si128(cipher->key[2], 0x04);
|
||||
cipher->key[4] = _mm_aeskeygenassist_si128(cipher->key[3], 0x08);
|
||||
cipher->key[5] = _mm_aeskeygenassist_si128(cipher->key[4], 0x10);
|
||||
cipher->key[6] = _mm_aeskeygenassist_si128(cipher->key[5], 0x20);
|
||||
cipher->key[7] = _mm_aeskeygenassist_si128(cipher->key[6], 0x40);
|
||||
cipher->key[8] = _mm_aeskeygenassist_si128(cipher->key[7], 0x80);
|
||||
cipher->key[9] = _mm_aeskeygenassist_si128(cipher->key[8], 0x1B);
|
||||
cipher->key[10] = _mm_aeskeygenassist_si128(cipher->key[9], 0x36);
|
||||
cipher->key[11] = _mm_aeskeygenassist_si128(cipher->key[10], 0x6C);
|
||||
cipher->key[12] = _mm_aeskeygenassist_si128(cipher->key[11], 0xD8);
|
||||
cipher->key[13] = _mm_aeskeygenassist_si128(cipher->key[12], 0xAB);
|
||||
cipher->key[14] = _mm_aeskeygenassist_si128(cipher->key[13], 0x4D);
|
||||
}
|
||||
|
||||
static inline __m128i
|
||||
garble(cipher *cipher, __m128i a, __m128i b, __m128i c, uint32_t t)
|
||||
{
|
||||
__m128i k;
|
||||
__m128i pi;
|
||||
|
||||
k = _mm_xor_si128(_mm_xor_si128(_mm_slli_epi32(a, 1),
|
||||
_mm_slli_epi32(b, 2)),
|
||||
_mm_set1_epi32(t));
|
||||
|
||||
/* Perform the AES encryption rounds. */
|
||||
pi = _mm_xor_si128(k, cipher->key[0]);
|
||||
pi = _mm_aesenc_si128(pi, cipher->key[1]);
|
||||
pi = _mm_aesenc_si128(pi, cipher->key[2]);
|
||||
pi = _mm_aesenc_si128(pi, cipher->key[3]);
|
||||
pi = _mm_aesenc_si128(pi, cipher->key[4]);
|
||||
pi = _mm_aesenc_si128(pi, cipher->key[5]);
|
||||
pi = _mm_aesenc_si128(pi, cipher->key[6]);
|
||||
pi = _mm_aesenc_si128(pi, cipher->key[7]);
|
||||
pi = _mm_aesenc_si128(pi, cipher->key[8]);
|
||||
pi = _mm_aesenc_si128(pi, cipher->key[9]);
|
||||
pi = _mm_aesenc_si128(pi, cipher->key[10]);
|
||||
pi = _mm_aesenc_si128(pi, cipher->key[11]);
|
||||
pi = _mm_aesenc_si128(pi, cipher->key[12]);
|
||||
pi = _mm_aesenc_si128(pi, cipher->key[13]);
|
||||
pi = _mm_aesenclast_si128(pi, cipher->key[14]);
|
||||
|
||||
return _mm_xor_si128(_mm_xor_si128(pi, k), c);
|
||||
}
|
||||
|
||||
void
|
||||
report(char *l, struct timeval *begin, struct timeval *end, int rounds)
|
||||
{
|
||||
int64_t d;
|
||||
d = (int64_t) end->tv_sec - (int64_t) begin->tv_sec;
|
||||
d *= 1000000;
|
||||
d += (int64_t) end->tv_usec - (int64_t) begin->tv_usec;
|
||||
d *= 1000;
|
||||
d *= 100;
|
||||
|
||||
d /= rounds;
|
||||
|
||||
printf("%-20s\t%d\t\t%.2f ns/op\n", l, rounds, (double) d / 100);
|
||||
}
|
||||
|
||||
int
|
||||
main()
|
||||
{
|
||||
const unsigned char key[AES_BLOCK_SIZE] = "0123456789ABCDEF";
|
||||
cipher cipher = {0};
|
||||
struct timeval begin, end;
|
||||
int rounds = 23802664;
|
||||
__m128i a = _mm_set1_epi32(42);
|
||||
__m128i b = _mm_set1_epi32(43);
|
||||
__m128i c = _mm_set1_epi32(44);
|
||||
label la = {42, 0};
|
||||
label lb = {43, 0};
|
||||
label lc = {44, 0};
|
||||
label lr;
|
||||
|
||||
make_cipher(&cipher, key);
|
||||
|
||||
gettimeofday(&begin, NULL);
|
||||
for (int i = 0; i < rounds; i++)
|
||||
garble(&cipher, a, b, c, (uint32_t) i);
|
||||
gettimeofday(&end, NULL);
|
||||
|
||||
report("AES-NI", &begin, &end, rounds);
|
||||
|
||||
gettimeofday(&begin, NULL);
|
||||
for (int i = 0; i < rounds; i++)
|
||||
garble2(&cipher, &la, &lb, &lc, (uint32_t) i, &lr);
|
||||
gettimeofday(&end, NULL);
|
||||
|
||||
report("AES-NI+C", &begin, &end, rounds);
|
||||
|
||||
|
||||
return 0;
|
||||
}
|
||||
144
bedlam/circuit/aesni/go_test.go
Normal file
144
bedlam/circuit/aesni/go_test.go
Normal file
@ -0,0 +1,144 @@
|
||||
//
|
||||
// enc_test.go
|
||||
//
|
||||
// Copyright (c) 2019-2023 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package aesni
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type LabelX struct {
|
||||
d0 uint64
|
||||
d1 uint64
|
||||
}
|
||||
|
||||
func NewLabelX(rand io.Reader) (LabelX, error) {
|
||||
var buf [16]byte
|
||||
var result LabelX
|
||||
|
||||
_, err := rand.Read(buf[:])
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
result.SetData(&buf)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (l *LabelX) SetData(buf *[16]byte) {
|
||||
l.d0 = binary.BigEndian.Uint64((*buf)[0:8])
|
||||
l.d1 = binary.BigEndian.Uint64((*buf)[8:16])
|
||||
}
|
||||
|
||||
func (l *LabelX) GetData(buf *[16]byte) {
|
||||
binary.BigEndian.PutUint64((*buf)[0:8], l.d0)
|
||||
binary.BigEndian.PutUint64((*buf)[8:16], l.d1)
|
||||
}
|
||||
|
||||
func (l *LabelX) Xor(o LabelX) {
|
||||
l.d0 ^= o.d0
|
||||
l.d1 ^= o.d1
|
||||
}
|
||||
|
||||
func (l *LabelX) Mul2() {
|
||||
l.d0 <<= 1
|
||||
l.d0 |= (l.d1 >> 63)
|
||||
l.d1 <<= 1
|
||||
}
|
||||
|
||||
func (l *LabelX) Mul4() {
|
||||
l.d0 <<= 2
|
||||
l.d0 |= (l.d1 >> 62)
|
||||
l.d1 <<= 2
|
||||
}
|
||||
|
||||
func TestLabelXor(t *testing.T) {
|
||||
val := uint64(0b0101010101010101010101010101010101010101010101010101010101010101)
|
||||
a := LabelX{
|
||||
d0: val,
|
||||
d1: val << 1,
|
||||
}
|
||||
b := LabelX{
|
||||
d0: 0xffffffffffffffff,
|
||||
d1: 0xffffffffffffffff,
|
||||
}
|
||||
a.Xor(b)
|
||||
if a.d0 != val<<1 {
|
||||
t.Errorf("Xor: unexpected d0=%x, epected %x", a.d0, val<<1)
|
||||
}
|
||||
if a.d1 != val {
|
||||
t.Errorf("Xor: unexpected d1=%x, epected %x", a.d1, val)
|
||||
}
|
||||
}
|
||||
|
||||
func NewTweakX(tweak uint32) LabelX {
|
||||
return LabelX{
|
||||
d1: uint64(tweak),
|
||||
}
|
||||
}
|
||||
|
||||
func encryptX(alg cipher.Block, a, b, c LabelX, t uint32,
|
||||
buf *[16]byte) LabelX {
|
||||
|
||||
k := makeKX(a, b, t)
|
||||
|
||||
k.GetData(buf)
|
||||
alg.Encrypt(buf[:], buf[:])
|
||||
|
||||
var pi LabelX
|
||||
pi.SetData(buf)
|
||||
|
||||
pi.Xor(k)
|
||||
pi.Xor(c)
|
||||
|
||||
return pi
|
||||
}
|
||||
|
||||
func makeKX(a, b LabelX, t uint32) LabelX {
|
||||
a.Mul2()
|
||||
|
||||
b.Mul4()
|
||||
a.Xor(b)
|
||||
|
||||
a.Xor(NewTweakX(t))
|
||||
|
||||
return a
|
||||
}
|
||||
|
||||
func BenchmarkLabelX(b *testing.B) {
|
||||
var key [32]byte
|
||||
|
||||
cipher, err := aes.NewCipher(key[:])
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to create cipher: %s", err)
|
||||
}
|
||||
|
||||
al, err := NewLabelX(rand.Reader)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to create label: %s", err)
|
||||
}
|
||||
bl, err := NewLabelX(rand.Reader)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to create label: %s", err)
|
||||
}
|
||||
cl, err := NewLabelX(rand.Reader)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to create label: %s", err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
var buf [16]byte
|
||||
for i := 0; i < b.N; i++ {
|
||||
encryptX(cipher, al, bl, cl, uint32(i), &buf)
|
||||
}
|
||||
}
|
||||
42
bedlam/circuit/analyze.go
Normal file
42
bedlam/circuit/analyze.go
Normal file
@ -0,0 +1,42 @@
|
||||
//
|
||||
// Copyright (c) 2021 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package circuit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Analyze identifies potential optimizations for the circuit.
|
||||
func (c *Circuit) Analyze() {
|
||||
fmt.Printf("analyzing circuit %v\n", c)
|
||||
|
||||
from := make([][]Gate, c.NumWires)
|
||||
to := make([][]Gate, c.NumWires)
|
||||
|
||||
// Collect wire inputs and outputs.
|
||||
for _, g := range c.Gates {
|
||||
switch g.Op {
|
||||
case XOR, XNOR, AND, OR:
|
||||
to[g.Input1] = append(to[g.Input1], g)
|
||||
fallthrough
|
||||
|
||||
case INV:
|
||||
to[g.Input0] = append(to[g.Input0], g)
|
||||
from[g.Output] = append(to[g.Output], g)
|
||||
}
|
||||
}
|
||||
|
||||
// INV gates as single output of input gate.
|
||||
for _, g := range c.Gates {
|
||||
if g.Op != INV {
|
||||
continue
|
||||
}
|
||||
if len(to[g.Input0]) == 1 && len(from[g.Input0]) == 1 {
|
||||
fmt.Printf("%v -> %v\n", from[g.Input0][0].Op, g.Op)
|
||||
}
|
||||
}
|
||||
}
|
||||
270
bedlam/circuit/circuit.go
Normal file
270
bedlam/circuit/circuit.go
Normal file
@ -0,0 +1,270 @@
|
||||
//
|
||||
// Copyright (c) 2019-2023 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package circuit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
|
||||
"github.com/markkurossi/tabulate"
|
||||
)
|
||||
|
||||
// Operation specifies gate function.
|
||||
type Operation byte
|
||||
|
||||
// Gate functions.
|
||||
const (
|
||||
XOR Operation = iota
|
||||
XNOR
|
||||
AND
|
||||
OR
|
||||
INV
|
||||
Count
|
||||
NumLevels
|
||||
MaxWidth
|
||||
)
|
||||
|
||||
// Known multi-party computation roles.
|
||||
const (
|
||||
IDGarbler int = iota
|
||||
IDEvaluator
|
||||
)
|
||||
|
||||
// Stats holds statistics about circuit operations.
|
||||
type Stats [MaxWidth + 1]uint64
|
||||
|
||||
// Add adds the argument statistics to this statistics object.
|
||||
func (stats *Stats) Add(o Stats) {
|
||||
for i := XOR; i < Count; i++ {
|
||||
stats[i] += o[i]
|
||||
}
|
||||
stats[Count]++
|
||||
|
||||
for i := NumLevels; i <= MaxWidth; i++ {
|
||||
if o[i] > stats[i] {
|
||||
stats[i] = o[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Count returns the number of gates in the statistics object.
|
||||
func (stats Stats) Count() uint64 {
|
||||
var result uint64
|
||||
for i := XOR; i < Count; i++ {
|
||||
result += stats[i]
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Cost computes the relative computational cost of the circuit.
|
||||
func (stats Stats) Cost() uint64 {
|
||||
return (stats[AND]+stats[INV])*2 + stats[OR]*3
|
||||
}
|
||||
|
||||
func (stats Stats) String() string {
|
||||
var result string
|
||||
|
||||
for i := XOR; i < Count; i++ {
|
||||
v := stats[i]
|
||||
if len(result) > 0 {
|
||||
result += " "
|
||||
}
|
||||
result += fmt.Sprintf("%s=%d", i, v)
|
||||
}
|
||||
result += fmt.Sprintf(" xor=%d", stats[XOR]+stats[XNOR])
|
||||
result += fmt.Sprintf(" !xor=%d", stats[AND]+stats[OR]+stats[INV])
|
||||
result += fmt.Sprintf(" levels=%d", stats[NumLevels])
|
||||
result += fmt.Sprintf(" width=%d", stats[MaxWidth])
|
||||
return result
|
||||
}
|
||||
|
||||
func (op Operation) String() string {
|
||||
switch op {
|
||||
case XOR:
|
||||
return "XOR"
|
||||
case XNOR:
|
||||
return "XNOR"
|
||||
case AND:
|
||||
return "AND"
|
||||
case OR:
|
||||
return "OR"
|
||||
case INV:
|
||||
return "INV"
|
||||
case Count:
|
||||
return "#"
|
||||
default:
|
||||
return fmt.Sprintf("{Operation %d}", op)
|
||||
}
|
||||
}
|
||||
|
||||
// Circuit specifies a boolean circuit.
|
||||
type Circuit struct {
|
||||
NumGates int
|
||||
NumWires int
|
||||
Inputs IO
|
||||
Outputs IO
|
||||
Gates []Gate
|
||||
Stats Stats
|
||||
}
|
||||
|
||||
func (c *Circuit) String() string {
|
||||
return fmt.Sprintf("#gates=%d (%s) #w=%d", c.NumGates, c.Stats, c.NumWires)
|
||||
}
|
||||
|
||||
// NumParties returns the number of parties needed for the circuit.
|
||||
func (c *Circuit) NumParties() int {
|
||||
return len(c.Inputs)
|
||||
}
|
||||
|
||||
// PrintInputs prints the circuit inputs.
|
||||
func (c *Circuit) PrintInputs(id int, input []string) {
|
||||
for i := 0; i < len(c.Inputs); i++ {
|
||||
if i == id {
|
||||
fmt.Print(" + ")
|
||||
} else {
|
||||
fmt.Print(" - ")
|
||||
}
|
||||
fmt.Printf("In%d: %s\n", i, c.Inputs[i])
|
||||
}
|
||||
fmt.Printf(" - Out: %s\n", c.Outputs)
|
||||
fmt.Printf(" - In: %s\n", input)
|
||||
}
|
||||
|
||||
// TabulateStats prints the circuit stats as a table to the specified
|
||||
// output Writer.
|
||||
func (c *Circuit) TabulateStats(out io.Writer) {
|
||||
tab := tabulate.New(tabulate.UnicodeLight)
|
||||
tab.Header("XOR").SetAlign(tabulate.MR)
|
||||
tab.Header("XNOR").SetAlign(tabulate.MR)
|
||||
tab.Header("AND").SetAlign(tabulate.MR)
|
||||
tab.Header("OR").SetAlign(tabulate.MR)
|
||||
tab.Header("INV").SetAlign(tabulate.MR)
|
||||
tab.Header("Gates").SetAlign(tabulate.MR)
|
||||
tab.Header("XOR").SetAlign(tabulate.MR)
|
||||
tab.Header("!XOR").SetAlign(tabulate.MR)
|
||||
tab.Header("Wires").SetAlign(tabulate.MR)
|
||||
|
||||
c.TabulateRow(tab.Row())
|
||||
tab.Print(out)
|
||||
}
|
||||
|
||||
// TabulateRow tabulates circuit statistics to the argument tabulation
|
||||
// row.
|
||||
func (c *Circuit) TabulateRow(row *tabulate.Row) {
|
||||
var sumGates uint64
|
||||
for op := XOR; op < Count; op++ {
|
||||
row.Column(fmt.Sprintf("%v", c.Stats[op]))
|
||||
sumGates += c.Stats[op]
|
||||
}
|
||||
row.Column(fmt.Sprintf("%v", sumGates))
|
||||
row.Column(fmt.Sprintf("%v", c.Stats[XOR]+c.Stats[XNOR]))
|
||||
row.Column(fmt.Sprintf("%v", c.Stats[AND]+c.Stats[OR]+c.Stats[INV]))
|
||||
row.Column(fmt.Sprintf("%v", c.NumWires))
|
||||
}
|
||||
|
||||
// Cost computes the relative computational cost of the circuit.
|
||||
func (c *Circuit) Cost() uint64 {
|
||||
return c.Stats.Cost()
|
||||
}
|
||||
|
||||
// Dump prints a debug dump of the circuit.
|
||||
func (c *Circuit) Dump() {
|
||||
fmt.Printf("circuit %s\n", c)
|
||||
for id, gate := range c.Gates {
|
||||
fmt.Printf("%04d\t%s\n", id, gate)
|
||||
}
|
||||
}
|
||||
|
||||
// AssignLevels assigns levels for gates. The level desribes how many
|
||||
// steps away the gate is from input wires.
|
||||
func (c *Circuit) AssignLevels() {
|
||||
levels := make([]Level, c.NumWires)
|
||||
countByLevel := make([]uint32, c.NumWires)
|
||||
|
||||
var max Level
|
||||
|
||||
for idx, gate := range c.Gates {
|
||||
level := levels[gate.Input0]
|
||||
if gate.Op != INV {
|
||||
l1 := levels[gate.Input1]
|
||||
if l1 > level {
|
||||
level = l1
|
||||
}
|
||||
}
|
||||
c.Gates[idx].Level = level
|
||||
countByLevel[level]++
|
||||
|
||||
level++
|
||||
|
||||
levels[gate.Output] = level
|
||||
if level > max {
|
||||
max = level
|
||||
}
|
||||
}
|
||||
c.Stats[NumLevels] = uint64(max)
|
||||
|
||||
var maxWidth uint32
|
||||
for _, count := range countByLevel {
|
||||
if count > maxWidth {
|
||||
maxWidth = count
|
||||
}
|
||||
}
|
||||
if false {
|
||||
for i := 0; i < int(max); i++ {
|
||||
fmt.Printf("%v,%v\n", i, countByLevel[i])
|
||||
}
|
||||
}
|
||||
|
||||
c.Stats[MaxWidth] = uint64(maxWidth)
|
||||
}
|
||||
|
||||
// Level defines gate's distance from input wires.
|
||||
type Level uint32
|
||||
|
||||
// Gate specifies a boolean gate.
|
||||
type Gate struct {
|
||||
Input0 Wire
|
||||
Input1 Wire
|
||||
Output Wire
|
||||
Op Operation
|
||||
Level Level
|
||||
}
|
||||
|
||||
func (g Gate) String() string {
|
||||
return fmt.Sprintf("%v %v %v", g.Inputs(), g.Op, g.Output)
|
||||
}
|
||||
|
||||
// Inputs returns gate input wires.
|
||||
func (g Gate) Inputs() []Wire {
|
||||
switch g.Op {
|
||||
case XOR, XNOR, AND, OR:
|
||||
return []Wire{g.Input0, g.Input1}
|
||||
case INV:
|
||||
return []Wire{g.Input0}
|
||||
default:
|
||||
panic(fmt.Sprintf("unsupported gate type %s", g.Op))
|
||||
}
|
||||
}
|
||||
|
||||
// Wire specifies a wire ID.
|
||||
type Wire uint32
|
||||
|
||||
// InvalidWire specifies an invalid wire ID.
|
||||
const InvalidWire Wire = math.MaxUint32
|
||||
|
||||
// Int returns the wire ID as integer.
|
||||
func (w Wire) Int() int {
|
||||
if uint64(w) > math.MaxInt {
|
||||
panic(w)
|
||||
}
|
||||
return int(w)
|
||||
}
|
||||
|
||||
func (w Wire) String() string {
|
||||
return fmt.Sprintf("w%d", w)
|
||||
}
|
||||
19
bedlam/circuit/circuit_test.go
Normal file
19
bedlam/circuit/circuit_test.go
Normal file
@ -0,0 +1,19 @@
|
||||
//
|
||||
// Copyright (c) 2022-2023 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package circuit
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func TestSize(t *testing.T) {
|
||||
var g Gate
|
||||
if unsafe.Sizeof(g) != 20 {
|
||||
t.Errorf("unexpected gate size: got %v, expected 20", unsafe.Sizeof(g))
|
||||
}
|
||||
}
|
||||
91
bedlam/circuit/computer.go
Normal file
91
bedlam/circuit/computer.go
Normal file
@ -0,0 +1,91 @@
|
||||
//
|
||||
// Copyright (c) 2020-2023 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package circuit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/big"
|
||||
)
|
||||
|
||||
// Compute evaluates the circuit with the given input values.
|
||||
func (c *Circuit) Compute(inputs []*big.Int) ([]*big.Int, error) {
|
||||
// Flatten circuit arguments.
|
||||
var args IO
|
||||
for _, io := range c.Inputs {
|
||||
if len(io.Compound) > 0 {
|
||||
args = append(args, io.Compound...)
|
||||
} else {
|
||||
args = append(args, io)
|
||||
}
|
||||
}
|
||||
if len(inputs) != len(args) {
|
||||
return nil, fmt.Errorf("invalid inputs: got %d, expected %d",
|
||||
len(inputs), len(args))
|
||||
}
|
||||
|
||||
// Flatten inputs and arguments.
|
||||
wires := make([]byte, c.NumWires)
|
||||
|
||||
var w int
|
||||
for idx, io := range args {
|
||||
for bit := 0; bit < int(io.Type.Bits); bit++ {
|
||||
wires[w] = byte(inputs[idx].Bit(bit))
|
||||
w++
|
||||
}
|
||||
}
|
||||
|
||||
// Evaluate circuit.
|
||||
for _, gate := range c.Gates {
|
||||
var result byte
|
||||
|
||||
switch gate.Op {
|
||||
case XOR:
|
||||
result = wires[gate.Input0] ^ wires[gate.Input1]
|
||||
|
||||
case XNOR:
|
||||
if wires[gate.Input0]^wires[gate.Input1] == 0 {
|
||||
result = 1
|
||||
} else {
|
||||
result = 0
|
||||
}
|
||||
|
||||
case AND:
|
||||
result = wires[gate.Input0] & wires[gate.Input1]
|
||||
|
||||
case OR:
|
||||
result = wires[gate.Input0] | wires[gate.Input1]
|
||||
|
||||
case INV:
|
||||
if wires[gate.Input0] == 0 {
|
||||
result = 1
|
||||
} else {
|
||||
result = 0
|
||||
}
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid gate %s", gate.Op)
|
||||
}
|
||||
|
||||
wires[gate.Output] = result
|
||||
}
|
||||
|
||||
// Construct outputs
|
||||
w = c.NumWires - c.Outputs.Size()
|
||||
var result []*big.Int
|
||||
for _, io := range c.Outputs {
|
||||
r := new(big.Int)
|
||||
for bit := 0; bit < int(io.Type.Bits); bit++ {
|
||||
if wires[w] != 0 {
|
||||
r.SetBit(r, bit, 1)
|
||||
}
|
||||
w++
|
||||
}
|
||||
result = append(result, r)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
19
bedlam/circuit/docs/and.svg
Normal file
19
bedlam/circuit/docs/and.svg
Normal file
@ -0,0 +1,19 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="100" height="100">
|
||||
<g fill="none" stroke="#000" stroke-width="1">
|
||||
<path d="M 25 25
|
||||
h 50
|
||||
v 25
|
||||
a 25 25 0 1 1 -50 0
|
||||
v -25
|
||||
z" />
|
||||
<path d="M 35 0
|
||||
v 25
|
||||
z" />
|
||||
<path d="M 65 0
|
||||
v 25
|
||||
z" />
|
||||
<path d="M 50 75
|
||||
v 25
|
||||
z" />
|
||||
</g>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 429 B |
20
bedlam/circuit/docs/not.svg
Normal file
20
bedlam/circuit/docs/not.svg
Normal file
@ -0,0 +1,20 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="100" height="100">
|
||||
<g fill="none" stroke="#000" stroke-width="1">
|
||||
<path d="M 25 25
|
||||
h 50
|
||||
l -25 43
|
||||
z" />
|
||||
<circle cx="50"
|
||||
cy="73.5"
|
||||
r="5" />
|
||||
<path d="M 35 0
|
||||
v 25
|
||||
z" />
|
||||
<path d="M 65 0
|
||||
v 25
|
||||
z" />
|
||||
<path d="M 50 79
|
||||
v 25
|
||||
z" />
|
||||
</g>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 444 B |
21
bedlam/circuit/docs/or.svg
Normal file
21
bedlam/circuit/docs/or.svg
Normal file
@ -0,0 +1,21 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="100" height="100">
|
||||
<g fill="none" stroke="#000" stroke-width="1">
|
||||
<path d="M 25 20
|
||||
c 10 10 40 10 50 0" />
|
||||
<path d="M 75 20
|
||||
v 30
|
||||
s 0 10 -25 25 " />
|
||||
<path d="M 25 20
|
||||
v 30
|
||||
s 0 10 25 25 " />
|
||||
<path d="M 35 0
|
||||
v 25
|
||||
z" />
|
||||
<path d="M 65 0
|
||||
v 25
|
||||
z" />
|
||||
<path d="M 50 75
|
||||
v 25
|
||||
z" />
|
||||
</g>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 499 B |
6
bedlam/circuit/docs/wire.svg
Normal file
6
bedlam/circuit/docs/wire.svg
Normal file
@ -0,0 +1,6 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="100" height="100">
|
||||
<g fill="none" stroke="#000" stroke-width="1">
|
||||
<path d="M 25 20
|
||||
C 25 30 35 30 35 40" />
|
||||
</g>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 187 B |
26
bedlam/circuit/docs/xnor.svg
Normal file
26
bedlam/circuit/docs/xnor.svg
Normal file
@ -0,0 +1,26 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="100" height="100">
|
||||
<g fill="none" stroke="#000" stroke-width="1">
|
||||
<path d="M 25 20
|
||||
c 10 10 40 10 50 0" />
|
||||
<path d="M 25 25
|
||||
c 10 10 40 10 50 0" />
|
||||
<path d="M 75 25
|
||||
v 25
|
||||
s 0 10 -25 25 " />
|
||||
<path d="M 25 25
|
||||
v 25
|
||||
s 0 10 25 25 " />
|
||||
<circle cx="50"
|
||||
cy="80"
|
||||
r="5" />
|
||||
<path d="M 35 0
|
||||
v 25
|
||||
z" />
|
||||
<path d="M 65 0
|
||||
v 25
|
||||
z" />
|
||||
<path d="M 50 85
|
||||
v 25
|
||||
z" />
|
||||
</g>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 617 B |
23
bedlam/circuit/docs/xor.svg
Normal file
23
bedlam/circuit/docs/xor.svg
Normal file
@ -0,0 +1,23 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="100" height="100">
|
||||
<g fill="none" stroke="#000" stroke-width="1">
|
||||
<path d="M 25 20
|
||||
c 10 10 40 10 50 0" />
|
||||
<path d="M 25 25
|
||||
c 10 10 40 10 50 0" />
|
||||
<path d="M 75 25
|
||||
v 25
|
||||
s 0 10 -25 25 " />
|
||||
<path d="M 25 25
|
||||
v 25
|
||||
s 0 10 25 25 " />
|
||||
<path d="M 35 0
|
||||
v 25
|
||||
z" />
|
||||
<path d="M 65 0
|
||||
v 25
|
||||
z" />
|
||||
<path d="M 50 75
|
||||
v 25
|
||||
z" />
|
||||
</g>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 556 B |
56
bedlam/circuit/dot.go
Normal file
56
bedlam/circuit/dot.go
Normal file
@ -0,0 +1,56 @@
|
||||
//
|
||||
// Copyright (c) 2019-2023 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package circuit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// Dot creates graphviz dot output of the circuit.
|
||||
func (c *Circuit) Dot(out io.Writer) {
|
||||
fmt.Fprintf(out, "digraph circuit\n{\n")
|
||||
fmt.Fprintf(out, " overlap=scale;\n")
|
||||
fmt.Fprintf(out, " node\t[fontname=\"Helvetica\"];\n")
|
||||
fmt.Fprintf(out, " {\n node [shape=plaintext];\n")
|
||||
for w := 0; w < c.NumWires; w++ {
|
||||
fmt.Fprintf(out, " w%d\t[label=\"%d\"];\n", w, w)
|
||||
}
|
||||
fmt.Fprintf(out, " }\n")
|
||||
|
||||
fmt.Fprintf(out, " {\n node [shape=box];\n")
|
||||
for idx, gate := range c.Gates {
|
||||
fmt.Fprintf(out, " g%d\t[label=\"%s\"];\n", idx, gate.Op)
|
||||
}
|
||||
fmt.Fprintf(out, " }\n")
|
||||
|
||||
if true {
|
||||
fmt.Fprintf(out, " { rank=same")
|
||||
var numInputs int
|
||||
for _, input := range c.Inputs {
|
||||
numInputs += int(input.Type.Bits)
|
||||
}
|
||||
for w := 0; w < numInputs; w++ {
|
||||
fmt.Fprintf(out, "; w%d", w)
|
||||
}
|
||||
fmt.Fprintf(out, ";}\n")
|
||||
|
||||
fmt.Fprintf(out, " { rank=same")
|
||||
for w := 0; w < c.Outputs.Size(); w++ {
|
||||
fmt.Fprintf(out, "; w%d", c.NumWires-w-1)
|
||||
}
|
||||
fmt.Fprintf(out, ";}\n")
|
||||
}
|
||||
|
||||
for idx, gate := range c.Gates {
|
||||
for _, i := range gate.Inputs() {
|
||||
fmt.Fprintf(out, " w%d -> g%d;\n", i, idx)
|
||||
}
|
||||
fmt.Fprintf(out, " g%d -> w%d;\n", idx, gate.Output)
|
||||
}
|
||||
fmt.Fprintf(out, "}\n")
|
||||
}
|
||||
91
bedlam/circuit/enc_test.go
Normal file
91
bedlam/circuit/enc_test.go
Normal file
@ -0,0 +1,91 @@
|
||||
//
|
||||
// enc_test.go
|
||||
//
|
||||
// Copyright (c) 2019-2023 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package circuit
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/rand"
|
||||
"testing"
|
||||
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/ot"
|
||||
)
|
||||
|
||||
func TestEnc(t *testing.T) {
|
||||
a, _ := ot.NewLabel(rand.Reader)
|
||||
b, _ := ot.NewLabel(rand.Reader)
|
||||
c, _ := ot.NewLabel(rand.Reader)
|
||||
tweak := uint32(42)
|
||||
var key [32]byte
|
||||
|
||||
cipher, err := aes.NewCipher(key[:])
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create cipher: %s", err)
|
||||
}
|
||||
|
||||
var data ot.LabelData
|
||||
|
||||
encrypted := encrypt(cipher, a, b, c, tweak, &data)
|
||||
if err != nil {
|
||||
t.Fatalf("Encrypt failed: %s", err)
|
||||
}
|
||||
|
||||
plain := decrypt(cipher, a, b, tweak, encrypted, &data)
|
||||
|
||||
if !c.Equal(plain) {
|
||||
t.Fatalf("Encrypt-decrypt failed")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkEnc(b *testing.B) {
|
||||
var key [32]byte
|
||||
|
||||
cipher, err := aes.NewCipher(key[:])
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to create cipher: %s", err)
|
||||
}
|
||||
|
||||
al, err := ot.NewLabel(rand.Reader)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to create label: %s", err)
|
||||
}
|
||||
bl, err := ot.NewLabel(rand.Reader)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to create label: %s", err)
|
||||
}
|
||||
cl, err := ot.NewLabel(rand.Reader)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to create label: %s", err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
var data ot.LabelData
|
||||
for i := 0; i < b.N; i++ {
|
||||
encrypt(cipher, al, bl, cl, uint32(i), &data)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkEncHalf(b *testing.B) {
|
||||
var key [32]byte
|
||||
|
||||
cipher, err := aes.NewCipher(key[:])
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to create cipher: %s", err)
|
||||
}
|
||||
|
||||
xl, err := ot.NewLabel(rand.Reader)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to create label: %s", err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
var data ot.LabelData
|
||||
for i := 0; i < b.N; i++ {
|
||||
encryptHalf(cipher, xl, uint32(i), &data)
|
||||
}
|
||||
}
|
||||
115
bedlam/circuit/eval.go
Normal file
115
bedlam/circuit/eval.go
Normal file
@ -0,0 +1,115 @@
|
||||
//
|
||||
// Copyright (c) 2019-2021 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package circuit
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"fmt"
|
||||
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/ot"
|
||||
)
|
||||
|
||||
// Eval evaluates the circuit.
|
||||
func (c *Circuit) Eval(key []byte, wires []ot.Label,
|
||||
garbled [][]ot.Label) error {
|
||||
|
||||
alg, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var data ot.LabelData
|
||||
var id uint32
|
||||
|
||||
for i := 0; i < len(c.Gates); i++ {
|
||||
gate := &c.Gates[i]
|
||||
|
||||
var a, b, c ot.Label
|
||||
|
||||
switch gate.Op {
|
||||
case XOR, XNOR, AND, OR:
|
||||
a = wires[gate.Input0]
|
||||
b = wires[gate.Input1]
|
||||
|
||||
case INV:
|
||||
a = wires[gate.Input0]
|
||||
|
||||
default:
|
||||
return fmt.Errorf("invalid operation %s", gate.Op)
|
||||
}
|
||||
|
||||
var output ot.Label
|
||||
|
||||
switch gate.Op {
|
||||
case XOR, XNOR:
|
||||
a.Xor(b)
|
||||
output = a
|
||||
|
||||
case AND:
|
||||
row := garbled[i]
|
||||
if len(row) != 2 {
|
||||
return fmt.Errorf("corrupted ciruit: AND row length: %d",
|
||||
len(row))
|
||||
}
|
||||
sa := a.S()
|
||||
sb := b.S()
|
||||
|
||||
j0 := id
|
||||
j1 := id + 1
|
||||
id += 2
|
||||
|
||||
tg := row[0]
|
||||
te := row[1]
|
||||
|
||||
wg := encryptHalf(alg, a, j0, &data)
|
||||
if sa {
|
||||
wg.Xor(tg)
|
||||
}
|
||||
we := encryptHalf(alg, b, j1, &data)
|
||||
if sb {
|
||||
we.Xor(te)
|
||||
we.Xor(a)
|
||||
}
|
||||
output = wg
|
||||
output.Xor(we)
|
||||
|
||||
case OR:
|
||||
row := garbled[i]
|
||||
index := idx(a, b)
|
||||
if index > 0 {
|
||||
// First row is zero and not transmitted.
|
||||
index--
|
||||
if index >= len(row) {
|
||||
return fmt.Errorf("corrupted circuit: index %d >= row %d",
|
||||
index, len(row))
|
||||
}
|
||||
c = row[index]
|
||||
}
|
||||
|
||||
output = decrypt(alg, a, b, id, c, &data)
|
||||
id++
|
||||
|
||||
case INV:
|
||||
row := garbled[i]
|
||||
index := idxUnary(a)
|
||||
if index > 0 {
|
||||
// First row is zero and not transmitted.
|
||||
index--
|
||||
if index >= len(row) {
|
||||
return fmt.Errorf("corrupted circuit: index %d >= row %d",
|
||||
index, len(row))
|
||||
}
|
||||
c = row[index]
|
||||
}
|
||||
output = decrypt(alg, a, ot.Label{}, id, c, &data)
|
||||
id++
|
||||
}
|
||||
wires[gate.Output] = output
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
159
bedlam/circuit/evaluator.go
Normal file
159
bedlam/circuit/evaluator.go
Normal file
@ -0,0 +1,159 @@
|
||||
//
|
||||
// evaluator.go
|
||||
//
|
||||
// Copyright (c) 2019-2023 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package circuit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/big"
|
||||
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/ot"
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/p2p"
|
||||
)
|
||||
|
||||
var (
|
||||
debug = false
|
||||
)
|
||||
|
||||
// Evaluator runs the evaluator on the P2P network.
|
||||
func Evaluator(conn *p2p.Conn, oti ot.OT, circ *Circuit, inputs *big.Int,
|
||||
verbose bool) ([]*big.Int, error) {
|
||||
|
||||
timing := NewTiming()
|
||||
|
||||
garbled := make([][]ot.Label, circ.NumGates)
|
||||
|
||||
// Receive program info.
|
||||
if verbose {
|
||||
fmt.Printf(" - Waiting for circuit info...\n")
|
||||
}
|
||||
key, err := conn.ReceiveData()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Receive garbled tables.
|
||||
timing.Sample("Wait", nil)
|
||||
if verbose {
|
||||
fmt.Printf(" - Receiving garbled circuit...\n")
|
||||
}
|
||||
count, err := conn.ReceiveUint32()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if count != circ.NumGates {
|
||||
return nil, fmt.Errorf("wrong number of gates: got %d, expected %d",
|
||||
count, circ.NumGates)
|
||||
}
|
||||
var label ot.Label
|
||||
var labelData ot.LabelData
|
||||
for i := 0; i < circ.NumGates; i++ {
|
||||
count, err := conn.ReceiveUint32()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
values := make([]ot.Label, count)
|
||||
for j := 0; j < count; j++ {
|
||||
err := conn.ReceiveLabel(&label, &labelData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
values[j] = label
|
||||
}
|
||||
garbled[i] = values
|
||||
}
|
||||
|
||||
wires := make([]ot.Label, circ.NumWires)
|
||||
|
||||
// Receive peer inputs.
|
||||
for i := 0; i < int(circ.Inputs[0].Type.Bits); i++ {
|
||||
err := conn.ReceiveLabel(&label, &labelData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
wires[Wire(i)] = label
|
||||
}
|
||||
|
||||
// Init oblivious transfer.
|
||||
err = oti.InitReceiver(conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ioStats := conn.Stats.Sum()
|
||||
timing.Sample("Recv", []string{FileSize(ioStats).String()})
|
||||
|
||||
// Query our inputs.
|
||||
if verbose {
|
||||
fmt.Printf(" - Querying our inputs...\n")
|
||||
}
|
||||
// Wire offset.
|
||||
if err := conn.SendUint32(int(circ.Inputs[0].Type.Bits)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Wire count.
|
||||
if err := conn.SendUint32(int(circ.Inputs[1].Type.Bits)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := conn.Flush(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
flags := make([]bool, int(circ.Inputs[1].Type.Bits))
|
||||
for i := 0; i < int(circ.Inputs[1].Type.Bits); i++ {
|
||||
if inputs.Bit(i) == 1 {
|
||||
flags[i] = true
|
||||
}
|
||||
}
|
||||
if err := oti.Receive(flags, wires[circ.Inputs[0].Type.Bits:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
xfer := conn.Stats.Sum() - ioStats
|
||||
ioStats = conn.Stats.Sum()
|
||||
timing.Sample("Inputs", []string{FileSize(xfer).String()})
|
||||
|
||||
// Evaluate gates.
|
||||
if verbose {
|
||||
fmt.Printf(" - Evaluating circuit...\n")
|
||||
}
|
||||
err = circ.Eval(key[:], wires, garbled)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
timing.Sample("Eval", nil)
|
||||
|
||||
// Resolve result values.
|
||||
|
||||
var labels []ot.Label
|
||||
|
||||
for i := 0; i < circ.Outputs.Size(); i++ {
|
||||
r := wires[Wire(circ.NumWires-circ.Outputs.Size()+i)]
|
||||
labels = append(labels, r)
|
||||
}
|
||||
for _, l := range labels {
|
||||
if err := conn.SendLabel(l, &labelData); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if err := conn.Flush(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result, err := conn.ReceiveData()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
raw := big.NewInt(0).SetBytes(result)
|
||||
|
||||
xfer = conn.Stats.Sum() - ioStats
|
||||
timing.Sample("Result", []string{FileSize(xfer).String()})
|
||||
if verbose {
|
||||
timing.Print(conn.Stats)
|
||||
}
|
||||
|
||||
return circ.Outputs.Split(raw), nil
|
||||
}
|
||||
410
bedlam/circuit/garble.go
Normal file
410
bedlam/circuit/garble.go
Normal file
@ -0,0 +1,410 @@
|
||||
//
|
||||
// Copyright (c) 2019-2023 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package circuit
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/ot"
|
||||
)
|
||||
|
||||
var (
|
||||
verbose = false
|
||||
)
|
||||
|
||||
func idxUnary(l0 ot.Label) int {
|
||||
if l0.S() {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func idx(l0, l1 ot.Label) int {
|
||||
var ret int
|
||||
|
||||
if l0.S() {
|
||||
ret |= 0x2
|
||||
}
|
||||
if l1.S() {
|
||||
ret |= 0x1
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func encrypt(alg cipher.Block, a, b, c ot.Label, t uint32,
|
||||
data *ot.LabelData) ot.Label {
|
||||
|
||||
k := makeK(a, b, t)
|
||||
|
||||
k.GetData(data)
|
||||
alg.Encrypt(data[:], data[:])
|
||||
|
||||
var pi ot.Label
|
||||
pi.SetData(data)
|
||||
|
||||
pi.Xor(k)
|
||||
pi.Xor(c)
|
||||
|
||||
return pi
|
||||
}
|
||||
|
||||
func decrypt(alg cipher.Block, a, b ot.Label, t uint32, c ot.Label,
|
||||
data *ot.LabelData) ot.Label {
|
||||
|
||||
k := makeK(a, b, t)
|
||||
|
||||
k.GetData(data)
|
||||
alg.Encrypt(data[:], data[:])
|
||||
|
||||
var crypted ot.Label
|
||||
crypted.SetData(data)
|
||||
|
||||
c.Xor(crypted)
|
||||
c.Xor(k)
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func makeK(a, b ot.Label, t uint32) ot.Label {
|
||||
a.Mul2()
|
||||
|
||||
b.Mul4()
|
||||
a.Xor(b)
|
||||
|
||||
a.Xor(ot.NewTweak(t))
|
||||
|
||||
return a
|
||||
}
|
||||
|
||||
// Hash function for half gates: Hπ(x, i) to be π(K) ⊕ K where K = 2x ⊕ i
|
||||
func encryptHalfReference(alg cipher.Block, x ot.Label, i uint32,
|
||||
data *ot.LabelData) ot.Label {
|
||||
|
||||
k := makeKHalf(x, i)
|
||||
|
||||
k.GetData(data)
|
||||
alg.Encrypt(data[:], data[:])
|
||||
|
||||
var pi ot.Label
|
||||
pi.SetData(data)
|
||||
|
||||
pi.Xor(k)
|
||||
|
||||
return pi
|
||||
}
|
||||
|
||||
// Optimized version of encryptHalfReference. Label operations are
|
||||
// inlined below, producing about 11% performance improvements.
|
||||
func encryptHalf(alg cipher.Block, x ot.Label, i uint32,
|
||||
data *ot.LabelData) ot.Label {
|
||||
|
||||
// k := makeKHalf(x, i) {
|
||||
k := x
|
||||
// k.Mul2()
|
||||
k.D0 <<= 1
|
||||
k.D0 |= (k.D1 >> 63)
|
||||
k.D1 <<= 1
|
||||
// k.Xor(ot.NewTweak(i))
|
||||
k.D1 ^= uint64(i)
|
||||
// }
|
||||
|
||||
// k.GetData(data) {
|
||||
binary.BigEndian.PutUint64(data[0:8], k.D0)
|
||||
binary.BigEndian.PutUint64(data[8:16], k.D1)
|
||||
// }
|
||||
|
||||
alg.Encrypt(data[:], data[:])
|
||||
|
||||
var pi ot.Label
|
||||
// pi.SetData(data) {
|
||||
pi.D0 = binary.BigEndian.Uint64((*data)[0:8])
|
||||
pi.D1 = binary.BigEndian.Uint64((*data)[8:16])
|
||||
// }
|
||||
|
||||
// pi.Xor(k) {
|
||||
pi.D0 ^= k.D0
|
||||
pi.D1 ^= k.D1
|
||||
// }
|
||||
|
||||
return pi
|
||||
}
|
||||
|
||||
// K = 2x ⊕ i
|
||||
func makeKHalf(x ot.Label, i uint32) ot.Label {
|
||||
x.Mul2()
|
||||
x.Xor(ot.NewTweak(i))
|
||||
return x
|
||||
}
|
||||
|
||||
func makeLabels(r ot.Label) (ot.Wire, error) {
|
||||
l0, err := ot.NewLabel(rand.Reader)
|
||||
if err != nil {
|
||||
return ot.Wire{}, err
|
||||
}
|
||||
l1 := l0
|
||||
l1.Xor(r)
|
||||
|
||||
return ot.Wire{
|
||||
L0: l0,
|
||||
L1: l1,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Garbled contains garbled circuit information.
|
||||
type Garbled struct {
|
||||
R ot.Label
|
||||
Wires []ot.Wire
|
||||
Gates [][]ot.Label
|
||||
}
|
||||
|
||||
// Lambda returns the lambda value of the wire.
|
||||
func (g *Garbled) Lambda(wire Wire) uint {
|
||||
if g.Wires[wire].L0.S() {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// SetLambda sets the lambda value of the wire.
|
||||
func (g *Garbled) SetLambda(wire Wire, val uint) {
|
||||
w := g.Wires[wire]
|
||||
if val == 0 {
|
||||
w.L0.SetS(false)
|
||||
} else {
|
||||
w.L0.SetS(true)
|
||||
}
|
||||
g.Wires[wire] = w
|
||||
}
|
||||
|
||||
// Garble garbles the circuit.
|
||||
func (c *Circuit) Garble(key []byte) (*Garbled, error) {
|
||||
// Create R.
|
||||
r, err := ot.NewLabel(rand.Reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r.SetS(true)
|
||||
|
||||
garbled := make([][]ot.Label, c.NumGates)
|
||||
|
||||
alg, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Wire labels.
|
||||
wires := make([]ot.Wire, c.NumWires)
|
||||
|
||||
// Assing all input wires.
|
||||
for i := 0; i < c.Inputs.Size(); i++ {
|
||||
w, err := makeLabels(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
wires[i] = w
|
||||
}
|
||||
|
||||
// Garble gates.
|
||||
var data ot.LabelData
|
||||
var id uint32
|
||||
for i := 0; i < len(c.Gates); i++ {
|
||||
gate := &c.Gates[i]
|
||||
data, err := gate.garble(wires, alg, r, &id, &data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
garbled[i] = data
|
||||
}
|
||||
|
||||
return &Garbled{
|
||||
R: r,
|
||||
Wires: wires,
|
||||
Gates: garbled,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Garble garbles the gate and returns it labels.
|
||||
func (g *Gate) garble(wires []ot.Wire, enc cipher.Block, r ot.Label,
|
||||
idp *uint32, data *ot.LabelData) ([]ot.Label, error) {
|
||||
|
||||
var a, b, c ot.Wire
|
||||
|
||||
var table [4]ot.Label
|
||||
var start, count int
|
||||
|
||||
// Inputs.
|
||||
switch g.Op {
|
||||
case XOR, XNOR, AND, OR:
|
||||
b = wires[g.Input1]
|
||||
fallthrough
|
||||
|
||||
case INV:
|
||||
a = wires[g.Input0]
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid gate type %s", g.Op)
|
||||
}
|
||||
|
||||
// Output.
|
||||
switch g.Op {
|
||||
case XOR:
|
||||
l0 := a.L0
|
||||
l0.Xor(b.L0)
|
||||
|
||||
l1 := l0
|
||||
l1.Xor(r)
|
||||
c = ot.Wire{
|
||||
L0: l0,
|
||||
L1: l1,
|
||||
}
|
||||
|
||||
case XNOR:
|
||||
l0 := a.L0
|
||||
l0.Xor(b.L0)
|
||||
|
||||
l1 := l0
|
||||
l1.Xor(r)
|
||||
c = ot.Wire{
|
||||
L0: l1,
|
||||
L1: l0,
|
||||
}
|
||||
|
||||
case AND:
|
||||
pa := a.L0.S()
|
||||
pb := b.L0.S()
|
||||
|
||||
j0 := *idp
|
||||
j1 := *idp + 1
|
||||
*idp = *idp + 2
|
||||
|
||||
// First half gate.
|
||||
tg := encryptHalf(enc, a.L0, j0, data)
|
||||
tg.Xor(encryptHalf(enc, a.L1, j0, data))
|
||||
if pb {
|
||||
tg.Xor(r)
|
||||
}
|
||||
wg0 := encryptHalf(enc, a.L0, j0, data)
|
||||
if pa {
|
||||
wg0.Xor(tg)
|
||||
}
|
||||
|
||||
// Second half gate.
|
||||
te := encryptHalf(enc, b.L0, j1, data)
|
||||
te.Xor(encryptHalf(enc, b.L1, j1, data))
|
||||
te.Xor(a.L0)
|
||||
we0 := encryptHalf(enc, b.L0, j1, data)
|
||||
if pb {
|
||||
we0.Xor(te)
|
||||
we0.Xor(a.L0)
|
||||
}
|
||||
|
||||
// Combine halves
|
||||
l0 := wg0
|
||||
l0.Xor(we0)
|
||||
|
||||
l1 := l0
|
||||
l1.Xor(r)
|
||||
|
||||
c = ot.Wire{
|
||||
L0: l0,
|
||||
L1: l1,
|
||||
}
|
||||
table[0] = tg
|
||||
table[1] = te
|
||||
count = 2
|
||||
|
||||
case OR, INV:
|
||||
// Row reduction creates labels below so that the first row is
|
||||
// all zero.
|
||||
|
||||
default:
|
||||
panic("invalid gate type")
|
||||
}
|
||||
|
||||
switch g.Op {
|
||||
case XOR, XNOR:
|
||||
// Free XOR.
|
||||
|
||||
case AND:
|
||||
// Half AND garbled above.
|
||||
|
||||
case OR:
|
||||
// a b c
|
||||
// -----
|
||||
// 0 0 0
|
||||
// 0 1 1
|
||||
// 1 0 1
|
||||
// 1 1 1
|
||||
id := *idp
|
||||
*idp = *idp + 1
|
||||
table[idx(a.L0, b.L0)] = encrypt(enc, a.L0, b.L0, c.L0, id, data)
|
||||
table[idx(a.L0, b.L1)] = encrypt(enc, a.L0, b.L1, c.L1, id, data)
|
||||
table[idx(a.L1, b.L0)] = encrypt(enc, a.L1, b.L0, c.L1, id, data)
|
||||
table[idx(a.L1, b.L1)] = encrypt(enc, a.L1, b.L1, c.L1, id, data)
|
||||
|
||||
l0Index := idx(a.L0, b.L0)
|
||||
|
||||
c.L0 = table[0]
|
||||
c.L1 = table[0]
|
||||
|
||||
if l0Index == 0 {
|
||||
c.L1.Xor(r)
|
||||
} else {
|
||||
c.L0.Xor(r)
|
||||
}
|
||||
for i := 0; i < 4; i++ {
|
||||
if i == l0Index {
|
||||
table[i].Xor(c.L0)
|
||||
} else {
|
||||
table[i].Xor(c.L1)
|
||||
}
|
||||
}
|
||||
start = 1
|
||||
count = 3
|
||||
|
||||
case INV:
|
||||
// a b c
|
||||
// -----
|
||||
// 0 1
|
||||
// 1 0
|
||||
id := *idp
|
||||
*idp = *idp + 1
|
||||
table[idxUnary(a.L0)] = encrypt(enc, a.L0, ot.Label{}, c.L1, id, data)
|
||||
table[idxUnary(a.L1)] = encrypt(enc, a.L1, ot.Label{}, c.L0, id, data)
|
||||
|
||||
l0Index := idxUnary(a.L0)
|
||||
|
||||
c.L0 = table[0]
|
||||
c.L1 = table[0]
|
||||
|
||||
if l0Index == 0 {
|
||||
c.L0.Xor(r)
|
||||
} else {
|
||||
c.L1.Xor(r)
|
||||
}
|
||||
for i := 0; i < 2; i++ {
|
||||
if i == l0Index {
|
||||
table[i].Xor(c.L1)
|
||||
} else {
|
||||
table[i].Xor(c.L0)
|
||||
}
|
||||
}
|
||||
start = 1
|
||||
count = 1
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid operand %s", g.Op)
|
||||
}
|
||||
wires[g.Output] = c
|
||||
|
||||
return table[start : start+count], nil
|
||||
}
|
||||
189
bedlam/circuit/garbler.go
Normal file
189
bedlam/circuit/garbler.go
Normal file
@ -0,0 +1,189 @@
|
||||
//
|
||||
// garbler.go
|
||||
//
|
||||
// Copyright (c) 2019-2023 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package circuit
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"math/big"
|
||||
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/ot"
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/p2p"
|
||||
)
|
||||
|
||||
// FileSize specifies a file (or data transfer) size in bytes.
|
||||
type FileSize uint64
|
||||
|
||||
func (s FileSize) String() string {
|
||||
if s > 1000*1000*1000*1000 {
|
||||
return fmt.Sprintf("%dTB", s/(1000*1000*1000*1000))
|
||||
} else if s > 1000*1000*1000 {
|
||||
return fmt.Sprintf("%dGB", s/(1000*1000*1000))
|
||||
} else if s > 1000*1000 {
|
||||
return fmt.Sprintf("%dMB", s/(1000*1000))
|
||||
} else if s > 1000 {
|
||||
return fmt.Sprintf("%dkB", s/1000)
|
||||
} else {
|
||||
return fmt.Sprintf("%dB", s)
|
||||
}
|
||||
}
|
||||
|
||||
// Garbler runs the garbler on the P2P network.
|
||||
func Garbler(conn *p2p.Conn, oti ot.OT, circ *Circuit, inputs *big.Int,
|
||||
verbose bool) ([]*big.Int, error) {
|
||||
|
||||
timing := NewTiming()
|
||||
if verbose {
|
||||
fmt.Printf(" - Garbling...\n")
|
||||
}
|
||||
|
||||
var key [32]byte
|
||||
_, err := rand.Read(key[:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
garbled, err := circ.Garble(key[:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
timing.Sample("Garble", nil)
|
||||
|
||||
// Send program info.
|
||||
if verbose {
|
||||
fmt.Printf(" - Sending garbled circuit...\n")
|
||||
}
|
||||
if err := conn.SendData(key[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Send garbled tables.
|
||||
if err := conn.SendUint32(len(garbled.Gates)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var labelData ot.LabelData
|
||||
for _, data := range garbled.Gates {
|
||||
if err := conn.SendUint32(len(data)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, d := range data {
|
||||
if err := conn.SendLabel(d, &labelData); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := conn.Flush(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Select our inputs.
|
||||
var n1 []ot.Label
|
||||
for i := 0; i < int(circ.Inputs[0].Type.Bits); i++ {
|
||||
wire := garbled.Wires[i]
|
||||
|
||||
var n ot.Label
|
||||
|
||||
if inputs.Bit(i) == 1 {
|
||||
n = wire.L1
|
||||
} else {
|
||||
n = wire.L0
|
||||
}
|
||||
n1 = append(n1, n)
|
||||
}
|
||||
|
||||
// Send our inputs.
|
||||
for idx, i := range n1 {
|
||||
if verbose && false {
|
||||
fmt.Printf("N1[%d]:\t%s\n", idx, i)
|
||||
}
|
||||
if err := conn.SendLabel(i, &labelData); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
ioStats := conn.Stats.Sum()
|
||||
timing.Sample("Xfer", []string{FileSize(ioStats).String()})
|
||||
if verbose {
|
||||
fmt.Printf(" - Processing messages...\n")
|
||||
}
|
||||
|
||||
if err := conn.Flush(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Init oblivious transfer.
|
||||
err = oti.InitSender(conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
xfer := conn.Stats.Sum() - ioStats
|
||||
ioStats = conn.Stats.Sum()
|
||||
timing.Sample("OT Init", []string{FileSize(xfer).String()})
|
||||
|
||||
// Peer OTs its inputs.
|
||||
offset, err := conn.ReceiveUint32()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
count, err := conn.ReceiveUint32()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if offset != int(circ.Inputs[0].Type.Bits) ||
|
||||
count != int(circ.Inputs[1].Type.Bits) {
|
||||
return nil, fmt.Errorf("peer can't OT wires [%d...%d[",
|
||||
offset, offset+count)
|
||||
}
|
||||
err = oti.Send(garbled.Wires[offset : offset+count])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
xfer = conn.Stats.Sum() - ioStats
|
||||
ioStats = conn.Stats.Sum()
|
||||
timing.Sample("OT", []string{FileSize(xfer).String()})
|
||||
|
||||
// Resolve result values.
|
||||
|
||||
result := big.NewInt(0)
|
||||
var label ot.Label
|
||||
|
||||
for i := 0; i < circ.Outputs.Size(); i++ {
|
||||
err := conn.ReceiveLabel(&label, &labelData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if i == 0 {
|
||||
timing.Sample("Eval", nil)
|
||||
}
|
||||
wire := garbled.Wires[circ.NumWires-circ.Outputs.Size()+i]
|
||||
|
||||
var bit uint
|
||||
if label.Equal(wire.L0) {
|
||||
bit = 0
|
||||
} else if label.Equal(wire.L1) {
|
||||
bit = 1
|
||||
} else {
|
||||
return nil, fmt.Errorf("unknown label %s for result %d", label, i)
|
||||
}
|
||||
result = big.NewInt(0).SetBit(result, i, bit)
|
||||
}
|
||||
data := result.Bytes()
|
||||
if err := conn.SendData(data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := conn.Flush(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
xfer = conn.Stats.Sum() - ioStats
|
||||
timing.Sample("Result", []string{FileSize(xfer).String()})
|
||||
if verbose {
|
||||
timing.Print(conn.Stats)
|
||||
}
|
||||
|
||||
return circ.Outputs.Split(result), nil
|
||||
}
|
||||
213
bedlam/circuit/ioarg.go
Normal file
213
bedlam/circuit/ioarg.go
Normal file
@ -0,0 +1,213 @@
|
||||
//
|
||||
// Copyright (c) 2019-2024 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package circuit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/big"
|
||||
"strings"
|
||||
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/types"
|
||||
)
|
||||
|
||||
// IO specifies circuit input and output arguments.
|
||||
type IO []IOArg
|
||||
|
||||
// Size computes the size of the circuit input and output arguments in
|
||||
// bits.
|
||||
func (io IO) Size() int {
|
||||
var sum int
|
||||
for _, a := range io {
|
||||
sum += int(a.Type.Bits)
|
||||
}
|
||||
return sum
|
||||
}
|
||||
|
||||
func (io IO) String() string {
|
||||
var str = ""
|
||||
for i, a := range io {
|
||||
if i > 0 {
|
||||
str += ", "
|
||||
}
|
||||
if len(a.Name) > 0 {
|
||||
str += a.Name + ":"
|
||||
}
|
||||
str += a.Type.String()
|
||||
}
|
||||
return str
|
||||
}
|
||||
|
||||
// Split splits the value into separate I/O arguments.
|
||||
func (io IO) Split(in *big.Int) []*big.Int {
|
||||
var result []*big.Int
|
||||
var bit int
|
||||
for _, arg := range io {
|
||||
r := big.NewInt(0)
|
||||
for i := 0; i < int(arg.Type.Bits); i++ {
|
||||
if in.Bit(bit) == 1 {
|
||||
r = big.NewInt(0).SetBit(r, i, 1)
|
||||
}
|
||||
bit++
|
||||
}
|
||||
result = append(result, r)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// IOArg describes circuit input argument.
|
||||
type IOArg struct {
|
||||
Name string
|
||||
Type types.Info
|
||||
Compound IO
|
||||
}
|
||||
|
||||
func (io IOArg) String() string {
|
||||
if len(io.Compound) > 0 {
|
||||
return io.Compound.String()
|
||||
}
|
||||
|
||||
if len(io.Name) > 0 {
|
||||
return io.Name + ":" + io.Type.String()
|
||||
}
|
||||
return io.Type.String()
|
||||
}
|
||||
|
||||
// Parse parses the I/O argument from the input string values.
|
||||
func (io IOArg) Parse(inputs []string) (*big.Int, error) {
|
||||
result := new(big.Int)
|
||||
|
||||
if len(io.Compound) == 0 {
|
||||
if len(inputs) != 1 {
|
||||
return nil,
|
||||
fmt.Errorf("invalid amount of arguments, got %d, expected 1",
|
||||
len(inputs))
|
||||
}
|
||||
|
||||
switch io.Type.Type {
|
||||
case types.TInt, types.TUint:
|
||||
_, ok := result.SetString(inputs[0], 0)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid input '%s' for %s",
|
||||
inputs[0], io.Type)
|
||||
}
|
||||
|
||||
case types.TBool:
|
||||
switch inputs[0] {
|
||||
case "0", "f", "false":
|
||||
case "1", "t", "true":
|
||||
result.SetInt64(1)
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid bool constant: %s", inputs[0])
|
||||
}
|
||||
|
||||
case types.TArray, types.TSlice:
|
||||
count := int(io.Type.ArraySize)
|
||||
elSize := int(io.Type.ElementType.Bits)
|
||||
if io.Type.Type == types.TArray && count == 0 {
|
||||
// Handle empty types.TArray arguments.
|
||||
break
|
||||
}
|
||||
|
||||
val := new(big.Int)
|
||||
_, ok := val.SetString(inputs[0], 0)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid input '%s' for %s",
|
||||
inputs[0], io.Type)
|
||||
}
|
||||
var bitLen int
|
||||
if strings.HasPrefix(inputs[0], "0x") {
|
||||
bitLen = (len(inputs[0]) - 2) * 4
|
||||
} else {
|
||||
bitLen = val.BitLen()
|
||||
}
|
||||
|
||||
valElCount := bitLen / elSize
|
||||
if bitLen%elSize != 0 {
|
||||
valElCount++
|
||||
}
|
||||
if io.Type.Type == types.TSlice {
|
||||
// Set the count=valElCount for types.TSlice arguments.
|
||||
count = valElCount
|
||||
}
|
||||
if valElCount > count {
|
||||
return nil, fmt.Errorf("too many values for input: %s",
|
||||
inputs[0])
|
||||
}
|
||||
pad := count - valElCount
|
||||
val.Lsh(val, uint(pad*elSize))
|
||||
|
||||
mask := new(big.Int)
|
||||
for i := 0; i < elSize; i++ {
|
||||
mask.SetBit(mask, i, 1)
|
||||
}
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
next := new(big.Int).Rsh(val, uint((count-i-1)*elSize))
|
||||
next = next.And(next, mask)
|
||||
|
||||
next.Lsh(next, uint(i*elSize))
|
||||
result.Or(result, next)
|
||||
}
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported input type: %s", io.Type)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
if len(inputs) != len(io.Compound) {
|
||||
return nil,
|
||||
fmt.Errorf("invalid amount of arguments, got %d, expected %d",
|
||||
len(inputs), len(io.Compound))
|
||||
}
|
||||
|
||||
var offset int
|
||||
|
||||
for idx, arg := range io.Compound {
|
||||
input, err := arg.Parse(inputs[idx : idx+1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
input.Lsh(input, uint(offset))
|
||||
result.Or(result, input)
|
||||
|
||||
offset += int(arg.Type.Bits)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// InputSizes computes the bit sizes of the input arguments. This is
|
||||
// used for parametrized main() when the program is instantiated based
|
||||
// on input sizes.
|
||||
func InputSizes(inputs []string) ([]int, error) {
|
||||
var result []int
|
||||
|
||||
for _, input := range inputs {
|
||||
switch input {
|
||||
case "_":
|
||||
result = append(result, 0)
|
||||
|
||||
case "0", "f", "false", "1", "t", "true":
|
||||
result = append(result, 1)
|
||||
|
||||
default:
|
||||
if strings.HasPrefix(input, "0x") {
|
||||
result = append(result, (len(input)-2)*4)
|
||||
} else {
|
||||
val := new(big.Int)
|
||||
_, ok := val.SetString(input, 0)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid input: %s", input)
|
||||
}
|
||||
result = append(result, val.BitLen())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
62
bedlam/circuit/ioarg_test.go
Normal file
62
bedlam/circuit/ioarg_test.go
Normal file
@ -0,0 +1,62 @@
|
||||
//
|
||||
// Copyright (c) 2023 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package circuit
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
var inputSizeTests = []struct {
|
||||
inputs []string
|
||||
sizes []int
|
||||
}{
|
||||
{
|
||||
inputs: []string{
|
||||
"0", "f", "false", "1", "t", "true",
|
||||
},
|
||||
sizes: []int{
|
||||
1, 1, 1, 1, 1, 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
inputs: []string{
|
||||
"0xdeadbeef", "255",
|
||||
},
|
||||
sizes: []int{
|
||||
32, 8,
|
||||
},
|
||||
},
|
||||
{
|
||||
inputs: []string{
|
||||
"0x0", "0x00", "0x000", "0x0000",
|
||||
},
|
||||
sizes: []int{
|
||||
4, 8, 12, 16,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
func TestInputSizes(t *testing.T) {
|
||||
for idx, test := range inputSizeTests {
|
||||
sizes, err := InputSizes(test.inputs)
|
||||
if err != nil {
|
||||
t.Errorf("t%v: InputSizes(%v) failed: %v", idx, test.inputs, err)
|
||||
continue
|
||||
}
|
||||
if len(sizes) != len(test.sizes) {
|
||||
t.Errorf("t%v: unexpected # of sizes: got %v, expected %v",
|
||||
idx, len(sizes), len(test.sizes))
|
||||
continue
|
||||
}
|
||||
for i := 0; i < len(sizes); i++ {
|
||||
if sizes[i] != test.sizes[i] {
|
||||
t.Errorf("t%v: sizes[%v]=%v, expected %v",
|
||||
idx, i, sizes[i], test.sizes[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
141
bedlam/circuit/marshal.go
Normal file
141
bedlam/circuit/marshal.go
Normal file
@ -0,0 +1,141 @@
|
||||
//
|
||||
// Copyright (c) 2020-2021, 2023 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package circuit
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
const (
|
||||
// MAGIC is a magic number for the QCL circuit format version 0.
|
||||
MAGIC = 0x63726300 // crc0
|
||||
)
|
||||
|
||||
var (
|
||||
bo = binary.BigEndian
|
||||
)
|
||||
|
||||
// MarshalFormat marshals circuit in the specified format.
|
||||
func (c *Circuit) MarshalFormat(out io.Writer, format string) error {
|
||||
switch format {
|
||||
case "qclc":
|
||||
return c.Marshal(out)
|
||||
case "bristol":
|
||||
return c.MarshalBristol(out)
|
||||
default:
|
||||
return fmt.Errorf("unsupported circuit format: %s", format)
|
||||
}
|
||||
}
|
||||
|
||||
// Marshal marshals circuit in the QCL circuit format.
|
||||
func (c *Circuit) Marshal(out io.Writer) error {
|
||||
var data = []interface{}{
|
||||
uint32(MAGIC),
|
||||
uint32(c.NumGates),
|
||||
uint32(c.NumWires),
|
||||
uint32(len(c.Inputs)),
|
||||
uint32(len(c.Outputs)),
|
||||
}
|
||||
for _, v := range data {
|
||||
if err := binary.Write(out, bo, v); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
for _, input := range c.Inputs {
|
||||
if err := marshalIOArg(out, input); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
for _, output := range c.Outputs {
|
||||
if err := marshalIOArg(out, output); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
for _, g := range c.Gates {
|
||||
switch g.Op {
|
||||
case XOR, XNOR, AND, OR:
|
||||
data = []interface{}{
|
||||
byte(g.Op),
|
||||
uint32(g.Input0), uint32(g.Input1), uint32(g.Output),
|
||||
}
|
||||
|
||||
case INV:
|
||||
data = []interface{}{
|
||||
byte(g.Op),
|
||||
uint32(g.Input0), uint32(g.Output),
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unsupported gate type %s", g.Op)
|
||||
}
|
||||
for _, v := range data {
|
||||
if err := binary.Write(out, bo, v); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func marshalIOArg(out io.Writer, arg IOArg) error {
|
||||
if err := marshalString(out, arg.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := marshalString(out, arg.Type.String()); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := binary.Write(out, bo, uint32(arg.Type.Bits)); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := binary.Write(out, bo, uint32(len(arg.Compound))); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, c := range arg.Compound {
|
||||
if err := marshalIOArg(out, c); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func marshalString(out io.Writer, val string) error {
|
||||
bytes := []byte(val)
|
||||
if err := binary.Write(out, bo, uint32(len(bytes))); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := out.Write(bytes)
|
||||
return err
|
||||
}
|
||||
|
||||
// MarshalBristol marshals the circuit in the Bristol format.
|
||||
func (c *Circuit) MarshalBristol(out io.Writer) error {
|
||||
fmt.Fprintf(out, "%d %d\n", c.NumGates, c.NumWires)
|
||||
fmt.Fprintf(out, "%d", len(c.Inputs))
|
||||
for _, input := range c.Inputs {
|
||||
fmt.Fprintf(out, " %d", input.Type.Bits)
|
||||
}
|
||||
fmt.Fprintln(out)
|
||||
fmt.Fprintf(out, "%d", len(c.Outputs))
|
||||
for _, ret := range c.Outputs {
|
||||
fmt.Fprintf(out, " %d", ret.Type.Bits)
|
||||
}
|
||||
fmt.Fprintln(out)
|
||||
fmt.Fprintln(out)
|
||||
|
||||
for _, g := range c.Gates {
|
||||
fmt.Fprintf(out, "%d 1", len(g.Inputs()))
|
||||
for _, w := range g.Inputs() {
|
||||
fmt.Fprintf(out, " %d", w)
|
||||
}
|
||||
fmt.Fprintf(out, " %d", g.Output)
|
||||
fmt.Fprintf(out, " %s\n", g.Op)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
511
bedlam/circuit/parser.go
Normal file
511
bedlam/circuit/parser.go
Normal file
@ -0,0 +1,511 @@
|
||||
//
|
||||
// Copyright (c) 2019-2023 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package circuit
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"os"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/types"
|
||||
)
|
||||
|
||||
var reParts = regexp.MustCompilePOSIX("[[:space:]]+")
|
||||
|
||||
// Seen describes whether wire has been seen.
|
||||
type Seen []bool
|
||||
|
||||
// Get gets the wire seen flag.
|
||||
func (s Seen) Get(index Wire) (bool, error) {
|
||||
if index >= Wire(len(s)) {
|
||||
return false, fmt.Errorf("invalid wire %d [0...%d[", index, len(s))
|
||||
}
|
||||
return s[index], nil
|
||||
}
|
||||
|
||||
// Set marks the wire seen.
|
||||
func (s Seen) Set(index Wire) error {
|
||||
if index >= Wire(len(s)) {
|
||||
return fmt.Errorf("invalid wire %d [0...%d[", index, len(s))
|
||||
}
|
||||
s[index] = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsFilename tests if the argument file is a potential circuit
|
||||
// filename.
|
||||
func IsFilename(file string) bool {
|
||||
return strings.HasSuffix(file, ".circ") ||
|
||||
strings.HasSuffix(file, ".bristol") ||
|
||||
strings.HasSuffix(file, ".qclc")
|
||||
}
|
||||
|
||||
// Parse parses the circuit file.
|
||||
func Parse(file string) (*Circuit, error) {
|
||||
f, err := os.Open(file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if strings.HasSuffix(file, ".circ") || strings.HasSuffix(file, ".bristol") {
|
||||
return ParseBristol(f)
|
||||
} else if strings.HasSuffix(file, ".qclc") {
|
||||
return ParseQCLC(f)
|
||||
}
|
||||
return nil, fmt.Errorf("unsupported circuit format")
|
||||
}
|
||||
|
||||
// ParseQCLC parses an QCL circuit file.
|
||||
func ParseQCLC(in io.Reader) (*Circuit, error) {
|
||||
r := bufio.NewReader(in)
|
||||
|
||||
var header struct {
|
||||
Magic uint32
|
||||
NumGates uint32
|
||||
NumWires uint32
|
||||
NumInputs uint32
|
||||
NumOutputs uint32
|
||||
}
|
||||
if err := binary.Read(r, bo, &header); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var inputs, outputs IO
|
||||
var inputWires, outputWires int
|
||||
|
||||
wiresSeen := make(Seen, header.NumWires)
|
||||
|
||||
for i := 0; i < int(header.NumInputs); i++ {
|
||||
arg, err := parseIOArg(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
inputs = append(inputs, arg)
|
||||
inputWires += int(arg.Type.Bits)
|
||||
}
|
||||
for i := 0; i < int(header.NumOutputs); i++ {
|
||||
out, err := parseIOArg(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
outputs = append(outputs, out)
|
||||
outputWires += int(out.Type.Bits)
|
||||
}
|
||||
|
||||
// Mark input wires seen.
|
||||
for i := 0; i < inputWires; i++ {
|
||||
if err := wiresSeen.Set(Wire(i)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
gates := make([]Gate, header.NumGates)
|
||||
var stats Stats
|
||||
var gate int
|
||||
for gate = 0; ; gate++ {
|
||||
op, err := r.ReadByte()
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
switch Operation(op) {
|
||||
case XOR, XNOR, AND, OR:
|
||||
var bin struct {
|
||||
Input0 uint32
|
||||
Input1 uint32
|
||||
Output uint32
|
||||
}
|
||||
if err := binary.Read(r, bo, &bin); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
seen, err := wiresSeen.Get(Wire(bin.Input0))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !seen {
|
||||
return nil, fmt.Errorf("input %d of gate %d not set",
|
||||
bin.Input0, gate)
|
||||
}
|
||||
seen, err = wiresSeen.Get(Wire(bin.Input1))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !seen {
|
||||
return nil, fmt.Errorf("input %d of gate %d not set",
|
||||
bin.Input1, gate)
|
||||
}
|
||||
if err := wiresSeen.Set(Wire(bin.Output)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
gates[gate] = Gate{
|
||||
Input0: Wire(bin.Input0),
|
||||
Input1: Wire(bin.Input1),
|
||||
Output: Wire(bin.Output),
|
||||
Op: Operation(op),
|
||||
}
|
||||
|
||||
case INV:
|
||||
var unary struct {
|
||||
Input0 uint32
|
||||
Output uint32
|
||||
}
|
||||
if err := binary.Read(r, bo, &unary); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
seen, err := wiresSeen.Get(Wire(unary.Input0))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !seen {
|
||||
return nil, fmt.Errorf("input %d of gate %d not set",
|
||||
unary.Input0, gate)
|
||||
}
|
||||
if err := wiresSeen.Set(Wire(unary.Output)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
gates[gate] = Gate{
|
||||
Input0: Wire(unary.Input0),
|
||||
Output: Wire(unary.Output),
|
||||
Op: Operation(op),
|
||||
}
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported gate type %s", Operation(op))
|
||||
}
|
||||
stats[Operation(op)]++
|
||||
}
|
||||
|
||||
if uint32(gate) != header.NumGates {
|
||||
return nil, fmt.Errorf("not enough gates: got %d, expected %d",
|
||||
gate, header.NumGates)
|
||||
}
|
||||
|
||||
// Check that all wires are seen.
|
||||
for i := 0; i < len(wiresSeen); i++ {
|
||||
if !wiresSeen[i] {
|
||||
return nil, fmt.Errorf("wire %d not assigned", i)
|
||||
}
|
||||
}
|
||||
|
||||
return &Circuit{
|
||||
NumGates: int(header.NumGates),
|
||||
NumWires: int(header.NumWires),
|
||||
Inputs: inputs,
|
||||
Outputs: outputs,
|
||||
Gates: gates,
|
||||
Stats: stats,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parseIOArg(r *bufio.Reader) (arg IOArg, err error) {
|
||||
name, err := parseString(r)
|
||||
if err != nil {
|
||||
return arg, err
|
||||
}
|
||||
t, err := parseString(r)
|
||||
if err != nil {
|
||||
return arg, err
|
||||
}
|
||||
var ui32 uint32
|
||||
if err := binary.Read(r, bo, &ui32); err != nil {
|
||||
return arg, err
|
||||
}
|
||||
arg.Name = name
|
||||
arg.Type, err = types.Parse(t)
|
||||
if err != nil {
|
||||
return arg, err
|
||||
}
|
||||
arg.Type.Bits = types.Size(ui32)
|
||||
|
||||
// Compound
|
||||
if err := binary.Read(r, bo, &ui32); err != nil {
|
||||
return arg, err
|
||||
}
|
||||
for i := 0; i < int(ui32); i++ {
|
||||
c, err := parseIOArg(r)
|
||||
if err != nil {
|
||||
return arg, err
|
||||
}
|
||||
arg.Compound = append(arg.Compound, c)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func parseString(r *bufio.Reader) (string, error) {
|
||||
var ui32 uint32
|
||||
if err := binary.Read(r, bo, &ui32); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if ui32 == 0 {
|
||||
return "", nil
|
||||
}
|
||||
buf := make([]byte, ui32)
|
||||
_, err := r.Read(buf)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(buf), nil
|
||||
}
|
||||
|
||||
// ParseBristol parses a Briston circuit file.
|
||||
func ParseBristol(in io.Reader) (*Circuit, error) {
|
||||
r := bufio.NewReader(in)
|
||||
|
||||
// NumGates NumWires
|
||||
line, err := readLine(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(line) != 2 {
|
||||
return nil, fmt.Errorf("invalid 1st line: '%s'", line)
|
||||
}
|
||||
numGates, err := strconv.Atoi(line[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if numGates < 0 || numGates > math.MaxInt32 {
|
||||
return nil, fmt.Errorf("invalid numGates: %d", numGates)
|
||||
}
|
||||
numWires, err := strconv.Atoi(line[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if numWires < 0 || numWires > math.MaxInt32 {
|
||||
return nil, fmt.Errorf("invalid numWires: %d", numWires)
|
||||
}
|
||||
wiresSeen := make(Seen, numWires)
|
||||
|
||||
// Inputs
|
||||
line, err = readLine(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
niv, err := strconv.Atoi(line[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if 1+niv != len(line) {
|
||||
return nil, fmt.Errorf("invalid inputs line: niv=%d, len=%d",
|
||||
niv, len(line))
|
||||
}
|
||||
var inputs IO
|
||||
var inputWires int64
|
||||
for i := 1; i < len(line); i++ {
|
||||
bits, err := strconv.ParseInt(line[i], 10, 32)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid input bits: %s", err)
|
||||
}
|
||||
if bits < 0 {
|
||||
return nil, fmt.Errorf("invalid input bits: %d", bits)
|
||||
}
|
||||
inputs = append(inputs, IOArg{
|
||||
Name: fmt.Sprintf("NI%d", i),
|
||||
Type: types.Info{
|
||||
Type: types.TUint,
|
||||
IsConcrete: true,
|
||||
Bits: types.Size(bits),
|
||||
},
|
||||
})
|
||||
inputWires += bits
|
||||
}
|
||||
if inputWires == 0 {
|
||||
return nil, fmt.Errorf("no inputs defined")
|
||||
}
|
||||
|
||||
// Mark input wires seen.
|
||||
for i := int64(0); i < inputWires; i++ {
|
||||
if err := wiresSeen.Set(Wire(i)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Outputs
|
||||
line, err = readLine(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nov, err := strconv.Atoi(line[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if 1+nov != len(line) {
|
||||
return nil, errors.New("invalid outputs line")
|
||||
}
|
||||
var outputs IO
|
||||
for i := 1; i < len(line); i++ {
|
||||
bits, err := strconv.ParseInt(line[i], 10, 32)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid output bits: %s", err)
|
||||
}
|
||||
if bits < 0 {
|
||||
return nil, fmt.Errorf("invalid output bits: %d", bits)
|
||||
}
|
||||
outputs = append(outputs, IOArg{
|
||||
Name: fmt.Sprintf("NO%d", i),
|
||||
Type: types.Info{
|
||||
Type: types.TUint,
|
||||
IsConcrete: true,
|
||||
Bits: types.Size(bits),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
gates := make([]Gate, numGates)
|
||||
var stats Stats
|
||||
var gate int
|
||||
for gate = 0; ; gate++ {
|
||||
line, err = readLine(r)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if gate >= numGates {
|
||||
return nil, errors.New("too many gates")
|
||||
}
|
||||
if len(line) < 3 {
|
||||
return nil, fmt.Errorf("invalid gate: %v", line)
|
||||
}
|
||||
n1, err := strconv.Atoi(line[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if n1 < 0 {
|
||||
return nil, fmt.Errorf("invalid n1: %v", n1)
|
||||
}
|
||||
n2, err := strconv.Atoi(line[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if n2 < 0 {
|
||||
return nil, fmt.Errorf("invalid n2: %v", n2)
|
||||
}
|
||||
if 2+n1+n2+1 != len(line) {
|
||||
return nil, fmt.Errorf("invalid gate: %v", line)
|
||||
}
|
||||
|
||||
var inputs []Wire
|
||||
for i := 0; i < n1; i++ {
|
||||
v, err := strconv.ParseUint(line[2+i], 10, 32)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
seen, err := wiresSeen.Get(Wire(v))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !seen {
|
||||
return nil, fmt.Errorf("input %d of gate %d not set", v, gate)
|
||||
}
|
||||
inputs = append(inputs, Wire(v))
|
||||
}
|
||||
|
||||
var outputs []Wire
|
||||
for i := 0; i < n2; i++ {
|
||||
v, err := strconv.ParseUint(line[2+n1+i], 10, 32)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = wiresSeen.Set(Wire(v))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
outputs = append(outputs, Wire(v))
|
||||
}
|
||||
var op Operation
|
||||
var numInputs int
|
||||
switch line[len(line)-1] {
|
||||
case "XOR":
|
||||
op = XOR
|
||||
numInputs = 2
|
||||
case "XNOR":
|
||||
op = XNOR
|
||||
numInputs = 2
|
||||
case "AND":
|
||||
op = AND
|
||||
numInputs = 2
|
||||
case "OR":
|
||||
op = OR
|
||||
numInputs = 2
|
||||
case "INV":
|
||||
op = INV
|
||||
numInputs = 1
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid operation '%s'", line[len(line)-1])
|
||||
}
|
||||
|
||||
if len(inputs) != numInputs {
|
||||
return nil, fmt.Errorf("invalid number of inputs %d for %s",
|
||||
len(inputs), op)
|
||||
}
|
||||
if len(outputs) != 1 {
|
||||
return nil, fmt.Errorf("invalid number of outputs %d for %s",
|
||||
len(outputs), op)
|
||||
}
|
||||
|
||||
var input1 Wire
|
||||
if len(inputs) > 1 {
|
||||
input1 = inputs[1]
|
||||
}
|
||||
|
||||
gates[gate] = Gate{
|
||||
Input0: inputs[0],
|
||||
Input1: input1,
|
||||
Output: outputs[0],
|
||||
Op: op,
|
||||
}
|
||||
stats[op]++
|
||||
}
|
||||
if gate != numGates {
|
||||
return nil, fmt.Errorf("not enough gates: got %d, expected %d",
|
||||
gate, numGates)
|
||||
}
|
||||
|
||||
// Check that all wires are seen.
|
||||
for i := 0; i < len(wiresSeen); i++ {
|
||||
if !wiresSeen[i] {
|
||||
return nil, fmt.Errorf("wire %d not assigned", i)
|
||||
}
|
||||
}
|
||||
|
||||
return &Circuit{
|
||||
NumGates: numGates,
|
||||
NumWires: numWires,
|
||||
Inputs: inputs,
|
||||
Outputs: outputs,
|
||||
Gates: gates,
|
||||
Stats: stats,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func readLine(r *bufio.Reader) ([]string, error) {
|
||||
for {
|
||||
line, err := r.ReadString('\n')
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
line = strings.TrimSpace(line)
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
parts := reParts.Split(line, -1)
|
||||
if len(parts) > 0 {
|
||||
return parts, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
28
bedlam/circuit/parser_test.go
Normal file
28
bedlam/circuit/parser_test.go
Normal file
@ -0,0 +1,28 @@
|
||||
//
|
||||
// parser_test.go
|
||||
//
|
||||
// Copyright (c) 2019-2023 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package circuit
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var data = `1 3
|
||||
2 1 1
|
||||
1 1
|
||||
|
||||
2 1 0 1 2 AND
|
||||
`
|
||||
|
||||
func TestParse(t *testing.T) {
|
||||
_, err := ParseBristol(bytes.NewReader([]byte(data)))
|
||||
if err != nil {
|
||||
t.Fatalf("Parse failed: %s", err)
|
||||
}
|
||||
}
|
||||
501
bedlam/circuit/stream_evaluator.go
Normal file
501
bedlam/circuit/stream_evaluator.go
Normal file
@ -0,0 +1,501 @@
|
||||
//
|
||||
// Copyright (c) 2020-2024 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package circuit
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"time"
|
||||
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/ot"
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/p2p"
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/types"
|
||||
)
|
||||
|
||||
// Protocol operation codes.
|
||||
const (
|
||||
OpResult = iota
|
||||
OpCircuit
|
||||
OpReturn
|
||||
)
|
||||
|
||||
// StreamEval is a streaming garbled circuit evaluator.
|
||||
type StreamEval struct {
|
||||
key []byte
|
||||
alg cipher.Block
|
||||
wires []ot.Label
|
||||
tmp []ot.Label
|
||||
}
|
||||
|
||||
// NewStreamEval creates a new streaming garbled circuit evaluator.
|
||||
func NewStreamEval(key []byte, numInputs, numOutputs int) (*StreamEval, error) {
|
||||
alg, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &StreamEval{
|
||||
key: key,
|
||||
alg: alg,
|
||||
wires: make([]ot.Label, numInputs+numOutputs),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Get gets the value of the wire.
|
||||
func (stream *StreamEval) Get(tmp bool, w int) ot.Label {
|
||||
if tmp {
|
||||
return stream.tmp[w]
|
||||
}
|
||||
return stream.wires[w]
|
||||
}
|
||||
|
||||
// GetInputs gets the specified input wire range.
|
||||
func (stream *StreamEval) GetInputs(offset, count int) []ot.Label {
|
||||
return stream.wires[offset : offset+count]
|
||||
}
|
||||
|
||||
// Set sets the value of the wire.
|
||||
func (stream *StreamEval) Set(tmp bool, w int, label ot.Label) {
|
||||
if tmp {
|
||||
stream.tmp[w] = label
|
||||
} else {
|
||||
stream.wires[w] = label
|
||||
}
|
||||
}
|
||||
|
||||
// InitCircuit initializes the stream evaluator with wires.
|
||||
func (stream *StreamEval) InitCircuit(numWires, numTmpWires int) {
|
||||
if numWires > len(stream.wires) {
|
||||
var size int
|
||||
for size = 1024; size < numWires; size *= 2 {
|
||||
}
|
||||
n := make([]ot.Label, size)
|
||||
copy(n, stream.wires)
|
||||
stream.wires = n
|
||||
}
|
||||
if numTmpWires > len(stream.tmp) {
|
||||
var size int
|
||||
for size = 1024; size < numTmpWires; size *= 2 {
|
||||
}
|
||||
stream.tmp = make([]ot.Label, size)
|
||||
}
|
||||
}
|
||||
|
||||
// StreamEvaluator runs the stream evaluator on the connection.
|
||||
func StreamEvaluator(conn *p2p.Conn, oti ot.OT, inputFlag []string,
|
||||
verbose bool) (IO, []*big.Int, error) {
|
||||
|
||||
timing := NewTiming()
|
||||
|
||||
// Receive program info.
|
||||
if verbose {
|
||||
fmt.Printf(" - Waiting for program info...\n")
|
||||
}
|
||||
key, err := conn.ReceiveData()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
alg, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
// Peer input.
|
||||
in1, err := receiveArgument(conn)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
// Our input.
|
||||
in2, err := receiveArgument(conn)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
inputs, err := in2.Parse(inputFlag)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
// Program outputs.
|
||||
numOutputs, err := conn.ReceiveUint32()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
var outputs IO
|
||||
for i := 0; i < numOutputs; i++ {
|
||||
out, err := receiveArgument(conn)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
outputs = append(outputs, out)
|
||||
}
|
||||
|
||||
numSteps, err := conn.ReceiveUint32()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
fmt.Printf(" - In1: %s\n", in1)
|
||||
fmt.Printf(" + In2: %s\n", in2)
|
||||
fmt.Printf(" - Out: %s\n", outputs)
|
||||
fmt.Printf(" - In: %s\n", inputFlag)
|
||||
|
||||
streaming, err := NewStreamEval(key, int(in1.Type.Bits+in2.Type.Bits),
|
||||
outputs.Size())
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Receive peer inputs.
|
||||
var label ot.Label
|
||||
var labelData ot.LabelData
|
||||
for w := 0; w < int(in1.Type.Bits); w++ {
|
||||
err := conn.ReceiveLabel(&label, &labelData)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
streaming.Set(false, w, label)
|
||||
}
|
||||
|
||||
// Init oblivious transfer.
|
||||
err = oti.InitReceiver(conn)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
ioStats := conn.Stats.Sum()
|
||||
timing.Sample("Init", []string{FileSize(ioStats).String()})
|
||||
|
||||
// Query our inputs.
|
||||
if verbose {
|
||||
fmt.Printf(" - Querying our inputs...\n")
|
||||
}
|
||||
flags := make([]bool, in2.Type.Bits)
|
||||
for i := 0; i < int(in2.Type.Bits); i++ {
|
||||
if inputs.Bit(i) == 1 {
|
||||
flags[i] = true
|
||||
}
|
||||
}
|
||||
inputLabels := streaming.GetInputs(int(in1.Type.Bits), int(in2.Type.Bits))
|
||||
if err := oti.Receive(flags, inputLabels); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
xfer := conn.Stats.Sum() - ioStats
|
||||
ioStats = conn.Stats.Sum()
|
||||
timing.Sample("Inputs", []string{FileSize(xfer).String()})
|
||||
|
||||
ws := func(i int, tmp bool) string {
|
||||
if tmp {
|
||||
return fmt.Sprintf("~%d", i)
|
||||
}
|
||||
return fmt.Sprintf("w%d", i)
|
||||
}
|
||||
|
||||
// Evaluate program.
|
||||
if verbose {
|
||||
fmt.Printf(" - Evaluating program...\n")
|
||||
}
|
||||
var garbled [4]ot.Label
|
||||
var lastStep int
|
||||
|
||||
var rawResult *big.Int
|
||||
|
||||
start := time.Now()
|
||||
lastReport := start
|
||||
loop:
|
||||
for {
|
||||
op, err := conn.ReceiveUint32()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
switch op {
|
||||
case OpCircuit:
|
||||
step, err := conn.ReceiveUint32()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
numGates, err := conn.ReceiveUint32()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
numTmpWires, err := conn.ReceiveUint32()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
numWires, err := conn.ReceiveUint32()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if step-lastStep >= 10 && verbose {
|
||||
lastStep = step
|
||||
now := time.Now()
|
||||
if now.Sub(lastReport) > time.Second*5 {
|
||||
lastReport = now
|
||||
elapsed := time.Now().Sub(start)
|
||||
done := float64(step) / float64(numSteps)
|
||||
if done > 0 {
|
||||
total := time.Duration(float64(elapsed) / done)
|
||||
progress := fmt.Sprintf("%d/%d", step, numSteps)
|
||||
remaining := fmt.Sprintf("%24s", total-elapsed)
|
||||
fmt.Printf("%-14s\t%s remaining\tETA %s\n",
|
||||
progress, remaining,
|
||||
start.Add(total).Format(time.Stamp))
|
||||
} else {
|
||||
fmt.Printf("%d/%d\n", step, numSteps)
|
||||
}
|
||||
}
|
||||
}
|
||||
streaming.InitCircuit(numWires, numTmpWires)
|
||||
var id uint32
|
||||
for i := 0; i < numGates; i++ {
|
||||
gop, err := conn.ReceiveByte()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
var aTmp, bTmp, cTmp bool
|
||||
if gop&0b10000000 != 0 {
|
||||
aTmp = true
|
||||
}
|
||||
if gop&0b01000000 != 0 {
|
||||
bTmp = true
|
||||
}
|
||||
if gop&0b00100000 != 0 {
|
||||
cTmp = true
|
||||
}
|
||||
var recvWire func() (int, error)
|
||||
if gop&0b00010000 != 0 {
|
||||
recvWire = conn.ReceiveUint16
|
||||
} else {
|
||||
recvWire = conn.ReceiveUint32
|
||||
}
|
||||
|
||||
gop &^= 0b11110000
|
||||
|
||||
var aIndex, bIndex, cIndex int
|
||||
var tableCount int
|
||||
|
||||
switch Operation(gop) {
|
||||
case XOR, XNOR, AND, OR:
|
||||
aIndex, err = recvWire()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
bIndex, err = recvWire()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
cIndex, err = recvWire()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
case INV:
|
||||
aIndex, err = recvWire()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
cIndex, err = recvWire()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("invalid operation %s",
|
||||
Operation(gop))
|
||||
}
|
||||
switch Operation(gop) {
|
||||
case XOR, XNOR:
|
||||
tableCount = 0
|
||||
case INV:
|
||||
tableCount = 1
|
||||
case AND:
|
||||
tableCount = 2
|
||||
case OR:
|
||||
tableCount = 3
|
||||
}
|
||||
|
||||
for c := 0; c < tableCount; c++ {
|
||||
err = conn.ReceiveLabel(&label, &labelData)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
garbled[c] = label
|
||||
}
|
||||
|
||||
var a, b, c ot.Label
|
||||
|
||||
switch Operation(gop) {
|
||||
case XOR, XNOR, AND, OR:
|
||||
if StreamDebug {
|
||||
fmt.Printf("Gate%d:\t %s %s %s %s\n", i,
|
||||
ws(aIndex, aTmp), ws(bIndex, bTmp),
|
||||
Operation(gop), ws(cIndex, cTmp))
|
||||
}
|
||||
a = streaming.Get(aTmp, aIndex)
|
||||
b = streaming.Get(bTmp, bIndex)
|
||||
|
||||
case INV:
|
||||
if StreamDebug {
|
||||
fmt.Printf("Gate%d:\t %s %s %s\n", i,
|
||||
ws(aIndex, aTmp), Operation(gop), ws(bIndex, bTmp))
|
||||
}
|
||||
a = streaming.Get(aTmp, aIndex)
|
||||
}
|
||||
|
||||
var output ot.Label
|
||||
|
||||
switch Operation(gop) {
|
||||
case XOR, XNOR:
|
||||
a.Xor(b)
|
||||
output = a
|
||||
|
||||
case AND:
|
||||
if tableCount != 2 {
|
||||
return nil, nil,
|
||||
fmt.Errorf("corrupted ciruit: AND table size: %d",
|
||||
tableCount)
|
||||
}
|
||||
sa := a.S()
|
||||
sb := b.S()
|
||||
|
||||
j0 := id
|
||||
j1 := id + 1
|
||||
id += 2
|
||||
|
||||
tg := garbled[0]
|
||||
te := garbled[1]
|
||||
|
||||
wg := encryptHalf(alg, a, j0, &labelData)
|
||||
if sa {
|
||||
wg.Xor(tg)
|
||||
}
|
||||
we := encryptHalf(alg, b, j1, &labelData)
|
||||
if sb {
|
||||
we.Xor(te)
|
||||
we.Xor(a)
|
||||
}
|
||||
output = wg
|
||||
output.Xor(we)
|
||||
|
||||
case OR:
|
||||
index := idx(a, b)
|
||||
if index > 0 {
|
||||
// First row is zero and not transmitted.
|
||||
index--
|
||||
if index >= tableCount {
|
||||
return nil, nil,
|
||||
fmt.Errorf("corrupted circuit: index %d >= %d",
|
||||
index, tableCount)
|
||||
}
|
||||
c = garbled[index]
|
||||
}
|
||||
output = decrypt(alg, a, b, id, c, &labelData)
|
||||
id++
|
||||
|
||||
case INV:
|
||||
index := idxUnary(a)
|
||||
if index > 0 {
|
||||
// First row is zero and not transmitted.
|
||||
index--
|
||||
if index >= tableCount {
|
||||
return nil, nil,
|
||||
fmt.Errorf("corrupted circuit: index %d >= %d",
|
||||
index, tableCount)
|
||||
}
|
||||
c = garbled[index]
|
||||
}
|
||||
|
||||
output = decrypt(alg, a, b, id, c, &labelData)
|
||||
id++
|
||||
}
|
||||
streaming.Set(cTmp, cIndex, output)
|
||||
}
|
||||
|
||||
case OpReturn:
|
||||
xfer := conn.Stats.Sum() - ioStats
|
||||
ioStats = conn.Stats.Sum()
|
||||
timing.Sample("Eval", []string{FileSize(xfer).String()})
|
||||
|
||||
var labels []ot.Label
|
||||
for i := 0; i < outputs.Size(); i++ {
|
||||
id, err := conn.ReceiveUint32()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
label := streaming.Get(false, id)
|
||||
labels = append(labels, label)
|
||||
}
|
||||
|
||||
// Resolve result values.
|
||||
if err := conn.SendUint32(OpResult); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
var labelData ot.LabelData
|
||||
for _, l := range labels {
|
||||
if err := conn.SendLabel(l, &labelData); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
if err := conn.Flush(); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
result, err := conn.ReceiveData()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
rawResult = new(big.Int).SetBytes(result)
|
||||
break loop
|
||||
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("unknown operation %d", op)
|
||||
}
|
||||
}
|
||||
|
||||
xfer = conn.Stats.Sum() - ioStats
|
||||
timing.Sample("Result", []string{FileSize(xfer).String()})
|
||||
|
||||
if verbose {
|
||||
timing.Print(conn.Stats)
|
||||
}
|
||||
|
||||
return outputs, outputs.Split(rawResult), nil
|
||||
}
|
||||
|
||||
func receiveArgument(conn *p2p.Conn) (arg IOArg, err error) {
|
||||
name, err := conn.ReceiveString()
|
||||
if err != nil {
|
||||
return arg, err
|
||||
}
|
||||
t, err := conn.ReceiveString()
|
||||
if err != nil {
|
||||
return arg, err
|
||||
}
|
||||
size, err := conn.ReceiveUint32()
|
||||
if err != nil {
|
||||
return arg, err
|
||||
}
|
||||
arg.Name = name
|
||||
arg.Type, err = types.Parse(t)
|
||||
if err != nil {
|
||||
return arg, err
|
||||
}
|
||||
arg.Type.Bits = types.Size(size)
|
||||
|
||||
if arg.Type.Type == types.TSlice {
|
||||
arg.Type.ArraySize = arg.Type.Bits / arg.Type.ElementType.Bits
|
||||
}
|
||||
|
||||
count, err := conn.ReceiveUint32()
|
||||
if err != nil {
|
||||
return arg, err
|
||||
}
|
||||
for i := 0; i < count; i++ {
|
||||
a, err := receiveArgument(conn)
|
||||
if err != nil {
|
||||
return arg, err
|
||||
}
|
||||
arg.Compound = append(arg.Compound, a)
|
||||
}
|
||||
return arg, nil
|
||||
}
|
||||
440
bedlam/circuit/stream_garble.go
Normal file
440
bedlam/circuit/stream_garble.go
Normal file
@ -0,0 +1,440 @@
|
||||
//
|
||||
// Copyright (c) 2020-2021, 2023 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package circuit
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/ot"
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/p2p"
|
||||
)
|
||||
|
||||
const (
|
||||
// StreamDebug controls the debugging output of the streaming
|
||||
// garbling.
|
||||
StreamDebug = false
|
||||
)
|
||||
|
||||
// Streaming is a streaming garbled circuit garbler.
|
||||
type Streaming struct {
|
||||
conn *p2p.Conn
|
||||
key []byte
|
||||
alg cipher.Block
|
||||
r ot.Label
|
||||
wires []ot.Wire
|
||||
tmp []ot.Wire
|
||||
in []Wire
|
||||
out []Wire
|
||||
firstTmp Wire
|
||||
firstOut Wire
|
||||
}
|
||||
|
||||
// NewStreaming creates a new streaming garbled circuit garbler.
|
||||
func NewStreaming(key []byte, inputs []Wire, conn *p2p.Conn) (
|
||||
*Streaming, error) {
|
||||
|
||||
r, err := ot.NewLabel(rand.Reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r.SetS(true)
|
||||
|
||||
alg, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stream := &Streaming{
|
||||
conn: conn,
|
||||
key: key,
|
||||
alg: alg,
|
||||
r: r,
|
||||
}
|
||||
|
||||
stream.ensureWires(maxWire(0, inputs))
|
||||
|
||||
// Assing all input wires.
|
||||
for i := 0; i < len(inputs); i++ {
|
||||
w, err := makeLabels(stream.r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stream.wires[inputs[i]] = w
|
||||
}
|
||||
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
func maxWire(max Wire, wires []Wire) Wire {
|
||||
for _, w := range wires {
|
||||
if w > max {
|
||||
max = w
|
||||
}
|
||||
}
|
||||
return max
|
||||
}
|
||||
|
||||
func (stream *Streaming) ensureWires(max Wire) {
|
||||
// Verify that wires is big enough.
|
||||
if len(stream.wires) <= int(max) {
|
||||
var i int
|
||||
for i = 65536; i <= int(max); i <<= 1 {
|
||||
}
|
||||
n := make([]ot.Wire, i)
|
||||
copy(n, stream.wires)
|
||||
stream.wires = n
|
||||
}
|
||||
}
|
||||
|
||||
func (stream *Streaming) initCircuit(c *Circuit, in, out []Wire) {
|
||||
stream.ensureWires(maxWire(maxWire(0, in), out))
|
||||
|
||||
if len(stream.tmp) < c.NumWires {
|
||||
stream.tmp = make([]ot.Wire, c.NumWires)
|
||||
}
|
||||
|
||||
stream.in = in
|
||||
stream.out = out
|
||||
|
||||
stream.firstTmp = Wire(len(in))
|
||||
stream.firstOut = Wire(c.NumWires - len(out))
|
||||
}
|
||||
|
||||
// GetInput gets the value of the input wire.
|
||||
func (stream *Streaming) GetInput(w Wire) ot.Wire {
|
||||
return stream.wires[w]
|
||||
}
|
||||
|
||||
// GetInputs gets the specified input wire range.
|
||||
func (stream *Streaming) GetInputs(offset, count int) []ot.Wire {
|
||||
return stream.wires[offset : offset+count]
|
||||
}
|
||||
|
||||
// Get gets the value of the wire.
|
||||
func (stream *Streaming) Get(w Wire) (ot.Wire, Wire, bool) {
|
||||
if w < stream.firstTmp {
|
||||
index := stream.in[w]
|
||||
return stream.wires[index], index, false
|
||||
} else if w >= stream.firstOut {
|
||||
index := stream.out[w-stream.firstOut]
|
||||
return stream.wires[index], index, false
|
||||
} else {
|
||||
return stream.tmp[w], w, true
|
||||
}
|
||||
}
|
||||
|
||||
// Set sets the value of the wire.
|
||||
func (stream *Streaming) Set(w Wire, val ot.Wire) (index Wire, tmp bool) {
|
||||
if w < stream.firstTmp {
|
||||
index = stream.in[w]
|
||||
stream.wires[index] = val
|
||||
} else if w >= stream.firstOut {
|
||||
index = stream.out[w-stream.firstOut]
|
||||
stream.wires[index] = val
|
||||
} else {
|
||||
index = w
|
||||
tmp = true
|
||||
stream.tmp[w] = val
|
||||
}
|
||||
return index, tmp
|
||||
}
|
||||
|
||||
// Garble garbles the circuit and streams the garbled tables into the
|
||||
// stream.
|
||||
func (stream *Streaming) Garble(c *Circuit, in, out []Wire) (
|
||||
time.Duration, time.Duration, error) {
|
||||
if StreamDebug {
|
||||
fmt.Printf(" - Streaming.Garble: in=%v, out=%v\n", in, out)
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
|
||||
stream.initCircuit(c, in, out)
|
||||
|
||||
// Garble gates.
|
||||
|
||||
var data ot.LabelData
|
||||
var id uint32
|
||||
var table [4]ot.Label
|
||||
|
||||
mid := time.Now()
|
||||
|
||||
for i := 0; i < len(c.Gates); i++ {
|
||||
gate := &c.Gates[i]
|
||||
err := stream.conn.NeedSpace(512)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
err = stream.garbleGate(gate, &id, table[:], &data,
|
||||
stream.conn.WriteBuf, &stream.conn.WritePos)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
}
|
||||
return mid.Sub(start), time.Now().Sub(mid), nil
|
||||
}
|
||||
|
||||
// GarbleGate garbles the gate and streams it to the stream.
|
||||
func (stream *Streaming) garbleGate(g *Gate, idp *uint32,
|
||||
table []ot.Label, data *ot.LabelData, buf []byte, bufpos *int) error {
|
||||
|
||||
var a, b, c ot.Wire
|
||||
var aIndex, bIndex, cIndex Wire
|
||||
var aTmp, bTmp, cTmp bool
|
||||
|
||||
table = table[0:4]
|
||||
var tableStart, tableCount, wireCount int
|
||||
|
||||
// Inputs.
|
||||
switch g.Op {
|
||||
case XOR, XNOR, AND, OR:
|
||||
b, bIndex, bTmp = stream.Get(g.Input1)
|
||||
fallthrough
|
||||
|
||||
case INV:
|
||||
a, aIndex, aTmp = stream.Get(g.Input0)
|
||||
|
||||
default:
|
||||
return fmt.Errorf("invalid gate type %s", g.Op)
|
||||
}
|
||||
|
||||
// Output.
|
||||
switch g.Op {
|
||||
case XOR:
|
||||
l0 := a.L0
|
||||
l0.Xor(b.L0)
|
||||
|
||||
l1 := l0
|
||||
l1.Xor(stream.r)
|
||||
c = ot.Wire{
|
||||
L0: l0,
|
||||
L1: l1,
|
||||
}
|
||||
|
||||
case XNOR:
|
||||
l0 := a.L0
|
||||
l0.Xor(b.L0)
|
||||
|
||||
l1 := l0
|
||||
l1.Xor(stream.r)
|
||||
c = ot.Wire{
|
||||
L0: l1,
|
||||
L1: l0,
|
||||
}
|
||||
|
||||
case AND:
|
||||
pa := a.L0.S()
|
||||
pb := b.L0.S()
|
||||
|
||||
j0 := *idp
|
||||
j1 := *idp + 1
|
||||
*idp = *idp + 2
|
||||
|
||||
// First half gate.
|
||||
tg := encryptHalf(stream.alg, a.L0, j0, data)
|
||||
tg.Xor(encryptHalf(stream.alg, a.L1, j0, data))
|
||||
if pb {
|
||||
tg.Xor(stream.r)
|
||||
}
|
||||
wg0 := encryptHalf(stream.alg, a.L0, j0, data)
|
||||
if pa {
|
||||
wg0.Xor(tg)
|
||||
}
|
||||
|
||||
// Second half gate.
|
||||
te := encryptHalf(stream.alg, b.L0, j1, data)
|
||||
te.Xor(encryptHalf(stream.alg, b.L1, j1, data))
|
||||
te.Xor(a.L0)
|
||||
we0 := encryptHalf(stream.alg, b.L0, j1, data)
|
||||
if pb {
|
||||
we0.Xor(te)
|
||||
we0.Xor(a.L0)
|
||||
}
|
||||
|
||||
// Combine halves
|
||||
l0 := wg0
|
||||
l0.Xor(we0)
|
||||
|
||||
l1 := l0
|
||||
l1.Xor(stream.r)
|
||||
|
||||
c = ot.Wire{
|
||||
L0: l0,
|
||||
L1: l1,
|
||||
}
|
||||
table[0] = tg
|
||||
table[1] = te
|
||||
tableCount = 2
|
||||
|
||||
case OR, INV:
|
||||
// Row reduction creates labels below so that the first row is
|
||||
// all zero.
|
||||
|
||||
default:
|
||||
panic("invalid gate type")
|
||||
}
|
||||
|
||||
switch g.Op {
|
||||
case XOR, XNOR:
|
||||
// Free XOR.
|
||||
wireCount = 3
|
||||
|
||||
case AND:
|
||||
// Half AND garbled above.
|
||||
wireCount = 3
|
||||
|
||||
case OR:
|
||||
// a b c
|
||||
// -----
|
||||
// 0 0 0
|
||||
// 0 1 1
|
||||
// 1 0 1
|
||||
// 1 1 1
|
||||
id := *idp
|
||||
*idp = *idp + 1
|
||||
table[idx(a.L0, b.L0)] = encrypt(stream.alg, a.L0, b.L0, c.L0, id, data)
|
||||
table[idx(a.L0, b.L1)] = encrypt(stream.alg, a.L0, b.L1, c.L1, id, data)
|
||||
table[idx(a.L1, b.L0)] = encrypt(stream.alg, a.L1, b.L0, c.L1, id, data)
|
||||
table[idx(a.L1, b.L1)] = encrypt(stream.alg, a.L1, b.L1, c.L1, id, data)
|
||||
|
||||
// Row reduction. Make first table all zero so we don't have
|
||||
// to transmit it.
|
||||
|
||||
l0Index := idx(a.L0, b.L0)
|
||||
|
||||
c.L0 = table[0]
|
||||
c.L1 = table[0]
|
||||
|
||||
if l0Index == 0 {
|
||||
c.L1.Xor(stream.r)
|
||||
} else {
|
||||
c.L0.Xor(stream.r)
|
||||
}
|
||||
for i := 0; i < 4; i++ {
|
||||
if i == l0Index {
|
||||
table[i].Xor(c.L0)
|
||||
} else {
|
||||
table[i].Xor(c.L1)
|
||||
}
|
||||
}
|
||||
|
||||
tableStart = 1
|
||||
tableCount = 3
|
||||
wireCount = 3
|
||||
|
||||
case INV:
|
||||
// a b c
|
||||
// -----
|
||||
// 0 1
|
||||
// 1 0
|
||||
zero := ot.Label{}
|
||||
id := *idp
|
||||
*idp = *idp + 1
|
||||
table[idxUnary(a.L0)] = encrypt(stream.alg, a.L0, zero, c.L1, id, data)
|
||||
table[idxUnary(a.L1)] = encrypt(stream.alg, a.L1, zero, c.L0, id, data)
|
||||
|
||||
l0Index := idxUnary(a.L0)
|
||||
|
||||
c.L0 = table[0]
|
||||
c.L1 = table[0]
|
||||
|
||||
if l0Index == 0 {
|
||||
c.L0.Xor(stream.r)
|
||||
} else {
|
||||
c.L1.Xor(stream.r)
|
||||
}
|
||||
for i := 0; i < 2; i++ {
|
||||
if i == l0Index {
|
||||
table[i].Xor(c.L1)
|
||||
} else {
|
||||
table[i].Xor(c.L0)
|
||||
}
|
||||
}
|
||||
|
||||
tableStart = 1
|
||||
tableCount = 1
|
||||
wireCount = 2
|
||||
|
||||
default:
|
||||
return fmt.Errorf("invalid operand %s", g.Op)
|
||||
}
|
||||
|
||||
if g.Output < stream.firstTmp {
|
||||
cIndex = stream.in[g.Output]
|
||||
stream.wires[cIndex] = c
|
||||
} else if g.Output >= stream.firstOut {
|
||||
cIndex = stream.out[g.Output-stream.firstOut]
|
||||
stream.wires[cIndex] = c
|
||||
} else {
|
||||
cIndex = g.Output
|
||||
cTmp = true
|
||||
stream.tmp[g.Output] = c
|
||||
}
|
||||
|
||||
op := byte(g.Op)
|
||||
if aTmp {
|
||||
op |= 0b10000000
|
||||
}
|
||||
if bTmp {
|
||||
op |= 0b01000000
|
||||
}
|
||||
if cTmp {
|
||||
op |= 0b00100000
|
||||
}
|
||||
if aIndex <= 0xffff && bIndex <= 0xffff && cIndex <= 0xffff {
|
||||
op |= 0b00010000
|
||||
buf[*bufpos] = op
|
||||
*bufpos = *bufpos + 1
|
||||
|
||||
switch wireCount {
|
||||
case 3:
|
||||
bo.PutUint16(buf[*bufpos+0:], uint16(aIndex))
|
||||
bo.PutUint16(buf[*bufpos+2:], uint16(bIndex))
|
||||
bo.PutUint16(buf[*bufpos+4:], uint16(cIndex))
|
||||
*bufpos = *bufpos + 6
|
||||
|
||||
case 2:
|
||||
bo.PutUint16(buf[*bufpos+0:], uint16(aIndex))
|
||||
bo.PutUint16(buf[*bufpos+2:], uint16(cIndex))
|
||||
*bufpos = *bufpos + 4
|
||||
|
||||
default:
|
||||
panic(fmt.Sprintf("invalid wire count: %d", wireCount))
|
||||
}
|
||||
} else {
|
||||
buf[*bufpos] = op
|
||||
*bufpos = *bufpos + 1
|
||||
|
||||
switch wireCount {
|
||||
case 3:
|
||||
bo.PutUint32(buf[*bufpos+0:], uint32(aIndex))
|
||||
bo.PutUint32(buf[*bufpos+4:], uint32(bIndex))
|
||||
bo.PutUint32(buf[*bufpos+8:], uint32(cIndex))
|
||||
*bufpos = *bufpos + 12
|
||||
|
||||
case 2:
|
||||
bo.PutUint32(buf[*bufpos+0:], uint32(aIndex))
|
||||
bo.PutUint32(buf[*bufpos+4:], uint32(cIndex))
|
||||
*bufpos = *bufpos + 8
|
||||
|
||||
default:
|
||||
panic(fmt.Sprintf("invalid wire count: %d", wireCount))
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < tableCount; i++ {
|
||||
bytes := table[tableStart+i].Bytes(data)
|
||||
copy(buf[*bufpos:], bytes)
|
||||
*bufpos = *bufpos + len(bytes)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
356
bedlam/circuit/stream_garble_test.go
Normal file
356
bedlam/circuit/stream_garble_test.go
Normal file
@ -0,0 +1,356 @@
|
||||
//
|
||||
// Copyright (c) 2020-2023 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package circuit
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/ot"
|
||||
)
|
||||
|
||||
func BenchmarkGarbleXOR(b *testing.B) {
|
||||
benchmarkGate(b, newGate(XOR))
|
||||
}
|
||||
|
||||
func BenchmarkGarbleXNOR(b *testing.B) {
|
||||
benchmarkGate(b, newGate(XNOR))
|
||||
}
|
||||
|
||||
func BenchmarkGarbleAND(b *testing.B) {
|
||||
benchmarkGate(b, newGate(AND))
|
||||
}
|
||||
|
||||
func BenchmarkGarbleOR(b *testing.B) {
|
||||
benchmarkGate(b, newGate(OR))
|
||||
}
|
||||
|
||||
func BenchmarkGarbleINV(b *testing.B) {
|
||||
benchmarkGate(b, newGate(INV))
|
||||
}
|
||||
|
||||
func newGate(op Operation) *Gate {
|
||||
return &Gate{
|
||||
Input0: 0,
|
||||
Input1: 1,
|
||||
Output: 2,
|
||||
Op: op,
|
||||
}
|
||||
}
|
||||
|
||||
func benchmarkGate(b *testing.B, g *Gate) {
|
||||
var key [16]byte
|
||||
inputs := []Wire{0, 1}
|
||||
outputs := []Wire{2}
|
||||
|
||||
stream, err := NewStreaming(key[:], inputs, nil)
|
||||
if err != nil {
|
||||
b.Fatalf("failed to init streaming: %s", err)
|
||||
}
|
||||
stream.wires = []ot.Wire{{}, {}, {}}
|
||||
stream.in = inputs
|
||||
stream.out = outputs
|
||||
stream.firstTmp = 2
|
||||
stream.firstOut = 2
|
||||
|
||||
var id uint32
|
||||
var data ot.LabelData
|
||||
var table [4]ot.Label
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
var buf [128]byte
|
||||
var bufpos int
|
||||
|
||||
err = stream.garbleGate(g, &id, table[:], &data, buf[:], &bufpos)
|
||||
if err != nil {
|
||||
b.Fatalf("garble failed: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkEncodeWire32(b *testing.B) {
|
||||
var buf [64]byte
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
var pos int
|
||||
bufpos := &pos
|
||||
|
||||
var aIndex = Wire(i)
|
||||
var bIndex = Wire(i * 2)
|
||||
var cIndex = Wire(i * 4)
|
||||
var op byte
|
||||
|
||||
buf[*bufpos] = op
|
||||
*bufpos = *bufpos + 1
|
||||
|
||||
bo.PutUint32(buf[*bufpos+0:], uint32(aIndex))
|
||||
bo.PutUint32(buf[*bufpos+4:], uint32(bIndex))
|
||||
bo.PutUint32(buf[*bufpos+8:], uint32(cIndex))
|
||||
*bufpos = *bufpos + 12
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkEncodeWire16_32(b *testing.B) {
|
||||
var buf [64]byte
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
var pos int
|
||||
bufpos := &pos
|
||||
|
||||
var aIndex = Wire(i)
|
||||
var bIndex = Wire(i * 2)
|
||||
var cIndex = Wire(i * 4)
|
||||
var op byte
|
||||
|
||||
if aIndex <= 0xffff && bIndex <= 0xffff && cIndex <= 0xffff {
|
||||
op |= 0b00010000
|
||||
buf[*bufpos] = op
|
||||
*bufpos = *bufpos + 1
|
||||
|
||||
bo.PutUint16(buf[*bufpos+0:], uint16(aIndex))
|
||||
bo.PutUint16(buf[*bufpos+2:], uint16(bIndex))
|
||||
bo.PutUint16(buf[*bufpos+4:], uint16(cIndex))
|
||||
*bufpos = *bufpos + 6
|
||||
} else {
|
||||
buf[*bufpos] = op
|
||||
*bufpos = *bufpos + 1
|
||||
|
||||
bo.PutUint32(buf[*bufpos+0:], uint32(aIndex))
|
||||
bo.PutUint32(buf[*bufpos+4:], uint32(bIndex))
|
||||
bo.PutUint32(buf[*bufpos+8:], uint32(cIndex))
|
||||
*bufpos = *bufpos + 12
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkEncodeWire8_16_24_32(b *testing.B) {
|
||||
var buf [64]byte
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
var pos int
|
||||
bufpos := &pos
|
||||
|
||||
var aIndex = Wire(i)
|
||||
var bIndex = Wire(i * 2)
|
||||
var cIndex = Wire(i * 4)
|
||||
var op byte
|
||||
|
||||
if aIndex <= 0xff && bIndex <= 0xff && cIndex <= 0xff {
|
||||
buf[*bufpos] = op
|
||||
*bufpos = *bufpos + 1
|
||||
|
||||
buf[*bufpos+0] = byte(aIndex)
|
||||
buf[*bufpos+1] = byte(bIndex)
|
||||
buf[*bufpos+2] = byte(cIndex)
|
||||
*bufpos = *bufpos + 3
|
||||
} else if aIndex <= 0xffff && bIndex <= 0xffff && cIndex <= 0xffff {
|
||||
op |= 0b00010000
|
||||
buf[*bufpos] = op
|
||||
*bufpos = *bufpos + 1
|
||||
|
||||
bo.PutUint16(buf[*bufpos+0:], uint16(aIndex))
|
||||
bo.PutUint16(buf[*bufpos+2:], uint16(bIndex))
|
||||
bo.PutUint16(buf[*bufpos+4:], uint16(cIndex))
|
||||
*bufpos = *bufpos + 6
|
||||
} else if aIndex <= 0xffffff && bIndex <= 0xffffff && cIndex <= 0xffffff {
|
||||
op |= 0b00100000
|
||||
buf[*bufpos] = op
|
||||
*bufpos = *bufpos + 1
|
||||
|
||||
PutUint24(buf[*bufpos+0:], uint32(aIndex))
|
||||
PutUint24(buf[*bufpos+3:], uint32(bIndex))
|
||||
PutUint24(buf[*bufpos+6:], uint32(cIndex))
|
||||
*bufpos = *bufpos + 9
|
||||
} else {
|
||||
op |= 0b00110000
|
||||
buf[*bufpos] = op
|
||||
*bufpos = *bufpos + 1
|
||||
|
||||
bo.PutUint32(buf[*bufpos+0:], uint32(aIndex))
|
||||
bo.PutUint32(buf[*bufpos+4:], uint32(bIndex))
|
||||
bo.PutUint32(buf[*bufpos+8:], uint32(cIndex))
|
||||
*bufpos = *bufpos + 12
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func PutUint24(b []byte, v uint32) {
|
||||
b[0] = byte(v >> 16)
|
||||
b[1] = byte(v >> 8)
|
||||
b[2] = byte(v)
|
||||
}
|
||||
|
||||
func BenchmarkEncodeWire7var(b *testing.B) {
|
||||
var buf [64]byte
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
var pos int
|
||||
bufpos := &pos
|
||||
|
||||
var aIndex = Wire(i)
|
||||
var bIndex = Wire(i * 2)
|
||||
var cIndex = Wire(i * 4)
|
||||
var op byte
|
||||
|
||||
buf[*bufpos] = op
|
||||
*bufpos = *bufpos + 1
|
||||
|
||||
*bufpos += encode7varRev(buf[*bufpos:], uint32(aIndex))
|
||||
*bufpos += encode7varRev(buf[*bufpos:], uint32(bIndex))
|
||||
*bufpos += encode7varRev(buf[*bufpos:], uint32(cIndex))
|
||||
}
|
||||
}
|
||||
|
||||
func encode7var(b []byte, v uint32) int {
|
||||
if v <= 0b01111111 {
|
||||
b[0] = byte(v)
|
||||
return 1
|
||||
} else if v <= 0b01111111_1111111 {
|
||||
b[0] = byte(v >> 7)
|
||||
b[1] = byte(v)
|
||||
return 2
|
||||
} else if v <= 0b01111111_1111111_1111111 {
|
||||
b[0] = byte(v >> 14)
|
||||
b[1] = byte(v >> 7)
|
||||
b[2] = byte(v)
|
||||
return 3
|
||||
} else if v <= 0b01111111_1111111_1111111_1111111 {
|
||||
b[0] = byte(v >> 21)
|
||||
b[1] = byte(v >> 14)
|
||||
b[2] = byte(v >> 7)
|
||||
b[3] = byte(v)
|
||||
return 4
|
||||
} else {
|
||||
b[0] = byte(v >> 28)
|
||||
b[1] = byte(v >> 21)
|
||||
b[2] = byte(v >> 14)
|
||||
b[3] = byte(v >> 7)
|
||||
b[4] = byte(v)
|
||||
return 5
|
||||
}
|
||||
}
|
||||
|
||||
func encode7varRev(b []byte, v uint32) int {
|
||||
if v > 0b01111111_1111111_1111111_1111111 {
|
||||
b[0] = byte(v >> 28)
|
||||
b[1] = byte(v >> 21)
|
||||
b[2] = byte(v >> 14)
|
||||
b[3] = byte(v >> 7)
|
||||
b[4] = byte(v)
|
||||
return 5
|
||||
} else if v > 0b01111111_1111111_1111111 {
|
||||
b[0] = byte(v >> 21)
|
||||
b[1] = byte(v >> 14)
|
||||
b[2] = byte(v >> 7)
|
||||
b[3] = byte(v)
|
||||
return 4
|
||||
} else if v > 0b01111111_1111111 {
|
||||
b[0] = byte(v >> 14)
|
||||
b[1] = byte(v >> 7)
|
||||
b[2] = byte(v)
|
||||
return 3
|
||||
} else if v > 0b01111111 {
|
||||
b[0] = byte(v >> 7)
|
||||
b[1] = byte(v)
|
||||
return 2
|
||||
} else {
|
||||
b[0] = byte(v)
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkEncodeWire7varInline(b *testing.B) {
|
||||
var buf [64]byte
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
var pos int
|
||||
bufpos := &pos
|
||||
|
||||
var aIndex = Wire(i)
|
||||
var bIndex = Wire(i * 2)
|
||||
var cIndex = Wire(i * 4)
|
||||
var op byte
|
||||
|
||||
buf[*bufpos] = op
|
||||
*bufpos = *bufpos + 1
|
||||
|
||||
if aIndex <= 0b01111111 &&
|
||||
bIndex <= 0b01111111 &&
|
||||
cIndex <= 0b01111111 {
|
||||
encode7var1(buf[*bufpos+0:], uint32(aIndex))
|
||||
encode7var1(buf[*bufpos+1:], uint32(bIndex))
|
||||
encode7var1(buf[*bufpos+2:], uint32(cIndex))
|
||||
*bufpos += 3
|
||||
} else if aIndex <= 0b01111111_1111111 &&
|
||||
bIndex <= 0b01111111_1111111 &&
|
||||
cIndex <= 0b01111111_1111111 {
|
||||
encode7var2(buf[*bufpos+0:], uint32(aIndex))
|
||||
encode7var2(buf[*bufpos+2:], uint32(bIndex))
|
||||
encode7var2(buf[*bufpos+4:], uint32(cIndex))
|
||||
*bufpos += 6
|
||||
} else if aIndex <= 0b01111111_1111111_1111111 &&
|
||||
bIndex <= 0b01111111_1111111_1111111 &&
|
||||
cIndex <= 0b01111111_1111111_1111111 {
|
||||
encode7var3(buf[*bufpos+0:], uint32(aIndex))
|
||||
encode7var3(buf[*bufpos+3:], uint32(bIndex))
|
||||
encode7var3(buf[*bufpos+6:], uint32(cIndex))
|
||||
*bufpos += 9
|
||||
} else if aIndex <= 0b01111111_1111111_1111111_1111111 &&
|
||||
bIndex <= 0b01111111_1111111_1111111_1111111 &&
|
||||
cIndex <= 0b01111111_1111111_1111111_1111111 {
|
||||
encode7var4(buf[*bufpos+0:], uint32(aIndex))
|
||||
encode7var4(buf[*bufpos+4:], uint32(bIndex))
|
||||
encode7var4(buf[*bufpos+8:], uint32(cIndex))
|
||||
*bufpos += 12
|
||||
} else {
|
||||
encode7var5(buf[*bufpos+0:], uint32(aIndex))
|
||||
encode7var5(buf[*bufpos+5:], uint32(bIndex))
|
||||
encode7var5(buf[*bufpos+10:], uint32(cIndex))
|
||||
*bufpos += 15
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func encode7var1(b []byte, v uint32) {
|
||||
b[0] = byte(v)
|
||||
}
|
||||
|
||||
func encode7var2(b []byte, v uint32) {
|
||||
b[0] = byte(v >> 7)
|
||||
b[1] = byte(v)
|
||||
}
|
||||
|
||||
func encode7var3(b []byte, v uint32) {
|
||||
b[0] = byte(v >> 14)
|
||||
b[1] = byte(v >> 7)
|
||||
b[2] = byte(v)
|
||||
}
|
||||
|
||||
func encode7var4(b []byte, v uint32) {
|
||||
b[0] = byte(v >> 21)
|
||||
b[1] = byte(v >> 14)
|
||||
b[2] = byte(v >> 7)
|
||||
b[3] = byte(v)
|
||||
}
|
||||
|
||||
func encode7var5(b []byte, v uint32) {
|
||||
b[0] = byte(v >> 28)
|
||||
b[1] = byte(v >> 21)
|
||||
b[2] = byte(v >> 14)
|
||||
b[3] = byte(v >> 7)
|
||||
b[4] = byte(v)
|
||||
}
|
||||
|
||||
func BenchmarkTimeDuration(b *testing.B) {
|
||||
var total time.Duration
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
start := time.Now()
|
||||
total += time.Now().Sub(start)
|
||||
}
|
||||
}
|
||||
511
bedlam/circuit/svg.go
Normal file
511
bedlam/circuit/svg.go
Normal file
@ -0,0 +1,511 @@
|
||||
//
|
||||
// Copyright (c) 2019-2022 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package circuit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"sort"
|
||||
)
|
||||
|
||||
const (
|
||||
ioWidth = 32
|
||||
ioHeight = 32
|
||||
ioPadX = 16
|
||||
ioPadY = 32
|
||||
|
||||
gateWidth = 32
|
||||
gateHeight = 32
|
||||
gatePadX = 16
|
||||
gatePadY = 64
|
||||
)
|
||||
|
||||
type tile struct {
|
||||
gate *Gate
|
||||
avg float64
|
||||
x float64
|
||||
y float64
|
||||
}
|
||||
|
||||
type point struct {
|
||||
x, y float64
|
||||
}
|
||||
|
||||
type wireType int
|
||||
|
||||
const (
|
||||
wireTypeNormal wireType = iota
|
||||
wireTypeZero
|
||||
wireTypeOne
|
||||
)
|
||||
|
||||
type wire struct {
|
||||
t wireType
|
||||
from point
|
||||
to point
|
||||
}
|
||||
|
||||
func (w *wire) svg(out io.Writer) {
|
||||
label := "1"
|
||||
|
||||
switch w.t {
|
||||
case wireTypeNormal:
|
||||
midY := w.from.y + (w.to.y-w.from.y)/2 - 5
|
||||
fmt.Fprintf(out, ` <path d="M %v %v
|
||||
v %v
|
||||
C %v %v %v %v %v %v
|
||||
v %v" />
|
||||
`,
|
||||
w.from.x, w.from.y,
|
||||
|
||||
midY-w.from.y,
|
||||
|
||||
w.from.x, midY+10,
|
||||
w.to.x, midY,
|
||||
w.to.x, midY+10,
|
||||
|
||||
w.to.y-midY-10,
|
||||
)
|
||||
|
||||
case wireTypeZero:
|
||||
label = "0"
|
||||
fallthrough
|
||||
|
||||
case wireTypeOne:
|
||||
fmt.Fprintf(out, ` <g fill="#000">
|
||||
<text x="%v" y="%v" text-anchor="middle">%v</text>
|
||||
</g>
|
||||
`,
|
||||
w.to.x, w.to.y-2, label)
|
||||
}
|
||||
}
|
||||
|
||||
type svgCtx struct {
|
||||
wireStarts []point
|
||||
zero Wire
|
||||
one Wire
|
||||
}
|
||||
|
||||
func (ctx *svgCtx) setWireType(input Wire, w *wire) {
|
||||
if input == ctx.zero {
|
||||
w.t = wireTypeZero
|
||||
} else if input == ctx.one {
|
||||
w.t = wireTypeOne
|
||||
}
|
||||
}
|
||||
|
||||
func (ctx *svgCtx) tileAvgInputX(t *tile) {
|
||||
var count float64
|
||||
switch t.gate.Op {
|
||||
case XOR, XNOR, AND, OR:
|
||||
if t.gate.Input1 != ctx.zero && t.gate.Input1 != ctx.one {
|
||||
count++
|
||||
t.avg += ctx.wireStarts[t.gate.Input1].x
|
||||
}
|
||||
fallthrough
|
||||
|
||||
case INV:
|
||||
if count == 0 ||
|
||||
(t.gate.Input0 != ctx.zero && t.gate.Input0 != ctx.one) {
|
||||
count++
|
||||
t.avg += ctx.wireStarts[t.gate.Input0].x
|
||||
}
|
||||
}
|
||||
t.avg /= count
|
||||
t.avg -= (gateWidth) / 2
|
||||
}
|
||||
|
||||
// Svg creates an SVG output of the circuit.
|
||||
func (c *Circuit) Svg(out io.Writer) {
|
||||
c.AssignLevels()
|
||||
|
||||
cols := c.Stats[MaxWidth]
|
||||
rows := c.Stats[NumLevels]
|
||||
|
||||
fmt.Printf("")
|
||||
|
||||
// Header.
|
||||
|
||||
ctx := &svgCtx{
|
||||
wireStarts: make([]point, c.NumWires),
|
||||
zero: InvalidWire,
|
||||
one: InvalidWire,
|
||||
}
|
||||
|
||||
iw := uint64(ioPadX + c.Inputs.Size()*(ioWidth+ioPadX))
|
||||
ow := uint64(ioPadX + c.Outputs.Size()*(ioWidth+ioPadX))
|
||||
|
||||
width := cols * (gateWidth + gatePadX)
|
||||
|
||||
if iw > width {
|
||||
width = iw
|
||||
}
|
||||
if ow > width {
|
||||
width = ow
|
||||
}
|
||||
|
||||
fmt.Fprintf(out,
|
||||
`<svg xmlns="http://www.w3.org/2000/svg" width="%d" height="%d">
|
||||
<style><![CDATA[
|
||||
text {
|
||||
font: 10px Courier, monospace;
|
||||
}
|
||||
]]></style>
|
||||
<g fill="none" stroke="#000" stroke-width=".5">
|
||||
`,
|
||||
width, rows*(gateHeight+gatePadY)+2*(ioHeight+gatePadY))
|
||||
|
||||
// Input wires.
|
||||
leftPad := ioPadX + int((width-iw)/2)
|
||||
for i := 0; i < c.Inputs.Size(); i++ {
|
||||
p := point{
|
||||
x: float64(leftPad + i*(ioWidth+ioPadX)),
|
||||
y: ioHeight,
|
||||
}
|
||||
ctx.wireStarts[i] = p
|
||||
|
||||
staticInput(out, p.x, p.y, fmt.Sprintf("i%v", i))
|
||||
}
|
||||
|
||||
// Compute level widths.
|
||||
|
||||
widths := make([]uint64, rows)
|
||||
|
||||
var lastLevel Level
|
||||
var x, y int
|
||||
|
||||
for _, g := range c.Gates {
|
||||
if g.Level != lastLevel {
|
||||
lastLevel = g.Level
|
||||
y++
|
||||
}
|
||||
widths[y]++
|
||||
}
|
||||
|
||||
// Render circuit.
|
||||
|
||||
var tiles []*tile
|
||||
var wires []*wire
|
||||
|
||||
x = 0
|
||||
y = -1
|
||||
lastLevel = 0
|
||||
|
||||
for idx, g := range c.Gates {
|
||||
if idx == 0 || g.Level != lastLevel {
|
||||
wires = append(wires, renderRow(out, int(width),
|
||||
float64(ioHeight+gatePadY+y*(gateHeight+gatePadY)), tiles,
|
||||
ctx)...)
|
||||
tiles = nil
|
||||
|
||||
lastLevel = g.Level
|
||||
y++
|
||||
x = 0
|
||||
}
|
||||
tiles = append(tiles, &tile{
|
||||
gate: &c.Gates[idx],
|
||||
})
|
||||
x++
|
||||
}
|
||||
wires = append(wires, renderRow(out, int(width),
|
||||
float64(ioHeight+gatePadY+y*(gateHeight+gatePadY)), tiles, ctx)...)
|
||||
|
||||
// Output wires.
|
||||
y++
|
||||
numOutputs := c.Outputs.Size()
|
||||
|
||||
var oAvg float64
|
||||
for i := 0; i < numOutputs; i++ {
|
||||
oAvg += ctx.wireStarts[c.NumWires-numOutputs+i].x
|
||||
}
|
||||
oAvg /= float64(numOutputs)
|
||||
|
||||
oWidth := float64((numOutputs-1)*(ioWidth+ioPadX) + ioWidth)
|
||||
|
||||
leftPad = int(oAvg-oWidth/2) + ioWidth/2
|
||||
|
||||
for i := 0; i < numOutputs; i++ {
|
||||
p := point{
|
||||
x: float64(leftPad + i*(ioWidth+ioPadX)),
|
||||
y: float64(ioHeight + gatePadY + y*(gateHeight+gatePadY)),
|
||||
}
|
||||
staticOutput(out, p.x, p.y, fmt.Sprintf("o%v", i))
|
||||
wires = append(wires, &wire{
|
||||
from: ctx.wireStarts[c.NumWires-numOutputs+i],
|
||||
to: p,
|
||||
})
|
||||
}
|
||||
|
||||
for _, w := range wires {
|
||||
w.svg(out)
|
||||
}
|
||||
|
||||
fmt.Fprintln(out, " </g>\n</svg>")
|
||||
}
|
||||
|
||||
func renderRow(out io.Writer, width int, y float64, tiles []*tile,
|
||||
ctx *svgCtx) []*wire {
|
||||
|
||||
var wires []*wire
|
||||
|
||||
for _, t := range tiles {
|
||||
ctx.tileAvgInputX(t)
|
||||
}
|
||||
sort.Slice(tiles, func(i, j int) bool {
|
||||
return tiles[i].avg < tiles[j].avg
|
||||
})
|
||||
|
||||
// Assign x based on input average and push tiles right.
|
||||
var next float64
|
||||
for _, t := range tiles {
|
||||
if next <= t.avg {
|
||||
t.x = t.avg
|
||||
} else {
|
||||
t.x = next
|
||||
}
|
||||
next = t.x + gateWidth + gatePadX
|
||||
}
|
||||
|
||||
// Starting from the right end, shift tiles left until they are on
|
||||
// screen and not overlapping.
|
||||
next = float64(width) - gateWidth - gatePadX
|
||||
for i := len(tiles) - 1; i >= 0; i-- {
|
||||
if tiles[i].x > next {
|
||||
tiles[i].x = next
|
||||
}
|
||||
next -= gateWidth + gatePadX
|
||||
}
|
||||
|
||||
for _, t := range tiles {
|
||||
wires = append(wires, t.gate.svg(out, t.x, y, ctx)...)
|
||||
}
|
||||
|
||||
return wires
|
||||
}
|
||||
|
||||
func (g *Gate) svg(out io.Writer, x, y float64, ctx *svgCtx) []*wire {
|
||||
fmt.Fprintf(out, ` <g transform="translate(%v %v)">
|
||||
`,
|
||||
x, y)
|
||||
|
||||
tmpl := templates[g.Op]
|
||||
if tmpl == nil {
|
||||
tmpl = templates[Count]
|
||||
}
|
||||
out.Write([]byte(tmpl.Expand()))
|
||||
fmt.Fprintln(out, ` </g>`)
|
||||
|
||||
ctx.wireStarts[g.Output] = point{
|
||||
x: x + gateWidth/2,
|
||||
y: y + gateHeight,
|
||||
}
|
||||
|
||||
var wires []*wire
|
||||
|
||||
switch g.Op {
|
||||
case XOR, XNOR, AND, OR:
|
||||
x0 := x + intCvt(35)
|
||||
x1 := x + intCvt(65)
|
||||
|
||||
f0 := ctx.wireStarts[g.Input0]
|
||||
f1 := ctx.wireStarts[g.Input1]
|
||||
|
||||
w0 := &wire{
|
||||
from: f0,
|
||||
to: point{
|
||||
x: x0,
|
||||
y: y,
|
||||
},
|
||||
}
|
||||
w1 := &wire{
|
||||
from: f1,
|
||||
to: point{
|
||||
x: x1,
|
||||
y: y,
|
||||
},
|
||||
}
|
||||
ctx.setWireType(g.Input0, w0)
|
||||
ctx.setWireType(g.Input1, w1)
|
||||
|
||||
// The input pin order does not matter in the
|
||||
// visualization. Swap input pins if input wires would cross
|
||||
// each other.
|
||||
if f0.x > f1.x {
|
||||
w0.to, w1.to = w1.to, w0.to
|
||||
}
|
||||
|
||||
wires = append(wires, w0)
|
||||
wires = append(wires, w1)
|
||||
|
||||
case INV:
|
||||
wire := &wire{
|
||||
from: ctx.wireStarts[g.Input0],
|
||||
to: point{
|
||||
x: x + intCvt(50),
|
||||
y: y,
|
||||
},
|
||||
}
|
||||
ctx.setWireType(g.Input0, wire)
|
||||
wires = append(wires, wire)
|
||||
}
|
||||
|
||||
// Constant value gates.
|
||||
switch g.Op {
|
||||
case XOR:
|
||||
if g.Input0 == g.Input1 {
|
||||
ctx.zero = g.Output
|
||||
staticOutput(out, x+gateWidth/2, y+gateHeight, "0")
|
||||
}
|
||||
|
||||
case XNOR:
|
||||
if g.Input0 == g.Input1 {
|
||||
fmt.Printf("*** one!\n")
|
||||
ctx.one = g.Output
|
||||
staticOutput(out, x+gateWidth/2, y+gateHeight, "1")
|
||||
}
|
||||
}
|
||||
|
||||
return wires
|
||||
}
|
||||
|
||||
func staticInput(out io.Writer, x, y float64, label string) {
|
||||
fmt.Fprintf(out, ` <g fill="#000">
|
||||
<text x="%v" y="%v" text-anchor="middle">%v</text>
|
||||
</g>
|
||||
`,
|
||||
x, y-2, label)
|
||||
}
|
||||
|
||||
func staticOutput(out io.Writer, x, y float64, label string) {
|
||||
fmt.Fprintf(out, ` <g fill="#000">
|
||||
<text x="%v" y="%v" text-anchor="middle">%v</text>
|
||||
</g>
|
||||
`,
|
||||
x, y+10, label)
|
||||
}
|
||||
|
||||
func scale(in int) float64 {
|
||||
return float64(in) * gateWidth / 100
|
||||
}
|
||||
|
||||
func path(out io.Writer) {
|
||||
fmt.Fprintln(out, ` <path fill="none" stroke="#000" stroke-width=".5"`)
|
||||
}
|
||||
|
||||
var templates [Count + 1]*Template
|
||||
|
||||
var intCvt IntCvt
|
||||
var floatCvt FloatCvt
|
||||
|
||||
func init() {
|
||||
templates[XOR] = NewTemplate(` <path d="M {{25}} {{20}}
|
||||
c {{10}} {{10}} {{40}} {{10}} {{50}} 0" />
|
||||
<path d="M {{25}} {{25}}
|
||||
c {{10}} {{10}} {{40}} {{10}} {{50}} 0" />
|
||||
<path d="M {{75}} {{25}}
|
||||
v {{25}}
|
||||
s 0 {{10}} {{-25}} {{25}}" />
|
||||
<path d="M {{25}} {{25}}
|
||||
v {{25}}
|
||||
s 0 {{10}} {{25}} {{25}}" />
|
||||
<path d="M {{35}} 0
|
||||
v {{25}}
|
||||
z" />
|
||||
<path d="M {{65}} 0
|
||||
v {{25}}
|
||||
z" />
|
||||
<path d="M {{50}} {{75}}
|
||||
v {{25}}
|
||||
z" />
|
||||
`)
|
||||
templates[XNOR] = NewTemplate(` <path d="M {{25}} {{20}}
|
||||
c {{10}} {{10}} {{40}} {{10}} {{50}} 0" />
|
||||
<path d="M {{25}} {{25}}
|
||||
c {{10}} {{10}} {{40}} {{10}} {{50}} 0" />
|
||||
<path d="M {{75}} {{25}}
|
||||
v {{25}}
|
||||
s 0 {{10}} {{-25}} {{25}}" />
|
||||
<path d="M {{25}} {{25}}
|
||||
v {{25}}
|
||||
s 0 {{10}} {{25}} {{25}}" />
|
||||
<circle cx="{{50}}" cy="{{80}}" r="{{5}}" />
|
||||
<path d="M {{35}} 0
|
||||
v {{25}}
|
||||
z" />
|
||||
<path d="M {{65}} 0
|
||||
v {{25}}
|
||||
z" />
|
||||
<path d="M {{50}} {{85}}
|
||||
v {{15}}
|
||||
z" />
|
||||
`)
|
||||
|
||||
templates[AND] = NewTemplate(` <path d="M {{25}} {{25}}
|
||||
h {{50}}
|
||||
v {{25}}
|
||||
a {{25}} {{25}} 0 1 1 {{-50}} 0
|
||||
v {{-25}}
|
||||
z" />
|
||||
<path d="M {{35}} 0
|
||||
v {{25}}
|
||||
z" />
|
||||
<path d="M {{65}} 0
|
||||
v {{25}}
|
||||
z" />
|
||||
<path d="M {{50}} {{75}}
|
||||
v {{25}}
|
||||
z" />
|
||||
`)
|
||||
|
||||
templates[OR] = NewTemplate(` <path d="M {{25}} {{20}}
|
||||
c {{10}} {{10}} {{40}} {{10}} {{50}} 0" />
|
||||
<path d="M {{75}} {{20}}
|
||||
v {{30}}
|
||||
s 0 {{10}} {{-25}} {{25}}" />
|
||||
<path d="M {{25}} {{20}}
|
||||
v {{30}}
|
||||
s 0 {{10}} {{25}} {{25}}" />
|
||||
<path d="M {{35}} 0
|
||||
v {{25}}
|
||||
z" />
|
||||
<path d="M {{65}} 0
|
||||
v {{25}}
|
||||
z" />
|
||||
<path d="M {{50}} {{75}}
|
||||
v {{25}}
|
||||
z" />
|
||||
`)
|
||||
|
||||
templates[INV] = NewTemplate(` <path d="M {{25}} {{25}}
|
||||
h {{50}}
|
||||
l {{-25}} {{43}}
|
||||
z" />
|
||||
<circle cx="{{50}}" cy="{{73.5}}" r="{{5}}" />
|
||||
<path d="M {{50}} 0
|
||||
v {{25}}
|
||||
z" />
|
||||
<path d="M {{50}} {{79}}
|
||||
v {{21}}
|
||||
z" />
|
||||
`)
|
||||
|
||||
templates[Count] = NewTemplate(`<path
|
||||
d="M {{25}} {{25}} h {{50}} v{{50}} h {{-50}} z" />`)
|
||||
|
||||
floatCvt = func(v float64) float64 {
|
||||
return v * gateWidth / 100
|
||||
}
|
||||
intCvt = func(v int) float64 {
|
||||
return float64(v) * gateWidth / 100
|
||||
}
|
||||
for op := XOR; op < Count+1; op++ {
|
||||
if templates[op] != nil {
|
||||
templates[op].FloatCvt = floatCvt
|
||||
templates[op].IntCvt = intCvt
|
||||
}
|
||||
}
|
||||
}
|
||||
142
bedlam/circuit/template.go
Normal file
142
bedlam/circuit/template.go
Normal file
@ -0,0 +1,142 @@
|
||||
//
|
||||
// Copyright (c) 2022 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package circuit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var reVar = regexp.MustCompilePOSIX(`{(.?){([^\}]+)}}`)
|
||||
|
||||
// Template defines an expandable text template.
|
||||
type Template struct {
|
||||
parts []*part
|
||||
FloatCvt FloatCvt
|
||||
IntCvt IntCvt
|
||||
StringCvt StringCvt
|
||||
}
|
||||
|
||||
// FloatCvt converts a float64 value to float64 value.
|
||||
type FloatCvt func(v float64) float64
|
||||
|
||||
// IntCvt converts an integer value to float64 value.
|
||||
type IntCvt func(v int) float64
|
||||
|
||||
// StringCvt converts a string value to a string value.
|
||||
type StringCvt func(v string) string
|
||||
|
||||
const (
|
||||
partFloat = iota
|
||||
partInt
|
||||
partString
|
||||
)
|
||||
|
||||
type part struct {
|
||||
t int
|
||||
fv float64
|
||||
iv int
|
||||
sv string
|
||||
}
|
||||
|
||||
func (p part) String() string {
|
||||
switch p.t {
|
||||
case partFloat:
|
||||
return fmt.Sprintf("float64:%v", p.fv)
|
||||
case partInt:
|
||||
return fmt.Sprintf("int:%v", p.iv)
|
||||
case partString:
|
||||
return p.sv
|
||||
default:
|
||||
return fmt.Sprintf("{part %d}", p.t)
|
||||
}
|
||||
}
|
||||
|
||||
// NewTemplate parses the input string and returns the parsed
|
||||
// Template.
|
||||
func NewTemplate(input string) *Template {
|
||||
t := &Template{
|
||||
FloatCvt: func(v float64) float64 { return v },
|
||||
IntCvt: func(v int) float64 { return float64(v) },
|
||||
StringCvt: func(v string) string { return v },
|
||||
}
|
||||
matches := reVar.FindAllStringSubmatchIndex(input, -1)
|
||||
if matches == nil {
|
||||
return t
|
||||
}
|
||||
|
||||
var start int
|
||||
var err error
|
||||
|
||||
for _, m := range matches {
|
||||
if m[0] > start {
|
||||
t.parts = append(t.parts, &part{
|
||||
t: partString,
|
||||
sv: input[start:m[0]],
|
||||
})
|
||||
}
|
||||
content := input[m[4]:m[5]]
|
||||
part := &part{
|
||||
t: partFloat,
|
||||
sv: content,
|
||||
}
|
||||
t.parts = append(t.parts, part)
|
||||
start = m[1]
|
||||
|
||||
if m[2] != m[3] {
|
||||
switch input[m[2]:m[3]] {
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown template variable conversion: %s",
|
||||
input[m[2]:m[3]]))
|
||||
}
|
||||
} else {
|
||||
part.iv, err = strconv.Atoi(content)
|
||||
if err == nil {
|
||||
part.t = partInt
|
||||
} else {
|
||||
part.fv, err = strconv.ParseFloat(content, 64)
|
||||
if err == nil {
|
||||
part.t = partFloat
|
||||
} else {
|
||||
part.sv = content
|
||||
part.t = partString
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
if start < len(input) {
|
||||
t.parts = append(t.parts, &part{
|
||||
t: partString,
|
||||
sv: input[start:],
|
||||
})
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
// Expand expands the template.
|
||||
func (t *Template) Expand() string {
|
||||
var b strings.Builder
|
||||
|
||||
for _, part := range t.parts {
|
||||
switch part.t {
|
||||
case partFloat:
|
||||
b.WriteString(fmt.Sprintf("%v", t.FloatCvt(part.fv)))
|
||||
case partInt:
|
||||
b.WriteString(fmt.Sprintf("%v", t.IntCvt(part.iv)))
|
||||
case partString:
|
||||
b.WriteString(part.sv)
|
||||
default:
|
||||
panic(fmt.Sprintf("invalid part type: %v", part.t))
|
||||
}
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
88
bedlam/circuit/template_test.go
Normal file
88
bedlam/circuit/template_test.go
Normal file
@ -0,0 +1,88 @@
|
||||
//
|
||||
// Copyright (c) 2022-2023 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package circuit
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
var tmplXOR = `<svg xmlns="http://www.w3.org/2000/svg" width="100" height="100">
|
||||
<path fill="none" stroke="#000" stroke-width="1"
|
||||
d="M {{25}} {{20}}
|
||||
c {{10}} {{10}} {{40}} {{10}} {{50}} 0" />
|
||||
<path fill="none" stroke="#000" stroke-width="1"
|
||||
d="M 25 25
|
||||
c 10 10 40 10 50 0" />
|
||||
|
||||
<path fill="none" stroke="#000" stroke-width="1"
|
||||
d="M 75 25
|
||||
v 25
|
||||
s 0 10 -25 25 " />
|
||||
<path fill="none" stroke="#000" stroke-width="1"
|
||||
d="M 25 25
|
||||
v 25
|
||||
s 0 10 25 25 " />
|
||||
|
||||
<!-- Wires -->
|
||||
<path fill="none" stroke="#000" stroke-width="1"
|
||||
d="M 35 0
|
||||
v 25
|
||||
z" />
|
||||
<path fill="none" stroke="#000" stroke-width="1"
|
||||
d="M 65 0
|
||||
v 25
|
||||
z" />
|
||||
<path fill="none" stroke="#000" stroke-width="1"
|
||||
d="M 50 75
|
||||
v 25
|
||||
z" />
|
||||
</svg>
|
||||
`
|
||||
|
||||
var tmplXORExpanded = `<svg xmlns="http://www.w3.org/2000/svg" width="100" height="100">
|
||||
<path fill="none" stroke="#000" stroke-width="1"
|
||||
d="M 6.25 5
|
||||
c 2.5 2.5 10 2.5 12.5 0" />
|
||||
<path fill="none" stroke="#000" stroke-width="1"
|
||||
d="M 25 25
|
||||
c 10 10 40 10 50 0" />
|
||||
|
||||
<path fill="none" stroke="#000" stroke-width="1"
|
||||
d="M 75 25
|
||||
v 25
|
||||
s 0 10 -25 25 " />
|
||||
<path fill="none" stroke="#000" stroke-width="1"
|
||||
d="M 25 25
|
||||
v 25
|
||||
s 0 10 25 25 " />
|
||||
|
||||
<!-- Wires -->
|
||||
<path fill="none" stroke="#000" stroke-width="1"
|
||||
d="M 35 0
|
||||
v 25
|
||||
z" />
|
||||
<path fill="none" stroke="#000" stroke-width="1"
|
||||
d="M 65 0
|
||||
v 25
|
||||
z" />
|
||||
<path fill="none" stroke="#000" stroke-width="1"
|
||||
d="M 50 75
|
||||
v 25
|
||||
z" />
|
||||
</svg>
|
||||
`
|
||||
|
||||
func TestTemplate(t *testing.T) {
|
||||
tmpl := NewTemplate(tmplXOR)
|
||||
tmpl.IntCvt = func(v int) float64 {
|
||||
return float64(v) * 25 / 100
|
||||
}
|
||||
expanded := tmpl.Expand()
|
||||
if expanded != tmplXORExpanded {
|
||||
t.Errorf("template expansion failed: got\n%v\n", expanded)
|
||||
}
|
||||
}
|
||||
163
bedlam/circuit/timing.go
Normal file
163
bedlam/circuit/timing.go
Normal file
@ -0,0 +1,163 @@
|
||||
//
|
||||
// Copyright (c) 2020-2023 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package circuit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/markkurossi/tabulate"
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/p2p"
|
||||
)
|
||||
|
||||
// Timing records timing samples and renders a profiling report.
|
||||
type Timing struct {
|
||||
Start time.Time
|
||||
Samples []*Sample
|
||||
}
|
||||
|
||||
// NewTiming creates a new Timing instance.
|
||||
func NewTiming() *Timing {
|
||||
return &Timing{
|
||||
Start: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// Sample adds a timing sample with label and data columns.
|
||||
func (t *Timing) Sample(label string, cols []string) *Sample {
|
||||
start := t.Start
|
||||
if len(t.Samples) > 0 {
|
||||
start = t.Samples[len(t.Samples)-1].End
|
||||
}
|
||||
sample := &Sample{
|
||||
Label: label,
|
||||
Start: start,
|
||||
End: time.Now(),
|
||||
Cols: cols,
|
||||
}
|
||||
t.Samples = append(t.Samples, sample)
|
||||
return sample
|
||||
}
|
||||
|
||||
// Print prints profiling report to standard output.
|
||||
func (t *Timing) Print(stats p2p.IOStats) {
|
||||
if len(t.Samples) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
sent := stats.Sent.Load()
|
||||
received := stats.Recvd.Load()
|
||||
flushed := stats.Flushed.Load()
|
||||
|
||||
tab := tabulate.New(tabulate.UnicodeLight)
|
||||
tab.Header("Op").SetAlign(tabulate.ML)
|
||||
tab.Header("Time").SetAlign(tabulate.MR)
|
||||
tab.Header("%").SetAlign(tabulate.MR)
|
||||
tab.Header("Xfer").SetAlign(tabulate.MR)
|
||||
|
||||
total := t.Samples[len(t.Samples)-1].End.Sub(t.Start)
|
||||
for _, sample := range t.Samples {
|
||||
row := tab.Row()
|
||||
row.Column(sample.Label)
|
||||
|
||||
duration := sample.End.Sub(sample.Start)
|
||||
row.Column(fmt.Sprintf("%s", duration.String()))
|
||||
row.Column(fmt.Sprintf("%.2f%%",
|
||||
float64(duration)/float64(total)*100))
|
||||
|
||||
for _, col := range sample.Cols {
|
||||
row.Column(col)
|
||||
}
|
||||
|
||||
for idx, sub := range sample.Samples {
|
||||
row := tab.Row()
|
||||
|
||||
var prefix string
|
||||
if idx+1 >= len(sample.Samples) {
|
||||
prefix = "\u2570\u2574"
|
||||
} else {
|
||||
prefix = "\u251C\u2574"
|
||||
}
|
||||
|
||||
row.Column(prefix + sub.Label).SetFormat(tabulate.FmtItalic)
|
||||
|
||||
var d time.Duration
|
||||
if sub.Abs > 0 {
|
||||
d = sub.Abs
|
||||
} else {
|
||||
d = sub.End.Sub(sub.Start)
|
||||
}
|
||||
row.Column(d.String()).SetFormat(tabulate.FmtItalic)
|
||||
|
||||
row.Column(
|
||||
fmt.Sprintf("%.2f%%", float64(d)/float64(duration)*100)).
|
||||
SetFormat(tabulate.FmtItalic)
|
||||
}
|
||||
}
|
||||
row := tab.Row()
|
||||
row.Column("Total").SetFormat(tabulate.FmtBold)
|
||||
row.Column(t.Samples[len(t.Samples)-1].End.Sub(t.Start).String()).
|
||||
SetFormat(tabulate.FmtBold)
|
||||
row.Column("").SetFormat(tabulate.FmtBold)
|
||||
row.Column(FileSize(sent + received).String()).SetFormat(tabulate.FmtBold)
|
||||
|
||||
row = tab.Row()
|
||||
row.Column("\u251C\u2574Sent").SetFormat(tabulate.FmtItalic)
|
||||
row.Column("")
|
||||
row.Column(
|
||||
fmt.Sprintf("%.2f%%", float64(sent)/float64(sent+received)*100)).
|
||||
SetFormat(tabulate.FmtItalic)
|
||||
row.Column(FileSize(sent).String()).SetFormat(tabulate.FmtItalic)
|
||||
|
||||
row = tab.Row()
|
||||
row.Column("\u251C\u2574Rcvd").SetFormat(tabulate.FmtItalic)
|
||||
row.Column("")
|
||||
row.Column(
|
||||
fmt.Sprintf("%.2f%%", float64(received)/float64(sent+received)*100)).
|
||||
SetFormat(tabulate.FmtItalic)
|
||||
row.Column(FileSize(received).String()).SetFormat(tabulate.FmtItalic)
|
||||
|
||||
row = tab.Row()
|
||||
row.Column("\u2570\u2574Flcd").SetFormat(tabulate.FmtItalic)
|
||||
row.Column("")
|
||||
row.Column("")
|
||||
row.Column(fmt.Sprintf("%v", flushed)).SetFormat(tabulate.FmtItalic)
|
||||
|
||||
tab.Print(os.Stdout)
|
||||
}
|
||||
|
||||
// Sample contains information about one timing sample.
|
||||
type Sample struct {
|
||||
Label string
|
||||
Start time.Time
|
||||
End time.Time
|
||||
Abs time.Duration
|
||||
Cols []string
|
||||
Samples []*Sample
|
||||
}
|
||||
|
||||
// SubSample adds a sub-sample for a timing sample.
|
||||
func (s *Sample) SubSample(label string, end time.Time) {
|
||||
start := s.Start
|
||||
if len(s.Samples) > 0 {
|
||||
start = s.Samples[len(s.Samples)-1].End
|
||||
}
|
||||
s.Samples = append(s.Samples, &Sample{
|
||||
Label: label,
|
||||
Start: start,
|
||||
End: end,
|
||||
})
|
||||
}
|
||||
|
||||
// AbsSubSample adds an absolute sub-sample for a timing sample.
|
||||
func (s *Sample) AbsSubSample(label string, duration time.Duration) {
|
||||
s.Samples = append(s.Samples, &Sample{
|
||||
Label: label,
|
||||
Abs: duration,
|
||||
})
|
||||
}
|
||||
1
bedlam/compiler/.gitignore
vendored
Normal file
1
bedlam/compiler/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
*.circ
|
||||
238
bedlam/compiler/arithmetic_test.go
Normal file
238
bedlam/compiler/arithmetic_test.go
Normal file
@ -0,0 +1,238 @@
|
||||
//
|
||||
// Copyright (c) 2019-2023 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package compiler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"math/big"
|
||||
"testing"
|
||||
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/circuit"
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler/utils"
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/ot"
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/p2p"
|
||||
)
|
||||
|
||||
type Test struct {
|
||||
Name string
|
||||
Heavy bool
|
||||
Operand string
|
||||
Bits int
|
||||
Eval func(a *big.Int, b *big.Int) *big.Int
|
||||
Code string
|
||||
}
|
||||
|
||||
var tests = []Test{
|
||||
{
|
||||
Name: "Add",
|
||||
Heavy: true,
|
||||
Operand: "+",
|
||||
Bits: 2,
|
||||
Eval: func(a *big.Int, b *big.Int) *big.Int {
|
||||
result := big.NewInt(0)
|
||||
result.Add(a, b)
|
||||
return result
|
||||
},
|
||||
Code: `
|
||||
package main
|
||||
func main(a, b int3) int3 {
|
||||
return a + b
|
||||
}
|
||||
`,
|
||||
},
|
||||
// 1-bit, 2-bit, and n-bit multipliers have a bit different wiring.
|
||||
{
|
||||
Name: "Multiply 1-bit",
|
||||
Heavy: false,
|
||||
Operand: "*",
|
||||
Bits: 1,
|
||||
Eval: func(a *big.Int, b *big.Int) *big.Int {
|
||||
result := big.NewInt(0)
|
||||
result.Mul(a, b)
|
||||
return result
|
||||
},
|
||||
Code: `
|
||||
package main
|
||||
func main(a, b int1) int1 {
|
||||
return a * b
|
||||
}
|
||||
`,
|
||||
},
|
||||
{
|
||||
Name: "Multiply 2-bits",
|
||||
Heavy: true,
|
||||
Operand: "*",
|
||||
Bits: 2,
|
||||
Eval: func(a *big.Int, b *big.Int) *big.Int {
|
||||
result := big.NewInt(0)
|
||||
result.Mul(a, b)
|
||||
return result
|
||||
},
|
||||
Code: `
|
||||
package main
|
||||
func main(a, b int4) int4 {
|
||||
return a * b
|
||||
}
|
||||
`,
|
||||
},
|
||||
{
|
||||
Name: "Multiply n-bits",
|
||||
Heavy: true,
|
||||
Operand: "*",
|
||||
Bits: 2,
|
||||
Eval: func(a *big.Int, b *big.Int) *big.Int {
|
||||
result := big.NewInt(0)
|
||||
result.Mul(a, b)
|
||||
return result
|
||||
},
|
||||
Code: `
|
||||
package main
|
||||
func main(a, b int6) int6 {
|
||||
return a * b
|
||||
}
|
||||
`,
|
||||
},
|
||||
}
|
||||
|
||||
func TestArithmetics(t *testing.T) {
|
||||
for _, test := range tests {
|
||||
if testing.Short() && test.Heavy {
|
||||
fmt.Printf("Skipping %s\n", test.Name)
|
||||
continue
|
||||
}
|
||||
circ, _, err := New(utils.NewParams()).Compile(test.Code, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to compile test %s: %s", test.Name, err)
|
||||
}
|
||||
|
||||
limit := 1 << test.Bits
|
||||
|
||||
for g := 0; g < limit; g++ {
|
||||
for e := 0; e < limit; e++ {
|
||||
gr, ew := io.Pipe()
|
||||
er, gw := io.Pipe()
|
||||
|
||||
gio := newReadWriter(gr, gw)
|
||||
eio := newReadWriter(er, ew)
|
||||
|
||||
gInput := big.NewInt(int64(g))
|
||||
eInput := big.NewInt(int64(e))
|
||||
|
||||
gerr := make(chan error)
|
||||
eerr := make(chan error)
|
||||
res := make(chan []*big.Int)
|
||||
|
||||
go func() {
|
||||
fmt.Println("start garbler")
|
||||
_, err := circuit.Garbler(p2p.NewConn(gio), ot.NewFerret(1, ":5555"),
|
||||
circ, gInput, false)
|
||||
fmt.Println("end garbler")
|
||||
gerr <- err
|
||||
}()
|
||||
|
||||
go func() {
|
||||
fmt.Println("start evaluator")
|
||||
result, err := circuit.Evaluator(p2p.NewConn(eio),
|
||||
ot.NewFerret(2, "127.0.0.1:5555"), circ, eInput, false)
|
||||
fmt.Println("end evaluator")
|
||||
eerr <- err
|
||||
res <- result
|
||||
}()
|
||||
|
||||
err = <-gerr
|
||||
if err != nil {
|
||||
t.Fatalf("Garbler failed: %s\n", err)
|
||||
}
|
||||
err = <-eerr
|
||||
if err != nil {
|
||||
t.Fatalf("Evaluator failed: %s\n", err)
|
||||
}
|
||||
|
||||
result := <-res
|
||||
expected := test.Eval(gInput, eInput)
|
||||
|
||||
if expected.Cmp(result[0]) != 0 {
|
||||
t.Errorf("%s failed: %s %s %s = %s, expected %s\n",
|
||||
test.Name, gInput, test.Operand, eInput, result,
|
||||
expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var mult512 = `
|
||||
package main
|
||||
func main(a, b int512) int512 {
|
||||
return a * b
|
||||
}
|
||||
`
|
||||
|
||||
func BenchmarkMult(b *testing.B) {
|
||||
circ, _, err := New(utils.NewParams()).Compile(mult512, nil)
|
||||
if err != nil {
|
||||
b.Fatalf("failed to compile test: %s", err)
|
||||
}
|
||||
|
||||
gr, ew := io.Pipe()
|
||||
er, gw := io.Pipe()
|
||||
|
||||
gio := newReadWriter(gr, gw)
|
||||
eio := newReadWriter(er, ew)
|
||||
|
||||
gInput := big.NewInt(int64(11))
|
||||
eInput := big.NewInt(int64(13))
|
||||
|
||||
gerr := make(chan error)
|
||||
eerr := make(chan error)
|
||||
res := make(chan []*big.Int)
|
||||
|
||||
go func() {
|
||||
_, err := circuit.Garbler(p2p.NewConn(gio), ot.NewFerret(1, ":5555"),
|
||||
circ, gInput, false)
|
||||
gerr <- err
|
||||
}()
|
||||
|
||||
go func() {
|
||||
result, err := circuit.Evaluator(p2p.NewConn(eio),
|
||||
ot.NewFerret(2, "127.0.0.1:5555"), circ, eInput, false)
|
||||
eerr <- err
|
||||
res <- result
|
||||
}()
|
||||
|
||||
err = <-gerr
|
||||
if err != nil {
|
||||
b.Fatalf("Garbler failed: %s\n", err)
|
||||
}
|
||||
err = <-eerr
|
||||
if err != nil {
|
||||
b.Fatalf("Evaluator failed: %s\n", err)
|
||||
}
|
||||
|
||||
<-res
|
||||
}
|
||||
|
||||
func newReadWriter(in io.Reader, out io.Writer) io.ReadWriter {
|
||||
return &wrap{
|
||||
in: in,
|
||||
out: out,
|
||||
}
|
||||
}
|
||||
|
||||
type wrap struct {
|
||||
in io.Reader
|
||||
out io.Writer
|
||||
}
|
||||
|
||||
func (w *wrap) Read(p []byte) (n int, err error) {
|
||||
return w.in.Read(p)
|
||||
}
|
||||
|
||||
func (w *wrap) Write(p []byte) (n int, err error) {
|
||||
return w.out.Write(p)
|
||||
}
|
||||
918
bedlam/compiler/ast/ast.go
Normal file
918
bedlam/compiler/ast/ast.go
Normal file
@ -0,0 +1,918 @@
|
||||
//
|
||||
// ast.go
|
||||
//
|
||||
// Copyright (c) 2019-2024 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package ast
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler/mpa"
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler/ssa"
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler/utils"
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/types"
|
||||
)
|
||||
|
||||
var (
|
||||
_ AST = &List{}
|
||||
_ AST = &Func{}
|
||||
_ AST = &VariableDef{}
|
||||
_ AST = &Assign{}
|
||||
_ AST = &If{}
|
||||
_ AST = &Call{}
|
||||
_ AST = &ArrayCast{}
|
||||
_ AST = &Return{}
|
||||
_ AST = &For{}
|
||||
_ AST = &ForRange{}
|
||||
_ AST = &Binary{}
|
||||
_ AST = &Unary{}
|
||||
_ AST = &Slice{}
|
||||
_ AST = &Index{}
|
||||
_ AST = &VariableRef{}
|
||||
_ AST = &BasicLit{}
|
||||
_ AST = &CompositeLit{}
|
||||
_ AST = &Make{}
|
||||
_ AST = &Copy{}
|
||||
)
|
||||
|
||||
func indent(w io.Writer, indent int) {
|
||||
for i := 0; i < indent; i++ {
|
||||
fmt.Fprint(w, " ")
|
||||
}
|
||||
}
|
||||
|
||||
// Type specifies AST types.
|
||||
type Type int
|
||||
|
||||
// AST types.
|
||||
const (
|
||||
TypeName Type = iota
|
||||
TypeArray
|
||||
TypeSlice
|
||||
TypeStruct
|
||||
TypePointer
|
||||
TypeAlias
|
||||
)
|
||||
|
||||
// TypeInfo contains AST type information.
|
||||
type TypeInfo struct {
|
||||
utils.Point
|
||||
Type Type
|
||||
Name Identifier
|
||||
ElementType *TypeInfo
|
||||
ArrayLength AST
|
||||
TypeName string
|
||||
StructFields []StructField
|
||||
AliasType *TypeInfo
|
||||
Methods map[string]*Func
|
||||
Annotations Annotations
|
||||
}
|
||||
|
||||
// Equal tests if the argument TypeInfo is equal to this TypeInfo.
|
||||
func (ti *TypeInfo) Equal(o *TypeInfo) bool {
|
||||
if ti.Type != o.Type {
|
||||
return false
|
||||
}
|
||||
switch ti.Type {
|
||||
case TypeName:
|
||||
return ti.Name.String() == o.Name.String()
|
||||
|
||||
case TypeArray:
|
||||
return ti.ElementType.Equal(o.ElementType) &&
|
||||
ti.ArrayLength == o.ArrayLength
|
||||
|
||||
case TypeSlice, TypePointer:
|
||||
return ti.ElementType.Equal(o.ElementType)
|
||||
|
||||
case TypeStruct:
|
||||
if len(ti.StructFields) != len(o.StructFields) {
|
||||
return false
|
||||
}
|
||||
for idx, f := range ti.StructFields {
|
||||
if !f.Type.Equal(o.StructFields[idx].Type) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
|
||||
case TypeAlias:
|
||||
return ti.AliasType.Equal(o.AliasType)
|
||||
|
||||
default:
|
||||
panic("unsupported type")
|
||||
}
|
||||
}
|
||||
|
||||
// StructField contains AST structure field information.
|
||||
type StructField struct {
|
||||
utils.Point
|
||||
Name string
|
||||
Type *TypeInfo
|
||||
}
|
||||
|
||||
func (ti *TypeInfo) String() string {
|
||||
return ti.format(false)
|
||||
}
|
||||
|
||||
// Format print the type definition of the type info.
|
||||
func (ti *TypeInfo) Format() string {
|
||||
return ti.format(true)
|
||||
}
|
||||
|
||||
func (ti *TypeInfo) format(pp bool) string {
|
||||
var str string
|
||||
|
||||
if pp {
|
||||
str = fmt.Sprintf("type %s ", ti.TypeName)
|
||||
}
|
||||
|
||||
switch ti.Type {
|
||||
case TypeName:
|
||||
return str + ti.Name.String()
|
||||
|
||||
case TypeArray:
|
||||
return fmt.Sprintf("%s[%s]%s", str, ti.ArrayLength, ti.ElementType)
|
||||
|
||||
case TypeSlice:
|
||||
return fmt.Sprintf("%s[]%s", str, ti.ElementType)
|
||||
|
||||
case TypeStruct:
|
||||
str = fmt.Sprintf("%sstruct {", str)
|
||||
if pp {
|
||||
var width int
|
||||
for _, field := range ti.StructFields {
|
||||
if len(field.Name) > width {
|
||||
width = len(field.Name)
|
||||
}
|
||||
}
|
||||
for idx, field := range ti.StructFields {
|
||||
if idx == 0 {
|
||||
str += "\n"
|
||||
}
|
||||
str += " "
|
||||
str += field.Name
|
||||
for i := len(field.Name); i < width; i++ {
|
||||
str += " "
|
||||
}
|
||||
str += fmt.Sprintf(" %s\n", field.Type.String())
|
||||
}
|
||||
} else {
|
||||
for idx, field := range ti.StructFields {
|
||||
if idx > 0 {
|
||||
str += ", "
|
||||
}
|
||||
str += fmt.Sprintf("%s %s", field.Name, field.Type.String())
|
||||
}
|
||||
}
|
||||
return str + "}"
|
||||
|
||||
case TypeAlias:
|
||||
return fmt.Sprintf("%s= %s", str, ti.AliasType)
|
||||
|
||||
case TypePointer:
|
||||
return fmt.Sprintf("%s*%s", str, ti.ElementType)
|
||||
|
||||
default:
|
||||
return fmt.Sprintf("%s{TypeInfo %d}", str, ti.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// IsIdentifier returns true if the type info specifies a type name
|
||||
// without package.
|
||||
func (ti *TypeInfo) IsIdentifier() bool {
|
||||
return ti.Type == TypeName && len(ti.Name.Package) == 0
|
||||
}
|
||||
|
||||
// Resolve resolves the type information in the environment.
|
||||
func (ti *TypeInfo) Resolve(env *Env, ctx *Codegen, gen *ssa.Generator) (
|
||||
result types.Info, err error) {
|
||||
|
||||
if ti == nil {
|
||||
return
|
||||
}
|
||||
switch ti.Type {
|
||||
case TypeName:
|
||||
result, err = types.Parse(ti.Name.Name)
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
// Check dynamic types from the env.
|
||||
var b ssa.Binding
|
||||
var pkg *Package
|
||||
var ok bool
|
||||
|
||||
if len(ti.Name.Package) == 0 {
|
||||
// Plain indentifiers.
|
||||
b, ok = env.Get(ti.Name.Name)
|
||||
if !ok {
|
||||
// Check dynamic types from the pkg.
|
||||
b, ok = ctx.Package.Bindings.Get(ti.Name.Name)
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
// Qualified names and package-local names.
|
||||
var pkgName string
|
||||
if len(ti.Name.Package) > 0 {
|
||||
pkgName = ti.Name.Package
|
||||
} else if ti.Name.Defined != ctx.Package.Name {
|
||||
pkgName = ti.Name.Defined
|
||||
}
|
||||
|
||||
if len(pkgName) > 0 {
|
||||
pkg, ok = ctx.Packages[pkgName]
|
||||
if !ok {
|
||||
return result, ctx.Errorf(ti, "unknown package: %s",
|
||||
pkgName)
|
||||
}
|
||||
b, ok = pkg.Bindings.Get(ti.Name.Name)
|
||||
}
|
||||
}
|
||||
if ok {
|
||||
val, ok := b.Bound.(*ssa.Value)
|
||||
if ok && val.TypeRef {
|
||||
return val.Type, nil
|
||||
}
|
||||
}
|
||||
return result, ctx.Errorf(ti, "undefined name: %s", ti)
|
||||
|
||||
case TypeArray:
|
||||
// Array length.
|
||||
constLength, ok, err := ti.ArrayLength.Eval(env, ctx, gen)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
if !ok {
|
||||
return result, ctx.Errorf(ti.ArrayLength,
|
||||
"array length is not constant: %s", ti.ArrayLength)
|
||||
}
|
||||
length, err := constLength.ConstInt()
|
||||
if err != nil {
|
||||
return result, ctx.Errorf(ti.ArrayLength,
|
||||
"invalid array length: %s", err)
|
||||
}
|
||||
|
||||
// Element type.
|
||||
elInfo, err := ti.ElementType.Resolve(env, ctx, gen)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
return types.Info{
|
||||
Type: types.TArray,
|
||||
IsConcrete: true,
|
||||
Bits: length * elInfo.Bits,
|
||||
MinBits: length * elInfo.MinBits,
|
||||
ElementType: &elInfo,
|
||||
ArraySize: length,
|
||||
}, nil
|
||||
|
||||
case TypeSlice:
|
||||
// Element type.
|
||||
elInfo, err := ti.ElementType.Resolve(env, ctx, gen)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
// Bits and ArraySize are left uninitialized and they must be
|
||||
// defined when type is instantiated.
|
||||
return types.Info{
|
||||
Type: types.TSlice,
|
||||
ElementType: &elInfo,
|
||||
}, nil
|
||||
|
||||
case TypePointer:
|
||||
// Element type.
|
||||
elInfo, err := ti.ElementType.Resolve(env, ctx, gen)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
return types.Info{
|
||||
Type: types.TPtr,
|
||||
IsConcrete: true,
|
||||
Bits: elInfo.Bits,
|
||||
MinBits: elInfo.Bits,
|
||||
ElementType: &elInfo,
|
||||
}, nil
|
||||
|
||||
default:
|
||||
return result, ctx.Errorf(ti, "can't resolve type %s", ti)
|
||||
}
|
||||
}
|
||||
|
||||
// ConstantDef implements a constant definition.
|
||||
type ConstantDef struct {
|
||||
utils.Point
|
||||
Name string
|
||||
Type *TypeInfo
|
||||
Init AST
|
||||
Annotations Annotations
|
||||
}
|
||||
|
||||
// Exported describes if the constant is exported from the package.
|
||||
func (ast *ConstantDef) Exported() bool {
|
||||
return IsExported(ast.Name)
|
||||
}
|
||||
|
||||
// IsExported describes if the name is exported from the package.
|
||||
func IsExported(name string) bool {
|
||||
if len(name) == 0 {
|
||||
return false
|
||||
}
|
||||
return unicode.IsUpper([]rune(name)[0])
|
||||
}
|
||||
|
||||
func (ast *ConstantDef) String() string {
|
||||
result := fmt.Sprintf("const %s", ast.Name)
|
||||
if ast.Type != nil {
|
||||
result += fmt.Sprintf(" %s", ast.Type)
|
||||
}
|
||||
if ast.Init != nil {
|
||||
result += fmt.Sprintf(" = %s", ast.Init)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Identifier implements an AST identifier.
|
||||
type Identifier struct {
|
||||
Defined string
|
||||
Package string
|
||||
Name string
|
||||
}
|
||||
|
||||
func (i Identifier) String() string {
|
||||
if len(i.Package) == 0 {
|
||||
return i.Name
|
||||
}
|
||||
return fmt.Sprintf("%s.%s", i.Package, i.Name)
|
||||
}
|
||||
|
||||
// Qualified tells if the identifier has package part.
|
||||
func (i Identifier) Qualified() bool {
|
||||
return len(i.Package) > 0
|
||||
}
|
||||
|
||||
// AST implements abstract syntax tree nodes.
|
||||
type AST interface {
|
||||
utils.Locator
|
||||
|
||||
String() string
|
||||
// SSA generates SSA code from the AST node. The code is appended
|
||||
// into the basic block `block'. The function returns the next
|
||||
// sequential basic block. The `ssa.Dead' is set to `true' if the
|
||||
// code terminates i.e. all following AST nodes are dead code.
|
||||
SSA(block *ssa.Block, ctx *Codegen, gen *ssa.Generator) (
|
||||
*ssa.Block, []ssa.Value, error)
|
||||
// Eval evaluates the AST node during constant propagation.
|
||||
Eval(env *Env, ctx *Codegen, gen *ssa.Generator) (
|
||||
value ssa.Value, isConstant bool, err error)
|
||||
}
|
||||
|
||||
// NewEnv creates a new environment based on the current environment
|
||||
// bindings in the block.
|
||||
func NewEnv(block *ssa.Block) *Env {
|
||||
return &Env{
|
||||
Bindings: block.Bindings.Clone(),
|
||||
}
|
||||
}
|
||||
|
||||
// Env implements a value bindings environment.
|
||||
type Env struct {
|
||||
Bindings *ssa.Bindings
|
||||
}
|
||||
|
||||
// Debug print the environment.
|
||||
func (env *Env) Debug() {
|
||||
fmt.Print("Env: ")
|
||||
env.Bindings.Debug()
|
||||
}
|
||||
|
||||
// Get gets the value binding from the environment.
|
||||
func (env *Env) Get(name string) (ssa.Binding, bool) {
|
||||
return env.Bindings.Get(name)
|
||||
}
|
||||
|
||||
// Set sets the value binding to the environment.
|
||||
func (env *Env) Set(v ssa.Value, val *ssa.Value) {
|
||||
env.Bindings.Define(v, val)
|
||||
}
|
||||
|
||||
// List implements an AST list statement.
|
||||
type List []AST
|
||||
|
||||
func (ast List) String() string {
|
||||
result := "{\n"
|
||||
for _, a := range ast {
|
||||
result += fmt.Sprintf("\t%s\n", a)
|
||||
}
|
||||
return result + "}\n"
|
||||
}
|
||||
|
||||
// Location implements the compiler.utils.Locator interface.
|
||||
func (ast List) Location() utils.Point {
|
||||
if len(ast) > 0 {
|
||||
return ast[0].Location()
|
||||
}
|
||||
return utils.Point{}
|
||||
}
|
||||
|
||||
// Variable implements an AST variable.
|
||||
type Variable struct {
|
||||
utils.Point
|
||||
Name string
|
||||
Type *TypeInfo
|
||||
}
|
||||
|
||||
func (v Variable) String() string {
|
||||
if v.Type == nil {
|
||||
return v.Name
|
||||
}
|
||||
return fmt.Sprintf("%s %s", v.Name, v.Type)
|
||||
}
|
||||
|
||||
// Func implements an AST function.
|
||||
type Func struct {
|
||||
utils.Point
|
||||
Name string
|
||||
This *Variable
|
||||
Args []*Variable
|
||||
Return []*Variable
|
||||
Returns []*ReturnInfo
|
||||
NamedReturn bool
|
||||
Body List
|
||||
End utils.Point
|
||||
NumInstances int
|
||||
Annotations Annotations
|
||||
}
|
||||
|
||||
// ReturnInfo provide information about function return values.
|
||||
type ReturnInfo struct {
|
||||
Return *Return
|
||||
Types []types.Info
|
||||
}
|
||||
|
||||
// Annotations specify function annotations.
|
||||
type Annotations []string
|
||||
|
||||
// FirstSentence returns the first sentence from the annotations or an
|
||||
// empty string it if annotations are empty.
|
||||
func (ann Annotations) FirstSentence() string {
|
||||
str := strings.Join(ann, "\n")
|
||||
idx := strings.IndexRune(str, '.')
|
||||
if idx > 0 {
|
||||
return str[:idx+1]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// NewFunc creates a new function definition.
|
||||
func NewFunc(loc utils.Point, name string, args []*Variable, ret []*Variable,
|
||||
namedReturn bool, body List, end utils.Point,
|
||||
annotations Annotations) *Func {
|
||||
|
||||
// Skip empty lines from the beginning and end of annotations.
|
||||
for i := 0; i < len(annotations); i++ {
|
||||
if len(strings.TrimSpace(annotations[i])) > 0 {
|
||||
annotations = annotations[i:]
|
||||
break
|
||||
}
|
||||
}
|
||||
for i := len(annotations) - 1; i >= 0; i-- {
|
||||
if len(strings.TrimSpace(annotations[i])) > 0 {
|
||||
annotations = annotations[0 : i+1]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return &Func{
|
||||
Point: loc,
|
||||
Name: name,
|
||||
Args: args,
|
||||
Return: ret,
|
||||
NamedReturn: namedReturn,
|
||||
Body: body,
|
||||
End: end,
|
||||
Annotations: annotations,
|
||||
}
|
||||
}
|
||||
|
||||
func (ast *Func) String() string {
|
||||
var str string
|
||||
if ast.This != nil {
|
||||
str = fmt.Sprintf("func (%s %s) %s(",
|
||||
ast.This.Name, ast.This.Type, ast.Name)
|
||||
} else {
|
||||
str = fmt.Sprintf("func %s(", ast.Name)
|
||||
}
|
||||
for idx, arg := range ast.Args {
|
||||
if idx > 0 {
|
||||
str += ", "
|
||||
}
|
||||
if idx+1 < len(ast.Args) && arg.Type.Equal(ast.Args[idx+1].Type) {
|
||||
str += arg.Name
|
||||
} else {
|
||||
str += fmt.Sprintf("%s %s", arg.Name, arg.Type)
|
||||
}
|
||||
}
|
||||
str += ")"
|
||||
|
||||
if len(ast.Return) > 0 {
|
||||
if ast.NamedReturn {
|
||||
str += " ("
|
||||
for idx, ret := range ast.Return {
|
||||
if idx > 0 {
|
||||
str += ", "
|
||||
}
|
||||
if idx+1 < len(ast.Return) &&
|
||||
ret.Type.Equal(ast.Return[idx+1].Type) {
|
||||
str += ret.Name
|
||||
} else {
|
||||
str += fmt.Sprintf("%s %s", ret.Name, ret.Type)
|
||||
}
|
||||
}
|
||||
str += ")"
|
||||
} else if len(ast.Return) > 1 {
|
||||
str += " ("
|
||||
for idx, ret := range ast.Return {
|
||||
if idx > 0 {
|
||||
str += ", "
|
||||
}
|
||||
str += fmt.Sprintf("%s", ret.Type)
|
||||
}
|
||||
str += ")"
|
||||
} else {
|
||||
str += fmt.Sprintf(" %s", ast.Return[0].Type)
|
||||
}
|
||||
}
|
||||
|
||||
return str
|
||||
}
|
||||
|
||||
// VariableDef implements an AST variable definition.
|
||||
type VariableDef struct {
|
||||
utils.Point
|
||||
Names []string
|
||||
Type *TypeInfo
|
||||
Init AST
|
||||
Annotations Annotations
|
||||
}
|
||||
|
||||
func (ast *VariableDef) String() string {
|
||||
result := fmt.Sprintf("var %v", strings.Join(ast.Names, ", "))
|
||||
if ast.Type != nil {
|
||||
result += fmt.Sprintf(" %s", ast.Type)
|
||||
}
|
||||
if ast.Init != nil {
|
||||
result += fmt.Sprintf(" = %s", ast.Init)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Assign implements an AST assignment expression.
|
||||
type Assign struct {
|
||||
utils.Point
|
||||
LValues []AST
|
||||
Exprs []AST
|
||||
Define bool
|
||||
}
|
||||
|
||||
func (ast *Assign) String() string {
|
||||
var op string
|
||||
if ast.Define {
|
||||
op = ":="
|
||||
} else {
|
||||
op = "="
|
||||
}
|
||||
return fmt.Sprintf("%v %s %v", ast.LValues, op, ast.Exprs)
|
||||
}
|
||||
|
||||
// If implements an AST if statement.
|
||||
type If struct {
|
||||
utils.Point
|
||||
Expr AST
|
||||
True AST
|
||||
False AST
|
||||
}
|
||||
|
||||
func (ast *If) String() string {
|
||||
return fmt.Sprintf("if %s", ast.Expr)
|
||||
}
|
||||
|
||||
// Call implements an AST call expression.
|
||||
type Call struct {
|
||||
utils.Point
|
||||
Ref *VariableRef
|
||||
Exprs []AST
|
||||
}
|
||||
|
||||
func (ast *Call) String() string {
|
||||
str := fmt.Sprintf("%s(", ast.Ref)
|
||||
for idx, expr := range ast.Exprs {
|
||||
if idx > 0 {
|
||||
str += ", "
|
||||
}
|
||||
str += fmt.Sprintf("%v", expr)
|
||||
}
|
||||
return str + ")"
|
||||
}
|
||||
|
||||
// ArrayCast implements array cast expressions.
|
||||
type ArrayCast struct {
|
||||
utils.Point
|
||||
TypeInfo *TypeInfo
|
||||
Expr AST
|
||||
}
|
||||
|
||||
func (ast *ArrayCast) String() string {
|
||||
return fmt.Sprintf("%v(%v)", ast.TypeInfo, ast.Expr)
|
||||
}
|
||||
|
||||
// Return implements an AST return statement.
|
||||
type Return struct {
|
||||
utils.Point
|
||||
Exprs []AST
|
||||
AutoGenerated bool
|
||||
}
|
||||
|
||||
func (ast *Return) String() string {
|
||||
return fmt.Sprintf("return %v", ast.Exprs)
|
||||
}
|
||||
|
||||
// For implements an AST for statement.
|
||||
type For struct {
|
||||
utils.Point
|
||||
Init AST
|
||||
Cond AST
|
||||
Inc AST
|
||||
Body List
|
||||
}
|
||||
|
||||
func (ast *For) String() string {
|
||||
return fmt.Sprintf("for %s; %s; %s %s",
|
||||
ast.Init, ast.Cond, ast.Inc, ast.Body)
|
||||
}
|
||||
|
||||
// ForRange implements an AST for for-range statement.
|
||||
type ForRange struct {
|
||||
utils.Point
|
||||
ExprList []AST
|
||||
Def bool
|
||||
Expr AST
|
||||
Body List
|
||||
}
|
||||
|
||||
func (ast *ForRange) String() string {
|
||||
result := "for "
|
||||
for idx, expr := range ast.ExprList {
|
||||
if idx > 0 {
|
||||
result += ", "
|
||||
}
|
||||
result += expr.String()
|
||||
}
|
||||
if ast.Def {
|
||||
result += " := range "
|
||||
} else {
|
||||
result += " = range "
|
||||
}
|
||||
result += ast.Expr.String()
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// BinaryType defines binary expression types.
|
||||
type BinaryType int
|
||||
|
||||
// Binary expression types.
|
||||
const (
|
||||
BinaryMul BinaryType = iota
|
||||
BinaryDiv
|
||||
BinaryMod
|
||||
BinaryLshift
|
||||
BinaryRshift
|
||||
BinaryBand
|
||||
BinaryBclear
|
||||
BinaryBor
|
||||
BinaryBxor
|
||||
BinaryAdd
|
||||
BinarySub
|
||||
BinaryEq
|
||||
BinaryNeq
|
||||
BinaryLt
|
||||
BinaryLe
|
||||
BinaryGt
|
||||
BinaryGe
|
||||
BinaryAnd
|
||||
BinaryOr
|
||||
)
|
||||
|
||||
var binaryTypes = map[BinaryType]string{
|
||||
BinaryMul: "*",
|
||||
BinaryDiv: "/",
|
||||
BinaryMod: "%",
|
||||
BinaryLshift: "<<",
|
||||
BinaryRshift: ">>",
|
||||
BinaryBand: "&",
|
||||
BinaryBclear: "&^",
|
||||
BinaryBor: "|",
|
||||
BinaryBxor: "^",
|
||||
BinaryAdd: "+",
|
||||
BinarySub: "-",
|
||||
BinaryEq: "==",
|
||||
BinaryNeq: "!=",
|
||||
BinaryLt: "<",
|
||||
BinaryLe: "<=",
|
||||
BinaryGt: ">",
|
||||
BinaryGe: ">=",
|
||||
BinaryAnd: "&&",
|
||||
BinaryOr: "||",
|
||||
}
|
||||
|
||||
func (t BinaryType) String() string {
|
||||
name, ok := binaryTypes[t]
|
||||
if ok {
|
||||
return name
|
||||
}
|
||||
return fmt.Sprintf("{BinaryType %d}", t)
|
||||
}
|
||||
|
||||
// Binary implements an AST binary expression.
|
||||
type Binary struct {
|
||||
utils.Point
|
||||
Left AST
|
||||
Op BinaryType
|
||||
Right AST
|
||||
}
|
||||
|
||||
func (ast *Binary) String() string {
|
||||
return fmt.Sprintf("%s %s %s", ast.Left, ast.Op, ast.Right)
|
||||
}
|
||||
|
||||
// UnaryType defines unary expression types.
|
||||
type UnaryType int
|
||||
|
||||
// Unary expression types.
|
||||
const (
|
||||
UnaryPlus UnaryType = iota
|
||||
UnaryMinus
|
||||
UnaryNot
|
||||
UnaryXor
|
||||
UnaryPtr
|
||||
UnaryAddr
|
||||
UnarySend
|
||||
)
|
||||
|
||||
var unaryTypes = map[UnaryType]string{
|
||||
UnaryPlus: "+",
|
||||
UnaryMinus: "-",
|
||||
UnaryNot: "!",
|
||||
UnaryXor: "^",
|
||||
UnaryPtr: "*",
|
||||
UnaryAddr: "&",
|
||||
UnarySend: "<-",
|
||||
}
|
||||
|
||||
func (t UnaryType) String() string {
|
||||
name, ok := unaryTypes[t]
|
||||
if ok {
|
||||
return name
|
||||
}
|
||||
return fmt.Sprintf("{UnaryType %d}", t)
|
||||
}
|
||||
|
||||
// Unary implements an AST unary expression.
|
||||
type Unary struct {
|
||||
utils.Point
|
||||
Type UnaryType
|
||||
Expr AST
|
||||
}
|
||||
|
||||
func (ast *Unary) String() string {
|
||||
return fmt.Sprintf("%s%s", ast.Type, ast.Expr)
|
||||
}
|
||||
|
||||
// Slice implements an AST slice expression.
|
||||
type Slice struct {
|
||||
utils.Point
|
||||
Expr AST
|
||||
From AST
|
||||
To AST
|
||||
}
|
||||
|
||||
func (ast *Slice) String() string {
|
||||
var fromStr, toStr string
|
||||
if ast.From != nil {
|
||||
fromStr = ast.From.String()
|
||||
}
|
||||
if ast.To != nil {
|
||||
toStr = ast.To.String()
|
||||
}
|
||||
return fmt.Sprintf("%s[%s:%s]", ast.Expr, fromStr, toStr)
|
||||
}
|
||||
|
||||
// Index implements an AST array index expression.
|
||||
type Index struct {
|
||||
utils.Point
|
||||
Expr AST
|
||||
Index AST
|
||||
}
|
||||
|
||||
func (ast *Index) String() string {
|
||||
return fmt.Sprintf("%s[%s]", ast.Expr, ast.Index)
|
||||
}
|
||||
|
||||
// VariableRef implements an AST variable reference.
|
||||
type VariableRef struct {
|
||||
utils.Point
|
||||
Name Identifier
|
||||
}
|
||||
|
||||
func (ast *VariableRef) String() string {
|
||||
return ast.Name.String()
|
||||
}
|
||||
|
||||
// BasicLit implements an AST basic literal value.
|
||||
type BasicLit struct {
|
||||
utils.Point
|
||||
Value interface{}
|
||||
}
|
||||
|
||||
func (ast *BasicLit) String() string {
|
||||
return ConstantName(ast.Value)
|
||||
}
|
||||
|
||||
// ConstantName returns the name of the constant value.
|
||||
func ConstantName(value interface{}) string {
|
||||
switch val := value.(type) {
|
||||
case int, uint, int32, uint32, int64, uint64:
|
||||
return fmt.Sprintf("%d", val)
|
||||
case *mpa.Int:
|
||||
return fmt.Sprintf("%s", val)
|
||||
case bool:
|
||||
return fmt.Sprintf("%v", val)
|
||||
case string:
|
||||
return fmt.Sprintf("%q", val)
|
||||
default:
|
||||
return fmt.Sprintf("{undefined constant %v (%T)}", val, val)
|
||||
}
|
||||
}
|
||||
|
||||
// CompositeLit implements an AST composite literal value.
|
||||
type CompositeLit struct {
|
||||
utils.Point
|
||||
Type *TypeInfo
|
||||
Value []KeyedElement
|
||||
}
|
||||
|
||||
func (ast *CompositeLit) String() string {
|
||||
str := ast.Type.String()
|
||||
str += "{"
|
||||
|
||||
for idx, e := range ast.Value {
|
||||
if idx > 0 {
|
||||
str += ","
|
||||
}
|
||||
if e.Key != nil {
|
||||
str += fmt.Sprintf("%s: %s", e.Key, e.Element)
|
||||
} else {
|
||||
str += e.Element.String()
|
||||
}
|
||||
}
|
||||
return str + "}"
|
||||
}
|
||||
|
||||
// KeyedElement implements a keyed element of composite literal.
|
||||
type KeyedElement struct {
|
||||
Key AST
|
||||
Element AST
|
||||
}
|
||||
|
||||
// Make implements the builtin function make.
|
||||
type Make struct {
|
||||
utils.Point
|
||||
Type *TypeInfo
|
||||
Exprs []AST
|
||||
}
|
||||
|
||||
func (ast *Make) String() string {
|
||||
str := fmt.Sprintf("make(%s", ast.Type)
|
||||
for _, expr := range ast.Exprs {
|
||||
str += ", "
|
||||
str += expr.String()
|
||||
}
|
||||
return str + ")"
|
||||
}
|
||||
|
||||
// Copy implements the builtin function copy.
|
||||
type Copy struct {
|
||||
utils.Point
|
||||
Dst AST
|
||||
Src AST
|
||||
}
|
||||
|
||||
func (ast *Copy) String() string {
|
||||
return fmt.Sprintf("copy(%v, %v)", ast.Dst, ast.Src)
|
||||
}
|
||||
409
bedlam/compiler/ast/builtin.go
Normal file
409
bedlam/compiler/ast/builtin.go
Normal file
@ -0,0 +1,409 @@
|
||||
//
|
||||
// Copyright (c) 2019-2024 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package ast
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path"
|
||||
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/circuit"
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler/circuits"
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler/ssa"
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler/utils"
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/types"
|
||||
)
|
||||
|
||||
// Builtin implements QCL builtin functions.
|
||||
type Builtin struct {
|
||||
SSA SSA
|
||||
Eval Eval
|
||||
}
|
||||
|
||||
// SSA implements the builtin SSA generation.
|
||||
type SSA func(block *ssa.Block, ctx *Codegen, gen *ssa.Generator,
|
||||
args []ssa.Value, loc utils.Point) (*ssa.Block, []ssa.Value, error)
|
||||
|
||||
// Eval implements the builtin evaluation in constant folding.
|
||||
type Eval func(args []AST, env *Env, ctx *Codegen, gen *ssa.Generator,
|
||||
loc utils.Point) (ssa.Value, bool, error)
|
||||
|
||||
// Predeclared identifiers.
|
||||
var builtins = map[string]Builtin{
|
||||
"floorPow2": {
|
||||
SSA: floorPow2SSA,
|
||||
Eval: floorPow2Eval,
|
||||
},
|
||||
"len": {
|
||||
SSA: lenSSA,
|
||||
Eval: lenEval,
|
||||
},
|
||||
"native": {
|
||||
SSA: nativeSSA,
|
||||
},
|
||||
"panic": {
|
||||
SSA: panicSSA,
|
||||
Eval: panicEval,
|
||||
},
|
||||
"size": {
|
||||
SSA: sizeSSA,
|
||||
Eval: sizeEval,
|
||||
},
|
||||
}
|
||||
|
||||
func floorPow2SSA(block *ssa.Block, ctx *Codegen, gen *ssa.Generator,
|
||||
args []ssa.Value, loc utils.Point) (*ssa.Block, []ssa.Value, error) {
|
||||
return nil, nil, ctx.Errorf(loc, "floorPow2SSA not implemented")
|
||||
}
|
||||
|
||||
func floorPow2Eval(args []AST, env *Env, ctx *Codegen, gen *ssa.Generator,
|
||||
loc utils.Point) (ssa.Value, bool, error) {
|
||||
|
||||
if len(args) != 1 {
|
||||
return ssa.Undefined, false, ctx.Errorf(loc,
|
||||
"invalid amount of arguments in call to floorPow2")
|
||||
}
|
||||
|
||||
constVal, _, err := args[0].Eval(env, ctx, gen)
|
||||
if err != nil {
|
||||
return ssa.Undefined, false, ctx.Errorf(loc, "%s", err)
|
||||
}
|
||||
|
||||
val, err := constVal.ConstInt()
|
||||
if err != nil {
|
||||
return ssa.Undefined, false, ctx.Errorf(loc,
|
||||
"non-integer (%T) argument in %s: %s", constVal, args[0], err)
|
||||
}
|
||||
|
||||
var i types.Size
|
||||
for i = 1; i <= val; i <<= 1 {
|
||||
}
|
||||
i >>= 1
|
||||
|
||||
return gen.Constant(int64(i), types.Undefined), true, nil
|
||||
}
|
||||
|
||||
func lenSSA(block *ssa.Block, ctx *Codegen, gen *ssa.Generator,
|
||||
args []ssa.Value, loc utils.Point) (*ssa.Block, []ssa.Value, error) {
|
||||
|
||||
if len(args) != 1 {
|
||||
return nil, nil, ctx.Errorf(loc,
|
||||
"invalid amount of arguments in call to len")
|
||||
}
|
||||
|
||||
var val types.Size
|
||||
switch args[0].Type.Type {
|
||||
case types.TString:
|
||||
val = args[0].Type.Bits / types.ByteBits
|
||||
|
||||
case types.TArray, types.TSlice:
|
||||
val = args[0].Type.ArraySize
|
||||
|
||||
case types.TNil:
|
||||
val = 0
|
||||
|
||||
default:
|
||||
return nil, nil, ctx.Errorf(loc, "invalid argument 1 (type %s) for len",
|
||||
args[0].Type)
|
||||
}
|
||||
|
||||
v := gen.Constant(int64(val), types.Undefined)
|
||||
gen.AddConstant(v)
|
||||
|
||||
return block, []ssa.Value{v}, nil
|
||||
}
|
||||
|
||||
func lenEval(args []AST, env *Env, ctx *Codegen, gen *ssa.Generator,
|
||||
loc utils.Point) (ssa.Value, bool, error) {
|
||||
|
||||
if len(args) != 1 {
|
||||
return ssa.Undefined, false, ctx.Errorf(loc,
|
||||
"invalid amount of arguments in call to len")
|
||||
}
|
||||
|
||||
switch arg := args[0].(type) {
|
||||
case *VariableRef:
|
||||
var typeInfo types.Info
|
||||
|
||||
if len(arg.Name.Package) > 0 {
|
||||
// Check if the package name is bound to a value.
|
||||
b, ok := env.Get(arg.Name.Package)
|
||||
if ok {
|
||||
if b.Type.Type != types.TStruct {
|
||||
return ssa.Undefined, false, ctx.Errorf(loc,
|
||||
"%s undefined", arg.Name)
|
||||
}
|
||||
ok = false
|
||||
for _, f := range b.Type.Struct {
|
||||
if f.Name == arg.Name.Name {
|
||||
typeInfo = f.Type
|
||||
ok = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
return ssa.Undefined, false, ctx.Errorf(loc,
|
||||
"undefined variable '%s'", arg.Name)
|
||||
}
|
||||
} else {
|
||||
// Resolve name from the package.
|
||||
pkg, ok := ctx.Packages[arg.Name.Package]
|
||||
if !ok {
|
||||
return ssa.Undefined, false, ctx.Errorf(loc,
|
||||
"package '%s' not found", arg.Name.Package)
|
||||
}
|
||||
b, ok := pkg.Bindings.Get(arg.Name.Name)
|
||||
if !ok {
|
||||
return ssa.Undefined, false, ctx.Errorf(loc,
|
||||
"undefined variable '%s'", arg.Name)
|
||||
}
|
||||
typeInfo = b.Type
|
||||
}
|
||||
} else {
|
||||
b, ok := env.Get(arg.Name.Name)
|
||||
if !ok {
|
||||
return ssa.Undefined, false, ctx.Errorf(loc,
|
||||
"undefined variable '%s'", arg.Name)
|
||||
}
|
||||
typeInfo = b.Type
|
||||
}
|
||||
|
||||
if typeInfo.Type == types.TPtr {
|
||||
typeInfo = *typeInfo.ElementType
|
||||
}
|
||||
|
||||
switch typeInfo.Type {
|
||||
case types.TString:
|
||||
return gen.Constant(int64(typeInfo.Bits/types.ByteBits),
|
||||
types.Undefined), true, nil
|
||||
|
||||
case types.TArray, types.TSlice:
|
||||
return gen.Constant(int64(typeInfo.ArraySize), types.Undefined),
|
||||
true, nil
|
||||
|
||||
default:
|
||||
return ssa.Undefined, false, ctx.Errorf(loc,
|
||||
"invalid argument 1 (type %s) for len", typeInfo)
|
||||
}
|
||||
|
||||
default:
|
||||
return ssa.Undefined, false, ctx.Errorf(loc,
|
||||
"len(%v/%T) is not constant", arg, arg)
|
||||
}
|
||||
}
|
||||
|
||||
func nativeSSA(block *ssa.Block, ctx *Codegen, gen *ssa.Generator,
|
||||
args []ssa.Value, loc utils.Point) (*ssa.Block, []ssa.Value, error) {
|
||||
|
||||
if len(args) < 1 {
|
||||
return nil, nil, ctx.Errorf(loc,
|
||||
"not enough arguments in call to native")
|
||||
}
|
||||
name, ok := args[0].ConstValue.(string)
|
||||
if !args[0].Const || !ok {
|
||||
return nil, nil, ctx.Errorf(loc,
|
||||
"not enough arguments in call to native")
|
||||
}
|
||||
// Our native name constant is not needed in the implementation.
|
||||
gen.RemoveConstant(args[0])
|
||||
args = args[1:]
|
||||
|
||||
switch name {
|
||||
case "hamming":
|
||||
if len(args) != 2 {
|
||||
return nil, nil, ctx.Errorf(loc,
|
||||
"invalid amount of arguments in call to '%s'", name)
|
||||
}
|
||||
|
||||
var typeInfo types.Info
|
||||
for _, arg := range args {
|
||||
if arg.Type.Bits > typeInfo.Bits {
|
||||
typeInfo = arg.Type
|
||||
}
|
||||
}
|
||||
|
||||
v := gen.AnonVal(typeInfo)
|
||||
block.AddInstr(ssa.NewBuiltinInstr(circuits.Hamming, args[0], args[1],
|
||||
v))
|
||||
|
||||
return block, []ssa.Value{v}, nil
|
||||
|
||||
default:
|
||||
if circuit.IsFilename(name) {
|
||||
return nativeCircuit(name, block, ctx, gen, args, loc)
|
||||
}
|
||||
return nil, nil, ctx.Errorf(loc, "unknown native '%s'", name)
|
||||
}
|
||||
}
|
||||
|
||||
func nativeCircuit(name string, block *ssa.Block, ctx *Codegen,
|
||||
gen *ssa.Generator, args []ssa.Value, loc utils.Point) (
|
||||
*ssa.Block, []ssa.Value, error) {
|
||||
|
||||
dir := path.Dir(loc.Source)
|
||||
fp := path.Join(dir, name)
|
||||
|
||||
var err error
|
||||
|
||||
circ, ok := ctx.Native[fp]
|
||||
if !ok {
|
||||
circ, err = circuit.Parse(fp)
|
||||
if err != nil {
|
||||
return nil, nil, ctx.Errorf(loc, "failed to parse circuit: %s", err)
|
||||
}
|
||||
circ.AssignLevels()
|
||||
ctx.Native[fp] = circ
|
||||
if ctx.Verbose {
|
||||
fmt.Printf(" - native %s: %v\n", name, circ)
|
||||
}
|
||||
} else if ctx.Verbose {
|
||||
fmt.Printf(" - native %s: cached\n", name)
|
||||
}
|
||||
|
||||
if len(circ.Inputs) > len(args) {
|
||||
return nil, nil, ctx.Errorf(loc,
|
||||
"not enough arguments in call to native")
|
||||
} else if len(circ.Inputs) < len(args) {
|
||||
return nil, nil, ctx.Errorf(loc, "too many argument in call to native")
|
||||
}
|
||||
// Check that the argument types match.
|
||||
for idx, io := range circ.Inputs {
|
||||
arg := args[idx]
|
||||
if io.Type.Bits < arg.Type.Bits || io.Type.Bits > arg.Type.Bits &&
|
||||
!arg.Const {
|
||||
return nil, nil, ctx.Errorf(loc,
|
||||
"invalid argument %d for native circuit: got %s, need %d",
|
||||
idx, arg.Type, io.Type.Bits)
|
||||
}
|
||||
}
|
||||
|
||||
var result []ssa.Value
|
||||
|
||||
for _, io := range circ.Outputs {
|
||||
result = append(result, gen.AnonVal(types.Info{
|
||||
Type: types.TUndefined,
|
||||
IsConcrete: true,
|
||||
Bits: io.Type.Bits,
|
||||
}))
|
||||
}
|
||||
|
||||
block.AddInstr(ssa.NewCircInstr(args, circ, result))
|
||||
|
||||
return block, result, nil
|
||||
}
|
||||
|
||||
func panicSSA(block *ssa.Block, ctx *Codegen, gen *ssa.Generator,
|
||||
args []ssa.Value, loc utils.Point) (*ssa.Block, []ssa.Value, error) {
|
||||
|
||||
var arr []string
|
||||
for _, arg := range args {
|
||||
var str string
|
||||
if arg.Const {
|
||||
str = fmt.Sprintf("%v", arg.ConstValue)
|
||||
} else {
|
||||
str = arg.String()
|
||||
}
|
||||
arr = append(arr, str)
|
||||
}
|
||||
|
||||
return nil, nil, ctx.Errorf(loc, "panic: %v", panicMessage(arr))
|
||||
}
|
||||
|
||||
func panicEval(args []AST, env *Env, ctx *Codegen, gen *ssa.Generator,
|
||||
loc utils.Point) (ssa.Value, bool, error) {
|
||||
|
||||
var arr []string
|
||||
for _, arg := range args {
|
||||
arr = append(arr, arg.String())
|
||||
}
|
||||
|
||||
return ssa.Undefined, false, ctx.Errorf(loc, "panic: %v", panicMessage(arr))
|
||||
}
|
||||
|
||||
func panicMessage(args []string) string {
|
||||
if len(args) == 0 {
|
||||
return ""
|
||||
}
|
||||
format := args[0]
|
||||
args = args[1:]
|
||||
|
||||
var result string
|
||||
|
||||
for i := 0; i < len(format); i++ {
|
||||
if format[i] != '%' || i+1 >= len(format) {
|
||||
result += string(format[i])
|
||||
continue
|
||||
}
|
||||
i++
|
||||
if len(args) == 0 {
|
||||
result += fmt.Sprintf("%%!%c(MISSING)", format[i])
|
||||
continue
|
||||
}
|
||||
switch format[i] {
|
||||
case 'v':
|
||||
result += args[0]
|
||||
default:
|
||||
result += fmt.Sprintf("%%!%c(%v)", format[i], args[0])
|
||||
}
|
||||
args = args[1:]
|
||||
}
|
||||
|
||||
for _, arg := range args {
|
||||
result += " " + arg
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func sizeSSA(block *ssa.Block, ctx *Codegen, gen *ssa.Generator,
|
||||
args []ssa.Value, loc utils.Point) (*ssa.Block, []ssa.Value, error) {
|
||||
|
||||
if len(args) != 1 {
|
||||
return nil, nil, ctx.Errorf(loc,
|
||||
"invalid amount of arguments in call to size")
|
||||
}
|
||||
|
||||
v := gen.Constant(int64(args[0].Type.Bits), types.Undefined)
|
||||
gen.AddConstant(v)
|
||||
|
||||
return block, []ssa.Value{v}, nil
|
||||
}
|
||||
|
||||
func sizeEval(args []AST, env *Env, ctx *Codegen, gen *ssa.Generator,
|
||||
loc utils.Point) (ssa.Value, bool, error) {
|
||||
|
||||
if len(args) != 1 {
|
||||
return ssa.Undefined, false, ctx.Errorf(loc,
|
||||
"invalid amount of arguments in call to size")
|
||||
}
|
||||
|
||||
switch arg := args[0].(type) {
|
||||
case *VariableRef:
|
||||
var b ssa.Binding
|
||||
var ok bool
|
||||
|
||||
if len(arg.Name.Package) > 0 {
|
||||
var pkg *Package
|
||||
pkg, ok = ctx.Packages[arg.Name.Package]
|
||||
if !ok {
|
||||
return ssa.Undefined, false, ctx.Errorf(loc,
|
||||
"package '%s' not found", arg.Name.Package)
|
||||
}
|
||||
b, ok = pkg.Bindings.Get(arg.Name.Name)
|
||||
} else {
|
||||
b, ok = env.Get(arg.Name.Name)
|
||||
}
|
||||
if !ok {
|
||||
return ssa.Undefined, false, ctx.Errorf(loc,
|
||||
"undefined variable '%s'", arg.Name.String())
|
||||
}
|
||||
return gen.Constant(int64(b.Type.Bits), types.Undefined), true, nil
|
||||
|
||||
default:
|
||||
return ssa.Undefined, false, ctx.Errorf(loc,
|
||||
"size(%v/%T) is not constant", arg, arg)
|
||||
}
|
||||
}
|
||||
219
bedlam/compiler/ast/codegen.go
Normal file
219
bedlam/compiler/ast/codegen.go
Normal file
@ -0,0 +1,219 @@
|
||||
//
|
||||
// ast.go
|
||||
//
|
||||
// Copyright (c) 2019-2024 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package ast
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/circuit"
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler/ssa"
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler/utils"
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/types"
|
||||
)
|
||||
|
||||
// Codegen implements compilation stack.
|
||||
type Codegen struct {
|
||||
logger *utils.Logger
|
||||
Params *utils.Params
|
||||
Verbose bool
|
||||
Package *Package
|
||||
Packages map[string]*Package
|
||||
MainInputSizes [][]int
|
||||
Stack []Compilation
|
||||
Types map[types.ID]*TypeInfo
|
||||
Native map[string]*circuit.Circuit
|
||||
HeapID int
|
||||
}
|
||||
|
||||
// NewCodegen creates a new compilation.
|
||||
func NewCodegen(logger *utils.Logger, pkg *Package,
|
||||
packages map[string]*Package, params *utils.Params,
|
||||
mainInputSizes [][]int) *Codegen {
|
||||
|
||||
return &Codegen{
|
||||
logger: logger,
|
||||
Params: params,
|
||||
Verbose: params.Verbose,
|
||||
Package: pkg,
|
||||
Packages: packages,
|
||||
MainInputSizes: mainInputSizes,
|
||||
Types: make(map[types.ID]*TypeInfo),
|
||||
Native: make(map[string]*circuit.Circuit),
|
||||
}
|
||||
}
|
||||
|
||||
func (ctx *Codegen) errorLoc(err error) error {
|
||||
if !ctx.Params.QCLCErrorLoc {
|
||||
return err
|
||||
}
|
||||
|
||||
for skip := 2; ; skip++ {
|
||||
pc, file, line, ok := runtime.Caller(skip)
|
||||
if !ok {
|
||||
return err
|
||||
}
|
||||
f := runtime.FuncForPC(pc)
|
||||
if f != nil && strings.HasSuffix(f.Name(), ".errf") {
|
||||
continue
|
||||
}
|
||||
fmt.Printf("%s:%d: MCPLC error:\n\u2514\u2574%s\n", file, line, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Error logs an error message.
|
||||
func (ctx *Codegen) Error(locator utils.Locator, msg string) error {
|
||||
return ctx.errorLoc(ctx.logger.Errorf(locator.Location(), "%s", msg))
|
||||
}
|
||||
|
||||
// Errorf logs an error message.
|
||||
func (ctx *Codegen) Errorf(locator utils.Locator, format string,
|
||||
a ...interface{}) error {
|
||||
return ctx.errorLoc(ctx.logger.Errorf(locator.Location(), format, a...))
|
||||
}
|
||||
|
||||
// Warningf logs a warning message
|
||||
func (ctx *Codegen) Warningf(locator utils.Locator, format string,
|
||||
a ...interface{}) {
|
||||
ctx.logger.Warningf(locator.Location(), format, a...)
|
||||
}
|
||||
|
||||
// DefineType defines the argument type and assigns it an unique type
|
||||
// ID.
|
||||
func (ctx *Codegen) DefineType(t *TypeInfo) types.ID {
|
||||
id := types.ID(len(ctx.Types) + 0x80000000)
|
||||
ctx.Types[id] = t
|
||||
return id
|
||||
}
|
||||
|
||||
// LookupFunc resolves the named function from the context.
|
||||
func (ctx *Codegen) LookupFunc(block *ssa.Block, ref *VariableRef) (
|
||||
*Func, error) {
|
||||
|
||||
// First, check method calls.
|
||||
if len(ref.Name.Package) > 0 {
|
||||
// Check if package name is bound to a value.
|
||||
var b ssa.Binding
|
||||
var ok bool
|
||||
|
||||
b, ok = block.Bindings.Get(ref.Name.Package)
|
||||
if !ok {
|
||||
// Check names in the current package.
|
||||
b, ok = ctx.Package.Bindings.Get(ref.Name.Package)
|
||||
}
|
||||
if ok {
|
||||
var typeInfo types.Info
|
||||
if b.Type.Type == types.TPtr {
|
||||
typeInfo = *b.Type.ElementType
|
||||
} else {
|
||||
typeInfo = b.Type
|
||||
}
|
||||
|
||||
info, ok := ctx.Types[typeInfo.ID]
|
||||
if !ok {
|
||||
return nil, ctx.Errorf(ref, "%s undefined", ref)
|
||||
}
|
||||
method, ok := info.Methods[ref.Name.Name]
|
||||
if !ok {
|
||||
return nil, ctx.Errorf(ref, "%s undefined", ref)
|
||||
}
|
||||
return method, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Next, check function calls.
|
||||
var pkgName string
|
||||
if len(ref.Name.Package) > 0 {
|
||||
pkgName = ref.Name.Package
|
||||
} else {
|
||||
pkgName = ref.Name.Defined
|
||||
}
|
||||
pkg, ok := ctx.Packages[pkgName]
|
||||
if !ok {
|
||||
return nil, ctx.Errorf(ref, "package '%s' not found", pkgName)
|
||||
}
|
||||
called, ok := pkg.Functions[ref.Name.Name]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
return called, nil
|
||||
}
|
||||
|
||||
// Func returns the current function in the current compilation.
|
||||
func (ctx *Codegen) Func() *Func {
|
||||
if len(ctx.Stack) == 0 {
|
||||
return nil
|
||||
}
|
||||
return ctx.Stack[len(ctx.Stack)-1].Called
|
||||
}
|
||||
|
||||
// Scope returns the value scope in the current compilation.
|
||||
func (ctx *Codegen) Scope() ssa.Scope {
|
||||
if ctx.Func() != nil {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// PushCompilation pushes a new compilation to the compilation stack.
|
||||
func (ctx *Codegen) PushCompilation(start, ret, caller *ssa.Block,
|
||||
called *Func) {
|
||||
|
||||
ctx.Stack = append(ctx.Stack, Compilation{
|
||||
Start: start,
|
||||
Return: ret,
|
||||
Caller: caller,
|
||||
Called: called,
|
||||
})
|
||||
}
|
||||
|
||||
// PopCompilation pops the topmost compilation from the compilation
|
||||
// stack.
|
||||
func (ctx *Codegen) PopCompilation() {
|
||||
if len(ctx.Stack) == 0 {
|
||||
panic("compilation stack underflow")
|
||||
}
|
||||
ctx.Stack = ctx.Stack[:len(ctx.Stack)-1]
|
||||
}
|
||||
|
||||
// Start returns the start block of the current compilation.
|
||||
func (ctx *Codegen) Start() *ssa.Block {
|
||||
return ctx.Stack[len(ctx.Stack)-1].Start
|
||||
}
|
||||
|
||||
// Return returns the return block of the current compilation.
|
||||
func (ctx *Codegen) Return() *ssa.Block {
|
||||
return ctx.Stack[len(ctx.Stack)-1].Return
|
||||
}
|
||||
|
||||
// Caller returns the caller block of the current compilation.
|
||||
func (ctx *Codegen) Caller() *ssa.Block {
|
||||
return ctx.Stack[len(ctx.Stack)-1].Caller
|
||||
}
|
||||
|
||||
// HeapVar returns the name of the next global heap variable.
|
||||
func (ctx *Codegen) HeapVar() string {
|
||||
name := fmt.Sprintf("$heap%v", ctx.HeapID)
|
||||
ctx.HeapID++
|
||||
return name
|
||||
}
|
||||
|
||||
// Compilation contains information about a compilation
|
||||
// scope. Toplevel, each function call, and each nested block specify
|
||||
// their own scope with their own variable bindings.
|
||||
type Compilation struct {
|
||||
Start *ssa.Block
|
||||
Return *ssa.Block
|
||||
Caller *ssa.Block
|
||||
Called *Func
|
||||
// XXX Bindings
|
||||
// XXX Parent scope.
|
||||
}
|
||||
684
bedlam/compiler/ast/eval.go
Normal file
684
bedlam/compiler/ast/eval.go
Normal file
@ -0,0 +1,684 @@
|
||||
//
|
||||
// Copyright (c) 2019-2024 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package ast
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler/mpa"
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler/ssa"
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/types"
|
||||
)
|
||||
|
||||
const (
|
||||
debugEval = false
|
||||
)
|
||||
|
||||
// Eval implements the compiler.ast.AST.Eval for list statements.
|
||||
func (ast List) Eval(env *Env, ctx *Codegen, gen *ssa.Generator) (
|
||||
ssa.Value, bool, error) {
|
||||
return ssa.Undefined, false, ctx.Errorf(ast, "List.Eval not implemented")
|
||||
}
|
||||
|
||||
// Eval implements the compiler.ast.AST.Eval for function definitions.
|
||||
func (ast *Func) Eval(env *Env, ctx *Codegen, gen *ssa.Generator) (
|
||||
ssa.Value, bool, error) {
|
||||
return ssa.Undefined, false, nil
|
||||
}
|
||||
|
||||
// Eval implements the compiler.ast.AST.Eval for variable definitions.
|
||||
func (ast *VariableDef) Eval(env *Env, ctx *Codegen,
|
||||
gen *ssa.Generator) (ssa.Value, bool, error) {
|
||||
return ssa.Undefined, false, nil
|
||||
}
|
||||
|
||||
// Eval implements the compiler.ast.AST.Eval for assignment expressions.
|
||||
func (ast *Assign) Eval(env *Env, ctx *Codegen, gen *ssa.Generator) (
|
||||
ssa.Value, bool, error) {
|
||||
|
||||
var values []interface{}
|
||||
for _, expr := range ast.Exprs {
|
||||
val, ok, err := expr.Eval(env, ctx, gen)
|
||||
if err != nil || !ok {
|
||||
return ssa.Undefined, ok, err
|
||||
}
|
||||
// XXX multiple return values.
|
||||
values = append(values, val)
|
||||
}
|
||||
|
||||
if len(ast.LValues) != len(values) {
|
||||
return ssa.Undefined, false, ctx.Errorf(ast,
|
||||
"assignment mismatch: %d variables but %d values",
|
||||
len(ast.LValues), len(values))
|
||||
}
|
||||
|
||||
arrType := types.Info{
|
||||
Type: types.TArray,
|
||||
IsConcrete: true,
|
||||
ArraySize: types.Size(len(values)),
|
||||
}
|
||||
|
||||
if ast.Define {
|
||||
for idx, lv := range ast.LValues {
|
||||
constVal := gen.Constant(values[idx], types.Undefined)
|
||||
gen.AddConstant(constVal)
|
||||
arrType.ElementType = &constVal.Type
|
||||
|
||||
ref, ok := lv.(*VariableRef)
|
||||
if !ok {
|
||||
return ssa.Undefined, false,
|
||||
ctx.Errorf(ast, "cannot assign to %s", lv)
|
||||
}
|
||||
// XXX package.name below
|
||||
|
||||
lValue := gen.NewVal(ref.Name.Name, constVal.Type, ctx.Scope())
|
||||
env.Set(lValue, &constVal)
|
||||
}
|
||||
} else {
|
||||
for idx, lv := range ast.LValues {
|
||||
ref, ok := lv.(*VariableRef)
|
||||
if !ok {
|
||||
return ssa.Undefined, false,
|
||||
ctx.Errorf(ast, "cannot assign to %s", lv)
|
||||
}
|
||||
// XXX package.name below
|
||||
|
||||
b, ok := env.Get(ref.Name.Name)
|
||||
if !ok {
|
||||
return ssa.Undefined, false,
|
||||
ctx.Errorf(ast, "undefined variable '%s'", ref.Name)
|
||||
}
|
||||
lValue := gen.NewVal(b.Name, b.Type, ctx.Scope())
|
||||
|
||||
constVal := gen.Constant(values[idx], b.Type)
|
||||
gen.AddConstant(constVal)
|
||||
arrType.ElementType = &constVal.Type
|
||||
|
||||
env.Set(lValue, &constVal)
|
||||
}
|
||||
}
|
||||
|
||||
return gen.Constant(values, arrType), true, nil
|
||||
}
|
||||
|
||||
// Eval implements the compiler.ast.AST.Eval for if statements.
|
||||
func (ast *If) Eval(env *Env, ctx *Codegen, gen *ssa.Generator) (
|
||||
ssa.Value, bool, error) {
|
||||
return ssa.Undefined, false, nil
|
||||
}
|
||||
|
||||
// Eval implements the compiler.ast.AST.Eval for call expressions.
|
||||
func (ast *Call) Eval(env *Env, ctx *Codegen, gen *ssa.Generator) (
|
||||
ssa.Value, bool, error) {
|
||||
|
||||
if debugEval {
|
||||
fmt.Printf("Call.Eval: %s(", ast.Ref)
|
||||
for idx, expr := range ast.Exprs {
|
||||
if idx > 0 {
|
||||
fmt.Print(", ")
|
||||
}
|
||||
fmt.Printf("%v", expr)
|
||||
}
|
||||
fmt.Println(")")
|
||||
}
|
||||
|
||||
// Resolve called.
|
||||
var pkgName string
|
||||
if len(ast.Ref.Name.Package) > 0 {
|
||||
pkgName = ast.Ref.Name.Package
|
||||
} else {
|
||||
pkgName = ast.Ref.Name.Defined
|
||||
}
|
||||
pkg, ok := ctx.Packages[pkgName]
|
||||
if !ok {
|
||||
return ssa.Undefined, false,
|
||||
ctx.Errorf(ast, "package '%s' not found", pkgName)
|
||||
}
|
||||
_, ok = pkg.Functions[ast.Ref.Name.Name]
|
||||
if ok {
|
||||
return ssa.Undefined, false, nil
|
||||
}
|
||||
// Check builtin functions.
|
||||
bi, ok := builtins[ast.Ref.Name.Name]
|
||||
if ok && bi.Eval != nil {
|
||||
return bi.Eval(ast.Exprs, env, ctx, gen, ast.Location())
|
||||
}
|
||||
|
||||
// Resolve name as type.
|
||||
typeName := &TypeInfo{
|
||||
Point: ast.Point,
|
||||
Type: TypeName,
|
||||
Name: ast.Ref.Name,
|
||||
}
|
||||
typeInfo, err := typeName.Resolve(env, ctx, gen)
|
||||
if err != nil {
|
||||
return ssa.Undefined, false, err
|
||||
}
|
||||
if len(ast.Exprs) != 1 {
|
||||
return ssa.Undefined, false, nil
|
||||
}
|
||||
constVal, ok, err := ast.Exprs[0].Eval(env, ctx, gen)
|
||||
if err != nil {
|
||||
return ssa.Undefined, false, err
|
||||
}
|
||||
if !ok {
|
||||
return ssa.Undefined, false, nil
|
||||
}
|
||||
|
||||
switch typeInfo.Type {
|
||||
case types.TInt, types.TUint:
|
||||
switch constVal.Type.Type {
|
||||
case types.TInt, types.TUint:
|
||||
if !typeInfo.Concrete() {
|
||||
typeInfo.Bits = constVal.Type.Bits
|
||||
typeInfo.SetConcrete(true)
|
||||
}
|
||||
if constVal.Type.MinBits > typeInfo.Bits {
|
||||
typeInfo.MinBits = typeInfo.Bits
|
||||
} else {
|
||||
typeInfo.MinBits = constVal.Type.MinBits
|
||||
}
|
||||
cast := constVal
|
||||
cast.Type = typeInfo
|
||||
if constVal.HashCode() != cast.HashCode() {
|
||||
panic("const cast changes value HashCode")
|
||||
}
|
||||
if !constVal.Equal(&cast) {
|
||||
panic("const cast changes value equality")
|
||||
}
|
||||
return cast, true, nil
|
||||
|
||||
default:
|
||||
return ssa.Undefined, false,
|
||||
ctx.Errorf(ast.Ref, "casting %T not supported", constVal.Type)
|
||||
}
|
||||
}
|
||||
|
||||
return ssa.Undefined, false, nil
|
||||
}
|
||||
|
||||
// Eval implements the compiler.ast.AST.Eval.
|
||||
func (ast *ArrayCast) Eval(env *Env, ctx *Codegen, gen *ssa.Generator) (
|
||||
ssa.Value, bool, error) {
|
||||
|
||||
typeInfo, err := ast.TypeInfo.Resolve(env, ctx, gen)
|
||||
if err != nil {
|
||||
return ssa.Undefined, false, err
|
||||
}
|
||||
if !typeInfo.Type.Array() {
|
||||
return ssa.Undefined, false,
|
||||
ctx.Errorf(ast.Expr, "array cast to non-array type %v", typeInfo)
|
||||
}
|
||||
|
||||
cv, ok, err := ast.Expr.Eval(env, ctx, gen)
|
||||
if err != nil {
|
||||
return ssa.Undefined, false, err
|
||||
}
|
||||
if !ok {
|
||||
return ssa.Undefined, false, nil
|
||||
}
|
||||
|
||||
switch cv.Type.Type {
|
||||
case types.TString:
|
||||
if cv.Type.Bits%8 != 0 {
|
||||
return ssa.Undefined, false,
|
||||
ctx.Errorf(ast.Expr, "invalid string length %v", cv.Type.Bits)
|
||||
}
|
||||
chars := cv.Type.Bits / 8
|
||||
et := typeInfo.ElementType
|
||||
if et.Bits != 8 || et.Type != types.TUint {
|
||||
return ssa.Undefined, false,
|
||||
ctx.Errorf(ast.Expr, "cast from %v to %v",
|
||||
cv.Type, ast.TypeInfo)
|
||||
}
|
||||
|
||||
if typeInfo.Concrete() {
|
||||
if typeInfo.ArraySize != chars || typeInfo.Bits != cv.Type.Bits {
|
||||
return ssa.Undefined, false,
|
||||
ctx.Errorf(ast.Expr, "cast from %v to %v",
|
||||
cv.Type, ast.TypeInfo)
|
||||
}
|
||||
} else {
|
||||
typeInfo.Bits = cv.Type.Bits
|
||||
typeInfo.ArraySize = chars
|
||||
typeInfo.SetConcrete(true)
|
||||
}
|
||||
cast := cv
|
||||
cast.Type = typeInfo
|
||||
if cv.HashCode() != cast.HashCode() {
|
||||
panic("const array cast changes value HashCode")
|
||||
}
|
||||
if !cv.Equal(&cast) {
|
||||
panic("const array cast changes value equality")
|
||||
}
|
||||
return cast, true, nil
|
||||
}
|
||||
|
||||
return ssa.Undefined, false, nil
|
||||
}
|
||||
|
||||
// Eval implements the compiler.ast.AST.Eval for return statements.
|
||||
func (ast *Return) Eval(env *Env, ctx *Codegen, gen *ssa.Generator) (
|
||||
ssa.Value, bool, error) {
|
||||
return ssa.Undefined, false, nil
|
||||
}
|
||||
|
||||
// Eval implements the compiler.ast.AST.Eval for for statements.
|
||||
func (ast *For) Eval(env *Env, ctx *Codegen, gen *ssa.Generator) (
|
||||
ssa.Value, bool, error) {
|
||||
return ssa.Undefined, false, nil
|
||||
}
|
||||
|
||||
// Eval implements the compiler.ast.AST.Eval.
|
||||
func (ast *ForRange) Eval(env *Env, ctx *Codegen, gen *ssa.Generator) (
|
||||
ssa.Value, bool, error) {
|
||||
return ssa.Undefined, false, nil
|
||||
}
|
||||
|
||||
// Eval implements the compiler.ast.AST.Eval for binary expressions.
|
||||
func (ast *Binary) Eval(env *Env, ctx *Codegen, gen *ssa.Generator) (
|
||||
ssa.Value, bool, error) {
|
||||
l, ok, err := ast.Left.Eval(env, ctx, gen)
|
||||
if err != nil || !ok {
|
||||
return ssa.Undefined, ok, err
|
||||
}
|
||||
r, ok, err := ast.Right.Eval(env, ctx, gen)
|
||||
if err != nil || !ok {
|
||||
return ssa.Undefined, ok, err
|
||||
}
|
||||
|
||||
if debugEval {
|
||||
fmt.Printf("%s: Binary.Eval: %v[%v:%v] %v %v[%v:%v]\n",
|
||||
ast.Location().ShortString(),
|
||||
ast.Left, l, l.Type, ast.Op, ast.Right, r, r.Type)
|
||||
}
|
||||
|
||||
// Resolve result type.
|
||||
rt, err := ast.resultType(ctx, l, r)
|
||||
if err != nil {
|
||||
return ssa.Undefined, false, err
|
||||
}
|
||||
|
||||
switch lval := l.ConstValue.(type) {
|
||||
case bool:
|
||||
rval, ok := r.ConstValue.(bool)
|
||||
if !ok {
|
||||
return ssa.Undefined, false, ctx.Errorf(ast.Right,
|
||||
"%s %v %s: invalid r-value %v (%T)", l, ast.Op, r, rval, rval)
|
||||
}
|
||||
switch ast.Op {
|
||||
case BinaryEq:
|
||||
return gen.Constant(lval == rval, types.Bool), true, nil
|
||||
case BinaryNeq:
|
||||
return gen.Constant(lval != rval, types.Bool), true, nil
|
||||
case BinaryAnd:
|
||||
return gen.Constant(lval && rval, types.Bool), true, nil
|
||||
case BinaryOr:
|
||||
return gen.Constant(lval || rval, types.Bool), true, nil
|
||||
}
|
||||
|
||||
case *mpa.Int:
|
||||
rval, ok := r.ConstValue.(*mpa.Int)
|
||||
if !ok {
|
||||
return ssa.Undefined, false, ctx.Errorf(ast.Right,
|
||||
"%s %v %s: invalid r-value %v (%T)", l, ast.Op, r, rval, rval)
|
||||
}
|
||||
switch ast.Op {
|
||||
case BinaryMul:
|
||||
return gen.Constant(mpa.New(rt.Bits).Mul(lval, rval), rt),
|
||||
true, nil
|
||||
case BinaryDiv:
|
||||
return gen.Constant(mpa.New(rt.Bits).Div(lval, rval), rt),
|
||||
true, nil
|
||||
case BinaryMod:
|
||||
return gen.Constant(mpa.New(rt.Bits).Mod(lval, rval), rt),
|
||||
true, nil
|
||||
case BinaryLshift:
|
||||
return gen.Constant(mpa.New(rt.Bits).Lsh(lval, uint(rval.Int64())),
|
||||
rt), true, nil
|
||||
case BinaryRshift:
|
||||
return gen.Constant(mpa.New(rt.Bits).Rsh(lval, uint(rval.Int64())),
|
||||
rt), true, nil
|
||||
case BinaryBand:
|
||||
return gen.Constant(mpa.New(rt.Bits).And(lval, rval), rt),
|
||||
true, nil
|
||||
case BinaryBclear:
|
||||
return gen.Constant(mpa.New(rt.Bits).AndNot(lval, rval), rt),
|
||||
true, nil
|
||||
case BinaryBor:
|
||||
return gen.Constant(mpa.New(rt.Bits).Or(lval, rval), rt),
|
||||
true, nil
|
||||
case BinaryBxor:
|
||||
return gen.Constant(mpa.New(rt.Bits).Xor(lval, rval), rt),
|
||||
true, nil
|
||||
case BinaryAdd:
|
||||
return gen.Constant(mpa.New(rt.Bits).Add(lval, rval), rt),
|
||||
true, nil
|
||||
case BinarySub:
|
||||
return gen.Constant(mpa.New(rt.Bits).Sub(lval, rval), rt),
|
||||
true, nil
|
||||
case BinaryEq:
|
||||
return gen.Constant(lval.Cmp(rval) == 0, types.Bool), true, nil
|
||||
case BinaryNeq:
|
||||
return gen.Constant(lval.Cmp(rval) != 0, types.Bool), true, nil
|
||||
case BinaryLt:
|
||||
return gen.Constant(lval.Cmp(rval) == -1, types.Bool), true, nil
|
||||
case BinaryLe:
|
||||
return gen.Constant(lval.Cmp(rval) != 1, types.Bool), true, nil
|
||||
case BinaryGt:
|
||||
return gen.Constant(lval.Cmp(rval) == 1, types.Bool), true, nil
|
||||
case BinaryGe:
|
||||
return gen.Constant(lval.Cmp(rval) != -1, types.Bool), true, nil
|
||||
}
|
||||
|
||||
case string:
|
||||
rval, ok := r.ConstValue.(string)
|
||||
if !ok {
|
||||
return ssa.Undefined, false, ctx.Errorf(ast.Right,
|
||||
"%s %v %s: invalid r-value %v (%T)", l, ast.Op, r, rval, rval)
|
||||
}
|
||||
switch ast.Op {
|
||||
case BinaryAdd:
|
||||
return gen.Constant(lval+rval, types.Undefined), true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return ssa.Undefined, false, ctx.Errorf(ast.Right,
|
||||
"invalid operation: operator %s not defined on %v (%v)",
|
||||
ast.Op, l, l.Type)
|
||||
}
|
||||
|
||||
// Eval implements the compiler.ast.AST.Eval for unary expressions.
|
||||
func (ast *Unary) Eval(env *Env, ctx *Codegen, gen *ssa.Generator) (
|
||||
ssa.Value, bool, error) {
|
||||
expr, ok, err := ast.Expr.Eval(env, ctx, gen)
|
||||
if err != nil || !ok {
|
||||
return ssa.Undefined, ok, err
|
||||
}
|
||||
switch val := expr.ConstValue.(type) {
|
||||
case bool:
|
||||
switch ast.Type {
|
||||
case UnaryNot:
|
||||
return gen.Constant(!val, types.Bool), true, nil
|
||||
}
|
||||
case *mpa.Int:
|
||||
switch ast.Type {
|
||||
case UnaryMinus:
|
||||
r := mpa.NewInt(0, expr.Type.Bits)
|
||||
return gen.Constant(r.Sub(r, val), expr.Type), true, nil
|
||||
}
|
||||
}
|
||||
return ssa.Undefined, false, ctx.Errorf(ast.Expr,
|
||||
"invalid unary expression: %s%T", ast.Type, ast.Expr)
|
||||
}
|
||||
|
||||
// Eval implements the compiler.ast.AST.Eval for slice expressions.
|
||||
func (ast *Slice) Eval(env *Env, ctx *Codegen, gen *ssa.Generator) (
|
||||
ssa.Value, bool, error) {
|
||||
|
||||
expr, ok, err := ast.Expr.Eval(env, ctx, gen)
|
||||
if err != nil || !ok {
|
||||
return ssa.Undefined, ok, err
|
||||
}
|
||||
|
||||
from := 0
|
||||
to := math.MaxInt32
|
||||
|
||||
if ast.From != nil {
|
||||
val, ok, err := ast.From.Eval(env, ctx, gen)
|
||||
if err != nil || !ok {
|
||||
return ssa.Undefined, ok, err
|
||||
}
|
||||
from, err = intVal(val)
|
||||
if err != nil {
|
||||
return ssa.Undefined, false, ctx.Errorf(ast.From, err.Error())
|
||||
}
|
||||
}
|
||||
if ast.To != nil {
|
||||
val, ok, err := ast.To.Eval(env, ctx, gen)
|
||||
if err != nil || !ok {
|
||||
return ssa.Undefined, ok, err
|
||||
}
|
||||
to, err = intVal(val)
|
||||
if err != nil {
|
||||
return ssa.Undefined, false, ctx.Errorf(ast.To, err.Error())
|
||||
}
|
||||
}
|
||||
if to < from {
|
||||
return ssa.Undefined, false, ctx.Errorf(ast.Expr,
|
||||
"invalid slice range %d:%d", from, to)
|
||||
}
|
||||
if !expr.Type.Type.Array() {
|
||||
return ssa.Undefined, false, ctx.Errorf(ast.Expr,
|
||||
"invalid operation: cannot slice %v (%v)", expr, expr.Type)
|
||||
}
|
||||
arr, err := expr.ConstArray()
|
||||
if err != nil {
|
||||
return ssa.Undefined, false, err
|
||||
}
|
||||
if to == math.MaxInt32 {
|
||||
to = int(expr.Type.ArraySize)
|
||||
}
|
||||
if to > int(expr.Type.ArraySize) || from > to {
|
||||
return ssa.Undefined, false, ctx.Errorf(ast.From,
|
||||
"slice bounds out of range [%d:%d] in slice of length %v",
|
||||
from, to, expr.Type.ArraySize)
|
||||
}
|
||||
numElements := to - from
|
||||
|
||||
switch val := arr.(type) {
|
||||
case []interface{}:
|
||||
ti := expr.Type
|
||||
ti.ArraySize = types.Size(numElements)
|
||||
// The gen.Constant will set the bit sizes.
|
||||
return gen.Constant(val[from:to], ti), true, nil
|
||||
|
||||
case []byte:
|
||||
constVal := make([]interface{}, numElements)
|
||||
for i := 0; i < numElements; i++ {
|
||||
constVal[i] = int64(val[from+i])
|
||||
}
|
||||
ti := expr.Type
|
||||
ti.ArraySize = types.Size(numElements)
|
||||
// The gen.Constant will set the bit sizes.
|
||||
return gen.Constant(constVal, ti), true, nil
|
||||
|
||||
default:
|
||||
return ssa.Undefined, false, ctx.Errorf(ast.Expr,
|
||||
"invalid operation: cannot slice %T array", arr)
|
||||
}
|
||||
}
|
||||
|
||||
func intVal(val interface{}) (int, error) {
|
||||
switch v := val.(type) {
|
||||
case *mpa.Int:
|
||||
return int(v.Int64()), nil
|
||||
|
||||
case ssa.Value:
|
||||
if !v.Const {
|
||||
return 0, fmt.Errorf("non-const slice index: %v", v)
|
||||
}
|
||||
return intVal(v.ConstValue)
|
||||
|
||||
default:
|
||||
return 0, fmt.Errorf("invalid slice index: %T", v)
|
||||
}
|
||||
}
|
||||
|
||||
// Eval implements the compiler.ast.AST.Eval() for index expressions.
|
||||
func (ast *Index) Eval(env *Env, ctx *Codegen, gen *ssa.Generator) (
|
||||
ssa.Value, bool, error) {
|
||||
|
||||
expr, ok, err := ast.Expr.Eval(env, ctx, gen)
|
||||
if err != nil || !ok {
|
||||
return ssa.Undefined, ok, err
|
||||
}
|
||||
|
||||
val, ok, err := ast.Index.Eval(env, ctx, gen)
|
||||
if err != nil || !ok {
|
||||
return ssa.Undefined, ok, err
|
||||
}
|
||||
index, err := intVal(val)
|
||||
if err != nil {
|
||||
return ssa.Undefined, false, ctx.Errorf(ast.Index, err.Error())
|
||||
}
|
||||
|
||||
switch expr.Type.Type {
|
||||
case types.TArray, types.TSlice:
|
||||
if index < 0 || index >= int(expr.Type.ArraySize) {
|
||||
return ssa.Undefined, false, ctx.Errorf(ast.Index,
|
||||
"invalid array index %d (out of bounds for %d-element array)",
|
||||
index, expr.Type.ArraySize)
|
||||
}
|
||||
arr, err := expr.ConstArray()
|
||||
if err != nil {
|
||||
return ssa.Undefined, false, err
|
||||
}
|
||||
switch val := arr.(type) {
|
||||
case []interface{}:
|
||||
return gen.Constant(val[index], *expr.Type.ElementType), true, nil
|
||||
|
||||
case []byte:
|
||||
return gen.Constant(int64(val[index]), *expr.Type.ElementType),
|
||||
true, nil
|
||||
}
|
||||
|
||||
case types.TString:
|
||||
numBytes := expr.Type.Bits / types.ByteBits
|
||||
if index < 0 || index >= int(numBytes) {
|
||||
return ssa.Undefined, false, ctx.Errorf(ast.Index,
|
||||
"invalid array index %d (out of bounds for %d-element string)",
|
||||
index, numBytes)
|
||||
}
|
||||
str, err := expr.ConstString()
|
||||
if err != nil {
|
||||
return ssa.Undefined, false, err
|
||||
}
|
||||
bytes := []byte(str)
|
||||
return gen.Constant(int64(bytes[index]), types.Byte), true, nil
|
||||
}
|
||||
|
||||
return ssa.Undefined, false, ctx.Errorf(ast.Expr,
|
||||
"invalid operation: cannot index %v (%v)", expr, expr.Type)
|
||||
}
|
||||
|
||||
// Eval implements the compiler.ast.AST.Eval for variable references.
|
||||
func (ast *VariableRef) Eval(env *Env, ctx *Codegen, gen *ssa.Generator) (
|
||||
ssa.Value, bool, error) {
|
||||
|
||||
lrv, ok, _, err := ctx.LookupVar(nil, gen, env.Bindings, ast)
|
||||
if err != nil {
|
||||
return ssa.Undefined, false, ctx.Error(ast, err.Error())
|
||||
}
|
||||
if !ok {
|
||||
return ssa.Undefined, ok, nil
|
||||
}
|
||||
|
||||
return lrv.ConstValue()
|
||||
}
|
||||
|
||||
// Eval implements the compiler.ast.AST.Eval for constant values.
|
||||
func (ast *BasicLit) Eval(env *Env, ctx *Codegen, gen *ssa.Generator) (
|
||||
ssa.Value, bool, error) {
|
||||
return gen.Constant(ast.Value, types.Undefined), true, nil
|
||||
}
|
||||
|
||||
// Eval implements the compiler.ast.AST.Eval for constant values.
|
||||
func (ast *CompositeLit) Eval(env *Env, ctx *Codegen, gen *ssa.Generator) (
|
||||
ssa.Value, bool, error) {
|
||||
|
||||
// XXX the init values might be short so we must pad them with
|
||||
// zero values so that we create correctly sized values.
|
||||
|
||||
typeInfo, err := ast.Type.Resolve(env, ctx, gen)
|
||||
if err != nil {
|
||||
return ssa.Undefined, false, err
|
||||
}
|
||||
switch typeInfo.Type {
|
||||
case types.TStruct:
|
||||
// Check if all elements are constants.
|
||||
var values []interface{}
|
||||
for _, el := range ast.Value {
|
||||
// XXX check if el.Key is specified
|
||||
|
||||
v, ok, err := el.Element.Eval(env, ctx, gen)
|
||||
if err != nil || !ok {
|
||||
return ssa.Undefined, ok, err
|
||||
}
|
||||
// XXX check that v is assignment compatible with typeInfo.Struct[i]
|
||||
values = append(values, v)
|
||||
}
|
||||
return gen.Constant(values, typeInfo), true, nil
|
||||
|
||||
case types.TArray, types.TSlice:
|
||||
// Check if all elements are constants.
|
||||
var values []interface{}
|
||||
for _, el := range ast.Value {
|
||||
// XXX check if el.Key is specified
|
||||
|
||||
v, ok, err := el.Element.Eval(env, ctx, gen)
|
||||
if err != nil || !ok {
|
||||
return ssa.Undefined, ok, err
|
||||
}
|
||||
// XXX check that v is assignment compatible with array.
|
||||
values = append(values, v)
|
||||
}
|
||||
typeInfo.ArraySize = types.Size(len(values))
|
||||
typeInfo.Bits = typeInfo.ArraySize * typeInfo.ElementType.Bits
|
||||
typeInfo.MinBits = typeInfo.Bits
|
||||
return gen.Constant(values, typeInfo), true, nil
|
||||
|
||||
default:
|
||||
fmt.Printf("CompositeLit.Eval: not implemented yet: %v, Value: %v\n",
|
||||
typeInfo, ast.Value)
|
||||
return ssa.Undefined, false, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Eval implements the compiler.ast.AST.Eval for the builtin function make.
|
||||
func (ast *Make) Eval(env *Env, ctx *Codegen, gen *ssa.Generator) (
|
||||
ssa.Value, bool, error) {
|
||||
if len(ast.Exprs) != 1 {
|
||||
return ssa.Undefined, false, ctx.Errorf(ast,
|
||||
"invalid amount of argument in call to make")
|
||||
}
|
||||
typeInfo, err := ast.Type.Resolve(env, ctx, gen)
|
||||
if err != nil {
|
||||
return ssa.Undefined, false, ctx.Errorf(ast.Type, "%s is not a type",
|
||||
ast.Type)
|
||||
}
|
||||
if typeInfo.Type.Array() {
|
||||
// Arrays are made in Make.SSA.
|
||||
return ssa.Undefined, false, nil
|
||||
}
|
||||
if typeInfo.Bits != 0 {
|
||||
return ssa.Undefined, false, ctx.Errorf(ast.Type,
|
||||
"can't make specified type %s", typeInfo)
|
||||
}
|
||||
constVal, _, err := ast.Exprs[0].Eval(env, ctx, gen)
|
||||
if err != nil {
|
||||
return ssa.Undefined, false, ctx.Error(ast.Exprs[0], err.Error())
|
||||
}
|
||||
length, err := constVal.ConstInt()
|
||||
if err != nil {
|
||||
return ssa.Undefined, false, ctx.Errorf(ast.Exprs[0],
|
||||
"non-integer (%T) len argument in %s: %s", constVal, ast, err)
|
||||
}
|
||||
|
||||
typeInfo.IsConcrete = true
|
||||
typeInfo.Bits = length
|
||||
|
||||
// Create typeref constant.
|
||||
return gen.Constant(typeInfo, types.Undefined), true, nil
|
||||
}
|
||||
|
||||
// Eval implements the compiler.ast.AST.Eval for the builtin function copy.
|
||||
func (ast *Copy) Eval(env *Env, ctx *Codegen, gen *ssa.Generator) (
|
||||
ssa.Value, bool, error) {
|
||||
return ssa.Undefined, false, nil
|
||||
}
|
||||
367
bedlam/compiler/ast/lrvalue.go
Normal file
367
bedlam/compiler/ast/lrvalue.go
Normal file
@ -0,0 +1,367 @@
|
||||
//
|
||||
// ast.go
|
||||
//
|
||||
// Copyright (c) 2019-2024 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package ast
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler/ssa"
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/types"
|
||||
)
|
||||
|
||||
// LRValue implements value as l-value or r-value. The LRValues have
|
||||
// two types:
|
||||
//
|
||||
// 1. base type that specifies the base memory location containing
|
||||
// the value
|
||||
// 2. value type that specifies the wires of the value
|
||||
//
|
||||
// Both types can be the same. The base and value types are set as
|
||||
// follows for different names:
|
||||
//
|
||||
// 1. Struct.Field:
|
||||
// - baseInfo points to the containing variable
|
||||
// - baseValue is the Struct
|
||||
// - structField defines the structure field
|
||||
// - valueType is the type of the Struct.Field
|
||||
// - value is nil
|
||||
//
|
||||
// 2. Package.Name:
|
||||
// - baseInfo points to the containing variable in Package
|
||||
// - baseValue is the value of Package.Name
|
||||
// - structField is nil
|
||||
// - valueType is the type of Package.Name
|
||||
// - value is the value of Package.Name
|
||||
//
|
||||
// 3. Name:
|
||||
// - baseInfo points to the containing local variable
|
||||
// - baseValue is the value of Name
|
||||
// - structField is nil
|
||||
// - valueType is the type of Name
|
||||
// - value is the value of Name
|
||||
type LRValue struct {
|
||||
ctx *Codegen
|
||||
ast AST
|
||||
block *ssa.Block
|
||||
gen *ssa.Generator
|
||||
baseInfo *ssa.PtrInfo
|
||||
baseValue ssa.Value
|
||||
valueType types.Info
|
||||
value ssa.Value
|
||||
structField *types.StructField
|
||||
}
|
||||
|
||||
func (lrv LRValue) String() string {
|
||||
offset := lrv.baseInfo.Offset + lrv.valueType.Offset
|
||||
return fmt.Sprintf("%s[%d-%d]@%s{%d}%s/%v",
|
||||
lrv.valueType, offset, offset+lrv.valueType.Bits,
|
||||
lrv.baseInfo.Name, lrv.baseInfo.Scope, lrv.baseInfo.ContainerType,
|
||||
lrv.baseInfo.ContainerType.Bits)
|
||||
}
|
||||
|
||||
// BaseType returns the base type of the LRValue.
|
||||
func (lrv *LRValue) BaseType() types.Info {
|
||||
return lrv.baseInfo.ContainerType
|
||||
}
|
||||
|
||||
// BaseValue returns the base value of the LRValue
|
||||
func (lrv *LRValue) BaseValue() ssa.Value {
|
||||
return lrv.baseValue
|
||||
}
|
||||
|
||||
// BasePtrInfo returns the base value as PtrInfo.
|
||||
func (lrv *LRValue) BasePtrInfo() *ssa.PtrInfo {
|
||||
return lrv.baseInfo
|
||||
}
|
||||
|
||||
// Indirect returns LRValue for the value that lrv points to. If lrv
|
||||
// is not a pointer, Indirect returns lrv.
|
||||
func (lrv *LRValue) Indirect() *LRValue {
|
||||
v := lrv.RValue()
|
||||
if v.Type.Type != types.TPtr {
|
||||
return lrv
|
||||
}
|
||||
|
||||
ret := *lrv
|
||||
ret.valueType = *lrv.valueType.ElementType
|
||||
ret.value.PtrInfo = nil
|
||||
|
||||
if lrv.baseInfo.ContainerType.Type == types.TStruct {
|
||||
// Set value to undefined so RValue() can regenerate it.
|
||||
ret.value.Type = types.Undefined
|
||||
|
||||
// Lookup struct field.
|
||||
ret.structField = nil
|
||||
for idx, f := range lrv.baseValue.Type.Struct {
|
||||
if f.Type.Offset == lrv.baseInfo.Offset {
|
||||
ret.structField = &lrv.baseValue.Type.Struct[idx]
|
||||
break
|
||||
}
|
||||
}
|
||||
if ret.structField == nil {
|
||||
panic("LRValue.Indirect: could not find struct field")
|
||||
}
|
||||
} else {
|
||||
ret.value.Type = *lrv.value.Type.ElementType
|
||||
ret.value = lrv.baseValue
|
||||
}
|
||||
|
||||
return &ret
|
||||
}
|
||||
|
||||
// Set sets the l-value to rv.
|
||||
func (lrv LRValue) Set(rv ssa.Value) error {
|
||||
if !ssa.CanAssign(lrv.valueType, rv) {
|
||||
return fmt.Errorf("cannot assing %v to variable of type %v",
|
||||
rv.Type, lrv.valueType)
|
||||
}
|
||||
lValue := lrv.LValue()
|
||||
|
||||
if lrv.structField != nil {
|
||||
fromConst := lrv.gen.Constant(int64(lrv.structField.Type.Offset),
|
||||
types.Undefined)
|
||||
toConst := lrv.gen.Constant(int64(lrv.structField.Type.Offset+
|
||||
lrv.structField.Type.Bits), types.Undefined)
|
||||
|
||||
lrv.block.AddInstr(ssa.NewAmovInstr(rv, lrv.baseValue,
|
||||
fromConst, toConst, lValue))
|
||||
return lrv.baseInfo.Bindings.Set(lValue, nil)
|
||||
}
|
||||
|
||||
if rv.Const && rv.IntegerLike() {
|
||||
// Type coersions rules for const int r-values.
|
||||
if lValue.Type.Concrete() {
|
||||
rv.Type = lValue.Type
|
||||
} else if rv.Type.Concrete() {
|
||||
lValue.Type = rv.Type
|
||||
} else {
|
||||
return fmt.Errorf("unspecified size for type %v", rv.Type)
|
||||
}
|
||||
} else if rv.Type.Concrete() {
|
||||
// Specifying the value of an unspecified variable, or
|
||||
// specializing it (assining arrays with values of different
|
||||
// size).
|
||||
lValue.Type = rv.Type
|
||||
} else if lValue.Type.Concrete() {
|
||||
// Specializing r-value.
|
||||
rv.Type = lValue.Type
|
||||
} else {
|
||||
return fmt.Errorf("unspecified size for type %v", rv.Type)
|
||||
}
|
||||
lrv.block.AddInstr(ssa.NewMovInstr(rv, lValue))
|
||||
|
||||
// The l-value and r-value types are now resolved. Let's define
|
||||
// the variable with correct type and value information,
|
||||
// overriding any old values.
|
||||
lrv.baseInfo.Bindings.Define(lValue, &rv)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LValue returns the l-value of the LRValue.
|
||||
func (lrv *LRValue) LValue() ssa.Value {
|
||||
return lrv.gen.NewVal(lrv.baseInfo.Name, lrv.baseInfo.ContainerType,
|
||||
lrv.baseInfo.Scope)
|
||||
}
|
||||
|
||||
// RValue returns the r-value of the LRValue.
|
||||
func (lrv *LRValue) RValue() ssa.Value {
|
||||
if lrv.value.Type.Undefined() && lrv.structField != nil {
|
||||
fieldType := lrv.valueType
|
||||
fieldType.Offset = 0
|
||||
|
||||
lrv.value = lrv.gen.AnonVal(fieldType)
|
||||
|
||||
from := int64(lrv.valueType.Offset)
|
||||
to := int64(lrv.valueType.Offset + lrv.valueType.Bits)
|
||||
|
||||
if to > from {
|
||||
fromConst := lrv.gen.Constant(from, types.Undefined)
|
||||
toConst := lrv.gen.Constant(to, types.Undefined)
|
||||
lrv.block.AddInstr(ssa.NewSliceInstr(lrv.baseValue, fromConst,
|
||||
toConst, lrv.value))
|
||||
}
|
||||
}
|
||||
return lrv.value
|
||||
}
|
||||
|
||||
// ValueType returns the value type of the LRValue.
|
||||
func (lrv *LRValue) ValueType() types.Info {
|
||||
return lrv.valueType
|
||||
}
|
||||
|
||||
func (lrv *LRValue) ptrBaseValue() (ssa.Value, error) {
|
||||
b, ok := lrv.baseInfo.Bindings.Get(lrv.baseInfo.Name)
|
||||
if !ok {
|
||||
return ssa.Undefined, fmt.Errorf("undefined: %s", lrv.baseInfo.Name)
|
||||
}
|
||||
return b.Value(lrv.block, lrv.gen), nil
|
||||
}
|
||||
|
||||
// ConstValue returns the constant value of the LRValue if available.
|
||||
func (lrv *LRValue) ConstValue() (ssa.Value, bool, error) {
|
||||
switch lrv.value.Type.Type {
|
||||
case types.TUndefined:
|
||||
return lrv.value, false, nil
|
||||
|
||||
case types.TBool, types.TInt, types.TUint, types.TFloat, types.TString,
|
||||
types.TStruct, types.TArray, types.TSlice, types.TNil:
|
||||
return lrv.value, true, nil
|
||||
|
||||
default:
|
||||
return ssa.Undefined, false, lrv.ctx.Errorf(lrv.ast,
|
||||
"LRValue.ConstValue: %s not supported yet: %v",
|
||||
lrv.value.Type, lrv.value)
|
||||
}
|
||||
}
|
||||
|
||||
// LookupVar resolves the named variable from the context.
|
||||
func (ctx *Codegen) LookupVar(block *ssa.Block, gen *ssa.Generator,
|
||||
bindings *ssa.Bindings, ref *VariableRef) (
|
||||
lrv *LRValue, cf, df bool, err error) {
|
||||
|
||||
lrv = &LRValue{
|
||||
ctx: ctx,
|
||||
ast: ref,
|
||||
block: block,
|
||||
gen: gen,
|
||||
}
|
||||
|
||||
var env *ssa.Bindings
|
||||
var b ssa.Binding
|
||||
var ok bool
|
||||
|
||||
if len(ref.Name.Package) > 0 {
|
||||
// Check if package name is bound to a value.
|
||||
b, ok = bindings.Get(ref.Name.Package)
|
||||
if ok {
|
||||
env = bindings
|
||||
} else {
|
||||
// Check names in the current package.
|
||||
b, ok = ctx.Package.Bindings.Get(ref.Name.Package)
|
||||
if ok {
|
||||
env = ctx.Package.Bindings
|
||||
}
|
||||
}
|
||||
if ok {
|
||||
if block != nil {
|
||||
lrv.baseValue = b.Value(block, gen)
|
||||
} else {
|
||||
// Evaluating a const value.
|
||||
v, ok := b.Bound.(*ssa.Value)
|
||||
if !ok || !v.Const {
|
||||
// Value is not const
|
||||
return nil, false, false, nil
|
||||
}
|
||||
lrv.baseValue = *v
|
||||
}
|
||||
|
||||
if lrv.baseValue.Type.Type == types.TPtr {
|
||||
lrv.baseInfo = lrv.baseValue.PtrInfo
|
||||
lrv.baseValue, err = lrv.ptrBaseValue()
|
||||
if err != nil {
|
||||
return nil, false, false, err
|
||||
}
|
||||
} else {
|
||||
lrv.baseInfo = &ssa.PtrInfo{
|
||||
Name: ref.Name.Package,
|
||||
Bindings: env,
|
||||
Scope: b.Scope,
|
||||
ContainerType: b.Type,
|
||||
}
|
||||
}
|
||||
|
||||
if lrv.baseValue.Type.Type != types.TStruct {
|
||||
return nil, false, false, fmt.Errorf("%s undefined", ref.Name)
|
||||
}
|
||||
|
||||
for _, f := range lrv.baseValue.Type.Struct {
|
||||
if f.Name == ref.Name.Name {
|
||||
lrv.structField = &f
|
||||
break
|
||||
}
|
||||
}
|
||||
if lrv.structField == nil {
|
||||
return nil, false, false, fmt.Errorf(
|
||||
"%s undefined (type %s has no field or method %s)",
|
||||
ref.Name, lrv.baseValue.Type, ref.Name.Name)
|
||||
}
|
||||
lrv.valueType = lrv.structField.Type
|
||||
|
||||
return lrv, true, false, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Explicit package references.
|
||||
var pkg *Package
|
||||
if len(ref.Name.Package) > 0 {
|
||||
pkg, ok = ctx.Packages[ref.Name.Package]
|
||||
if !ok {
|
||||
return nil, false, false, fmt.Errorf("package '%s' not found",
|
||||
ref.Name.Package)
|
||||
}
|
||||
env = pkg.Bindings
|
||||
b, ok = env.Get(ref.Name.Name)
|
||||
if !ok {
|
||||
return nil, false, false, fmt.Errorf("undefined variable '%s'",
|
||||
ref.Name)
|
||||
}
|
||||
} else {
|
||||
// Check block bindings.
|
||||
env = bindings
|
||||
b, ok = env.Get(ref.Name.Name)
|
||||
if !ok {
|
||||
// Check names in the name's package.
|
||||
if len(ref.Name.Defined) > 0 {
|
||||
pkg, ok = ctx.Packages[ref.Name.Defined]
|
||||
if !ok {
|
||||
return nil, false, false,
|
||||
fmt.Errorf("package '%s' not found", ref.Name.Defined)
|
||||
}
|
||||
env = pkg.Bindings
|
||||
b, ok = env.Get(ref.Name.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
return nil, false, true, fmt.Errorf("undefined variable '%s'",
|
||||
ref.Name)
|
||||
}
|
||||
|
||||
if block != nil {
|
||||
lrv.value = b.Value(block, gen)
|
||||
} else {
|
||||
// Evaluating const value.
|
||||
v, ok := b.Bound.(*ssa.Value)
|
||||
if !ok || !v.Const {
|
||||
// Value is not const
|
||||
return nil, false, false, nil
|
||||
}
|
||||
lrv.value = *v
|
||||
}
|
||||
lrv.valueType = lrv.value.Type
|
||||
|
||||
if lrv.value.Type.Type == types.TPtr {
|
||||
lrv.baseInfo = lrv.value.PtrInfo
|
||||
lrv.baseValue, err = lrv.ptrBaseValue()
|
||||
if err != nil {
|
||||
return nil, false, false, err
|
||||
}
|
||||
} else {
|
||||
lrv.baseInfo = &ssa.PtrInfo{
|
||||
Name: ref.Name.Name,
|
||||
Bindings: env,
|
||||
Scope: b.Scope,
|
||||
ContainerType: b.Type,
|
||||
}
|
||||
lrv.baseValue = lrv.value
|
||||
}
|
||||
|
||||
return lrv, true, false, nil
|
||||
}
|
||||
376
bedlam/compiler/ast/package.go
Normal file
376
bedlam/compiler/ast/package.go
Normal file
@ -0,0 +1,376 @@
|
||||
//
|
||||
// Copyright (c) 2020-2024 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package ast
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/circuit"
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler/ssa"
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler/utils"
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/types"
|
||||
)
|
||||
|
||||
// Package implements a QCL package.
|
||||
type Package struct {
|
||||
Name string
|
||||
Source string
|
||||
Annotations Annotations
|
||||
Initialized bool
|
||||
Imports map[string]string
|
||||
Bindings *ssa.Bindings
|
||||
Types []*TypeInfo
|
||||
Constants []*ConstantDef
|
||||
Variables []*VariableDef
|
||||
Functions map[string]*Func
|
||||
}
|
||||
|
||||
// NewPackage creates a new package.
|
||||
func NewPackage(name, source string, annotations Annotations) *Package {
|
||||
return &Package{
|
||||
Name: name,
|
||||
Source: source,
|
||||
Annotations: annotations,
|
||||
Imports: make(map[string]string),
|
||||
Bindings: new(ssa.Bindings),
|
||||
Functions: make(map[string]*Func),
|
||||
}
|
||||
}
|
||||
|
||||
// Compile compiles the package.
|
||||
func (pkg *Package) Compile(ctx *Codegen) (*ssa.Program, Annotations, error) {
|
||||
|
||||
main, err := pkg.Main()
|
||||
if err != nil {
|
||||
return nil, nil, ctx.Error(utils.Point{
|
||||
Source: pkg.Source,
|
||||
}, err.Error())
|
||||
}
|
||||
|
||||
gen := ssa.NewGenerator(ctx.Params)
|
||||
|
||||
// Init is the program start point.
|
||||
init := gen.Block()
|
||||
|
||||
// Init package.
|
||||
block, err := pkg.Init(ctx.Packages, init, ctx, gen)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Main block derives package's bindings from block with NextBlock().
|
||||
ctx.PushCompilation(gen.NextBlock(block), gen.Block(), nil, main)
|
||||
|
||||
// Arguments.
|
||||
var inputs circuit.IO
|
||||
for idx, arg := range main.Args {
|
||||
typeInfo, err := arg.Type.Resolve(NewEnv(ctx.Start()), ctx, gen)
|
||||
if err != nil {
|
||||
return nil, nil, ctx.Errorf(arg, "invalid argument type: %s", err)
|
||||
}
|
||||
if !typeInfo.Concrete() {
|
||||
if ctx.MainInputSizes == nil {
|
||||
return nil, nil,
|
||||
ctx.Errorf(arg, "argument %s of %s has unspecified type",
|
||||
arg.Name, main)
|
||||
}
|
||||
// Specify unspecified argument type.
|
||||
if idx >= len(ctx.MainInputSizes) {
|
||||
return nil, nil, ctx.Errorf(arg,
|
||||
"not enough values for argument %s of %s",
|
||||
arg.Name, main)
|
||||
}
|
||||
err = typeInfo.InstantiateWithSizes(ctx.MainInputSizes[idx])
|
||||
if err != nil {
|
||||
return nil, nil, ctx.Errorf(arg,
|
||||
"can't specify unspecified argument %s of %s: %s",
|
||||
arg.Name, main, err)
|
||||
}
|
||||
}
|
||||
// Define argument in block.
|
||||
a := gen.NewVal(arg.Name, typeInfo, ctx.Scope())
|
||||
ctx.Start().Bindings.Define(a, nil)
|
||||
|
||||
input := circuit.IOArg{
|
||||
Name: arg.Name,
|
||||
Type: a.Type,
|
||||
}
|
||||
if typeInfo.Type == types.TStruct {
|
||||
input.Compound = flattenStruct(typeInfo)
|
||||
}
|
||||
|
||||
inputs = append(inputs, input)
|
||||
}
|
||||
|
||||
// Compile main.
|
||||
_, returnVars, err := main.SSA(ctx.Start(), ctx, gen)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Return values
|
||||
var outputs circuit.IO
|
||||
for idx, rt := range main.Return {
|
||||
if idx >= len(returnVars) {
|
||||
return nil, nil, fmt.Errorf("too few values for %s", main)
|
||||
}
|
||||
typeInfo, err := rt.Type.Resolve(NewEnv(ctx.Start()), ctx, gen)
|
||||
if err != nil {
|
||||
return nil, nil, ctx.Errorf(rt, "invalid return type: %s", err)
|
||||
}
|
||||
// Instantiate result values for template functions.
|
||||
if !typeInfo.Concrete() && !typeInfo.Instantiate(returnVars[idx].Type) {
|
||||
return nil, nil, ctx.Errorf(main,
|
||||
"invalid value %v for return value %d of %s",
|
||||
returnVars[idx].Type, idx, main)
|
||||
}
|
||||
// The native() returns undefined values.
|
||||
if returnVars[idx].Type.Undefined() {
|
||||
returnVars[idx].Type.Type = typeInfo.Type
|
||||
}
|
||||
if !ssa.CanAssign(typeInfo, returnVars[idx]) {
|
||||
return nil, nil,
|
||||
ctx.Errorf(main, "invalid value %v for return value %d of %s",
|
||||
returnVars[idx].Type, idx, main)
|
||||
}
|
||||
|
||||
v := returnVars[idx]
|
||||
outputs = append(outputs, circuit.IOArg{
|
||||
Name: v.String(),
|
||||
Type: v.Type,
|
||||
})
|
||||
}
|
||||
|
||||
steps := init.Serialize()
|
||||
|
||||
program, err := ssa.NewProgram(ctx.Params, inputs, outputs, gen.Constants(),
|
||||
steps)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if false { // XXX Peephole liveness analysis is broken.
|
||||
err = program.Peephole()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
program.GC()
|
||||
|
||||
if ctx.Params.SSAOut != nil {
|
||||
program.PP(ctx.Params.SSAOut)
|
||||
}
|
||||
if ctx.Params.SSADotOut != nil {
|
||||
ssa.Dot(ctx.Params.SSADotOut, init)
|
||||
}
|
||||
|
||||
return program, main.Annotations, nil
|
||||
}
|
||||
|
||||
// Main returns package's main function.
|
||||
func (pkg *Package) Main() (*Func, error) {
|
||||
main, ok := pkg.Functions["main"]
|
||||
if !ok {
|
||||
return nil, errors.New("no main function defined")
|
||||
}
|
||||
return main, nil
|
||||
}
|
||||
|
||||
func flattenStruct(t types.Info) circuit.IO {
|
||||
var result circuit.IO
|
||||
if t.Type != types.TStruct {
|
||||
return result
|
||||
}
|
||||
|
||||
for _, f := range t.Struct {
|
||||
if f.Type.Type == types.TStruct {
|
||||
ios := flattenStruct(f.Type)
|
||||
result = append(result, ios...)
|
||||
} else {
|
||||
result = append(result, circuit.IOArg{
|
||||
Name: f.Name,
|
||||
Type: f.Type,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Init initializes the package.
|
||||
func (pkg *Package) Init(packages map[string]*Package, block *ssa.Block,
|
||||
ctx *Codegen, gen *ssa.Generator) (*ssa.Block, error) {
|
||||
|
||||
if pkg.Initialized {
|
||||
return block, nil
|
||||
}
|
||||
pkg.Initialized = true
|
||||
if ctx.Verbose {
|
||||
fmt.Printf("Initializing %s\n", pkg.Name)
|
||||
}
|
||||
|
||||
// Imported packages.
|
||||
for alias, name := range pkg.Imports {
|
||||
p, ok := packages[alias]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("imported and not used: \"%s\"", name)
|
||||
}
|
||||
var err error
|
||||
block, err = p.Init(packages, block, ctx, gen)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Define constants.
|
||||
for _, def := range pkg.Constants {
|
||||
err := pkg.defineConstant(def, ctx, gen)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Define types.
|
||||
for _, typeDef := range pkg.Types {
|
||||
err := pkg.defineType(typeDef, ctx, gen)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Package initializer block.
|
||||
block = gen.NextBlock(block)
|
||||
block.Name = fmt.Sprintf(".%s", pkg.Name)
|
||||
|
||||
// Package sees only its bindings.
|
||||
block.Bindings = pkg.Bindings.Clone()
|
||||
|
||||
var err error
|
||||
|
||||
// Define variables.
|
||||
for _, def := range pkg.Variables {
|
||||
block, _, err = def.SSA(block, ctx, gen)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
pkg.Bindings = block.Bindings
|
||||
|
||||
return block, nil
|
||||
}
|
||||
|
||||
func (pkg *Package) defineConstant(def *ConstantDef, ctx *Codegen,
|
||||
gen *ssa.Generator) error {
|
||||
|
||||
env := &Env{
|
||||
Bindings: pkg.Bindings,
|
||||
}
|
||||
|
||||
typeInfo, err := def.Type.Resolve(env, ctx, gen)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
constVal, ok, err := def.Init.Eval(env, ctx, gen)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
return ctx.Errorf(def.Init, "init value is not constant")
|
||||
}
|
||||
constVar := gen.Constant(constVal, typeInfo)
|
||||
if typeInfo.Undefined() {
|
||||
typeInfo.Type = constVar.Type.Type
|
||||
}
|
||||
if !typeInfo.Concrete() {
|
||||
typeInfo.Bits = constVar.Type.Bits
|
||||
}
|
||||
if !typeInfo.CanAssignConst(constVar.Type) {
|
||||
return ctx.Errorf(def.Init,
|
||||
"invalid init value %s for type %s", constVar.Type, typeInfo)
|
||||
}
|
||||
|
||||
_, ok = pkg.Bindings.Get(def.Name)
|
||||
if ok {
|
||||
return ctx.Errorf(def, "constant %s already defined", def.Name)
|
||||
}
|
||||
lValue := constVar
|
||||
lValue.Name = def.Name
|
||||
pkg.Bindings.Define(lValue, &constVar)
|
||||
gen.AddConstant(constVal)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pkg *Package) defineType(def *TypeInfo, ctx *Codegen,
|
||||
gen *ssa.Generator) error {
|
||||
|
||||
_, ok := pkg.Bindings.Get(def.TypeName)
|
||||
if ok {
|
||||
return ctx.Errorf(def, "type %s already defined", def.TypeName)
|
||||
}
|
||||
env := &Env{
|
||||
Bindings: pkg.Bindings,
|
||||
}
|
||||
var info types.Info
|
||||
var err error
|
||||
|
||||
switch def.Type {
|
||||
case TypeStruct:
|
||||
// Construct compound type.
|
||||
var fields []types.StructField
|
||||
var bits types.Size
|
||||
var minBits types.Size
|
||||
var offset types.Size
|
||||
for _, field := range def.StructFields {
|
||||
info, err := field.Type.Resolve(env, ctx, gen)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
field := types.StructField{
|
||||
Name: field.Name,
|
||||
Type: info,
|
||||
}
|
||||
field.Type.Offset = offset
|
||||
fields = append(fields, field)
|
||||
|
||||
bits += info.Bits
|
||||
minBits += info.MinBits
|
||||
offset += info.Bits
|
||||
}
|
||||
info = types.Info{
|
||||
Type: types.TStruct,
|
||||
IsConcrete: true,
|
||||
Bits: bits,
|
||||
MinBits: minBits,
|
||||
Struct: fields,
|
||||
}
|
||||
|
||||
case TypeArray:
|
||||
info, err = def.Resolve(env, ctx, gen)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case TypeAlias:
|
||||
info, err = def.AliasType.Resolve(env, ctx, gen)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
default:
|
||||
return ctx.Errorf(def, "invalid type definition: %s", def)
|
||||
}
|
||||
|
||||
info.ID = ctx.DefineType(def)
|
||||
|
||||
v := gen.Constant(info, types.Undefined)
|
||||
lval := gen.NewVal(def.TypeName, info, ctx.Scope())
|
||||
pkg.Bindings.Define(lval, &v)
|
||||
|
||||
return nil
|
||||
}
|
||||
2302
bedlam/compiler/ast/ssagen.go
Normal file
2302
bedlam/compiler/ast/ssagen.go
Normal file
File diff suppressed because it is too large
Load Diff
101
bedlam/compiler/circuits/allocator.go
Normal file
101
bedlam/compiler/circuits/allocator.go
Normal file
@ -0,0 +1,101 @@
|
||||
//
|
||||
// Copyright (c) 2023 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package circuits
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"unsafe"
|
||||
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/circuit"
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/types"
|
||||
)
|
||||
|
||||
var (
|
||||
sizeofWire = uint64(unsafe.Sizeof(Wire{}))
|
||||
sizeofGate = uint64(unsafe.Sizeof(Gate{}))
|
||||
)
|
||||
|
||||
// Allocator implements circuit wire and gate allocation.
|
||||
type Allocator struct {
|
||||
numWire uint64
|
||||
numWires uint64
|
||||
numGates uint64
|
||||
}
|
||||
|
||||
// NewAllocator creates a new circuit allocator.
|
||||
func NewAllocator() *Allocator {
|
||||
return new(Allocator)
|
||||
}
|
||||
|
||||
// Wire allocates a new Wire.
|
||||
func (alloc *Allocator) Wire() *Wire {
|
||||
alloc.numWire++
|
||||
w := new(Wire)
|
||||
w.Reset(UnassignedID)
|
||||
return w
|
||||
}
|
||||
|
||||
// Wires allocate an array of Wires.
|
||||
func (alloc *Allocator) Wires(bits types.Size) []*Wire {
|
||||
alloc.numWires += uint64(bits)
|
||||
|
||||
wires := make([]Wire, bits)
|
||||
result := make([]*Wire, bits)
|
||||
for i := 0; i < int(bits); i++ {
|
||||
w := &wires[i]
|
||||
w.id = UnassignedID
|
||||
result[i] = w
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// BinaryGate creates a new binary gate.
|
||||
func (alloc *Allocator) BinaryGate(op circuit.Operation, a, b, o *Wire) *Gate {
|
||||
alloc.numGates++
|
||||
gate := &Gate{
|
||||
Op: op,
|
||||
A: a,
|
||||
B: b,
|
||||
O: o,
|
||||
}
|
||||
a.AddOutput(gate)
|
||||
b.AddOutput(gate)
|
||||
o.SetInput(gate)
|
||||
|
||||
return gate
|
||||
}
|
||||
|
||||
// INVGate creates a new INV gate.
|
||||
func (alloc *Allocator) INVGate(i, o *Wire) *Gate {
|
||||
alloc.numGates++
|
||||
gate := &Gate{
|
||||
Op: circuit.INV,
|
||||
A: i,
|
||||
O: o,
|
||||
}
|
||||
i.AddOutput(gate)
|
||||
o.SetInput(gate)
|
||||
|
||||
return gate
|
||||
}
|
||||
|
||||
// Debug print debugging information about the circuit allocator.
|
||||
func (alloc *Allocator) Debug() {
|
||||
wireSize := circuit.FileSize(alloc.numWire * sizeofWire)
|
||||
wiresSize := circuit.FileSize(alloc.numWires * sizeofWire)
|
||||
gatesSize := circuit.FileSize(alloc.numGates * sizeofGate)
|
||||
|
||||
total := float64(wireSize + wiresSize + gatesSize)
|
||||
|
||||
fmt.Println("circuits.Allocator:")
|
||||
fmt.Printf(" wire : %9v %5s %5.2f%%\n",
|
||||
alloc.numWire, wireSize, float64(wireSize)/total*100.0)
|
||||
fmt.Printf(" wires: %9v %5s %5.2f%%\n",
|
||||
alloc.numWires, wiresSize, float64(wiresSize)/total*100.0)
|
||||
fmt.Printf(" gates: %9v %5s %5.2f%%\n",
|
||||
alloc.numGates, gatesSize, float64(gatesSize)/total*100.0)
|
||||
}
|
||||
96
bedlam/compiler/circuits/circ_adder.go
Normal file
96
bedlam/compiler/circuits/circ_adder.go
Normal file
@ -0,0 +1,96 @@
|
||||
//
|
||||
// circ_adder.go
|
||||
//
|
||||
// Copyright (c) 2019-2023 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package circuits
|
||||
|
||||
import (
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/circuit"
|
||||
)
|
||||
|
||||
// NewHalfAdder creates a half adder circuit.
|
||||
func NewHalfAdder(cc *Compiler, a, b, s, c *Wire) {
|
||||
// S = XOR(A, B)
|
||||
cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, a, b, s))
|
||||
|
||||
if c != nil {
|
||||
// C = AND(A, B)
|
||||
cc.AddGate(cc.Calloc.BinaryGate(circuit.AND, a, b, c))
|
||||
}
|
||||
}
|
||||
|
||||
// NewFullAdder creates a full adder circuit
|
||||
func NewFullAdder(cc *Compiler, a, b, cin, s, cout *Wire) {
|
||||
w1 := cc.Calloc.Wire()
|
||||
w2 := cc.Calloc.Wire()
|
||||
w3 := cc.Calloc.Wire()
|
||||
|
||||
// s = a XOR b XOR cin
|
||||
// cout = cin XOR ((a XOR cin) AND (b XOR cin)).
|
||||
|
||||
// w1 = XOR(b, cin)
|
||||
cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, b, cin, w1))
|
||||
|
||||
// s = XOR(a, w1)
|
||||
cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, a, w1, s))
|
||||
|
||||
if cout != nil {
|
||||
// w2 = XOR(a, cin)
|
||||
cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, a, cin, w2))
|
||||
|
||||
// w3 = AND(w1, w2)
|
||||
cc.AddGate(cc.Calloc.BinaryGate(circuit.AND, w1, w2, w3))
|
||||
|
||||
// cout = XOR(cin, w3)
|
||||
cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, cin, w3, cout))
|
||||
}
|
||||
}
|
||||
|
||||
// NewAdder creates a new adder circuit implementing z=x+y.
|
||||
func NewAdder(cc *Compiler, x, y, z []*Wire) error {
|
||||
x, y = cc.ZeroPad(x, y)
|
||||
if len(x) > len(z) {
|
||||
x = x[0:len(z)]
|
||||
y = y[0:len(z)]
|
||||
}
|
||||
|
||||
if len(x) == 1 {
|
||||
var cin *Wire
|
||||
if len(z) > 1 {
|
||||
cin = z[1]
|
||||
}
|
||||
NewHalfAdder(cc, x[0], y[0], z[0], cin)
|
||||
} else {
|
||||
cin := cc.Calloc.Wire()
|
||||
NewHalfAdder(cc, x[0], y[0], z[0], cin)
|
||||
|
||||
for i := 1; i < len(x); i++ {
|
||||
var cout *Wire
|
||||
if i+1 >= len(x) {
|
||||
if i+1 >= len(z) {
|
||||
// N+N=N, overflow, drop carry bit.
|
||||
cout = nil
|
||||
} else {
|
||||
cout = z[len(x)]
|
||||
}
|
||||
} else {
|
||||
cout = cc.Calloc.Wire()
|
||||
}
|
||||
|
||||
NewFullAdder(cc, x[i], y[i], cin, z[i], cout)
|
||||
|
||||
cin = cout
|
||||
}
|
||||
}
|
||||
|
||||
// Set all leftover bits to zero.
|
||||
for i := len(x) + 1; i < len(z); i++ {
|
||||
z[i] = cc.ZeroWire()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
65
bedlam/compiler/circuits/circ_binary.go
Normal file
65
bedlam/compiler/circuits/circ_binary.go
Normal file
@ -0,0 +1,65 @@
|
||||
//
|
||||
// Copyright (c) 2020-2021, 2023 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package circuits
|
||||
|
||||
import (
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/circuit"
|
||||
)
|
||||
|
||||
// NewBinaryAND creates a new binary AND circuit implementing r=x&y
|
||||
func NewBinaryAND(cc *Compiler, x, y, r []*Wire) error {
|
||||
x, y = cc.ZeroPad(x, y)
|
||||
if len(r) < len(x) {
|
||||
x = x[0:len(r)]
|
||||
y = y[0:len(r)]
|
||||
}
|
||||
for i := 0; i < len(x); i++ {
|
||||
cc.AddGate(cc.Calloc.BinaryGate(circuit.AND, x[i], y[i], r[i]))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewBinaryClear creates a new binary clear circuit implementing r=x&^y.
|
||||
func NewBinaryClear(cc *Compiler, x, y, r []*Wire) error {
|
||||
x, y = cc.ZeroPad(x, y)
|
||||
if len(r) < len(x) {
|
||||
x = x[0:len(r)]
|
||||
y = y[0:len(r)]
|
||||
}
|
||||
for i := 0; i < len(x); i++ {
|
||||
w := cc.Calloc.Wire()
|
||||
cc.INV(y[i], w)
|
||||
cc.AddGate(cc.Calloc.BinaryGate(circuit.AND, x[i], w, r[i]))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewBinaryOR creates a new binary OR circuit implementing r=x|y.
|
||||
func NewBinaryOR(cc *Compiler, x, y, r []*Wire) error {
|
||||
x, y = cc.ZeroPad(x, y)
|
||||
if len(r) < len(x) {
|
||||
x = x[0:len(r)]
|
||||
y = y[0:len(r)]
|
||||
}
|
||||
for i := 0; i < len(x); i++ {
|
||||
cc.AddGate(cc.Calloc.BinaryGate(circuit.OR, x[i], y[i], r[i]))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewBinaryXOR creates a new binary XOR circuit implementing r=x^y.
|
||||
func NewBinaryXOR(cc *Compiler, x, y, r []*Wire) error {
|
||||
x, y = cc.ZeroPad(x, y)
|
||||
if len(r) < len(x) {
|
||||
x = x[0:len(r)]
|
||||
y = y[0:len(r)]
|
||||
}
|
||||
for i := 0; i < len(x); i++ {
|
||||
cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, x[i], y[i], r[i]))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
161
bedlam/compiler/circuits/circ_comparators.go
Normal file
161
bedlam/compiler/circuits/circ_comparators.go
Normal file
@ -0,0 +1,161 @@
|
||||
//
|
||||
// Copyright (c) 2020-2023 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package circuits
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/circuit"
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/types"
|
||||
)
|
||||
|
||||
// comparator tests if x>y if cin=0, and x>=y if cin=1.
|
||||
func comparator(cc *Compiler, cin *Wire, x, y, r []*Wire) error {
|
||||
x, y = cc.ZeroPad(x, y)
|
||||
if len(r) != 1 {
|
||||
return fmt.Errorf("invalid lt comparator arguments: r=%d", len(r))
|
||||
}
|
||||
|
||||
for i := 0; i < len(x); i++ {
|
||||
w1 := cc.Calloc.Wire()
|
||||
cc.AddGate(cc.Calloc.BinaryGate(circuit.XNOR, cin, y[i], w1))
|
||||
w2 := cc.Calloc.Wire()
|
||||
cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, cin, x[i], w2))
|
||||
w3 := cc.Calloc.Wire()
|
||||
cc.AddGate(cc.Calloc.BinaryGate(circuit.AND, w1, w2, w3))
|
||||
|
||||
var cout *Wire
|
||||
if i+1 < len(x) {
|
||||
cout = cc.Calloc.Wire()
|
||||
} else {
|
||||
cout = r[0]
|
||||
}
|
||||
cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, cin, w3, cout))
|
||||
cin = cout
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewGtComparator tests if x>y.
|
||||
func NewGtComparator(cc *Compiler, x, y, r []*Wire) error {
|
||||
return comparator(cc, cc.ZeroWire(), x, y, r)
|
||||
}
|
||||
|
||||
// NewGeComparator tests if x>=y.
|
||||
func NewGeComparator(cc *Compiler, x, y, r []*Wire) error {
|
||||
return comparator(cc, cc.OneWire(), x, y, r)
|
||||
}
|
||||
|
||||
// NewLtComparator tests if x<y.
|
||||
func NewLtComparator(cc *Compiler, x, y, r []*Wire) error {
|
||||
return comparator(cc, cc.ZeroWire(), y, x, r)
|
||||
}
|
||||
|
||||
// NewLeComparator tests if x<=y.
|
||||
func NewLeComparator(cc *Compiler, x, y, r []*Wire) error {
|
||||
return comparator(cc, cc.OneWire(), y, x, r)
|
||||
}
|
||||
|
||||
// NewNeqComparator tewsts if x!=y.
|
||||
func NewNeqComparator(cc *Compiler, x, y, r []*Wire) error {
|
||||
x, y = cc.ZeroPad(x, y)
|
||||
if len(r) != 1 {
|
||||
return fmt.Errorf("invalid neq comparator arguments: r=%d", len(r))
|
||||
}
|
||||
|
||||
if len(x) == 1 {
|
||||
cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, x[0], y[0], r[0]))
|
||||
return nil
|
||||
}
|
||||
|
||||
c := cc.Calloc.Wire()
|
||||
cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, x[0], y[0], c))
|
||||
|
||||
for i := 1; i < len(x); i++ {
|
||||
xor := cc.Calloc.Wire()
|
||||
cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, x[i], y[i], xor))
|
||||
|
||||
var out *Wire
|
||||
if i+1 >= len(x) {
|
||||
out = r[0]
|
||||
} else {
|
||||
out = cc.Calloc.Wire()
|
||||
}
|
||||
cc.AddGate(cc.Calloc.BinaryGate(circuit.OR, c, xor, out))
|
||||
c = out
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewEqComparator tests if x==y.
|
||||
func NewEqComparator(cc *Compiler, x, y, r []*Wire) error {
|
||||
if len(r) != 1 {
|
||||
return fmt.Errorf("invalid eq comparator arguments: r=%d", len(r))
|
||||
}
|
||||
|
||||
// w = x == y
|
||||
w := cc.Calloc.Wire()
|
||||
err := NewNeqComparator(cc, x, y, []*Wire{w})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// r = !w
|
||||
cc.INV(w, r[0])
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewLogicalAND implements logical AND implementing r=x&y. The input
|
||||
// and output wires must be 1 bit wide.
|
||||
func NewLogicalAND(cc *Compiler, x, y, r []*Wire) error {
|
||||
if len(x) != 1 || len(y) != 1 || len(r) != 1 {
|
||||
return fmt.Errorf("invalid logical and arguments: x=%d, y=%d, r=%d",
|
||||
len(x), len(y), len(r))
|
||||
}
|
||||
cc.AddGate(cc.Calloc.BinaryGate(circuit.AND, x[0], y[0], r[0]))
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewLogicalOR implements logical OR implementing r=x|y. The input
|
||||
// and output wires must be 1 bit wide.
|
||||
func NewLogicalOR(cc *Compiler, x, y, r []*Wire) error {
|
||||
if len(x) != 1 || len(y) != 1 || len(r) != 1 {
|
||||
return fmt.Errorf("invalid logical or arguments: x=%d, y=%d, r=%d",
|
||||
len(x), len(y), len(r))
|
||||
}
|
||||
cc.AddGate(cc.Calloc.BinaryGate(circuit.OR, x[0], y[0], r[0]))
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewBitSetTest tests if the index'th bit of x is set.
|
||||
func NewBitSetTest(cc *Compiler, x []*Wire, index types.Size, r []*Wire) error {
|
||||
if len(r) != 1 {
|
||||
return fmt.Errorf("invalid bit set test arguments: x=%d, r=%d",
|
||||
len(x), len(r))
|
||||
}
|
||||
if index < types.Size(len(x)) {
|
||||
w := cc.ZeroWire()
|
||||
cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, x[index], w, r[0]))
|
||||
} else {
|
||||
r[0] = cc.ZeroWire()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewBitClrTest tests if the index'th bit of x is unset.
|
||||
func NewBitClrTest(cc *Compiler, x []*Wire, index types.Size, r []*Wire) error {
|
||||
if len(r) != 1 {
|
||||
return fmt.Errorf("invalid bit clear test arguments: x=%d, r=%d",
|
||||
len(x), len(r))
|
||||
}
|
||||
if index < types.Size(len(x)) {
|
||||
w := cc.OneWire()
|
||||
cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, x[index], w, r[0]))
|
||||
} else {
|
||||
r[0] = cc.OneWire()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
309
bedlam/compiler/circuits/circ_divider.go
Normal file
309
bedlam/compiler/circuits/circ_divider.go
Normal file
@ -0,0 +1,309 @@
|
||||
//
|
||||
// Copyright (c) 2019-2024 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package circuits
|
||||
|
||||
// NewUDividerLong creates an unsigned integer division circuit
|
||||
// computing r=a/b, q=a%b. This function uses Long Division algorithm.
|
||||
func NewUDividerLong(cc *Compiler, a, b, q, rret []*Wire) error {
|
||||
a, b = cc.ZeroPad(a, b)
|
||||
|
||||
r := make([]*Wire, len(a))
|
||||
for i := 0; i < len(r); i++ {
|
||||
r[i] = cc.ZeroWire()
|
||||
}
|
||||
|
||||
for i := len(a) - 1; i >= 0; i-- {
|
||||
// r << 1
|
||||
for j := len(r) - 1; j > 0; j-- {
|
||||
r[j] = r[j-1]
|
||||
}
|
||||
r[0] = a[i]
|
||||
|
||||
// r-d, overlow: r < d
|
||||
diff := make([]*Wire, len(r)+1)
|
||||
for j := 0; j < len(diff); j++ {
|
||||
diff[j] = cc.Calloc.Wire()
|
||||
}
|
||||
err := NewSubtractor(cc, r, b, diff)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if i < len(q) {
|
||||
err = NewMUX(cc, diff[len(diff)-1:], []*Wire{cc.ZeroWire()},
|
||||
[]*Wire{cc.OneWire()}, q[i:i+1])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
nr := make([]*Wire, len(r))
|
||||
for j := 0; j < len(nr); j++ {
|
||||
if i == 0 && j < len(rret) {
|
||||
nr[j] = rret[j]
|
||||
} else {
|
||||
nr[j] = cc.Calloc.Wire()
|
||||
}
|
||||
}
|
||||
|
||||
err = NewMUX(cc, diff[len(diff)-1:], r, diff[:len(diff)-1], nr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r = nr
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewUDividerRestoring creates an unsigned integer division circuit
|
||||
// computing r=a/b, q=a%b. This function uses Restoring Division
|
||||
// algorithm.
|
||||
func NewUDividerRestoring(cc *Compiler, a, b, q, rret []*Wire) error {
|
||||
a, b = cc.ZeroPad(a, b)
|
||||
|
||||
r := make([]*Wire, len(a)*2)
|
||||
for i := 0; i < len(r); i++ {
|
||||
if i < len(a) {
|
||||
r[i] = a[i]
|
||||
} else {
|
||||
r[i] = cc.ZeroWire()
|
||||
}
|
||||
}
|
||||
d := make([]*Wire, len(b)*2)
|
||||
for i := 0; i < len(d); i++ {
|
||||
if i < len(b) {
|
||||
d[i] = cc.ZeroWire()
|
||||
} else {
|
||||
d[i] = b[i-len(b)]
|
||||
}
|
||||
}
|
||||
|
||||
for i := len(a) - 1; i >= 0; i-- {
|
||||
// r << 1
|
||||
for j := len(r) - 1; j > 0; j-- {
|
||||
r[j] = r[j-1]
|
||||
}
|
||||
r[0] = cc.ZeroWire()
|
||||
|
||||
// r-d, overlow: r < d
|
||||
diff := make([]*Wire, len(r)+1)
|
||||
for j := 0; j < len(diff); j++ {
|
||||
diff[j] = cc.Calloc.Wire()
|
||||
}
|
||||
err := NewSubtractor(cc, r, d, diff)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if i < len(q) {
|
||||
err = NewMUX(cc, diff[len(diff)-1:], []*Wire{cc.ZeroWire()},
|
||||
[]*Wire{cc.OneWire()}, q[i:i+1])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
nr := make([]*Wire, len(r))
|
||||
for j := 0; j < len(nr); j++ {
|
||||
if i == 0 && j >= len(a) && j-len(a) < len(rret) {
|
||||
nr[j] = rret[j-len(a)]
|
||||
} else {
|
||||
nr[j] = cc.Calloc.Wire()
|
||||
}
|
||||
}
|
||||
|
||||
err = NewMUX(cc, diff[len(diff)-1:], r, diff[:len(diff)-1], nr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r = nr
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewUDividerArray creates an unsigned integer division circuit
|
||||
// computing r=a/b, q=a%b. This function uses Array Divider algorithm.
|
||||
func NewUDividerArray(cc *Compiler, a, b, q, r []*Wire) error {
|
||||
a, b = cc.ZeroPad(a, b)
|
||||
|
||||
rIn := make([]*Wire, len(b)+1)
|
||||
rOut := make([]*Wire, len(b)+1)
|
||||
|
||||
// Init bINV.
|
||||
bINV := make([]*Wire, len(b))
|
||||
for i := 0; i < len(b); i++ {
|
||||
bINV[i] = cc.Calloc.Wire()
|
||||
cc.INV(b[i], bINV[i])
|
||||
}
|
||||
|
||||
// Init for the first row.
|
||||
for i := 0; i < len(b); i++ {
|
||||
rOut[i] = cc.ZeroWire()
|
||||
}
|
||||
|
||||
// Generate matrix.
|
||||
for y := 0; y < len(a); y++ {
|
||||
// Init rIn.
|
||||
rIn[0] = a[len(a)-1-y]
|
||||
copy(rIn[1:], rOut)
|
||||
|
||||
// Adders from b{0} to b{n-1}, 0
|
||||
cIn := cc.OneWire()
|
||||
for x := 0; x < len(b)+1; x++ {
|
||||
var bw *Wire
|
||||
if x < len(b) {
|
||||
bw = bINV[x]
|
||||
} else {
|
||||
bw = cc.OneWire() // INV(0)
|
||||
}
|
||||
co := cc.Calloc.Wire()
|
||||
ro := cc.Calloc.Wire()
|
||||
NewFullAdder(cc, rIn[x], bw, cIn, ro, co)
|
||||
rOut[x] = ro
|
||||
cIn = co
|
||||
}
|
||||
|
||||
// Quotient y.
|
||||
if len(a)-1-y < len(q) {
|
||||
w := cc.Calloc.Wire()
|
||||
cc.INV(cIn, w)
|
||||
cc.INV(w, q[len(a)-1-y])
|
||||
}
|
||||
|
||||
// MUXes from high to low bit.
|
||||
for x := len(b); x >= 0; x-- {
|
||||
var ro *Wire
|
||||
if y+1 >= len(a) && x < len(r) {
|
||||
ro = r[x]
|
||||
} else {
|
||||
ro = cc.Calloc.Wire()
|
||||
}
|
||||
|
||||
err := NewMUX(cc, []*Wire{cIn}, rOut[x:x+1], rIn[x:x+1],
|
||||
[]*Wire{ro})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rOut[x] = ro
|
||||
}
|
||||
}
|
||||
|
||||
// Set extra quotient bits to zero.
|
||||
for y := len(a); y < len(q); y++ {
|
||||
q[y] = cc.ZeroWire()
|
||||
}
|
||||
|
||||
// Set extra remainder bits to zero.
|
||||
for x := len(b); x < len(r); x++ {
|
||||
r[x] = cc.ZeroWire()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewUDivider creates an unsigned integer division circuit computing
|
||||
// r=a/b, q=a%b.
|
||||
func NewUDivider(cc *Compiler, a, b, q, r []*Wire) error {
|
||||
return NewUDividerLong(cc, a, b, q, r)
|
||||
}
|
||||
|
||||
// NewIDivider creates a signed integer division circuit computing
|
||||
// r=a/b, q=a%b. The function converts negative a and b to positive
|
||||
// values before doing the divmod with NewUDivider. If a or b was
|
||||
// negative, the funtion converts the result quotient to negative
|
||||
// value.
|
||||
func NewIDivider(cc *Compiler, a, b, q, r []*Wire) error {
|
||||
a, b = cc.ZeroPad(a, b)
|
||||
|
||||
zero := []*Wire{cc.ZeroWire()}
|
||||
neg0 := cc.ZeroWire()
|
||||
|
||||
// If a is negative, set neg=!neg, a=-a.
|
||||
|
||||
neg1 := cc.Calloc.Wire()
|
||||
cc.INV(neg0, neg1)
|
||||
|
||||
a1 := make([]*Wire, len(a))
|
||||
for i := 0; i < len(a1); i++ {
|
||||
a1[i] = cc.Calloc.Wire()
|
||||
}
|
||||
err := NewSubtractor(cc, zero, a, a1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
neg2 := cc.Calloc.Wire()
|
||||
err = NewMUX(cc, a[len(a)-1:], []*Wire{neg1}, []*Wire{neg0}, []*Wire{neg2})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
a2 := make([]*Wire, len(a))
|
||||
for i := 0; i < len(a2); i++ {
|
||||
a2[i] = cc.Calloc.Wire()
|
||||
}
|
||||
|
||||
err = NewMUX(cc, a[len(a)-1:], a1, a, a2)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If b is negative, set neg=!neg, b=-b.
|
||||
|
||||
neg3 := cc.Calloc.Wire()
|
||||
cc.INV(neg2, neg3)
|
||||
|
||||
b1 := make([]*Wire, len(b))
|
||||
for i := 0; i < len(b1); i++ {
|
||||
b1[i] = cc.Calloc.Wire()
|
||||
}
|
||||
err = NewSubtractor(cc, zero, b, b1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
neg4 := cc.Calloc.Wire()
|
||||
err = NewMUX(cc, b[len(b)-1:], []*Wire{neg3}, []*Wire{neg2}, []*Wire{neg4})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
b2 := make([]*Wire, len(b))
|
||||
for i := 0; i < len(a2); i++ {
|
||||
b2[i] = cc.Calloc.Wire()
|
||||
}
|
||||
|
||||
err = NewMUX(cc, b[len(b)-1:], b1, b, b2)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(q) == 0 {
|
||||
// Modulo operation.
|
||||
return NewUDivider(cc, a2, b2, q, r)
|
||||
}
|
||||
|
||||
// If neg is set, set q=-q
|
||||
|
||||
q0 := make([]*Wire, len(q))
|
||||
for i := 0; i < len(q0); i++ {
|
||||
q0[i] = cc.Calloc.Wire()
|
||||
}
|
||||
err = NewUDivider(cc, a2, b2, q0, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
q1 := make([]*Wire, len(q))
|
||||
for i := 0; i < len(q1); i++ {
|
||||
q1[i] = cc.Calloc.Wire()
|
||||
}
|
||||
err = NewSubtractor(cc, zero, q0, q1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return NewMUX(cc, []*Wire{neg4}, q1, q0, q)
|
||||
}
|
||||
44
bedlam/compiler/circuits/circ_hamming.go
Normal file
44
bedlam/compiler/circuits/circ_hamming.go
Normal file
@ -0,0 +1,44 @@
|
||||
//
|
||||
// Copyright (c) 2020-2023 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package circuits
|
||||
|
||||
import (
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/circuit"
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/types"
|
||||
)
|
||||
|
||||
// Hamming creates a hamming distance circuit computing the hamming
|
||||
// distance between a and b and returning the distance in r.
|
||||
func Hamming(cc *Compiler, a, b, r []*Wire) error {
|
||||
a, b = cc.ZeroPad(a, b)
|
||||
|
||||
var arr [][]*Wire
|
||||
for i := 0; i < len(a); i++ {
|
||||
w := cc.Calloc.Wire()
|
||||
cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, a[i], b[i], w))
|
||||
arr = append(arr, []*Wire{w})
|
||||
}
|
||||
|
||||
for len(arr) > 2 {
|
||||
var n [][]*Wire
|
||||
for i := 0; i < len(arr); i += 2 {
|
||||
if i+1 < len(arr) {
|
||||
result := cc.Calloc.Wires(types.Size(len(arr[i]) + 1))
|
||||
err := NewAdder(cc, arr[i], arr[i+1], result)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
n = append(n, result)
|
||||
} else {
|
||||
n = append(n, arr[i])
|
||||
}
|
||||
}
|
||||
arr = n
|
||||
}
|
||||
|
||||
return NewAdder(cc, arr[0], arr[1], r)
|
||||
}
|
||||
101
bedlam/compiler/circuits/circ_index.go
Normal file
101
bedlam/compiler/circuits/circ_index.go
Normal file
@ -0,0 +1,101 @@
|
||||
//
|
||||
// Copyright (c) 2021-2023 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package circuits
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// NewIndex creates a new array element selection (index) circuit.
|
||||
func NewIndex(cc *Compiler, size int, array, index, out []*Wire) error {
|
||||
if len(array)%size != 0 {
|
||||
return fmt.Errorf("array width %d must be multiple of element size %d",
|
||||
len(array), size)
|
||||
}
|
||||
if len(out) < size {
|
||||
return fmt.Errorf("out %d too small for element size %d",
|
||||
len(out), size)
|
||||
}
|
||||
n := len(array) / size
|
||||
if n == 0 {
|
||||
for i := 0; i < len(out); i++ {
|
||||
out[i] = cc.ZeroWire()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
bits := 1
|
||||
var length int
|
||||
|
||||
for length = 2; length < n; length *= 2 {
|
||||
bits++
|
||||
}
|
||||
|
||||
return newIndex(cc, bits-1, length, size, array, index, out)
|
||||
}
|
||||
|
||||
func newIndex(cc *Compiler, bit, length, size int,
|
||||
array, index, out []*Wire) error {
|
||||
|
||||
// Default "not found" value.
|
||||
def := make([]*Wire, size)
|
||||
for i := 0; i < size; i++ {
|
||||
def[i] = cc.ZeroWire()
|
||||
}
|
||||
|
||||
n := len(array) / size
|
||||
|
||||
if bit == 0 {
|
||||
fVal := array[:size]
|
||||
|
||||
var tVal []*Wire
|
||||
if n > 1 {
|
||||
tVal = array[size : 2*size]
|
||||
} else {
|
||||
tVal = def
|
||||
}
|
||||
return NewMUX(cc, index[0:1], tVal, fVal, out)
|
||||
}
|
||||
|
||||
length /= 2
|
||||
fArray := array
|
||||
if n > length {
|
||||
fArray = fArray[:length*size]
|
||||
}
|
||||
|
||||
if bit >= len(index) {
|
||||
// Not enough bits to select upper half so just select from
|
||||
// the lower half.
|
||||
return newIndex(cc, bit-1, length, size, fArray, index, out)
|
||||
}
|
||||
|
||||
fVal := make([]*Wire, size)
|
||||
for i := 0; i < size; i++ {
|
||||
fVal[i] = cc.Calloc.Wire()
|
||||
}
|
||||
err := newIndex(cc, bit-1, length, size, fArray, index, fVal)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var tVal []*Wire
|
||||
if n > length {
|
||||
tVal = make([]*Wire, size)
|
||||
for i := 0; i < size; i++ {
|
||||
tVal[i] = cc.Calloc.Wire()
|
||||
}
|
||||
err = newIndex(cc, bit-1, length, size,
|
||||
array[length*size:], index, tVal)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
tVal = def
|
||||
}
|
||||
|
||||
return NewMUX(cc, index[bit:bit+1], tVal, fVal, out)
|
||||
}
|
||||
227
bedlam/compiler/circuits/circ_multiplier.go
Normal file
227
bedlam/compiler/circuits/circ_multiplier.go
Normal file
@ -0,0 +1,227 @@
|
||||
//
|
||||
// circ_multiplier.go
|
||||
//
|
||||
// Copyright (c) 2019-2023 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package circuits
|
||||
|
||||
import (
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/circuit"
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/types"
|
||||
)
|
||||
|
||||
// NewMultiplier creates a multiplier circuit implementing x*y=z.
|
||||
func NewMultiplier(c *Compiler, arrayTreshold int, x, y, z []*Wire) error {
|
||||
if false {
|
||||
return NewArrayMultiplier(c, x, y, z)
|
||||
}
|
||||
if arrayTreshold < 8 {
|
||||
var ok bool
|
||||
|
||||
arrayTreshold, ok = multiplierArrayTresholds[len(x)]
|
||||
if !ok {
|
||||
arrayTreshold = 21
|
||||
}
|
||||
}
|
||||
return NewKaratsubaMultiplier(c, arrayTreshold, x, y, z)
|
||||
}
|
||||
|
||||
// NewArrayMultiplier creates a multiplier circuit implementing
|
||||
// x*y=z. This function implements Array Multiplier Circuit.
|
||||
func NewArrayMultiplier(cc *Compiler, x, y, z []*Wire) error {
|
||||
x, y = cc.ZeroPad(x, y)
|
||||
if len(x) > len(z) {
|
||||
x = x[0:len(z)]
|
||||
y = y[0:len(z)]
|
||||
}
|
||||
|
||||
// One bit multiplication is AND.
|
||||
if len(x) == 1 {
|
||||
cc.AddGate(cc.Calloc.BinaryGate(circuit.AND, x[0], y[0], z[0]))
|
||||
if len(z) > 1 {
|
||||
z[1] = cc.ZeroWire()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var sums []*Wire
|
||||
|
||||
// Construct Y0 sums
|
||||
for i, xn := range x {
|
||||
var s *Wire
|
||||
if i == 0 {
|
||||
s = z[0]
|
||||
} else {
|
||||
s = cc.Calloc.Wire()
|
||||
sums = append(sums, s)
|
||||
}
|
||||
cc.AddGate(cc.Calloc.BinaryGate(circuit.AND, xn, y[0], s))
|
||||
}
|
||||
|
||||
// Construct len(y)-2 intermediate layers
|
||||
var j int
|
||||
for j = 1; j+1 < len(y); j++ {
|
||||
// ANDs for y(j)
|
||||
var ands []*Wire
|
||||
for _, xn := range x {
|
||||
wire := cc.Calloc.Wire()
|
||||
cc.AddGate(cc.Calloc.BinaryGate(circuit.AND, xn, y[j], wire))
|
||||
ands = append(ands, wire)
|
||||
}
|
||||
|
||||
// Compute next sums.
|
||||
var nsums []*Wire
|
||||
var c *Wire
|
||||
for i := 0; i < len(ands); i++ {
|
||||
cout := cc.Calloc.Wire()
|
||||
|
||||
var s *Wire
|
||||
if i == 0 {
|
||||
s = z[j]
|
||||
} else {
|
||||
s = cc.Calloc.Wire()
|
||||
nsums = append(nsums, s)
|
||||
}
|
||||
|
||||
if i == 0 {
|
||||
NewHalfAdder(cc, ands[i], sums[i], s, cout)
|
||||
} else if i >= len(sums) {
|
||||
NewHalfAdder(cc, ands[i], c, s, cout)
|
||||
} else {
|
||||
NewFullAdder(cc, ands[i], sums[i], c, s, cout)
|
||||
}
|
||||
c = cout
|
||||
}
|
||||
// New sums with carry as the highest bit.
|
||||
sums = append(nsums, c)
|
||||
}
|
||||
|
||||
// Construct final layer.
|
||||
var c *Wire
|
||||
for i, xn := range x {
|
||||
and := cc.Calloc.Wire()
|
||||
cc.AddGate(cc.Calloc.BinaryGate(circuit.AND, xn, y[j], and))
|
||||
|
||||
var cout *Wire
|
||||
if i+1 >= len(x) && j+i+1 < len(z) {
|
||||
cout = z[j+i+1]
|
||||
} else {
|
||||
cout = cc.Calloc.Wire()
|
||||
}
|
||||
|
||||
if j+i < len(z) {
|
||||
if i == 0 {
|
||||
NewHalfAdder(cc, and, sums[i], z[j+i], cout)
|
||||
} else if i >= len(sums) {
|
||||
NewHalfAdder(cc, and, c, z[j+i], cout)
|
||||
} else {
|
||||
NewFullAdder(cc, and, sums[i], c, z[j+i], cout)
|
||||
}
|
||||
}
|
||||
c = cout
|
||||
}
|
||||
for i := j + len(x) + 1; i < len(z); i++ {
|
||||
z[1] = cc.ZeroWire()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewKaratsubaMultiplier creates a multiplier circuit implementing
|
||||
// the Karatsuba algorithm
|
||||
// (https://en.wikipedia.org/wiki/Karatsuba_algorithm). The Karatsuba
|
||||
// algorithm is should produce faster circuits on inputs of about 128
|
||||
// bits (the number of non-XOR gates is smaller). On input sizes of
|
||||
// 256 bits also the overall circuits are smaller than with the array
|
||||
// multiplier algorithm.
|
||||
//
|
||||
// Bits Array a-xor a-and Karatsu K-xor K-and
|
||||
// 8 301 172 129 993 724 269
|
||||
// 16 1365 852 513 3573 2660 913
|
||||
// 32 5797 3748 2049 11937 8980 2957
|
||||
// 64 23877 15684 8193 38277 28964 9313
|
||||
// 128 96901 64132 32769 119793 90964 28829
|
||||
// 256 390405 259332 131073 369333 281060 88273
|
||||
// 512 1567237 1042948 524289 1127937 859540 268397
|
||||
// 1024 6280197 4183044 2097153 3423717 2611364 812353
|
||||
// 2048 25143301 16754692 8388609 10350993 7899604 2451389
|
||||
func NewKaratsubaMultiplier(cc *Compiler, limit int, a, b, r []*Wire) error {
|
||||
|
||||
a, b = cc.ZeroPad(a, b)
|
||||
if len(a) > len(r) {
|
||||
a = a[0:len(r)]
|
||||
b = b[0:len(r)]
|
||||
}
|
||||
|
||||
// Compute smaller multiplications with array multiplier.
|
||||
if len(a) <= limit {
|
||||
return NewArrayMultiplier(cc, a, b, r)
|
||||
}
|
||||
|
||||
mid := len(a) / 2
|
||||
|
||||
aLow := a[:mid]
|
||||
aHigh := a[mid:]
|
||||
|
||||
bLow := b[:mid]
|
||||
bHigh := b[mid:]
|
||||
|
||||
z0 := cc.Calloc.Wires(types.Size(min(max(len(aLow), len(bLow))*2, len(r))))
|
||||
if err := NewKaratsubaMultiplier(cc, limit, aLow, bLow, z0); err != nil {
|
||||
return err
|
||||
}
|
||||
aSumLen := max(len(aLow), len(aHigh)) + 1
|
||||
aSum := cc.Calloc.Wires(types.Size(aSumLen))
|
||||
if err := NewAdder(cc, aLow, aHigh, aSum); err != nil {
|
||||
return err
|
||||
}
|
||||
bSumLen := max(len(bLow), len(bHigh)) + 1
|
||||
bSum := cc.Calloc.Wires(types.Size(bSumLen))
|
||||
if err := NewAdder(cc, bLow, bHigh, bSum); err != nil {
|
||||
return err
|
||||
}
|
||||
z1 := cc.Calloc.Wires(types.Size(min(max(aSumLen, bSumLen)*2, len(r))))
|
||||
if err := NewKaratsubaMultiplier(cc, limit, aSum, bSum, z1); err != nil {
|
||||
return err
|
||||
}
|
||||
z2 := cc.Calloc.Wires(types.Size(min(max(len(aHigh), len(bHigh))*2, len(r))))
|
||||
if err := NewKaratsubaMultiplier(cc, limit, aHigh, bHigh, z2); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sub1 := cc.Calloc.Wires(types.Size(len(r)))
|
||||
if err := NewSubtractor(cc, z1, z2, sub1); err != nil {
|
||||
return err
|
||||
}
|
||||
sub2 := cc.Calloc.Wires(types.Size(len(r)))
|
||||
if err := NewSubtractor(cc, sub1, z0, sub2); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
shift1 := cc.ShiftLeft(z2, len(r), mid*2)
|
||||
shift2 := cc.ShiftLeft(sub2, len(r), mid)
|
||||
|
||||
add1 := cc.Calloc.Wires(types.Size(len(r)))
|
||||
if err := NewAdder(cc, shift1, shift2, add1); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return NewAdder(cc, add1, z0, r)
|
||||
}
|
||||
|
||||
func max(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
1112
bedlam/compiler/circuits/circ_multiplier_params.go
Normal file
1112
bedlam/compiler/circuits/circ_multiplier_params.go
Normal file
File diff suppressed because it is too large
Load Diff
39
bedlam/compiler/circuits/circ_mux.go
Normal file
39
bedlam/compiler/circuits/circ_mux.go
Normal file
@ -0,0 +1,39 @@
|
||||
//
|
||||
// Copyright (c) 2020-2023 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package circuits
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/circuit"
|
||||
)
|
||||
|
||||
// NewMUX creates a multiplexer circuit that selects the input t or f
|
||||
// to output, based on the value of the condition cond.
|
||||
func NewMUX(cc *Compiler, cond, t, f, out []*Wire) error {
|
||||
t, f = cc.ZeroPad(t, f)
|
||||
if len(cond) != 1 || len(t) != len(f) || len(t) != len(out) {
|
||||
return fmt.Errorf("invalid mux arguments: cond=%d, l=%d, r=%d, out=%d",
|
||||
len(cond), len(t), len(f), len(out))
|
||||
}
|
||||
|
||||
for i := 0; i < len(t); i++ {
|
||||
w1 := cc.Calloc.Wire()
|
||||
w2 := cc.Calloc.Wire()
|
||||
|
||||
// w1 = XOR(f[i], t[i])
|
||||
cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, f[i], t[i], w1))
|
||||
|
||||
// w2 = AND(w1, cond)
|
||||
cc.AddGate(cc.Calloc.BinaryGate(circuit.AND, w1, cond[0], w2))
|
||||
|
||||
// out[i] = XOR(w2, f[i])
|
||||
cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, w2, f[i], out[i]))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
63
bedlam/compiler/circuits/circ_subtractor.go
Normal file
63
bedlam/compiler/circuits/circ_subtractor.go
Normal file
@ -0,0 +1,63 @@
|
||||
//
|
||||
// circ_subtractor.go
|
||||
//
|
||||
// Copyright (c) 2019-2023 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package circuits
|
||||
|
||||
import (
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/circuit"
|
||||
)
|
||||
|
||||
// NewFullSubtractor creates a full subtractor circuit.
|
||||
func NewFullSubtractor(cc *Compiler, x, y, cin, d, cout *Wire) {
|
||||
w1 := cc.Calloc.Wire()
|
||||
cc.AddGate(cc.Calloc.BinaryGate(circuit.XNOR, y, cin, w1))
|
||||
cc.AddGate(cc.Calloc.BinaryGate(circuit.XNOR, x, w1, d))
|
||||
|
||||
if cout != nil {
|
||||
w2 := cc.Calloc.Wire()
|
||||
cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, x, cin, w2))
|
||||
|
||||
w3 := cc.Calloc.Wire()
|
||||
cc.AddGate(cc.Calloc.BinaryGate(circuit.AND, w1, w2, w3))
|
||||
|
||||
cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, w3, cin, cout))
|
||||
}
|
||||
}
|
||||
|
||||
// NewSubtractor creates a new subtractor circuit implementing z=x-y.
|
||||
func NewSubtractor(cc *Compiler, x, y, z []*Wire) error {
|
||||
x, y = cc.ZeroPad(x, y)
|
||||
if len(x) > len(z) {
|
||||
x = x[0:len(z)]
|
||||
y = y[0:len(z)]
|
||||
}
|
||||
cin := cc.ZeroWire()
|
||||
|
||||
for i := 0; i < len(x); i++ {
|
||||
var cout *Wire
|
||||
if i+1 >= len(x) {
|
||||
if i+1 >= len(z) {
|
||||
// N-N=N, overflow, drop carry bit.
|
||||
cout = nil
|
||||
} else {
|
||||
cout = z[i+1]
|
||||
}
|
||||
} else {
|
||||
cout = cc.Calloc.Wire()
|
||||
}
|
||||
|
||||
// Note y-x here.
|
||||
NewFullSubtractor(cc, y[i], x[i], cin, z[i], cout)
|
||||
|
||||
cin = cout
|
||||
}
|
||||
for i := len(x) + 1; i < len(z); i++ {
|
||||
z[i] = cc.ZeroWire()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
144
bedlam/compiler/circuits/circuits_test.go
Normal file
144
bedlam/compiler/circuits/circuits_test.go
Normal file
@ -0,0 +1,144 @@
|
||||
//
|
||||
// circuits_test.go
|
||||
//
|
||||
// Copyright (c) 2019-2023 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package circuits
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/circuit"
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler/utils"
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/types"
|
||||
)
|
||||
|
||||
const (
|
||||
verbose = false
|
||||
)
|
||||
|
||||
var (
|
||||
params = utils.NewParams()
|
||||
calloc = NewAllocator()
|
||||
)
|
||||
|
||||
func makeWires(count int, output bool) []*Wire {
|
||||
var result []*Wire
|
||||
for i := 0; i < count; i++ {
|
||||
w := calloc.Wire()
|
||||
w.SetOutput(output)
|
||||
result = append(result, w)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func NewIO(size int, name string) circuit.IO {
|
||||
return circuit.IO{
|
||||
circuit.IOArg{
|
||||
Name: name,
|
||||
Type: types.Info{
|
||||
Type: types.TUint,
|
||||
IsConcrete: true,
|
||||
Bits: types.Size(size),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdd4(t *testing.T) {
|
||||
bits := 4
|
||||
|
||||
// 2xbits inputs, bits+1 outputs
|
||||
inputs := makeWires(bits*2, false)
|
||||
outputs := makeWires(bits+1, true)
|
||||
c, err := NewCompiler(params, calloc, NewIO(bits*2, "in"),
|
||||
NewIO(bits+1, "out"), inputs, outputs)
|
||||
if err != nil {
|
||||
t.Fatalf("NewCompiler: %s", err)
|
||||
}
|
||||
|
||||
cin := calloc.Wire()
|
||||
NewHalfAdder(c, inputs[0], inputs[bits], outputs[0], cin)
|
||||
|
||||
for i := 1; i < bits; i++ {
|
||||
var cout *Wire
|
||||
if i+1 >= bits {
|
||||
cout = outputs[bits]
|
||||
} else {
|
||||
cout = calloc.Wire()
|
||||
}
|
||||
|
||||
NewFullAdder(c, inputs[i], inputs[bits+i], cin, outputs[i], cout)
|
||||
|
||||
cin = cout
|
||||
}
|
||||
|
||||
result := c.Compile()
|
||||
if verbose {
|
||||
fmt.Printf("Result: %s\n", result)
|
||||
result.Marshal(os.Stdout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFullSubtractor(t *testing.T) {
|
||||
inputs := makeWires(1+2, false)
|
||||
outputs := makeWires(2, true)
|
||||
c, err := NewCompiler(params, calloc, NewIO(1+2, "in"), NewIO(2, "out"),
|
||||
inputs, outputs)
|
||||
if err != nil {
|
||||
t.Fatalf("NewCompiler: %s", err)
|
||||
}
|
||||
|
||||
NewFullSubtractor(c, inputs[0], inputs[1], inputs[2],
|
||||
outputs[0], outputs[1])
|
||||
|
||||
result := c.Compile()
|
||||
if verbose {
|
||||
fmt.Printf("Result: %s\n", result)
|
||||
result.Marshal(os.Stdout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultiply1(t *testing.T) {
|
||||
inputs := makeWires(2, false)
|
||||
outputs := makeWires(2, true)
|
||||
c, err := NewCompiler(params, calloc, NewIO(2, "in"), NewIO(2, "out"),
|
||||
inputs, outputs)
|
||||
if err != nil {
|
||||
t.Fatalf("NewCompiler: %s", err)
|
||||
}
|
||||
|
||||
err = NewMultiplier(c, 0, inputs[0:1], inputs[1:2], outputs)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultiply(t *testing.T) {
|
||||
bits := 64
|
||||
|
||||
inputs := makeWires(bits*2, false)
|
||||
outputs := makeWires(bits*2, true)
|
||||
|
||||
c, err := NewCompiler(params, calloc, NewIO(bits*2, "in"),
|
||||
NewIO(bits*2, "out"), inputs, outputs)
|
||||
if err != nil {
|
||||
t.Fatalf("NewCompiler: %s", err)
|
||||
}
|
||||
|
||||
err = NewMultiplier(c, 0, inputs[0:bits], inputs[bits:2*bits], outputs)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
result := c.Compile()
|
||||
if verbose {
|
||||
fmt.Printf("Result: %s\n", result)
|
||||
result.Marshal(os.Stdout)
|
||||
}
|
||||
}
|
||||
399
bedlam/compiler/circuits/compiler.go
Normal file
399
bedlam/compiler/circuits/compiler.go
Normal file
@ -0,0 +1,399 @@
|
||||
//
|
||||
// Copyright (c) 2019-2023 Markku Rossi
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
|
||||
package circuits
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/circuit"
|
||||
"source.quilibrium.com/quilibrium/monorepo/bedlam/compiler/utils"
|
||||
)
|
||||
|
||||
// Builtin implements a buitin circuit that uses input wires a and b
|
||||
// and returns the circuit result in r.
|
||||
type Builtin func(cc *Compiler, a, b, r []*Wire) error
|
||||
|
||||
// Compiler implements binary circuit compiler.
|
||||
type Compiler struct {
|
||||
Params *utils.Params
|
||||
Calloc *Allocator
|
||||
OutputsAssigned bool
|
||||
Inputs circuit.IO
|
||||
Outputs circuit.IO
|
||||
InputWires []*Wire
|
||||
OutputWires []*Wire
|
||||
Gates []*Gate
|
||||
nextWireID circuit.Wire
|
||||
pending []*Gate
|
||||
assigned []*Gate
|
||||
compiled []circuit.Gate
|
||||
invI0Wire *Wire
|
||||
zeroWire *Wire
|
||||
oneWire *Wire
|
||||
}
|
||||
|
||||
// NewCompiler creates a new circuit compiler for the specified
|
||||
// circuit input and output values.
|
||||
func NewCompiler(params *utils.Params, calloc *Allocator,
|
||||
inputs, outputs circuit.IO, inputWires, outputWires []*Wire) (
|
||||
*Compiler, error) {
|
||||
|
||||
if len(inputWires) == 0 {
|
||||
return nil, fmt.Errorf("no inputs defined")
|
||||
}
|
||||
return &Compiler{
|
||||
Params: params,
|
||||
Calloc: calloc,
|
||||
Inputs: inputs,
|
||||
Outputs: outputs,
|
||||
InputWires: inputWires,
|
||||
OutputWires: outputWires,
|
||||
Gates: make([]*Gate, 0, 65536),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// InvI0Wire returns a wire holding value INV(input[0]).
|
||||
func (cc *Compiler) InvI0Wire() *Wire {
|
||||
if cc.invI0Wire == nil {
|
||||
cc.invI0Wire = cc.Calloc.Wire()
|
||||
cc.AddGate(cc.Calloc.INVGate(cc.InputWires[0], cc.invI0Wire))
|
||||
}
|
||||
return cc.invI0Wire
|
||||
}
|
||||
|
||||
// ZeroWire returns a wire holding value 0.
|
||||
func (cc *Compiler) ZeroWire() *Wire {
|
||||
if cc.zeroWire == nil {
|
||||
cc.zeroWire = cc.Calloc.Wire()
|
||||
cc.AddGate(cc.Calloc.BinaryGate(circuit.AND, cc.InputWires[0],
|
||||
cc.InvI0Wire(), cc.zeroWire))
|
||||
cc.zeroWire.SetValue(Zero)
|
||||
}
|
||||
return cc.zeroWire
|
||||
}
|
||||
|
||||
// OneWire returns a wire holding value 1.
|
||||
func (cc *Compiler) OneWire() *Wire {
|
||||
if cc.oneWire == nil {
|
||||
cc.oneWire = cc.Calloc.Wire()
|
||||
cc.AddGate(cc.Calloc.BinaryGate(circuit.OR, cc.InputWires[0],
|
||||
cc.InvI0Wire(), cc.oneWire))
|
||||
cc.oneWire.SetValue(One)
|
||||
}
|
||||
return cc.oneWire
|
||||
}
|
||||
|
||||
// ZeroPad pads the argument wires x and y with zero values so that
|
||||
// the resulting wires have the same number of bits.
|
||||
func (cc *Compiler) ZeroPad(x, y []*Wire) ([]*Wire, []*Wire) {
|
||||
if len(x) == len(y) {
|
||||
return x, y
|
||||
}
|
||||
|
||||
max := len(x)
|
||||
if len(y) > max {
|
||||
max = len(y)
|
||||
}
|
||||
|
||||
rx := make([]*Wire, max)
|
||||
for i := 0; i < max; i++ {
|
||||
if i < len(x) {
|
||||
rx[i] = x[i]
|
||||
} else {
|
||||
rx[i] = cc.ZeroWire()
|
||||
}
|
||||
}
|
||||
|
||||
ry := make([]*Wire, max)
|
||||
for i := 0; i < max; i++ {
|
||||
if i < len(y) {
|
||||
ry[i] = y[i]
|
||||
} else {
|
||||
ry[i] = cc.ZeroWire()
|
||||
}
|
||||
}
|
||||
|
||||
return rx, ry
|
||||
}
|
||||
|
||||
// ShiftLeft shifts the size number of bits of the input wires w,
|
||||
// count bits left.
|
||||
func (cc *Compiler) ShiftLeft(w []*Wire, size, count int) []*Wire {
|
||||
result := make([]*Wire, size)
|
||||
|
||||
if count < size {
|
||||
copy(result[count:], w)
|
||||
}
|
||||
for i := 0; i < count; i++ {
|
||||
result[i] = cc.ZeroWire()
|
||||
}
|
||||
for i := count + len(w); i < size; i++ {
|
||||
result[i] = cc.ZeroWire()
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// INV creates an inverse wire inverting the input wire i's value to
|
||||
// the output wire o.
|
||||
func (cc *Compiler) INV(i, o *Wire) {
|
||||
cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, i, cc.OneWire(), o))
|
||||
}
|
||||
|
||||
// ID creates an identity wire passing the input wire i's value to the
|
||||
// output wire o.
|
||||
func (cc *Compiler) ID(i, o *Wire) {
|
||||
cc.AddGate(cc.Calloc.BinaryGate(circuit.XOR, i, cc.ZeroWire(), o))
|
||||
}
|
||||
|
||||
// AddGate adds a get into the circuit.
|
||||
func (cc *Compiler) AddGate(gate *Gate) {
|
||||
cc.Gates = append(cc.Gates, gate)
|
||||
}
|
||||
|
||||
// SetNextWireID sets the next unique wire ID to use.
|
||||
func (cc *Compiler) SetNextWireID(next circuit.Wire) {
|
||||
cc.nextWireID = next
|
||||
}
|
||||
|
||||
// NextWireID returns the next unique wire ID.
|
||||
func (cc *Compiler) NextWireID() circuit.Wire {
|
||||
ret := cc.nextWireID
|
||||
cc.nextWireID++
|
||||
return ret
|
||||
}
|
||||
|
||||
// ConstPropagate propagates constant wire values in the circuit and
|
||||
// short circuits gates if their output does not depend on the gate's
|
||||
// logical operation.
|
||||
func (cc *Compiler) ConstPropagate() {
|
||||
var stats circuit.Stats
|
||||
|
||||
start := time.Now()
|
||||
|
||||
for _, g := range cc.Gates {
|
||||
switch g.Op {
|
||||
case circuit.XOR:
|
||||
if (g.A.Value() == Zero && g.B.Value() == Zero) ||
|
||||
(g.A.Value() == One && g.B.Value() == One) {
|
||||
g.O.SetValue(Zero)
|
||||
stats[g.Op]++
|
||||
} else if (g.A.Value() == Zero && g.B.Value() == One) ||
|
||||
(g.A.Value() == One && g.B.Value() == Zero) {
|
||||
g.O.SetValue(One)
|
||||
stats[g.Op]++
|
||||
} else if g.A.Value() == Zero {
|
||||
// O = B
|
||||
stats[g.Op]++
|
||||
g.ShortCircuit(g.B)
|
||||
} else if g.B.Value() == Zero {
|
||||
// O = A
|
||||
stats[g.Op]++
|
||||
g.ShortCircuit(g.A)
|
||||
}
|
||||
|
||||
case circuit.XNOR:
|
||||
if (g.A.Value() == Zero && g.B.Value() == Zero) ||
|
||||
(g.A.Value() == One && g.B.Value() == One) {
|
||||
g.O.SetValue(One)
|
||||
stats[g.Op]++
|
||||
} else if (g.A.Value() == Zero && g.B.Value() == One) ||
|
||||
(g.A.Value() == One && g.B.Value() == Zero) {
|
||||
g.O.SetValue(Zero)
|
||||
stats[g.Op]++
|
||||
}
|
||||
|
||||
case circuit.AND:
|
||||
if g.A.Value() == Zero || g.B.Value() == Zero {
|
||||
g.O.SetValue(Zero)
|
||||
stats[g.Op]++
|
||||
} else if g.A.Value() == One && g.B.Value() == One {
|
||||
g.O.SetValue(One)
|
||||
stats[g.Op]++
|
||||
} else if g.A.Value() == One {
|
||||
// O = B
|
||||
stats[g.Op]++
|
||||
g.ShortCircuit(g.B)
|
||||
} else if g.B.Value() == One {
|
||||
// O = A
|
||||
stats[g.Op]++
|
||||
g.ShortCircuit(g.A)
|
||||
}
|
||||
|
||||
case circuit.OR:
|
||||
if g.A.Value() == One || g.B.Value() == One {
|
||||
g.O.SetValue(One)
|
||||
stats[g.Op]++
|
||||
} else if g.A.Value() == Zero && g.B.Value() == Zero {
|
||||
g.O.SetValue(Zero)
|
||||
stats[g.Op]++
|
||||
} else if g.A.Value() == Zero {
|
||||
// O = B
|
||||
stats[g.Op]++
|
||||
g.ShortCircuit(g.B)
|
||||
} else if g.B.Value() == Zero {
|
||||
// O = A
|
||||
stats[g.Op]++
|
||||
g.ShortCircuit(g.A)
|
||||
}
|
||||
|
||||
case circuit.INV:
|
||||
if g.A.Value() == One {
|
||||
g.O.SetValue(Zero)
|
||||
stats[g.Op]++
|
||||
} else if g.A.Value() == Zero {
|
||||
g.O.SetValue(One)
|
||||
stats[g.Op]++
|
||||
}
|
||||
}
|
||||
|
||||
if g.A.Value() == Zero {
|
||||
g.A.RemoveOutput(g)
|
||||
g.A = cc.ZeroWire()
|
||||
g.A.AddOutput(g)
|
||||
} else if g.A.Value() == One {
|
||||
g.A.RemoveOutput(g)
|
||||
g.A = cc.OneWire()
|
||||
g.A.AddOutput(g)
|
||||
}
|
||||
if g.B != nil {
|
||||
if g.B.Value() == Zero {
|
||||
g.B.RemoveOutput(g)
|
||||
g.B = cc.ZeroWire()
|
||||
g.B.AddOutput(g)
|
||||
} else if g.B.Value() == One {
|
||||
g.B.RemoveOutput(g)
|
||||
g.B = cc.OneWire()
|
||||
g.B.AddOutput(g)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if cc.Params.Diagnostics && stats.Count() > 0 {
|
||||
fmt.Printf(" - ConstPropagate: %12s: %d/%d (%.2f%%)\n",
|
||||
elapsed, stats.Count(), len(cc.Gates),
|
||||
float64(stats.Count())/float64(len(cc.Gates))*100)
|
||||
}
|
||||
}
|
||||
|
||||
// ShortCircuitXORZero short circuits input to output where input is
|
||||
// XOR'ed to zero.
|
||||
func (cc *Compiler) ShortCircuitXORZero() {
|
||||
var stats circuit.Stats
|
||||
|
||||
start := time.Now()
|
||||
|
||||
for _, g := range cc.Gates {
|
||||
if g.Op != circuit.XOR {
|
||||
continue
|
||||
}
|
||||
if g.A.Value() == Zero && !g.B.IsInput() &&
|
||||
g.B.Input().O.NumOutputs() == 1 {
|
||||
|
||||
g.B.Input().ResetOutput(g.O)
|
||||
|
||||
// Disconnect gate's output wire.
|
||||
g.O = cc.Calloc.Wire()
|
||||
|
||||
stats[g.Op]++
|
||||
}
|
||||
if g.B.Value() == Zero && !g.A.IsInput() &&
|
||||
g.A.Input().O.NumOutputs() == 1 {
|
||||
|
||||
g.A.Input().ResetOutput(g.O)
|
||||
|
||||
// Disconnect gate's output wire.
|
||||
g.O = cc.Calloc.Wire()
|
||||
|
||||
stats[g.Op]++
|
||||
}
|
||||
}
|
||||
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if cc.Params.Diagnostics && stats.Count() > 0 {
|
||||
fmt.Printf(" - ShortCircuitXORZero: %12s: %d/%d (%.2f%%)\n",
|
||||
elapsed, stats.Count(), len(cc.Gates),
|
||||
float64(stats.Count())/float64(len(cc.Gates))*100)
|
||||
}
|
||||
}
|
||||
|
||||
// Prune removes all gates whose output wires are unused.
|
||||
func (cc *Compiler) Prune() int {
|
||||
|
||||
n := make([]*Gate, len(cc.Gates))
|
||||
nPos := len(n)
|
||||
|
||||
for i := len(cc.Gates) - 1; i >= 0; i-- {
|
||||
g := cc.Gates[i]
|
||||
if !g.Prune() {
|
||||
nPos--
|
||||
n[nPos] = g
|
||||
}
|
||||
}
|
||||
cc.Gates = n[nPos:]
|
||||
|
||||
return nPos
|
||||
}
|
||||
|
||||
// Compile compiles the circuit.
|
||||
func (cc *Compiler) Compile() *circuit.Circuit {
|
||||
if len(cc.pending) != 0 {
|
||||
panic("Compile: pending set")
|
||||
}
|
||||
cc.pending = make([]*Gate, 0, len(cc.Gates))
|
||||
if len(cc.assigned) != 0 {
|
||||
panic("Compile: assigned set")
|
||||
}
|
||||
cc.assigned = make([]*Gate, 0, len(cc.Gates))
|
||||
if len(cc.compiled) != 0 {
|
||||
panic("Compile: compiled set")
|
||||
}
|
||||
cc.compiled = make([]circuit.Gate, 0, len(cc.Gates))
|
||||
|
||||
for _, w := range cc.InputWires {
|
||||
w.Assign(cc)
|
||||
}
|
||||
for len(cc.pending) > 0 {
|
||||
gate := cc.pending[0]
|
||||
cc.pending = cc.pending[1:]
|
||||
gate.Assign(cc)
|
||||
}
|
||||
// Assign outputs.
|
||||
for _, w := range cc.OutputWires {
|
||||
if w.Assigned() {
|
||||
if !cc.OutputsAssigned {
|
||||
panic("Output already assigned")
|
||||
}
|
||||
} else {
|
||||
w.SetID(cc.NextWireID())
|
||||
}
|
||||
}
|
||||
|
||||
// Compile circuit.
|
||||
for _, gate := range cc.assigned {
|
||||
gate.Compile(cc)
|
||||
}
|
||||
|
||||
var stats circuit.Stats
|
||||
for _, g := range cc.compiled {
|
||||
stats[g.Op]++
|
||||
}
|
||||
|
||||
result := &circuit.Circuit{
|
||||
NumGates: len(cc.compiled),
|
||||
NumWires: int(cc.nextWireID),
|
||||
Inputs: cc.Inputs,
|
||||
Outputs: cc.Outputs,
|
||||
Gates: cc.compiled,
|
||||
Stats: stats,
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user