use stream.Reset where appropriate

License: MIT
Signed-off-by: Steven Allen <steven@stebalien.com>
This commit is contained in:
Steven Allen 2017-09-14 11:52:14 -07:00
parent 8deaaa8d8c
commit adfbecf3f7
8 changed files with 68 additions and 21 deletions

View File

@ -40,6 +40,7 @@ type BitSwapNetwork interface {
type MessageSender interface {
SendMsg(context.Context, bsmsg.BitSwapMessage) error
Close() error
Reset() error
}
// Implement Receiver to receive messages from the BitSwapNetwork

View File

@ -56,6 +56,10 @@ func (s *streamMessageSender) Close() error {
return s.s.Close()
}
func (s *streamMessageSender) Reset() error {
return s.s.Reset()
}
func (s *streamMessageSender) SendMsg(ctx context.Context, msg bsmsg.BitSwapMessage) error {
return msgToStream(ctx, s.s, msg)
}
@ -121,9 +125,14 @@ func (bsnet *impl) SendMessage(
if err != nil {
return err
}
defer s.Close()
return msgToStream(ctx, s, outgoing)
err = msgToStream(ctx, s, outgoing)
if err != nil {
s.Reset()
} else {
s.Close()
}
return err
}
func (bsnet *impl) SetDelegate(r Receiver) {
@ -180,6 +189,7 @@ func (bsnet *impl) handleNewStream(s inet.Stream) {
defer s.Close()
if bsnet.receiver == nil {
s.Reset()
return
}
@ -188,6 +198,7 @@ func (bsnet *impl) handleNewStream(s inet.Stream) {
received, err := bsmsg.FromPBReader(reader)
if err != nil {
if err != io.EOF {
s.Reset()
go bsnet.receiver.ReceiveError(err)
log.Debugf("bitswap net handleNewStream from %s error: %s", s.Conn().RemotePeer(), err)
}

View File

@ -133,6 +133,10 @@ func (mp *messagePasser) Close() error {
return nil
}
func (mp *messagePasser) Reset() error {
return nil
}
func (n *networkClient) NewMessageSender(ctx context.Context, p peer.ID) (bsnet.MessageSender, error) {
return &messagePasser{
net: n.network,

View File

@ -172,18 +172,19 @@ func (pm *WantManager) stopPeerHandler(p peer.ID) {
}
func (mq *msgQueue) runQueue(ctx context.Context) {
defer func() {
if mq.sender != nil {
mq.sender.Close()
}
}()
for {
select {
case <-mq.work: // there is work to be done
mq.doWork(ctx)
case <-mq.done:
if mq.sender != nil {
mq.sender.Close()
}
return
case <-ctx.Done():
if mq.sender != nil {
mq.sender.Reset()
}
return
}
}
@ -218,7 +219,7 @@ func (mq *msgQueue) doWork(ctx context.Context) {
}
log.Infof("bitswap send error: %s", err)
mq.sender.Close()
mq.sender.Reset()
mq.sender = nil
select {

View File

@ -64,7 +64,7 @@ func (p2p *P2P) Dial(ctx context.Context, addr ma.Multiaddr, peer peer.ID, proto
case "tcp", "tcp4", "tcp6":
listener, err := manet.Listen(bindAddr)
if err != nil {
if err2 := remote.Close(); err2 != nil {
if err2 := remote.Reset(); err2 != nil {
return nil, err2
}
return nil, err
@ -158,7 +158,7 @@ func (p2p *P2P) registerStreamHandler(ctx2 context.Context, protocol string) (*P
select {
case list.conCh <- s:
case <-ctx.Done():
s.Close()
s.Reset()
}
})
@ -198,7 +198,7 @@ func (p2p *P2P) acceptStreams(listenerInfo *ListenerInfo, listener Listener) {
local, err := manet.Dial(listenerInfo.Address)
if err != nil {
remote.Close()
remote.Reset()
continue
}

View File

@ -4,6 +4,8 @@ import (
"fmt"
"io"
net "gx/ipfs/QmNa31VPzC561NWwRsJLE7nGYZYuuD2QfpK2b1q9BK54J1/go-libp2p-net"
manet "gx/ipfs/QmX3U3YXCQ6UYBxq2LVWF8dARS1hPUTEYLrSx654Qyxyw6/go-multiaddr-net"
ma "gx/ipfs/QmXY77cVe7rVRQXZZQRioukUM7aRW3BTcAgJe12MCtb3Ji/go-multiaddr"
peer "gx/ipfs/QmXYjuNuxVzXKJCfWasQk1RqkhVLDM9jtUKhqc2WPQmFSB/go-libp2p-peer"
)
@ -76,8 +78,8 @@ type StreamInfo struct {
RemotePeer peer.ID
RemoteAddr ma.Multiaddr
Local io.ReadWriteCloser
Remote io.ReadWriteCloser
Local manet.Conn
Remote net.Stream
Registry *StreamRegistry
}
@ -90,15 +92,31 @@ func (s *StreamInfo) Close() error {
return nil
}
// Reset closes stream endpoints and deregisters it
func (s *StreamInfo) Reset() error {
s.Local.Close()
s.Remote.Reset()
s.Registry.Deregister(s.HandlerID)
return nil
}
func (s *StreamInfo) startStreaming() {
go func() {
io.Copy(s.Local, s.Remote)
s.Close()
_, err := io.Copy(s.Local, s.Remote)
if err != nil {
s.Reset()
} else {
s.Close()
}
}()
go func() {
io.Copy(s.Remote, s.Local)
s.Close()
_, err := io.Copy(s.Remote, s.Local)
if err != nil {
s.Reset()
} else {
s.Close()
}
}()
}

View File

@ -42,6 +42,7 @@ func (lb *Loopback) HandleStream(s inet.Stream) {
pbr := ggio.NewDelimitedReader(s, inet.MessageSizeMax)
var incoming dhtpb.Message
if err := pbr.ReadMsg(&incoming); err != nil {
s.Reset()
log.Debug(err)
return
}
@ -51,6 +52,8 @@ func (lb *Loopback) HandleStream(s inet.Stream) {
pbw := ggio.NewDelimitedWriter(s)
if err := pbw.WriteMsg(outgoing); err != nil {
return // TODO logerr
s.Reset()
log.Debug(err)
return
}
}

View File

@ -60,7 +60,7 @@ func (px *standard) Bootstrap(ctx context.Context) error {
func (p *standard) HandleStream(s inet.Stream) {
// TODO(brian): Should clients be able to satisfy requests?
log.Error("supernode client received (dropped) a routing message from", s.Conn().RemotePeer())
s.Close()
s.Reset()
}
const replicationFactor = 2
@ -102,9 +102,15 @@ func (px *standard) sendMessage(ctx context.Context, m *dhtpb.Message, remote pe
if err != nil {
return err
}
defer s.Close()
pbw := ggio.NewDelimitedWriter(s)
return pbw.WriteMsg(m)
err = pbw.WriteMsg(m)
if err == nil {
s.Close()
} else {
s.Reset()
}
return err
}
// SendRequest sends the request to each remote sequentially (randomized order),
@ -139,17 +145,20 @@ func (px *standard) sendRequest(ctx context.Context, m *dhtpb.Message, remote pe
r := ggio.NewDelimitedReader(s, inet.MessageSizeMax)
w := ggio.NewDelimitedWriter(s)
if err = w.WriteMsg(m); err != nil {
s.Reset()
e.SetError(err)
return nil, err
}
response := &dhtpb.Message{}
if err = r.ReadMsg(response); err != nil {
s.Reset()
e.SetError(err)
return nil, err
}
// need ctx expiration?
if response == nil {
s.Reset()
err := errors.New("no response to request")
e.SetError(err)
return nil, err