Skip to content

Commit

Permalink
net: add ListenConfig, Dialer.Control to permit socket opts before li…
Browse files Browse the repository at this point in the history
…sten/dial

Existing implementation does not provide a way to set options such as
SO_REUSEPORT, that has to be set prior the socket being bound.

New exposed API:
pkg net, method (*ListenConfig) Listen(context.Context, string, string) (Listener, error)
pkg net, method (*ListenConfig) ListenPacket(context.Context, string, string) (PacketConn, error)
pkg net, type ListenConfig struct
pkg net, type ListenConfig struct, Control func(string, string, syscall.RawConn) error
pkg net, type Dialer struct, Control func(string, string, syscall.RawConn) error

Fixes #9661

Change-Id: If4d275711f823df72d3ac5cc3858651a6a57cccb
Reviewed-on: https://go-review.googlesource.com/72810
Run-TryBot: Ian Lance Taylor <iant@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Ian Lance Taylor <iant@golang.org>
  • Loading branch information
AudriusButkevicius authored and ianlancetaylor committed May 30, 2018
1 parent cc6e568 commit 3c4d3bd
Show file tree
Hide file tree
Showing 12 changed files with 350 additions and 74 deletions.
123 changes: 87 additions & 36 deletions src/net/dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"context"
"internal/nettrace"
"internal/poll"
"syscall"
"time"
)

Expand Down Expand Up @@ -70,6 +71,14 @@ type Dialer struct {
//
// Deprecated: Use DialContext instead.
Cancel <-chan struct{}

// If Control is not nil, it is called after creating the network
// connection but before actually dialing.
//
// Network and address parameters passed to Control method are not
// necessarily the ones passed to Dial. For example, passing "tcp" to Dial
// will cause the Control function to be called with "tcp4" or "tcp6".
Control func(network, address string, c syscall.RawConn) error
}

func minNonzeroTime(a, b time.Time) time.Time {
Expand Down Expand Up @@ -563,8 +572,82 @@ func (sd *sysDialer) dialSingle(ctx context.Context, ra Addr) (c Conn, err error
return c, nil
}

// ListenConfig contains options for listening to an address.
type ListenConfig struct {
// If Control is not nil, it is called after creating the network
// connection but before binding it to the operating system.
//
// Network and address parameters passed to Control method are not
// necessarily the ones passed to Listen. For example, passing "tcp" to
// Listen will cause the Control function to be called with "tcp4" or "tcp6".
Control func(network, address string, c syscall.RawConn) error
}

// Listen announces on the local network address.
//
// See func Listen for a description of the network and address
// parameters.
func (lc *ListenConfig) Listen(ctx context.Context, network, address string) (Listener, error) {
addrs, err := DefaultResolver.resolveAddrList(ctx, "listen", network, address, nil)
if err != nil {
return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: nil, Err: err}
}
sl := &sysListener{
ListenConfig: *lc,
network: network,
address: address,
}
var l Listener
la := addrs.first(isIPv4)
switch la := la.(type) {
case *TCPAddr:
l, err = sl.listenTCP(ctx, la)
case *UnixAddr:
l, err = sl.listenUnix(ctx, la)
default:
return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: &AddrError{Err: "unexpected address type", Addr: address}}
}
if err != nil {
return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: err} // l is non-nil interface containing nil pointer
}
return l, nil
}

// ListenPacket announces on the local network address.
//
// See func ListenPacket for a description of the network and address
// parameters.
func (lc *ListenConfig) ListenPacket(ctx context.Context, network, address string) (PacketConn, error) {
addrs, err := DefaultResolver.resolveAddrList(ctx, "listen", network, address, nil)
if err != nil {
return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: nil, Err: err}
}
sl := &sysListener{
ListenConfig: *lc,
network: network,
address: address,
}
var c PacketConn
la := addrs.first(isIPv4)
switch la := la.(type) {
case *UDPAddr:
c, err = sl.listenUDP(ctx, la)
case *IPAddr:
c, err = sl.listenIP(ctx, la)
case *UnixAddr:
c, err = sl.listenUnixgram(ctx, la)
default:
return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: &AddrError{Err: "unexpected address type", Addr: address}}
}
if err != nil {
return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: err} // c is non-nil interface containing nil pointer
}
return c, nil
}

// sysListener contains a Listen's parameters and configuration.
type sysListener struct {
ListenConfig
network, address string
}

