diff --git a/exchange/bitswap/session.go b/exchange/bitswap/session.go index 53db1a28a..3128cb0a0 100644 --- a/exchange/bitswap/session.go +++ b/exchange/bitswap/session.go @@ -78,13 +78,28 @@ func (bs *Bitswap) NewSession(ctx context.Context) *Session { return s } +func (bs *Bitswap) removeSession(s *Session) { + bs.sessLk.Lock() + defer bs.sessLk.Unlock() + for i := 0; i < len(bs.sessions); i++ { + if bs.sessions[i] == s { + bs.sessions[i] = bs.sessions[len(bs.sessions)-1] + bs.sessions = bs.sessions[:len(bs.sessions)-1] + return + } + } +} + type blkRecv struct { from peer.ID blk blocks.Block } func (s *Session) receiveBlockFrom(from peer.ID, blk blocks.Block) { - s.incoming <- blkRecv{from: from, blk: blk} + select { + case s.incoming <- blkRecv{from: from, blk: blk}: + case <-s.ctx.Done(): + } } type interestReq struct { @@ -105,7 +120,13 @@ func (s *Session) isLiveWant(c *cid.Cid) bool { c: c, resp: resp, } - return <-resp + + select { + case want := <-resp: + return want + case <-s.ctx.Done(): + return false + } } func (s *Session) interestedIn(c *cid.Cid) bool { @@ -194,6 +215,7 @@ func (s *Session) run(ctx context.Context) { lwchk.resp <- s.cidIsWanted(lwchk.c) case <-ctx.Done(): s.tick.Stop() + s.bs.removeSession(s) return } } diff --git a/exchange/bitswap/session_test.go b/exchange/bitswap/session_test.go index dfdae79cb..6d981eb4b 100644 --- a/exchange/bitswap/session_test.go +++ b/exchange/bitswap/session_test.go @@ -242,3 +242,46 @@ func TestPutAfterSessionCacheEvict(t *testing.T) { t.Fatal("timed out waiting for block") } } + +func TestMultipleSessions(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + vnet := getVirtualNetwork() + sesgen := NewTestSessionGenerator(vnet) + defer sesgen.Close() + bgen := blocksutil.NewBlockGenerator() + + blk := bgen.Blocks(1)[0] + inst := sesgen.Instances(2) + + a := inst[0] + b := inst[1] + + ctx1, cancel1 := context.WithCancel(ctx) + ses := a.Exchange.NewSession(ctx1) + + blkch, err := ses.GetBlocks(ctx, []*cid.Cid{blk.Cid()}) + if err != nil { + t.Fatal(err) + } + cancel1() + + ses2 := a.Exchange.NewSession(ctx) + blkch2, err := ses2.GetBlocks(ctx, []*cid.Cid{blk.Cid()}) + if err != nil { + t.Fatal(err) + } + + time.Sleep(time.Millisecond * 10) + if err := b.Exchange.HasBlock(blk); err != nil { + t.Fatal(err) + } + + select { + case <-blkch2: + case <-time.After(time.Second * 20): + t.Fatal("bad juju") + } + _ = blkch +}