Skip to content

Commit

Permalink
net: move dial and listen functions under sysDialer, sysListener
Browse files Browse the repository at this point in the history
Updates #9661

Change-Id: I237e7502cb9faad6dece1e25b1a503739c54d826
Reviewed-on: https://go-review.googlesource.com/115175
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
  • Loading branch information
AudriusButkevicius authored and bradfitz committed May 29, 2018
1 parent a4330ed commit c6295e7
Show file tree
Hide file tree
Showing 14 changed files with 111 additions and 91 deletions.
53 changes: 29 additions & 24 deletions src/net/dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,8 @@ func DialTimeout(network, address string, timeout time.Duration) (Conn, error) {
return d.Dial(network, address)
}

// dialParam contains a Dial's parameters and configuration.
type dialParam struct {
// sysDialer contains a Dial's parameters and configuration.
type sysDialer struct {
Dialer
network, address string
}
Expand Down Expand Up @@ -377,7 +377,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn
return nil, &OpError{Op: "dial", Net: network, Source: nil, Addr: nil, Err: err}
}

dp := &dialParam{
sd := &sysDialer{
Dialer: *d,
network: network,
address: address,
Expand All @@ -392,9 +392,9 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn

var c Conn
if len(fallbacks) > 0 {
c, err = dialParallel(ctx, dp, primaries, fallbacks)
c, err = sd.dialParallel(ctx, primaries, fallbacks)
} else {
c, err = dialSerial(ctx, dp, primaries)
c, err = sd.dialSerial(ctx, primaries)
}
if err != nil {
return nil, err
Expand All @@ -412,9 +412,9 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn
// head start. It returns the first established connection and
// closes the others. Otherwise it returns an error from the first
// primary address.
func dialParallel(ctx context.Context, dp *dialParam, primaries, fallbacks addrList) (Conn, error) {
func (sd *sysDialer) dialParallel(ctx context.Context, primaries, fallbacks addrList) (Conn, error) {
if len(fallbacks) == 0 {
return dialSerial(ctx, dp, primaries)
return sd.dialSerial(ctx, primaries)
}

returned := make(chan struct{})
Expand All @@ -433,7 +433,7 @@ func dialParallel(ctx context.Context, dp *dialParam, primaries, fallbacks addrL
if !primary {
ras = fallbacks
}
c, err := dialSerial(ctx, dp, ras)
c, err := sd.dialSerial(ctx, ras)
select {
case results <- dialResult{Conn: c, error: err, primary: primary, done: true}:
case <-returned:
Expand All @@ -451,7 +451,7 @@ func dialParallel(ctx context.Context, dp *dialParam, primaries, fallbacks addrL
go startRacer(primaryCtx, true)

// Start the timer for the fallback racer.
fallbackTimer := time.NewTimer(dp.fallbackDelay())
fallbackTimer := time.NewTimer(sd.fallbackDelay())
defer fallbackTimer.Stop()

for {
Expand Down Expand Up @@ -486,13 +486,13 @@ func dialParallel(ctx context.Context, dp *dialParam, primaries, fallbacks addrL

// dialSerial connects to a list of addresses in sequence, returning
// either the first successful connection, or the first error.
func dialSerial(ctx context.Context, dp *dialParam, ras addrList) (Conn, error) {
func (sd *sysDialer) dialSerial(ctx context.Context, ras addrList) (Conn, error) {
var firstErr error // The error from the first address is most relevant.

for i, ra := range ras {
select {
case <-ctx.Done():
return nil, &OpError{Op: "dial", Net: dp.network, Source: dp.LocalAddr, Addr: ra, Err: mapErr(ctx.Err())}
return nil, &OpError{Op: "dial", Net: sd.network, Source: sd.LocalAddr, Addr: ra, Err: mapErr(ctx.Err())}
default:
}

Expand All @@ -501,7 +501,7 @@ func dialSerial(ctx context.Context, dp *dialParam, ras addrList) (Conn, error)
if err != nil {
// Ran out of time.
if firstErr == nil {
firstErr = &OpError{Op: "dial", Net: dp.network, Source: dp.LocalAddr, Addr: ra, Err: err}
firstErr = &OpError{Op: "dial", Net: sd.network, Source: sd.LocalAddr, Addr: ra, Err: err}
}
break
}
Expand All @@ -512,7 +512,7 @@ func dialSerial(ctx context.Context, dp *dialParam, ras addrList) (Conn, error)
defer cancel()
}

c, err := dialSingle(dialCtx, dp, ra)
c, err := sd.dialSingle(dialCtx, ra)
if err == nil {
return c, nil
}
Expand All @@ -522,47 +522,52 @@ func dialSerial(ctx context.Context, dp *dialParam, ras addrList) (Conn, error)
}

if firstErr == nil {
firstErr = &OpError{Op: "dial", Net: dp.network, Source: nil, Addr: nil, Err: errMissingAddress}
firstErr = &OpError{Op: "dial", Net: sd.network, Source: nil, Addr: nil, Err: errMissingAddress}
}
return nil, firstErr
}

// dialSingle attempts to establish and returns a single connection to
// the destination address.
func dialSingle(ctx context.Context, dp *dialParam, ra Addr) (c Conn, err error) {
func (sd *sysDialer) dialSingle(ctx context.Context, ra Addr) (c Conn, err error) {
trace, _ := ctx.Value(nettrace.TraceKey{}).(*nettrace.Trace)
if trace != nil {
raStr := ra.String()
if trace.ConnectStart != nil {
trace.ConnectStart(dp.network, raStr)
trace.ConnectStart(sd.network, raStr)
}
if trace.ConnectDone != nil {
defer func() { trace.ConnectDone(dp.network, raStr, err) }()
defer func() { trace.ConnectDone(sd.network, raStr, err) }()
}
}
la := dp.LocalAddr
la := sd.LocalAddr
switch ra := ra.(type) {
case *TCPAddr:
la, _ := la.(*TCPAddr)
c, err = dialTCP(ctx, dp.network, la, ra)
c, err = sd.dialTCP(ctx, la, ra)
case *UDPAddr:
la, _ := la.(*UDPAddr)
c, err = dialUDP(ctx, dp.network, la, ra)
c, err = sd.dialUDP(ctx, la, ra)
case *IPAddr:
la, _ := la.(*IPAddr)
c, err = dialIP(ctx, dp.network, la, ra)
c, err = sd.dialIP(ctx, la, ra)
case *UnixAddr:
la, _ := la.(*UnixAddr)
c, err = dialUnix(ctx, dp.network, la, ra)
c, err = sd.dialUnix(ctx, la, ra)
default:
return nil, &OpError{Op: "dial", Net: dp.network, Source: la, Addr: ra, Err: &AddrError{Err: "unexpected address type", Addr: dp.address}}
return nil, &OpError{Op: "dial", Net: sd.network, Source: la, Addr: ra, Err: &AddrError{Err: "unexpected address type", Addr: sd.address}}
}
if err != nil {
return nil, &OpError{Op: "dial", Net: dp.network, Source: la, Addr: ra, Err: err} // c is non-nil interface containing nil pointer
return nil, &OpError{Op: "dial", Net: sd.network, Source: la, Addr: ra, Err: err} // c is non-nil interface containing nil pointer
}
return c, nil
}

// sysListener contains a Listen's parameters and configuration.
type sysListener struct {
network, address string
}

// Listen announces on the local network address.
//
// The network must be "tcp", "tcp4", "tcp6", "unix" or "unixpacket".
Expand Down
18 changes: 10 additions & 8 deletions src/net/dial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,9 @@ const (
// In some environments, the slow IPs may be explicitly unreachable, and fail
// more quickly than expected. This test hook prevents dialTCP from returning
// before the deadline.
func slowDialTCP(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error) {
c, err := doDialTCP(ctx, net, laddr, raddr)
func slowDialTCP(ctx context.Context, network string, laddr, raddr *TCPAddr) (*TCPConn, error) {
sd := &sysDialer{network: network, address: raddr.String()}
c, err := sd.doDialTCP(ctx, laddr, raddr)
if ParseIP(slowDst4).Equal(raddr.IP) || ParseIP(slowDst6).Equal(raddr.IP) {
// Wait for the deadline, or indefinitely if none exists.
<-ctx.Done()
Expand Down Expand Up @@ -295,12 +296,12 @@ func TestDialParallel(t *testing.T) {
FallbackDelay: fallbackDelay,
}
startTime := time.Now()
dp := &dialParam{
sd := &sysDialer{
Dialer: d,
network: "tcp",
address: "?",
}
c, err := dialParallel(context.Background(), dp, primaries, fallbacks)
c, err := sd.dialParallel(context.Background(), primaries, fallbacks)
elapsed := time.Since(startTime)

if c != nil {
Expand Down Expand Up @@ -331,7 +332,7 @@ func TestDialParallel(t *testing.T) {
wg.Done()
}()
startTime = time.Now()
c, err = dialParallel(ctx, dp, primaries, fallbacks)
c, err = sd.dialParallel(ctx, primaries, fallbacks)
if c != nil {
c.Close()
}
Expand Down Expand Up @@ -467,13 +468,14 @@ func TestDialParallelSpuriousConnection(t *testing.T) {
// Now ignore the provided context (which will be canceled) and use a
// different one to make sure this completes with a valid connection,
// which we hope to be closed below:
return doDialTCP(context.Background(), net, laddr, raddr)
sd := &sysDialer{network: net, address: raddr.String()}
return sd.doDialTCP(context.Background(), laddr, raddr)
}

d := Dialer{
FallbackDelay: fallbackDelay,
}
dp := &dialParam{
sd := &sysDialer{
Dialer: d,
network: "tcp",
address: "?",
Expand All @@ -488,7 +490,7 @@ func TestDialParallelSpuriousConnection(t *testing.T) {
}

// dialParallel returns one connection (and closes the other.)
c, err := dialParallel(context.Background(), dp, makeAddr("127.0.0.1"), makeAddr("::1"))
c, err := sd.dialParallel(context.Background(), makeAddr("127.0.0.1"), makeAddr("::1"))
if err != nil {
t.Fatal(err)
}
Expand Down
12 changes: 10 additions & 2 deletions src/net/iprawsock.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,11 @@ func newIPConn(fd *netFD) *IPConn { return &IPConn{conn{fd}} }
// If the IP field of raddr is nil or an unspecified IP address, the
// local system is assumed.
func DialIP(network string, laddr, raddr *IPAddr) (*IPConn, error) {
c, err := dialIP(context.Background(), network, laddr, raddr)
if raddr == nil {
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: nil, Err: errMissingAddress}
}
sd := &sysDialer{network: network, address: raddr.String()}
c, err := sd.dialIP(context.Background(), laddr, raddr)
if err != nil {
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
}
Expand All @@ -224,7 +228,11 @@ func DialIP(network string, laddr, raddr *IPAddr) (*IPConn, error) {
// ListenIP listens on all available IP addresses of the local system
// except multicast IP addresses.
func ListenIP(network string, laddr *IPAddr) (*IPConn, error) {
c, err := listenIP(context.Background(), network, laddr)
if laddr == nil {
laddr = &IPAddr{}
}
sl := &sysListener{network: network, address: laddr.String()}
c, err := sl.listenIP(context.Background(), laddr)
if err != nil {
return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: laddr.opAddr(), Err: err}
}
Expand Down
4 changes: 2 additions & 2 deletions src/net/iprawsock_plan9.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ func (c *IPConn) writeMsg(b, oob []byte, addr *IPAddr) (n, oobn int, err error)
return 0, 0, syscall.EPLAN9
}

func dialIP(ctx context.Context, netProto string, laddr, raddr *IPAddr) (*IPConn, error) {
func (sd *sysDialer) dialIP(ctx context.Context, laddr, raddr *IPAddr) (*IPConn, error) {
return nil, syscall.EPLAN9
}

func listenIP(ctx context.Context, netProto string, laddr *IPAddr) (*IPConn, error) {
func (sl *sysListener) listenIP(ctx context.Context, laddr *IPAddr) (*IPConn, error) {
return nil, syscall.EPLAN9
}
15 changes: 6 additions & 9 deletions src/net/iprawsock_posix.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,18 +112,15 @@ func (c *IPConn) writeMsg(b, oob []byte, addr *IPAddr) (n, oobn int, err error)
return c.fd.writeMsg(b, oob, sa)
}

func dialIP(ctx context.Context, netProto string, laddr, raddr *IPAddr) (*IPConn, error) {
network, proto, err := parseNetwork(ctx, netProto, true)
func (sd *sysDialer) dialIP(ctx context.Context, laddr, raddr *IPAddr) (*IPConn, error) {
network, proto, err := parseNetwork(ctx, sd.network, true)
if err != nil {
return nil, err
}
switch network {
case "ip", "ip4", "ip6":
default:
return nil, UnknownNetworkError(netProto)
}
if raddr == nil {
return nil, errMissingAddress
return nil, UnknownNetworkError(sd.network)
}
fd, err := internetSocket(ctx, network, laddr, raddr, syscall.SOCK_RAW, proto, "dial")
if err != nil {
Expand All @@ -132,15 +129,15 @@ func dialIP(ctx context.Context, netProto string, laddr, raddr *IPAddr) (*IPConn
return newIPConn(fd), nil
}

func listenIP(ctx context.Context, netProto string, laddr *IPAddr) (*IPConn, error) {
network, proto, err := parseNetwork(ctx, netProto, true)
func (sl *sysListener) listenIP(ctx context.Context, laddr *IPAddr) (*IPConn, error) {
network, proto, err := parseNetwork(ctx, sl.network, true)
if err != nil {
return nil, err
}
switch network {
case "ip", "ip4", "ip6":
default:
return nil, UnknownNetworkError(netProto)
return nil, UnknownNetworkError(sl.network)
}
fd, err := internetSocket(ctx, network, laddr, nil, syscall.SOCK_RAW, proto, "listen")
if err != nil {
Expand Down
6 changes: 4 additions & 2 deletions src/net/tcpsock.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,8 @@ func DialTCP(network string, laddr, raddr *TCPAddr) (*TCPConn, error) {
if raddr == nil {
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: nil, Err: errMissingAddress}
}
c, err := dialTCP(context.Background(), network, laddr, raddr)
sd := &sysDialer{network: network, address: raddr.String()}
c, err := sd.dialTCP(context.Background(), laddr, raddr)
if err != nil {
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
}
Expand Down Expand Up @@ -328,7 +329,8 @@ func ListenTCP(network string, laddr *TCPAddr) (*TCPListener, error) {
if laddr == nil {
laddr = &TCPAddr{}
}
ln, err := listenTCP(context.Background(), network, laddr)
sl := &sysListener{network: network, address: laddr.String()}
ln, err := sl.listenTCP(context.Background(), laddr)
if err != nil {
return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: laddr.opAddr(), Err: err}
}
Expand Down
18 changes: 9 additions & 9 deletions src/net/tcpsock_plan9.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,23 @@ func (c *TCPConn) readFrom(r io.Reader) (int64, error) {
return genericReadFrom(c, r)
}

func dialTCP(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error) {
func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConn, error) {
if testHookDialTCP != nil {
return testHookDialTCP(ctx, net, laddr, raddr)
return testHookDialTCP(ctx, sd.network, laddr, raddr)
}
return doDialTCP(ctx, net, laddr, raddr)
return sd.doDialTCP(ctx, laddr, raddr)
}

func doDialTCP(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error) {
switch net {
func (sd *sysDialer) doDialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConn, error) {
switch sd.network {
case "tcp", "tcp4", "tcp6":
default:
return nil, UnknownNetworkError(net)
return nil, UnknownNetworkError(sd.network)
}
if raddr == nil {
return nil, errMissingAddress
}
fd, err := dialPlan9(ctx, net, laddr, raddr)
fd, err := dialPlan9(ctx, sd.network, laddr, raddr)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -69,8 +69,8 @@ func (ln *TCPListener) file() (*os.File, error) {
return f, nil
}

func listenTCP(ctx context.Context, network string, laddr *TCPAddr) (*TCPListener, error) {
fd, err := listenPlan9(ctx, network, laddr)
func (sl *sysListener) listenTCP(ctx context.Context, laddr *TCPAddr) (*TCPListener, error) {
fd, err := listenPlan9(ctx, sl.network, laddr)
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit c6295e7

Please sign in to comment.