diff --git a/routing/dht/dht.go b/routing/dht/dht.go index 146de751b..b0a2a0481 100644 --- a/routing/dht/dht.go +++ b/routing/dht/dht.go @@ -100,10 +100,11 @@ func (dht *IpfsDHT) Connect(addr *ma.Multiaddr) (*peer.Peer, error) { func (dht *IpfsDHT) handleMessages() { u.DOut("Begin message handling routine\n") - ch := dht.network.GetChan() + errs := dht.network.GetErrChan() + dhtmes := dht.network.GetChannel(swarm.PBWrapper_DHT_MESSAGE) for { select { - case mes, ok := <-ch.Incoming: + case mes, ok := <-dhtmes: if !ok { u.DOut("handleMessages closing, bad recv on incoming\n") return @@ -147,7 +148,7 @@ func (dht *IpfsDHT) handleMessages() { u.PErr("Recieved invalid message type") } - case err := <-ch.Errors: + case err := <-errs: u.PErr("dht err: %s\n", err) case <-dht.shutdown: return diff --git a/routing/dht/dht_test.go b/routing/dht/dht_test.go index 6296d1029..92d0931fb 100644 --- a/routing/dht/dht_test.go +++ b/routing/dht/dht_test.go @@ -132,8 +132,8 @@ func TestValueGetSet(t *testing.T) { dhtA.Start() dhtB.Start() - errsa := dhtA.network.GetChan().Errors - errsb := dhtB.network.GetChan().Errors + errsa := dhtA.network.GetErrChan() + errsb := dhtB.network.GetErrChan() go func() { select { case err := <-errsa: diff --git a/routing/dht/ext_test.go b/routing/dht/ext_test.go index 8d1e74ba5..79cfd27bc 100644 --- a/routing/dht/ext_test.go +++ b/routing/dht/ext_test.go @@ -66,8 +66,12 @@ func (f *fauxNet) Send(mes *swarm.Message) { f.Chan.Outgoing <- mes } -func (f *fauxNet) GetChan() *swarm.Chan { - return f.Chan +func (f *fauxNet) GetErrChan() chan error { + return f.Chan.Errors +} + +func (f *fauxNet) GetChannel(t swarm.PBWrapper_MessageType) chan *swarm.Message { + return f.Chan.Incoming } func (f *fauxNet) Connect(addr *ma.Multiaddr) (*peer.Peer, error) { @@ -167,7 +171,6 @@ func _randPeer() *peer.Peer { } func TestNotFound(t *testing.T) { - u.Debug = true fn := newFauxNet() fn.Listen() @@ -225,3 +228,64 @@ func TestNotFound(t *testing.T) { } t.Fatal("Expected to recieve an error.") } + +// If less than K nodes are in the entire network, it should fail when we make +// a GET rpc and nobody has the value +func TestLessThanKResponses(t *testing.T) { + u.Debug = false + fn := newFauxNet() + fn.Listen() + + local := new(peer.Peer) + local.ID = peer.ID("test_peer") + + d := NewDHT(local, fn) + d.Start() + + var ps []*peer.Peer + for i := 0; i < 5; i++ { + ps = append(ps, _randPeer()) + d.Update(ps[i]) + } + other := _randPeer() + + // Reply with random peers to every message + fn.AddHandler(func(mes *swarm.Message) *swarm.Message { + t.Log("Handling message...") + pmes := new(PBDHTMessage) + err := proto.Unmarshal(mes.Data, pmes) + if err != nil { + t.Fatal(err) + } + + switch pmes.GetType() { + case PBDHTMessage_GET_VALUE: + resp := Message{ + Type: pmes.GetType(), + ID: pmes.GetId(), + Response: true, + Success: false, + Peers: []*peer.Peer{other}, + } + + return swarm.NewMessage(mes.Peer, resp.ToProtobuf()) + default: + panic("Shouldnt recieve this.") + } + + }) + + _, err := d.GetValue(u.Key("hello"), time.Second*30) + if err != nil { + switch err { + case u.ErrNotFound: + //Success! + return + case u.ErrTimeout: + t.Fatal("Should not have gotten timeout!") + default: + t.Fatalf("Got unexpected error: %s", err) + } + } + t.Fatal("Expected to recieve an error.") +} diff --git a/routing/dht/providers.go b/routing/dht/providers.go index 3dc3b7b05..fdf8d6581 100644 --- a/routing/dht/providers.go +++ b/routing/dht/providers.go @@ -3,24 +3,24 @@ package dht import ( "time" - u "github.com/jbenet/go-ipfs/util" peer "github.com/jbenet/go-ipfs/peer" + u "github.com/jbenet/go-ipfs/util" ) type ProviderManager struct { providers map[u.Key][]*providerInfo - newprovs chan *addProv - getprovs chan *getProv - halt chan struct{} + newprovs chan *addProv + getprovs chan *getProv + halt chan struct{} } type addProv struct { - k u.Key + k u.Key val *peer.Peer } type getProv struct { - k u.Key + k u.Key resp chan []*peer.Peer } @@ -55,7 +55,7 @@ func (pm *ProviderManager) run() { for k, provs := range pm.providers { var filtered []*providerInfo for _, p := range provs { - if time.Now().Sub(p.Creation) < time.Hour * 24 { + if time.Now().Sub(p.Creation) < time.Hour*24 { filtered = append(filtered, p) } } @@ -69,7 +69,7 @@ func (pm *ProviderManager) run() { func (pm *ProviderManager) AddProvider(k u.Key, val *peer.Peer) { pm.newprovs <- &addProv{ - k: k, + k: k, val: val, } } diff --git a/routing/dht/routing.go b/routing/dht/routing.go index e3d34325d..3a4ebd33d 100644 --- a/routing/dht/routing.go +++ b/routing/dht/routing.go @@ -164,7 +164,8 @@ func (dht *IpfsDHT) GetValue(key u.Key, timeout time.Duration) ([]byte, error) { case p := <-npeerChan: count++ if count >= KValue { - break + errChan <- u.ErrNotFound + return } c.Increment() @@ -172,40 +173,38 @@ func (dht *IpfsDHT) GetValue(key u.Key, timeout time.Duration) ([]byte, error) { default: if c.Size() == 0 { errChan <- u.ErrNotFound + return } } } }() process := func() { - for { - select { - case p, ok := <-procPeer: - if !ok || p == nil { - c.Decrement() - return - } - val, peers, err := dht.getValueOrPeers(p, key, timeout/4, routeLevel) - if err != nil { - u.DErr("%v\n", err.Error()) - c.Decrement() - continue - } - if val != nil { - valChan <- val - c.Decrement() - return - } - - for _, np := range peers { - // TODO: filter out peers that arent closer - if !pset.Contains(np) && pset.Size() < KValue { - pset.Add(np) //This is racey... make a single function to do operation - npeerChan <- np - } - } + for p := range procPeer { + if p == nil { c.Decrement() + return } + val, peers, err := dht.getValueOrPeers(p, key, timeout/4, routeLevel) + if err != nil { + u.DErr("%v\n", err.Error()) + c.Decrement() + continue + } + if val != nil { + valChan <- val + c.Decrement() + return + } + + for _, np := range peers { + // TODO: filter out peers that arent closer + if !pset.Contains(np) && pset.Size() < KValue { + pset.Add(np) //This is racey... make a single function to do operation + npeerChan <- np + } + } + c.Decrement() } } diff --git a/swarm/conn.go b/swarm/conn.go index 05ccb3057..072b53437 100644 --- a/swarm/conn.go +++ b/swarm/conn.go @@ -40,8 +40,6 @@ func Dial(network string, peer *peer.Peer) (*Conn, error) { return nil, err } - fmt.Printf("Making connection to: %s\n", host) - nconn, err := net.Dial(network, host) if err != nil { return nil, err diff --git a/swarm/interface.go b/swarm/interface.go index 9a70890e6..3bfcd233b 100644 --- a/swarm/interface.go +++ b/swarm/interface.go @@ -14,7 +14,8 @@ type Network interface { Listen() error ConnectNew(*ma.Multiaddr) (*peer.Peer, error) GetConnection(id peer.ID, addr *ma.Multiaddr) (*peer.Peer, error) - GetChan() *Chan + GetErrChan() chan error + GetChannel(PBWrapper_MessageType) chan *Message Close() Drop(*peer.Peer) error } diff --git a/swarm/swarm.go b/swarm/swarm.go index 8c96be132..926aa8910 100644 --- a/swarm/swarm.go +++ b/swarm/swarm.go @@ -84,6 +84,10 @@ type Swarm struct { conns ConnMap connsLock sync.RWMutex + filterChans map[PBWrapper_MessageType]chan *Message + toFilter chan *Message + newFilters chan *newFilterInfo + local *peer.Peer listeners []net.Listener } @@ -91,10 +95,14 @@ type Swarm struct { // NewSwarm constructs a Swarm, with a Chan. func NewSwarm(local *peer.Peer) *Swarm { s := &Swarm{ - Chan: NewChan(10), - conns: ConnMap{}, - local: local, + Chan: NewChan(10), + conns: ConnMap{}, + local: local, + filterChans: make(map[PBWrapper_MessageType]chan *Message), + toFilter: make(chan *Message, 32), + newFilters: make(chan *newFilterInfo), } + go s.routeMessages() go s.fanOut() return s } @@ -299,15 +307,8 @@ func (s *Swarm) fanIn(conn *Conn) { goto out } - wrapper, err := Unwrap(data) - if err != nil { - s.Error(err) - continue - } - - // wrap it for consumers. - msg := &Message{Peer: conn.Peer, Data: wrapper.GetMessage()} - s.Chan.Incoming <- msg + msg := &Message{Peer: conn.Peer, Data: data} + s.toFilter <- msg } } out: @@ -317,6 +318,39 @@ out: s.connsLock.Unlock() } +type newFilterInfo struct { + Type PBWrapper_MessageType + resp chan chan *Message +} + +func (s *Swarm) routeMessages() { + for { + select { + case mes, ok := <-s.toFilter: + if !ok { + return + } + wrapper, err := Unwrap(mes.Data) + if err != nil { + u.PErr("error in route messages: %s\n", err) + } + + ch, ok := s.filterChans[PBWrapper_MessageType(wrapper.GetType())] + if !ok { + u.PErr("Received message with invalid type: %d\n", wrapper.GetType()) + continue + } + + mes.Data = wrapper.GetMessage() + ch <- mes + case gchan := <-s.newFilters: + nch := make(chan *Message) + s.filterChans[gchan.Type] = nch + gchan.resp <- nch + } + } +} + func (s *Swarm) Find(key u.Key) *peer.Peer { s.connsLock.RLock() defer s.connsLock.RUnlock() @@ -414,8 +448,8 @@ func (s *Swarm) Error(e error) { s.Chan.Errors <- e } -func (s *Swarm) GetChan() *Chan { - return s.Chan +func (s *Swarm) GetErrChan() chan error { + return s.Chan.Errors } func Wrap(data []byte, typ PBWrapper_MessageType) ([]byte, error) { @@ -439,5 +473,15 @@ func Unwrap(data []byte) (*PBWrapper, error) { return mes, nil } +func (s *Swarm) GetChannel(typ PBWrapper_MessageType) chan *Message { + nfi := &newFilterInfo{ + Type: typ, + resp: make(chan chan *Message), + } + s.newFilters <- nfi + + return <-nfi.resp +} + // Temporary to ensure that the Swarm always matches the Network interface as we are changing it var _ Network = &Swarm{}