diff --git a/testsuite/devserver.go b/testsuite/devserver.go index 8b922a66f..ba10b5e7e 100644 --- a/testsuite/devserver.go +++ b/testsuite/devserver.go @@ -339,9 +339,9 @@ func extractZip(r io.Reader, toExtract string, w io.Writer) error { // Returns a connected client created using the provided options. func waitServerReady(ctx context.Context, options client.Options) (client.Client, error) { var returnedClient client.Client - lastErr := retryFor(600, 100*time.Millisecond, func() error { + lastErr := retryFor(ctx, 600, 100*time.Millisecond, func() error { var err error - returnedClient, err = client.Dial(options) + returnedClient, err = client.DialContext(ctx, options) return err }) if lastErr != nil { @@ -351,7 +351,7 @@ func waitServerReady(ctx context.Context, options client.Options) (client.Client } // retryFor retries some function until it returns nil or runs out of attempts. Wait interval between attempts. -func retryFor(maxAttempts int, interval time.Duration, cond func() error) error { +func retryFor(ctx context.Context, maxAttempts int, interval time.Duration, cond func() error) error { if maxAttempts < 1 { // this is used internally, okay to panic panic("maxAttempts should be at least 1") @@ -363,7 +363,12 @@ func retryFor(maxAttempts int, interval time.Duration, cond func() error) error } else { lastErr = curE } - time.Sleep(interval) + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(interval): + // Try again. + } } return lastErr } diff --git a/testsuite/devserver_internal_test.go b/testsuite/devserver_internal_test.go new file mode 100644 index 000000000..92ecfe794 --- /dev/null +++ b/testsuite/devserver_internal_test.go @@ -0,0 +1,62 @@ +// The MIT License +// +// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. +// +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package testsuite + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.temporal.io/sdk/client" + "go.temporal.io/sdk/internal/log" +) + +func TestWaitServerReady_respectsTimeout(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() + + hostPort, err := getFreeHostPort() + require.NoError(t, err, "get free host port") + + startTime := time.Now() + _, err = waitServerReady(ctx, client.Options{ + HostPort: hostPort, + Namespace: "default", + Logger: log.NewNopLogger(), + }) + require.Error(t, err, "Dial should fail") + assert.ErrorIs(t, err, context.DeadlineExceeded) + assert.WithinDuration(t, + startTime.Add(time.Millisecond), + time.Now(), + 5*time.Millisecond, + // Even though the timeout is only a millisecond, + // we'll allow for a slack of up to 5 milliseconds + // to account for slow CI machines. + // Anything smaller than 1 second is fine to use here. + ) +}