Skip to content

Commit

Permalink
blob/s3blob: support URL query parameters to override aws.Config fiel…
Browse files Browse the repository at this point in the history
…ds (#1359)
  • Loading branch information
vangent authored Feb 28, 2019
1 parent 283bafc commit 31355ab
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 9 deletions.
14 changes: 14 additions & 0 deletions aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,17 @@ func SessionConfig(sess *session.Session) *aws.Config {
func ConfigCredentials(cfg *aws.Config) *credentials.Credentials {
return cfg.Credentials
}

// ConfigOverrider implements client.ConfigProvider by overlaying a list of
// configurations over a base configuration provider.
type ConfigOverrider struct {
Base client.ConfigProvider
Configs []*aws.Config
}

// ClientConfig calls the base provider's ClientConfig method with co.Configs
// followed by the arguments given to ClientConfig.
func (co ConfigOverrider) ClientConfig(serviceName string, cfgs ...*aws.Config) client.Config {
cfgs = append(co.Configs[:len(co.Configs):len(co.Configs)], cfgs...)
return co.Base.ClientConfig(serviceName, cfgs...)
}
4 changes: 3 additions & 1 deletion blob/s3blob/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"context"
"log"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"gocloud.dev/blob"
"gocloud.dev/blob/s3blob"
Expand All @@ -26,7 +27,8 @@ import (
func Example() {
// Establish an AWS session.
// See https://docs.aws.amazon.com/sdk-for-go/api/aws/session/ for more info.
session, err := session.NewSession(nil)
// The region must match the region for "my_bucket".
session, err := session.NewSession(&aws.Config{Region: aws.String("us-west-1")})
if err != nil {
log.Fatal(err)
}
Expand Down
66 changes: 58 additions & 8 deletions blob/s3blob/s3blob.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ import (
"strings"
"sync"

gcaws "gocloud.dev/aws"
"gocloud.dev/blob"
"gocloud.dev/blob/driver"
"gocloud.dev/gcerrors"
Expand All @@ -81,6 +82,11 @@ func init() {
}

// URLOpener opens S3 URLs like "s3://mybucket".
// The following query options are supported:
// - region: The AWS region for requests; sets aws.Config.Region.
// - endpoint: The endpoint URL (hostname only or fully qualified URI); sets aws.Config.Endpoint.
// - disableSSL: A value of "true" disables SSL when sending requests; sets aws.Config.DisableSSL.
// - s3ForcePathStyle: A value of "true" forces the request to use path-style addressing; sets aws.Config.S3ForcePathStyle.
type URLOpener struct {
// ConfigProvider must be set to a non-nil value.
ConfigProvider client.ConfigProvider
Expand Down Expand Up @@ -120,10 +126,54 @@ const Scheme = "s3"

// OpenBucketURL opens the S3 bucket with the same name as the host in the URL.
func (o *URLOpener) OpenBucketURL(ctx context.Context, u *url.URL) (*blob.Bucket, error) {
for k := range u.Query() {
return nil, fmt.Errorf("open S3 bucket %q: unknown query parameter %s", u, k)
configProvider := &gcaws.ConfigOverrider{
Base: o.ConfigProvider,
}
return OpenBucket(ctx, o.ConfigProvider, u.Host, &o.Options)
overrideCfg, err := o.forParams(ctx, u.Query())
if err != nil {
return nil, fmt.Errorf("open bucket %v: %v", u, err)
}
if overrideCfg != nil {
configProvider.Configs = append(configProvider.Configs, overrideCfg)
}
return OpenBucket(ctx, configProvider, u.Host, &o.Options)
}

var legalQueryParam = map[string]bool{
"region": true,
"endpoint": true,
"disableSSL": true,
"s3ForcePathStyle": true,
}

func (o *URLOpener) forParams(ctx context.Context, q url.Values) (*aws.Config, error) {
for k := range q {
if !legalQueryParam[k] {
return nil, fmt.Errorf("unknown S3 query parameter %s", k)
}
}
var cfg aws.Config
override := false
if region := q["region"]; len(region) > 0 {
cfg.Region = aws.String(region[0])
override = true
}
if endpoint := q["endpoint"]; len(endpoint) > 0 {
cfg.Endpoint = aws.String(endpoint[0])
override = true
}
if disableSSL := q["disableSSL"]; len(disableSSL) > 0 {
cfg.DisableSSL = aws.Bool(disableSSL[0] == "true")
override = true
}
if s3ForcePathStyle := q["s3ForcePathStyle"]; len(s3ForcePathStyle) > 0 {
cfg.S3ForcePathStyle = aws.Bool(s3ForcePathStyle[0] == "true")
override = true
}
if !override {
return nil, nil
}
return &cfg, nil
}

// Options sets options for constructing a *blob.Bucket backed by fileblob.
Expand All @@ -139,13 +189,14 @@ func openBucket(ctx context.Context, sess client.ConfigProvider, bucketName stri
}
return &bucket{
name: bucketName,
sess: sess,
client: s3.New(sess),
}, nil
}

// OpenBucket returns a *blob.Bucket backed by S3. See the package documentation
// for an example.
// OpenBucket returns a *blob.Bucket backed by S3.
// AWS buckets are bound to a region; sess must have been created using an
// aws.Config with Region set to the right region for bucketName.
// See the package documentation for an example.
func OpenBucket(ctx context.Context, sess client.ConfigProvider, bucketName string, opts *Options) (*blob.Bucket, error) {
drv, err := openBucket(ctx, sess, bucketName, opts)
if err != nil {
Expand Down Expand Up @@ -260,7 +311,6 @@ func (w *writer) Close() error {
// bucket represents an S3 bucket and handles read, write and delete operations.
type bucket struct {
name string
sess client.ConfigProvider
client *s3.S3
}

Expand Down Expand Up @@ -525,7 +575,7 @@ func unescapeKey(key string) string {
// NewTypedWriter implements driver.NewTypedWriter.
func (b *bucket) NewTypedWriter(ctx context.Context, key string, contentType string, opts *driver.WriterOptions) (driver.Writer, error) {
key = escapeKey(key)
uploader := s3manager.NewUploader(b.sess, func(u *s3manager.Uploader) {
uploader := s3manager.NewUploaderWithClient(b.client, func(u *s3manager.Uploader) {
if opts.BufferSize != 0 {
u.PartSize = int64(opts.BufferSize)
}
Expand Down
70 changes: 70 additions & 0 deletions blob/s3blob/s3blob_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ import (
"context"
"errors"
"fmt"
"github.com/google/go-cmp/cmp"
"net/http"
"net/url"
"testing"

"github.com/aws/aws-sdk-go/aws"
Expand Down Expand Up @@ -220,3 +222,71 @@ func TestOpenBucket(t *testing.T) {
})
}
}

func TestURLOpenerForParams(t *testing.T) {
ctx := context.Background()

tests := []struct {
name string
query url.Values
wantCfg *aws.Config
wantErr bool
}{
{
name: "No overrides",
query: url.Values{},
wantCfg: nil,
},
{
name: "Invalid query parameter",
query: url.Values{"foo": {"bar"}},
wantErr: true,
},
{
name: "Region",
query: url.Values{"region": {"my_region"}},
wantCfg: &aws.Config{Region: aws.String("my_region")},
},
{
name: "Endpoint",
query: url.Values{"endpoint": {"foo"}},
wantCfg: &aws.Config{Endpoint: aws.String("foo")},
},
{
name: "DisableSSL true",
query: url.Values{"disableSSL": {"true"}},
wantCfg: &aws.Config{DisableSSL: aws.Bool(true)},
},
{
name: "DisableSSL false",
query: url.Values{"disableSSL": {"not-true"}},
wantCfg: &aws.Config{DisableSSL: aws.Bool(false)},
},
{
name: "S3ForcePathStyle true",
query: url.Values{"s3ForcePathStyle": {"true"}},
wantCfg: &aws.Config{S3ForcePathStyle: aws.Bool(true)},
},
{
name: "S3ForcePathStyle false",
query: url.Values{"s3ForcePathStyle": {"not-true"}},
wantCfg: &aws.Config{S3ForcePathStyle: aws.Bool(false)},
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
u := &URLOpener{}
got, err := u.forParams(ctx, test.query)
if (err != nil) != test.wantErr {
t.Errorf("got err %v want error %v", err, test.wantErr)
}
if err != nil {
return
}
if diff := cmp.Diff(got, test.wantCfg); diff != "" {
t.Errorf("opener.forParams(...) diff (-want +got):\n%s", diff)
}
})
}
}

0 comments on commit 31355ab

Please sign in to comment.