diff --git a/conn_linux.go b/conn_linux.go index 61af3a2..61aa592 100644 --- a/conn_linux.go +++ b/conn_linux.go @@ -43,6 +43,7 @@ type socket interface { SetWriteDeadline(t time.Time) error SetSockoptSockFprog(level, opt int, fprog *unix.SockFprog) error SetSockoptInt(level, opt, value int) error + GetSockoptInt(level, opt int) (int, error) } // dial is the entry point for Dial. dial opens a netlink socket using @@ -276,21 +277,67 @@ func (c *conn) SetWriteDeadline(t time.Time) error { // SetReadBuffer sets the size of the operating system's receive buffer // associated with the Conn. func (c *conn) SetReadBuffer(bytes int) error { - return os.NewSyscallError("setsockopt", c.s.SetSockoptInt( + // First try SO_RCVBUFFORCE. Given necessary permissions this syscall ignores limits. + err := os.NewSyscallError("setsockopt", c.s.SetSockoptInt( unix.SOL_SOCKET, - unix.SO_RCVBUF, + unix.SO_RCVBUFFORCE, bytes, )) + if err != nil { + // If SO_SNDBUFFORCE fails, try SO_RCVBUF + err = os.NewSyscallError("setsockopt", c.s.SetSockoptInt( + unix.SOL_SOCKET, + unix.SO_RCVBUF, + bytes, + )) + } + return err } // SetReadBuffer sets the size of the operating system's transmit buffer // associated with the Conn. func (c *conn) SetWriteBuffer(bytes int) error { - return os.NewSyscallError("setsockopt", c.s.SetSockoptInt( + // First try SO_SNDBUFFORCE. Given necessary permissions this syscall ignores limits. + err := os.NewSyscallError("setsockopt", c.s.SetSockoptInt( unix.SOL_SOCKET, - unix.SO_SNDBUF, + unix.SO_SNDBUFFORCE, bytes, )) + if err != nil { + // If SO_SNDBUFFORCE fails, try SO_SNDBUF + err = os.NewSyscallError("setsockopt", c.s.SetSockoptInt( + unix.SOL_SOCKET, + unix.SO_SNDBUF, + bytes, + )) + } + return err +} + +// GetReadBuffer retrieves the size of the operating system's receive buffer +// associated with the Conn. +func (c *conn) GetReadBuffer() (int, error) { + value, err := c.s.GetSockoptInt( + unix.SOL_SOCKET, + unix.SO_RCVBUF, + ) + if err != nil { + return 0, os.NewSyscallError("getsockopt", err) + } + return value, nil +} + +// GetWriteBuffer retrieves the size of the operating system's transmit buffer +// associated with the Conn. +func (c *conn) GetWriteBuffer() (int, error) { + value, err := c.s.GetSockoptInt( + unix.SOL_SOCKET, + unix.SO_SNDBUF, + ) + if err != nil { + return 0, os.NewSyscallError("getsockopt", err) + } + return value, nil } // linuxOption converts a ConnOption to its Linux value. @@ -597,6 +644,21 @@ func (s *sysSocket) SetSockoptInt(level, opt, value int) error { return err } +func (s *sysSocket) GetSockoptInt(level, opt int) (int, error) { + var ( + value int + err error + ) + doErr := s.control(func(fd int) { + value, err = unix.GetsockoptInt(fd, level, opt) + }) + if doErr != nil { + return 0, doErr + } + + return value, err +} + func (s *sysSocket) SetSockoptSockFprog(level, opt int, fprog *unix.SockFprog) error { var err error doErr := s.control(func(fd int) { diff --git a/conn_linux_test.go b/conn_linux_test.go index 17d1ed9..b3a2264 100644 --- a/conn_linux_test.go +++ b/conn_linux_test.go @@ -449,12 +449,12 @@ func TestLinuxConnSetBuffers(t *testing.T) { want := []setSockopt{ { level: unix.SOL_SOCKET, - opt: unix.SO_RCVBUF, + opt: unix.SO_RCVBUFFORCE, value: n, }, { level: unix.SOL_SOCKET, - opt: unix.SO_SNDBUF, + opt: unix.SO_SNDBUFFORCE, value: n, }, } @@ -679,6 +679,16 @@ func (s *testSocket) SetSockoptInt(level, opt, value int) error { return nil } +func (s *testSocket) GetSockoptInt(level, opt int) (int, error) { + for i := len(s.setSockopt)-1; i >= 0; i-- { + if s.setSockopt[i].level == level && s.setSockopt[i].opt == opt { + return s.setSockopt[i].value, nil + } + } + + return 0, errors.New("getsockopt without preceding setsockopt") +} + func (s *testSocket) SetSockoptSockFprog(_, _ int, _ *unix.SockFprog) error { panic("netlink: testSocket.SetSockoptSockFprog not currently implemented") }