Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

aws: Add singleflight support to SafeCredentialsProvider #503

Merged
merged 6 commits into from
Mar 17, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ LINTIGNOREINFLECTS3UPLOAD='service/s3/s3manager/upload\.go:.+struct field SSEKMS
LINTIGNOREDEPS='vendor/.+\.go'
LINTIGNOREPKGCOMMENT='service/[^/]+/doc_custom.go:.+package comment should be of the form'
LINTIGNOREENDPOINTS='aws/endpoints/defaults.go:.+(method|const) .+ should be '
LINTIGNORESINGLEFIGHT='internal/sync/singleflight/singleflight.go:.+error should be the last type'
UNIT_TEST_TAGS="example codegen awsinclude"
ALL_TAGS="example codegen awsinclude integration perftest sdktool"

Expand Down Expand Up @@ -145,7 +146,16 @@ verify: lint vet sdkv1check
lint:
@echo "go lint SDK and vendor packages"
@lint=`golint ./...`; \
dolint=`echo "$$lint" | grep -E -v -e ${LINTIGNOREDOC} -e ${LINTIGNORECONST} -e ${LINTIGNORESTUTTER} -e ${LINTIGNOREINFLECT} -e ${LINTIGNOREDEPS} -e ${LINTIGNOREINFLECTS3UPLOAD} -e ${LINTIGNOREPKGCOMMENT} -e ${LINTIGNOREENDPOINTS}`; \
dolint=`echo "$$lint" | grep -E -v \
-e ${LINTIGNOREDOC} \
-e ${LINTIGNORECONST} \
-e ${LINTIGNORESTUTTER} \
-e ${LINTIGNOREINFLECT} \
-e ${LINTIGNOREDEPS} \
-e ${LINTIGNOREINFLECTS3UPLOAD} \
-e ${LINTIGNOREPKGCOMMENT} \
-e ${LINTIGNOREENDPOINTS} \
-e ${LINTIGNORESINGLEFIGHT}`; \
echo "$$dolint"; \
if [ "$$dolint" != "" ]; then exit 1; fi

