From bd636e1e95cf9947d95efbf64d2fe9e924c177b3 Mon Sep 17 00:00:00 2001 From: Juan Batiz-Benet Date: Sat, 13 Dec 2014 09:03:40 -0800 Subject: [PATCH] muxer io --- net/mux/mux.go | 154 +++++++++++------------- net/mux/mux_test.go | 278 +++++++++++--------------------------------- 2 files changed, 136 insertions(+), 296 deletions(-) diff --git a/net/mux/mux.go b/net/mux/mux.go index 4f54890e3..b835c0e9c 100644 --- a/net/mux/mux.go +++ b/net/mux/mux.go @@ -3,16 +3,15 @@ package mux import ( "errors" + "fmt" "sync" - conn "github.com/jbenet/go-ipfs/net/conn" msg "github.com/jbenet/go-ipfs/net/message" pb "github.com/jbenet/go-ipfs/net/mux/internal/pb" u "github.com/jbenet/go-ipfs/util" - ctxc "github.com/jbenet/go-ipfs/util/ctxcloser" - context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context" proto "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto" + router "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-router" ) var log = u.Logger("muxer") @@ -30,7 +29,10 @@ var ( // encapsulates and decapsulates when interfacing with its Protocols. The // Protocols do not encounter their ProtocolID. type Protocol interface { - GetPipe() *msg.Pipe + ProtocolID() pb.ProtocolID + + // Node is a router.Node, for message connectivity. + router.Node } // ProtocolMap maps ProtocolIDs to Protocols. @@ -39,9 +41,15 @@ type ProtocolMap map[pb.ProtocolID]Protocol // Muxer is a simple multiplexor that reads + writes to Incoming and Outgoing // channels. It multiplexes various protocols, wrapping and unwrapping data // with a ProtocolID. +// +// implements router.Node and router.Route type Muxer struct { + local router.Address + uplink router.Node + // Protocols are the multiplexed services. Protocols ProtocolMap + mapLock sync.Mutex bwiLock sync.Mutex bwIn uint64 @@ -50,32 +58,16 @@ type Muxer struct { bwoLock sync.Mutex bwOut uint64 msgOut uint64 - - *msg.Pipe - ctxc.ContextCloser } // NewMuxer constructs a muxer given a protocol map. -func NewMuxer(ctx context.Context, mp ProtocolMap) *Muxer { - m := &Muxer{ - Protocols: mp, - Pipe: msg.NewPipe(10), - ContextCloser: ctxc.NewContextCloser(ctx, nil), +// uplink is a Node to send all outgoing traffic to. +func NewMuxer(local router.Address, uplink router.Node) *Muxer { + return &Muxer{ + local: local, + uplink: uplink, + Protocols: ProtocolMap{}, } - - m.Children().Add(1) - go m.handleIncomingMessages() - for pid, proto := range m.Protocols { - m.Children().Add(1) - go m.handleOutgoingMessages(pid, proto) - } - - return m -} - -// GetPipe implements the Protocol interface -func (m *Muxer) GetPipe() *msg.Pipe { - return m.Pipe } // GetMessageCounts return the in/out message count measured over this muxer. @@ -104,6 +96,9 @@ func (m *Muxer) GetBandwidthTotals() (in uint64, out uint64) { // AddProtocol adds a Protocol with given ProtocolID to the Muxer. func (m *Muxer) AddProtocol(p Protocol, pid pb.ProtocolID) error { + m.mapLock.Lock() + defer m.mapLock.Unlock() + if _, found := m.Protocols[pid]; found { return errors.New("Another protocol already using this ProtocolID") } @@ -112,98 +107,89 @@ func (m *Muxer) AddProtocol(p Protocol, pid pb.ProtocolID) error { return nil } -// handleIncoming consumes the messages on the m.Incoming channel and -// routes them appropriately (to the protocols). -func (m *Muxer) handleIncomingMessages() { - defer m.Children().Done() +func (m *Muxer) Address() router.Address { + return m.local +} - for { - select { - case <-m.Closing(): - return +func (m *Muxer) HandlePacket(p router.Packet, from router.Node) error { + pkt, ok := p.(*msg.Packet) + if !ok { + return msg.ErrInvalidPayload + } - case msg, more := <-m.Incoming: - if !more { - return - } - m.Children().Add(1) - go m.handleIncomingMessage(msg) - } + if from == m.uplink { + return m.handleIncomingPacket(pkt, from) + } else { + return m.handleOutgoingPacket(pkt, from) } } -// handleIncomingMessage routes message to the appropriate protocol. -func (m *Muxer) handleIncomingMessage(m1 msg.NetMessage) { - defer m.Children().Done() +// handleIncomingPacket routes message to the appropriate protocol. +func (m *Muxer) handleIncomingPacket(p *msg.Packet, _ router.Node) error { m.bwiLock.Lock() // TODO: compensate for overhead - m.bwIn += uint64(len(m1.Data())) + m.bwIn += uint64(len(p.Data)) m.msgIn++ m.bwiLock.Unlock() - data, pid, err := unwrapData(m1.Data()) + data, pid, err := unwrapData(p.Data) if err != nil { - log.Errorf("muxer de-serializing error: %v", err) - return + return fmt.Errorf("muxer de-serializing error: %v", err) } - conn.ReleaseBuffer(m1.Data()) - m2 := msg.New(m1.Peer(), data) + // TODO: fix this when mpool is fixed. + // conn.ReleaseBuffer(m1.Data()) + + p.Data = data + + m.mapLock.Lock() proto, found := m.Protocols[pid] + m.mapLock.Unlock() + if !found { - log.Errorf("muxer unknown protocol %v", pid) - return + return fmt.Errorf("muxer: unknown protocol %v", pid) } - select { - case proto.GetPipe().Incoming <- m2: - case <-m.Closing(): - return - } + log.Debugf("muxer: outgoing packet %d -> %s", proto.ProtocolID(), m.uplink.Address()) + return proto.HandlePacket(p, m) } -// handleOutgoingMessages consumes the messages on the proto.Outgoing channel, -// wraps them and sends them out. -func (m *Muxer) handleOutgoingMessages(pid pb.ProtocolID, proto Protocol) { - defer m.Children().Done() +// handleOutgoingMessages sends out messages to the outside world +func (m *Muxer) handleOutgoingPacket(p *msg.Packet, from router.Node) error { - for { - select { - case msg, more := <-proto.GetPipe().Outgoing: - if !more { - return - } - m.handleOutgoingMessage(pid, msg) - - case <-m.Closing(): - return + var pid pb.ProtocolID + var proto Protocol + m.mapLock.Lock() + for pid2, proto2 := range m.Protocols { + if proto2 == from { + pid = pid2 + proto = proto2 + break } } -} + m.mapLock.Unlock() -// handleOutgoingMessage wraps out a message and sends it out the -func (m *Muxer) handleOutgoingMessage(pid pb.ProtocolID, m1 msg.NetMessage) { + if proto == nil { + return errors.New("muxer: packet sent from unknown protocol") + } - data, err := wrapData(m1.Data(), pid) + var err error + p.Data, err = wrapData(p.Data, pid) if err != nil { - log.Errorf("muxer serializing error: %v", err) - return + return fmt.Errorf("muxer serializing error: %v", err) } m.bwoLock.Lock() // TODO: compensate for overhead // TODO(jbenet): switch this to a goroutine to prevent sync waiting. - m.bwOut += uint64(len(data)) + m.bwOut += uint64(len(p.Data)) m.msgOut++ m.bwoLock.Unlock() - m2 := msg.New(m1.Peer(), data) - select { - case m.GetPipe().Outgoing <- m2: - case <-m.Closing(): - return - } + // TODO: add multiple uplinks + log.Debugf("muxer: incoming packet %s -> %d", m.uplink.Address(), proto.ProtocolID()) + return m.uplink.HandlePacket(p, m) } func wrapData(data []byte, pid pb.ProtocolID) ([]byte, error) { diff --git a/net/mux/mux_test.go b/net/mux/mux_test.go index 7401541c4..6cfd223d8 100644 --- a/net/mux/mux_test.go +++ b/net/mux/mux_test.go @@ -2,10 +2,7 @@ package mux import ( "bytes" - "fmt" - "sync" "testing" - "time" mh "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-multihash" msg "github.com/jbenet/go-ipfs/net/message" @@ -13,15 +10,35 @@ import ( peer "github.com/jbenet/go-ipfs/peer" testutil "github.com/jbenet/go-ipfs/util/testutil" - context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context" + router "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-router" ) type TestProtocol struct { - *msg.Pipe + mux *Muxer + pid pb.ProtocolID + msg []*msg.Packet } -func (t *TestProtocol) GetPipe() *msg.Pipe { - return t.Pipe +func (t *TestProtocol) ProtocolID() pb.ProtocolID { + return t.pid +} + +func (t *TestProtocol) Address() router.Address { + return t.pid +} + +func (t *TestProtocol) HandlePacket(p router.Packet, from router.Node) error { + pkt, ok := p.(*msg.Packet) + if !ok { + return msg.ErrInvalidPayload + } + + log.Debugf("TestProtocol %d got: %v", t, p) + if from == t.mux { + t.msg = append(t.msg, pkt) + return nil + } + return t.mux.HandlePacket(p, t) } func newPeer(t *testing.T, id string) peer.Peer { @@ -34,14 +51,14 @@ func newPeer(t *testing.T, id string) peer.Peer { return testutil.NewPeerWithID(peer.ID(mh)) } -func testMsg(t *testing.T, m msg.NetMessage, data []byte) { - if !bytes.Equal(data, m.Data()) { - t.Errorf("Data does not match: %v != %v", data, m.Data()) +func testMsg(t *testing.T, m *msg.Packet, data []byte) { + if !bytes.Equal(data, m.Data) { + t.Errorf("Data does not match: %v != %v", data, m.Data) } } -func testWrappedMsg(t *testing.T, m msg.NetMessage, pid pb.ProtocolID, data []byte) { - data2, pid2, err := unwrapData(m.Data()) +func testWrappedMsg(t *testing.T, m *msg.Packet, pid pb.ProtocolID, data []byte) { + data2, pid2, err := unwrapData(m.Data) if err != nil { t.Error(err) } @@ -56,228 +73,65 @@ func testWrappedMsg(t *testing.T, m msg.NetMessage, pid pb.ProtocolID, data []by } func TestSimpleMuxer(t *testing.T) { - ctx := context.Background() - // setup - p1 := &TestProtocol{Pipe: msg.NewPipe(10)} - p2 := &TestProtocol{Pipe: msg.NewPipe(10)} + peer1 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275aaaaaa") + peer2 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275bbbbbb") + + uplink := router.NewQueueNode("queue", make(chan router.Packet, 10)) + mux1 := NewMuxer(string(peer1.ID()), uplink) + pid1 := pb.ProtocolID_Test pid2 := pb.ProtocolID_Routing - mux1 := NewMuxer(ctx, ProtocolMap{ - pid1: p1, - pid2: p2, - }) - peer1 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275aaaaaa") - // peer2 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275bbbbbb") + p1 := &TestProtocol{mux1, pid1, nil} + p2 := &TestProtocol{mux1, pid2, nil} + mux1.AddProtocol(p1, pid1) + mux1.AddProtocol(p2, pid2) // test outgoing p1 for _, s := range []string{"foo", "bar", "baz"} { - p1.Outgoing <- msg.New(peer1, []byte(s)) - testWrappedMsg(t, <-mux1.Outgoing, pid1, []byte(s)) + + pkt := msg.Packet{Src: peer1, Dst: peer2, Data: []byte(s)} + if err := p1.HandlePacket(&pkt, nil); err != nil { + t.Fatal(err) + } + testWrappedMsg(t, (<-uplink.Queue()).(*msg.Packet), pid1, []byte(s)) } // test incoming p1 - for _, s := range []string{"foo", "bar", "baz"} { + for i, s := range []string{"foo", "bar", "baz"} { d, err := wrapData([]byte(s), pid1) if err != nil { t.Error(err) } - mux1.Incoming <- msg.New(peer1, d) - testMsg(t, <-p1.Incoming, []byte(s)) + + pkt := msg.Packet{Src: peer1, Dst: peer2, Data: d} + if err := mux1.HandlePacket(&pkt, uplink); err != nil { + t.Fatal(err) + } + testMsg(t, p1.msg[i], []byte(s)) } // test outgoing p2 for _, s := range []string{"foo", "bar", "baz"} { - p2.Outgoing <- msg.New(peer1, []byte(s)) - testWrappedMsg(t, <-mux1.Outgoing, pid2, []byte(s)) + + pkt := msg.Packet{Src: peer1, Dst: peer2, Data: []byte(s)} + if err := p2.HandlePacket(&pkt, nil); err != nil { + t.Fatal(err) + } + testWrappedMsg(t, (<-uplink.Queue()).(*msg.Packet), pid2, []byte(s)) } // test incoming p2 - for _, s := range []string{"foo", "bar", "baz"} { + for i, s := range []string{"foo", "bar", "baz"} { d, err := wrapData([]byte(s), pid2) if err != nil { - t.Error(err) - } - mux1.Incoming <- msg.New(peer1, d) - testMsg(t, <-p2.Incoming, []byte(s)) - } -} - -func TestSimultMuxer(t *testing.T) { - if testing.Short() { - t.SkipNow() - } - // run muxer - ctx, cancel := context.WithCancel(context.Background()) - - // setup - p1 := &TestProtocol{Pipe: msg.NewPipe(10)} - p2 := &TestProtocol{Pipe: msg.NewPipe(10)} - pid1 := pb.ProtocolID_Test - pid2 := pb.ProtocolID_Identify - mux1 := NewMuxer(ctx, ProtocolMap{ - pid1: p1, - pid2: p2, - }) - peer1 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275aaaaaa") - // peer2 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275bbbbbb") - - // counts - total := 10000 - speed := time.Microsecond * 1 - counts := [2][2][2]int{} - var countsLock sync.Mutex - - // run producers at every end sending incrementing messages - produceOut := func(pid pb.ProtocolID, size int) { - limiter := time.Tick(speed) - for i := 0; i < size; i++ { - <-limiter - s := fmt.Sprintf("proto %v out %v", pid, i) - m := msg.New(peer1, []byte(s)) - mux1.Protocols[pid].GetPipe().Outgoing <- m - countsLock.Lock() - counts[pid][0][0]++ - countsLock.Unlock() - // log.Debug("sent %v", s) - } - } - - produceIn := func(pid pb.ProtocolID, size int) { - limiter := time.Tick(speed) - for i := 0; i < size; i++ { - <-limiter - s := fmt.Sprintf("proto %v in %v", pid, i) - d, err := wrapData([]byte(s), pid) - if err != nil { - t.Error(err) - } - - m := msg.New(peer1, d) - mux1.Incoming <- m - countsLock.Lock() - counts[pid][1][0]++ - countsLock.Unlock() - // log.Debug("sent %v", s) - } - } - - consumeOut := func() { - for { - select { - case m := <-mux1.Outgoing: - data, pid, err := unwrapData(m.Data()) - if err != nil { - t.Error(err) - } - - // log.Debug("got %v", string(data)) - _ = data - countsLock.Lock() - counts[pid][1][1]++ - countsLock.Unlock() - - case <-ctx.Done(): - return - } - } - } - - consumeIn := func(pid pb.ProtocolID) { - for { - select { - case m := <-mux1.Protocols[pid].GetPipe().Incoming: - countsLock.Lock() - counts[pid][0][1]++ - countsLock.Unlock() - // log.Debug("got %v", string(m.Data())) - _ = m - case <-ctx.Done(): - return - } - } - } - - go produceOut(pid1, total) - go produceOut(pid2, total) - go produceIn(pid1, total) - go produceIn(pid2, total) - go consumeOut() - go consumeIn(pid1) - go consumeIn(pid2) - - limiter := time.Tick(speed) - for { - <-limiter - countsLock.Lock() - got := counts[0][0][0] + counts[0][0][1] + - counts[0][1][0] + counts[0][1][1] + - counts[1][0][0] + counts[1][0][1] + - counts[1][1][0] + counts[1][1][1] - countsLock.Unlock() - - if got == total*8 { - cancel() - return - } - } - -} - -func TestStopping(t *testing.T) { - ctx := context.Background() - - // setup - p1 := &TestProtocol{Pipe: msg.NewPipe(10)} - p2 := &TestProtocol{Pipe: msg.NewPipe(10)} - pid1 := pb.ProtocolID_Test - pid2 := pb.ProtocolID_Identify - mux1 := NewMuxer(ctx, ProtocolMap{ - pid1: p1, - pid2: p2, - }) - peer1 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275aaaaaa") - // peer2 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275bbbbbb") - - // test outgoing p1 - for _, s := range []string{"foo1", "bar1", "baz1"} { - p1.Outgoing <- msg.New(peer1, []byte(s)) - testWrappedMsg(t, <-mux1.Outgoing, pid1, []byte(s)) - } - - // test incoming p1 - for _, s := range []string{"foo2", "bar2", "baz2"} { - d, err := wrapData([]byte(s), pid1) - if err != nil { - t.Error(err) - } - mux1.Incoming <- msg.New(peer1, d) - testMsg(t, <-p1.Incoming, []byte(s)) - } - - mux1.Close() // waits - - // test outgoing p1 - for _, s := range []string{"foo3", "bar3", "baz3"} { - p1.Outgoing <- msg.New(peer1, []byte(s)) - select { - case m := <-mux1.Outgoing: - t.Errorf("should not have received anything. Got: %v", string(m.Data())) - case <-time.After(time.Millisecond): - } - } - - // test incoming p1 - for _, s := range []string{"foo4", "bar4", "baz4"} { - d, err := wrapData([]byte(s), pid1) - if err != nil { - t.Error(err) - } - mux1.Incoming <- msg.New(peer1, d) - select { - case <-p1.Incoming: - t.Error("should not have received anything.") - case <-time.After(time.Millisecond): + t.Fatal(err) } + + pkt := msg.Packet{Src: peer1, Dst: peer2, Data: d} + if err := mux1.HandlePacket(&pkt, uplink); err != nil { + t.Fatal(err) + } + testMsg(t, p2.msg[i], []byte(s)) } }