diff --git a/backend/azure/azure.go b/backend/azure/azure.go index b0b117d4..0d3fe230 100644 --- a/backend/azure/azure.go +++ b/backend/azure/azure.go @@ -411,7 +411,7 @@ func (az *Azure) DeleteBucketTagging(ctx context.Context, bucket string) error { return az.PutBucketTagging(ctx, bucket, nil) } -func (az *Azure) GetObject(ctx context.Context, input *s3.GetObjectInput, writer io.Writer) (*s3.GetObjectOutput, error) { +func (az *Azure) GetObject(ctx context.Context, input *s3.GetObjectInput) (*s3.GetObjectOutput, error) { var opts *azblob.DownloadStreamOptions if *input.Range != "" { offset, count, err := backend.ParseRange(0, *input.Range) @@ -429,12 +429,6 @@ func (az *Azure) GetObject(ctx context.Context, input *s3.GetObjectInput, writer if err != nil { return nil, azureErrToS3Err(err) } - defer blobDownloadResponse.Body.Close() - - _, err = io.Copy(writer, blobDownloadResponse.Body) - if err != nil { - return nil, fmt.Errorf("copy data: %w", err) - } var tagcount int32 if blobDownloadResponse.TagCount != nil { @@ -451,6 +445,7 @@ func (az *Azure) GetObject(ctx context.Context, input *s3.GetObjectInput, writer Metadata: parseAzMetadata(blobDownloadResponse.Metadata), TagCount: &tagcount, ContentRange: blobDownloadResponse.ContentRange, + Body: blobDownloadResponse.Body, }, nil } diff --git a/backend/backend.go b/backend/backend.go index 6ebc705d..2951887d 100644 --- a/backend/backend.go +++ b/backend/backend.go @@ -18,7 +18,6 @@ import ( "bufio" "context" "fmt" - "io" "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/aws-sdk-go-v2/service/s3/types" @@ -60,7 +59,7 @@ type Backend interface { // standard object operations PutObject(context.Context, *s3.PutObjectInput) (string, error) HeadObject(context.Context, *s3.HeadObjectInput) (*s3.HeadObjectOutput, error) - GetObject(context.Context, *s3.GetObjectInput, io.Writer) (*s3.GetObjectOutput, error) + GetObject(context.Context, *s3.GetObjectInput) (*s3.GetObjectOutput, error) GetObjectAcl(context.Context, *s3.GetObjectAclInput) (*s3.GetObjectAclOutput, error) GetObjectAttributes(context.Context, *s3.GetObjectAttributesInput) (s3response.GetObjectAttributesResult, error) CopyObject(context.Context, *s3.CopyObjectInput) (*s3.CopyObjectOutput, error) @@ -180,7 +179,7 @@ func (BackendUnsupported) PutObject(context.Context, *s3.PutObjectInput) (string func (BackendUnsupported) HeadObject(context.Context, *s3.HeadObjectInput) (*s3.HeadObjectOutput, error) { return nil, s3err.GetAPIError(s3err.ErrNotImplemented) } -func (BackendUnsupported) GetObject(context.Context, *s3.GetObjectInput, io.Writer) (*s3.GetObjectOutput, error) { +func (BackendUnsupported) GetObject(context.Context, *s3.GetObjectInput) (*s3.GetObjectOutput, error) { return nil, s3err.GetAPIError(s3err.ErrNotImplemented) } func (BackendUnsupported) GetObjectAcl(context.Context, *s3.GetObjectAclInput) (*s3.GetObjectAclOutput, error) { diff --git a/backend/common.go b/backend/common.go index 42c449c6..57ebe5ce 100644 --- a/backend/common.go +++ b/backend/common.go @@ -18,7 +18,9 @@ import ( "crypto/md5" "encoding/hex" "fmt" + "io" "net/http" + "os" "strconv" "strings" "time" @@ -128,3 +130,16 @@ func md5String(data []byte) string { sum := md5.Sum(data) return hex.EncodeToString(sum[:]) } + +type FileSectionReadCloser struct { + R io.Reader + F *os.File +} + +func (f *FileSectionReadCloser) Read(p []byte) (int, error) { + return f.R.Read(p) +} + +func (f *FileSectionReadCloser) Close() error { + return f.F.Close() +} diff --git a/backend/posix/posix.go b/backend/posix/posix.go index f92d0fc4..70b3d7f2 100644 --- a/backend/posix/posix.go +++ b/backend/posix/posix.go @@ -1593,7 +1593,7 @@ func (p *Posix) DeleteObjects(ctx context.Context, input *s3.DeleteObjectsInput) }, nil } -func (p *Posix) GetObject(_ context.Context, input *s3.GetObjectInput, writer io.Writer) (*s3.GetObjectOutput, error) { +func (p *Posix) GetObject(_ context.Context, input *s3.GetObjectInput) (*s3.GetObjectOutput, error) { if input.Bucket == nil { return nil, s3err.GetAPIError(s3err.ErrInvalidBucketName) } @@ -1637,11 +1637,11 @@ func (p *Posix) GetObject(_ context.Context, input *s3.GetObjectInput, writer io } if length == -1 { - length = objSize - startOffset + 1 + length = objSize - startOffset } - if startOffset+length > objSize+1 { - length = objSize - startOffset + 1 + if startOffset+length > objSize { + length = objSize - startOffset } var contentRange string @@ -1684,21 +1684,6 @@ func (p *Posix) GetObject(_ context.Context, input *s3.GetObjectInput, writer io }, nil } - f, err := os.Open(objPath) - if errors.Is(err, fs.ErrNotExist) { - return nil, s3err.GetAPIError(s3err.ErrNoSuchKey) - } - if err != nil { - return nil, fmt.Errorf("open object: %w", err) - } - defer f.Close() - - rdr := io.NewSectionReader(f, startOffset, length) - _, err = io.Copy(writer, rdr) - if err != nil { - return nil, fmt.Errorf("copy data: %w", err) - } - userMetaData := make(map[string]string) contentType, contentEncoding := p.loadUserMetaData(bucket, object, userMetaData) @@ -1719,6 +1704,16 @@ func (p *Posix) GetObject(_ context.Context, input *s3.GetObjectInput, writer io tagCount = &tgCount } + f, err := os.Open(objPath) + if errors.Is(err, fs.ErrNotExist) { + return nil, s3err.GetAPIError(s3err.ErrNoSuchKey) + } + if err != nil { + return nil, fmt.Errorf("open object: %w", err) + } + + rdr := io.NewSectionReader(f, startOffset, length) + return &s3.GetObjectOutput{ AcceptRanges: &acceptRange, ContentLength: &length, @@ -1729,6 +1724,7 @@ func (p *Posix) GetObject(_ context.Context, input *s3.GetObjectInput, writer io Metadata: userMetaData, TagCount: tagCount, ContentRange: &contentRange, + Body: &backend.FileSectionReadCloser{R: rdr, F: f}, }, nil } diff --git a/backend/s3proxy/s3.go b/backend/s3proxy/s3.go index 6b2010c5..7d7c6fbd 100644 --- a/backend/s3proxy/s3.go +++ b/backend/s3proxy/s3.go @@ -314,17 +314,11 @@ func (s *S3Proxy) HeadObject(ctx context.Context, input *s3.HeadObjectInput) (*s return out, handleError(err) } -func (s *S3Proxy) GetObject(ctx context.Context, input *s3.GetObjectInput, w io.Writer) (*s3.GetObjectOutput, error) { +func (s *S3Proxy) GetObject(ctx context.Context, input *s3.GetObjectInput) (*s3.GetObjectOutput, error) { output, err := s.client.GetObject(ctx, input) if err != nil { return nil, handleError(err) } - defer output.Body.Close() - - _, err = io.Copy(w, output.Body) - if err != nil { - return nil, err - } return output, nil } diff --git a/backend/scoutfs/scoutfs.go b/backend/scoutfs/scoutfs.go index 4c52e92a..606475f4 100644 --- a/backend/scoutfs/scoutfs.go +++ b/backend/scoutfs/scoutfs.go @@ -589,7 +589,7 @@ func (s *ScoutFS) retrieveUploadId(bucket, object string) (string, [32]byte, err return entries[0].Name(), sum, nil } -func (s *ScoutFS) GetObject(_ context.Context, input *s3.GetObjectInput, writer io.Writer) (*s3.GetObjectOutput, error) { +func (s *ScoutFS) GetObject(_ context.Context, input *s3.GetObjectInput) (*s3.GetObjectOutput, error) { bucket := *input.Bucket object := *input.Key acceptRange := *input.Range @@ -658,13 +658,8 @@ func (s *ScoutFS) GetObject(_ context.Context, input *s3.GetObjectInput, writer if err != nil { return nil, fmt.Errorf("open object: %w", err) } - defer f.Close() rdr := io.NewSectionReader(f, startOffset, length) - _, err = io.Copy(writer, rdr) - if err != nil { - return nil, fmt.Errorf("copy data: %w", err) - } userMetaData := make(map[string]string) @@ -694,6 +689,7 @@ func (s *ScoutFS) GetObject(_ context.Context, input *s3.GetObjectInput, writer TagCount: &tagCount, StorageClass: types.StorageClassStandard, ContentRange: &contentRange, + Body: &backend.FileSectionReadCloser{R: rdr, F: f}, }, nil } diff --git a/s3api/controllers/backend_moq_test.go b/s3api/controllers/backend_moq_test.go index 3a8aef31..f5e59575 100644 --- a/s3api/controllers/backend_moq_test.go +++ b/s3api/controllers/backend_moq_test.go @@ -10,7 +10,6 @@ import ( "github.com/aws/aws-sdk-go-v2/service/s3/types" "github.com/versity/versitygw/backend" "github.com/versity/versitygw/s3response" - "io" "sync" ) @@ -78,7 +77,7 @@ var _ backend.Backend = &BackendMock{} // GetBucketVersioningFunc: func(contextMoqParam context.Context, bucket string) (*s3.GetBucketVersioningOutput, error) { // panic("mock out the GetBucketVersioning method") // }, -// GetObjectFunc: func(contextMoqParam context.Context, getObjectInput *s3.GetObjectInput, writer io.Writer) (*s3.GetObjectOutput, error) { +// GetObjectFunc: func(contextMoqParam context.Context, getObjectInput *s3.GetObjectInput) (*s3.GetObjectOutput, error) { // panic("mock out the GetObject method") // }, // GetObjectAclFunc: func(contextMoqParam context.Context, getObjectAclInput *s3.GetObjectAclInput) (*s3.GetObjectAclOutput, error) { @@ -239,7 +238,7 @@ type BackendMock struct { GetBucketVersioningFunc func(contextMoqParam context.Context, bucket string) (*s3.GetBucketVersioningOutput, error) // GetObjectFunc mocks the GetObject method. - GetObjectFunc func(contextMoqParam context.Context, getObjectInput *s3.GetObjectInput, writer io.Writer) (*s3.GetObjectOutput, error) + GetObjectFunc func(contextMoqParam context.Context, getObjectInput *s3.GetObjectInput) (*s3.GetObjectOutput, error) // GetObjectAclFunc mocks the GetObjectAcl method. GetObjectAclFunc func(contextMoqParam context.Context, getObjectAclInput *s3.GetObjectAclInput) (*s3.GetObjectAclOutput, error) @@ -477,8 +476,6 @@ type BackendMock struct { ContextMoqParam context.Context // GetObjectInput is the getObjectInput argument value. GetObjectInput *s3.GetObjectInput - // Writer is the writer argument value. - Writer io.Writer } // GetObjectAcl holds details about calls to the GetObjectAcl method. GetObjectAcl []struct { @@ -1449,23 +1446,21 @@ func (mock *BackendMock) GetBucketVersioningCalls() []struct { } // GetObject calls GetObjectFunc. -func (mock *BackendMock) GetObject(contextMoqParam context.Context, getObjectInput *s3.GetObjectInput, writer io.Writer) (*s3.GetObjectOutput, error) { +func (mock *BackendMock) GetObject(contextMoqParam context.Context, getObjectInput *s3.GetObjectInput) (*s3.GetObjectOutput, error) { if mock.GetObjectFunc == nil { panic("BackendMock.GetObjectFunc: method is nil but Backend.GetObject was just called") } callInfo := struct { ContextMoqParam context.Context GetObjectInput *s3.GetObjectInput - Writer io.Writer }{ ContextMoqParam: contextMoqParam, GetObjectInput: getObjectInput, - Writer: writer, } mock.lockGetObject.Lock() mock.calls.GetObject = append(mock.calls.GetObject, callInfo) mock.lockGetObject.Unlock() - return mock.GetObjectFunc(contextMoqParam, getObjectInput, writer) + return mock.GetObjectFunc(contextMoqParam, getObjectInput) } // GetObjectCalls gets all the calls that were made to GetObject. @@ -1475,12 +1470,10 @@ func (mock *BackendMock) GetObject(contextMoqParam context.Context, getObjectInp func (mock *BackendMock) GetObjectCalls() []struct { ContextMoqParam context.Context GetObjectInput *s3.GetObjectInput - Writer io.Writer } { var calls []struct { ContextMoqParam context.Context GetObjectInput *s3.GetObjectInput - Writer io.Writer } mock.lockGetObject.RLock() calls = mock.calls.GetObject diff --git a/s3api/controllers/base.go b/s3api/controllers/base.go index ce1475e3..edfd6fe6 100644 --- a/s3api/controllers/base.go +++ b/s3api/controllers/base.go @@ -402,7 +402,7 @@ func (c S3ApiController) GetActions(ctx *fiber.Ctx) error { Key: &key, Range: &acceptRange, VersionId: &versionId, - }, ctx.Response().BodyWriter()) + }) if err != nil { return SendResponse(ctx, err, &MetaOpts{ @@ -412,15 +412,6 @@ func (c S3ApiController) GetActions(ctx *fiber.Ctx) error { BucketOwner: parsedAcl.Owner, }) } - if res == nil { - return SendResponse(ctx, fmt.Errorf("get object nil response"), - &MetaOpts{ - Logger: c.logger, - MetricsMng: c.mm, - Action: metrics.ActionGetObject, - BucketOwner: parsedAcl.Owner, - }) - } utils.SetMetaHeaders(ctx, res.Metadata) var lastmod string @@ -429,10 +420,6 @@ func (c S3ApiController) GetActions(ctx *fiber.Ctx) error { } utils.SetResponseHeaders(ctx, []utils.CustomHeader{ - { - Key: "Content-Length", - Value: fmt.Sprint(getint64(res.ContentLength)), - }, { Key: "Content-Type", Value: getstring(res.ContentType), @@ -477,6 +464,10 @@ func (c S3ApiController) GetActions(ctx *fiber.Ctx) error { status = http.StatusPartialContent } + if res.Body != nil { + ctx.Response().SetBodyStream(res.Body, int(getint64(res.ContentLength))) + } + return SendResponse(ctx, nil, &MetaOpts{ Logger: c.logger, diff --git a/s3api/controllers/base_test.go b/s3api/controllers/base_test.go index 4aa6ff8b..6469b983 100644 --- a/s3api/controllers/base_test.go +++ b/s3api/controllers/base_test.go @@ -19,7 +19,6 @@ import ( "context" "encoding/json" "fmt" - "io" "net/http" "net/http/httptest" "reflect" @@ -191,7 +190,7 @@ func TestS3ApiController_GetActions(t *testing.T) { GetObjectAttributesFunc: func(context.Context, *s3.GetObjectAttributesInput) (s3response.GetObjectAttributesResult, error) { return s3response.GetObjectAttributesResult{}, nil }, - GetObjectFunc: func(context.Context, *s3.GetObjectInput, io.Writer) (*s3.GetObjectOutput, error) { + GetObjectFunc: func(context.Context, *s3.GetObjectInput) (*s3.GetObjectOutput, error) { return &s3.GetObjectOutput{ Metadata: map[string]string{"hello": "world"}, ContentType: getPtr("application/xml"), diff --git a/tests/integration/tests.go b/tests/integration/tests.go index 719f9fe7..5c92570e 100644 --- a/tests/integration/tests.go +++ b/tests/integration/tests.go @@ -3287,30 +3287,30 @@ func GetObject_invalid_ranges(s *S3Conf) error { } ctx, cancel = context.WithTimeout(context.Background(), shortTimeout) - resp, err := s3client.GetObject(ctx, &s3.GetObjectInput{ + _, err = s3client.GetObject(ctx, &s3.GetObjectInput{ Bucket: &bucket, Key: &obj, - Range: getPtr("bytes=1500-999999999999"), + Range: getPtr("bytes=0-0"), }) cancel() - if err != nil { + if err := checkApiErr(err, s3err.GetAPIError(s3err.ErrInvalidRange)); err != nil { return err } - if *resp.ContentLength != dataLength-1500 { - return fmt.Errorf("expected content-length to be %v, instead got %v", dataLength-1500, *resp.ContentLength) - } - ctx, cancel = context.WithTimeout(context.Background(), shortTimeout) - _, err = s3client.GetObject(ctx, &s3.GetObjectInput{ + resp, err := s3client.GetObject(ctx, &s3.GetObjectInput{ Bucket: &bucket, Key: &obj, - Range: getPtr("bytes=0-0"), + Range: getPtr("bytes=1500-999999999999"), }) cancel() - if err := checkApiErr(err, s3err.GetAPIError(s3err.ErrInvalidRange)); err != nil { + if err != nil { return err } + + if *resp.ContentLength != dataLength-1500 { + return fmt.Errorf("expected content-length to be %v, instead got %v", dataLength-1500, *resp.ContentLength) + } return nil }) } @@ -3419,7 +3419,7 @@ func GetObject_by_range_success(s *S3Conf) error { return fmt.Errorf("expected accept range: %v, instead got: %v", rangeString, getString(out.AcceptRanges)) } b, err := io.ReadAll(out.Body) - if err != nil { + if err != nil && !errors.Is(err, io.EOF) { return err } @@ -3443,7 +3443,7 @@ func GetObject_by_range_success(s *S3Conf) error { defer out.Body.Close() b, err = io.ReadAll(out.Body) - if err != nil { + if err != nil && !errors.Is(err, io.EOF) { return err }