diff --git a/CHANGELOG_PENDING.md b/CHANGELOG_PENDING.md index bfa814e1ea5..9ddc9589b03 100644 --- a/CHANGELOG_PENDING.md +++ b/CHANGELOG_PENDING.md @@ -1,13 +1,13 @@ ### SDK Features ### SDK Enhancements -* `aws/endpoints`: Expose DNSSuffix for partitions ([#368](https://github.com/aws/aws-sdk-go/pull/368)) +* `aws/endpoints`: Expose DNSSuffix for partitions ([#369](https://github.com/aws/aws-sdk-go-v2/pull/369)) * Exposes the underlying partition metadata's DNSSuffix value via the `DNSSuffix` method on the endpoint's `Partition` type. This allows access to the partition's DNS suffix, e.g. "amazon.com". - * Fixes [#347](https://github.com/aws/aws-sdk-go/issues/347) + * Fixes [#347](https://github.com/aws/aws-sdk-go-v2/issues/347) * `private/protocol`: Add support for parsing fractional timestamp ([#367](https://github.com/aws/aws-sdk-go-v2/pull/367)) * Fixes the SDK's ability to parse fractional unix timestamp values and adds tests. * Fixes [#365](https://github.com/aws/aws-sdk-go-v2/issues/365) -* `aws/ec2metadata`: Add marketplaceProductCodes to EC2 Instance Identity Document +* `aws/ec2metadata`: Add marketplaceProductCodes to EC2 Instance Identity Document ([#374](https://github.com/aws/aws-sdk-go-v2/pull/374)) * Adds `MarketplaceProductCodes` to the EC2 Instance Metadata's Identity Document. The ec2metadata client will now retrieve these values if they are available. * Related to: [aws/aws-sdk-go#2781](https://github.com/aws/aws-sdk-go/issues/2781) @@ -15,4 +15,6 @@ * `aws`: Fixes bug in calculating throttled retry delay ([#373](https://github.com/aws/aws-sdk-go-v2/pull/373)) * The `Retry-After` duration specified in the request is now added to the Retry delay for throttled exception. Adds test for retry delays for throttled exceptions. Fixes bug where the throttled retry's math was off. * Fixes [#45](https://github.com/aws/aws-sdk-go-v2/issues/45) - +* `aws` : Adds missing sdk error checking when seeking readers [#379](https://github.com/aws/aws-sdk-go-v2/pull/379). + * Adds support for nonseekable io.Reader. Adds support for streamed payloads for unsigned body request. + * Fixes [#371](https://github.com/aws/aws-sdk-go-v2/issues/371) diff --git a/aws/client_logger.go b/aws/client_logger.go index 5f1c90d5bcc..2b424c068f2 100644 --- a/aws/client_logger.go +++ b/aws/client_logger.go @@ -43,6 +43,8 @@ func (reader *teeReaderCloser) Close() error { func logRequest(r *Request) { logBody := r.Config.LogLevel.Matches(LogDebugWithHTTPBody) + bodySeekable := IsReaderSeekable(r.Body) + dumpedBody, err := httputil.DumpRequestOut(r.HTTPRequest, logBody) if err != nil { r.Config.Logger.Log(fmt.Sprintf(logReqErrMsg, r.Metadata.ServiceName, r.Operation.Name, err)) @@ -50,10 +52,17 @@ func logRequest(r *Request) { } if logBody { - // Reset the request body because dumpRequest will re-wrap the r.HTTPRequest's - // Body as a NoOpCloser and will not be reset after read by the HTTP - // client reader. - r.ResetBody() + if !bodySeekable { + r.SetReaderBody(ReadSeekCloser(r.HTTPRequest.Body)) + } + + // Reset the request body because dumpRequest will re-wrap the + // r.HTTPRequest's Body as a NoOpCloser and will not be reset + // after read by the HTTP client reader. + if err := r.Error; err != nil { + r.Config.Logger.Log(fmt.Sprintf(logReqErrMsg, r.Metadata.ServiceName, r.Operation.Name, err)) + return + } } r.Config.Logger.Log(fmt.Sprintf(logReqMsg, r.Metadata.ServiceName, r.Operation.Name, string(dumpedBody))) diff --git a/aws/client_logger_test.go b/aws/client_logger_test.go index 993f313dbd2..7d78ff550b0 100644 --- a/aws/client_logger_test.go +++ b/aws/client_logger_test.go @@ -2,7 +2,11 @@ package aws import ( "bytes" + "fmt" "io" + "io/ioutil" + "reflect" + "runtime" "testing" ) @@ -55,3 +59,87 @@ func TestLogWriter(t *testing.T) { t.Errorf("Expected %q, but received %q", expected, lw.buf.String()) } } + +func TestLogRequest(t *testing.T) { + cases := []struct { + Body io.ReadSeeker + ExpectBody []byte + LogLevel LogLevel + }{ + { + Body: ReadSeekCloser(bytes.NewBuffer([]byte("body content"))), + ExpectBody: []byte("body content"), + }, + { + Body: ReadSeekCloser(bytes.NewBuffer([]byte("body content"))), + LogLevel: LogDebugWithHTTPBody, + ExpectBody: []byte("body content"), + }, + { + Body: bytes.NewReader([]byte("body content")), + ExpectBody: []byte("body content"), + }, + { + Body: bytes.NewReader([]byte("body content")), + LogLevel: LogDebugWithHTTPBody, + ExpectBody: []byte("body content"), + }, + } + + for i, c := range cases { + var logW bytes.Buffer + req := New( + Config{ + EndpointResolver: ResolveWithEndpointURL("https://endpoint"), + Credentials: AnonymousCredentials, + Logger: &bufLogger{w: &logW}, + LogLevel: c.LogLevel, + Region: "mock-region", + }, + Metadata{ + EndpointsID: "https://mock-service.mock-region.amazonaws.com", + }, + testHandlers(), + nil, + &Operation{ + Name: "APIName", + HTTPMethod: "POST", + HTTPPath: "/", + }, + struct{}{}, nil, + ) + req.SetReaderBody(c.Body) + req.Build() + + logRequest(req) + + b, err := ioutil.ReadAll(req.HTTPRequest.Body) + if err != nil { + t.Fatalf("%d, expect to read SDK request Body", i) + } + + if e, a := c.ExpectBody, b; !reflect.DeepEqual(e, a) { + t.Errorf("%d, expect %v body, got %v", i, e, a) + } + } +} + +type bufLogger struct { + w *bytes.Buffer +} + +func (l *bufLogger) Log(args ...interface{}) { + fmt.Fprintln(l.w, args...) +} + +func testHandlers() Handlers { + var handlers Handlers + handler := NamedHandler{ + Name: "core.SDKVersionUserAgentHandler", + Fn: MakeAddToUserAgentHandler(SDKName, SDKVersion, + runtime.Version(), runtime.GOOS, runtime.GOARCH), + } + handlers.Build.PushBackNamed(handler) + + return handlers +} diff --git a/aws/defaults/handlers.go b/aws/defaults/handlers.go index ee8f183dd61..86e155ca35a 100644 --- a/aws/defaults/handlers.go +++ b/aws/defaults/handlers.go @@ -40,9 +40,19 @@ var BuildContentLengthHandler = aws.NamedHandler{Name: "core.BuildContentLengthH case lener: length = int64(body.Len()) case io.Seeker: - r.BodyStart, _ = body.Seek(0, 1) - end, _ := body.Seek(0, 2) - body.Seek(r.BodyStart, 0) // make sure to seek back to original location + var err error + r.BodyStart, err = body.Seek(0, io.SeekCurrent) + if err != nil { + r.Error = awserr.New(aws.ErrCodeSerialization, "failed to determine start of the request body", err) + } + end, err := body.Seek(0, io.SeekEnd) + if err != nil { + r.Error = awserr.New(aws.ErrCodeSerialization, "failed to determine end of the request body", err) + } + _, err = body.Seek(r.BodyStart, io.SeekStart) // make sure to seek back to original location + if err != nil { + r.Error = awserr.New(aws.ErrCodeSerialization, "failed to seek back to the original location", err) + } length = end - r.BodyStart default: panic("Cannot get length of body, must provide `ContentLength`") diff --git a/aws/offset_reader.go b/aws/offset_reader.go index cb4614cd3a5..fa7f2eec81e 100644 --- a/aws/offset_reader.go +++ b/aws/offset_reader.go @@ -13,12 +13,14 @@ type offsetReader struct { closed bool } -func newOffsetReader(buf io.ReadSeeker, offset int64) *offsetReader { +func newOffsetReader(buf io.ReadSeeker, offset int64) (*offsetReader, error) { reader := &offsetReader{} - buf.Seek(offset, 0) - + _, err := buf.Seek(offset, io.SeekStart) + if err != nil { + return nil, err + } reader.buf = buf - return reader + return reader, nil } // Close will close the instance of the offset reader's access to @@ -52,7 +54,9 @@ func (o *offsetReader) Seek(offset int64, whence int) (int64, error) { // CloseAndCopy will return a new offsetReader with a copy of the old buffer // and close the old buffer. -func (o *offsetReader) CloseAndCopy(offset int64) *offsetReader { - o.Close() +func (o *offsetReader) CloseAndCopy(offset int64) (*offsetReader, error) { + if err := o.Close(); err != nil { + return nil, err + } return newOffsetReader(o.buf, offset) } diff --git a/aws/offset_reader_test.go b/aws/offset_reader_test.go index 78891ffa20d..26f6819a75e 100644 --- a/aws/offset_reader_test.go +++ b/aws/offset_reader_test.go @@ -21,7 +21,7 @@ func TestOffsetReaderRead(t *testing.T) { t.Errorf("expect %v, got %v", e, a) } if err != nil { - t.Errorf("expect no error, got %v", err) + t.Fatalf("expect no error, got %v", err) } if e, a := buf, tempBuf; !bytes.Equal(e, a) { t.Errorf("expect %v, got %v", e, a) @@ -30,27 +30,30 @@ func TestOffsetReaderRead(t *testing.T) { func TestOffsetReaderSeek(t *testing.T) { buf := []byte("testData") - reader := newOffsetReader(bytes.NewReader(buf), 0) + reader, err := newOffsetReader(bytes.NewReader(buf), 0) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } - orig, err := reader.Seek(0, 1) + orig, err := reader.Seek(0, io.SeekCurrent) if err != nil { - t.Errorf("expect no error, got %v", err) + t.Fatalf("expect no error, got %v", err) } if e, a := int64(0), orig; e != a { t.Errorf("expect %v, got %v", e, a) } - n, err := reader.Seek(0, 2) + n, err := reader.Seek(0, io.SeekEnd) if err != nil { - t.Errorf("expect no error, got %v", err) + t.Fatalf("expect no error, got %v", err) } if e, a := int64(len(buf)), n; e != a { t.Errorf("expect %v, got %v", e, a) } - n, err = reader.Seek(orig, 0) + n, err = reader.Seek(orig, io.SeekStart) if err != nil { - t.Errorf("expect no error, got %v", err) + t.Fatalf("expect no error, got %v", err) } if e, a := int64(0), n; e != a { t.Errorf("expect %v, got %v", e, a) @@ -81,8 +84,10 @@ func TestOffsetReaderCloseAndCopy(t *testing.T) { tempBuf := make([]byte, len(buf)) reader := &offsetReader{buf: bytes.NewReader(buf)} - newReader := reader.CloseAndCopy(0) - + newReader, err := reader.CloseAndCopy(0) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } n, err := reader.Read(tempBuf) if e, a := 0, n; e != a { t.Errorf("expect %v, got %v", e, a) @@ -96,7 +101,7 @@ func TestOffsetReaderCloseAndCopy(t *testing.T) { t.Errorf("expect %v, got %v", e, a) } if err != nil { - t.Errorf("expect no error, got %v", err) + t.Fatalf("expect no error, got %v", err) } if e, a := buf, tempBuf; !bytes.Equal(e, a) { t.Errorf("expect %v, got %v", e, a) @@ -108,13 +113,16 @@ func TestOffsetReaderCloseAndCopyOffset(t *testing.T) { tempBuf := make([]byte, len(buf)) reader := &offsetReader{buf: bytes.NewReader(buf)} - newReader := reader.CloseAndCopy(4) + newReader, err := reader.CloseAndCopy(4) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } n, err := newReader.Read(tempBuf) if e, a := n, len(buf)-4; e != a { t.Errorf("expect %v, got %v", e, a) } if err != nil { - t.Errorf("expect no error, got %v", err) + t.Fatalf("expect no error, got %v", err) } expected := []byte{'D', 'a', 't', 'a', 0, 0, 0, 0} diff --git a/aws/request.go b/aws/request.go index 11c561d5249..59475eb12d1 100644 --- a/aws/request.go +++ b/aws/request.go @@ -100,6 +100,7 @@ func New(cfg Config, metadata Metadata, handlers Handlers, // TODO need better way of handling this error... NewRequest should return error. endpoint, err := cfg.EndpointResolver.ResolveEndpoint(metadata.EndpointsID, cfg.Region) + if err == nil { // TODO so ugly metadata.Endpoint = endpoint.URL @@ -227,6 +228,9 @@ func (r *Request) SetContext(ctx context.Context) { // WillRetry returns if the request's can be retried. func (r *Request) WillRetry() bool { + if !IsReaderSeekable(r.Body) && r.HTTPRequest.Body != NoBody { + return false + } return r.Error != nil && BoolValue(r.Retryable) && r.RetryCount < r.MaxRetries() } @@ -251,6 +255,17 @@ func (r *Request) SetStringBody(s string) { // SetReaderBody will set the request's body reader. func (r *Request) SetReaderBody(reader io.ReadSeeker) { r.Body = reader + if IsReaderSeekable(reader) { + var err error + // Get the Bodies current offset so retries will start from the same + // initial position. + r.BodyStart, err = reader.Seek(0, io.SeekCurrent) + if err != nil { + r.Error = awserr.New(ErrCodeSerialization, + "failed to determine start of request body", err) + return + } + } r.ResetBody() } @@ -350,12 +365,16 @@ func (r *Request) Sign() error { return r.Error } -func (r *Request) getNextRequestBody() (io.ReadCloser, error) { +func (r *Request) getNextRequestBody() (body io.ReadCloser, err error) { if r.safeBody != nil { r.safeBody.Close() } - r.safeBody = newOffsetReader(r.Body, r.BodyStart) + r.safeBody, err = newOffsetReader(r.Body, r.BodyStart) + if err != nil { + return nil, awserr.New(ErrCodeSerialization, + "failed to get request body error", err) + } // Go 1.8 tightened and clarified the rules code needs to use when building // requests with the http package. Go 1.8 removed the automatic detection @@ -370,12 +389,12 @@ func (r *Request) getNextRequestBody() (io.ReadCloser, error) { // of the SDK if they used that field. // // Related golang/go#18257 - l, err := computeBodyLength(r.Body) + l, err := SeekerLen(r.Body) if err != nil { - return nil, awserr.New(ErrCodeSerialization, "failed to compute request body size", err) + return nil, awserr.New(ErrCodeSerialization, + "failed to compute request body size", err) } - var body io.ReadCloser if l == 0 { body = NoBody } else if l > 0 { @@ -388,7 +407,8 @@ func (r *Request) getNextRequestBody() (io.ReadCloser, error) { // Transfer-Encoding: chunked bodies for these methods. // // This would only happen if a ReaderSeekerCloser was used with - // a io.Reader that was not also an io.Seeker. + // a io.Reader that was not also an io.Seeker, or did not + // implement Len() method. switch r.Operation.HTTPMethod { case "GET", "HEAD", "DELETE": body = NoBody @@ -400,42 +420,6 @@ func (r *Request) getNextRequestBody() (io.ReadCloser, error) { return body, nil } -// Attempts to compute the length of the body of the reader using the -// io.Seeker interface. If the value is not seekable because of being -// a ReaderSeekerCloser without an unerlying Seeker -1 will be returned. -// If no error occurs the length of the body will be returned. -func computeBodyLength(r io.ReadSeeker) (int64, error) { - seekable := true - // Determine if the seeker is actually seekable. ReaderSeekerCloser - // hides the fact that a io.Readers might not actually be seekable. - switch v := r.(type) { - case ReaderSeekerCloser: - seekable = v.IsSeeker() - case *ReaderSeekerCloser: - seekable = v.IsSeeker() - } - if !seekable { - return -1, nil - } - - curOffset, err := r.Seek(0, 1) - if err != nil { - return 0, err - } - - endOffset, err := r.Seek(0, 2) - if err != nil { - return 0, err - } - - _, err = r.Seek(curOffset, 0) - if err != nil { - return 0, err - } - - return endOffset - curOffset, nil -} - // GetBody will return an io.ReadSeeker of the Request's underlying // input body with a concurrency safe wrapper. func (r *Request) GetBody() io.ReadSeeker { @@ -585,3 +569,72 @@ func shouldRetryCancel(r *Request) bool { errStr != "net/http: request canceled while waiting for connection") } + +// SanitizeHostForHeader removes default port from host and updates request.Host +func SanitizeHostForHeader(r *http.Request) { + host := getHost(r) + port := portOnly(host) + if port != "" && isDefaultPort(r.URL.Scheme, port) { + r.Host = stripPort(host) + } +} + +// Returns host from request +func getHost(r *http.Request) string { + if r.Host != "" { + return r.Host + } + + return r.URL.Host +} + +// Hostname returns u.Host, without any port number. +// +// If Host is an IPv6 literal with a port number, Hostname returns the +// IPv6 literal without the square brackets. IPv6 literals may include +// a zone identifier. +// +// Copied from the Go 1.8 standard library (net/url) +func stripPort(hostport string) string { + colon := strings.IndexByte(hostport, ':') + if colon == -1 { + return hostport + } + if i := strings.IndexByte(hostport, ']'); i != -1 { + return strings.TrimPrefix(hostport[:i], "[") + } + return hostport[:colon] +} + +// Port returns the port part of u.Host, without the leading colon. +// If u.Host doesn't contain a port, Port returns an empty string. +// +// Copied from the Go 1.8 standard library (net/url) +func portOnly(hostport string) string { + colon := strings.IndexByte(hostport, ':') + if colon == -1 { + return "" + } + if i := strings.Index(hostport, "]:"); i != -1 { + return hostport[i+len("]:"):] + } + if strings.Contains(hostport, "]") { + return "" + } + return hostport[colon+len(":"):] +} + +// Returns true if the specified URI is using the standard port +// (i.e. port 80 for HTTP URIs or 443 for HTTPS URIs) +func isDefaultPort(scheme, port string) bool { + if port == "" { + return true + } + + lowerCaseScheme := strings.ToLower(scheme) + if (lowerCaseScheme == "http" && port == "80") || (lowerCaseScheme == "https" && port == "443") { + return true + } + + return false +} diff --git a/aws/request_1_8.go b/aws/request_1_8.go index ad81c3d8bfb..71ea240d549 100644 --- a/aws/request_1_8.go +++ b/aws/request_1_8.go @@ -4,6 +4,8 @@ package aws import ( "net/http" + + "github.com/aws/aws-sdk-go-v2/aws/awserr" ) // NoBody is a http.NoBody reader instructing Go HTTP client to not include @@ -24,7 +26,8 @@ var NoBody = http.NoBody func (r *Request) ResetBody() { body, err := r.getNextRequestBody() if err != nil { - r.Error = err + r.Error = awserr.New(ErrCodeSerialization, + "failed to reset request body", err) return } diff --git a/aws/request_resetbody_test.go b/aws/request_resetbody_test.go index 92c4929e729..164494a6345 100644 --- a/aws/request_resetbody_test.go +++ b/aws/request_resetbody_test.go @@ -2,6 +2,7 @@ package aws import ( "bytes" + "io" "net/http" "strings" "testing" @@ -23,21 +24,70 @@ func TestResetBody_WithBodyContents(t *testing.T) { } } -func TestResetBody_ExcludeUnseekableBodyByMethod(t *testing.T) { +type mockReader struct{} + +func (mockReader) Read([]byte) (int, error) { + return 0, io.EOF +} + +func TestResetBody_ExcludeEmptyUnseekableBodyByMethod(t *testing.T) { cases := []struct { Method string + Body io.ReadSeeker IsNoBody bool }{ - {"GET", true}, - {"HEAD", true}, - {"DELETE", true}, - {"PUT", false}, - {"PATCH", false}, - {"POST", false}, + { + Method: "GET", + IsNoBody: true, + Body: ReadSeekCloser(mockReader{}), + }, + { + Method: "HEAD", + IsNoBody: true, + Body: ReadSeekCloser(mockReader{}), + }, + { + Method: "DELETE", + IsNoBody: true, + Body: ReadSeekCloser(mockReader{}), + }, + { + Method: "PUT", + IsNoBody: false, + Body: ReadSeekCloser(mockReader{}), + }, + { + Method: "PATCH", + IsNoBody: false, + Body: ReadSeekCloser(mockReader{}), + }, + { + Method: "POST", + IsNoBody: false, + Body: ReadSeekCloser(mockReader{}), + }, + { + Method: "GET", + IsNoBody: false, + Body: ReadSeekCloser(bytes.NewBuffer([]byte("abc"))), + }, + { + Method: "GET", + IsNoBody: true, + Body: ReadSeekCloser(bytes.NewBuffer(nil)), + }, + { + Method: "POST", + IsNoBody: false, + Body: ReadSeekCloser(bytes.NewBuffer([]byte("abc"))), + }, + { + Method: "POST", + IsNoBody: true, + Body: ReadSeekCloser(bytes.NewBuffer(nil)), + }, } - reader := ReadSeekCloser(bytes.NewBuffer([]byte("abc"))) - for i, c := range cases { r := Request{ HTTPRequest: &http.Request{}, @@ -46,7 +96,7 @@ func TestResetBody_ExcludeUnseekableBodyByMethod(t *testing.T) { }, } - r.SetReaderBody(reader) + r.SetReaderBody(c.Body) if a, e := r.HTTPRequest.Body == NoBody, c.IsNoBody; a != e { t.Errorf("%d, expect body to be set to noBody(%t), but was %t", i, e, a) diff --git a/aws/request_test.go b/aws/request_test.go index 12de9fab3d4..92c4eca33be 100644 --- a/aws/request_test.go +++ b/aws/request_test.go @@ -940,3 +940,61 @@ func TestRequest_TemporaryRetry(t *testing.T) { t.Errorf("expect temporary error, was not") } } + +func TestSanitizeHostForHeader(t *testing.T) { + cases := []struct { + url string + expectedRequestHost string + }{ + {"https://estest.us-east-1.es.amazonaws.com:443", "estest.us-east-1.es.amazonaws.com"}, + {"https://estest.us-east-1.es.amazonaws.com", "estest.us-east-1.es.amazonaws.com"}, + {"https://localhost:9200", "localhost:9200"}, + {"http://localhost:80", "localhost"}, + {"http://localhost:8080", "localhost:8080"}, + } + + for _, c := range cases { + r, _ := http.NewRequest("GET", c.url, nil) + aws.SanitizeHostForHeader(r) + + if h := r.Host; h != c.expectedRequestHost { + t.Errorf("expect %v host, got %q", c.expectedRequestHost, h) + } + } +} + +func TestRequestBodySeekFails(t *testing.T) { + s := awstesting.NewClient(unit.Config()) + s.Handlers.Validate.Clear() + s.Handlers.Build.Clear() + + out := &testData{} + r := s.NewRequest(&aws.Operation{Name: "Operation"}, nil, out) + r.SetReaderBody(&stubSeekFail{ + Err: fmt.Errorf("failed to seek reader"), + }) + err := r.Send() + if err == nil { + t.Fatal("expect error, but got none") + } + + aerr := err.(awserr.Error) + if e, a := aws.ErrCodeSerialization, aerr.Code(); e != a { + t.Errorf("expect %v error code, got %v", e, a) + } + +} + +type stubSeekFail struct { + Err error +} + +func (f *stubSeekFail) Read(b []byte) (int, error) { + return len(b), nil +} +func (f *stubSeekFail) ReadAt(b []byte, offset int64) (int, error) { + return len(b), nil +} +func (f *stubSeekFail) Seek(offset int64, mode int) (int64, error) { + return 0, f.Err +} diff --git a/aws/signer/v4/v4.go b/aws/signer/v4/v4.go index ddbade68697..4e4ea0e2b9b 100644 --- a/aws/signer/v4/v4.go +++ b/aws/signer/v4/v4.go @@ -96,25 +96,25 @@ var ignoredHeaders = rules{ var requiredSignedHeaders = rules{ whitelist{ mapRule{ - "Cache-Control": struct{}{}, - "Content-Disposition": struct{}{}, - "Content-Encoding": struct{}{}, - "Content-Language": struct{}{}, - "Content-Md5": struct{}{}, - "Content-Type": struct{}{}, - "Expires": struct{}{}, - "If-Match": struct{}{}, - "If-Modified-Since": struct{}{}, - "If-None-Match": struct{}{}, - "If-Unmodified-Since": struct{}{}, - "Range": struct{}{}, - "X-Amz-Acl": struct{}{}, - "X-Amz-Copy-Source": struct{}{}, - "X-Amz-Copy-Source-If-Match": struct{}{}, - "X-Amz-Copy-Source-If-Modified-Since": struct{}{}, - "X-Amz-Copy-Source-If-None-Match": struct{}{}, - "X-Amz-Copy-Source-If-Unmodified-Since": struct{}{}, - "X-Amz-Copy-Source-Range": struct{}{}, + "Cache-Control": struct{}{}, + "Content-Disposition": struct{}{}, + "Content-Encoding": struct{}{}, + "Content-Language": struct{}{}, + "Content-Md5": struct{}{}, + "Content-Type": struct{}{}, + "Expires": struct{}{}, + "If-Match": struct{}{}, + "If-Modified-Since": struct{}{}, + "If-None-Match": struct{}{}, + "If-Unmodified-Since": struct{}{}, + "Range": struct{}{}, + "X-Amz-Acl": struct{}{}, + "X-Amz-Copy-Source": struct{}{}, + "X-Amz-Copy-Source-If-Match": struct{}{}, + "X-Amz-Copy-Source-If-Modified-Since": struct{}{}, + "X-Amz-Copy-Source-If-None-Match": struct{}{}, + "X-Amz-Copy-Source-If-Unmodified-Since": struct{}{}, + "X-Amz-Copy-Source-Range": struct{}{}, "X-Amz-Copy-Source-Server-Side-Encryption-Customer-Algorithm": struct{}{}, "X-Amz-Copy-Source-Server-Side-Encryption-Customer-Key": struct{}{}, "X-Amz-Copy-Source-Server-Side-Encryption-Customer-Key-Md5": struct{}{}, @@ -329,8 +329,11 @@ func (v4 Signer) signWithBody(r *http.Request, body io.ReadSeeker, service, regi return http.Header{}, err } + aws.SanitizeHostForHeader(ctx.Request) ctx.assignAmzQueryValues() - ctx.build(v4.DisableHeaderHoisting) + if err := ctx.build(v4.DisableHeaderHoisting); err != nil { + return nil, err + } // If the request is not presigned the body should be attached to it. This // prevents the confusion of wanting to send a signed request without @@ -483,11 +486,13 @@ func (v4 *Signer) logSigningInfo(ctx *signingCtx) { v4.Logger.Log(msg) } -func (ctx *signingCtx) build(disableHeaderHoisting bool) { +func (ctx *signingCtx) build(disableHeaderHoisting bool) error { ctx.buildTime() // no depends ctx.buildCredentialString() // no depends - ctx.buildBodyDigest() + if err := ctx.buildBodyDigest(); err != nil { + return err + } unsignedHeaders := ctx.Request.Header if ctx.isPresign { @@ -515,6 +520,7 @@ func (ctx *signingCtx) build(disableHeaderHoisting bool) { } ctx.Request.Header.Set("Authorization", strings.Join(parts, ", ")) } + return nil } func (ctx *signingCtx) buildTime() { @@ -641,7 +647,7 @@ func (ctx *signingCtx) buildSignature() { ctx.signature = hex.EncodeToString(signature) } -func (ctx *signingCtx) buildBodyDigest() { +func (ctx *signingCtx) buildBodyDigest() error { hash := ctx.Request.Header.Get("X-Amz-Content-Sha256") if hash == "" { includeSHA256Header := ctx.unsignedPayload || @@ -656,7 +662,15 @@ func (ctx *signingCtx) buildBodyDigest() { } else if ctx.Body == nil { hash = emptyStringSHA256 } else { - hash = hex.EncodeToString(makeSha256Reader(ctx.Body)) + + if !aws.IsReaderSeekable(ctx.Body) { + return fmt.Errorf("cannot use unseekable request body %T, for signed request with body", ctx.Body) + } + hashBytes, err := makeSha256Reader(ctx.Body) + if err != nil { + return err + } + hash = hex.EncodeToString(hashBytes) } if includeSHA256Header { @@ -664,6 +678,7 @@ func (ctx *signingCtx) buildBodyDigest() { } } ctx.bodyDigest = hash + return nil } // isRequestSigned returns if the request is currently signed or presigned @@ -701,13 +716,19 @@ func makeSha256(data []byte) []byte { return hash.Sum(nil) } -func makeSha256Reader(reader io.ReadSeeker) []byte { +func makeSha256Reader(reader io.ReadSeeker) (hashBytes []byte, err error) { hash := sha256.New() - start, _ := reader.Seek(0, 1) - defer reader.Seek(start, 0) + start, err := reader.Seek(0, io.SeekCurrent) + if err != nil { + return nil, err + } + defer func() { + // ensure error is return if unable to seek back to start if payload + _, err = reader.Seek(start, io.SeekStart) + }() io.Copy(hash, reader) - return hash.Sum(nil) + return hash.Sum(nil), nil } const doubleSpace = " " diff --git a/aws/signer/v4/v4_test.go b/aws/signer/v4/v4_test.go index 4e14b16944a..5a1ff7448e6 100644 --- a/aws/signer/v4/v4_test.go +++ b/aws/signer/v4/v4_test.go @@ -7,6 +7,7 @@ import ( "net/http" "net/http/httptest" "reflect" + "strconv" "strings" "testing" "time" @@ -61,17 +62,43 @@ func TestStripExcessHeaders(t *testing.T) { } func buildRequest(serviceName, region, body string) (*http.Request, io.ReadSeeker) { - endpoint := "https://" + serviceName + "." + region + ".amazonaws.com" + reader := strings.NewReader(body) - req, _ := http.NewRequest("POST", endpoint, reader) + return buildRequestWithBodyReader(serviceName, region, reader) +} + +func buildRequestWithBodyReader(serviceName, region string, body io.Reader) (*http.Request, io.ReadSeeker) { + var bodyLen int + + type lenner interface { + Len() int + } + if lr, ok := body.(lenner); ok { + bodyLen = lr.Len() + } + + endpoint := "https://" + serviceName + "." + region + ".amazonaws.com" + req, _ := http.NewRequest("POST", endpoint, body) req.URL.Opaque = "//example.org/bucket/key-._~,!@#$%^&*()" - req.Header.Add("X-Amz-Target", "prefix.Operation") - req.Header.Add("Content-Type", "application/x-amz-json-1.0") - req.Header.Add("Content-Length", string(len(body))) - req.Header.Add("X-Amz-Meta-Other-Header", "some-value=!@#$%^&* (+)") + req.Header.Set("X-Amz-Target", "prefix.Operation") + req.Header.Set("Content-Type", "application/x-amz-json-1.0") + + if bodyLen > 0 { + req.Header.Set("Content-Length", strconv.Itoa(bodyLen)) + } + + req.Header.Set("X-Amz-Meta-Other-Header", "some-value=!@#$%^&* (+)") req.Header.Add("X-Amz-Meta-Other-Header_With_Underscore", "some-value=!@#$%^&* (+)") req.Header.Add("X-amz-Meta-Other-Header_With_Underscore", "some-value=!@#$%^&* (+)") - return req, reader + + var seeker io.ReadSeeker + if sr, ok := body.(io.ReadSeeker); ok { + seeker = sr + } else { + seeker = aws.ReadSeekCloser(body) + } + + return req, seeker } func buildSigner() Signer { @@ -101,7 +128,7 @@ func TestPresignRequest(t *testing.T) { expectedDate := "19700101T000000Z" expectedHeaders := "content-length;content-type;host;x-amz-meta-other-header;x-amz-meta-other-header_with_underscore" - expectedSig := "ea7856749041f727690c580569738282e99c79355fe0d8f125d3b5535d2ece83" + expectedSig := "122f0b9e091e4ba84286097e2b3404a1f1f4c4aad479adda95b7dff0ccbe5581" expectedCred := "AKID/19700101/us-east-1/dynamodb/aws4_request" expectedTarget := "prefix.Operation" @@ -135,7 +162,7 @@ func TestPresignBodyWithArrayRequest(t *testing.T) { expectedDate := "19700101T000000Z" expectedHeaders := "content-length;content-type;host;x-amz-meta-other-header;x-amz-meta-other-header_with_underscore" - expectedSig := "fef6002062400bbf526d70f1a6456abc0fb2e213fe1416012737eebd42a62924" + expectedSig := "e3ac55addee8711b76c6d608d762cff285fe8b627a057f8b5ec9268cf82c08b1" expectedCred := "AKID/19700101/us-east-1/dynamodb/aws4_request" expectedTarget := "prefix.Operation" @@ -166,7 +193,7 @@ func TestSignRequest(t *testing.T) { signer.Sign(req, body, "dynamodb", "us-east-1", time.Unix(0, 0)) expectedDate := "19700101T000000Z" - expectedSig := "AWS4-HMAC-SHA256 Credential=AKID/19700101/us-east-1/dynamodb/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date;x-amz-meta-other-header;x-amz-meta-other-header_with_underscore;x-amz-security-token;x-amz-target, Signature=ea766cabd2ec977d955a3c2bae1ae54f4515d70752f2207618396f20aa85bd21" + expectedSig := "AWS4-HMAC-SHA256 Credential=AKID/19700101/us-east-1/dynamodb/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date;x-amz-meta-other-header;x-amz-meta-other-header_with_underscore;x-amz-security-token;x-amz-target, Signature=a518299330494908a70222cec6899f6f32f297f8595f6df1776d998936652ad9" q := req.Header if e, a := expectedSig, q.Get("Authorization"); e != a { @@ -177,6 +204,53 @@ func TestSignRequest(t *testing.T) { } } +func TestSignUnseekableBody(t *testing.T) { + req, body := buildRequestWithBodyReader("mock-service", "mock-region", bytes.NewBuffer([]byte("hello"))) + signer := buildSigner() + _, err := signer.Sign(req, body, "mock-service", "mock-region", time.Now()) + if err == nil { + t.Fatalf("expect error signing request") + } + + if e, a := "unseekable request body", err.Error(); !strings.Contains(a, e) { + t.Errorf("expect %q to be in %q", e, a) + } +} + +func TestSignUnsignedPayloadUnseekableBody(t *testing.T) { + req, body := buildRequestWithBodyReader("mock-service", "mock-region", bytes.NewBuffer([]byte("hello"))) + + signer := buildSigner() + signer.UnsignedPayload = true + + _, err := signer.Sign(req, body, "mock-service", "mock-region", time.Now()) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + hash := req.Header.Get("X-Amz-Content-Sha256") + if e, a := "UNSIGNED-PAYLOAD", hash; e != a { + t.Errorf("expect %v, got %v", e, a) + } +} + +func TestSignPreComputedHashUnseekableBody(t *testing.T) { + req, body := buildRequestWithBodyReader("mock-service", "mock-region", bytes.NewBuffer([]byte("hello"))) + + signer := buildSigner() + + req.Header.Set("X-Amz-Content-Sha256", "some-content-sha256") + _, err := signer.Sign(req, body, "mock-service", "mock-region", time.Now()) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + hash := req.Header.Get("X-Amz-Content-Sha256") + if e, a := "some-content-sha256", hash; e != a { + t.Errorf("expect %v, got %v", e, a) + } +} + func TestSignBodyS3(t *testing.T) { req, body := buildRequest("s3", "us-east-1", "hello") signer := buildSigner() diff --git a/aws/types.go b/aws/types.go index 0e2d864e10a..0dfb25bb810 100644 --- a/aws/types.go +++ b/aws/types.go @@ -5,13 +5,18 @@ import ( "sync" ) -// ReadSeekCloser wraps a io.Reader returning a ReaderSeekerCloser. Should -// only be used with an io.Reader that is also an io.Seeker. Doing so may -// cause request signature errors, or request body's not sent for GET, HEAD -// and DELETE HTTP methods. +// ReadSeekCloser wraps a io.Reader returning a ReaderSeekerCloser. Allows the +// SDK to accept an io.Reader that is not also an io.Seeker for unsigned +// streaming payload API operations. // -// Deprecated: Should only be used with io.ReadSeeker. If using for -// S3 PutObject to stream content use s3manager.Uploader instead. +// A readSeekCloser wrapping an nonseekable io.Reader used in an API operation's +// input will prevent that operation being retried in the case of +// network errors, and cause operation requests to fail if yhe operation +// requires payload signing. +// +// Note: If using with S3 PutObject to stream an object upload. The SDK's S3 +// Upload Manager(s3manager.Uploader) provides support for streaming +// with the ability to retry network errors. func ReadSeekCloser(r io.Reader) ReaderSeekerCloser { return ReaderSeekerCloser{r} } @@ -22,10 +27,92 @@ type ReaderSeekerCloser struct { r io.Reader } +// IsReaderSeekable returns if the underlying reader type can be seeked. A +// io.Reader might not actually be seekable if it is the ReaderSeekerCloser +// type. +func IsReaderSeekable(r io.Reader) bool { + switch v := r.(type) { + case ReaderSeekerCloser: + return v.IsSeeker() + case *ReaderSeekerCloser: + return v.IsSeeker() + case io.ReadSeeker: + return true + default: + return false + } +} + +// SeekerLen attempts to get the number of bytes remaining at the seeker's +// current position. Returns the number of bytes remaining or error. +func SeekerLen(s io.Seeker) (int64, error) { + // Determine if the seeker is actually seekable. ReaderSeekerCloser + // hides the fact that a io.Readers might not actually be seekable. + switch v := s.(type) { + case ReaderSeekerCloser: + return v.GetLen() + case *ReaderSeekerCloser: + return v.GetLen() + } + + return seekerLen(s) +} + +// GetLen returns the length of the bytes remaining in the underlying reader. +// Checks first for Len(), then io.Seeker to determine the size of the +// underlying reader. +// +// Will return -1 if the length cannot be determined. +func (r ReaderSeekerCloser) GetLen() (int64, error) { + if l, ok := r.HasLen(); ok { + return int64(l), nil + } + + if s, ok := r.r.(io.Seeker); ok { + return seekerLen(s) + } + + return -1, nil +} + +func seekerLen(s io.Seeker) (int64, error) { + curOffset, err := s.Seek(0, io.SeekCurrent) + if err != nil { + return 0, err + } + + endOffset, err := s.Seek(0, io.SeekEnd) + if err != nil { + return 0, err + } + + _, err = s.Seek(curOffset, io.SeekStart) + if err != nil { + return 0, err + } + + return endOffset - curOffset, nil +} + +// HasLen returns the length of the underlying reader if the value implements +// the Len() int method. +func (r ReaderSeekerCloser) HasLen() (int, bool) { + type lenner interface { + Len() int + } + + if lr, ok := r.r.(lenner); ok { + return lr.Len(), true + } + + return 0, false +} + // Read reads from the reader up to size of p. The number of bytes read, and // error if it occurred will be returned. // -// If the reader is not an io.Reader zero bytes read, and nil error will be returned. +// If the reader is not an io.Reader zero bytes read, and nil error will be +// returned. // // Performs the same functionality as io.Reader Read func (r ReaderSeekerCloser) Read(p []byte) (int, error) { diff --git a/example/service/mediastoredata/streamingNonSeekableReader/README.md b/example/service/mediastoredata/streamingNonSeekableReader/README.md new file mode 100644 index 00000000000..da426371113 --- /dev/null +++ b/example/service/mediastoredata/streamingNonSeekableReader/README.md @@ -0,0 +1,17 @@ +# Example + +This is an example demonstrates how you can use the AWS Elemental MediaStore +API PutObject operation with a non-seekable io.Reader. + +# Usage + +The example will create an Elemental MediaStore container, and upload a +contrived non-seekable io.Reader to that container. Using the SDK's +[aws.ReadSeekCloser](https://docs.aws.amazon.com/sdk-for-go/v2/api/aws/#ReadSeekCloser) +utility for wrapping the `io.Reader` in a value the +[mediastore#PutObjectInput](https://docs.aws.amazon.com/sdk-for-go/v2/api/service/mediastoredata/#PutObjectInput).Body will accept. + +The example will attempt to create the container if it does not already exist. + +```sh +AWS_REGION= go run -tags example main.go diff --git a/example/service/mediastoredata/streamingNonSeekableReader/main.go b/example/service/mediastoredata/streamingNonSeekableReader/main.go new file mode 100644 index 00000000000..23c7bfa7f83 --- /dev/null +++ b/example/service/mediastoredata/streamingNonSeekableReader/main.go @@ -0,0 +1,88 @@ +// +build example + +package main + +import ( + "fmt" + "io" + "log" + "math/rand" + "os" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/mediastore" + "github.com/aws/aws-sdk-go-v2/service/mediastoredata" +) + +func main() { + containerName := os.Args[1] + objectPath := os.Args[2] + + // Create an AWS Elemental MediaStore Data client using default config. + config := aws.Config{} + dataSvc, err := getMediaStoreDataClient(containerName, config) + if err != nil { + log.Fatalf("failed to create client, %v", err) + } + + // Create a random reader to simulate a unseekable reader, wrap the reader + // in an io.LimitReader to prevent uploading forever. + randReader := rand.New(rand.NewSource(0)) + reader := io.LimitReader(randReader, 1024*1024 /* 1MB */) + + // Wrap the unseekable reader with the SDK's RandSeekCloser. This type will + // allow the SDK's to use the nonseekable reader. + body := aws.ReadSeekCloser(reader) + + // Make the PutObject API call with the nonseekable reader, causing the SDK + // to send the request body payload a chunked transfer encoding. + dataSvc.PutObjectRequest(&mediastoredata.PutObjectInput{ + Path: &objectPath, + Body: body, + }) + + fmt.Println("object uploaded") +} + +// getMediaStoreDataClient uses the AWS Elemental MediaStore API to get the +// endpoint for a container. If the container endpoint can be retrieved a AWS +// Elemental MediaStore Data client will be created and returned. Otherwise +// error is returned. +func getMediaStoreDataClient(containerName string, config aws.Config) (*mediastoredata.Client, error) { + endpoint, err := containerEndpoint(containerName, config) + if err != nil { + return nil, err + } + config.EndpointResolver = aws.ResolveWithEndpointURL(aws.StringValue(endpoint)) + dataSvc := mediastoredata.New(config) + + return dataSvc, nil +} + +// ContainerEndpoint will attempt to get the endpoint for a container, +// returning error if the container doesn't exist, or is not active within a +// timeout. +func containerEndpoint(name string, config aws.Config) (*string, error) { + for i := 0; i < 3; i++ { + ctrlSvc := mediastore.New(config) + descContainerRequest := ctrlSvc.DescribeContainerRequest(&mediastore.DescribeContainerInput{ + ContainerName: &name, + }) + + descResp, err := descContainerRequest.Send(descContainerRequest.Context()) + if err != nil { + return nil, err + } + + if status := descResp.Container.Status; status != "ACTIVE" { + log.Println("waiting for container to be active, ", status) + time.Sleep(10 * time.Second) + continue + } + + return descResp.Container.Endpoint, nil + } + + return nil, fmt.Errorf("container is not active") +} diff --git a/example/service/s3/loggingUploadObjectReadBehavior/README.md b/example/service/s3/loggingUploadObjectReadBehavior/README.md new file mode 100644 index 00000000000..be1b28fe544 --- /dev/null +++ b/example/service/s3/loggingUploadObjectReadBehavior/README.md @@ -0,0 +1,16 @@ + +# Example + +This example shows how you could wrap the reader of an file being +uploaded to Amazon S3 with a logger that will log the usage of the +reader, and print call stacks when the reader's Read, Seek, or ReadAt +methods encounter an error. + +# Usage + +This bucket uses the bucket name, key, and local file name passed to +upload the local file to S3 as the key into the bucket. + +```sh +AWS_REGION=us-west-2 AWS_PROFILE=default go run . "mybucket" "10MB.file" ./10MB.file +`` \ No newline at end of file diff --git a/example/service/s3/loggingUploadObjectReadBehavior/main.go b/example/service/s3/loggingUploadObjectReadBehavior/main.go new file mode 100644 index 00000000000..69f3c8973b5 --- /dev/null +++ b/example/service/s3/loggingUploadObjectReadBehavior/main.go @@ -0,0 +1,114 @@ +package main + +import ( + "fmt" + "io" + "log" + "os" + "runtime/debug" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/s3/s3manager" +) + +// Usage: +// go run -tags example +// +// Example: +// AWS_REGION=us-west-2 AWS_PROFILE=default go run . "mybucket" "10MB.file" ./10MB.file +func main() { + + config := aws.Config{} + uploader := s3manager.NewUploader(config) + + file, err := os.Open(os.Args[3]) + if err != nil { + log.Fatalf("failed to open file, %v", err) + } + defer file.Close() + + // Wrap the readSeeker with a logger that will log usage, and stack traces + // on errors. + readLogger := NewReadLogger(file, config.Logger) + + // Upload with read logger + resp, err := uploader.Upload(&s3manager.UploadInput{ + Bucket: &os.Args[1], + Key: &os.Args[2], + Body: readLogger, + }, func(u *s3manager.Uploader) { + u.Concurrency = 1 + u.RequestOptions = append(u.RequestOptions, func(r *aws.Request) { + }) + }) + + fmt.Println(resp, err) +} + +// Logger is a logger use for logging the readers usage. +type Logger interface { + Log(args ...interface{}) +} + +// ReadSeeker interface provides the interface for a Reader, Seeker, and ReadAt. +type ReadSeeker interface { + io.ReadSeeker + io.ReaderAt +} + +// ReadLogger wraps an reader with logging for access. +type ReadLogger struct { + reader ReadSeeker + logger Logger +} + +// NewReadLogger a ReadLogger that wraps the passed in ReadSeeker (Reader, +// Seeker, ReadAt) with a logger. +func NewReadLogger(r ReadSeeker, logger Logger) *ReadLogger { + return &ReadLogger{ + reader: r, + logger: logger, + } +} + +// Seek offsets the reader's current position for the next read. +func (s *ReadLogger) Seek(offset int64, mode int) (int64, error) { + newOffset, err := s.reader.Seek(offset, mode) + msg := fmt.Sprintf( + "ReadLogger.Seek(offset:%d, mode:%d) (newOffset:%d, err:%v)", + offset, mode, newOffset, err) + if err != nil { + msg += fmt.Sprintf("\n\tStack:\n%s", string(debug.Stack())) + } + + s.logger.Log(msg) + return newOffset, err +} + +// Read attempts to read from the reader, returning the bytes read, or error. +func (s *ReadLogger) Read(b []byte) (int, error) { + n, err := s.reader.Read(b) + msg := fmt.Sprintf( + "ReadLogger.Read(len(bytes):%d) (read:%d, err:%v)", + len(b), n, err) + if err != nil { + msg += fmt.Sprintf("\n\tStack:\n%s", string(debug.Stack())) + } + + s.logger.Log(msg) + return n, err +} + +// ReadAt will read the underlying reader starting at the offset. +func (s *ReadLogger) ReadAt(b []byte, offset int64) (int, error) { + n, err := s.reader.ReadAt(b, offset) + msg := fmt.Sprintf( + "ReadLogger.ReadAt(len(bytes):%d, offset:%d) (read:%d, err:%v)", + len(b), offset, n, err) + if err != nil { + msg += fmt.Sprintf("\n\tStack:\n%s", string(debug.Stack())) + } + + s.logger.Log(msg) + return n, err +} diff --git a/internal/awstesting/cmd/op_crawler/main.go b/internal/awstesting/cmd/op_crawler/main.go index 5ad9d4cf87d..104b742a470 100644 --- a/internal/awstesting/cmd/op_crawler/main.go +++ b/internal/awstesting/cmd/op_crawler/main.go @@ -436,17 +436,17 @@ func computeBodyLength(r io.ReadSeeker) (int64, error) { return -1, nil } - curOffset, err := r.Seek(0, 1) + curOffset, err := r.Seek(0, io.SeekCurrent) if err != nil { return 0, err } - endOffset, err := r.Seek(0, 2) + endOffset, err := r.Seek(0, io.SeekEnd) if err != nil { return 0, err } - _, err = r.Seek(curOffset, 0) + _, err = r.Seek(curOffset, io.SeekStart) if err != nil { return 0, err } diff --git a/private/model/api/customization_passes.go b/private/model/api/customization_passes.go index c1eb4831105..7cb793395e9 100644 --- a/private/model/api/customization_passes.go +++ b/private/model/api/customization_passes.go @@ -55,13 +55,13 @@ func (a *API) customizationPasses() { // Backfill the authentication type for cognito identity and sts. // Removes the need for the customizations in these services. - "cognitoidentity": backfillAuthType("none", + "cognitoidentity": backfillAuthType(NoneAuthType, "GetId", "GetOpenIdToken", "UnlinkIdentity", "GetCredentialsForIdentity", ), - "sts": backfillAuthType("none", + "sts": backfillAuthType(NoneAuthType, "AssumeRoleWithSAML", "AssumeRoleWithWebIdentity", ), @@ -226,7 +226,7 @@ func rdsCustomizations(a *API) { } } } -func backfillAuthType(typ string, opNames ...string) func(*API) { +func backfillAuthType(typ AuthType, opNames ...string) func(*API) { return func(a *API) { for _, opName := range opNames { op, ok := a.Operations[opName] diff --git a/private/model/api/load.go b/private/model/api/load.go index db7e3dbca79..356bbe5abde 100644 --- a/private/model/api/load.go +++ b/private/model/api/load.go @@ -182,6 +182,7 @@ func (a *API) Setup() { //a.findEndpointDiscoveryOp() a.suppressEventStreams() a.customizationPasses() + a.injectUnboundedOutputStreaming() if !a.NoRemoveUnusedShapes { a.removeUnusedShapes() diff --git a/private/model/api/operation.go b/private/model/api/operation.go index e1bf3651f1e..2cbf0ac28c7 100644 --- a/private/model/api/operation.go +++ b/private/model/api/operation.go @@ -22,8 +22,8 @@ type Operation struct { OutputRef ShapeRef `json:"output"` ErrorRefs []ShapeRef `json:"errors"` Paginator *Paginator - Deprecated bool `json:"deprecated"` - AuthType string `json:"authtype"` + Deprecated bool `json:"deprecated"` + AuthType AuthType `json:"authtype"` imports map[string]bool CustomBuildHandlers []string `json:"-"` Endpoint *EndpointTrait `json:"endpoint"` @@ -54,14 +54,25 @@ func (o *Operation) HasOutput() bool { return o.OutputRef.ShapeName != "" } +// AuthType provides the enumeration of AuthType trait. +type AuthType string + +// Enumeration values for AuthType trait +const ( + NoneAuthType AuthType = "none" + V4UnsignedBodyAuthType AuthType = "v4-unsigned-body" +) + // GetSigner returns the signer to use for a request. func (o *Operation) GetSigner() string { buf := bytes.NewBuffer(nil) switch o.AuthType { - case "none": + case NoneAuthType: + o.API.AddSDKImport("aws") + buf.WriteString("req.Config.Credentials = aws.AnonymousCredentials") - case "v4-unsigned-body": + case V4UnsignedBodyAuthType: o.API.AddSDKImport("aws/signer/v4") buf.WriteString("req.Handlers.Sign.Remove(v4.SignRequestHandler)\n") diff --git a/private/model/api/passes.go b/private/model/api/passes.go index 4de222a1152..cbc3a8d196e 100644 --- a/private/model/api/passes.go +++ b/private/model/api/passes.go @@ -377,3 +377,24 @@ func (a *API) setMetadataEndpointsKey() { a.Metadata.EndpointsID = a.Metadata.EndpointPrefix } } + +func (a *API) injectUnboundedOutputStreaming() { + for _, op := range a.Operations { + if op.AuthType != V4UnsignedBodyAuthType { + continue + } + for _, ref := range op.InputRef.Shape.MemberRefs { + if ref.Streaming || ref.Shape.Streaming { + if len(ref.Documentation) != 0 { + ref.Documentation += ` +//` + } + ref.Documentation += ` +// To use an non-seekable io.Reader for this request wrap the io.Reader with +// "aws.ReadSeekCloser". The SDK will not retry request errors for non-seekable +// readers. This will allow the SDK to send the reader's payload as chunked +// transfer encoding.` + } + } + } +} diff --git a/service/glacier/treehash.go b/service/glacier/treehash.go index e1ee0aa5b7b..4b7a88d8ebd 100644 --- a/service/glacier/treehash.go +++ b/service/glacier/treehash.go @@ -18,8 +18,8 @@ type Hash struct { // // See http://docs.aws.amazon.com/amazonglacier/latest/dev/checksum-calculations.html for more information. func ComputeHashes(r io.ReadSeeker) Hash { - r.Seek(0, 0) // Read the whole stream - defer r.Seek(0, 0) // Rewind stream at end + start, _ := r.Seek(0, io.SeekCurrent) // Read the whole stream + defer r.Seek(start, io.SeekStart) // Rewind stream at end buf := make([]byte, bufsize) hashes := [][]byte{} diff --git a/service/glacier/treehash_test.go b/service/glacier/treehash_test.go index 3cd0d6db4d6..351806894bb 100644 --- a/service/glacier/treehash_test.go +++ b/service/glacier/treehash_test.go @@ -13,7 +13,7 @@ func ExampleComputeHashes() { r := testCreateReader() h := glacier.ComputeHashes(r) - n, _ := r.Seek(0, 1) // Check position after checksumming + n, _ := r.Seek(0, io.SeekCurrent) // Check position after checksumming fmt.Printf("linear: %x\n", h.LinearHash) fmt.Printf("tree: %x\n", h.TreeHash) diff --git a/service/lexruntimeservice/api_op_PostContent.go b/service/lexruntimeservice/api_op_PostContent.go index ae392f6f1e7..38940b16522 100644 --- a/service/lexruntimeservice/api_op_PostContent.go +++ b/service/lexruntimeservice/api_op_PostContent.go @@ -70,6 +70,11 @@ type PostContentInput struct { // that captures all of the audio data before sending. In general, you get better // performance if you stream audio data rather than buffering the data locally. // + // To use an non-seekable io.Reader for this request wrap the io.Reader with + // "aws.ReadSeekCloser". The SDK will not retry request errors for non-seekable + // readers. This will allow the SDK to send the reader's payload as chunked + // transfer encoding. + // // InputStream is a required field InputStream io.ReadSeeker `locationName:"inputStream" type:"blob" required:"true"` diff --git a/service/mediastoredata/api_op_PutObject.go b/service/mediastoredata/api_op_PutObject.go index 52dde487c90..bfe5ce1e4a1 100644 --- a/service/mediastoredata/api_op_PutObject.go +++ b/service/mediastoredata/api_op_PutObject.go @@ -18,6 +18,11 @@ type PutObjectInput struct { // The bytes to be stored. // + // To use an non-seekable io.Reader for this request wrap the io.Reader with + // "aws.ReadSeekCloser". The SDK will not retry request errors for non-seekable + // readers. This will allow the SDK to send the reader's payload as chunked + // transfer encoding. + // // Body is a required field Body io.ReadSeeker `type:"blob" required:"true"` diff --git a/service/s3/content_md5.go b/service/s3/content_md5.go index f03445e59f6..df87bcbea74 100644 --- a/service/s3/content_md5.go +++ b/service/s3/content_md5.go @@ -22,7 +22,7 @@ func contentMD5(r *request.Request) { r.Error = awserr.New("ContentMD5", "failed to read body", err) return } - _, err = r.Body.Seek(0, 0) + _, err = r.Body.Seek(0, io.SeekStart) if err != nil { r.Error = awserr.New("ContentMD5", "failed to seek body", err) return diff --git a/service/s3/s3crypto/encryption_client.go b/service/s3/s3crypto/encryption_client.go index 787eba8c5b9..4581488e4df 100644 --- a/service/s3/s3crypto/encryption_client.go +++ b/service/s3/s3crypto/encryption_client.go @@ -72,12 +72,11 @@ func (c *EncryptionClient) PutObjectRequest(input *s3.PutObjectInput) s3.PutObje req := c.S3Client.PutObjectRequest(input) // Get Size of file - n, err := input.Body.Seek(0, 2) + n, err := aws.SeekerLen(input.Body) if err != nil { req.Error = err return req } - input.Body.Seek(0, 0) dst, err := getWriterStore(req.Request, c.TempFolderPath, n >= c.MinFileSize) if err != nil { @@ -116,7 +115,7 @@ func (c *EncryptionClient) PutObjectRequest(input *s3.PutObjectInput) s3.PutObje shaHex := hex.EncodeToString(sha.GetValue()) req.HTTPRequest.Header.Set("X-Amz-Content-Sha256", shaHex) - dst.Seek(0, 0) + dst.Seek(0, io.SeekStart) input.Body = dst err = c.SaveStrategy.Save(env, r) diff --git a/service/s3/s3crypto/helper_test.go b/service/s3/s3crypto/helper_test.go index 7f06bda0267..f981992be71 100644 --- a/service/s3/s3crypto/helper_test.go +++ b/service/s3/s3crypto/helper_test.go @@ -2,6 +2,7 @@ package s3crypto import ( "bytes" + "io" "testing" ) @@ -55,7 +56,7 @@ func TestBytesReadWriteSeeker_Write(t *testing.T) { func TestBytesReadWriteSeeker_Seek(t *testing.T) { b := &bytesReadWriteSeeker{[]byte{1, 2, 3}, 0} expected := []byte{2, 3} - m, err := b.Seek(1, 0) + m, err := b.Seek(1, io.SeekStart) if err != nil { t.Errorf("expected no error, but received %v", err) diff --git a/service/s3/s3manager/upload.go b/service/s3/s3manager/upload.go index 2e4b4026f48..7300b6ae471 100644 --- a/service/s3/s3manager/upload.go +++ b/service/s3/s3manager/upload.go @@ -351,7 +351,9 @@ type uploader struct { // internal logic for deciding whether to upload a single part or use a // multipart upload. func (u *uploader) upload() (*UploadOutput, error) { - u.init() + if err := u.init(); err != nil { + return nil, awserr.New("ReadRequestBody", "unable to initillize upload", err) + } if u.cfg.PartSize < MinUploadPartSize { msg := fmt.Sprintf("part size must be at least %d bytes", MinUploadPartSize) @@ -371,7 +373,7 @@ func (u *uploader) upload() (*UploadOutput, error) { } // init will initialize all default options. -func (u *uploader) init() { +func (u *uploader) init() error { if u.cfg.Concurrency == 0 { u.cfg.Concurrency = DefaultUploadConcurrency } @@ -380,22 +382,19 @@ func (u *uploader) init() { } // Try to get the total size for some optimizations - u.initSize() + return u.initSize() } // initSize tries to detect the total stream size, setting u.totalSize. If // the size is not known, totalSize is set to -1. -func (u *uploader) initSize() { +func (u *uploader) initSize() error { u.totalSize = -1 switch r := u.in.Body.(type) { case io.Seeker: - pos, _ := r.Seek(0, 1) - defer r.Seek(pos, 0) - - n, err := r.Seek(0, 2) + n, err := aws.SeekerLen(r) if err != nil { - return + return err } u.totalSize = n @@ -407,6 +406,7 @@ func (u *uploader) initSize() { u.cfg.PartSize = (u.totalSize / int64(u.cfg.MaxUploadParts)) + 1 } } + return nil } // nextReader returns a seekable reader representing the next packet of data. diff --git a/service/s3/statusok_error.go b/service/s3/statusok_error.go index 050b631a608..a0d410a8573 100644 --- a/service/s3/statusok_error.go +++ b/service/s3/statusok_error.go @@ -2,6 +2,7 @@ package s3 import ( "bytes" + "io" "io/ioutil" "net/http" @@ -17,7 +18,7 @@ func copyMultipartStatusOKUnmarhsalError(r *request.Request) { } body := bytes.NewReader(b) r.HTTPResponse.Body = ioutil.NopCloser(body) - defer body.Seek(0, 0) + defer body.Seek(0, io.SeekStart) if body.Len() == 0 { // If there is no body don't attempt to parse the body.