Expand All @@ -587,23 +670,8 @@ type sysListener struct {
// See func Dial for a description of the network and address
// parameters.
func Listen(network, address string) (Listener, error) {
addrs, err := DefaultResolver.resolveAddrList(context.Background(), "listen", network, address, nil)
if err != nil {
return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: nil, Err: err}
}
var l Listener
switch la := addrs.first(isIPv4).(type) {
case *TCPAddr:
l, err = ListenTCP(network, la)
case *UnixAddr:
l, err = ListenUnix(network, la)
default:
return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: la, Err: &AddrError{Err: "unexpected address type", Addr: address}}
}
if err != nil {
return nil, err // l is non-nil interface containing nil pointer
}
return l, nil
var lc ListenConfig
return lc.Listen(context.Background(), network, address)
}

// ListenPacket announces on the local network address.
Expand All @@ -629,23 +697,6 @@ func Listen(network, address string) (Listener, error) {
// See func Dial for a description of the network and address
// parameters.
func ListenPacket(network, address string) (PacketConn, error) {
addrs, err := DefaultResolver.resolveAddrList(context.Background(), "listen", network, address, nil)
if err != nil {
return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: nil, Err: err}
}
var l PacketConn
switch la := addrs.first(isIPv4).(type) {
case *UDPAddr:
l, err = ListenUDP(network, la)
case *IPAddr:
l, err = ListenIP(network, la)
case *UnixAddr:
l, err = ListenUnixgram(network, la)
default:
return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: la, Err: &AddrError{Err: "unexpected address type", Addr: address}}
}
if err != nil {
return nil, err // l is non-nil interface containing nil pointer
}
return l, nil
var lc ListenConfig
return lc.ListenPacket(context.Background(), network, address)
}
51 changes: 51 additions & 0 deletions src/net/dial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,57 @@ func TestDialListenerAddr(t *testing.T) {
c.Close()
}

func TestDialerControl(t *testing.T) {
switch runtime.GOOS {
case "nacl", "plan9":
t.Skipf("not supported on %s", runtime.GOOS)
}

t.Run("StreamDial", func(t *testing.T) {
for _, network := range []string{"tcp", "tcp4", "tcp6", "unix", "unixpacket"} {
if !testableNetwork(network) {
continue
}
ln, err := newLocalListener(network)
if err != nil {
t.Error(err)
continue
}
defer ln.Close()
d := Dialer{Control: controlOnConnSetup}
c, err := d.Dial(network, ln.Addr().String())
if err != nil {
t.Error(err)
continue
}
c.Close()
}
})
t.Run("PacketDial", func(t *testing.T) {
for _, network := range []string{"udp", "udp4", "udp6", "unixgram"} {
if !testableNetwork(network) {
continue
}
c1, err := newLocalPacketListener(network)
if err != nil {
t.Error(err)
continue
}
if network == "unixgram" {
defer os.Remove(c1.LocalAddr().String())
}
defer c1.Close()
d := Dialer{Control: controlOnConnSetup}
c2, err := d.Dial(network, c1.LocalAddr().String())
if err != nil {
t.Error(err)
continue
}
c2.Close()
}
})
}

