Skip to content

Commit

Permalink
split downloader and getter
Browse files Browse the repository at this point in the history
  • Loading branch information
wty-Bryant committed Feb 13, 2025
1 parent 93559ed commit dd3bcc7
Show file tree
Hide file tree
Showing 7 changed files with 618 additions and 485 deletions.
39 changes: 18 additions & 21 deletions feature/s3/transfermanager/api_op_DownloadObject.go
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ func (o *DownloadObjectOutput) mapFromGetObjectOutput(out *s3.GetObjectOutput, c
// download. These options are copies of the original Options instance, the client of which DownloadObject is called from.
// Modifying the options will not impact the original Client and Options instance.
func (c *Client) DownloadObject(ctx context.Context, input *DownloadObjectInput, opts ...func(*Options)) (*DownloadObjectOutput, error) {
i := downloader{in: input, options: c.options.Copy(), w: input.WriterAt}
i := downloader{in: input, options: c.options.Copy()}
for _, opt := range opts {
opt(&i.options)
}
Expand All @@ -541,7 +541,6 @@ type downloader struct {
options Options
in *DownloadObjectInput
out *DownloadObjectOutput
w io.WriterAt

wg sync.WaitGroup
m sync.Mutex
Expand Down Expand Up @@ -571,7 +570,7 @@ func (d *downloader) download(ctx context.Context) (*DownloadObjectOutput, error
return d.singleDownload(ctx, clientOptions...)
}

var output *GetObjectOutput
var output *DownloadObjectOutput
if d.options.MultipartDownloadType == types.MultipartDownloadTypePart {
if d.in.Range != "" {
return d.singleDownload(ctx, clientOptions...)
Expand All @@ -583,7 +582,7 @@ func (d *downloader) download(ctx context.Context) (*DownloadObjectOutput, error

if output.PartsCount > 1 {
partSize := output.ContentLength
ch := make(chan dlchunk, d.options.Concurrency)
ch := make(chan dlChunk, d.options.Concurrency)
for i := 0; i < d.options.Concurrency; i++ {
d.wg.Add(1)
go d.downloadPart(ctx, ch, clientOptions...)
Expand All @@ -594,25 +593,23 @@ func (d *downloader) download(ctx context.Context) (*DownloadObjectOutput, error
break
}

ch <- dlchunk{w: d.w, start: d.pos - d.offset, part: i}
ch <- dlChunk{w: d.in.WriterAt, start: d.pos - d.offset, part: i}
d.pos += partSize
}

close(ch)
d.wg.Wait()
}
} else {
var total int64
if d.in.Range == "" {
output = d.getChunk(ctx, 0, d.byteRange(), clientOptions...)
total = d.getTotalBytes()
} else {
d.pos, d.totalBytes = d.getDownloadRange()
d.offset = d.pos
total = d.totalBytes
}
total := d.totalBytes

ch := make(chan dlchunk, d.options.Concurrency)
ch := make(chan dlChunk, d.options.Concurrency)
for i := 0; i < d.options.Concurrency; i++ {
d.wg.Add(1)
go d.downloadPart(ctx, ch, clientOptions...)
Expand All @@ -625,7 +622,7 @@ func (d *downloader) download(ctx context.Context) (*DownloadObjectOutput, error
}

// Queue the next range of bytes to read.
ch <- dlchunk{w: d.w, start: d.pos - d.offset, withRange: d.byteRange()}
ch <- dlChunk{w: d.in.WriterAt, start: d.pos - d.offset, withRange: d.byteRange()}
d.pos += d.options.PartSizeBytes
}

Expand Down Expand Up @@ -659,17 +656,17 @@ func (d *downloader) init(ctx context.Context) error {
}

func (d *downloader) singleDownload(ctx context.Context, clientOptions ...func(*s3.Options)) (*DownloadObjectOutput, error) {
chunk := dlchunk{w: d.w}
d.in.PartNumber = 0
chunk := dlChunk{w: d.in.WriterAt}
// d.in.PartNumber = 0
output, err := d.downloadChunk(ctx, chunk, clientOptions...)
if err != nil {
return output, err
return nil, err
}

return output, err
return output, nil
}

func (d *downloader) downloadPart(ctx context.Context, ch chan dlchunk, clientOptions ...func(*s3.Options)) {
func (d *downloader) downloadPart(ctx context.Context, ch chan dlChunk, clientOptions ...func(*s3.Options)) {
defer d.wg.Done()
for {
chunk, ok := <-ch
Expand All @@ -690,8 +687,8 @@ func (d *downloader) downloadPart(ctx context.Context, ch chan dlchunk, clientOp

// getChunk grabs a chunk of data from the body.
// Not thread safe. Should only used when grabbing data on a single thread.
func (d *downloader) getChunk(ctx context.Context, part int32, rng string, clientOptions ...func(*s3.Options)) *GetObjectOutput {
chunk := dlchunk{w: d.w, start: d.pos - d.offset, part: part, withRange: rng}
func (d *downloader) getChunk(ctx context.Context, part int32, rng string, clientOptions ...func(*s3.Options)) *DownloadObjectOutput {
chunk := dlChunk{w: d.in.WriterAt, start: d.pos - d.offset, part: part, withRange: rng}

output, err := d.downloadChunk(ctx, chunk, clientOptions...)
if err != nil {
Expand All @@ -705,7 +702,7 @@ func (d *downloader) getChunk(ctx context.Context, part int32, rng string, clien
}

// downloadChunk downloads the chunk from s3
func (d *downloader) downloadChunk(ctx context.Context, chunk dlchunk, clientOptions ...func(*s3.Options)) (*DownloadObjectOutput, error) {
func (d *downloader) downloadChunk(ctx context.Context, chunk dlChunk, clientOptions ...func(*s3.Options)) (*DownloadObjectOutput, error) {
params := d.in.mapGetObjectInput(!d.options.DisableChecksumValidation)
if chunk.part != 0 {
params.PartNumber = aws.Int32(chunk.part)
Expand Down Expand Up @@ -750,7 +747,7 @@ func (d *downloader) downloadChunk(ctx context.Context, chunk dlchunk, clientOpt
return output, err
}

func (d *downloader) tryDownloadChunk(ctx context.Context, params *s3.GetObjectInput, chunk *dlchunk, clientOptions ...func(*s3.Options)) (*s3.GetObjectOutput, int64, error) {
func (d *downloader) tryDownloadChunk(ctx context.Context, params *s3.GetObjectInput, chunk *dlChunk, clientOptions ...func(*s3.Options)) (*s3.GetObjectOutput, int64, error) {
out, err := d.options.S3.GetObject(ctx, params, clientOptions...)
if err != nil {
return nil, 0, err
Expand Down Expand Up @@ -865,7 +862,7 @@ func (d *downloader) setErr(e error) {
d.err = e
}

type dlchunk struct {
type dlChunk struct {
w io.WriterAt

start int64
Expand All @@ -875,7 +872,7 @@ type dlchunk struct {
withRange string
}

func (c *dlchunk) Write(p []byte) (int, error) {
func (c *dlChunk) Write(p []byte) (int, error) {
n, err := c.w.WriteAt(p, c.start+c.cur)
c.cur += int64(n)

Expand Down
24 changes: 24 additions & 0 deletions feature/s3/transfermanager/api_op_DownloadObject_integ_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
//go:build integration
// +build integration

package transfermanager

import (
"bytes"
"strings"
"testing"
)

func TestInteg_DownloadObject(t *testing.T) {
cases := map[string]getObjectTestData{
"seekable body": {Body: strings.NewReader("hello world"), ExpectBody: []byte("hello world")},
"empty string body": {Body: strings.NewReader(""), ExpectBody: []byte("")},
"multipart download body": {Body: bytes.NewReader(largeObjectBuf), ExpectBody: largeObjectBuf},
}

for name, c := range cases {
t.Run(name, func(t *testing.T) {
testDownloadObject(t, setupMetadata.Buckets.Source.Name, c)
})
}
}
Loading

0 comments on commit dd3bcc7

Please sign in to comment.