diff --git a/transport/http/checksum_middleware.go b/transport/http/checksum_middleware.go index 2ec7cbaee..bc4ad6e79 100644 --- a/transport/http/checksum_middleware.go +++ b/transport/http/checksum_middleware.go @@ -45,6 +45,11 @@ func (m *contentMD5Checksum) HandleBuild( stream := req.GetStream() // compute checksum if payload is explicit if stream != nil { + if !req.IsStreamSeekable() { + return out, metadata, fmt.Errorf( + "unseekable stream is not supported for computing md5 checksum") + } + v, err := computeMD5Checksum(stream) if err != nil { return out, metadata, fmt.Errorf("error computing md5 checksum, %w", err) diff --git a/transport/http/checksum_middleware_test.go b/transport/http/checksum_middleware_test.go index 6a6dfc89b..a911d68d9 100644 --- a/transport/http/checksum_middleware_test.go +++ b/transport/http/checksum_middleware_test.go @@ -35,7 +35,7 @@ func TestChecksumMiddleware(t *testing.T) { "nil body": {}, "unseekable payload": { payload: bytes.NewBuffer([]byte(`xyz`)), - expectError: "error rewinding request stream", + expectError: "unseekable stream is not supported", }, } @@ -61,6 +61,7 @@ func TestChecksumMiddleware(t *testing.T) { if e, a := c.expectError, err.Error(); !strings.Contains(a, e) { t.Fatalf("expect error to contain %q, got %v", e, a) } + return } else if err != nil { t.Fatalf("expect no error, got %v", err) } diff --git a/transport/http/request.go b/transport/http/request.go index 6a759ff3e..ffac684f4 100644 --- a/transport/http/request.go +++ b/transport/http/request.go @@ -45,19 +45,23 @@ func (r *Request) Clone() *Request { // to the request and ok set. If the length cannot be determined, an error will // be returned. func (r *Request) StreamLength() (size int64, ok bool, err error) { - if r.stream == nil { + return streamLength(r.stream, r.isStreamSeekable, r.streamStartPos) +} + +func streamLength(stream io.Reader, seekable bool, startPos int64) (size int64, ok bool, err error) { + if stream == nil { return 0, true, nil } - if l, ok := r.stream.(interface{ Len() int }); ok { + if l, ok := stream.(interface{ Len() int }); ok { return int64(l.Len()), true, nil } - if !r.isStreamSeekable { + if !seekable { return 0, false, nil } - s := r.stream.(io.Seeker) + s := stream.(io.Seeker) endOffset, err := s.Seek(0, io.SeekEnd) if err != nil { return 0, false, err @@ -69,12 +73,12 @@ func (r *Request) StreamLength() (size int64, ok bool, err error) { // file, and wants to skip the first N bytes uploading the rest. The // application would move the file's offset N bytes, then hand it off to // the SDK to send the remaining. The SDK should respect that initial offset. - _, err = s.Seek(r.streamStartPos, io.SeekStart) + _, err = s.Seek(startPos, io.SeekStart) if err != nil { return 0, false, err } - return endOffset - r.streamStartPos, true, nil + return endOffset - startPos, true, nil } // RewindStream will rewind the io.Reader to the relative start position if it @@ -103,8 +107,9 @@ func (r *Request) IsStreamSeekable() bool { return r.isStreamSeekable } -// SetStream returns a clone of the request with the stream set to the provided reader. -// May return an error if the provided reader is seekable but returns an error. +// SetStream returns a clone of the request with the stream set to the provided +// reader. May return an error if the provided reader is seekable but returns +// an error. func (r *Request) SetStream(reader io.Reader) (rc *Request, err error) { rc = r.Clone() @@ -112,18 +117,31 @@ func (r *Request) SetStream(reader io.Reader) (rc *Request, err error) { reader = nil } + var isStreamSeekable bool + var streamStartPos int64 switch v := reader.(type) { case io.Seeker: n, err := v.Seek(0, io.SeekCurrent) if err != nil { return r, err } - rc.isStreamSeekable = true - rc.streamStartPos = n + isStreamSeekable = true + streamStartPos = n default: - rc.isStreamSeekable = false + // If the stream length can be determined, and is determined to be empty, + // use a nil stream to prevent confusion between empty vs not-empty + // streams. + length, ok, err := streamLength(reader, false, 0) + if err != nil { + return nil, err + } else if ok && length == 0 { + reader = nil + } } + rc.stream = reader + rc.isStreamSeekable = isStreamSeekable + rc.streamStartPos = streamStartPos return rc, err } diff --git a/transport/http/request_test.go b/transport/http/request_test.go index 602dab1d2..af88fbc36 100644 --- a/transport/http/request_test.go +++ b/transport/http/request_test.go @@ -20,8 +20,12 @@ func TestRequestRewindable(t *testing.T) { "rewindable": { Stream: bytes.NewReader([]byte{}), }, - "not rewindable": { - Stream: bytes.NewBuffer([]byte{}), + "empty not rewindable": { + Stream: bytes.NewBuffer([]byte{}), + // ExpectErr: "stream is not seekable", + }, + "not empty not rewindable": { + Stream: bytes.NewBuffer([]byte("abc123")), ExpectErr: "stream is not seekable", }, "nil stream": {}, @@ -121,7 +125,7 @@ func TestRequestSetStream(t *testing.T) { }, "empty unseekable stream": { reader: bytes.NewBuffer([]byte{}), - expectNilStream: false, + expectNilStream: true, expectNilBody: true, }, "empty seekable stream": {