diff --git a/src/net/dial.go b/src/net/dial.go index 3ea049ca460e7..b1a5ca7cd53f5 100644 --- a/src/net/dial.go +++ b/src/net/dial.go @@ -8,6 +8,7 @@ import ( "context" "internal/nettrace" "internal/poll" + "syscall" "time" ) @@ -70,6 +71,14 @@ type Dialer struct { // // Deprecated: Use DialContext instead. Cancel <-chan struct{} + + // If Control is not nil, it is called after creating the network + // connection but before actually dialing. + // + // Network and address parameters passed to Control method are not + // necessarily the ones passed to Dial. For example, passing "tcp" to Dial + // will cause the Control function to be called with "tcp4" or "tcp6". + Control func(network, address string, c syscall.RawConn) error } func minNonzeroTime(a, b time.Time) time.Time { @@ -563,8 +572,82 @@ func (sd *sysDialer) dialSingle(ctx context.Context, ra Addr) (c Conn, err error return c, nil } +// ListenConfig contains options for listening to an address. +type ListenConfig struct { + // If Control is not nil, it is called after creating the network + // connection but before binding it to the operating system. + // + // Network and address parameters passed to Control method are not + // necessarily the ones passed to Listen. For example, passing "tcp" to + // Listen will cause the Control function to be called with "tcp4" or "tcp6". + Control func(network, address string, c syscall.RawConn) error +} + +// Listen announces on the local network address. +// +// See func Listen for a description of the network and address +// parameters. +func (lc *ListenConfig) Listen(ctx context.Context, network, address string) (Listener, error) { + addrs, err := DefaultResolver.resolveAddrList(ctx, "listen", network, address, nil) + if err != nil { + return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: nil, Err: err} + } + sl := &sysListener{ + ListenConfig: *lc, + network: network, + address: address, + } + var l Listener + la := addrs.first(isIPv4) + switch la := la.(type) { + case *TCPAddr: + l, err = sl.listenTCP(ctx, la) + case *UnixAddr: + l, err = sl.listenUnix(ctx, la) + default: + return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: &AddrError{Err: "unexpected address type", Addr: address}} + } + if err != nil { + return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: err} // l is non-nil interface containing nil pointer + } + return l, nil +} + +// ListenPacket announces on the local network address. +// +// See func ListenPacket for a description of the network and address +// parameters. +func (lc *ListenConfig) ListenPacket(ctx context.Context, network, address string) (PacketConn, error) { + addrs, err := DefaultResolver.resolveAddrList(ctx, "listen", network, address, nil) + if err != nil { + return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: nil, Err: err} + } + sl := &sysListener{ + ListenConfig: *lc, + network: network, + address: address, + } + var c PacketConn + la := addrs.first(isIPv4) + switch la := la.(type) { + case *UDPAddr: + c, err = sl.listenUDP(ctx, la) + case *IPAddr: + c, err = sl.listenIP(ctx, la) + case *UnixAddr: + c, err = sl.listenUnixgram(ctx, la) + default: + return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: &AddrError{Err: "unexpected address type", Addr: address}} + } + if err != nil { + return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: err} // c is non-nil interface containing nil pointer + } + return c, nil +} + // sysListener contains a Listen's parameters and configuration. type sysListener struct { + ListenConfig network, address string } @@ -587,23 +670,8 @@ type sysListener struct { // See func Dial for a description of the network and address // parameters. func Listen(network, address string) (Listener, error) { - addrs, err := DefaultResolver.resolveAddrList(context.Background(), "listen", network, address, nil) - if err != nil { - return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: nil, Err: err} - } - var l Listener - switch la := addrs.first(isIPv4).(type) { - case *TCPAddr: - l, err = ListenTCP(network, la) - case *UnixAddr: - l, err = ListenUnix(network, la) - default: - return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: la, Err: &AddrError{Err: "unexpected address type", Addr: address}} - } - if err != nil { - return nil, err // l is non-nil interface containing nil pointer - } - return l, nil + var lc ListenConfig + return lc.Listen(context.Background(), network, address) } // ListenPacket announces on the local network address. @@ -629,23 +697,6 @@ func Listen(network, address string) (Listener, error) { // See func Dial for a description of the network and address // parameters. func ListenPacket(network, address string) (PacketConn, error) { - addrs, err := DefaultResolver.resolveAddrList(context.Background(), "listen", network, address, nil) - if err != nil { - return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: nil, Err: err} - } - var l PacketConn - switch la := addrs.first(isIPv4).(type) { - case *UDPAddr: - l, err = ListenUDP(network, la) - case *IPAddr: - l, err = ListenIP(network, la) - case *UnixAddr: - l, err = ListenUnixgram(network, la) - default: - return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: la, Err: &AddrError{Err: "unexpected address type", Addr: address}} - } - if err != nil { - return nil, err // l is non-nil interface containing nil pointer - } - return l, nil + var lc ListenConfig + return lc.ListenPacket(context.Background(), network, address) } diff --git a/src/net/dial_test.go b/src/net/dial_test.go index 96d8921ec8525..3934ad8648360 100644 --- a/src/net/dial_test.go +++ b/src/net/dial_test.go @@ -912,6 +912,57 @@ func TestDialListenerAddr(t *testing.T) { c.Close() } +func TestDialerControl(t *testing.T) { + switch runtime.GOOS { + case "nacl", "plan9": + t.Skipf("not supported on %s", runtime.GOOS) + } + + t.Run("StreamDial", func(t *testing.T) { + for _, network := range []string{"tcp", "tcp4", "tcp6", "unix", "unixpacket"} { + if !testableNetwork(network) { + continue + } + ln, err := newLocalListener(network) + if err != nil { + t.Error(err) + continue + } + defer ln.Close() + d := Dialer{Control: controlOnConnSetup} + c, err := d.Dial(network, ln.Addr().String()) + if err != nil { + t.Error(err) + continue + } + c.Close() + } + }) + t.Run("PacketDial", func(t *testing.T) { + for _, network := range []string{"udp", "udp4", "udp6", "unixgram"} { + if !testableNetwork(network) { + continue + } + c1, err := newLocalPacketListener(network) + if err != nil { + t.Error(err) + continue + } + if network == "unixgram" { + defer os.Remove(c1.LocalAddr().String()) + } + defer c1.Close() + d := Dialer{Control: controlOnConnSetup} + c2, err := d.Dial(network, c1.LocalAddr().String()) + if err != nil { + t.Error(err) + continue + } + c2.Close() + } + }) +} + // mustHaveExternalNetwork is like testenv.MustHaveExternalNetwork // except that it won't skip testing on non-iOS builders. func mustHaveExternalNetwork(t *testing.T) { diff --git a/src/net/iprawsock_posix.go b/src/net/iprawsock_posix.go index 7dafd20bf6861..b2f57916433eb 100644 --- a/src/net/iprawsock_posix.go +++ b/src/net/iprawsock_posix.go @@ -122,7 +122,7 @@ func (sd *sysDialer) dialIP(ctx context.Context, laddr, raddr *IPAddr) (*IPConn, default: return nil, UnknownNetworkError(sd.network) } - fd, err := internetSocket(ctx, network, laddr, raddr, syscall.SOCK_RAW, proto, "dial") + fd, err := internetSocket(ctx, network, laddr, raddr, syscall.SOCK_RAW, proto, "dial", sd.Dialer.Control) if err != nil { return nil, err } @@ -139,7 +139,7 @@ func (sl *sysListener) listenIP(ctx context.Context, laddr *IPAddr) (*IPConn, er default: return nil, UnknownNetworkError(sl.network) } - fd, err := internetSocket(ctx, network, laddr, nil, syscall.SOCK_RAW, proto, "listen") + fd, err := internetSocket(ctx, network, laddr, nil, syscall.SOCK_RAW, proto, "listen", sl.ListenConfig.Control) if err != nil { return nil, err } diff --git a/src/net/ipsock_posix.go b/src/net/ipsock_posix.go index 8372aaa7423cf..eddd4118fa231 100644 --- a/src/net/ipsock_posix.go +++ b/src/net/ipsock_posix.go @@ -133,12 +133,12 @@ func favoriteAddrFamily(network string, laddr, raddr sockaddr, mode string) (fam return syscall.AF_INET6, false } -func internetSocket(ctx context.Context, net string, laddr, raddr sockaddr, sotype, proto int, mode string) (fd *netFD, err error) { +func internetSocket(ctx context.Context, net string, laddr, raddr sockaddr, sotype, proto int, mode string, ctrlFn func(string, string, syscall.RawConn) error) (fd *netFD, err error) { if (runtime.GOOS == "windows" || runtime.GOOS == "openbsd" || runtime.GOOS == "nacl") && mode == "dial" && raddr.isWildcard() { raddr = raddr.toLocal(net) } family, ipv6only := favoriteAddrFamily(net, laddr, raddr, mode) - return socket(ctx, net, family, sotype, proto, ipv6only, laddr, raddr) + return socket(ctx, net, family, sotype, proto, ipv6only, laddr, raddr, ctrlFn) } func ipToSockaddr(family int, ip IP, port int, zone string) (syscall.Sockaddr, error) { diff --git a/src/net/listen_test.go b/src/net/listen_test.go index 96624f98ce53f..ffd38d79506d6 100644 --- a/src/net/listen_test.go +++ b/src/net/listen_test.go @@ -7,6 +7,7 @@ package net import ( + "context" "fmt" "internal/testenv" "os" @@ -729,3 +730,56 @@ func TestClosingListener(t *testing.T) { } ln2.Close() } + +func TestListenConfigControl(t *testing.T) { + switch runtime.GOOS { + case "nacl", "plan9": + t.Skipf("not supported on %s", runtime.GOOS) + } + + t.Run("StreamListen", func(t *testing.T) { + for _, network := range []string{"tcp", "tcp4", "tcp6", "unix", "unixpacket"} { + if !testableNetwork(network) { + continue + } + ln, err := newLocalListener(network) + if err != nil { + t.Error(err) + continue + } + address := ln.Addr().String() + ln.Close() + lc := ListenConfig{Control: controlOnConnSetup} + ln, err = lc.Listen(context.Background(), network, address) + if err != nil { + t.Error(err) + continue + } + ln.Close() + } + }) + t.Run("PacketListen", func(t *testing.T) { + for _, network := range []string{"udp", "udp4", "udp6", "unixgram"} { + if !testableNetwork(network) { + continue + } + c, err := newLocalPacketListener(network) + if err != nil { + t.Error(err) + continue + } + address := c.LocalAddr().String() + c.Close() + if network == "unixgram" { + os.Remove(address) + } + lc := ListenConfig{Control: controlOnConnSetup} + c, err = lc.ListenPacket(context.Background(), network, address) + if err != nil { + t.Error(err) + continue + } + c.Close() + } + }) +} diff --git a/src/net/rawconn_stub_test.go b/src/net/rawconn_stub_test.go index 391b4d188e2e2..3e3b6bf5b2a1a 100644 --- a/src/net/rawconn_stub_test.go +++ b/src/net/rawconn_stub_test.go @@ -22,3 +22,7 @@ func writeRawConn(c syscall.RawConn, b []byte) error { func controlRawConn(c syscall.RawConn, addr Addr) error { return errors.New("not supported") } + +func controlOnConnSetup(network string, address string, c syscall.RawConn) error { + return nil +} diff --git a/src/net/rawconn_unix_test.go b/src/net/rawconn_unix_test.go index 2fe4d2c6bace6..a720a8a4a3e27 100644 --- a/src/net/rawconn_unix_test.go +++ b/src/net/rawconn_unix_test.go @@ -6,7 +6,10 @@ package net -import "syscall" +import ( + "errors" + "syscall" +) func readRawConn(c syscall.RawConn, b []byte) (int, error) { var operr error @@ -89,3 +92,36 @@ func controlRawConn(c syscall.RawConn, addr Addr) error { } return nil } + +func controlOnConnSetup(network string, address string, c syscall.RawConn) error { + var operr error + var fn func(uintptr) + switch network { + case "tcp", "udp", "ip": + return errors.New("ambiguous network: " + network) + case "unix", "unixpacket", "unixgram": + fn = func(s uintptr) { + _, operr = syscall.GetsockoptInt(int(s), syscall.SOL_SOCKET, syscall.SO_ERROR) + } + default: + switch network[len(network)-1] { + case '4': + fn = func(s uintptr) { + operr = syscall.SetsockoptInt(int(s), syscall.IPPROTO_IP, syscall.IP_TTL, 1) + } + case '6': + fn = func(s uintptr) { + operr = syscall.SetsockoptInt(int(s), syscall.IPPROTO_IPV6, syscall.IPV6_UNICAST_HOPS, 1) + } + default: + return errors.New("unknown network: " + network) + } + } + if err := c.Control(fn); err != nil { + return err + } + if operr != nil { + return operr + } + return nil +} diff --git a/src/net/rawconn_windows_test.go b/src/net/rawconn_windows_test.go index 6df101e9de4b2..2774c97e5c82e 100644 --- a/src/net/rawconn_windows_test.go +++ b/src/net/rawconn_windows_test.go @@ -5,6 +5,7 @@ package net import ( + "errors" "syscall" "unsafe" ) @@ -96,3 +97,32 @@ func controlRawConn(c syscall.RawConn, addr Addr) error { } return nil } + +func controlOnConnSetup(network string, address string, c syscall.RawConn) error { + var operr error + var fn func(uintptr) + switch network { + case "tcp", "udp", "ip": + return errors.New("ambiguous network: " + network) + default: + switch network[len(network)-1] { + case '4': + fn = func(s uintptr) { + operr = syscall.SetsockoptInt(syscall.Handle(s), syscall.IPPROTO_IP, syscall.IP_TTL, 1) + } + case '6': + fn = func(s uintptr) { + operr = syscall.SetsockoptInt(syscall.Handle(s), syscall.IPPROTO_IPV6, syscall.IPV6_UNICAST_HOPS, 1) + } + default: + return errors.New("unknown network: " + network) + } + } + if err := c.Control(fn); err != nil { + return err + } + if operr != nil { + return operr + } + return nil +} diff --git a/src/net/sock_posix.go b/src/net/sock_posix.go index 8cfc42eb7e66f..00ff3fd39394c 100644 --- a/src/net/sock_posix.go +++ b/src/net/sock_posix.go @@ -38,7 +38,7 @@ type sockaddr interface { // socket returns a network file descriptor that is ready for // asynchronous I/O using the network poller. -func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr) (fd *netFD, err error) { +func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr, ctrlFn func(string, string, syscall.RawConn) error) (fd *netFD, err error) { s, err := sysSocket(family, sotype, proto) if err != nil { return nil, err @@ -77,26 +77,41 @@ func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only if laddr != nil && raddr == nil { switch sotype { case syscall.SOCK_STREAM, syscall.SOCK_SEQPACKET: - if err := fd.listenStream(laddr, listenerBacklog); err != nil { + if err := fd.listenStream(laddr, listenerBacklog, ctrlFn); err != nil { fd.Close() return nil, err } return fd, nil case syscall.SOCK_DGRAM: - if err := fd.listenDatagram(laddr); err != nil { + if err := fd.listenDatagram(laddr, ctrlFn); err != nil { fd.Close() return nil, err } return fd, nil } } - if err := fd.dial(ctx, laddr, raddr); err != nil { + if err := fd.dial(ctx, laddr, raddr, ctrlFn); err != nil { fd.Close() return nil, err } return fd, nil } +func (fd *netFD) ctrlNetwork() string { + switch fd.net { + case "unix", "unixgram", "unixpacket": + return fd.net + } + switch fd.net[len(fd.net)-1] { + case '4', '6': + return fd.net + } + if fd.family == syscall.AF_INET { + return fd.net + "4" + } + return fd.net + "6" +} + func (fd *netFD) addrFunc() func(syscall.Sockaddr) Addr { switch fd.family { case syscall.AF_INET, syscall.AF_INET6: @@ -121,14 +136,29 @@ func (fd *netFD) addrFunc() func(syscall.Sockaddr) Addr { return func(syscall.Sockaddr) Addr { return nil } } -func (fd *netFD) dial(ctx context.Context, laddr, raddr sockaddr) error { +func (fd *netFD) dial(ctx context.Context, laddr, raddr sockaddr, ctrlFn func(string, string, syscall.RawConn) error) error { + if ctrlFn != nil { + c, err := newRawConn(fd) + if err != nil { + return err + } + var ctrlAddr string + if raddr != nil { + ctrlAddr = raddr.String() + } else if laddr != nil { + ctrlAddr = laddr.String() + } + if err := ctrlFn(fd.ctrlNetwork(), ctrlAddr, c); err != nil { + return err + } + } var err error var lsa syscall.Sockaddr if laddr != nil { if lsa, err = laddr.sockaddr(fd.family); err != nil { return err } else if lsa != nil { - if err := syscall.Bind(fd.pfd.Sysfd, lsa); err != nil { + if err = syscall.Bind(fd.pfd.Sysfd, lsa); err != nil { return os.NewSyscallError("bind", err) } } @@ -165,29 +195,39 @@ func (fd *netFD) dial(ctx context.Context, laddr, raddr sockaddr) error { return nil } -func (fd *netFD) listenStream(laddr sockaddr, backlog int) error { - if err := setDefaultListenerSockopts(fd.pfd.Sysfd); err != nil { +func (fd *netFD) listenStream(laddr sockaddr, backlog int, ctrlFn func(string, string, syscall.RawConn) error) error { + var err error + if err = setDefaultListenerSockopts(fd.pfd.Sysfd); err != nil { return err } - if lsa, err := laddr.sockaddr(fd.family); err != nil { + var lsa syscall.Sockaddr + if lsa, err = laddr.sockaddr(fd.family); err != nil { return err - } else if lsa != nil { - if err := syscall.Bind(fd.pfd.Sysfd, lsa); err != nil { - return os.NewSyscallError("bind", err) + } + if ctrlFn != nil { + c, err := newRawConn(fd) + if err != nil { + return err + } + if err := ctrlFn(fd.ctrlNetwork(), laddr.String(), c); err != nil { + return err } } - if err := listenFunc(fd.pfd.Sysfd, backlog); err != nil { + if err = syscall.Bind(fd.pfd.Sysfd, lsa); err != nil { + return os.NewSyscallError("bind", err) + } + if err = listenFunc(fd.pfd.Sysfd, backlog); err != nil { return os.NewSyscallError("listen", err) } - if err := fd.init(); err != nil { + if err = fd.init(); err != nil { return err } - lsa, _ := syscall.Getsockname(fd.pfd.Sysfd) + lsa, _ = syscall.Getsockname(fd.pfd.Sysfd) fd.setAddr(fd.addrFunc()(lsa), nil) return nil } -func (fd *netFD) listenDatagram(laddr sockaddr) error { +func (fd *netFD) listenDatagram(laddr sockaddr, ctrlFn func(string, string, syscall.RawConn) error) error { switch addr := laddr.(type) { case *UDPAddr: // We provide a socket that listens to a wildcard @@ -211,17 +251,27 @@ func (fd *netFD) listenDatagram(laddr sockaddr) error { laddr = &addr } } - if lsa, err := laddr.sockaddr(fd.family); err != nil { + var err error + var lsa syscall.Sockaddr + if lsa, err = laddr.sockaddr(fd.family); err != nil { return err - } else if lsa != nil { - if err := syscall.Bind(fd.pfd.Sysfd, lsa); err != nil { - return os.NewSyscallError("bind", err) + } + if ctrlFn != nil { + c, err := newRawConn(fd) + if err != nil { + return err + } + if err := ctrlFn(fd.ctrlNetwork(), laddr.String(), c); err != nil { + return err } } - if err := fd.init(); err != nil { + if err = syscall.Bind(fd.pfd.Sysfd, lsa); err != nil { + return os.NewSyscallError("bind", err) + } + if err = fd.init(); err != nil { return err } - lsa, _ := syscall.Getsockname(fd.pfd.Sysfd) + lsa, _ = syscall.Getsockname(fd.pfd.Sysfd) fd.setAddr(fd.addrFunc()(lsa), nil) return nil } diff --git a/src/net/tcpsock_posix.go b/src/net/tcpsock_posix.go index 6061c16986c22..bcf7592d35f79 100644 --- a/src/net/tcpsock_posix.go +++ b/src/net/tcpsock_posix.go @@ -62,7 +62,7 @@ func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPCo } func (sd *sysDialer) doDialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConn, error) { - fd, err := internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_STREAM, 0, "dial") + fd, err := internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_STREAM, 0, "dial", sd.Dialer.Control) // TCP has a rarely used mechanism called a 'simultaneous connection' in // which Dial("tcp", addr1, addr2) run on the machine at addr1 can @@ -92,7 +92,7 @@ func (sd *sysDialer) doDialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCP if err == nil { fd.Close() } - fd, err = internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_STREAM, 0, "dial") + fd, err = internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_STREAM, 0, "dial", sd.Dialer.Control) } if err != nil { @@ -156,7 +156,7 @@ func (ln *TCPListener) file() (*os.File, error) { } func (sl *sysListener) listenTCP(ctx context.Context, laddr *TCPAddr) (*TCPListener, error) { - fd, err := internetSocket(ctx, sl.network, laddr, nil, syscall.SOCK_STREAM, 0, "listen") + fd, err := internetSocket(ctx, sl.network, laddr, nil, syscall.SOCK_STREAM, 0, "listen", sl.ListenConfig.Control) if err != nil { return nil, err } diff --git a/src/net/udpsock_posix.go b/src/net/udpsock_posix.go index 4e96f4781df47..8f4b71c01ec80 100644 --- a/src/net/udpsock_posix.go +++ b/src/net/udpsock_posix.go @@ -95,7 +95,7 @@ func (c *UDPConn) writeMsg(b, oob []byte, addr *UDPAddr) (n, oobn int, err error } func (sd *sysDialer) dialUDP(ctx context.Context, laddr, raddr *UDPAddr) (*UDPConn, error) { - fd, err := internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_DGRAM, 0, "dial") + fd, err := internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_DGRAM, 0, "dial", sd.Dialer.Control) if err != nil { return nil, err } @@ -103,7 +103,7 @@ func (sd *sysDialer) dialUDP(ctx context.Context, laddr, raddr *UDPAddr) (*UDPCo } func (sl *sysListener) listenUDP(ctx context.Context, laddr *UDPAddr) (*UDPConn, error) { - fd, err := internetSocket(ctx, sl.network, laddr, nil, syscall.SOCK_DGRAM, 0, "listen") + fd, err := internetSocket(ctx, sl.network, laddr, nil, syscall.SOCK_DGRAM, 0, "listen", sl.ListenConfig.Control) if err != nil { return nil, err } @@ -111,7 +111,7 @@ func (sl *sysListener) listenUDP(ctx context.Context, laddr *UDPAddr) (*UDPConn, } func (sl *sysListener) listenMulticastUDP(ctx context.Context, ifi *Interface, gaddr *UDPAddr) (*UDPConn, error) { - fd, err := internetSocket(ctx, sl.network, gaddr, nil, syscall.SOCK_DGRAM, 0, "listen") + fd, err := internetSocket(ctx, sl.network, gaddr, nil, syscall.SOCK_DGRAM, 0, "listen", sl.ListenConfig.Control) if err != nil { return nil, err } diff --git a/src/net/unixsock_posix.go b/src/net/unixsock_posix.go index f627567af5f0f..2495da1d253fa 100644 --- a/src/net/unixsock_posix.go +++ b/src/net/unixsock_posix.go @@ -13,7 +13,7 @@ import ( "syscall" ) -func unixSocket(ctx context.Context, net string, laddr, raddr sockaddr, mode string) (*netFD, error) { +func unixSocket(ctx context.Context, net string, laddr, raddr sockaddr, mode string, ctrlFn func(string, string, syscall.RawConn) error) (*netFD, error) { var sotype int switch net { case "unix": @@ -42,7 +42,7 @@ func unixSocket(ctx context.Context, net string, laddr, raddr sockaddr, mode str return nil, errors.New("unknown mode: " + mode) } - fd, err := socket(ctx, net, syscall.AF_UNIX, sotype, 0, false, laddr, raddr) + fd, err := socket(ctx, net, syscall.AF_UNIX, sotype, 0, false, laddr, raddr, ctrlFn) if err != nil { return nil, err } @@ -151,7 +151,7 @@ func (c *UnixConn) writeMsg(b, oob []byte, addr *UnixAddr) (n, oobn int, err err } func (sd *sysDialer) dialUnix(ctx context.Context, laddr, raddr *UnixAddr) (*UnixConn, error) { - fd, err := unixSocket(ctx, sd.network, laddr, raddr, "dial") + fd, err := unixSocket(ctx, sd.network, laddr, raddr, "dial", sd.Dialer.Control) if err != nil { return nil, err } @@ -207,7 +207,7 @@ func (l *UnixListener) SetUnlinkOnClose(unlink bool) { } func (sl *sysListener) listenUnix(ctx context.Context, laddr *UnixAddr) (*UnixListener, error) { - fd, err := unixSocket(ctx, sl.network, laddr, nil, "listen") + fd, err := unixSocket(ctx, sl.network, laddr, nil, "listen", sl.ListenConfig.Control) if err != nil { return nil, err } @@ -215,7 +215,7 @@ func (sl *sysListener) listenUnix(ctx context.Context, laddr *UnixAddr) (*UnixLi } func (sl *sysListener) listenUnixgram(ctx context.Context, laddr *UnixAddr) (*UnixConn, error) { - fd, err := unixSocket(ctx, sl.network, laddr, nil, "listen") + fd, err := unixSocket(ctx, sl.network, laddr, nil, "listen", sl.ListenConfig.Control) if err != nil { return nil, err }