Expand Down
4 changes: 2 additions & 2 deletions aws/chain_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ func NewChainProvider(providers []CredentialsProvider) *ChainProvider {
//
// If a provider is found it will be cached and any calls to IsExpired()
// will return the expired state of the cached provider.
func (c *ChainProvider) retrieveFn(ctx context.Context) (Credentials, error) {
func (c *ChainProvider) retrieveFn() (Credentials, error) {
var errs []error
for _, p := range c.Providers {
creds, err := p.Retrieve(ctx)
creds, err := p.Retrieve(context.Background())
if err == nil {
return creds, nil
}
Expand Down
23 changes: 16 additions & 7 deletions aws/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ package aws
import (
"context"
"math"
"sync"
"sync/atomic"
"time"

"github.com/aws/aws-sdk-go-v2/aws/awserr"
"github.com/aws/aws-sdk-go-v2/internal/sdk"
"github.com/aws/aws-sdk-go-v2/internal/sync/singleflight"
)

// NeverExpire is the time identifier used when a credential provider's
Expand Down Expand Up @@ -83,10 +84,10 @@ type CredentialsProvider interface {
// SafeCredentialsProvider provides caching and concurrency safe credentials
// retrieval via the RetrieveFn.
type SafeCredentialsProvider struct {
RetrieveFn func(ctx context.Context) (Credentials, error)
RetrieveFn func() (Credentials, error)

creds atomic.Value
m sync.Mutex
sf singleflight.Group
}

// Retrieve returns the credentials. If the credentials have already been
Expand All @@ -99,15 +100,23 @@ func (p *SafeCredentialsProvider) Retrieve(ctx context.Context) (Credentials, er
return *creds, nil
}

p.m.Lock()
defer p.m.Unlock()
resCh := p.sf.DoChan("", p.singleRetrieve)
select {
case res := <-resCh:
return res.Val.(Credentials), res.Err
case <-ctx.Done():
return Credentials{}, awserr.New("RequestCanceled",
"request context canceled", ctx.Err())
}
}

// Make sure another goroutine didn't already update the credentials.
func (p *SafeCredentialsProvider) singleRetrieve() (interface{}, error) {
if creds := p.getCreds(); creds != nil {
return *creds, nil
}

creds, err := p.RetrieveFn(ctx)
creds, err := p.RetrieveFn()

if err != nil {
return Credentials{}, err
}
Expand Down
4 changes: 2 additions & 2 deletions aws/credentials_bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
)

func BenchmarkSafeCredentialsProvider_Retrieve(b *testing.B) {
retrieveFn := func(ctx context.Context) (Credentials, error) {
retrieveFn := func() (Credentials, error) {
return Credentials{
AccessKeyID: "key",
SecretAccessKey: "secret",
Expand Down Expand Up @@ -45,7 +45,7 @@ func BenchmarkSafeCredentialsProvider_Retrieve(b *testing.B) {
}

func BenchmarkSafeCredentialsProvider_Retrieve_Invalidate(b *testing.B) {
retrieveFn := func(ctx context.Context) (Credentials, error) {
retrieveFn := func() (Credentials, error) {
time.Sleep(time.Millisecond)
return Credentials{
AccessKeyID: "key",
Expand Down
8 changes: 4 additions & 4 deletions aws/credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func TestSafeCredentialsProvider_Cache(t *testing.T) {

var called bool
p := &SafeCredentialsProvider{
RetrieveFn: func(ctx context.Context) (Credentials, error) {
RetrieveFn: func() (Credentials, error) {
if called {
t.Fatalf("expect RetrieveFn to only be called once")
}
Expand Down Expand Up @@ -108,7 +108,7 @@ func TestSafeCredentialsProvider_Expires(t *testing.T) {
for _, c := range cases {
var called int
p := &SafeCredentialsProvider{
RetrieveFn: func(ctx context.Context) (Credentials, error) {
RetrieveFn: func() (Credentials, error) {
called++
return c.Creds(), nil
},
Expand All @@ -132,7 +132,7 @@ func TestSafeCredentialsProvider_Expires(t *testing.T) {

func TestSafeCredentialsProvider_Error(t *testing.T) {
p := &SafeCredentialsProvider{
RetrieveFn: func(ctx context.Context) (Credentials, error) {
RetrieveFn: func() (Credentials, error) {
return Credentials{}, fmt.Errorf("failed")
},
}
Expand All @@ -156,7 +156,7 @@ func TestSafeCredentialsProvider_Race(t *testing.T) {
}
var called bool
p := &SafeCredentialsProvider{
RetrieveFn: func(ctx context.Context) (Credentials, error) {
RetrieveFn: func() (Credentials, error) {
time.Sleep(time.Duration(rand.Intn(10)) * time.Millisecond)
if called {
t.Fatalf("expect RetrieveFn only called once")
Expand Down
6 changes: 3 additions & 3 deletions aws/ec2rolecreds/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ func New(client *ec2metadata.Client, options ...func(*ProviderOptions)) *Provide
// Retrieve retrieves credentials from the EC2 service.
// Error will be returned if the request fails, or unable to extract
// the desired credentials.
func (p *Provider) retrieveFn(ctx context.Context) (aws.Credentials, error) {
credsList, err := requestCredList(ctx, p.client)
func (p *Provider) retrieveFn() (aws.Credentials, error) {
credsList, err := requestCredList(context.Background(), p.client)
if err != nil {
return aws.Credentials{}, err
}
Expand All @@ -80,7 +80,7 @@ func (p *Provider) retrieveFn(ctx context.Context) (aws.Credentials, error) {
}
credsName := credsList[0]

roleCreds, err := requestCred(ctx, p.client, credsName)
roleCreds, err := requestCred(context.Background(), p.client, credsName)
if err != nil {
return aws.Credentials{}, err
}
Expand Down
8 changes: 3 additions & 5 deletions aws/endpointcreds/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
package endpointcreds

import (
"context"
"encoding/json"
"time"

Expand Down Expand Up @@ -99,8 +98,8 @@ func New(cfg aws.Config, options ...func(*ProviderOptions)) *Provider {

// Retrieve will attempt to request the credentials from the endpoint the Provider
// was configured for. And error will be returned if the retrieval fails.
func (p *Provider) retrieveFn(ctx context.Context) (aws.Credentials, error) {
resp, err := p.getCredentials(ctx)
func (p *Provider) retrieveFn() (aws.Credentials, error) {
resp, err := p.getCredentials()
if err != nil {
return aws.Credentials{},
awserr.New("CredentialsEndpointError", "failed to load credentials", err)
Expand Down Expand Up @@ -133,15 +132,14 @@ type errorOutput struct {
Message string `json:"message"`
}

func (p *Provider) getCredentials(ctx context.Context) (*getCredentialsOutput, error) {
func (p *Provider) getCredentials() (*getCredentialsOutput, error) {
op := &aws.Operation{
Name: "GetCredentials",
HTTPMethod: "GET",
}

out := &getCredentialsOutput{}
req := p.client.NewRequest(op, nil, out)
req.SetContext(ctx)
req.HTTPRequest.Header.Set("Accept", "application/json")
if authToken := p.options.AuthorizationToken; len(authToken) != 0 {
req.HTTPRequest.Header.Set("Authorization", authToken)
Expand Down
12 changes: 6 additions & 6 deletions aws/processcreds/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,8 @@ type credentialProcessResponse struct {
}

// retrieveFn executes the 'credential_process' and returns the credentials.
func (p *Provider) retrieveFn(ctx context.Context) (aws.Credentials, error) {
out, err := p.executeCredentialProcess(ctx)
func (p *Provider) retrieveFn() (aws.Credentials, error) {
out, err := p.executeCredentialProcess()
if err != nil {
return aws.Credentials{Source: ProviderName}, err
}
Expand Down Expand Up @@ -253,7 +253,7 @@ func (p *Provider) retrieveFn(ctx context.Context) (aws.Credentials, error) {
}

// prepareCommand prepares the command to be executed.
func (p *Provider) prepareCommand(ctx context.Context) (context.Context, context.CancelFunc, error) {
func (p *Provider) prepareCommand() (context.Context, context.CancelFunc, error) {

var cmdArgs []string
if runtime.GOOS == "windows" {
Expand All @@ -278,7 +278,7 @@ func (p *Provider) prepareCommand(ctx context.Context) (context.Context, context
}
}

timeoutCtx, cancelFunc := context.WithTimeout(ctx, p.options.Timeout)
timeoutCtx, cancelFunc := context.WithTimeout(context.Background(), p.options.Timeout)

cmdArgs = append(cmdArgs, p.originalCommand...)
p.command = exec.CommandContext(timeoutCtx, cmdArgs[0], cmdArgs[1:]...)
Expand All @@ -289,8 +289,8 @@ func (p *Provider) prepareCommand(ctx context.Context) (context.Context, context

// executeCredentialProcess starts the credential process on the OS and
// returns the results or an error.
func (p *Provider) executeCredentialProcess(ctx context.Context) ([]byte, error) {
ctx, cancelFunc, err := p.prepareCommand(ctx)
func (p *Provider) executeCredentialProcess() ([]byte, error) {
ctx, cancelFunc, err := p.prepareCommand()
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions aws/stscreds/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ func NewAssumeRoleProvider(client AssumeRoler, roleARN string, options ...func(*
}

// Retrieve generates a new set of temporary credentials using STS.
func (p *AssumeRoleProvider) retrieveFn(ctx context.Context) (aws.Credentials, error) {
func (p *AssumeRoleProvider) retrieveFn() (aws.Credentials, error) {
// Apply defaults where parameters are not set.
if len(p.options.RoleSessionName) == 0 {
// Try to work out a role name that will hopefully end up unique.
Expand Down Expand Up @@ -246,7 +246,7 @@ func (p *AssumeRoleProvider) retrieveFn(ctx context.Context) (aws.Credentials, e
}

req := p.client.AssumeRoleRequest(input)
resp, err := req.Send(ctx)
resp, err := req.Send(context.Background())
if err != nil {
return aws.Credentials{Source: ProviderName}, err
}
Expand Down
4 changes: 2 additions & 2 deletions aws/stscreds/web_identity_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func NewWebIdentityRoleProvider(svc stsiface.ClientAPI, roleARN, roleSessionName
// retrieve will attempt to assume a role from a token which is located at
// 'WebIdentityTokenFilePath' specified destination and if that is empty an
// error will be returned.
func (p *WebIdentityRoleProvider) retrieveFn(ctx context.Context) (aws.Credentials, error) {
func (p *WebIdentityRoleProvider) retrieveFn() (aws.Credentials, error) {
b, err := p.tokenRetriever.GetIdentityToken()
if err != nil {
return aws.Credentials{}, awserr.New(ErrCodeWebIdentity, "failed to retrieve jwt from provide source", err)
Expand All @@ -104,7 +104,7 @@ func (p *WebIdentityRoleProvider) retrieveFn(ctx context.Context) (aws.Credentia
// InvalidIdentityToken error is a temporary error that can occur
// when assuming an Role with a JWT web identity token.
req.Retryer = retry.AddWithErrorCodes(req.Retryer, sts.ErrCodeInvalidIdentityTokenException)
resp, err := req.Send(ctx)
resp, err := req.Send(context.Background())
if err != nil {
return aws.Credentials{}, awserr.New(ErrCodeWebIdentity, "failed to retrieve credentials", err)
}
Expand Down
27 changes: 27 additions & 0 deletions internal/sync/singleflight/LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
Copyright (c) 2009 The Go Authors. All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:

* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Loading