Skip to content

Commit

Permalink
refactor: move concurrent client/funcs to be completely separate from…
Browse files Browse the repository at this point in the history
… existing S3 wrappers

Based on review
  • Loading branch information
CallumNZ committed Dec 13, 2023
1 parent 074be75 commit 1d2f64d
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 84 deletions.
58 changes: 3 additions & 55 deletions aws/s3/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,9 @@ import (
)

type S3 struct {
client *s3.Client
uploader *manager.Uploader
downloader *manager.Downloader
concurrenter *ConcurrencyManager
client *s3.Client
uploader *manager.Uploader
downloader *manager.Downloader
}

type Meta map[string]string
Expand Down Expand Up @@ -86,29 +85,6 @@ func (s3 *S3) AddDownloader() error {
return nil
}

// AddConcurrencyManager creates a ConcurrencyManager and sets it to
// the S3 struct's concurrenter field. This enables the use of the GetAll function,
// which downloads multiple files at once while retaining order.
// Ensure that the HTTP client option used to create the S3 client is configured to
// make use of the specified maxConnections. Also, ensure that the S3 Client
// has access to maxBytes in memory to avoid out of memory errors.
func (s3 *S3) AddConcurrencyManager(maxConnections, maxConnectionsPerRequest, maxBytes int) error {
if !s3.Ready() {
return errors.New("S3 client needs to be initialised to add concurrency limits")
}
if maxConnections <= 0 || maxConnectionsPerRequest <= 0 || maxBytes <= 0 {
return errors.New("all parameters must be greater than 0")
}
if maxConnections > maxBytes {
return errors.New("max bytes must be greater than or equal to max connections")
}
if maxConnectionsPerRequest > maxConnections {
return errors.New("max connections must be greater than or equal to max connections per request")
}
s3.concurrenter = NewConcurrencyManager(maxConnections, maxConnectionsPerRequest, maxBytes)
return nil
}

