diff --git a/aws/request/request_test.go b/aws/request/request_test.go index eec650b6d69..0082fa8b37c 100644 --- a/aws/request/request_test.go +++ b/aws/request/request_test.go @@ -778,3 +778,65 @@ func TestIsNoBodyReader(t *testing.T) { } } } + +func TestRequest_TemporaryRetry(t *testing.T) { + done := make(chan struct{}) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Length", "1024") + w.WriteHeader(http.StatusOK) + + w.Write(make([]byte, 100)) + + f := w.(http.Flusher) + f.Flush() + + <-done + })) + + client := &http.Client{ + Timeout: 100 * time.Millisecond, + } + + svc := awstesting.NewClient(&aws.Config{ + Region: unit.Session.Config.Region, + MaxRetries: aws.Int(1), + HTTPClient: client, + DisableSSL: aws.Bool(true), + Endpoint: aws.String(server.URL), + }) + + req := svc.NewRequest(&request.Operation{ + Name: "name", HTTPMethod: "GET", HTTPPath: "/path", + }, &struct{}{}, &struct{}{}) + + req.Handlers.Unmarshal.PushBack(func(r *request.Request) { + defer req.HTTPResponse.Body.Close() + _, err := io.Copy(ioutil.Discard, req.HTTPResponse.Body) + r.Error = awserr.New(request.ErrCodeSerialization, "error", err) + }) + + err := req.Send() + if err == nil { + t.Errorf("expect error, got none") + } + close(done) + + aerr := err.(awserr.Error) + if e, a := request.ErrCodeSerialization, aerr.Code(); e != a { + t.Errorf("expect %q error code, got %q", e, a) + } + + if e, a := 1, req.RetryCount; e != a { + t.Errorf("expect %d retries, got %d", e, a) + } + + type temporary interface { + Temporary() bool + } + + terr := aerr.OrigErr().(temporary) + if !terr.Temporary() { + t.Errorf("expect temporary error, was not") + } +} diff --git a/aws/request/retryer.go b/aws/request/retryer.go index e36aa3822d8..8d369c1b8c5 100644 --- a/aws/request/retryer.go +++ b/aws/request/retryer.go @@ -74,6 +74,10 @@ var validParentCodes = map[string]struct{}{ ErrCodeRead: struct{}{}, } +type temporaryError interface { + Temporary() bool +} + func isNestedErrorRetryable(parentErr awserr.Error) bool { if parentErr == nil { return false @@ -92,6 +96,10 @@ func isNestedErrorRetryable(parentErr awserr.Error) bool { return isCodeRetryable(aerr.Code()) } + if t, ok := err.(temporaryError); ok { + return t.Temporary() + } + return isErrConnectionReset(err) } diff --git a/aws/request/retryer_test.go b/aws/request/retryer_test.go index b1926e3d6f0..a8787487e81 100644 --- a/aws/request/retryer_test.go +++ b/aws/request/retryer_test.go @@ -1,10 +1,10 @@ package request import ( + "errors" + "fmt" "testing" - "github.com/stretchr/testify/assert" - "github.com/aws/aws-sdk-go/aws/awserr" ) @@ -12,5 +12,51 @@ func TestRequestThrottling(t *testing.T) { req := Request{} req.Error = awserr.New("Throttling", "", nil) - assert.True(t, req.IsErrorThrottle()) + if e, a := true, req.IsErrorThrottle(); e != a { + t.Errorf("expect %t to be throttled, was %t", e, a) + } +} + +type mockTempError bool + +func (e mockTempError) Error() string { + return fmt.Sprintf("mock temporary error: %t", e.Temporary()) +} +func (e mockTempError) Temporary() bool { + return bool(e) +} + +func TestIsErrorRetryable(t *testing.T) { + cases := []struct { + Err error + IsTemp bool + }{ + { + Err: awserr.New(ErrCodeSerialization, "temporary error", mockTempError(true)), + IsTemp: true, + }, + { + Err: awserr.New(ErrCodeSerialization, "temporary error", mockTempError(false)), + IsTemp: false, + }, + { + Err: awserr.New(ErrCodeSerialization, "some error", errors.New("blah")), + IsTemp: false, + }, + { + Err: awserr.New("SomeError", "some error", nil), + IsTemp: false, + }, + { + Err: awserr.New("RequestError", "some error", nil), + IsTemp: true, + }, + } + + for i, c := range cases { + retryable := IsErrorRetryable(c.Err) + if e, a := c.IsTemp, retryable; e != a { + t.Errorf("%d, expect %t temporary error, got %t", i, e, a) + } + } }