diff --git a/common/util.go b/common/util.go index d672e8d4cfe..bc12bd55b29 100644 --- a/common/util.go +++ b/common/util.go @@ -23,6 +23,7 @@ package common import ( "context" "encoding/json" + "errors" "fmt" "math" "math/rand" @@ -244,14 +245,22 @@ func ToServiceTransientError(err error) error { // IsServiceTransientError checks if the error is a transient error. func IsServiceTransientError(err error) bool { - switch err.(type) { - case *types.InternalServiceError: + + var ( + typesInternalServiceError *types.InternalServiceError + typesServiceBusyError *types.ServiceBusyError + typesShardOwnershipLostError *types.ShardOwnershipLostError + yarpcErrorsStatus *yarpcerrors.Status + ) + + switch { + case errors.As(err, &typesInternalServiceError): return true - case *types.ServiceBusyError: + case errors.As(err, &typesServiceBusyError): return true - case *types.ShardOwnershipLostError: + case errors.As(err, &typesShardOwnershipLostError): return true - case *yarpcerrors.Status: + case errors.As(err, &yarpcErrorsStatus): // We only selectively retry the following yarpc errors client can safe retry with a backoff if yarpcerrors.IsUnavailable(err) || yarpcerrors.IsUnknown(err) || diff --git a/common/util_test.go b/common/util_test.go index 5934bf2a5f9..fa1e0fb6cac 100644 --- a/common/util_test.go +++ b/common/util_test.go @@ -46,23 +46,53 @@ import ( "github.com/uber/cadence/common/types" ) -func TestIsServiceTransientError_ContextTimeout(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) - defer cancel() - time.Sleep(100 * time.Millisecond) - - require.False(t, IsServiceTransientError(ctx.Err())) -} - -func TestIsServiceTransientError_YARPCDeadlineExceeded(t *testing.T) { - yarpcErr := yarpcerrors.DeadlineExceededErrorf("yarpc deadline exceeded") - require.False(t, IsServiceTransientError(yarpcErr)) -} +func TestIsServiceTransientError(t *testing.T) { + for name, c := range map[string]struct { + err error + want bool + }{ + "ContextTimeout": { + err: context.DeadlineExceeded, + want: false, + }, + "YARPCDeadlineExceeded": { + err: yarpcerrors.DeadlineExceededErrorf("yarpc deadline exceeded"), + want: false, + }, + "YARPCUnavailable": { + err: yarpcerrors.UnavailableErrorf("yarpc unavailable"), + want: true, + }, + "YARPCUnavailable wrapped": { + err: fmt.Errorf("wrapped err: %w", yarpcerrors.UnavailableErrorf("yarpc unavailable")), + want: true, + }, + "YARPCUnknown": { + err: yarpcerrors.UnknownErrorf("yarpc unknown"), + want: true, + }, + "YARPCInternal": { + err: yarpcerrors.InternalErrorf("yarpc internal"), + want: true, + }, + "ContextCancel": { + err: context.Canceled, + want: false, + }, + "ServiceBusyError": { + err: &types.ServiceBusyError{}, + want: true, + }, + "ShardOwnershipLostError": { + err: &types.ShardOwnershipLostError{}, + want: true, + }, + } { + t.Run(name, func(t *testing.T) { + require.Equal(t, c.want, IsServiceTransientError(c.err)) + }) + } -func TestIsServiceTransientError_ContextCancel(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - cancel() - require.False(t, IsServiceTransientError(ctx.Err())) } func TestIsContextTimeoutError(t *testing.T) {