diff --git a/routing/dht/providers.go b/routing/dht/providers.go index 2b7fa2cbd..74c79b8e9 100644 --- a/routing/dht/providers.go +++ b/routing/dht/providers.go @@ -10,22 +10,25 @@ import ( context "github.com/ipfs/go-ipfs/Godeps/_workspace/src/golang.org/x/net/context" ) -type providerInfo struct { - Creation time.Time - Value peer.ID -} - type ProviderManager struct { - providers map[key.Key][]*providerInfo + // all non channel fields are meant to be accessed only within + // the run method + providers map[key.Key]*providerSet local map[key.Key]struct{} lpeer peer.ID - getlocal chan chan []key.Key - newprovs chan *addProv - getprovs chan *getProv - period time.Duration + + getlocal chan chan []key.Key + newprovs chan *addProv + getprovs chan *getProv + period time.Duration ctxgroup.ContextGroup } +type providerSet struct { + providers []peer.ID + set map[peer.ID]time.Time +} + type addProv struct { k key.Key val peer.ID @@ -40,7 +43,7 @@ func NewProviderManager(ctx context.Context, local peer.ID) *ProviderManager { pm := new(ProviderManager) pm.getprovs = make(chan *getProv) pm.newprovs = make(chan *addProv) - pm.providers = make(map[key.Key][]*providerInfo) + pm.providers = make(map[key.Key]*providerSet) pm.getlocal = make(chan chan []key.Key) pm.local = make(map[key.Key]struct{}) pm.ContextGroup = ctxgroup.WithContext(ctx) @@ -61,18 +64,20 @@ func (pm *ProviderManager) run() { if np.val == pm.lpeer { pm.local[np.k] = struct{}{} } - pi := new(providerInfo) - pi.Creation = time.Now() - pi.Value = np.val - arr := pm.providers[np.k] - pm.providers[np.k] = append(arr, pi) + provs, ok := pm.providers[np.k] + if !ok { + provs = newProviderSet() + pm.providers[np.k] = provs + } + provs.Add(np.val) case gp := <-pm.getprovs: var parr []peer.ID - provs := pm.providers[gp.k] - for _, p := range provs { - parr = append(parr, p.Value) + provs, ok := pm.providers[gp.k] + if ok { + parr = provs.providers } + gp.resp <- parr case lc := <-pm.getlocal: @@ -83,14 +88,16 @@ func (pm *ProviderManager) run() { lc <- keys case <-tick.C: - for k, provs := range pm.providers { - var filtered []*providerInfo - for _, p := range provs { - if time.Now().Sub(p.Creation) < time.Hour*24 { + for _, provs := range pm.providers { + var filtered []peer.ID + for p, t := range provs.set { + if time.Now().Sub(t) > time.Hour*24 { + delete(provs.set, p) + } else { filtered = append(filtered, p) } } - pm.providers[k] = filtered + provs.providers = filtered } case <-pm.Closing(): @@ -133,3 +140,18 @@ func (pm *ProviderManager) GetLocal() []key.Key { pm.getlocal <- resp return <-resp } + +func newProviderSet() *providerSet { + return &providerSet{ + set: make(map[peer.ID]time.Time), + } +} + +func (ps *providerSet) Add(p peer.ID) { + _, found := ps.set[p] + if !found { + ps.providers = append(ps.providers, p) + } + + ps.set[p] = time.Now() +}