diff --git a/exchange/bitswap/session.go b/exchange/bitswap/session.go index 049be4e9e..bc824dbee 100644 --- a/exchange/bitswap/session.go +++ b/exchange/bitswap/session.go @@ -84,6 +84,14 @@ func (bs *Bitswap) NewSession(ctx context.Context) *Session { func (bs *Bitswap) removeSession(s *Session) { s.notif.Shutdown() + + live := make([]*cid.Cid, 0, len(s.liveWants)) + for c := range s.liveWants { + cs, _ := cid.Cast([]byte(c)) + live = append(live, cs) + } + bs.CancelWants(live, s.id) + bs.sessLk.Lock() defer bs.sessLk.Unlock() for i := 0; i < len(bs.sessions); i++ { diff --git a/exchange/bitswap/session_test.go b/exchange/bitswap/session_test.go index 645890454..2fe4672b0 100644 --- a/exchange/bitswap/session_test.go +++ b/exchange/bitswap/session_test.go @@ -285,3 +285,36 @@ func TestMultipleSessions(t *testing.T) { } _ = blkch } + +func TestWantlistClearsOnCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + vnet := getVirtualNetwork() + sesgen := NewTestSessionGenerator(vnet) + defer sesgen.Close() + bgen := blocksutil.NewBlockGenerator() + + blks := bgen.Blocks(10) + var cids []*cid.Cid + for _, blk := range blks { + cids = append(cids, blk.Cid()) + } + + inst := sesgen.Instances(1) + + a := inst[0] + + ctx1, cancel1 := context.WithCancel(ctx) + ses := a.Exchange.NewSession(ctx1) + + _, err := ses.GetBlocks(ctx, cids) + if err != nil { + t.Fatal(err) + } + cancel1() + + if len(a.Exchange.GetWantlist()) > 0 { + t.Fatal("expected empty wantlist") + } +}