diff --git a/go-libp2p/p2p/net/gostream/addr.go b/go-libp2p/p2p/net/gostream/addr.go index 49d844f..2ca2bb7 100644 --- a/go-libp2p/p2p/net/gostream/addr.go +++ b/go-libp2p/p2p/net/gostream/addr.go @@ -1,6 +1,10 @@ package gostream -import "github.com/libp2p/go-libp2p/core/peer" +import ( + "net" + + "github.com/libp2p/go-libp2p/core/peer" +) // addr implements net.Addr and holds a libp2p peer ID. type addr struct{ id peer.ID } @@ -12,3 +16,12 @@ func (a *addr) Network() string { return Network } // String returns the peer ID of this address in string form // (B58-encoded). func (a *addr) String() string { return a.id.String() } + +// PeerIDFromAddr extracts a peer ID from a net.Addr. +func PeerIDFromAddr(a net.Addr) (peer.ID, bool) { + addr, ok := a.(*addr) + if !ok { + return "", false + } + return addr.id, true +} diff --git a/node/consensus/data/grpc_worker_rate_limiter.go b/node/consensus/data/grpc_worker_rate_limiter.go index 05f4786..824de53 100644 --- a/node/consensus/data/grpc_worker_rate_limiter.go +++ b/node/consensus/data/grpc_worker_rate_limiter.go @@ -4,13 +4,14 @@ import ( "sync" "time" + "github.com/libp2p/go-libp2p/core/peer" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) type RateLimiter struct { mu sync.RWMutex - clients map[string]*bucket + clients map[peer.ID]*bucket maxTokens int refillTokens int refillTime time.Duration @@ -31,7 +32,7 @@ func NewRateLimiter( refillDuration time.Duration, ) *RateLimiter { return &RateLimiter{ - clients: make(map[string]*bucket), + clients: make(map[peer.ID]*bucket), maxTokens: maxTokens, refillTokens: refillTokens, refillTime: refillDuration, @@ -41,7 +42,7 @@ func NewRateLimiter( } } -func (rl *RateLimiter) Allow(peerId string) error { +func (rl *RateLimiter) Allow(peerId peer.ID) error { rl.mu.Lock() defer rl.mu.Unlock() diff --git a/node/consensus/data/peer_messaging.go b/node/consensus/data/peer_messaging.go index c27dea6..59d3c3a 100644 --- a/node/consensus/data/peer_messaging.go +++ b/node/consensus/data/peer_messaging.go @@ -15,8 +15,11 @@ import ( "go.uber.org/zap" "golang.org/x/crypto/sha3" "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "source.quilibrium.com/quilibrium/monorepo/node/crypto" "source.quilibrium.com/quilibrium/monorepo/node/execution/intrinsics/token/application" + grpc_internal "source.quilibrium.com/quilibrium/monorepo/node/internal/grpc" "source.quilibrium.com/quilibrium/monorepo/node/p2p" "source.quilibrium.com/quilibrium/monorepo/node/protobufs" "source.quilibrium.com/quilibrium/monorepo/node/store" @@ -28,11 +31,12 @@ func (e *DataClockConsensusEngine) GetDataFrame( ctx context.Context, request *protobufs.GetDataFrameRequest, ) (*protobufs.DataFrameResponse, error) { - if request.PeerId == "" || len(request.PeerId) > 64 { - return nil, errors.Wrap(errors.New("invalid request"), "get data frame") + peerID, ok := grpc_internal.PeerIDFromContext(ctx) + if !ok { + return nil, status.Error(codes.Internal, "remote peer ID not found") } - if err := e.grpcRateLimiter.Allow(request.PeerId); err != nil { + if err := e.grpcRateLimiter.Allow(peerID); err != nil { return nil, errors.Wrap(err, "get data frame") } diff --git a/node/internal/grpc/peer_id.go b/node/internal/grpc/peer_id.go new file mode 100644 index 0000000..83c4c88 --- /dev/null +++ b/node/internal/grpc/peer_id.go @@ -0,0 +1,33 @@ +package grpc + +import ( + "context" + + "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/p2p/net/gostream" + grpc_peer "google.golang.org/grpc/peer" +) + +type peerIDKeyType struct{} + +var peerIDKey peerIDKeyType + +// PeerIDFromContext returns the peer.ID of the remote peer from the given context. +// It assumes that the context is a gRPC request context, and the connection was established +// by gostream.Listen. +func PeerIDFromContext(ctx context.Context) (peer.ID, bool) { + if peerID, ok := ctx.Value(peerIDKey).(peer.ID); ok { + return peerID, true + } + remotePeer, ok := grpc_peer.FromContext(ctx) + if !ok { + return "", false + } + return gostream.PeerIDFromAddr(remotePeer.Addr) +} + +// NewContextWithPeerID returns a new context with the given peer.ID. +// This method is meant to be used only in unit testing contexts. +func NewContextWithPeerID(ctx context.Context, peerID peer.ID) context.Context { + return context.WithValue(ctx, peerIDKey, peerID) +}