This commit is contained in:
Juan Batiz-Benet 2014-12-13 09:03:40 -08:00
parent d94593a955
commit bd636e1e95
2 changed files with 136 additions and 296 deletions

View File

@ -3,16 +3,15 @@ package mux
import (
"errors"
"fmt"
"sync"
conn "github.com/jbenet/go-ipfs/net/conn"
msg "github.com/jbenet/go-ipfs/net/message"
pb "github.com/jbenet/go-ipfs/net/mux/internal/pb"
u "github.com/jbenet/go-ipfs/util"
ctxc "github.com/jbenet/go-ipfs/util/ctxcloser"
context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context"
proto "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto"
router "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-router"
)
var log = u.Logger("muxer")
@ -30,7 +29,10 @@ var (
// encapsulates and decapsulates when interfacing with its Protocols. The
// Protocols do not encounter their ProtocolID.
type Protocol interface {
GetPipe() *msg.Pipe
ProtocolID() pb.ProtocolID
// Node is a router.Node, for message connectivity.
router.Node
}
// ProtocolMap maps ProtocolIDs to Protocols.
@ -39,9 +41,15 @@ type ProtocolMap map[pb.ProtocolID]Protocol
// Muxer is a simple multiplexor that reads + writes to Incoming and Outgoing
// channels. It multiplexes various protocols, wrapping and unwrapping data
// with a ProtocolID.
//
// implements router.Node and router.Route
type Muxer struct {
local router.Address
uplink router.Node
// Protocols are the multiplexed services.
Protocols ProtocolMap
mapLock sync.Mutex
bwiLock sync.Mutex
bwIn uint64
@ -50,32 +58,16 @@ type Muxer struct {
bwoLock sync.Mutex
bwOut uint64
msgOut uint64
*msg.Pipe
ctxc.ContextCloser
}
// NewMuxer constructs a muxer given a protocol map.
func NewMuxer(ctx context.Context, mp ProtocolMap) *Muxer {
m := &Muxer{
Protocols: mp,
Pipe: msg.NewPipe(10),
ContextCloser: ctxc.NewContextCloser(ctx, nil),
// uplink is a Node to send all outgoing traffic to.
func NewMuxer(local router.Address, uplink router.Node) *Muxer {
return &Muxer{
local: local,
uplink: uplink,
Protocols: ProtocolMap{},
}
m.Children().Add(1)
go m.handleIncomingMessages()
for pid, proto := range m.Protocols {
m.Children().Add(1)
go m.handleOutgoingMessages(pid, proto)
}
return m
}
// GetPipe implements the Protocol interface
func (m *Muxer) GetPipe() *msg.Pipe {
return m.Pipe
}
// GetMessageCounts return the in/out message count measured over this muxer.
@ -104,6 +96,9 @@ func (m *Muxer) GetBandwidthTotals() (in uint64, out uint64) {
// AddProtocol adds a Protocol with given ProtocolID to the Muxer.
func (m *Muxer) AddProtocol(p Protocol, pid pb.ProtocolID) error {
m.mapLock.Lock()
defer m.mapLock.Unlock()
if _, found := m.Protocols[pid]; found {
return errors.New("Another protocol already using this ProtocolID")
}
@ -112,98 +107,89 @@ func (m *Muxer) AddProtocol(p Protocol, pid pb.ProtocolID) error {
return nil
}
// handleIncoming consumes the messages on the m.Incoming channel and
// routes them appropriately (to the protocols).
func (m *Muxer) handleIncomingMessages() {
defer m.Children().Done()
func (m *Muxer) Address() router.Address {
return m.local
}
for {
select {
case <-m.Closing():
return
func (m *Muxer) HandlePacket(p router.Packet, from router.Node) error {
pkt, ok := p.(*msg.Packet)
if !ok {
return msg.ErrInvalidPayload
}
case msg, more := <-m.Incoming:
if !more {
return
}
m.Children().Add(1)
go m.handleIncomingMessage(msg)
}
if from == m.uplink {
return m.handleIncomingPacket(pkt, from)
} else {
return m.handleOutgoingPacket(pkt, from)
}
}
// handleIncomingMessage routes message to the appropriate protocol.
func (m *Muxer) handleIncomingMessage(m1 msg.NetMessage) {
defer m.Children().Done()
// handleIncomingPacket routes message to the appropriate protocol.
func (m *Muxer) handleIncomingPacket(p *msg.Packet, _ router.Node) error {
m.bwiLock.Lock()
// TODO: compensate for overhead
m.bwIn += uint64(len(m1.Data()))
m.bwIn += uint64(len(p.Data))
m.msgIn++
m.bwiLock.Unlock()
data, pid, err := unwrapData(m1.Data())
data, pid, err := unwrapData(p.Data)
if err != nil {
log.Errorf("muxer de-serializing error: %v", err)
return
return fmt.Errorf("muxer de-serializing error: %v", err)
}
conn.ReleaseBuffer(m1.Data())
m2 := msg.New(m1.Peer(), data)
// TODO: fix this when mpool is fixed.
// conn.ReleaseBuffer(m1.Data())
p.Data = data
m.mapLock.Lock()
proto, found := m.Protocols[pid]
m.mapLock.Unlock()
if !found {
log.Errorf("muxer unknown protocol %v", pid)
return
return fmt.Errorf("muxer: unknown protocol %v", pid)
}
select {
case proto.GetPipe().Incoming <- m2:
case <-m.Closing():
return
}
log.Debugf("muxer: outgoing packet %d -> %s", proto.ProtocolID(), m.uplink.Address())
return proto.HandlePacket(p, m)
}
// handleOutgoingMessages consumes the messages on the proto.Outgoing channel,
// wraps them and sends them out.
func (m *Muxer) handleOutgoingMessages(pid pb.ProtocolID, proto Protocol) {
defer m.Children().Done()
// handleOutgoingMessages sends out messages to the outside world
func (m *Muxer) handleOutgoingPacket(p *msg.Packet, from router.Node) error {
for {
select {
case msg, more := <-proto.GetPipe().Outgoing:
if !more {
return
}
m.handleOutgoingMessage(pid, msg)
case <-m.Closing():
return
var pid pb.ProtocolID
var proto Protocol
m.mapLock.Lock()
for pid2, proto2 := range m.Protocols {
if proto2 == from {
pid = pid2
proto = proto2
break
}
}
}
m.mapLock.Unlock()
// handleOutgoingMessage wraps out a message and sends it out the
func (m *Muxer) handleOutgoingMessage(pid pb.ProtocolID, m1 msg.NetMessage) {
if proto == nil {
return errors.New("muxer: packet sent from unknown protocol")
}
data, err := wrapData(m1.Data(), pid)
var err error
p.Data, err = wrapData(p.Data, pid)
if err != nil {
log.Errorf("muxer serializing error: %v", err)
return
return fmt.Errorf("muxer serializing error: %v", err)
}
m.bwoLock.Lock()
// TODO: compensate for overhead
// TODO(jbenet): switch this to a goroutine to prevent sync waiting.
m.bwOut += uint64(len(data))
m.bwOut += uint64(len(p.Data))
m.msgOut++
m.bwoLock.Unlock()
m2 := msg.New(m1.Peer(), data)
select {
case m.GetPipe().Outgoing <- m2:
case <-m.Closing():
return
}
// TODO: add multiple uplinks
log.Debugf("muxer: incoming packet %s -> %d", m.uplink.Address(), proto.ProtocolID())
return m.uplink.HandlePacket(p, m)
}
func wrapData(data []byte, pid pb.ProtocolID) ([]byte, error) {

View File

@ -2,10 +2,7 @@ package mux
import (
"bytes"
"fmt"
"sync"
"testing"
"time"
mh "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-multihash"
msg "github.com/jbenet/go-ipfs/net/message"
@ -13,15 +10,35 @@ import (
peer "github.com/jbenet/go-ipfs/peer"
testutil "github.com/jbenet/go-ipfs/util/testutil"
context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context"
router "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-router"
)
type TestProtocol struct {
*msg.Pipe
mux *Muxer
pid pb.ProtocolID
msg []*msg.Packet
}
func (t *TestProtocol) GetPipe() *msg.Pipe {
return t.Pipe
func (t *TestProtocol) ProtocolID() pb.ProtocolID {
return t.pid
}
func (t *TestProtocol) Address() router.Address {
return t.pid
}
func (t *TestProtocol) HandlePacket(p router.Packet, from router.Node) error {
pkt, ok := p.(*msg.Packet)
if !ok {
return msg.ErrInvalidPayload
}
log.Debugf("TestProtocol %d got: %v", t, p)
if from == t.mux {
t.msg = append(t.msg, pkt)
return nil
}
return t.mux.HandlePacket(p, t)
}
func newPeer(t *testing.T, id string) peer.Peer {
@ -34,14 +51,14 @@ func newPeer(t *testing.T, id string) peer.Peer {
return testutil.NewPeerWithID(peer.ID(mh))
}
func testMsg(t *testing.T, m msg.NetMessage, data []byte) {
if !bytes.Equal(data, m.Data()) {
t.Errorf("Data does not match: %v != %v", data, m.Data())
func testMsg(t *testing.T, m *msg.Packet, data []byte) {
if !bytes.Equal(data, m.Data) {
t.Errorf("Data does not match: %v != %v", data, m.Data)
}
}
func testWrappedMsg(t *testing.T, m msg.NetMessage, pid pb.ProtocolID, data []byte) {
data2, pid2, err := unwrapData(m.Data())
func testWrappedMsg(t *testing.T, m *msg.Packet, pid pb.ProtocolID, data []byte) {
data2, pid2, err := unwrapData(m.Data)
if err != nil {
t.Error(err)
}
@ -56,228 +73,65 @@ func testWrappedMsg(t *testing.T, m msg.NetMessage, pid pb.ProtocolID, data []by
}
func TestSimpleMuxer(t *testing.T) {
ctx := context.Background()
// setup
p1 := &TestProtocol{Pipe: msg.NewPipe(10)}
p2 := &TestProtocol{Pipe: msg.NewPipe(10)}
peer1 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275aaaaaa")
peer2 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275bbbbbb")
uplink := router.NewQueueNode("queue", make(chan router.Packet, 10))
mux1 := NewMuxer(string(peer1.ID()), uplink)
pid1 := pb.ProtocolID_Test
pid2 := pb.ProtocolID_Routing
mux1 := NewMuxer(ctx, ProtocolMap{
pid1: p1,
pid2: p2,
})
peer1 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275aaaaaa")
// peer2 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275bbbbbb")
p1 := &TestProtocol{mux1, pid1, nil}
p2 := &TestProtocol{mux1, pid2, nil}
mux1.AddProtocol(p1, pid1)
mux1.AddProtocol(p2, pid2)
// test outgoing p1
for _, s := range []string{"foo", "bar", "baz"} {
p1.Outgoing <- msg.New(peer1, []byte(s))
testWrappedMsg(t, <-mux1.Outgoing, pid1, []byte(s))
pkt := msg.Packet{Src: peer1, Dst: peer2, Data: []byte(s)}
if err := p1.HandlePacket(&pkt, nil); err != nil {
t.Fatal(err)
}
testWrappedMsg(t, (<-uplink.Queue()).(*msg.Packet), pid1, []byte(s))
}
// test incoming p1
for _, s := range []string{"foo", "bar", "baz"} {
for i, s := range []string{"foo", "bar", "baz"} {
d, err := wrapData([]byte(s), pid1)
if err != nil {
t.Error(err)
}
mux1.Incoming <- msg.New(peer1, d)
testMsg(t, <-p1.Incoming, []byte(s))
pkt := msg.Packet{Src: peer1, Dst: peer2, Data: d}
if err := mux1.HandlePacket(&pkt, uplink); err != nil {
t.Fatal(err)
}
testMsg(t, p1.msg[i], []byte(s))
}
// test outgoing p2
for _, s := range []string{"foo", "bar", "baz"} {
p2.Outgoing <- msg.New(peer1, []byte(s))
testWrappedMsg(t, <-mux1.Outgoing, pid2, []byte(s))
pkt := msg.Packet{Src: peer1, Dst: peer2, Data: []byte(s)}
if err := p2.HandlePacket(&pkt, nil); err != nil {
t.Fatal(err)
}
testWrappedMsg(t, (<-uplink.Queue()).(*msg.Packet), pid2, []byte(s))
}
// test incoming p2
for _, s := range []string{"foo", "bar", "baz"} {
for i, s := range []string{"foo", "bar", "baz"} {
d, err := wrapData([]byte(s), pid2)
if err != nil {
t.Error(err)
}
mux1.Incoming <- msg.New(peer1, d)
testMsg(t, <-p2.Incoming, []byte(s))
}
}
func TestSimultMuxer(t *testing.T) {
if testing.Short() {
t.SkipNow()
}
// run muxer
ctx, cancel := context.WithCancel(context.Background())
// setup
p1 := &TestProtocol{Pipe: msg.NewPipe(10)}
p2 := &TestProtocol{Pipe: msg.NewPipe(10)}
pid1 := pb.ProtocolID_Test
pid2 := pb.ProtocolID_Identify
mux1 := NewMuxer(ctx, ProtocolMap{
pid1: p1,
pid2: p2,
})
peer1 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275aaaaaa")
// peer2 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275bbbbbb")
// counts
total := 10000
speed := time.Microsecond * 1
counts := [2][2][2]int{}
var countsLock sync.Mutex
// run producers at every end sending incrementing messages
produceOut := func(pid pb.ProtocolID, size int) {
limiter := time.Tick(speed)
for i := 0; i < size; i++ {
<-limiter
s := fmt.Sprintf("proto %v out %v", pid, i)
m := msg.New(peer1, []byte(s))
mux1.Protocols[pid].GetPipe().Outgoing <- m
countsLock.Lock()
counts[pid][0][0]++
countsLock.Unlock()
// log.Debug("sent %v", s)
}
}
produceIn := func(pid pb.ProtocolID, size int) {
limiter := time.Tick(speed)
for i := 0; i < size; i++ {
<-limiter
s := fmt.Sprintf("proto %v in %v", pid, i)
d, err := wrapData([]byte(s), pid)
if err != nil {
t.Error(err)
}
m := msg.New(peer1, d)
mux1.Incoming <- m
countsLock.Lock()
counts[pid][1][0]++
countsLock.Unlock()
// log.Debug("sent %v", s)
}
}
consumeOut := func() {
for {
select {
case m := <-mux1.Outgoing:
data, pid, err := unwrapData(m.Data())
if err != nil {
t.Error(err)
}
// log.Debug("got %v", string(data))
_ = data
countsLock.Lock()
counts[pid][1][1]++
countsLock.Unlock()
case <-ctx.Done():
return
}
}
}
consumeIn := func(pid pb.ProtocolID) {
for {
select {
case m := <-mux1.Protocols[pid].GetPipe().Incoming:
countsLock.Lock()
counts[pid][0][1]++
countsLock.Unlock()
// log.Debug("got %v", string(m.Data()))
_ = m
case <-ctx.Done():
return
}
}
}
go produceOut(pid1, total)
go produceOut(pid2, total)
go produceIn(pid1, total)
go produceIn(pid2, total)
go consumeOut()
go consumeIn(pid1)
go consumeIn(pid2)
limiter := time.Tick(speed)
for {
<-limiter
countsLock.Lock()
got := counts[0][0][0] + counts[0][0][1] +
counts[0][1][0] + counts[0][1][1] +
counts[1][0][0] + counts[1][0][1] +
counts[1][1][0] + counts[1][1][1]
countsLock.Unlock()
if got == total*8 {
cancel()
return
}
}
}
func TestStopping(t *testing.T) {
ctx := context.Background()
// setup
p1 := &TestProtocol{Pipe: msg.NewPipe(10)}
p2 := &TestProtocol{Pipe: msg.NewPipe(10)}
pid1 := pb.ProtocolID_Test
pid2 := pb.ProtocolID_Identify
mux1 := NewMuxer(ctx, ProtocolMap{
pid1: p1,
pid2: p2,
})
peer1 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275aaaaaa")
// peer2 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275bbbbbb")
// test outgoing p1
for _, s := range []string{"foo1", "bar1", "baz1"} {
p1.Outgoing <- msg.New(peer1, []byte(s))
testWrappedMsg(t, <-mux1.Outgoing, pid1, []byte(s))
}
// test incoming p1
for _, s := range []string{"foo2", "bar2", "baz2"} {
d, err := wrapData([]byte(s), pid1)
if err != nil {
t.Error(err)
}
mux1.Incoming <- msg.New(peer1, d)
testMsg(t, <-p1.Incoming, []byte(s))
}
mux1.Close() // waits
// test outgoing p1
for _, s := range []string{"foo3", "bar3", "baz3"} {
p1.Outgoing <- msg.New(peer1, []byte(s))
select {
case m := <-mux1.Outgoing:
t.Errorf("should not have received anything. Got: %v", string(m.Data()))
case <-time.After(time.Millisecond):
}
}
// test incoming p1
for _, s := range []string{"foo4", "bar4", "baz4"} {
d, err := wrapData([]byte(s), pid1)
if err != nil {
t.Error(err)
}
mux1.Incoming <- msg.New(peer1, d)
select {
case <-p1.Incoming:
t.Error("should not have received anything.")
case <-time.After(time.Millisecond):
t.Fatal(err)
}
pkt := msg.Packet{Src: peer1, Dst: peer2, Data: d}
if err := mux1.HandlePacket(&pkt, uplink); err != nil {
t.Fatal(err)
}
testMsg(t, p2.msg[i], []byte(s))
}
}