// getConfig returns the default AWS Config struct.
func getConfig() (aws.Config, error) {
if os.Getenv("AWS_REGION") == "" {
Expand Down Expand Up @@ -166,34 +142,6 @@ func (s *S3) Get(bucket, key, version string, b *bytes.Buffer) error {
return err
}

// GetAll gets the objects specified from bucket and writes the resulting HydratedFiles
// to the returned output channel. The closure of this channel is handled, however it's the caller's
// responsibility to purge the channel, and handle any errors present in the HydratedFiles.
// If the S3 client's ConcurrencyManager is not initialised before calling GetAll, an output channel
// containing a single HydratedFile with an error is returned.
// Version can be empty, but must be the same for all objects.
func (s *S3) GetAll(bucket, version string, objects []types.Object) chan HydratedFile {

if s.concurrenter == nil {
output := make(chan HydratedFile, 1)
output <- HydratedFile{Error: errors.New("error getting files from S3, concurrenter not initialised")}
close(output)
return output
}
processFunc := func(input types.Object) HydratedFile {
buf := bytes.NewBuffer(make([]byte, 0, input.Size))
key := aws.ToString(input.Key)
err := s.Get(bucket, key, version, buf)

return HydratedFile{
Key: key,
Data: buf.Bytes(),
Error: err,
}
}
return s.concurrenter.Process(processFunc, objects)
}

// GetByteRange gets the specified byte range of an object referred to by key and version
// from bucket and writes it into b. Version can be empty.
// See https://www.rfc-editor.org/rfc/rfc9110.html#name-byte-ranges for examples
Expand Down
66 changes: 63 additions & 3 deletions aws/s3/concurrency.go → aws/s3/s3_concurrent.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package s3

import (
"bytes"
"context"
"errors"
"sync"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3/types"
)

Expand Down Expand Up @@ -38,6 +41,11 @@ import (
// to be finished via the WorkerGroup's sync.WaitGroup, after which it closes up the remaining channels.
// 10. Each worker is returned to the worker pool.

type S3Concurrent struct {
S3
manager *ConcurrencyManager
}

type ConcurrencyManager struct {
memoryPool chan int64
workerPool chan *worker
Expand All @@ -59,9 +67,33 @@ type worker struct {
output chan HydratedFile
}

// NewConcurrencyManager returns a new ConcurrencyManager set with the given specifications.
// This enables the use of the S3Client's GetAll function.
func NewConcurrencyManager(maxWorkers, maxWorkersPerRequest, maxBytes int) *ConcurrencyManager {
// NewConcurrent returns an S3Concurrent client, which embeds the given S3 client, along with a ConcurrencyManager
// to allow the use of the GetAll function. The GetAll function can download multiple files at once while retaining order.
// Ensure that the HTTP client option used to create the S3 client is configured to
// make use of the specified maxConnections. Also, ensure that the S3 Client
// has access to maxBytes in memory to avoid out of memory errors.
func NewConcurrent(s3Client S3, maxConnections, maxConnectionsPerRequest, maxBytes int) (S3Concurrent, error) {

if !s3Client.Ready() {
return S3Concurrent{}, errors.New("S3 client needs to be initialised to add concurrency limits")
}
if maxConnections <= 0 || maxConnectionsPerRequest <= 0 || maxBytes <= 0 {
return S3Concurrent{}, errors.New("all parameters must be greater than 0")
}
if maxConnections > maxBytes {
return S3Concurrent{}, errors.New("max bytes must be greater than or equal to max connections")
}
if maxConnectionsPerRequest > maxConnections {
return S3Concurrent{}, errors.New("max connections must be greater than or equal to max connections per request")
}
return S3Concurrent{
S3: s3Client,
manager: newConcurrencyManager(maxConnections, maxConnectionsPerRequest, maxBytes),
}, nil
}

// newConcurrencyManager returns a new ConcurrencyManager set with the given specifications.
func newConcurrencyManager(maxWorkers, maxWorkersPerRequest, maxBytes int) *ConcurrencyManager {
cm := ConcurrencyManager{}

// Create worker pool
Expand Down Expand Up @@ -92,6 +124,34 @@ func NewConcurrencyManager(maxWorkers, maxWorkersPerRequest, maxBytes int) *Conc
return &cm
}

// GetAll gets the objects specified from bucket and writes the resulting HydratedFiles
// to the returned output channel. The closure of this channel is handled, however it's the caller's
// responsibility to purge the channel, and handle any errors present in the HydratedFiles.
// If the ConcurrencyManager is not initialised before calling GetAll, an output channel
// containing a single HydratedFile with an error is returned.
// Version can be empty, but must be the same for all objects.
func (s *S3Concurrent) GetAll(bucket, version string, objects []types.Object) chan HydratedFile {

if s.manager == nil {
output := make(chan HydratedFile, 1)
output <- HydratedFile{Error: errors.New("error getting files from S3, Concurrency Manager not initialised")}
close(output)
return output
}
processFunc := func(input types.Object) HydratedFile {
buf := bytes.NewBuffer(make([]byte, 0, input.Size))
key := aws.ToString(input.Key)
err := s.Get(bucket, key, version, buf)

return HydratedFile{
Key: key,
Data: buf.Bytes(),
Error: err,
}
}
return s.manager.Process(processFunc, objects)
}

// getWorker retrieves a worker from the manager's worker pool.
func (cm *ConcurrencyManager) getWorker() *worker {
return <-cm.workerPool
Expand Down
52 changes: 26 additions & 26 deletions aws/s3/s3_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,45 +320,45 @@ func TestS3GetAll(t *testing.T) {
defer teardown()

// ASSERT parameter errors.
emptyClient := S3{}
err := emptyClient.AddConcurrencyManager(100, 10, 1000)
emptyBaseClient := S3{}
_, err := NewConcurrent(emptyBaseClient, 100, 10, 1000)
assert.NotNil(t, err)

client, _ := New()
err = client.AddConcurrencyManager(0, 100, 1000)
baseClient, _ := New()
_, err = NewConcurrent(baseClient, 0, 100, 1000)
assert.NotNil(t, err)
err = client.AddConcurrencyManager(100, 0, 1000)
_, err = NewConcurrent(baseClient, 100, 0, 1000)
assert.NotNil(t, err)
err = client.AddConcurrencyManager(100, 100, 0)
_, err = NewConcurrent(baseClient, 100, 100, 0)
assert.NotNil(t, err)
err = client.AddConcurrencyManager(100, 10, 99)
_, err = NewConcurrent(baseClient, 100, 10, 99)
assert.NotNil(t, err)
err = client.AddConcurrencyManager(100, 101, 1000)
_, err = NewConcurrent(baseClient, 100, 101, 1000)
assert.NotNil(t, err)

err = client.AddConcurrencyManager(100, 10, 1000)
client, err := NewConcurrent(baseClient, 100, 10, 1000)
require.Nil(t, err, fmt.Sprintf("error creating s3 client concurrency manager: %v", err))

// ASSERT computed fields.
assert.Equal(t, 100, len(client.concurrenter.workerPool))
assert.Equal(t, 100, len(client.concurrenter.memoryPool))
assert.Equal(t, int64(10), client.concurrenter.memoryChunkSize)
assert.Equal(t, 10, client.concurrenter.maxWorkersPerRequest)
assert.Equal(t, 100, len(client.manager.workerPool))
assert.Equal(t, 100, len(client.manager.memoryPool))
assert.Equal(t, int64(10), client.manager.memoryChunkSize)
assert.Equal(t, 10, client.manager.maxWorkersPerRequest)

// ASSERT memory chunk size is correct in memory pool.
chunk := <-client.concurrenter.memoryPool
chunk := <-client.manager.memoryPool
assert.Equal(t, int64(10), chunk)
client.concurrenter.memoryPool <- chunk
client.manager.memoryPool <- chunk

// ASSERT worker/memory get/release methods work expectedly.
w := client.concurrenter.getWorker()
assert.Equal(t, 99, len(client.concurrenter.workerPool))
client.concurrenter.returnWorker(w)
assert.Equal(t, 100, len(client.concurrenter.workerPool))
client.concurrenter.secureMemory(20)
assert.Equal(t, 98, len(client.concurrenter.memoryPool))
client.concurrenter.releaseMemory(20)
assert.Equal(t, 100, len(client.concurrenter.memoryPool))
w := client.manager.getWorker()
assert.Equal(t, 99, len(client.manager.workerPool))
client.manager.returnWorker(w)
assert.Equal(t, 100, len(client.manager.workerPool))
client.manager.secureMemory(20)
assert.Equal(t, 98, len(client.manager.memoryPool))
client.manager.releaseMemory(20)
assert.Equal(t, 100, len(client.manager.memoryPool))

// ARRANGE bucket with test objects.
total := 20
Expand All @@ -384,11 +384,11 @@ func TestS3GetAll(t *testing.T) {

// ASSERT all workers and memory returned to pools.
time.Sleep(2 * time.Second)
assert.Equal(t, 100, len(client.concurrenter.workerPool))
assert.Equal(t, 100, len(client.concurrenter.memoryPool))
assert.Equal(t, 100, len(client.manager.workerPool))
assert.Equal(t, 100, len(client.manager.memoryPool))

// ASSERT that process blocked when all memory secured.
client.concurrenter.secureMemory(1000)
client.manager.secureMemory(1000)
output2 := client.GetAll(testBucket, "", objects)

for {
Expand Down

0 comments on commit 1d2f64d

Please sign in to comment.