diff --git a/aws/client/default_retryer.go b/aws/client/default_retryer.go index a397b0d044c..c1c68076b27 100644 --- a/aws/client/default_retryer.go +++ b/aws/client/default_retryer.go @@ -8,18 +8,18 @@ import ( "github.com/aws/aws-sdk-go/internal/sdkrand" ) -// DefaultRetryer implements basic retry logic using exponential backoff for +// DefaultRetryer implements basic retry logic using exponential back off for // most services. If you want to implement custom retry logic, implement the // request.Retryer interface or create a structure type that composes this // struct and override the specific methods. For example, to override only // the MaxRetries method: // -// type retryer struct { -// client.DefaultRetryer -// } +// type retryer struct { +// client.DefaultRetryer +// } // -// // This implementation always has 100 max retries -// func (d retryer) MaxRetries() int { return 100 } +// // This implementation always has 100 max retries +// func (d retryer) MaxRetries() int { return 100 } type DefaultRetryer struct { NumMaxRetries int } @@ -34,8 +34,8 @@ func (d DefaultRetryer) MaxRetries() int { func (d DefaultRetryer) RetryRules(r *request.Request) time.Duration { // Set the upper limit of delay in retrying at ~five minutes minTime := 30 - throttle := d.shouldThrottle(r) - if throttle { + isThrottle := r.IsErrorThrottle() + if isThrottle { if delay, ok := getRetryDelay(r); ok { return delay } @@ -44,7 +44,7 @@ func (d DefaultRetryer) RetryRules(r *request.Request) time.Duration { } retryCount := r.RetryCount - if throttle && retryCount > 8 { + if isThrottle && retryCount > 8 { retryCount = 8 } else if retryCount > 13 { retryCount = 13 @@ -65,21 +65,8 @@ func (d DefaultRetryer) ShouldRetry(r *request.Request) bool { if r.HTTPResponse.StatusCode >= 500 && r.HTTPResponse.StatusCode != 501 { return true } - return r.IsErrorRetryable() || d.shouldThrottle(r) -} - -// ShouldThrottle returns true if the request should be throttled. -func (d DefaultRetryer) shouldThrottle(r *request.Request) bool { - switch r.HTTPResponse.StatusCode { - case 429: - case 502: - case 503: - case 504: - default: - return r.IsErrorThrottle() - } - return true + return r.IsErrorRetryable() || r.IsErrorThrottle() } // This will look in the Retry-After header, RFC 7231, for how long diff --git a/aws/client/default_retryer_test.go b/aws/client/default_retryer_test.go index fddba4e49e2..0c6e068a1db 100644 --- a/aws/client/default_retryer_test.go +++ b/aws/client/default_retryer_test.go @@ -60,7 +60,7 @@ func TestRetryThrottleStatusCodes(t *testing.T) { d := DefaultRetryer{NumMaxRetries: 10} for i, c := range cases { - throttle := d.shouldThrottle(&c.r) + throttle := c.r.IsErrorThrottle() retry := d.ShouldRetry(&c.r) if e, a := c.expectThrottle, throttle; e != a { diff --git a/aws/corehandlers/handlers.go b/aws/corehandlers/handlers.go index f8853d78af2..0c60e612ea5 100644 --- a/aws/corehandlers/handlers.go +++ b/aws/corehandlers/handlers.go @@ -159,9 +159,9 @@ func handleSendError(r *request.Request, err error) { Body: ioutil.NopCloser(bytes.NewReader([]byte{})), } } - // Catch all other request errors. + // Catch all request errors, and let the default retrier determine + // if the error is retryable. r.Error = awserr.New("RequestError", "send request failed", err) - r.Retryable = aws.Bool(true) // network errors are retryable // Override the error with a context canceled error, if that was canceled. ctx := r.Context() @@ -184,37 +184,39 @@ var ValidateResponseHandler = request.NamedHandler{Name: "core.ValidateResponseH // AfterRetryHandler performs final checks to determine if the request should // be retried and how long to delay. -var AfterRetryHandler = request.NamedHandler{Name: "core.AfterRetryHandler", Fn: func(r *request.Request) { - // If one of the other handlers already set the retry state - // we don't want to override it based on the service's state - if r.Retryable == nil || aws.BoolValue(r.Config.EnforceShouldRetryCheck) { - r.Retryable = aws.Bool(r.ShouldRetry(r)) - } +var AfterRetryHandler = request.NamedHandler{ + Name: "core.AfterRetryHandler", + Fn: func(r *request.Request) { + // If one of the other handlers already set the retry state + // we don't want to override it based on the service's state + if r.Retryable == nil || aws.BoolValue(r.Config.EnforceShouldRetryCheck) { + r.Retryable = aws.Bool(r.ShouldRetry(r)) + } - if r.WillRetry() { - r.RetryDelay = r.RetryRules(r) + if r.WillRetry() { + r.RetryDelay = r.RetryRules(r) - if sleepFn := r.Config.SleepDelay; sleepFn != nil { - // Support SleepDelay for backwards compatibility and testing - sleepFn(r.RetryDelay) - } else if err := aws.SleepWithContext(r.Context(), r.RetryDelay); err != nil { - r.Error = awserr.New(request.CanceledErrorCode, - "request context canceled", err) - r.Retryable = aws.Bool(false) - return - } + if sleepFn := r.Config.SleepDelay; sleepFn != nil { + // Support SleepDelay for backwards compatibility and testing + sleepFn(r.RetryDelay) + } else if err := aws.SleepWithContext(r.Context(), r.RetryDelay); err != nil { + r.Error = awserr.New(request.CanceledErrorCode, + "request context canceled", err) + r.Retryable = aws.Bool(false) + return + } - // when the expired token exception occurs the credentials - // need to be expired locally so that the next request to - // get credentials will trigger a credentials refresh. - if r.IsErrorExpired() { - r.Config.Credentials.Expire() - } + // when the expired token exception occurs the credentials + // need to be expired locally so that the next request to + // get credentials will trigger a credentials refresh. + if r.IsErrorExpired() { + r.Config.Credentials.Expire() + } - r.RetryCount++ - r.Error = nil - } -}} + r.RetryCount++ + r.Error = nil + } + }} // ValidateEndpointHandler is a request handler to validate a request had the // appropriate Region and Endpoint set. Will set r.Error if the endpoint or diff --git a/aws/request/http_request_retry_test.go b/aws/request/http_request_retry_test.go index fcdd1ce819b..8cc3b043d5e 100644 --- a/aws/request/http_request_retry_test.go +++ b/aws/request/http_request_retry_test.go @@ -1,7 +1,6 @@ package request_test import ( - "errors" "strings" "testing" @@ -14,14 +13,15 @@ func TestRequestCancelRetry(t *testing.T) { c := make(chan struct{}) reqNum := 0 - s := mock.NewMockClient(aws.NewConfig().WithMaxRetries(10)) + s := mock.NewMockClient(&aws.Config{ + MaxRetries: aws.Int(1), + }) s.Handlers.Validate.Clear() s.Handlers.Unmarshal.Clear() s.Handlers.UnmarshalMeta.Clear() s.Handlers.UnmarshalError.Clear() s.Handlers.Send.PushFront(func(r *request.Request) { reqNum++ - r.Error = errors.New("net/http: request canceled") }) out := &testData{} r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out) diff --git a/aws/request/request.go b/aws/request/request.go index e7c9b2b61af..4d8f3daa5bd 100644 --- a/aws/request/request.go +++ b/aws/request/request.go @@ -4,7 +4,6 @@ import ( "bytes" "fmt" "io" - "net" "net/http" "net/url" "reflect" @@ -65,6 +64,11 @@ type Request struct { LastSignedAt time.Time DisableFollowRedirects bool + // Additional API error codes that the SDK will retry. + RetryCodes []string + // Additional API error codes that the SDK will retry as throttled requests. + ThrottleCodes []string + // A value greater than 0 instructs the request to be signed as Presigned URL // You should not set this field directly. Instead use Request's // Presign or PresignRequest methods. @@ -498,21 +502,17 @@ func (r *Request) Send() error { if err := r.sendRequest(); err == nil { return nil - } else if !shouldRetryError(r.Error) { + } + r.Handlers.Retry.Run(r) + r.Handlers.AfterRetry.Run(r) + + if r.Error != nil || !aws.BoolValue(r.Retryable) { + return r.Error + } + + if err := r.prepareRetry(); err != nil { + r.Error = err return err - } else { - r.Handlers.Retry.Run(r) - r.Handlers.AfterRetry.Run(r) - - if r.Error != nil || !aws.BoolValue(r.Retryable) { - return r.Error - } - - if err := r.prepareRetry(); err != nil { - r.Error = err - return err - } - continue } } } @@ -596,51 +596,6 @@ func AddToUserAgent(r *Request, s string) { r.HTTPRequest.Header.Set("User-Agent", s) } -type temporary interface { - Temporary() bool -} - -func shouldRetryError(origErr error) bool { - switch err := origErr.(type) { - case awserr.Error: - if err.Code() == CanceledErrorCode { - return false - } - return shouldRetryError(err.OrigErr()) - case *url.Error: - if strings.Contains(err.Error(), "connection refused") { - // Refused connections should be retried as the service may not yet - // be running on the port. Go TCP dial considers refused - // connections as not temporary. - return true - } - // *url.Error only implements Temporary after golang 1.6 but since - // url.Error only wraps the error: - return shouldRetryError(err.Err) - case temporary: - if netErr, ok := err.(*net.OpError); ok && netErr.Op == "dial" { - return true - } - // If the error is temporary, we want to allow continuation of the - // retry process - return err.Temporary() || isErrConnectionReset(origErr) - case nil: - // `awserr.Error.OrigErr()` can be nil, meaning there was an error but - // because we don't know the cause, it is marked as retryable. See - // TestRequest4xxUnretryable for an example. - return true - default: - switch err.Error() { - case "net/http: request canceled", - "net/http: request canceled while waiting for connection": - // known 1.5 error case when an http request is cancelled - return false - } - // here we don't know the error; so we allow a retry. - return true - } -} - // SanitizeHostForHeader removes default port from host and updates request.Host func SanitizeHostForHeader(r *http.Request) { host := getHost(r) diff --git a/aws/request/request_retry_test.go b/aws/request/request_retry_test.go index 60a8dfbf147..d42f960e157 100644 --- a/aws/request/request_retry_test.go +++ b/aws/request/request_retry_test.go @@ -31,12 +31,12 @@ func TestShouldRetryError_timeout(t *testing.T) { tr := &http.Transport{} defer tr.CloseIdleConnections() - cli := http.Client{ + client := http.Client{ Timeout: time.Nanosecond, Transport: tr, } - resp, err := cli.Do(newRequest(t, "https://179.179.179.179/no/such/host")) + resp, err := client.Do(newRequest(t, "https://179.179.179.179/no/such/host")) if resp != nil { resp.Body.Close() } @@ -53,7 +53,7 @@ func TestShouldRetryError_timeout(t *testing.T) { func TestShouldRetryError_cancelled(t *testing.T) { tr := &http.Transport{} defer tr.CloseIdleConnections() - cli := http.Client{ + client := http.Client{ Transport: tr, } @@ -82,7 +82,7 @@ func TestShouldRetryError_cancelled(t *testing.T) { close(ch) // request is cancelled before anything }() - resp, err := cli.Do(r) + resp, err := client.Do(r) if resp != nil { resp.Body.Close() } diff --git a/aws/request/request_test.go b/aws/request/request_test.go index a7a229baade..477c761ea72 100644 --- a/aws/request/request_test.go +++ b/aws/request/request_test.go @@ -181,13 +181,18 @@ func TestRequestRecoverRetry4xxRetryable(t *testing.T) { // test that retries don't occur for 4xx status codes with a response type that can't be retried func TestRequest4xxUnretryable(t *testing.T) { - s := awstesting.NewClient(aws.NewConfig().WithMaxRetries(10)) + s := awstesting.NewClient(&aws.Config{ + MaxRetries: aws.Int(1), + }) s.Handlers.Validate.Clear() s.Handlers.Unmarshal.PushBack(unmarshal) s.Handlers.UnmarshalError.PushBack(unmarshalError) s.Handlers.Send.Clear() // mock sending s.Handlers.Send.PushBack(func(r *request.Request) { - r.HTTPResponse = &http.Response{StatusCode: 401, Body: body(`{"__type":"SignatureDoesNotMatch","message":"Signature does not match."}`)} + r.HTTPResponse = &http.Response{ + StatusCode: 401, + Body: body(`{"__type":"SignatureDoesNotMatch","message":"Signature does not match."}`), + } }) out := &testData{} r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out) @@ -580,7 +585,7 @@ func TestIsSerializationErrorRetryable(t *testing.T) { Error: c.err, } if r.IsErrorRetryable() != c.expected { - t.Errorf("Case %d: Expected %v, but received %v", i+1, c.expected, !c.expected) + t.Errorf("Case %d: Expected %v, but received %v", i, c.expected, !c.expected) } } } @@ -1124,7 +1129,6 @@ func TestRequestBodySeekFails(t *testing.T) { if err == nil { t.Fatal("expect error, but got none") } - t.Log("Error:", err) aerr := err.(awserr.Error) if e, a := request.ErrCodeSerialization, aerr.Code(); e != a { diff --git a/aws/request/retryer.go b/aws/request/retryer.go index d0aa54c6d10..4c2d46a9641 100644 --- a/aws/request/retryer.go +++ b/aws/request/retryer.go @@ -1,6 +1,9 @@ package request import ( + "net" + "net/url" + "strings" "time" "github.com/aws/aws-sdk-go/aws" @@ -108,32 +111,87 @@ func isNestedErrorRetryable(parentErr awserr.Error) bool { // IsErrorRetryable returns whether the error is retryable, based on its Code. // Returns false if error is nil. func IsErrorRetryable(err error) bool { - if err != nil { - if aerr, ok := err.(awserr.Error); ok { - return isCodeRetryable(aerr.Code()) || isNestedErrorRetryable(aerr) + return shouldRetryError(err) +} + +type temporary interface { + Temporary() bool +} + +func shouldRetryError(origErr error) bool { + switch err := origErr.(type) { + case awserr.Error: + if err.Code() == CanceledErrorCode { + return false + } + if isNestedErrorRetryable(err) { + return true + } + + origErr := err.OrigErr() + var shouldRetry bool + if origErr != nil { + shouldRetry := shouldRetryError(origErr) + if err.Code() == "RequestError" && !shouldRetry { + return false + } + } + if isCodeRetryable(err.Code()) { + return true + } + return shouldRetry + + case *url.Error: + if strings.Contains(err.Error(), "connection refused") { + // Refused connections should be retried as the service may not yet + // be running on the port. Go TCP dial considers refused + // connections as not temporary. + return true } + // *url.Error only implements Temporary after golang 1.6 but since + // url.Error only wraps the error: + return shouldRetryError(err.Err) + + case temporary: + if netErr, ok := err.(*net.OpError); ok && netErr.Op == "dial" { + return true + } + // If the error is temporary, we want to allow continuation of the + // retry process + return err.Temporary() || isErrConnectionReset(origErr) + + case nil: + // `awserr.Error.OrigErr()` can be nil, meaning there was an error but + // because we don't know the cause, it is marked as retryable. See + // TestRequest4xxUnretryable for an example. + return true + + default: + switch err.Error() { + case "net/http: request canceled", + "net/http: request canceled while waiting for connection": + // known 1.5 error case when an http request is cancelled + return false + } + // here we don't know the error; so we allow a retry. + return true } - return false } // IsErrorThrottle returns whether the error is to be throttled based on its code. // Returns false if error is nil. func IsErrorThrottle(err error) bool { - if err != nil { - if aerr, ok := err.(awserr.Error); ok { - return isCodeThrottle(aerr.Code()) - } + if aerr, ok := err.(awserr.Error); ok && aerr != nil { + return isCodeThrottle(aerr.Code()) } return false } -// IsErrorExpiredCreds returns whether the error code is a credential expiry error. -// Returns false if error is nil. +// IsErrorExpiredCreds returns whether the error code is a credential expiry +// error. Returns false if error is nil. func IsErrorExpiredCreds(err error) bool { - if err != nil { - if aerr, ok := err.(awserr.Error); ok { - return isCodeExpiredCreds(aerr.Code()) - } + if aerr, ok := err.(awserr.Error); ok && aerr != nil { + return isCodeExpiredCreds(aerr.Code()) } return false } @@ -143,17 +201,47 @@ func IsErrorExpiredCreds(err error) bool { // // Alias for the utility function IsErrorRetryable func (r *Request) IsErrorRetryable() bool { + if r.Error == nil { + return false + } + if isErrCode(r.Error, r.RetryCodes) { + return true + } + return IsErrorRetryable(r.Error) } -// IsErrorThrottle returns whether the error is to be throttled based on its code. -// Returns false if the request has no Error set +// IsErrorThrottle returns whether the error is to be throttled based on its +// code. Returns false if the request has no Error set. // // Alias for the utility function IsErrorThrottle func (r *Request) IsErrorThrottle() bool { + if isErrCode(r.Error, r.ThrottleCodes) { + return true + } + + if r.HTTPResponse != nil { + switch r.HTTPResponse.StatusCode { + case 429, 502, 503, 504: + return true + } + } + return IsErrorThrottle(r.Error) } +func isErrCode(err error, codes []string) bool { + if aerr, ok := err.(awserr.Error); ok { + for _, code := range codes { + if code == aerr.Code() { + return true + } + } + } + + return false +} + // IsErrorExpired returns whether the error code is a credential expiry error. // Returns false if the request has no Error set. //