From 8d08e1e3d6e5cffe7becd9c43e8cb92885ee019f Mon Sep 17 00:00:00 2001 From: Juan Batiz-Benet Date: Thu, 22 Jan 2015 05:23:45 -0800 Subject: [PATCH] reuseport: respect dialer timeout --- Godeps/Godeps.json | 2 +- .../github.com/jbenet/go-reuseport/addr.go | 12 ++ .../jbenet/go-reuseport/impl_unix.go | 124 ++++++------ .../jbenet/go-reuseport/interface.go | 29 ++- .../jbenet/go-reuseport/poll/error.go | 9 + .../jbenet/go-reuseport/poll/poll_bsd.go | 59 ++++++ .../jbenet/go-reuseport/poll/poll_linux.go | 51 +++++ .../go-reuseport/poll/poll_unsupported.go | 11 ++ .../jbenet/go-reuseport/reuse_test.go | 182 +++++++++++++++++- 9 files changed, 408 insertions(+), 71 deletions(-) create mode 100644 Godeps/_workspace/src/github.com/jbenet/go-reuseport/poll/error.go create mode 100644 Godeps/_workspace/src/github.com/jbenet/go-reuseport/poll/poll_bsd.go create mode 100644 Godeps/_workspace/src/github.com/jbenet/go-reuseport/poll/poll_linux.go create mode 100644 Godeps/_workspace/src/github.com/jbenet/go-reuseport/poll/poll_unsupported.go diff --git a/Godeps/Godeps.json b/Godeps/Godeps.json index 9fc3041ee..f7c7ab7cc 100644 --- a/Godeps/Godeps.json +++ b/Godeps/Godeps.json @@ -160,7 +160,7 @@ }, { "ImportPath": "github.com/jbenet/go-reuseport", - "Rev": "a2e454f12a99b8898c41f9dcebae6c35dc3efa3a" + "Rev": "6924153aded2d61c89a83c8f0738ed4e8df9191f" }, { "ImportPath": "github.com/jbenet/go-sockaddr/net", diff --git a/Godeps/_workspace/src/github.com/jbenet/go-reuseport/addr.go b/Godeps/_workspace/src/github.com/jbenet/go-reuseport/addr.go index cfffc7c8c..f793a21d1 100644 --- a/Godeps/_workspace/src/github.com/jbenet/go-reuseport/addr.go +++ b/Godeps/_workspace/src/github.com/jbenet/go-reuseport/addr.go @@ -18,3 +18,15 @@ func ResolveAddr(network, address string) (net.Addr, error) { return net.ResolveUnixAddr(network, address) } } + +// conn is a struct that stores a raddr to get around: +// * https://github.com/golang/go/issues/9661#issuecomment-71043147 +// * https://gist.github.com/jbenet/5c191d698fe9ec58c49d +type conn struct { + net.Conn + raddr net.Addr +} + +func (c *conn) RemoteAddr() net.Addr { + return c.raddr +} diff --git a/Godeps/_workspace/src/github.com/jbenet/go-reuseport/impl_unix.go b/Godeps/_workspace/src/github.com/jbenet/go-reuseport/impl_unix.go index ed8415abd..93727fed9 100644 --- a/Godeps/_workspace/src/github.com/jbenet/go-reuseport/impl_unix.go +++ b/Godeps/_workspace/src/github.com/jbenet/go-reuseport/impl_unix.go @@ -9,6 +9,7 @@ import ( "syscall" "time" + poll "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-reuseport/poll" sockaddrnet "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-sockaddr/net" ) @@ -30,20 +31,18 @@ func socket(family, socktype, protocol int) (fd int, err error) { return -1, err } - // set non-blocking until after connect, because we cant poll using runtime :( + // cant set it until after connect // if err = syscall.SetNonblock(fd, true); err != nil { // syscall.Close(fd) // return -1, err // } if err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, soReuseAddr, 1); err != nil { - // fmt.Println("reuse addr failed") syscall.Close(fd) return -1, err } if err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, soReusePort, 1); err != nil { - // fmt.Println("reuse port failed") syscall.Close(fd) return -1, err } @@ -51,7 +50,6 @@ func socket(family, socktype, protocol int) (fd int, err error) { // set setLinger to 5 as reusing exact same (srcip:srcport, dstip:dstport) // will otherwise fail on connect. if err = setLinger(fd, 5); err != nil { - // fmt.Println("linger failed") syscall.Close(fd) return -1, err } @@ -68,13 +66,13 @@ func dial(dialer net.Dialer, netw, addr string) (c net.Conn, err error) { lprotocol int rprotocol int file *os.File + deadline time.Time remoteSockaddr syscall.Sockaddr localSockaddr syscall.Sockaddr ) netAddr, err := ResolveAddr(netw, addr) if err != nil { - // fmt.Println("resolve addr failed") return nil, err } @@ -84,6 +82,13 @@ func dial(dialer net.Dialer, netw, addr string) (c net.Conn, err error) { return nil, ErrUnsupportedProtocol } + switch { + case !dialer.Deadline.IsZero(): + deadline = dialer.Deadline + case dialer.Timeout != 0: + deadline = time.Now().Add(dialer.Timeout) + } + localSockaddr = sockaddrnet.NetAddrToSockaddr(dialer.LocalAddr) remoteSockaddr = sockaddrnet.NetAddrToSockaddr(netAddr) @@ -109,18 +114,29 @@ func dial(dialer net.Dialer, netw, addr string) (c net.Conn, err error) { // look at dialTCP in http://golang.org/src/net/tcpsock_posix.go .... ! // here we just try again 3 times. for i := 0; i < 3; i++ { + if !deadline.IsZero() && deadline.Before(time.Now()) { + err = errTimeout + break + } + if fd, err = socket(rfamily, socktype, rprotocol); err != nil { return nil, err } - if err = syscall.Bind(fd, localSockaddr); err != nil { - // fmt.Println("bind failed") + if localSockaddr != nil { + if err = syscall.Bind(fd, localSockaddr); err != nil { + syscall.Close(fd) + return nil, err + } + } + + if err = syscall.SetNonblock(fd, true); err != nil { syscall.Close(fd) return nil, err } - if err = connect(fd, remoteSockaddr); err != nil { + + if err = connect(fd, remoteSockaddr, deadline); err != nil { syscall.Close(fd) - // fmt.Println("connect failed", localSockaddr, err) continue // try again. } @@ -133,48 +149,40 @@ func dial(dialer net.Dialer, netw, addr string) (c net.Conn, err error) { if rprotocol == syscall.IPPROTO_TCP { // by default golang/net sets TCP no delay to true. if err = setNoDelay(fd, true); err != nil { - // fmt.Println("set no delay failed") syscall.Close(fd) return nil, err } } - if err = syscall.SetNonblock(fd, true); err != nil { + // File Name get be nil + file = os.NewFile(uintptr(fd), filePrefix+strconv.Itoa(os.Getpid())) + if c, err = net.FileConn(file); err != nil { syscall.Close(fd) return nil, err } - switch socktype { - case syscall.SOCK_STREAM, syscall.SOCK_SEQPACKET: - - // File Name get be nil - file = os.NewFile(uintptr(fd), filePrefix+strconv.Itoa(os.Getpid())) - if c, err = net.FileConn(file); err != nil { - // fmt.Println("fileconn failed") - syscall.Close(fd) - return nil, err - } - - case syscall.SOCK_DGRAM: - - // File Name get be nil - file = os.NewFile(uintptr(fd), filePrefix+strconv.Itoa(os.Getpid())) - if c, err = net.FileConn(file); err != nil { - // fmt.Println("fileconn failed") - syscall.Close(fd) - return nil, err - } - } - if err = file.Close(); err != nil { - // fmt.Println("file close failed") syscall.Close(fd) return nil, err } + // c = wrapConnWithRemoteAddr(c, netAddr) return c, err } +// there's a rare case where dial returns successfully but for some reason the +// RemoteAddr is not yet set. So, since we know what raddr should be, we just +// wrap it. This is not ideal in that sometimes getpeername() may return a +// different addr. But until this is fixed, best way to do it. +// * https://gist.github.com/jbenet/5c191d698fe9ec58c49d +// * https://github.com/golang/go/issues/9661#issuecomment-71043147 +func wrapConnWithRemoteAddr(c net.Conn, raddr net.Addr) net.Conn { + if c.RemoteAddr() == nil { + return &conn{Conn: c, raddr: raddr} + } + return c // it's fine, no need to wrap. +} + func listen(netw, addr string) (fd int, err error) { var ( family int @@ -185,7 +193,6 @@ func listen(netw, addr string) (fd int, err error) { netAddr, err := ResolveAddr(netw, addr) if err != nil { - // fmt.Println("resolve addr failed") return -1, err } @@ -205,7 +212,6 @@ func listen(netw, addr string) (fd int, err error) { } if err = syscall.Bind(fd, sockaddr); err != nil { - // fmt.Println("bind failed") syscall.Close(fd) return -1, err } @@ -213,7 +219,6 @@ func listen(netw, addr string) (fd int, err error) { if protocol == syscall.IPPROTO_TCP { // by default golang/net sets TCP no delay to true. if err = setNoDelay(fd, true); err != nil { - // fmt.Println("set no delay failed") syscall.Close(fd) return -1, err } @@ -239,20 +244,17 @@ func listenStream(netw, addr string) (l net.Listener, err error) { // Set backlog size to the maximum if err = syscall.Listen(fd, syscall.SOMAXCONN); err != nil { - // fmt.Println("listen failed") syscall.Close(fd) return nil, err } file = os.NewFile(uintptr(fd), filePrefix+strconv.Itoa(os.Getpid())) if l, err = net.FileListener(file); err != nil { - // fmt.Println("filelistener failed") syscall.Close(fd) return nil, err } if err = file.Close(); err != nil { - // fmt.Println("file close failed") syscall.Close(fd) return nil, err } @@ -272,13 +274,11 @@ func listenPacket(netw, addr string) (p net.PacketConn, err error) { file = os.NewFile(uintptr(fd), filePrefix+strconv.Itoa(os.Getpid())) if p, err = net.FilePacketConn(file); err != nil { - // fmt.Println("filelistener failed") syscall.Close(fd) return nil, err } if err = file.Close(); err != nil { - // fmt.Println("file close failed") syscall.Close(fd) return nil, err } @@ -298,13 +298,11 @@ func listenUDP(netw, addr string) (c net.Conn, err error) { file = os.NewFile(uintptr(fd), filePrefix+strconv.Itoa(os.Getpid())) if c, err = net.FileConn(file); err != nil { - // fmt.Println("filelistener failed") syscall.Close(fd) return nil, err } if err = file.Close(); err != nil { - // fmt.Println("file close failed") syscall.Close(fd) return nil, err } @@ -313,26 +311,36 @@ func listenUDP(netw, addr string) (c net.Conn, err error) { } // this is close to the connect() function inside stdlib/net -func connect(fd int, ra syscall.Sockaddr) error { +func connect(fd int, ra syscall.Sockaddr, deadline time.Time) error { switch err := syscall.Connect(fd, ra); err { case syscall.EINPROGRESS, syscall.EALREADY, syscall.EINTR: case nil, syscall.EISCONN: + if !deadline.IsZero() && deadline.Before(time.Now()) { + return errTimeout + } return nil default: return err } - var err error - start := time.Now() + poller, err := poll.New(fd) + if err != nil { + return err + } + for { + if err = poller.WaitWrite(deadline); err != nil { + return err + } + // if err := fd.pd.WaitWrite(); err != nil { // return err // } // i'd use the above fd.pd.WaitWrite to poll io correctly, just like net sockets... - // but of course, it uses fucking runtime_* functions that _cannot_ be used by - // non-go-stdlib source... seriously guys, what kind of bullshit is that!? + // but of course, it uses the damn runtime_* functions that _cannot_ be used by + // non-go-stdlib source... seriously guys, this is not nice. // we're relegated to using syscall.Select (what nightmare that is) or using - // a simple but totally bogus time-based wait. garbage. + // a simple but totally bogus time-based wait. such garbage. var nerr int nerr, err = syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_ERROR) if err != nil { @@ -340,14 +348,22 @@ func connect(fd int, ra syscall.Sockaddr) error { } switch err = syscall.Errno(nerr); err { case syscall.EINPROGRESS, syscall.EALREADY, syscall.EINTR: - if time.Now().Sub(start) > time.Second { - return err - } - <-time.After(20 * time.Microsecond) + continue case syscall.Errno(0), syscall.EISCONN: + if !deadline.IsZero() && deadline.Before(time.Now()) { + return errTimeout + } return nil default: return err } } } + +var errTimeout = &timeoutError{} + +type timeoutError struct{} + +func (e *timeoutError) Error() string { return "i/o timeout" } +func (e *timeoutError) Timeout() bool { return true } +func (e *timeoutError) Temporary() bool { return true } diff --git a/Godeps/_workspace/src/github.com/jbenet/go-reuseport/interface.go b/Godeps/_workspace/src/github.com/jbenet/go-reuseport/interface.go index 5433b691b..d20b3b4c6 100644 --- a/Godeps/_workspace/src/github.com/jbenet/go-reuseport/interface.go +++ b/Godeps/_workspace/src/github.com/jbenet/go-reuseport/interface.go @@ -91,21 +91,20 @@ type Dialer struct { // Returns a net.Conn created from a file discriptor for a socket // with SO_REUSEPORT and SO_REUSEADDR option set. func (d *Dialer) Dial(network, address string) (net.Conn, error) { - c, err := dial(d.D, network, address) - if err != nil { - return nil, err + if !available() { + return nil, syscall.Errno(syscall.ENOPROTOOPT) } - // there's a rare case where dial returns successfully but for some reason the - // RemoteAddr is not yet set. We wait here a while until it is, and if too long - // passes, we fail. This is horrendous. - for start := time.Now(); c.RemoteAddr() == nil; { - if time.Now().Sub(start) > (time.Millisecond * 500) { - c.Close() - return nil, ErrReuseFailed - } - - <-time.After(20 * time.Microsecond) - } - return c, nil + return dial(d.D, network, address) +} + +func (d *Dialer) deadline(def time.Duration) time.Time { + switch { + case !d.D.Deadline.IsZero(): + return d.D.Deadline + case d.D.Timeout != 0: + return time.Now().Add(d.D.Timeout) + default: + return time.Now().Add(def) + } } diff --git a/Godeps/_workspace/src/github.com/jbenet/go-reuseport/poll/error.go b/Godeps/_workspace/src/github.com/jbenet/go-reuseport/poll/error.go new file mode 100644 index 000000000..c13fb8132 --- /dev/null +++ b/Godeps/_workspace/src/github.com/jbenet/go-reuseport/poll/error.go @@ -0,0 +1,9 @@ +package poll + +var errTimeout = &timeoutError{} + +type timeoutError struct{} + +func (e *timeoutError) Error() string { return "i/o timeout" } +func (e *timeoutError) Timeout() bool { return true } +func (e *timeoutError) Temporary() bool { return true } diff --git a/Godeps/_workspace/src/github.com/jbenet/go-reuseport/poll/poll_bsd.go b/Godeps/_workspace/src/github.com/jbenet/go-reuseport/poll/poll_bsd.go new file mode 100644 index 000000000..058f21b85 --- /dev/null +++ b/Godeps/_workspace/src/github.com/jbenet/go-reuseport/poll/poll_bsd.go @@ -0,0 +1,59 @@ +// +build darwin freebsd dragonfly netbsd openbsd + +package poll + +import ( + "syscall" + "time" +) + +type Poller struct { + kqfd int + event syscall.Kevent_t +} + +func New(fd int) (p *Poller, err error) { + p = &Poller{} + + p.kqfd, err = syscall.Kqueue() + if p.kqfd == -1 || err != nil { + return nil, err + } + + p.event = syscall.Kevent_t{ + Ident: uint64(fd), + Filter: syscall.EVFILT_WRITE, + Flags: syscall.EV_ADD | syscall.EV_ENABLE | syscall.EV_ONESHOT, + Fflags: 0, + Data: 0, + Udata: nil, + } + return p, nil +} + +func (p *Poller) Close() error { + return syscall.Close(p.kqfd) +} + +func (p *Poller) WaitWrite(deadline time.Time) error { + + // setup timeout + var timeout *syscall.Timespec + if !deadline.IsZero() { + d := deadline.Sub(time.Now()) + t := syscall.NsecToTimespec(d.Nanoseconds()) + timeout = &t + } + + // wait on kevent + events := make([]syscall.Kevent_t, 1) + n, err := syscall.Kevent(p.kqfd, []syscall.Kevent_t{p.event}, events, timeout) + if err != nil { + return err + } + + if n < 1 { + return errTimeout + } + return nil +} diff --git a/Godeps/_workspace/src/github.com/jbenet/go-reuseport/poll/poll_linux.go b/Godeps/_workspace/src/github.com/jbenet/go-reuseport/poll/poll_linux.go new file mode 100644 index 000000000..b63a00091 --- /dev/null +++ b/Godeps/_workspace/src/github.com/jbenet/go-reuseport/poll/poll_linux.go @@ -0,0 +1,51 @@ +// +build linux + +package poll + +import ( + "syscall" + "time" +) + +type Poller struct { + epfd int + event syscall.EpollEvent + events [32]syscall.EpollEvent +} + +func New(fd int) (p *Poller, err error) { + p = &Poller{} + if p.epfd, err = syscall.EpollCreate1(0); err != nil { + return nil, err + } + + p.event.Events = syscall.EPOLLOUT + p.event.Fd = int32(fd) + if err = syscall.EpollCtl(p.epfd, syscall.EPOLL_CTL_ADD, fd, &p.event); err != nil { + p.Close() + return nil, err + } + + return p, nil +} + +func (p *Poller) Close() error { + return syscall.Close(p.epfd) +} + +func (p *Poller) WaitWrite(deadline time.Time) error { + msec := -1 + if !deadline.IsZero() { + d := deadline.Sub(time.Now()) + msec = int(d.Nanoseconds() / 1000000) // ms!? omg... + } + + n, err := syscall.EpollWait(p.epfd, p.events[:], msec) + if err != nil { + return err + } + if n < 1 { + return errTimeout + } + return nil +} diff --git a/Godeps/_workspace/src/github.com/jbenet/go-reuseport/poll/poll_unsupported.go b/Godeps/_workspace/src/github.com/jbenet/go-reuseport/poll/poll_unsupported.go new file mode 100644 index 000000000..bc0e5ed83 --- /dev/null +++ b/Godeps/_workspace/src/github.com/jbenet/go-reuseport/poll/poll_unsupported.go @@ -0,0 +1,11 @@ +// +build windows plan9 + +package poll + +import ( + "errors" +) + +func WaitWrite(fd int) error { + return errors.New("platform not supported") +} diff --git a/Godeps/_workspace/src/github.com/jbenet/go-reuseport/reuse_test.go b/Godeps/_workspace/src/github.com/jbenet/go-reuseport/reuse_test.go index 0dc77abba..1607077f4 100644 --- a/Godeps/_workspace/src/github.com/jbenet/go-reuseport/reuse_test.go +++ b/Godeps/_workspace/src/github.com/jbenet/go-reuseport/reuse_test.go @@ -2,11 +2,15 @@ package reuseport import ( "bytes" + "errors" + "fmt" "io" "net" "os" "strings" + "sync" "testing" + "time" ) func echo(c net.Conn) { @@ -226,7 +230,7 @@ func TestStreamListenDialSamePort(t *testing.T) { c1, err := Dial(network, l1.Addr().String(), l2.Addr().String()) if err != nil { - t.Fatal(err) + t.Fatal(err, network, l1.Addr().String(), l2.Addr().String()) continue } defer c1.Close() @@ -260,6 +264,120 @@ func TestStreamListenDialSamePort(t *testing.T) { } } +func TestStreamListenDialSamePortStressManyMsgs(t *testing.T) { + testCases := [][]string{ + []string{"tcp", "127.0.0.1:0"}, + []string{"tcp4", "127.0.0.1:0"}, + []string{"tcp6", "[::]:0"}, + } + + for _, tcase := range testCases { + subestStreamListenDialSamePortStress(t, tcase[0], tcase[1], 2, 1000) + } +} + +func TestStreamListenDialSamePortStressManyNodes(t *testing.T) { + testCases := [][]string{ + []string{"tcp", "127.0.0.1:0"}, + []string{"tcp4", "127.0.0.1:0"}, + []string{"tcp6", "[::]:0"}, + } + + for _, tcase := range testCases { + subestStreamListenDialSamePortStress(t, tcase[0], tcase[1], 50, 1) + } +} + +func TestStreamListenDialSamePortStressManyMsgsManyNodes(t *testing.T) { + testCases := [][]string{ + []string{"tcp", "127.0.0.1:0"}, + []string{"tcp4", "127.0.0.1:0"}, + []string{"tcp6", "[::]:0"}, + } + + for _, tcase := range testCases { + subestStreamListenDialSamePortStress(t, tcase[0], tcase[1], 50, 100) + } +} + +func subestStreamListenDialSamePortStress(t *testing.T, network, addr string, nodes int, msgs int) { + t.Logf("testing %s:%s %d nodes %d msgs", network, addr, nodes, msgs) + + var ls []net.Listener + for i := 0; i < nodes; i++ { + l, err := Listen(network, addr) + if err != nil { + t.Fatal(err) + } + defer l.Close() + go acceptAndEcho(l) + ls = append(ls, l) + t.Logf("listening %s", l.Addr()) + } + + // connect them all + var cs []net.Conn + for i := 0; i < nodes; i++ { + for j := 0; j < i; j++ { + if i == j { + continue // cannot do self. + } + + ia := ls[i].Addr().String() + ja := ls[j].Addr().String() + c, err := Dial(network, ia, ja) + if err != nil { + t.Fatal(network, ia, ja, err) + } + defer c.Close() + cs = append(cs, c) + t.Logf("dialed %s --> %s", c.LocalAddr(), c.RemoteAddr()) + } + } + + errs := make(chan error) + + send := func(c net.Conn, buf []byte) { + if _, err := c.Write(buf); err != nil { + errs <- err + } + } + + recv := func(c net.Conn, buf []byte) { + buf2 := make([]byte, len(buf)) + if _, err := c.Read(buf2); err != nil { + errs <- err + } + if !bytes.Equal(buf, buf2) { + errs <- fmt.Errorf("recv failure: %s <--> %s -- %s %s", c.RemoteAddr(), c.LocalAddr(), buf, buf2) + } + } + + t.Logf("sending %d msgs per conn", msgs) + go func() { + var wg sync.WaitGroup + for _, c := range cs { + wg.Add(1) + go func(c net.Conn) { + defer wg.Done() + for i := 0; i < msgs; i++ { + msg := []byte(fmt.Sprintf("message %d", i)) + send(c, msg) + recv(c, msg) + } + }(c) + } + wg.Wait() + close(errs) + }() + + for err := range errs { + if err != nil { + t.Error(err) + } + } +} + func TestPacketListenDialSamePort(t *testing.T) { any := [][]string{ @@ -343,6 +461,68 @@ func TestPacketListenDialSamePort(t *testing.T) { } } +func TestDialRespectsTimeout(t *testing.T) { + + testCases := [][]string{ + []string{"tcp", "127.0.0.1:6780", "1.2.3.4:6781"}, + []string{"tcp4", "127.0.0.1:6782", "1.2.3.4:6783"}, + []string{"tcp6", "[::1]:6784", "[::2]:6785"}, + } + + timeout := 50 * time.Millisecond + + for _, tcase := range testCases { + network := tcase[0] + laddr := tcase[1] + raddr := tcase[2] + + // l, err := Listen(network, raddr) + // if err != nil { + // t.Error("without a listener it wont work") + // continue + // } + // defer l.Close() + + nladdr, err := ResolveAddr(network, laddr) + if err != nil { + t.Error("failed to resolve addr", network, laddr, err) + continue + } + t.Log("testing", network, nladdr, raddr) + + d := Dialer{ + D: net.Dialer{ + LocalAddr: nil, + Timeout: timeout, + }, + } + + errs := make(chan error) + go func() { + c, err := d.Dial(network, raddr) + if err == nil { + errs <- errors.New("should've not connected") + c.Close() + return + } + close(errs) // success! + }() + + ErrDrain: + select { + case <-time.After(5 * time.Second): + t.Fatal("took too long") + case err, more := <-errs: + if !more { + break + } + t.Error(err) + goto ErrDrain + } + + } +} + func TestUnixNotSupported(t *testing.T) { testCases := [][]string{