diff --git a/core/commands/p2p.go b/core/commands/p2p.go index f8063040c..ca2b7d73a 100644 --- a/core/commands/p2p.go +++ b/core/commands/p2p.go @@ -511,7 +511,7 @@ var p2pStreamCloseCmd = &cmds.Command{ n.P2P.Streams.Unlock() for _, s := range toClose { - s.Reset() + n.P2P.Streams.Reset(s) } }, } diff --git a/p2p/local.go b/p2p/local.go index 70fd80d12..42740d8ef 100644 --- a/p2p/local.go +++ b/p2p/local.go @@ -77,23 +77,17 @@ func (l *localListener) setupStream(local manet.Conn) { return } - cmgr := l.p2p.peerHost.ConnManager() - cmgr.TagPeer(l.peer, CMGR_TAG, 20) - stream := &Stream{ Protocol: l.proto, OriginAddr: local.RemoteMultiaddr(), TargetAddr: l.TargetAddress(), + peer: l.peer, Local: local, Remote: remote, Registry: l.p2p.Streams, - - cleanup: func() { - cmgr.UntagPeer(l.peer, CMGR_TAG) - }, } l.p2p.Streams.Register(stream) diff --git a/p2p/p2p.go b/p2p/p2p.go index a20844258..eb773303e 100644 --- a/p2p/p2p.go +++ b/p2p/p2p.go @@ -31,7 +31,9 @@ func NewP2P(identity peer.ID, peerHost p2phost.Host, peerstore pstore.Peerstore) ListenersP2P: newListenerP2PRegistry(identity, peerHost), Streams: &StreamRegistry{ - Streams: map[uint64]*Stream{}, + Streams: map[uint64]*Stream{}, + ConnManager: peerHost.ConnManager(), + conns: map[peer.ID]int{}, }, } } diff --git a/p2p/remote.go b/p2p/remote.go index 00971481b..633497a0c 100644 --- a/p2p/remote.go +++ b/p2p/remote.go @@ -57,23 +57,17 @@ func (l *remoteListener) handleStream(remote net.Stream) { return } - cmgr := l.p2p.peerHost.ConnManager() - cmgr.TagPeer(peer, CMGR_TAG, 20) - stream := &Stream{ Protocol: l.proto, OriginAddr: peerMa, TargetAddr: l.addr, + peer: peer, Local: local, Remote: remote, Registry: l.p2p.Streams, - - cleanup: func() { - cmgr.UntagPeer(peer, CMGR_TAG) - }, } l.p2p.Streams.Register(stream) diff --git a/p2p/stream.go b/p2p/stream.go index 0443194be..5dc4a95a6 100644 --- a/p2p/stream.go +++ b/p2p/stream.go @@ -5,9 +5,11 @@ import ( "sync" net "gx/ipfs/QmQSbtGXCyNrj34LWL8EgXyNNYDZ8r3SwQcpW5pPxVhLnM/go-libp2p-net" + peer "gx/ipfs/QmQsErDt8Qgw1XrsXf2BpEzDgGWtB1YLsTAARBup5b6B9W/go-libp2p-peer" manet "gx/ipfs/QmV6FjemM1K8oXjrvuq3wuVWWoU2TLDPmNnKrxHzY3v6Ai/go-multiaddr-net" + ifconnmgr "gx/ipfs/QmVz2p8ZVZ5GcWPNWGs2HZHiZyHumZcJpQdMRpxkMDhc2C/go-libp2p-interface-connmgr" ma "gx/ipfs/QmYmsdtJ3HsodkePE3eU3TsCaP2YvPZJ4LoXnNkDE5Tpt7/go-multiaddr" - "gx/ipfs/QmZNkThpqfVXs9GNbexPrfBbXSLNYeKrE7jwFM2oqHbyqN/go-libp2p-protocol" + protocol "gx/ipfs/QmZNkThpqfVXs9GNbexPrfBbXSLNYeKrE7jwFM2oqHbyqN/go-libp2p-protocol" ) const CMGR_TAG = "stream-fwd" @@ -20,29 +22,23 @@ type Stream struct { OriginAddr ma.Multiaddr TargetAddr ma.Multiaddr + peer peer.ID Local manet.Conn Remote net.Stream Registry *StreamRegistry - - cleanup func() } -// Close closes stream endpoints and deregisters it -func (s *Stream) Close() error { - s.Local.Close() - s.Remote.Close() - s.cleanup() - s.Registry.Deregister(s.id) +// close closes stream endpoints and deregisters it +func (s *Stream) close() error { + s.Registry.Close(s) return nil } -// Reset closes stream endpoints and deregisters it -func (s *Stream) Reset() error { - s.Local.Close() - s.Remote.Reset() - s.Registry.Deregister(s.id) +// reset closes stream endpoints and deregisters it +func (s *Stream) reset() error { + s.Registry.Reset(s) return nil } @@ -50,18 +46,18 @@ func (s *Stream) startStreaming() { go func() { _, err := io.Copy(s.Local, s.Remote) if err != nil { - s.Reset() + s.reset() } else { - s.Close() + s.close() } }() go func() { _, err := io.Copy(s.Remote, s.Local) if err != nil { - s.Reset() + s.reset() } else { - s.Close() + s.close() } }() } @@ -71,7 +67,10 @@ type StreamRegistry struct { sync.Mutex Streams map[uint64]*Stream + conns map[peer.ID]int nextID uint64 + + ifconnmgr.ConnManager } // Register registers a stream to the registry @@ -79,6 +78,9 @@ func (r *StreamRegistry) Register(streamInfo *Stream) { r.Lock() defer r.Unlock() + r.ConnManager.TagPeer(streamInfo.peer, CMGR_TAG, 20) + r.conns[streamInfo.peer]++ + streamInfo.id = r.nextID r.Streams[r.nextID] = streamInfo r.nextID++ @@ -89,5 +91,32 @@ func (r *StreamRegistry) Deregister(streamID uint64) { r.Lock() defer r.Unlock() + s, ok := r.Streams[streamID] + if !ok { + return + } + p := s.peer + r.conns[p]-- + if r.conns[p] < 1 { + delete(r.conns, p) + r.ConnManager.UntagPeer(p, CMGR_TAG) + } + delete(r.Streams, streamID) } + +// close closes stream endpoints and deregisters it +func (r *StreamRegistry) Close(s *Stream) error { + s.Local.Close() + s.Remote.Close() + s.Registry.Deregister(s.id) + return nil +} + +// reset closes stream endpoints and deregisters it +func (r *StreamRegistry) Reset(s *Stream) error { + s.Local.Close() + s.Remote.Reset() + s.Registry.Deregister(s.id) + return nil +}