ceremonyclient/go-libp2p/p2p/protocol/autonatv2/server.go
Cassandra Heart dbd95bd9e9
v2.1.0 (#439)
* v2.1.0 [omit consensus and adjacent] - this commit will be amended with the full release after the file copy is complete

* 2.1.0 main node rollup
2025-09-30 02:48:15 -05:00

563 lines
16 KiB
Go

package autonatv2
import (
"context"
"errors"
"fmt"
"io"
"os"
"runtime/debug"
"sync"
"time"
pool "github.com/libp2p/go-buffer-pool"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/peerstore"
"github.com/libp2p/go-libp2p/p2p/protocol/autonatv2/pb"
"github.com/libp2p/go-msgio/pbio"
"math/rand"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
)
var (
errResourceLimitExceeded = errors.New("resource limit exceeded")
errBadRequest = errors.New("bad request")
errDialDataRefused = errors.New("dial data refused")
)
type dataRequestPolicyFunc = func(observedAddr, dialAddr ma.Multiaddr) bool
type EventDialRequestCompleted struct {
Error error
ResponseStatus pb.DialResponse_ResponseStatus
DialStatus pb.DialStatus
DialDataRequired bool
DialedAddr ma.Multiaddr
}
// server implements the AutoNATv2 server.
// It can ask client to provide dial data before attempting the requested dial.
// It rate limits requests on a global level, per peer level and on whether the request requires dial data.
type server struct {
host host.Host
dialerHost host.Host
limiter *rateLimiter
// dialDataRequestPolicy is used to determine whether dialing the address requires receiving
// dial data. It is set to amplification attack prevention by default.
dialDataRequestPolicy dataRequestPolicyFunc
amplificatonAttackPreventionDialWait time.Duration
metricsTracer MetricsTracer
// for tests
now func() time.Time
allowPrivateAddrs bool
}
func newServer(dialer host.Host, s *autoNATSettings) *server {
return &server{
dialerHost: dialer,
dialDataRequestPolicy: s.dataRequestPolicy,
amplificatonAttackPreventionDialWait: s.amplificatonAttackPreventionDialWait,
allowPrivateAddrs: s.allowPrivateAddrs,
limiter: &rateLimiter{
RPM: s.serverRPM,
PerPeerRPM: s.serverPerPeerRPM,
DialDataRPM: s.serverDialDataRPM,
MaxConcurrentRequestsPerPeer: s.maxConcurrentRequestsPerPeer,
now: s.now,
},
now: s.now,
metricsTracer: s.metricsTracer,
}
}
// Enable attaches the stream handler to the host.
func (as *server) Start(h host.Host) {
as.host = h
as.host.SetStreamHandler(DialProtocol, as.handleDialRequest)
}
func (as *server) Close() {
as.host.RemoveStreamHandler(DialProtocol)
as.dialerHost.Close()
as.limiter.Close()
}
// handleDialRequest is the dial-request protocol stream handler
func (as *server) handleDialRequest(s network.Stream) {
defer func() {
if rerr := recover(); rerr != nil {
fmt.Fprintf(os.Stderr, "caught panic: %s\n%s\n", rerr, debug.Stack())
s.Reset()
}
}()
log.Debug("received dial-request",
"remote_peer", s.Conn().RemotePeer(),
"remote_multiaddr", s.Conn().RemoteMultiaddr())
evt := as.serveDialRequest(s)
log.Debug("completed dial-request",
"remote_peer", s.Conn().RemotePeer(),
"response_status", evt.ResponseStatus,
"dial_status", evt.DialStatus,
"error", evt.Error)
if as.metricsTracer != nil {
as.metricsTracer.CompletedRequest(evt)
}
}
func (as *server) serveDialRequest(s network.Stream) EventDialRequestCompleted {
if err := s.Scope().SetService(ServiceName); err != nil {
s.Reset()
log.Debug("failed to attach stream to service",
"service_name", ServiceName,
"error", err)
return EventDialRequestCompleted{
Error: errors.New("failed to attach stream to autonat-v2"),
}
}
if err := s.Scope().ReserveMemory(maxMsgSize, network.ReservationPriorityAlways); err != nil {
s.Reset()
log.Debug("failed to reserve memory for stream",
"protocol", DialProtocol,
"error", err)
return EventDialRequestCompleted{Error: errResourceLimitExceeded}
}
defer s.Scope().ReleaseMemory(maxMsgSize)
deadline := as.now().Add(streamTimeout)
ctx, cancel := context.WithDeadline(context.Background(), deadline)
defer cancel()
s.SetDeadline(as.now().Add(streamTimeout))
defer s.Close()
p := s.Conn().RemotePeer()
var msg pb.Message
w := pbio.NewDelimitedWriter(s)
// Check for rate limit before parsing the request
if !as.limiter.Accept(p) {
msg = pb.Message{
Msg: &pb.Message_DialResponse{
DialResponse: &pb.DialResponse{
Status: pb.DialResponse_E_REQUEST_REJECTED,
},
},
}
if err := w.WriteMsg(&msg); err != nil {
s.Reset()
log.Debug("failed to write request rejected response",
"remote_peer", p,
"error", err)
return EventDialRequestCompleted{
ResponseStatus: pb.DialResponse_E_REQUEST_REJECTED,
Error: fmt.Errorf("write failed: %w", err),
}
}
log.Debug("rejected request",
"remote_peer", p,
"reason", "rate limit exceeded")
return EventDialRequestCompleted{ResponseStatus: pb.DialResponse_E_REQUEST_REJECTED}
}
defer as.limiter.CompleteRequest(p)
r := pbio.NewDelimitedReader(s, maxMsgSize)
if err := r.ReadMsg(&msg); err != nil {
s.Reset()
log.Debug("failed to read request",
"remote_peer", p,
"error", err)
return EventDialRequestCompleted{Error: fmt.Errorf("read failed: %w", err)}
}
if msg.GetDialRequest() == nil {
s.Reset()
log.Debug("invalid message type",
"remote_peer", p,
"actual_type", fmt.Sprintf("%T", msg.Msg),
"expected_type", "DialRequest")
return EventDialRequestCompleted{Error: errBadRequest}
}
// parse peer's addresses
var dialAddr ma.Multiaddr
var addrIdx int
for i, ab := range msg.GetDialRequest().GetAddrs() {
if i >= maxPeerAddresses {
break
}
a, err := ma.NewMultiaddrBytes(ab)
if err != nil {
continue
}
isPubAddr, err := manet.IsPublicAddr(a)
if (!as.allowPrivateAddrs && !isPubAddr) || err != nil {
continue
}
if !as.dialerHost.Network().CanDial(p, a) {
continue
}
dialAddr = a
addrIdx = i
break
}
// No dialable address
if dialAddr == nil {
msg = pb.Message{
Msg: &pb.Message_DialResponse{
DialResponse: &pb.DialResponse{
Status: pb.DialResponse_E_DIAL_REFUSED,
},
},
}
if err := w.WriteMsg(&msg); err != nil {
s.Reset()
log.Debug("failed to write dial refused response",
"remote_peer", p,
"error", err)
return EventDialRequestCompleted{
ResponseStatus: pb.DialResponse_E_DIAL_REFUSED,
Error: fmt.Errorf("write failed: %w", err),
}
}
return EventDialRequestCompleted{
ResponseStatus: pb.DialResponse_E_DIAL_REFUSED,
}
}
nonce := msg.GetDialRequest().Nonce
isDialDataRequired := as.dialDataRequestPolicy(s.Conn().RemoteMultiaddr(), dialAddr)
if isDialDataRequired && !as.limiter.AcceptDialDataRequest() {
msg = pb.Message{
Msg: &pb.Message_DialResponse{
DialResponse: &pb.DialResponse{
Status: pb.DialResponse_E_REQUEST_REJECTED,
},
},
}
if err := w.WriteMsg(&msg); err != nil {
s.Reset()
log.Debug("failed to write request rejected response",
"remote_peer", p,
"error", err)
return EventDialRequestCompleted{
ResponseStatus: pb.DialResponse_E_REQUEST_REJECTED,
Error: fmt.Errorf("write failed: %w", err),
DialDataRequired: true,
}
}
log.Debug("rejected request",
"remote_peer", p,
"reason", "rate limit exceeded")
return EventDialRequestCompleted{
ResponseStatus: pb.DialResponse_E_REQUEST_REJECTED,
DialDataRequired: true,
}
}
if isDialDataRequired {
if err := getDialData(w, s, &msg, addrIdx); err != nil {
s.Reset()
log.Debug("dial data request refused",
"remote_peer", p,
"error", err)
return EventDialRequestCompleted{
Error: errDialDataRefused,
DialDataRequired: true,
DialedAddr: dialAddr,
}
}
// wait for a bit to prevent thundering herd style attacks on a victim
waitTime := time.Duration(rand.Intn(int(as.amplificatonAttackPreventionDialWait) + 1)) // the range is [0, n)
t := time.NewTimer(waitTime)
defer t.Stop()
select {
case <-ctx.Done():
s.Reset()
log.Debug("rejecting request without dialing",
"remote_peer", p,
"error", ctx.Err())
return EventDialRequestCompleted{Error: ctx.Err(), DialDataRequired: true, DialedAddr: dialAddr}
case <-t.C:
}
}
dialStatus := as.dialBack(ctx, s.Conn().RemotePeer(), dialAddr, nonce)
msg = pb.Message{
Msg: &pb.Message_DialResponse{
DialResponse: &pb.DialResponse{
Status: pb.DialResponse_OK,
DialStatus: dialStatus,
AddrIdx: uint32(addrIdx),
},
},
}
if err := w.WriteMsg(&msg); err != nil {
s.Reset()
log.Debug("failed to write response",
"remote_peer", p,
"error", err)
return EventDialRequestCompleted{
ResponseStatus: pb.DialResponse_OK,
DialStatus: dialStatus,
Error: fmt.Errorf("write failed: %w", err),
DialDataRequired: isDialDataRequired,
DialedAddr: dialAddr,
}
}
return EventDialRequestCompleted{
ResponseStatus: pb.DialResponse_OK,
DialStatus: dialStatus,
Error: nil,
DialDataRequired: isDialDataRequired,
DialedAddr: dialAddr,
}
}
// getDialData gets data from the client for dialing the address
func getDialData(w pbio.Writer, s network.Stream, msg *pb.Message, addrIdx int) error {
numBytes := minHandshakeSizeBytes + rand.Intn(maxHandshakeSizeBytes-minHandshakeSizeBytes)
*msg = pb.Message{
Msg: &pb.Message_DialDataRequest{
DialDataRequest: &pb.DialDataRequest{
AddrIdx: uint32(addrIdx),
NumBytes: uint64(numBytes),
},
},
}
if err := w.WriteMsg(msg); err != nil {
return fmt.Errorf("dial data write: %w", err)
}
// pbio.Reader that we used so far on this stream is buffered. But at this point
// there is nothing unread on the stream. So it is safe to use the raw stream to
// read, reducing allocations.
return readDialData(numBytes, s)
}
func readDialData(numBytes int, r io.Reader) error {
mr := &msgReader{R: r, Buf: pool.Get(maxMsgSize)}
defer pool.Put(mr.Buf)
for remain := numBytes; remain > 0; {
msg, err := mr.ReadMsg()
if err != nil {
return fmt.Errorf("dial data read: %w", err)
}
// protobuf format is:
// (oneof dialDataResponse:<fieldTag><len varint>)(dial data:<fieldTag><len varint><bytes>)
bytesLen := len(msg)
bytesLen -= 2 // fieldTag + varint first byte
if bytesLen > 127 {
bytesLen -= 1 // varint second byte
}
bytesLen -= 2 // second fieldTag + varint first byte
if bytesLen > 127 {
bytesLen -= 1 // varint second byte
}
if bytesLen > 0 {
remain -= bytesLen
}
// Check if the peer is not sending too little data forcing us to just do a lot of compute
if bytesLen < 100 && remain > 0 {
return fmt.Errorf("dial data msg too small: %d", bytesLen)
}
}
return nil
}
func (as *server) dialBack(ctx context.Context, p peer.ID, addr ma.Multiaddr, nonce uint64) pb.DialStatus {
ctx, cancel := context.WithTimeout(ctx, dialBackDialTimeout)
ctx = network.WithForceDirectDial(ctx, "autonatv2")
as.dialerHost.Peerstore().AddAddr(p, addr, peerstore.TempAddrTTL)
defer func() {
cancel()
as.dialerHost.Network().ClosePeer(p)
as.dialerHost.Peerstore().ClearAddrs(p)
as.dialerHost.Peerstore().RemovePeer(p)
}()
err := as.dialerHost.Connect(ctx, peer.AddrInfo{ID: p})
if err != nil {
return pb.DialStatus_E_DIAL_ERROR
}
s, err := as.dialerHost.NewStream(ctx, p, DialBackProtocol)
if err != nil {
return pb.DialStatus_E_DIAL_BACK_ERROR
}
defer s.Close()
s.SetDeadline(as.now().Add(dialBackStreamTimeout))
w := pbio.NewDelimitedWriter(s)
if err := w.WriteMsg(&pb.DialBack{Nonce: nonce}); err != nil {
s.Reset()
return pb.DialStatus_E_DIAL_BACK_ERROR
}
// Since the underlying connection is on a separate dialer, it'll be closed after this
// function returns. Connection close will drop all the queued writes. To ensure message
// delivery, do a CloseWrite and read a byte from the stream. The peer actually sends a
// response of type DialBackResponse but we only care about the fact that the DialBack
// message has reached the peer. So we ignore that message on the read side.
s.CloseWrite()
s.SetDeadline(as.now().Add(5 * time.Second)) // 5 is a magic number
b := make([]byte, 1) // Read 1 byte here because 0 len reads are free to return (0, nil) immediately
s.Read(b)
return pb.DialStatus_OK
}
// rateLimiter implements a sliding window rate limit of requests per minute. It allows 1 concurrent request
// per peer. It rate limits requests globally, at a peer level and depending on whether it requires dial data.
type rateLimiter struct {
// PerPeerRPM is the rate limit per peer
PerPeerRPM int
// RPM is the global rate limit
RPM int
// DialDataRPM is the rate limit for requests that require dial data
DialDataRPM int
// MaxConcurrentRequestsPerPeer is the maximum number of concurrent requests per peer
MaxConcurrentRequestsPerPeer int
mu sync.Mutex
closed bool
reqs []entry
peerReqs map[peer.ID][]time.Time
dialDataReqs []time.Time
// inProgressReqs tracks in progress requests. This is used to limit multiple
// concurrent requests by the same peer.
inProgressReqs map[peer.ID]int
now func() time.Time // for tests
}
type entry struct {
PeerID peer.ID
Time time.Time
}
func (r *rateLimiter) init() {
if r.peerReqs == nil {
r.peerReqs = make(map[peer.ID][]time.Time)
r.inProgressReqs = make(map[peer.ID]int)
}
}
func (r *rateLimiter) Accept(p peer.ID) bool {
r.mu.Lock()
defer r.mu.Unlock()
if r.closed {
return false
}
r.init()
nw := r.now()
r.cleanup(nw)
if r.inProgressReqs[p] >= r.MaxConcurrentRequestsPerPeer {
return false
}
if len(r.reqs) >= r.RPM || len(r.peerReqs[p]) >= r.PerPeerRPM {
return false
}
r.inProgressReqs[p]++
r.reqs = append(r.reqs, entry{PeerID: p, Time: nw})
r.peerReqs[p] = append(r.peerReqs[p], nw)
return true
}
func (r *rateLimiter) AcceptDialDataRequest() bool {
r.mu.Lock()
defer r.mu.Unlock()
if r.closed {
return false
}
r.init()
nw := r.now()
r.cleanup(nw)
if len(r.dialDataReqs) >= r.DialDataRPM {
return false
}
r.dialDataReqs = append(r.dialDataReqs, nw)
return true
}
// cleanup removes stale requests.
//
// This is fast enough in rate limited cases and the state is small enough to
// clean up quickly when blocking requests.
func (r *rateLimiter) cleanup(now time.Time) {
idx := len(r.reqs)
for i, e := range r.reqs {
if now.Sub(e.Time) >= time.Minute {
pi := len(r.peerReqs[e.PeerID])
for j, t := range r.peerReqs[e.PeerID] {
if now.Sub(t) < time.Minute {
pi = j
break
}
}
r.peerReqs[e.PeerID] = r.peerReqs[e.PeerID][pi:]
if len(r.peerReqs[e.PeerID]) == 0 {
delete(r.peerReqs, e.PeerID)
}
} else {
idx = i
break
}
}
r.reqs = r.reqs[idx:]
idx = len(r.dialDataReqs)
for i, t := range r.dialDataReqs {
if now.Sub(t) < time.Minute {
idx = i
break
}
}
r.dialDataReqs = r.dialDataReqs[idx:]
}
func (r *rateLimiter) CompleteRequest(p peer.ID) {
r.mu.Lock()
defer r.mu.Unlock()
r.inProgressReqs[p]--
if r.inProgressReqs[p] <= 0 {
delete(r.inProgressReqs, p)
if r.inProgressReqs[p] < 0 {
log.Error("BUG: negative in progress requests",
"remote_peer", p)
}
}
}
func (r *rateLimiter) Close() {
r.mu.Lock()
defer r.mu.Unlock()
r.closed = true
r.peerReqs = nil
r.inProgressReqs = nil
r.dialDataReqs = nil
}
// amplificationAttackPrevention is a dialDataRequestPolicy which requests data when the peer's observed
// IP address is different from the dial back IP address
func amplificationAttackPrevention(observedAddr, dialAddr ma.Multiaddr) bool {
observedIP, err := manet.ToIP(observedAddr)
if err != nil {
return true
}
dialIP, err := manet.ToIP(dialAddr) // can be dns addr
if err != nil {
return true
}
return !observedIP.Equal(dialIP)
}