From adfbecf3f78ca72709fbed881452c6cd77ac98ac Mon Sep 17 00:00:00 2001 From: Steven Allen Date: Thu, 14 Sep 2017 11:52:14 -0700 Subject: [PATCH] use stream.Reset where appropriate License: MIT Signed-off-by: Steven Allen --- exchange/bitswap/network/interface.go | 1 + exchange/bitswap/network/ipfs_impl.go | 15 ++++++++++++-- exchange/bitswap/testnet/virtual.go | 4 ++++ exchange/bitswap/wantmanager.go | 13 ++++++------ p2p/p2p.go | 6 +++--- p2p/registry.go | 30 +++++++++++++++++++++------ routing/supernode/proxy/loopback.go | 5 ++++- routing/supernode/proxy/standard.go | 15 +++++++++++--- 8 files changed, 68 insertions(+), 21 deletions(-) diff --git a/exchange/bitswap/network/interface.go b/exchange/bitswap/network/interface.go index 92d27676c..2ec1c639b 100644 --- a/exchange/bitswap/network/interface.go +++ b/exchange/bitswap/network/interface.go @@ -40,6 +40,7 @@ type BitSwapNetwork interface { type MessageSender interface { SendMsg(context.Context, bsmsg.BitSwapMessage) error Close() error + Reset() error } // Implement Receiver to receive messages from the BitSwapNetwork diff --git a/exchange/bitswap/network/ipfs_impl.go b/exchange/bitswap/network/ipfs_impl.go index 505ea4d2e..8e18527aa 100644 --- a/exchange/bitswap/network/ipfs_impl.go +++ b/exchange/bitswap/network/ipfs_impl.go @@ -56,6 +56,10 @@ func (s *streamMessageSender) Close() error { return s.s.Close() } +func (s *streamMessageSender) Reset() error { + return s.s.Reset() +} + func (s *streamMessageSender) SendMsg(ctx context.Context, msg bsmsg.BitSwapMessage) error { return msgToStream(ctx, s.s, msg) } @@ -121,9 +125,14 @@ func (bsnet *impl) SendMessage( if err != nil { return err } - defer s.Close() - return msgToStream(ctx, s, outgoing) + err = msgToStream(ctx, s, outgoing) + if err != nil { + s.Reset() + } else { + s.Close() + } + return err } func (bsnet *impl) SetDelegate(r Receiver) { @@ -180,6 +189,7 @@ func (bsnet *impl) handleNewStream(s inet.Stream) { defer s.Close() if bsnet.receiver == nil { + s.Reset() return } @@ -188,6 +198,7 @@ func (bsnet *impl) handleNewStream(s inet.Stream) { received, err := bsmsg.FromPBReader(reader) if err != nil { if err != io.EOF { + s.Reset() go bsnet.receiver.ReceiveError(err) log.Debugf("bitswap net handleNewStream from %s error: %s", s.Conn().RemotePeer(), err) } diff --git a/exchange/bitswap/testnet/virtual.go b/exchange/bitswap/testnet/virtual.go index a01d4165f..37ae23b54 100644 --- a/exchange/bitswap/testnet/virtual.go +++ b/exchange/bitswap/testnet/virtual.go @@ -133,6 +133,10 @@ func (mp *messagePasser) Close() error { return nil } +func (mp *messagePasser) Reset() error { + return nil +} + func (n *networkClient) NewMessageSender(ctx context.Context, p peer.ID) (bsnet.MessageSender, error) { return &messagePasser{ net: n.network, diff --git a/exchange/bitswap/wantmanager.go b/exchange/bitswap/wantmanager.go index cdc8da868..e2859a292 100644 --- a/exchange/bitswap/wantmanager.go +++ b/exchange/bitswap/wantmanager.go @@ -172,18 +172,19 @@ func (pm *WantManager) stopPeerHandler(p peer.ID) { } func (mq *msgQueue) runQueue(ctx context.Context) { - defer func() { - if mq.sender != nil { - mq.sender.Close() - } - }() for { select { case <-mq.work: // there is work to be done mq.doWork(ctx) case <-mq.done: + if mq.sender != nil { + mq.sender.Close() + } return case <-ctx.Done(): + if mq.sender != nil { + mq.sender.Reset() + } return } } @@ -218,7 +219,7 @@ func (mq *msgQueue) doWork(ctx context.Context) { } log.Infof("bitswap send error: %s", err) - mq.sender.Close() + mq.sender.Reset() mq.sender = nil select { diff --git a/p2p/p2p.go b/p2p/p2p.go index 7cbc62002..ea9eab68d 100644 --- a/p2p/p2p.go +++ b/p2p/p2p.go @@ -64,7 +64,7 @@ func (p2p *P2P) Dial(ctx context.Context, addr ma.Multiaddr, peer peer.ID, proto case "tcp", "tcp4", "tcp6": listener, err := manet.Listen(bindAddr) if err != nil { - if err2 := remote.Close(); err2 != nil { + if err2 := remote.Reset(); err2 != nil { return nil, err2 } return nil, err @@ -158,7 +158,7 @@ func (p2p *P2P) registerStreamHandler(ctx2 context.Context, protocol string) (*P select { case list.conCh <- s: case <-ctx.Done(): - s.Close() + s.Reset() } }) @@ -198,7 +198,7 @@ func (p2p *P2P) acceptStreams(listenerInfo *ListenerInfo, listener Listener) { local, err := manet.Dial(listenerInfo.Address) if err != nil { - remote.Close() + remote.Reset() continue } diff --git a/p2p/registry.go b/p2p/registry.go index aa6c21706..be865c470 100644 --- a/p2p/registry.go +++ b/p2p/registry.go @@ -4,6 +4,8 @@ import ( "fmt" "io" + net "gx/ipfs/QmNa31VPzC561NWwRsJLE7nGYZYuuD2QfpK2b1q9BK54J1/go-libp2p-net" + manet "gx/ipfs/QmX3U3YXCQ6UYBxq2LVWF8dARS1hPUTEYLrSx654Qyxyw6/go-multiaddr-net" ma "gx/ipfs/QmXY77cVe7rVRQXZZQRioukUM7aRW3BTcAgJe12MCtb3Ji/go-multiaddr" peer "gx/ipfs/QmXYjuNuxVzXKJCfWasQk1RqkhVLDM9jtUKhqc2WPQmFSB/go-libp2p-peer" ) @@ -76,8 +78,8 @@ type StreamInfo struct { RemotePeer peer.ID RemoteAddr ma.Multiaddr - Local io.ReadWriteCloser - Remote io.ReadWriteCloser + Local manet.Conn + Remote net.Stream Registry *StreamRegistry } @@ -90,15 +92,31 @@ func (s *StreamInfo) Close() error { return nil } +// Reset closes stream endpoints and deregisters it +func (s *StreamInfo) Reset() error { + s.Local.Close() + s.Remote.Reset() + s.Registry.Deregister(s.HandlerID) + return nil +} + func (s *StreamInfo) startStreaming() { go func() { - io.Copy(s.Local, s.Remote) - s.Close() + _, err := io.Copy(s.Local, s.Remote) + if err != nil { + s.Reset() + } else { + s.Close() + } }() go func() { - io.Copy(s.Remote, s.Local) - s.Close() + _, err := io.Copy(s.Remote, s.Local) + if err != nil { + s.Reset() + } else { + s.Close() + } }() } diff --git a/routing/supernode/proxy/loopback.go b/routing/supernode/proxy/loopback.go index e8b77b322..f6d9c0bb7 100644 --- a/routing/supernode/proxy/loopback.go +++ b/routing/supernode/proxy/loopback.go @@ -42,6 +42,7 @@ func (lb *Loopback) HandleStream(s inet.Stream) { pbr := ggio.NewDelimitedReader(s, inet.MessageSizeMax) var incoming dhtpb.Message if err := pbr.ReadMsg(&incoming); err != nil { + s.Reset() log.Debug(err) return } @@ -51,6 +52,8 @@ func (lb *Loopback) HandleStream(s inet.Stream) { pbw := ggio.NewDelimitedWriter(s) if err := pbw.WriteMsg(outgoing); err != nil { - return // TODO logerr + s.Reset() + log.Debug(err) + return } } diff --git a/routing/supernode/proxy/standard.go b/routing/supernode/proxy/standard.go index d5a9b51bf..eddd1e84f 100644 --- a/routing/supernode/proxy/standard.go +++ b/routing/supernode/proxy/standard.go @@ -60,7 +60,7 @@ func (px *standard) Bootstrap(ctx context.Context) error { func (p *standard) HandleStream(s inet.Stream) { // TODO(brian): Should clients be able to satisfy requests? log.Error("supernode client received (dropped) a routing message from", s.Conn().RemotePeer()) - s.Close() + s.Reset() } const replicationFactor = 2 @@ -102,9 +102,15 @@ func (px *standard) sendMessage(ctx context.Context, m *dhtpb.Message, remote pe if err != nil { return err } - defer s.Close() pbw := ggio.NewDelimitedWriter(s) - return pbw.WriteMsg(m) + + err = pbw.WriteMsg(m) + if err == nil { + s.Close() + } else { + s.Reset() + } + return err } // SendRequest sends the request to each remote sequentially (randomized order), @@ -139,17 +145,20 @@ func (px *standard) sendRequest(ctx context.Context, m *dhtpb.Message, remote pe r := ggio.NewDelimitedReader(s, inet.MessageSizeMax) w := ggio.NewDelimitedWriter(s) if err = w.WriteMsg(m); err != nil { + s.Reset() e.SetError(err) return nil, err } response := &dhtpb.Message{} if err = r.ReadMsg(response); err != nil { + s.Reset() e.SetError(err) return nil, err } // need ctx expiration? if response == nil { + s.Reset() err := errors.New("no response to request") e.SetError(err) return nil, err