// mustHaveExternalNetwork is like testenv.MustHaveExternalNetwork
// except that it won't skip testing on non-iOS builders.
func mustHaveExternalNetwork(t *testing.T) {
Expand Down
4 changes: 2 additions & 2 deletions src/net/iprawsock_posix.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ func (sd *sysDialer) dialIP(ctx context.Context, laddr, raddr *IPAddr) (*IPConn,
default:
return nil, UnknownNetworkError(sd.network)
}
fd, err := internetSocket(ctx, network, laddr, raddr, syscall.SOCK_RAW, proto, "dial")
fd, err := internetSocket(ctx, network, laddr, raddr, syscall.SOCK_RAW, proto, "dial", sd.Dialer.Control)
if err != nil {
return nil, err
}
Expand All @@ -139,7 +139,7 @@ func (sl *sysListener) listenIP(ctx context.Context, laddr *IPAddr) (*IPConn, er
default:
return nil, UnknownNetworkError(sl.network)
}
fd, err := internetSocket(ctx, network, laddr, nil, syscall.SOCK_RAW, proto, "listen")
fd, err := internetSocket(ctx, network, laddr, nil, syscall.SOCK_RAW, proto, "listen", sl.ListenConfig.Control)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions src/net/ipsock_posix.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,12 @@ func favoriteAddrFamily(network string, laddr, raddr sockaddr, mode string) (fam
return syscall.AF_INET6, false
}

func internetSocket(ctx context.Context, net string, laddr, raddr sockaddr, sotype, proto int, mode string) (fd *netFD, err error) {
func internetSocket(ctx context.Context, net string, laddr, raddr sockaddr, sotype, proto int, mode string, ctrlFn func(string, string, syscall.RawConn) error) (fd *netFD, err error) {
if (runtime.GOOS == "windows" || runtime.GOOS == "openbsd" || runtime.GOOS == "nacl") && mode == "dial" && raddr.isWildcard() {
raddr = raddr.toLocal(net)
}
family, ipv6only := favoriteAddrFamily(net, laddr, raddr, mode)
return socket(ctx, net, family, sotype, proto, ipv6only, laddr, raddr)
return socket(ctx, net, family, sotype, proto, ipv6only, laddr, raddr, ctrlFn)
}

func ipToSockaddr(family int, ip IP, port int, zone string) (syscall.Sockaddr, error) {
Expand Down
54 changes: 54 additions & 0 deletions src/net/listen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package net

import (
"context"
"fmt"
"internal/testenv"
"os"
Expand Down Expand Up @@ -729,3 +730,56 @@ func TestClosingListener(t *testing.T) {
}
ln2.Close()
}

func TestListenConfigControl(t *testing.T) {
switch runtime.GOOS {
case "nacl", "plan9":
t.Skipf("not supported on %s", runtime.GOOS)
}

t.Run("StreamListen", func(t *testing.T) {
for _, network := range []string{"tcp", "tcp4", "tcp6", "unix", "unixpacket"} {
if !testableNetwork(network) {
continue
}
ln, err := newLocalListener(network)
if err != nil {
t.Error(err)
continue
}
address := ln.Addr().String()
ln.Close()
lc := ListenConfig{Control: controlOnConnSetup}
ln, err = lc.Listen(context.Background(), network, address)
if err != nil {
t.Error(err)
continue
}
ln.Close()
}
})
t.Run("PacketListen", func(t *testing.T) {
for _, network := range []string{"udp", "udp4", "udp6", "unixgram"} {
if !testableNetwork(network) {
continue
}
c, err := newLocalPacketListener(network)
if err != nil {
t.Error(err)
continue
}
address := c.LocalAddr().String()
c.Close()
if network == "unixgram" {
os.Remove(address)
}
lc := ListenConfig{Control: controlOnConnSetup}
c, err = lc.ListenPacket(context.Background(), network, address)
if err != nil {
t.Error(err)
continue
}
c.Close()
}
})
}
4 changes: 4 additions & 0 deletions src/net/rawconn_stub_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,7 @@ func writeRawConn(c syscall.RawConn, b []byte) error {
func controlRawConn(c syscall.RawConn, addr Addr) error {
return errors.New("not supported")
}

func controlOnConnSetup(network string, address string, c syscall.RawConn) error {
return nil
}
38 changes: 37 additions & 1 deletion src/net/rawconn_unix_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@

package net

import "syscall"
import (
"errors"
"syscall"
)

func readRawConn(c syscall.RawConn, b []byte) (int, error) {
var operr error
Expand Down Expand Up @@ -89,3 +92,36 @@ func controlRawConn(c syscall.RawConn, addr Addr) error {
}
return nil
}

func controlOnConnSetup(network string, address string, c syscall.RawConn) error {
var operr error
var fn func(uintptr)
switch network {
case "tcp", "udp", "ip":
return errors.New("ambiguous network: " + network)
case "unix", "unixpacket", "unixgram":
fn = func(s uintptr) {
_, operr = syscall.GetsockoptInt(int(s), syscall.SOL_SOCKET, syscall.SO_ERROR)
}
default:
switch network[len(network)-1] {
case '4':
fn = func(s uintptr) {
operr = syscall.SetsockoptInt(int(s), syscall.IPPROTO_IP, syscall.IP_TTL, 1)
}
case '6':
fn = func(s uintptr) {
operr = syscall.SetsockoptInt(int(s), syscall.IPPROTO_IPV6, syscall.IPV6_UNICAST_HOPS, 1)
}
default:
return errors.New("unknown network: " + network)
}
}
if err := c.Control(fn); err != nil {
return err
}
if operr != nil {
return operr
}
return nil
}
Loading

0 comments on commit 3c4d3bd

Please sign in to comment.