diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 911db387..cf184eba 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -2,9 +2,9 @@ name: CI on: push: - branches: [ master ] pull_request: - branches: [ master ] + branches: + - master jobs: diff --git a/sockets/sockets.go b/sockets/sockets.go index b0eae239..2aed7dd6 100644 --- a/sockets/sockets.go +++ b/sockets/sockets.go @@ -2,13 +2,19 @@ package sockets import ( + "context" "errors" + "fmt" "net" "net/http" + "syscall" "time" ) -const defaultTimeout = 10 * time.Second +const ( + defaultTimeout = 10 * time.Second + maxUnixSocketPathSize = len(syscall.RawSockaddrUnix{}.Path) +) // ErrProtocolNotAvailable is returned when a given transport protocol is not provided by the operating system. var ErrProtocolNotAvailable = errors.New("protocol not available") @@ -35,3 +41,18 @@ func ConfigureTransport(tr *http.Transport, proto, addr string) error { } return nil } + +func configureUnixTransport(tr *http.Transport, proto, addr string) error { + if len(addr) > maxUnixSocketPathSize { + return fmt.Errorf("unix socket path %q is too long", addr) + } + // No need for compression in local communications. + tr.DisableCompression = true + dialer := &net.Dialer{ + Timeout: defaultTimeout, + } + tr.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) { + return dialer.DialContext(ctx, proto, addr) + } + return nil +} diff --git a/sockets/sockets_unix.go b/sockets/sockets_unix.go index 78a34a98..4c469271 100644 --- a/sockets/sockets_unix.go +++ b/sockets/sockets_unix.go @@ -3,31 +3,12 @@ package sockets import ( - "context" - "fmt" "net" "net/http" "syscall" "time" ) -const maxUnixSocketPathSize = len(syscall.RawSockaddrUnix{}.Path) - -func configureUnixTransport(tr *http.Transport, proto, addr string) error { - if len(addr) > maxUnixSocketPathSize { - return fmt.Errorf("unix socket path %q is too long", addr) - } - // No need for compression in local communications. - tr.DisableCompression = true - dialer := &net.Dialer{ - Timeout: defaultTimeout, - } - tr.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) { - return dialer.DialContext(ctx, proto, addr) - } - return nil -} - func configureNpipeTransport(tr *http.Transport, proto, addr string) error { return ErrProtocolNotAvailable } diff --git a/sockets/sockets_windows.go b/sockets/sockets_windows.go index 7acafc5a..d4f2e788 100644 --- a/sockets/sockets_windows.go +++ b/sockets/sockets_windows.go @@ -9,10 +9,6 @@ import ( "github.com/Microsoft/go-winio" ) -func configureUnixTransport(tr *http.Transport, proto, addr string) error { - return ErrProtocolNotAvailable -} - func configureNpipeTransport(tr *http.Transport, proto, addr string) error { // No need for compression in local communications. tr.DisableCompression = true diff --git a/sockets/unix_socket.go b/sockets/unix_socket.go index b9233521..be6aa713 100644 --- a/sockets/unix_socket.go +++ b/sockets/unix_socket.go @@ -1,5 +1,3 @@ -//go:build !windows - /* Package sockets is a simple unix domain socket wrapper. @@ -90,22 +88,7 @@ func NewUnixSocketWithOpts(path string, opts ...SockOption) (net.Listener, error return nil, err } - // net.Listen does not allow for permissions to be set. As a result, when - // specifying custom permissions ("WithChmod()"), there is a short time - // between creating the socket and applying the permissions, during which - // the socket permissions are Less restrictive than desired. - // - // To work around this limitation of net.Listen(), we temporarily set the - // umask to 0777, which forces the socket to be created with 000 permissions - // (i.e.: no access for anyone). After that, WithChmod() must be used to set - // the desired permissions. - // - // We don't use "defer" here, to reset the umask to its original value as soon - // as possible. Ideally we'd be able to detect if WithChmod() was passed as - // an option, and skip changing umask if default permissions are used. - origUmask := syscall.Umask(0o777) - l, err := net.Listen("unix", path) - syscall.Umask(origUmask) + l, err := listenUnix(path) if err != nil { return nil, err } diff --git a/sockets/unix_socket_test.go b/sockets/unix_socket_test.go index e4ae0e37..527d433c 100644 --- a/sockets/unix_socket_test.go +++ b/sockets/unix_socket_test.go @@ -1,12 +1,9 @@ -//go:build !windows - package sockets import ( "fmt" "net" "os" - "syscall" "testing" ) @@ -52,26 +49,16 @@ func TestNewUnixSocket(t *testing.T) { } func TestUnixSocketWithOpts(t *testing.T) { - uid, gid := os.Getuid(), os.Getgid() - perms := os.FileMode(0o660) - path := "/tmp/test.sock" - echoStr := "hello" - l, err := NewUnixSocketWithOpts(path, WithChown(uid, gid), WithChmod(perms)) + socketFile, err := os.CreateTemp("", "test*.sock") if err != nil { t.Fatal(err) } + socketFile.Close() + defer os.Remove(socketFile.Name()) + + l := createTestUnixSocket(t, socketFile.Name()) defer l.Close() - p, err := os.Stat(path) - if err != nil { - t.Fatal(err) - } - if p.Mode().Perm() != perms { - t.Fatalf("unexpected file permissions: expected: %#o, got: %#o", perms, p.Mode().Perm()) - } - if stat, ok := p.Sys().(*syscall.Stat_t); ok { - if stat.Uid != uint32(uid) || stat.Gid != uint32(gid) { - t.Fatalf("unexpected file ownership: expected: %d:%d, got: %d:%d", uid, gid, stat.Uid, stat.Gid) - } - } - runTest(t, path, l, echoStr) + + echoStr := "hello" + runTest(t, socketFile.Name(), l, echoStr) } diff --git a/sockets/unix_socket_test_unix.go b/sockets/unix_socket_test_unix.go new file mode 100644 index 00000000..ae2e7321 --- /dev/null +++ b/sockets/unix_socket_test_unix.go @@ -0,0 +1,32 @@ +//go:build !windows + +package sockets + +import ( + "net" + "os" + "syscall" + "testing" +) + +func createTestUnixSocket(t *testing.T, path string) (listener net.Listener) { + uid, gid := os.Getuid(), os.Getgid() + perms := os.FileMode(0660) + l, err := NewUnixSocketWithOpts(path, WithChown(uid, gid), WithChmod(perms)) + if err != nil { + t.Fatal(err) + } + p, err := os.Stat(path) + if err != nil { + t.Fatal(err) + } + if p.Mode().Perm() != perms { + t.Fatalf("unexpected file permissions: expected: %#o, got: %#o", perms, p.Mode().Perm()) + } + if stat, ok := p.Sys().(*syscall.Stat_t); ok { + if stat.Uid != uint32(uid) || stat.Gid != uint32(gid) { + t.Fatalf("unexpected file ownership: expected: %d:%d, got: %d:%d", uid, gid, stat.Uid, stat.Gid) + } + } + return l +} diff --git a/sockets/unix_socket_test_windows.go b/sockets/unix_socket_test_windows.go new file mode 100644 index 00000000..e68aca0b --- /dev/null +++ b/sockets/unix_socket_test_windows.go @@ -0,0 +1,14 @@ +package sockets + +import ( + "net" + "testing" +) + +func createTestUnixSocket(t *testing.T, path string) (listener net.Listener) { + l, err := NewUnixSocketWithOpts(path) + if err != nil { + t.Fatal(err) + } + return l +} diff --git a/sockets/unix_socket_unix.go b/sockets/unix_socket_unix.go new file mode 100644 index 00000000..3fbc982e --- /dev/null +++ b/sockets/unix_socket_unix.go @@ -0,0 +1,28 @@ +//go:build !windows + +package sockets + +import ( + "net" + "syscall" +) + +func listenUnix(path string) (net.Listener, error) { + // net.Listen does not allow for permissions to be set. As a result, when + // specifying custom permissions ("WithChmod()"), there is a short time + // between creating the socket and applying the permissions, during which + // the socket permissions are Less restrictive than desired. + // + // To work around this limitation of net.Listen(), we temporarily set the + // umask to 0777, which forces the socket to be created with 000 permissions + // (i.e.: no access for anyone). After that, WithChmod() must be used to set + // the desired permissions. + // + // We don't use "defer" here, to reset the umask to its original value as soon + // as possible. Ideally we'd be able to detect if WithChmod() was passed as + // an option, and skip changing umask if default permissions are used. + origUmask := syscall.Umask(0o777) + l, err := net.Listen("unix", path) + syscall.Umask(origUmask) + return l, err +} diff --git a/sockets/unix_socket_windows.go b/sockets/unix_socket_windows.go new file mode 100644 index 00000000..5ec29e05 --- /dev/null +++ b/sockets/unix_socket_windows.go @@ -0,0 +1,7 @@ +package sockets + +import "net" + +func listenUnix(path string) (net.Listener, error) { + return net.Listen("unix", path) +}