ceremonyclient/bedlam/p2p/protocol.go

372 lines
7.6 KiB
Go

//
// Copyright (c) 2019-2023 Markku Rossi
//
// All rights reserved.
//
package p2p
import (
"io"
"sync/atomic"
"source.quilibrium.com/quilibrium/monorepo/bedlam/ot"
)
var (
_ ot.IO = &Conn{}
)
const (
numBuffers = 3
writeBufSize = 64 * 1024
readBufSize = 1024 * 1024
)
// Conn implements a protocol connection.
type Conn struct {
conn io.ReadWriter
WriteBuf []byte
WritePos int
ReadBuf []byte
ReadStart int
ReadEnd int
Stats IOStats
fromWriter chan []byte
toWriter chan []byte
writerErr error
}
// IOStats implements I/O statistics.
type IOStats struct {
Sent *atomic.Uint64
Recvd *atomic.Uint64
Flushed *atomic.Uint64
}
// NewIOStats creates a new I/O statistics object.
func NewIOStats() IOStats {
return IOStats{
Sent: new(atomic.Uint64),
Recvd: new(atomic.Uint64),
Flushed: new(atomic.Uint64),
}
}
// Add adds the argument stats to this IOStats and returns the sum.
func (stats IOStats) Add(o IOStats) IOStats {
sent := new(atomic.Uint64)
sent.Store(stats.Sent.Load() + o.Sent.Load())
recvd := new(atomic.Uint64)
recvd.Store(stats.Recvd.Load() + o.Recvd.Load())
flushed := new(atomic.Uint64)
flushed.Store(stats.Flushed.Load() + o.Flushed.Load())
return IOStats{
Sent: sent,
Recvd: recvd,
Flushed: flushed,
}
}
// Sum returns sum of sent and received bytes.
func (stats IOStats) Sum() uint64 {
return stats.Sent.Load() + stats.Recvd.Load()
}
// NewConn creates a new connection around the argument connection.
func NewConn(conn io.ReadWriter) *Conn {
c := &Conn{
conn: conn,
ReadBuf: make([]byte, readBufSize),
fromWriter: make(chan []byte, numBuffers),
toWriter: make(chan []byte, numBuffers),
Stats: NewIOStats(),
}
go c.writer()
c.WriteBuf = <-c.fromWriter
return c
}
func (c *Conn) writer() {
for i := 0; i < numBuffers; i++ {
c.fromWriter <- make([]byte, writeBufSize)
}
for buf := range c.toWriter {
_, err := c.conn.Write(buf)
if err != nil {
c.writerErr = err
}
c.fromWriter <- buf[0:cap(buf)]
}
close(c.fromWriter)
}
// NeedSpace ensures the write buffer has space for count bytes. The
// function flushes the output if needed.
func (c *Conn) NeedSpace(count int) error {
if c.WritePos+count > len(c.WriteBuf) {
return c.Flush()
}
return nil
}
// Flush flushed any pending data in the connection.
func (c *Conn) Flush() error {
if c.WritePos > 0 {
c.Stats.Sent.Add(uint64(c.WritePos))
c.toWriter <- c.WriteBuf[0:c.WritePos]
next := <-c.fromWriter
if c.writerErr != nil {
return c.writerErr
}
c.WriteBuf = next
c.WritePos = 0
c.Stats.Flushed.Add(1)
}
return nil
}
// Fill fills the input buffer from the connection. Any unused data in
// the buffer is moved to the beginning of the buffer.
func (c *Conn) Fill(n int) error {
if c.ReadStart < c.ReadEnd {
copy(c.ReadBuf[0:], c.ReadBuf[c.ReadStart:c.ReadEnd])
c.ReadEnd -= c.ReadStart
c.ReadStart = 0
} else {
c.ReadStart = 0
c.ReadEnd = 0
}
for c.ReadStart+n > c.ReadEnd {
got, err := c.conn.Read(c.ReadBuf[c.ReadEnd:])
if err != nil {
return err
}
c.Stats.Recvd.Add(uint64(got))
c.ReadEnd += got
}
return nil
}
// Close flushes any pending data and closes the connection.
func (c *Conn) Close() error {
if err := c.Flush(); err != nil {
return err
}
// Wait that flush completes.
close(c.toWriter)
for range <-c.fromWriter {
}
if c.writerErr != nil {
return c.writerErr
}
closer, ok := c.conn.(io.Closer)
if ok {
return closer.Close()
}
return nil
}
// SendByte sends a byte value.
func (c *Conn) SendByte(val byte) error {
if c.WritePos+1 > len(c.WriteBuf) {
if err := c.Flush(); err != nil {
return err
}
}
c.WriteBuf[c.WritePos] = val
c.WritePos++
return nil
}
// SendUint16 sends an uint16 value.
func (c *Conn) SendUint16(val int) error {
if c.WritePos+2 > len(c.WriteBuf) {
if err := c.Flush(); err != nil {
return err
}
}
c.WriteBuf[c.WritePos+0] = byte((uint32(val) >> 8) & 0xff)
c.WriteBuf[c.WritePos+1] = byte(uint32(val) & 0xff)
c.WritePos += 2
return nil
}
// SendUint32 sends an uint32 value.
func (c *Conn) SendUint32(val int) error {
if c.WritePos+4 > len(c.WriteBuf) {
if err := c.Flush(); err != nil {
return err
}
}
c.WriteBuf[c.WritePos+0] = byte((uint32(val) >> 24) & 0xff)
c.WriteBuf[c.WritePos+1] = byte((uint32(val) >> 16) & 0xff)
c.WriteBuf[c.WritePos+2] = byte((uint32(val) >> 8) & 0xff)
c.WriteBuf[c.WritePos+3] = byte(uint32(val) & 0xff)
c.WritePos += 4
return nil
}
// SendData sends binary data.
func (c *Conn) SendData(val []byte) error {
if c.WritePos+4+len(val) > len(c.WriteBuf) {
if err := c.Flush(); err != nil {
return err
}
}
err := c.SendUint32(len(val))
if err != nil {
return err
}
copy(c.WriteBuf[c.WritePos:], val)
c.WritePos += len(val)
return nil
}
// SendLabel sends an OT label.
func (c *Conn) SendLabel(val ot.Label, data *ot.LabelData) error {
bytes := val.Bytes(data)
if c.WritePos+len(bytes) > len(c.WriteBuf) {
if err := c.Flush(); err != nil {
return err
}
}
copy(c.WriteBuf[c.WritePos:], bytes)
c.WritePos += len(bytes)
return nil
}
// SendString sends a string value.
func (c *Conn) SendString(val string) error {
return c.SendData([]byte(val))
}
// SendInputSizes sends the input sizes.
func (c *Conn) SendInputSizes(sizes []int) error {
if err := c.SendUint32(len(sizes)); err != nil {
return err
}
for i := 0; i < len(sizes); i++ {
if err := c.SendUint32(sizes[i]); err != nil {
return err
}
}
return nil
}
// ReceiveByte receives a byte value.
func (c *Conn) ReceiveByte() (byte, error) {
if c.ReadStart+1 > c.ReadEnd {
if err := c.Fill(1); err != nil {
return 0, err
}
}
val := c.ReadBuf[c.ReadStart]
c.ReadStart++
return val, nil
}
// ReceiveUint16 receives an uint16 value.
func (c *Conn) ReceiveUint16() (int, error) {
if c.ReadStart+2 > c.ReadEnd {
if err := c.Fill(2); err != nil {
return 0, err
}
}
val := uint32(c.ReadBuf[c.ReadStart+0])
val <<= 8
val |= uint32(c.ReadBuf[c.ReadStart+1])
c.ReadStart += 2
return int(val), nil
}
// ReceiveUint32 receives an uint32 value.
func (c *Conn) ReceiveUint32() (int, error) {
if c.ReadStart+4 > c.ReadEnd {
if err := c.Fill(4); err != nil {
return 0, err
}
}
val := uint32(c.ReadBuf[c.ReadStart+0])
val <<= 8
val |= uint32(c.ReadBuf[c.ReadStart+1])
val <<= 8
val |= uint32(c.ReadBuf[c.ReadStart+2])
val <<= 8
val |= uint32(c.ReadBuf[c.ReadStart+3])
c.ReadStart += 4
return int(val), nil
}
// ReceiveData receives binary data.
func (c *Conn) ReceiveData() ([]byte, error) {
len, err := c.ReceiveUint32()
if err != nil {
return nil, err
}
if c.ReadStart+len > c.ReadEnd {
if err := c.Fill(len); err != nil {
return nil, err
}
}
result := make([]byte, len)
copy(result, c.ReadBuf[c.ReadStart:c.ReadStart+len])
c.ReadStart += len
return result, nil
}
// ReceiveLabel receives an OT label.
func (c *Conn) ReceiveLabel(val *ot.Label, data *ot.LabelData) error {
if c.ReadStart+len(data) > c.ReadEnd {
if err := c.Fill(len(data)); err != nil {
return err
}
}
copy(data[:], c.ReadBuf[c.ReadStart:c.ReadStart+len(data)])
c.ReadStart += len(data)
val.SetData(data)
return nil
}
// ReceiveString receives a string value.
func (c *Conn) ReceiveString() (string, error) {
data, err := c.ReceiveData()
if err != nil {
return "", err
}
return string(data), nil
}
// ReceiveInputSizes receives input sizes.
func (c *Conn) ReceiveInputSizes() ([]int, error) {
count, err := c.ReceiveUint32()
if err != nil {
return nil, err
}
result := make([]int, count)
for i := 0; i < count; i++ {
size, err := c.ReceiveUint32()
if err != nil {
return nil, err
}
result[i] = size
}
return result, nil
}