Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

blob/s3blob: support URL query parameters to override aws.Config fields #1359

Merged
merged 16 commits into from
Feb 28, 2019
14 changes: 14 additions & 0 deletions aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,17 @@ func SessionConfig(sess *session.Session) *aws.Config {
func ConfigCredentials(cfg *aws.Config) *credentials.Credentials {
return cfg.Credentials
}

// ConfigOverrider implements client.ConfigProvider by overlaying a list of
// configurations over a base configuration provider.
type ConfigOverrider struct {
Base client.ConfigProvider
Configs []*aws.Config
}

// ClientConfig calls the base provider's ClientConfig method with co.Configs
// followed by the arguments given to ClientConfig.
func (co ConfigOverrider) ClientConfig(serviceName string, cfgs ...*aws.Config) client.Config {
cfgs = append(co.Configs[:len(co.Configs):len(co.Configs)], cfgs...)
return co.Base.ClientConfig(serviceName, cfgs...)
}
4 changes: 3 additions & 1 deletion blob/s3blob/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"context"
"log"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"gocloud.dev/blob"
"gocloud.dev/blob/s3blob"
Expand All @@ -26,7 +27,8 @@ import (
func Example() {
// Establish an AWS session.
// See https://docs.aws.amazon.com/sdk-for-go/api/aws/session/ for more info.
session, err := session.NewSession(nil)
// The region must match the region for "my_bucket".
session, err := session.NewSession(&aws.Config{Region: aws.String("us-west-1")})
if err != nil {
log.Fatal(err)
}
Expand Down
66 changes: 58 additions & 8 deletions blob/s3blob/s3blob.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ import (
"strings"
"sync"

gcaws "gocloud.dev/aws"
"gocloud.dev/blob"
"gocloud.dev/blob/driver"
"gocloud.dev/gcerrors"
Expand All @@ -81,6 +82,11 @@ func init() {
}

// URLOpener opens S3 URLs like "s3://mybucket".
// The following query options are supported:
// - region: The AWS region for requests; sets aws.Config.Region.
// - endpoint: The endpoint URL (hostname only or fully qualified URI); sets aws.Config.Endpoint.
// - disableSSL: A value of "true" disables SSL when sending requests; sets aws.Config.DisableSSL.
// - s3ForcePathStyle: A value of "true" forces the request to use path-style addressing; sets aws.Config.S3ForcePathStyle.
type URLOpener struct {
// ConfigProvider must be set to a non-nil value.
ConfigProvider client.ConfigProvider
Expand Down Expand Up @@ -120,10 +126,54 @@ const Scheme = "s3"

// OpenBucketURL opens the S3 bucket with the same name as the host in the URL.
func (o *URLOpener) OpenBucketURL(ctx context.Context, u *url.URL) (*blob.Bucket, error) {
for k := range u.Query() {
return nil, fmt.Errorf("open S3 bucket %q: unknown query parameter %s", u, k)
configProvider := &gcaws.ConfigOverrider{
Base: o.ConfigProvider,
}
return OpenBucket(ctx, o.ConfigProvider, u.Host, &o.Options)
overrideCfg, err := o.forParams(ctx, u.Query())
if err != nil {
return nil, fmt.Errorf("open bucket %v: %v", u, err)
}
if overrideCfg != nil {
configProvider.Configs = append(configProvider.Configs, overrideCfg)
}
return OpenBucket(ctx, configProvider, u.Host, &o.Options)
}

var legalQueryParam = map[string]bool{
"region": true,
"endpoint": true,
"disableSSL": true,
"s3ForcePathStyle": true,
}

func (o *URLOpener) forParams(ctx context.Context, q url.Values) (*aws.Config, error) {
for k := range q {
if !legalQueryParam[k] {
return nil, fmt.Errorf("unknown S3 query parameter %s", k)
}
}
var cfg aws.Config
override := false
if region := q["region"]; len(region) > 0 {
cfg.Region = aws.String(region[0])
override = true
}
if endpoint := q["endpoint"]; len(endpoint) > 0 {
zombiezen marked this conversation as resolved.
Show resolved Hide resolved
cfg.Endpoint = aws.String(endpoint[0])
override = true
}
if disableSSL := q["disableSSL"]; len(disableSSL) > 0 {
cfg.DisableSSL = aws.Bool(disableSSL[0] == "true")
override = true
}
if s3ForcePathStyle := q["s3ForcePathStyle"]; len(s3ForcePathStyle) > 0 {
cfg.S3ForcePathStyle = aws.Bool(s3ForcePathStyle[0] == "true")
override = true
}
if !override {
return nil, nil
}
return &cfg, nil
}

// Options sets options for constructing a *blob.Bucket backed by fileblob.
Expand All @@ -139,13 +189,14 @@ func openBucket(ctx context.Context, sess client.ConfigProvider, bucketName stri
}
return &bucket{
name: bucketName,
sess: sess,
client: s3.New(sess),
}, nil
}

// OpenBucket returns a *blob.Bucket backed by S3. See the package documentation
// for an example.
// OpenBucket returns a *blob.Bucket backed by S3.
// AWS buckets are bound to a region; sess must have been created using an
// aws.Config with Region set to the right region for bucketName.
// See the package documentation for an example.
func OpenBucket(ctx context.Context, sess client.ConfigProvider, bucketName string, opts *Options) (*blob.Bucket, error) {
drv, err := openBucket(ctx, sess, bucketName, opts)
if err != nil {
Expand Down Expand Up @@ -260,7 +311,6 @@ func (w *writer) Close() error {
// bucket represents an S3 bucket and handles read, write and delete operations.
type bucket struct {
name string
sess client.ConfigProvider
client *s3.S3
}

Expand Down Expand Up @@ -525,7 +575,7 @@ func unescapeKey(key string) string {
// NewTypedWriter implements driver.NewTypedWriter.
func (b *bucket) NewTypedWriter(ctx context.Context, key string, contentType string, opts *driver.WriterOptions) (driver.Writer, error) {
key = escapeKey(key)
uploader := s3manager.NewUploader(b.sess, func(u *s3manager.Uploader) {
uploader := s3manager.NewUploaderWithClient(b.client, func(u *s3manager.Uploader) {
if opts.BufferSize != 0 {
u.PartSize = int64(opts.BufferSize)
}
Expand Down
70 changes: 70 additions & 0 deletions blob/s3blob/s3blob_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ import (
"context"
"errors"
"fmt"
"github.com/google/go-cmp/cmp"
"net/http"
"net/url"
"testing"

"github.com/aws/aws-sdk-go/aws"
Expand Down Expand Up @@ -220,3 +222,71 @@ func TestOpenBucket(t *testing.T) {
})
}
}

func TestURLOpenerForParams(t *testing.T) {
ctx := context.Background()

tests := []struct {
name string
query url.Values
wantCfg *aws.Config
wantErr bool
}{
{
name: "No overrides",
query: url.Values{},
wantCfg: nil,
},
{
name: "Invalid query parameter",
query: url.Values{"foo": {"bar"}},
wantErr: true,
},
{
name: "Region",
query: url.Values{"region": {"my_region"}},
wantCfg: &aws.Config{Region: aws.String("my_region")},
},
{
name: "Endpoint",
query: url.Values{"endpoint": {"foo"}},
wantCfg: &aws.Config{Endpoint: aws.String("foo")},
},
{
name: "DisableSSL true",
query: url.Values{"disableSSL": {"true"}},
wantCfg: &aws.Config{DisableSSL: aws.Bool(true)},
},
{
name: "DisableSSL false",
query: url.Values{"disableSSL": {"not-true"}},
wantCfg: &aws.Config{DisableSSL: aws.Bool(false)},
},
{
name: "S3ForcePathStyle true",
query: url.Values{"s3ForcePathStyle": {"true"}},
wantCfg: &aws.Config{S3ForcePathStyle: aws.Bool(true)},
},
{
name: "S3ForcePathStyle false",
query: url.Values{"s3ForcePathStyle": {"not-true"}},
wantCfg: &aws.Config{S3ForcePathStyle: aws.Bool(false)},
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
u := &URLOpener{}
got, err := u.forParams(ctx, test.query)
if (err != nil) != test.wantErr {
t.Errorf("got err %v want error %v", err, test.wantErr)
}
if err != nil {
return
}
if diff := cmp.Diff(got, test.wantCfg); diff != "" {
t.Errorf("opener.forParams(...) diff (-want +got):\n%s", diff)
}
})
}
}