From 444f47d7f5a81f442d2d6f95b320bb0e5e4cb3cd Mon Sep 17 00:00:00 2001 From: Juan Batiz-Benet Date: Wed, 17 Dec 2014 06:58:55 -0800 Subject: [PATCH] mock2: link map fixes --- net/mock2/mock_link.go | 4 +- net/mock2/mock_net.go | 58 ++++++++----- net/mock2/mock_test.go | 188 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 225 insertions(+), 25 deletions(-) create mode 100644 net/mock2/mock_test.go diff --git a/net/mock2/mock_link.go b/net/mock2/mock_link.go index 731a32ead..bc81d427b 100644 --- a/net/mock2/mock_link.go +++ b/net/mock2/mock_link.go @@ -19,8 +19,8 @@ type link struct { sync.RWMutex } -func newLink(mn *mocknet) *link { - return &link{mock: mn, opts: mn.linkDefaults} +func newLink(mn *mocknet, opts LinkOptions) *link { + return &link{mock: mn, opts: opts} } func (l *link) newConnPair() (*conn, *conn) { diff --git a/net/mock2/mock_net.go b/net/mock2/mock_net.go index cbedf9327..a4696c5a6 100644 --- a/net/mock2/mock_net.go +++ b/net/mock2/mock_net.go @@ -164,18 +164,21 @@ func (mn *mocknet) validate(n inet.Network) (*peernet, error) { } func (mn *mocknet) LinkNets(n1, n2 inet.Network) (Link, error) { - mn.Lock() - defer mn.Unlock() + mn.RLock() + n1r, err1 := mn.validate(n1) + n2r, err2 := mn.validate(n1) + ld := mn.linkDefaults + mn.RUnlock() - if _, err := mn.validate(n1); err != nil { - return nil, err + if err1 != nil { + return nil, err1 + } + if err2 != nil { + return nil, err2 } - if _, err := mn.validate(n2); err != nil { - return nil, err - } - - l := newLink(mn) + l := newLink(mn, ld) + l.nets = append(l.nets, n1r, n2r) mn.addLink(l) return l, nil } @@ -209,13 +212,31 @@ func (mn *mocknet) UnlinkNets(n1, n2 inet.Network) error { return mn.UnlinkPeers(n1.LocalPeer(), n2.LocalPeer()) } +// get from the links map. and lazily contruct. +func (mn *mocknet) linksMapGet(p1, p2 peer.Peer) *map[*link]struct{} { + l1, found := mn.links[pid(p1)] + if !found { + mn.links[pid(p1)] = map[peerID]map[*link]struct{}{} + l1 = mn.links[pid(p1)] // so we make sure it's there. + } + + l2, found := l1[pid(p2)] + if !found { + m := map[*link]struct{}{} + l1[pid(p2)] = m + l2 = l1[pid(p2)] + } + + return &l2 +} + func (mn *mocknet) addLink(l *link) { mn.Lock() defer mn.Unlock() n1, n2 := l.nets[0], l.nets[1] - mn.links[pid(n1.peer)][pid(n2.peer)][l] = struct{}{} - mn.links[pid(n2.peer)][pid(n1.peer)][l] = struct{}{} + (*mn.linksMapGet(n1.peer, n2.peer))[l] = struct{}{} + (*mn.linksMapGet(n2.peer, n1.peer))[l] = struct{}{} } func (mn *mocknet) removeLink(l *link) { @@ -223,8 +244,8 @@ func (mn *mocknet) removeLink(l *link) { defer mn.Unlock() n1, n2 := l.nets[0], l.nets[1] - delete(mn.links[pid(n1.peer)][pid(n2.peer)], l) - delete(mn.links[pid(n2.peer)][pid(n1.peer)], l) + delete(*mn.linksMapGet(n1.peer, n2.peer), l) + delete(*mn.linksMapGet(n2.peer, n1.peer), l) } func (mn *mocknet) ConnectAll() error { @@ -263,16 +284,7 @@ func (mn *mocknet) LinksBetweenPeers(p1, p2 peer.Peer) []Link { mn.RLock() defer mn.RUnlock() - ls1, found := mn.links[pid(p1)] - if !found { - return nil - } - - ls2, found := ls1[pid(p2)] - if !found { - return nil - } - + ls2 := *mn.linksMapGet(p1, p2) cp := make([]Link, 0, len(ls2)) for l := range ls2 { cp = append(cp, l) diff --git a/net/mock2/mock_test.go b/net/mock2/mock_test.go new file mode 100644 index 000000000..b9cf83ecf --- /dev/null +++ b/net/mock2/mock_test.go @@ -0,0 +1,188 @@ +package mocknet + +import ( + "bytes" + "io" + "math/rand" + "sync" + "testing" + + inet "github.com/jbenet/go-ipfs/net" + + context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context" +) + +// func TestNetworkSetup(t *testing.T) { + +// p1 := testutil.RandPeer() +// p2 := testutil.RandPeer() +// p3 := testutil.RandPeer() +// peers := []peer.Peer{p1, p2, p3} + +// nets, err := MakeNetworks(context.Background(), peers) +// if err != nil { +// t.Fatal(err) +// } + +// // check things + +// if len(nets) != 3 { +// t.Error("nets must be 3") +// } + +// for i, n := range nets { +// if n.local != peers[i] { +// t.Error("peer mismatch") +// } + +// if len(n.conns) != len(nets) { +// t.Error("conn mismatch") +// } + +// for _, c := range n.conns { +// if c.remote.conns[n.local] == nil { +// t.Error("conn other side fail") +// } +// if c.remote.conns[n.local].remote.local != n.local { +// t.Error("conn other side fail") +// } +// } + +// } + +// } + +func TestStreams(t *testing.T) { + + mn, err := FullMeshConnected(context.Background(), 3) + if err != nil { + t.Fatal(err) + } + + handler := func(s inet.Stream) { + go func() { + b := make([]byte, 4) + if _, err := io.ReadFull(s, b); err != nil { + panic(err) + } + if !bytes.Equal(b, []byte("beep")) { + panic("bytes mismatch") + } + if _, err := s.Write([]byte("boop")); err != nil { + panic(err) + } + s.Close() + }() + } + + nets := mn.Nets() + for _, n := range nets { + n.SetHandler(inet.ProtocolDHT, handler) + } + + s, err := nets[0].NewStream(inet.ProtocolDHT, nets[1].LocalPeer()) + if err != nil { + t.Fatal(err) + } + + if _, err := s.Write([]byte("beep")); err != nil { + panic(err) + } + b := make([]byte, 4) + if _, err := io.ReadFull(s, b); err != nil { + panic(err) + } + if !bytes.Equal(b, []byte("boop")) { + panic("bytes mismatch 2") + } + +} + +func makePinger(st string, n int) func(inet.Stream) { + return func(s inet.Stream) { + go func() { + defer s.Close() + + for i := 0; i < n; i++ { + b := make([]byte, 4+len(st)) + if _, err := s.Write([]byte("ping" + st)); err != nil { + panic(err) + } + if _, err := io.ReadFull(s, b); err != nil { + panic(err) + } + if !bytes.Equal(b, []byte("pong"+st)) { + panic("bytes mismatch") + } + } + }() + } +} + +func makePonger(st string) func(inet.Stream) { + return func(s inet.Stream) { + go func() { + defer s.Close() + + for { + b := make([]byte, 4+len(st)) + if _, err := io.ReadFull(s, b); err != nil { + if err == io.EOF { + return + } + panic(err) + } + if !bytes.Equal(b, []byte("ping"+st)) { + panic("bytes mismatch") + } + if _, err := s.Write([]byte("pong" + st)); err != nil { + panic(err) + } + } + }() + } +} + +func TestStreamsStress(t *testing.T) { + + mn, err := FullMeshConnected(context.Background(), 100) + if err != nil { + t.Fatal(err) + } + + protos := []inet.ProtocolID{ + inet.ProtocolDHT, + inet.ProtocolBitswap, + inet.ProtocolDiag, + } + + nets := mn.Nets() + for _, n := range nets { + for _, p := range protos { + n.SetHandler(p, makePonger(string(p))) + } + } + + var wg sync.WaitGroup + for i := 0; i < 1000; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + from := rand.Intn(len(nets)) + to := rand.Intn(len(nets)) + p := rand.Intn(3) + proto := protos[p] + log.Debug("%d (%s) %d (%s) %d (%s)", from, nets[from], to, nets[to], p, protos[p]) + s, err := nets[from].NewStream(protos[p], nets[to].LocalPeer()) + if err != nil { + panic(err) + } + + log.Infof("%d start pinging", i) + makePinger(string(proto), rand.Intn(100))(s) + log.Infof("%d done pinging", i) + }(i) + } + + wg.Done() +}