diff --git a/CHANGELOG_PENDING.md b/CHANGELOG_PENDING.md index 3ad885b904c..c4a7382063c 100644 --- a/CHANGELOG_PENDING.md +++ b/CHANGELOG_PENDING.md @@ -1,6 +1,8 @@ ### SDK Features ### SDK Enhancements +* `service/s3/s3manager`: adding cleanup function to batch objects [#1375](https://github.com/aws/aws-sdk-go/issues/1375) + * This enhancement will add an After field that will be called after each iteration of the batch operation. ### SDK Bugs * `aws/signer/v4`: checking length on `stripExcessSpaces` [#1372](https://github.com/aws/aws-sdk-go/issues/1372) diff --git a/service/s3/s3manager/batch.go b/service/s3/s3manager/batch.go index ddeb771aa5f..630cea5fa7d 100644 --- a/service/s3/s3manager/batch.go +++ b/service/s3/s3manager/batch.go @@ -43,6 +43,14 @@ type Error struct { Key *string } +func newError(err error, bucket, key *string) Error { + return Error{ + err, + bucket, + key, + } +} + func (err *Error) Error() string { return err.OrigErr.Error() } @@ -239,6 +247,8 @@ func NewBatchDelete(c client.ConfigProvider, options ...func(*BatchDelete)) *Bat // BatchDeleteObject is a wrapper object for calling the batch delete operation. type BatchDeleteObject struct { Object *s3.DeleteObjectInput + // After will run after each iteration during the batch process + After func() error } // DeleteObjectsIterator is an interface that uses the scanner pattern to iterate @@ -277,15 +287,17 @@ func (iter *DeleteObjectsIterator) DeleteObject() BatchDeleteObject { func (d *BatchDelete) Delete(ctx aws.Context, iter BatchDeleteIterator) error { var errs []Error for iter.Next() { - object := iter.DeleteObject().Object - if _, err := d.Client.DeleteObjectWithContext(ctx, object); err != nil { - s3Err := Error{ - OrigErr: err, - Bucket: object.Bucket, - Key: object.Key, - } - - errs = append(errs, s3Err) + object := iter.DeleteObject() + if _, err := d.Client.DeleteObjectWithContext(ctx, object.Object); err != nil { + errs = append(errs, newError(err, object.Object.Bucket, object.Object.Key)) + } + + if object.After == nil { + continue + } + + if err := object.After(); err != nil { + errs = append(errs, newError(err, object.Object.Bucket, object.Object.Key)) } } @@ -307,6 +319,8 @@ type BatchDownloadIterator interface { type BatchDownloadObject struct { Object *s3.GetObjectInput Writer io.WriterAt + // After will run after each iteration during the batch process + After func() error } // DownloadObjectsIterator implements the BatchDownloadIterator interface and allows for batched @@ -382,4 +396,6 @@ func (batcher *UploadObjectsIterator) UploadObject() BatchUploadObject { // BatchUploadObject contains all necessary information to run a batch operation once. type BatchUploadObject struct { Object *UploadInput + // After will run after each iteration during the batch process + After func() error } diff --git a/service/s3/s3manager/batch_test.go b/service/s3/s3manager/batch_test.go index 0a800fe445a..6bf8f9788bb 100644 --- a/service/s3/s3manager/batch_test.go +++ b/service/s3/s3manager/batch_test.go @@ -447,7 +447,9 @@ func TestBatchUpload(t *testing.T) { type mockClient struct { s3iface.S3API - index int + Put func() (*s3.PutObjectOutput, error) + Get func() (*s3.GetObjectOutput, error) + List func() (*s3.ListObjectsOutput, error) responses []response } @@ -457,37 +459,25 @@ type response struct { } func (client *mockClient) PutObject(input *s3.PutObjectInput) (*s3.PutObjectOutput, error) { - resp := client.responses[client.index] - client.index++ - return resp.out.(*s3.PutObjectOutput), resp.err + return client.Put() } func (client *mockClient) PutObjectRequest(input *s3.PutObjectInput) (*request.Request, *s3.PutObjectOutput) { - resp := client.responses[client.index] req, _ := client.S3API.PutObjectRequest(input) req.Handlers.Clear() - req.Data = resp.out - req.Error = resp.err - - client.index++ - return req, resp.out.(*s3.PutObjectOutput) + req.Data, req.Error = client.Put() + return req, req.Data.(*s3.PutObjectOutput) } func (client *mockClient) ListObjects(input *s3.ListObjectsInput) (*s3.ListObjectsOutput, error) { - resp := client.responses[client.index] - client.index++ - return resp.out.(*s3.ListObjectsOutput), resp.err + return client.List() } func (client *mockClient) ListObjectsRequest(input *s3.ListObjectsInput) (*request.Request, *s3.ListObjectsOutput) { - resp := client.responses[client.index] req, _ := client.S3API.ListObjectsRequest(input) req.Handlers.Clear() - req.Data = resp.out - req.Error = resp.err - - client.index++ - return req, resp.out.(*s3.ListObjectsOutput) + req.Data, req.Error = client.List() + return req, req.Data.(*s3.ListObjectsOutput) } func TestBatchError(t *testing.T) { @@ -500,26 +490,38 @@ func TestBatchError(t *testing.T) { Region: aws.String("foo"), Credentials: credentials.NewStaticCredentials("AKID", "SECRET", "SESSION"), }) + + index := 0 + responses := []response{ + { + &s3.PutObjectOutput{}, + errors.New("Foo"), + }, + { + &s3.PutObjectOutput{}, + nil, + }, + { + &s3.PutObjectOutput{}, + nil, + }, + { + &s3.PutObjectOutput{}, + errors.New("Bar"), + }, + } + svc := &mockClient{ - s3.New(sess), - 0, - []response{ - { - &s3.PutObjectOutput{}, - errors.New("Foo"), - }, - { - &s3.PutObjectOutput{}, - nil, - }, - { - &s3.PutObjectOutput{}, - nil, - }, - { - &s3.PutObjectOutput{}, - errors.New("Bar"), - }, + S3API: s3.New(sess), + Put: func() (*s3.PutObjectOutput, error) { + resp := responses[index] + index++ + return resp.out.(*s3.PutObjectOutput), resp.err + }, + List: func() (*s3.ListObjectsOutput, error) { + resp := responses[index] + index++ + return resp.out.(*s3.ListObjectsOutput), resp.err }, } uploader := NewUploaderWithClient(svc) @@ -590,8 +592,141 @@ func TestBatchError(t *testing.T) { t.Error("Expected error, but received nil") } - if svc.index != len(objects) { - t.Errorf("Expected %d, but received %d", len(objects), svc.index) + if index != len(objects) { + t.Errorf("Expected %d, but received %d", len(objects), index) + } + +} + +type testAfterIter struct { + afterDelete bool + afterDownload bool + afterUpload bool + index int +} + +func (iter *testAfterIter) Next() bool { + next := (iter.index & 1) == 0 + iter.index++ + return next +} + +func (iter *testAfterIter) Err() error { + return nil +} + +func (iter *testAfterIter) DeleteObject() BatchDeleteObject { + return BatchDeleteObject{ + Object: &s3.DeleteObjectInput{ + Bucket: aws.String("foo"), + Key: aws.String("foo"), + }, + After: func() error { + iter.afterDelete = true + return nil + }, + } +} + +func (iter *testAfterIter) DownloadObject() BatchDownloadObject { + return BatchDownloadObject{ + Object: &s3.GetObjectInput{ + Bucket: aws.String("foo"), + Key: aws.String("foo"), + }, + Writer: aws.NewWriteAtBuffer([]byte{}), + After: func() error { + iter.afterDownload = true + return nil + }, } +} +func (iter *testAfterIter) UploadObject() BatchUploadObject { + return BatchUploadObject{ + Object: &UploadInput{ + Bucket: aws.String("foo"), + Key: aws.String("foo"), + Body: strings.NewReader("bar"), + }, + After: func() error { + iter.afterUpload = true + return nil + }, + } +} + +func TestAfter(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + })) + + sess := session.New(&aws.Config{ + Endpoint: &server.URL, + S3ForcePathStyle: aws.Bool(true), + Region: aws.String("foo"), + Credentials: credentials.NewStaticCredentials("AKID", "SECRET", "SESSION"), + }) + + index := 0 + responses := []response{ + { + &s3.PutObjectOutput{}, + nil, + }, + { + &s3.GetObjectOutput{}, + nil, + }, + { + &s3.DeleteObjectOutput{}, + nil, + }, + } + + svc := &mockClient{ + S3API: s3.New(sess), + Put: func() (*s3.PutObjectOutput, error) { + resp := responses[index] + index++ + return resp.out.(*s3.PutObjectOutput), resp.err + }, + Get: func() (*s3.GetObjectOutput, error) { + resp := responses[index] + index++ + return resp.out.(*s3.GetObjectOutput), resp.err + }, + List: func() (*s3.ListObjectsOutput, error) { + resp := responses[index] + index++ + return resp.out.(*s3.ListObjectsOutput), resp.err + }, + } + uploader := NewUploaderWithClient(svc) + downloader := NewDownloaderWithClient(svc) + deleter := NewBatchDeleteWithClient(svc) + + iter := &testAfterIter{} + if err := uploader.UploadWithIterator(aws.BackgroundContext(), iter); err != nil { + t.Error(err) + } + + if err := downloader.DownloadWithIterator(aws.BackgroundContext(), iter); err != nil { + t.Error(err) + } + + if err := deleter.Delete(aws.BackgroundContext(), iter); err != nil { + t.Error(err) + } + + if !iter.afterDelete { + t.Error("Expected 'afterDelete' to be true, but received false") + } + + if !iter.afterDownload { + t.Error("Expected 'afterDownload' to be true, but received false") + } + + if !iter.afterUpload { + t.Error("Expected 'afterUpload' to be true, but received false") + } } diff --git a/service/s3/s3manager/download.go b/service/s3/s3manager/download.go index 7a2168e21d0..d30f2b6b3ca 100644 --- a/service/s3/s3manager/download.go +++ b/service/s3/s3manager/download.go @@ -231,13 +231,15 @@ func (d Downloader) DownloadWithIterator(ctx aws.Context, iter BatchDownloadIter for iter.Next() { object := iter.DownloadObject() if _, err := d.DownloadWithContext(ctx, object.Writer, object.Object, opts...); err != nil { - s3Err := Error{ - OrigErr: err, - Bucket: object.Object.Bucket, - Key: object.Object.Key, - } + errs = append(errs, newError(err, object.Object.Bucket, object.Object.Key)) + } + + if object.After == nil { + continue + } - errs = append(errs, s3Err) + if err := object.After(); err != nil { + errs = append(errs, newError(err, object.Object.Bucket, object.Object.Key)) } } diff --git a/service/s3/s3manager/upload.go b/service/s3/s3manager/upload.go index 56894042fd3..fc1f47205b1 100644 --- a/service/s3/s3manager/upload.go +++ b/service/s3/s3manager/upload.go @@ -396,12 +396,26 @@ func (u Uploader) UploadWithContext(ctx aws.Context, input *UploadInput, opts .. func (u Uploader) UploadWithIterator(ctx aws.Context, iter BatchUploadIterator, opts ...func(*Uploader)) error { var errs []Error for iter.Next() { - object := iter.UploadObject().Object - if _, err := u.UploadWithContext(ctx, object, opts...); err != nil { + object := iter.UploadObject() + if _, err := u.UploadWithContext(ctx, object.Object, opts...); err != nil { s3Err := Error{ OrigErr: err, - Bucket: object.Bucket, - Key: object.Key, + Bucket: object.Object.Bucket, + Key: object.Object.Key, + } + + errs = append(errs, s3Err) + } + + if object.After == nil { + continue + } + + if err := object.After(); err != nil { + s3Err := Error{ + OrigErr: err, + Bucket: object.Object.Bucket, + Key: object.Object.Key, } errs = append(errs, s3Err)