From da734fa4c8cc6156534972cf269ffb4beb8435ee Mon Sep 17 00:00:00 2001 From: Juan Batiz-Benet Date: Sun, 21 Dec 2014 01:43:58 -0800 Subject: [PATCH] swarm: learned to use a ConnHandler --- net/swarm/swarm.go | 24 ++++++++++- net/swarm/swarm_conn.go | 4 ++ net/swarm/swarm_listen.go | 33 ++++++++-------- net/swarm/swarm_test.go | 83 +++++++++++++++++++++++++++------------ 4 files changed, 101 insertions(+), 43 deletions(-) diff --git a/net/swarm/swarm.go b/net/swarm/swarm.go index baffa5423..a8a0fc7b2 100644 --- a/net/swarm/swarm.go +++ b/net/swarm/swarm.go @@ -24,6 +24,7 @@ type Swarm struct { swarm *ps.Swarm local peer.ID peers peer.Peerstore + connh ConnHandler cg ctxgroup.ContextGroup } @@ -41,7 +42,7 @@ func NewSwarm(ctx context.Context, listenAddrs []ma.Multiaddr, // configure Swarm s.cg.SetTeardown(s.teardown) - s.swarm.SetConnHandler(s.connHandler) + s.SetConnHandler(nil) // make sure to setup our own conn handler. return s, s.listen(listenAddrs) } @@ -65,6 +66,27 @@ func (s *Swarm) StreamSwarm() *ps.Swarm { return s.swarm } +// SetConnHandler assigns the handler for new connections. +// See peerstream. You will rarely use this. See SetStreamHandler +func (s *Swarm) SetConnHandler(handler ConnHandler) { + + // handler is nil if user wants to clear the old handler. + if handler == nil { + s.swarm.SetConnHandler(func(psconn *ps.Conn) { + s.connHandler(psconn) + }) + return + } + + s.swarm.SetConnHandler(func(psconn *ps.Conn) { + // sc is nil if closed in our handler. + if sc := s.connHandler(psconn); sc != nil { + // call the user's handler. in a goroutine for sync safety. + go handler(sc) + } + }) +} + // SetStreamHandler assigns the handler for new streams. // See peerstream. func (s *Swarm) SetStreamHandler(handler StreamHandler) { diff --git a/net/swarm/swarm_conn.go b/net/swarm/swarm_conn.go index 0525d73f3..95fe21273 100644 --- a/net/swarm/swarm_conn.go +++ b/net/swarm/swarm_conn.go @@ -24,6 +24,10 @@ import ( // layers do build up pieces of functionality. and they're all just io.RW :) ) type Conn ps.Conn +// ConnHandler is called when new conns are opened from remote peers. +// See peerstream.ConnHandler +type ConnHandler func(*Conn) + func (c *Conn) StreamConn() *ps.Conn { return (*ps.Conn)(c) } diff --git a/net/swarm/swarm_listen.go b/net/swarm/swarm_listen.go index c984a9276..bcc55cad6 100644 --- a/net/swarm/swarm_listen.go +++ b/net/swarm/swarm_listen.go @@ -65,21 +65,22 @@ func (s *Swarm) setupListener(maddr ma.Multiaddr) error { // here we configure it slightly. Note that this is sequential, so if anything // will take a while do it in a goroutine. // See https://godoc.org/github.com/jbenet/go-peerstream for more information -func (s *Swarm) connHandler(c *ps.Conn) { - go func() { - ctx := context.Background() - // this context is for running the handshake, which -- when receiveing connections - // -- we have no bound on beyond what the transport protocol bounds it at. - // note that setup + the handshake are bounded by underlying io. - // (i.e. if TCP or UDP disconnects (or the swarm closes), we're done. - // Q: why not have a shorter handshake? think about an HTTP server on really slow conns. - // as long as the conn is live (TCP says its online), it tries its best. we follow suit.) +func (s *Swarm) connHandler(c *ps.Conn) *Conn { + ctx := context.Background() + // this context is for running the handshake, which -- when receiveing connections + // -- we have no bound on beyond what the transport protocol bounds it at. + // note that setup + the handshake are bounded by underlying io. + // (i.e. if TCP or UDP disconnects (or the swarm closes), we're done. + // Q: why not have a shorter handshake? think about an HTTP server on really slow conns. + // as long as the conn is live (TCP says its online), it tries its best. we follow suit.) - if _, err := s.newConnSetup(ctx, c); err != nil { - log.Error(err) - log.Event(ctx, "newConnHandlerDisconnect", lgbl.NetConn(c.NetConn()), lgbl.Error(err)) - c.Close() // boom. close it. - return - } - }() + sc, err := s.newConnSetup(ctx, c) + if err != nil { + log.Error(err) + log.Event(ctx, "newConnHandlerDisconnect", lgbl.NetConn(c.NetConn()), lgbl.Error(err)) + c.Close() // boom. close it. + return nil + } + + return sc } diff --git a/net/swarm/swarm_test.go b/net/swarm/swarm_test.go index f13b80c23..c0a1ab9fa 100644 --- a/net/swarm/swarm_test.go +++ b/net/swarm/swarm_test.go @@ -72,6 +72,34 @@ func makeSwarms(ctx context.Context, t *testing.T, num int) ([]*Swarm, []testuti return swarms, peersnp } +func connectSwarms(t *testing.T, ctx context.Context, swarms []*Swarm, peersnp []testutil.PeerNetParams) { + + var wg sync.WaitGroup + connect := func(s *Swarm, dst peer.ID, addr ma.Multiaddr) { + // TODO: make a DialAddr func. + s.peers.AddAddress(dst, addr) + if _, err := s.Dial(ctx, dst); err != nil { + t.Fatal("error swarm dialing to peer", err) + } + wg.Done() + } + + log.Info("Connecting swarms simultaneously.") + for _, s := range swarms { + for _, p := range peersnp { + if p.ID != s.local { // don't connect to self. + wg.Add(1) + connect(s, p.ID, p.Addr) + } + } + } + wg.Wait() + + for _, s := range swarms { + log.Infof("%s swarm routing table: %s", s.local, s.Peers()) + } +} + func SubtestSwarm(t *testing.T, SwarmNum int, MsgNum int) { // t.Skip("skipping for another test") @@ -79,32 +107,7 @@ func SubtestSwarm(t *testing.T, SwarmNum int, MsgNum int) { swarms, peersnp := makeSwarms(ctx, t, SwarmNum) // connect everyone - { - var wg sync.WaitGroup - connect := func(s *Swarm, dst peer.ID, addr ma.Multiaddr) { - // TODO: make a DialAddr func. - s.peers.AddAddress(dst, addr) - if _, err := s.Dial(ctx, dst); err != nil { - t.Fatal("error swarm dialing to peer", err) - } - wg.Done() - } - - log.Info("Connecting swarms simultaneously.") - for _, s := range swarms { - for _, p := range peersnp { - if p.ID != s.local { // don't connect to self. - wg.Add(1) - connect(s, p.ID, p.Addr) - } - } - } - wg.Wait() - - for _, s := range swarms { - log.Infof("%s swarm routing table: %s", s.local, s.Peers()) - } - } + connectSwarms(t, ctx, swarms, peersnp) // ping/pong for _, s1 := range swarms { @@ -229,3 +232,31 @@ func TestSwarm(t *testing.T) { swarms := 5 SubtestSwarm(t, swarms, msgs) } + +func TestConnHandler(t *testing.T) { + // t.Skip("skipping for another test") + + ctx := context.Background() + swarms, peersnp := makeSwarms(ctx, t, 5) + + gotconn := make(chan struct{}, 10) + swarms[0].SetConnHandler(func(conn *Conn) { + gotconn <- struct{}{} + }) + + connectSwarms(t, ctx, swarms, peersnp) + + <-time.After(time.Millisecond) + // should've gotten 5 by now. + close(gotconn) + + expect := 4 + actual := 0 + for _ = range gotconn { + actual++ + } + + if actual != expect { + t.Fatal("should have connected to %d swarms. got: %d", actual, expect) + } +}