diff --git a/p2p/net/interface.go b/p2p/net/interface.go index 74add2bb6..577c496cb 100644 --- a/p2p/net/interface.go +++ b/p2p/net/interface.go @@ -111,6 +111,10 @@ type Dialer interface { // ConnsToPeer returns the connections in this Netowrk for given peer. ConnsToPeer(p peer.ID) []Conn + + // Notify/StopNotify register and unregister a notifiee for signals + Notify(Notifiee) + StopNotify(Notifiee) } // Connectedness signals the capacity for a connection with a given node. @@ -131,3 +135,16 @@ const ( // (should signal "made effort, failed") CannotConnect ) + +// Notifiee is an interface for an object wishing to receive +// notifications from a Network. +type Notifiee interface { + Connected(Network, Conn) // called when a connection opened + Disconnected(Network, Conn) // called when a connection closed + OpenedStream(Network, Stream) // called when a stream opened + ClosedStream(Network, Stream) // called when a stream closed + + // TODO + // PeerConnected(Network, peer.ID) // called when a peer connected + // PeerDisconnected(Network, peer.ID) // called when a peer disconnected +} diff --git a/p2p/net/mock/mock_conn.go b/p2p/net/mock/mock_conn.go index 7e58eaae5..23a4ead62 100644 --- a/p2p/net/mock/mock_conn.go +++ b/p2p/net/mock/mock_conn.go @@ -37,6 +37,9 @@ func (c *conn) Close() error { s.Close() } c.net.removeConn(c) + c.net.notifyAll(func(n inet.Notifiee) { + n.Disconnected(c.net, c) + }) return nil } @@ -73,11 +76,17 @@ func (c *conn) allStreams() []inet.Stream { func (c *conn) remoteOpenedStream(s *stream) { c.addStream(s) c.net.handleNewStream(s) + c.net.notifyAll(func(n inet.Notifiee) { + n.OpenedStream(c.net, s) + }) } func (c *conn) openStream() *stream { sl, sr := c.link.newStreamPair() c.addStream(sl) + c.net.notifyAll(func(n inet.Notifiee) { + n.OpenedStream(c.net, sl) + }) c.rconn.remoteOpenedStream(sr) return sl } diff --git a/p2p/net/mock/mock_notif_test.go b/p2p/net/mock/mock_notif_test.go new file mode 100644 index 000000000..1b80ecb96 --- /dev/null +++ b/p2p/net/mock/mock_notif_test.go @@ -0,0 +1,198 @@ +package mocknet + +import ( + "testing" + "time" + + inet "github.com/jbenet/go-ipfs/p2p/net" + + context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context" +) + +func TestNotifications(t *testing.T) { + t.Parallel() + + mn, err := FullMeshLinked(context.Background(), 5) + if err != nil { + t.Fatal(err) + } + + timeout := 5 * time.Second + + // signup notifs + nets := mn.Nets() + notifiees := make([]*netNotifiee, len(nets)) + for i, pn := range nets { + n := newNetNotifiee() + pn.Notify(n) + notifiees[i] = n + } + + // connect all + for _, n1 := range nets { + for _, n2 := range nets { + if n1 == n2 { + continue + } + if _, err := mn.ConnectNets(n1, n2); err != nil { + t.Fatal(err) + } + } + } + + // test everyone got the correct connection opened calls + for i, s := range nets { + n := notifiees[i] + for _, s2 := range nets { + cos := s.ConnsToPeer(s2.LocalPeer()) + func() { + for i := 0; i < len(cos); i++ { + var c inet.Conn + select { + case c = <-n.connected: + case <-time.After(timeout): + t.Fatal("timeout") + } + for _, c2 := range cos { + if c == c2 { + t.Log("got notif for conn") + return + } + } + t.Error("connection not found") + } + }() + } + } + + complement := func(c inet.Conn) (inet.Network, *netNotifiee, *conn) { + for i, s := range nets { + for _, c2 := range s.Conns() { + if c2.(*conn).rconn == c { + return s, notifiees[i], c2.(*conn) + } + } + } + t.Fatal("complementary conn not found", c) + return nil, nil, nil + } + + testOCStream := func(n *netNotifiee, s inet.Stream) { + var s2 inet.Stream + select { + case s2 = <-n.openedStream: + t.Log("got notif for opened stream") + case <-time.After(timeout): + t.Fatal("timeout") + } + if s != nil && s != s2 { + t.Fatalf("got incorrect stream %p %p", s, s2) + } + + select { + case s2 = <-n.closedStream: + t.Log("got notif for closed stream") + case <-time.After(timeout): + t.Fatal("timeout") + } + if s != nil && s != s2 { + t.Fatalf("got incorrect stream %p %p", s, s2) + } + } + + streams := make(chan inet.Stream) + for _, s := range nets { + s.SetStreamHandler(func(s inet.Stream) { + streams <- s + s.Close() + }) + } + + // there's one stream per conn that we need to drain.... + // unsure where these are coming from + for i, _ := range nets { + n := notifiees[i] + testOCStream(n, nil) + testOCStream(n, nil) + testOCStream(n, nil) + testOCStream(n, nil) + } + + // open a streams in each conn + for i, s := range nets { + conns := s.Conns() + for _, c := range conns { + _, n2, c2 := complement(c) + st1, err := c.NewStream() + if err != nil { + t.Error(err) + } else { + t.Logf("%s %s <--%p--> %s %s", c.LocalPeer(), c.LocalMultiaddr(), st1, c.RemotePeer(), c.RemoteMultiaddr()) + // st1.Write([]byte("hello")) + st1.Close() + st2 := <-streams + t.Logf("%s %s <--%p--> %s %s", c2.LocalPeer(), c2.LocalMultiaddr(), st2, c2.RemotePeer(), c2.RemoteMultiaddr()) + testOCStream(notifiees[i], st1) + testOCStream(n2, st2) + } + } + } + + // close conns + for i, s := range nets { + n := notifiees[i] + for _, c := range s.Conns() { + _, n2, c2 := complement(c) + c.(*conn).Close() + c2.Close() + + var c3, c4 inet.Conn + select { + case c3 = <-n.disconnected: + case <-time.After(timeout): + t.Fatal("timeout") + } + if c != c3 { + t.Fatal("got incorrect conn", c, c3) + } + + select { + case c4 = <-n2.disconnected: + case <-time.After(timeout): + t.Fatal("timeout") + } + if c2 != c4 { + t.Fatal("got incorrect conn", c, c2) + } + } + } +} + +type netNotifiee struct { + connected chan inet.Conn + disconnected chan inet.Conn + openedStream chan inet.Stream + closedStream chan inet.Stream +} + +func newNetNotifiee() *netNotifiee { + return &netNotifiee{ + connected: make(chan inet.Conn), + disconnected: make(chan inet.Conn), + openedStream: make(chan inet.Stream), + closedStream: make(chan inet.Stream), + } +} + +func (nn *netNotifiee) Connected(n inet.Network, v inet.Conn) { + nn.connected <- v +} +func (nn *netNotifiee) Disconnected(n inet.Network, v inet.Conn) { + nn.disconnected <- v +} +func (nn *netNotifiee) OpenedStream(n inet.Network, v inet.Stream) { + nn.openedStream <- v +} +func (nn *netNotifiee) ClosedStream(n inet.Network, v inet.Stream) { + nn.closedStream <- v +} diff --git a/p2p/net/mock/mock_peernet.go b/p2p/net/mock/mock_peernet.go index 633a32762..b6bbbaa61 100644 --- a/p2p/net/mock/mock_peernet.go +++ b/p2p/net/mock/mock_peernet.go @@ -31,6 +31,9 @@ type peernet struct { streamHandler inet.StreamHandler connHandler inet.ConnHandler + notifmu sync.RWMutex + notifs map[inet.Notifiee]struct{} + cg ctxgroup.ContextGroup sync.RWMutex } @@ -58,6 +61,8 @@ func newPeernet(ctx context.Context, m *mocknet, k ic.PrivKey, connsByPeer: map[peer.ID]map[*conn]struct{}{}, connsByLink: map[*link]map[*conn]struct{}{}, + + notifs: make(map[inet.Notifiee]struct{}), } n.cg.SetTeardown(n.teardown) @@ -163,6 +168,9 @@ func (pn *peernet) openConn(r peer.ID, l *link) *conn { lc, rc := l.newConnPair(pn) log.Debugf("%s opening connection to %s", pn.LocalPeer(), lc.RemotePeer()) pn.addConn(lc) + pn.notifyAll(func(n inet.Notifiee) { + n.Connected(pn, lc) + }) rc.net.remoteOpenedConn(rc) return lc } @@ -171,6 +179,9 @@ func (pn *peernet) remoteOpenedConn(c *conn) { log.Debugf("%s accepting connection from %s", pn.LocalPeer(), c.RemotePeer()) pn.addConn(c) pn.handleNewConn(c) + pn.notifyAll(func(n inet.Notifiee) { + n.Connected(pn, c) + }) } // addConn constructs and adds a connection @@ -201,13 +212,13 @@ func (pn *peernet) removeConn(c *conn) { cs, found := pn.connsByLink[c.link] if !found || len(cs) < 1 { - panic("attempting to remove a conn that doesnt exist") + panic(fmt.Sprintf("attempting to remove a conn that doesnt exist %p", c.link)) } delete(cs, c) cs, found = pn.connsByPeer[c.remote] if !found { - panic("attempting to remove a conn that doesnt exist") + panic(fmt.Sprintf("attempting to remove a conn that doesnt exist %p", c.remote)) } delete(cs, c) } @@ -360,3 +371,28 @@ func (pn *peernet) SetConnHandler(h inet.ConnHandler) { pn.connHandler = h pn.Unlock() } + +// Notify signs up Notifiee to receive signals when events happen +func (pn *peernet) Notify(f inet.Notifiee) { + pn.notifmu.Lock() + pn.notifs[f] = struct{}{} + pn.notifmu.Unlock() +} + +// StopNotify unregisters Notifiee fromr receiving signals +func (pn *peernet) StopNotify(f inet.Notifiee) { + pn.notifmu.Lock() + delete(pn.notifs, f) + pn.notifmu.Unlock() +} + +// notifyAll runs the notification function on all Notifiees +func (pn *peernet) notifyAll(notification func(f inet.Notifiee)) { + pn.notifmu.RLock() + for n := range pn.notifs { + // make sure we dont block + // and they dont block each other. + go notification(n) + } + pn.notifmu.RUnlock() +} diff --git a/p2p/net/mock/mock_stream.go b/p2p/net/mock/mock_stream.go index 71a0ba66d..e116bb117 100644 --- a/p2p/net/mock/mock_stream.go +++ b/p2p/net/mock/mock_stream.go @@ -19,8 +19,11 @@ func (s *stream) Close() error { r.Close() } if w, ok := (s.Writer).(io.Closer); ok { - return w.Close() + w.Close() } + s.conn.net.notifyAll(func(n inet.Notifiee) { + n.ClosedStream(s.conn.net, s) + }) return nil } diff --git a/p2p/net/swarm/swarm.go b/p2p/net/swarm/swarm.go index f167dec4c..56e88445c 100644 --- a/p2p/net/swarm/swarm.go +++ b/p2p/net/swarm/swarm.go @@ -4,6 +4,7 @@ package swarm import ( "fmt" + "sync" "time" inet "github.com/jbenet/go-ipfs/p2p/net" @@ -38,6 +39,9 @@ type Swarm struct { backf dialbackoff dialT time.Duration // mainly for tests + notifmu sync.RWMutex + notifs map[inet.Notifiee]ps.Notifiee + cg ctxgroup.ContextGroup } @@ -54,11 +58,12 @@ func NewSwarm(ctx context.Context, listenAddrs []ma.Multiaddr, } s := &Swarm{ - swarm: ps.NewSwarm(PSTransport), - local: local, - peers: peers, - cg: ctxgroup.WithContext(ctx), - dialT: DialTimeout, + swarm: ps.NewSwarm(PSTransport), + local: local, + peers: peers, + cg: ctxgroup.WithContext(ctx), + dialT: DialTimeout, + notifs: make(map[inet.Notifiee]ps.Notifiee), } // configure Swarm @@ -177,3 +182,51 @@ func (s *Swarm) Peers() []peer.ID { func (s *Swarm) LocalPeer() peer.ID { return s.local } + +// Notify signs up Notifiee to receive signals when events happen +func (s *Swarm) Notify(f inet.Notifiee) { + // wrap with our notifiee, to translate function calls + n := &ps2netNotifee{net: (*Network)(s), not: f} + + s.notifmu.Lock() + s.notifs[f] = n + s.notifmu.Unlock() + + // register for notifications in the peer swarm. + s.swarm.Notify(n) +} + +// StopNotify unregisters Notifiee fromr receiving signals +func (s *Swarm) StopNotify(f inet.Notifiee) { + s.notifmu.Lock() + n, found := s.notifs[f] + if found { + delete(s.notifs, f) + } + s.notifmu.Unlock() + + if found { + s.swarm.StopNotify(n) + } +} + +type ps2netNotifee struct { + net *Network + not inet.Notifiee +} + +func (n *ps2netNotifee) Connected(c *ps.Conn) { + n.not.Connected(n.net, inet.Conn((*Conn)(c))) +} + +func (n *ps2netNotifee) Disconnected(c *ps.Conn) { + n.not.Disconnected(n.net, inet.Conn((*Conn)(c))) +} + +func (n *ps2netNotifee) OpenedStream(s *ps.Stream) { + n.not.OpenedStream(n.net, inet.Stream((*Stream)(s))) +} + +func (n *ps2netNotifee) ClosedStream(s *ps.Stream) { + n.not.ClosedStream(n.net, inet.Stream((*Stream)(s))) +} diff --git a/p2p/net/swarm/swarm_net.go b/p2p/net/swarm/swarm_net.go index 5df744747..561ceef82 100644 --- a/p2p/net/swarm/swarm_net.go +++ b/p2p/net/swarm/swarm_net.go @@ -154,3 +154,13 @@ func (n *Network) SetConnHandler(h inet.ConnHandler) { func (n *Network) String() string { return fmt.Sprintf("", n.LocalPeer()) } + +// Notify signs up Notifiee to receive signals when events happen +func (n *Network) Notify(f inet.Notifiee) { + n.Swarm().Notify(f) +} + +// StopNotify unregisters Notifiee fromr receiving signals +func (n *Network) StopNotify(f inet.Notifiee) { + n.Swarm().StopNotify(f) +} diff --git a/p2p/net/swarm/swarm_notif_test.go b/p2p/net/swarm/swarm_notif_test.go new file mode 100644 index 000000000..3469a1b32 --- /dev/null +++ b/p2p/net/swarm/swarm_notif_test.go @@ -0,0 +1,186 @@ +package swarm + +import ( + "testing" + "time" + + inet "github.com/jbenet/go-ipfs/p2p/net" + + context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context" +) + +func TestNotifications(t *testing.T) { + t.Parallel() + + ctx := context.Background() + swarms := makeSwarms(ctx, t, 5) + defer func() { + for _, s := range swarms { + s.Close() + } + }() + + timeout := 5 * time.Second + + // signup notifs + notifiees := make([]*netNotifiee, len(swarms)) + for i, swarm := range swarms { + n := newNetNotifiee() + swarm.Notify(n) + notifiees[i] = n + } + + connectSwarms(t, ctx, swarms) + + <-time.After(time.Millisecond) + // should've gotten 5 by now. + + // test everyone got the correct connection opened calls + for i, s := range swarms { + n := notifiees[i] + for _, s2 := range swarms { + if s == s2 { + continue + } + + cos := s.ConnectionsToPeer(s2.LocalPeer()) + func() { + for i := 0; i < len(cos); i++ { + var c inet.Conn + select { + case c = <-n.connected: + case <-time.After(timeout): + t.Fatal("timeout") + } + for _, c2 := range cos { + if c == c2 { + t.Log("got notif for conn", c) + return + } + } + t.Error("connection not found", c) + } + }() + } + } + + complement := func(c inet.Conn) (*Swarm, *netNotifiee, *Conn) { + for i, s := range swarms { + for _, c2 := range s.Connections() { + if c.LocalMultiaddr().Equal(c2.RemoteMultiaddr()) && + c2.LocalMultiaddr().Equal(c.RemoteMultiaddr()) { + return s, notifiees[i], c2 + } + } + } + t.Fatal("complementary conn not found", c) + return nil, nil, nil + } + + testOCStream := func(n *netNotifiee, s inet.Stream) { + var s2 inet.Stream + select { + case s2 = <-n.openedStream: + t.Log("got notif for opened stream") + case <-time.After(timeout): + t.Fatal("timeout") + } + if s != s2 { + t.Fatal("got incorrect stream", s.Conn(), s2.Conn()) + } + + select { + case s2 = <-n.closedStream: + t.Log("got notif for closed stream") + case <-time.After(timeout): + t.Fatal("timeout") + } + if s != s2 { + t.Fatal("got incorrect stream", s.Conn(), s2.Conn()) + } + } + + streams := make(chan inet.Stream) + for _, s := range swarms { + s.SetStreamHandler(func(s inet.Stream) { + streams <- s + s.Close() + }) + } + + // open a streams in each conn + for i, s := range swarms { + for _, c := range s.Connections() { + _, n2, _ := complement(c) + + st1, err := c.NewStream() + if err != nil { + t.Error(err) + } else { + st1.Write([]byte("hello")) + st1.Close() + testOCStream(notifiees[i], st1) + st2 := <-streams + testOCStream(n2, st2) + } + } + } + + // close conns + for i, s := range swarms { + n := notifiees[i] + for _, c := range s.Connections() { + _, n2, c2 := complement(c) + c.Close() + c2.Close() + + var c3, c4 inet.Conn + select { + case c3 = <-n.disconnected: + case <-time.After(timeout): + t.Fatal("timeout") + } + if c != c3 { + t.Fatal("got incorrect conn", c, c3) + } + + select { + case c4 = <-n2.disconnected: + case <-time.After(timeout): + t.Fatal("timeout") + } + if c2 != c4 { + t.Fatal("got incorrect conn", c, c2) + } + } + } +} + +type netNotifiee struct { + connected chan inet.Conn + disconnected chan inet.Conn + openedStream chan inet.Stream + closedStream chan inet.Stream +} + +func newNetNotifiee() *netNotifiee { + return &netNotifiee{ + connected: make(chan inet.Conn), + disconnected: make(chan inet.Conn), + openedStream: make(chan inet.Stream), + closedStream: make(chan inet.Stream), + } +} + +func (nn *netNotifiee) Connected(n inet.Network, v inet.Conn) { + nn.connected <- v +} +func (nn *netNotifiee) Disconnected(n inet.Network, v inet.Conn) { + nn.disconnected <- v +} +func (nn *netNotifiee) OpenedStream(n inet.Network, v inet.Stream) { + nn.openedStream <- v +} +func (nn *netNotifiee) ClosedStream(n inet.Network, v inet.Stream) { + nn.closedStream <- v +}