diff --git a/port_forwarding_test.go b/port_forwarding_test.go index 46a1917a52..471736150b 100644 --- a/port_forwarding_test.go +++ b/port_forwarding_test.go @@ -6,9 +6,12 @@ import ( "io" "net" "net/http" + "net/http/httptest" "testing" "time" + "github.com/stretchr/testify/require" + "github.com/testcontainers/testcontainers-go" tcexec "github.com/testcontainers/testcontainers-go/exec" "github.com/testcontainers/testcontainers-go/network" @@ -52,21 +55,10 @@ func TestExposeHostPorts(t *testing.T) { t.Run(tc.name, func(tt *testing.T) { freePorts := make([]int, tc.numberOfPorts) for i := range freePorts { - freePort, err := getFreePort() - if err != nil { - tt.Fatal(err) - } - - freePorts[i] = freePort - - // create an http server for each port - server, err := createHttpServer(freePort) - if err != nil { - tt.Fatal(err) - } - go func() { - _ = server.ListenAndServe() - }() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, expectedResponse) + })) + freePorts[i] = server.Listener.Addr().(*net.TCPAddr).Port tt.Cleanup(func() { server.Close() }) @@ -87,13 +79,10 @@ func TestExposeHostPorts(t *testing.T) { if tc.hasNetwork { var err error nw, err = network.New(context.Background()) - if err != nil { - tt.Fatal(err) - } + require.NoError(tt, err) + tt.Cleanup(func() { - if err := nw.Remove(context.Background()); err != nil { - tt.Fatal(err) - } + require.NoError(tt, nw.Remove(context.Background())) }) req.Networks = []string{nw.Name} @@ -108,13 +97,9 @@ func TestExposeHostPorts(t *testing.T) { } c, err := testcontainers.GenericContainer(ctx, req) - if err != nil { - tt.Fatal(err) - } + require.NoError(tt, err) tt.Cleanup(func() { - if err := c.Terminate(context.Background()); err != nil { - tt.Fatal(err) - } + require.NoError(tt, c.Terminate(context.Background())) }) if tc.hasHostAccess { @@ -174,30 +159,3 @@ func assertContainerHasNoHostAccess(t *testing.T, c testcontainers.Container, po } } } - -func createHttpServer(port int) (*http.Server, error) { - server := &http.Server{ - Addr: fmt.Sprintf(":%d", port), - } - - server.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, expectedResponse) - }) - - return server, nil -} - -// getFreePort asks the kernel for a free open port that is ready to use. -func getFreePort() (int, error) { - addr, err := net.ResolveTCPAddr("tcp", "localhost:0") - if err != nil { - return 0, err - } - - l, err := net.ListenTCP("tcp", addr) - if err != nil { - return 0, err - } - defer l.Close() - return l.Addr().(*net.TCPAddr).Port, nil -}