diff --git a/pkg/util/util.go b/pkg/util/util.go index ad09034bd61..24af2c17aee 100644 --- a/pkg/util/util.go +++ b/pkg/util/util.go @@ -782,12 +782,11 @@ func SafeGetBool(b *bool) bool { // IsPortFree checks if the port on localhost is free to use func IsPortFree(port int) bool { - address := fmt.Sprintf("localhost:%d", port) + address := fmt.Sprintf("0.0.0.0:%d", port) listener, err := net.Listen("tcp", address) if err != nil { return false } - _ = listener.Addr().(*net.TCPAddr).Port err = listener.Close() return err == nil } diff --git a/pkg/util/util_test.go b/pkg/util/util_test.go index 7b837b1199e..a64c34621ad 100644 --- a/pkg/util/util_test.go +++ b/pkg/util/util_test.go @@ -2735,3 +2735,129 @@ bar }) } } + +func TestIsPortFree(t *testing.T) { + type serverCloser interface { + Close() + } + type args struct { + port int + portProvider func() (int, serverCloser, error) + } + type test struct { + name string + args args + want bool + } + tests := []test{ + { + name: "negative port should return an error, handles as false", + args: args{port: -10}, + want: false, + }, + { + name: "0 should always be free", + args: args{port: 0}, + want: true, + }, + { + name: "random port bound on 127.0.0.1", + args: args{ + portProvider: func() (int, serverCloser, error) { + s := httptest.NewServer(nil) + _, p, err := net.SplitHostPort(strings.TrimPrefix(s.URL, "http://")) + if err != nil { + return 0, s, err + } + port, err := strconv.Atoi(p) + if err != nil { + return 0, s, err + } + return port, s, nil + }, + }, + want: false, + }, + { + name: "random port bound on 127.0.0.1 and checking 0 as input", + args: args{ + portProvider: func() (int, serverCloser, error) { + s := httptest.NewServer(nil) + return 0, s, nil + }, + }, + want: true, + }, + { + name: "random port bound on 0.0.0.0 and checking 0 as input", + args: args{ + portProvider: func() (int, serverCloser, error) { + // Intentionally not using httptest.Server, which listens to 127.0.0.1 + l, err := net.Listen("tcp", "0.0.0.0:0") + if err != nil { + return 0, nil, err + } + s := &httptest.Server{ + Listener: l, + Config: &http.Server{}, + } + s.Start() + + return 0, s, nil + }, + }, + want: true, + }, + { + name: "random port bound on 0.0.0.0", + args: args{ + portProvider: func() (int, serverCloser, error) { + // Intentionally not using httptest.Server, which listens to 127.0.0.1 + l, err := net.Listen("tcp", "0.0.0.0:0") + if err != nil { + return 0, nil, err + } + s := &httptest.Server{ + Listener: l, + Config: &http.Server{}, + } + s.Start() + + _, p, err := net.SplitHostPort(strings.TrimPrefix(s.URL, "http://")) + if err != nil { + return 0, s, err + } + port, err := strconv.Atoi(p) + if err != nil { + return 0, s, err + } + return port, s, nil + }, + }, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + port := tt.args.port + var s serverCloser + var err error + if tt.args.portProvider != nil { + port, s, err = tt.args.portProvider() + if s != nil { + defer s.Close() + } + if err != nil { + t.Errorf("error while computing port: %v", err) + return + } + } + + if got := IsPortFree(port); got != tt.want { + t.Errorf("IsPortFree() = %v, want %v", got, tt.want) + } + }) + } + +}