Skip to content

Commit

Permalink
Add sso session and token provider support (#4885)
Browse files Browse the repository at this point in the history
Update sso credential provider logic to support both sso token provider and legacy sso config, which can all be resolved from updated shared config profile and sso session section.
  • Loading branch information
wty-Bryant authored Jul 6, 2023
1 parent 75e508d commit d744468
Show file tree
Hide file tree
Showing 22 changed files with 1,255 additions and 93 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG_PENDING.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@
### SDK Enhancements

### SDK Bugs
* `aws/credentials/ssocreds`: Implement SSO token provider support for `sso-session` in AWS shared config.
* Fixes [4649](https://github.com/aws/aws-sdk-go/issues/4649)
50 changes: 50 additions & 0 deletions aws/auth/bearer/token.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package bearer

import (
"github.com/aws/aws-sdk-go/aws"
"time"
)

// Token provides a type wrapping a bearer token and expiration metadata.
type Token struct {
Value string

CanExpire bool
Expires time.Time
}

// Expired returns if the token's Expires time is before or equal to the time
// provided. If CanExpire is false, Expired will always return false.
func (t Token) Expired(now time.Time) bool {
if !t.CanExpire {
return false
}
now = now.Round(0)
return now.Equal(t.Expires) || now.After(t.Expires)
}

// TokenProvider provides interface for retrieving bearer tokens.
type TokenProvider interface {
RetrieveBearerToken(aws.Context) (Token, error)
}

// TokenProviderFunc provides a helper utility to wrap a function as a type
// that implements the TokenProvider interface.
type TokenProviderFunc func(aws.Context) (Token, error)

// RetrieveBearerToken calls the wrapped function, returning the Token or
// error.
func (fn TokenProviderFunc) RetrieveBearerToken(ctx aws.Context) (Token, error) {
return fn(ctx)
}

// StaticTokenProvider provides a utility for wrapping a static bearer token
// value within an implementation of a token provider.
type StaticTokenProvider struct {
Token Token
}

// RetrieveBearerToken returns the static token specified.
func (s StaticTokenProvider) RetrieveBearerToken(aws.Context) (Token, error) {
return s.Token, nil
}
75 changes: 41 additions & 34 deletions aws/credentials/ssocreds/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ import (
"crypto/sha1"
"encoding/hex"
"encoding/json"
"fmt"
"io/ioutil"
"path/filepath"
"strings"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/auth/bearer"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/credentials"
Expand Down Expand Up @@ -55,6 +55,19 @@ type Provider struct {

// The URL that points to the organization's AWS Single Sign-On (AWS SSO) user portal.
StartURL string

// The filepath the cached token will be retrieved from. If unset Provider will
// use the startURL to determine the filepath at.
//
// ~/.aws/sso/cache/<sha1-hex-encoded-startURL>.json
//
// If custom cached token filepath is used, the Provider's startUrl
// parameter will be ignored.
CachedTokenFilepath string

// Used by the SSOCredentialProvider if a token configuration
// profile is used in the shared config
TokenProvider bearer.TokenProvider
}

// NewCredentials returns a new AWS Single Sign-On (AWS SSO) credential provider. The ConfigProvider is expected to be configured
Expand Down Expand Up @@ -89,13 +102,31 @@ func (p *Provider) Retrieve() (credentials.Value, error) {
// RetrieveWithContext retrieves temporary AWS credentials from the configured Amazon Single Sign-On (AWS SSO) user portal
// by exchanging the accessToken present in ~/.aws/sso/cache.
func (p *Provider) RetrieveWithContext(ctx credentials.Context) (credentials.Value, error) {
tokenFile, err := loadTokenFile(p.StartURL)
if err != nil {
return credentials.Value{}, err
var accessToken *string
if p.TokenProvider != nil {
token, err := p.TokenProvider.RetrieveBearerToken(ctx)
if err != nil {
return credentials.Value{}, err
}
accessToken = &token.Value
} else {
if p.CachedTokenFilepath == "" {
cachedTokenFilePath, err := getCachedFilePath(p.StartURL)
if err != nil {
return credentials.Value{}, err
}
p.CachedTokenFilepath = cachedTokenFilePath
}

tokenFile, err := loadTokenFile(p.CachedTokenFilepath)
if err != nil {
return credentials.Value{}, err
}
accessToken = &tokenFile.AccessToken
}

output, err := p.Client.GetRoleCredentialsWithContext(ctx, &sso.GetRoleCredentialsInput{
AccessToken: &tokenFile.AccessToken,
AccessToken: accessToken,
AccountId: &p.AccountID,
RoleName: &p.RoleName,
})
Expand All @@ -114,32 +145,13 @@ func (p *Provider) RetrieveWithContext(ctx credentials.Context) (credentials.Val
}, nil
}

func getCacheFileName(url string) (string, error) {
func getCachedFilePath(startUrl string) (string, error) {
hash := sha1.New()
_, err := hash.Write([]byte(url))
_, err := hash.Write([]byte(startUrl))
if err != nil {
return "", err
}
return strings.ToLower(hex.EncodeToString(hash.Sum(nil))) + ".json", nil
}

type rfc3339 time.Time

func (r *rfc3339) UnmarshalJSON(bytes []byte) error {
var value string

if err := json.Unmarshal(bytes, &value); err != nil {
return err
}

parse, err := time.Parse(time.RFC3339, value)
if err != nil {
return fmt.Errorf("expected RFC3339 timestamp: %v", err)
}

*r = rfc3339(parse)

return nil
return filepath.Join(defaultCacheLocation(), strings.ToLower(hex.EncodeToString(hash.Sum(nil)))+".json"), nil
}

type token struct {
Expand All @@ -153,13 +165,8 @@ func (t token) Expired() bool {
return nowTime().Round(0).After(time.Time(t.ExpiresAt))
}

func loadTokenFile(startURL string) (t token, err error) {
key, err := getCacheFileName(startURL)
if err != nil {
return token{}, awserr.New(ErrCodeSSOProviderInvalidToken, invalidTokenMessage, err)
}

fileBytes, err := ioutil.ReadFile(filepath.Join(defaultCacheLocation(), key))
func loadTokenFile(cachedTokenPath string) (t token, err error) {
fileBytes, err := ioutil.ReadFile(cachedTokenPath)
if err != nil {
return token{}, awserr.New(ErrCodeSSOProviderInvalidToken, invalidTokenMessage, err)
}
Expand Down
129 changes: 106 additions & 23 deletions aws/credentials/ssocreds/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@ package ssocreds

import (
"fmt"
"path/filepath"
"reflect"
"testing"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/auth/bearer"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/sso"
Expand All @@ -24,14 +26,25 @@ type mockClient struct {
Output *sso.GetRoleCredentialsOutput
Err error

ExpectedAccountID string
ExpectedAccessToken string
ExpectedRoleName string
ExpectedClientRegion string
ExpectedAccountID string
ExpectedAccessToken string
ExpectedRoleName string

Response func(mockClient) (*sso.GetRoleCredentialsOutput, error)
}

type mockTokenProvider struct {
Response func() (bearer.Token, error)
}

func (p *mockTokenProvider) RetrieveBearerToken(ctx aws.Context) (bearer.Token, error) {
if p.Response == nil {
return bearer.Token{}, nil
}

return p.Response()
}

func (m mockClient) GetRoleCredentialsWithContext(ctx aws.Context, params *sso.GetRoleCredentialsInput, _ ...request.Option) (*sso.GetRoleCredentialsOutput, error) {
m.t.Helper()

Expand Down Expand Up @@ -88,11 +101,12 @@ func TestProvider(t *testing.T) {
defer restoreTime()

cases := map[string]struct {
Client mockClient
AccountID string
Region string
RoleName string
StartURL string
Client mockClient
AccountID string
RoleName string
StartURL string
CachedTokenFilePath string
TokenProvider *mockTokenProvider

ExpectedErr bool
ExpectedCredentials credentials.Value
Expand All @@ -104,10 +118,9 @@ func TestProvider(t *testing.T) {
},
"valid required parameter values": {
Client: mockClient{
ExpectedAccountID: "012345678901",
ExpectedRoleName: "TestRole",
ExpectedClientRegion: "us-west-2",
ExpectedAccessToken: "dGhpcyBpcyBub3QgYSByZWFsIHZhbHVl",
ExpectedAccountID: "012345678901",
ExpectedRoleName: "TestRole",
ExpectedAccessToken: "dGhpcyBpcyBub3QgYSByZWFsIHZhbHVl",
Response: func(mock mockClient) (*sso.GetRoleCredentialsOutput, error) {
return &sso.GetRoleCredentialsOutput{
RoleCredentials: &sso.RoleCredentials{
Expand All @@ -120,7 +133,6 @@ func TestProvider(t *testing.T) {
},
},
AccountID: "012345678901",
Region: "us-west-2",
RoleName: "TestRole",
StartURL: "https://valid-required-only",
ExpectedCredentials: credentials.Value{
Expand All @@ -131,22 +143,89 @@ func TestProvider(t *testing.T) {
},
ExpectedExpire: time.Date(2021, 01, 20, 21, 22, 23, 0.123e9, time.UTC),
},
"custom cached token file": {
Client: mockClient{
ExpectedAccountID: "012345678901",
ExpectedRoleName: "TestRole",
ExpectedAccessToken: "ZhbHVldGhpcyBpcyBub3QgYSByZWFsIH",
Response: func(mock mockClient) (*sso.GetRoleCredentialsOutput, error) {
return &sso.GetRoleCredentialsOutput{
RoleCredentials: &sso.RoleCredentials{
AccessKeyId: aws.String("AccessKey"),
SecretAccessKey: aws.String("SecretKey"),
SessionToken: aws.String("SessionToken"),
Expiration: aws.Int64(1611177743123),
},
}, nil
},
},
CachedTokenFilePath: filepath.Join("testdata", "custom_cached_token.json"),
AccountID: "012345678901",
RoleName: "TestRole",
ExpectedCredentials: credentials.Value{
AccessKeyID: "AccessKey",
SecretAccessKey: "SecretKey",
SessionToken: "SessionToken",
ProviderName: ProviderName,
},
ExpectedExpire: time.Date(2021, 01, 20, 21, 22, 23, 0.123e9, time.UTC),
},
"access token retrieved by token provider": {
Client: mockClient{
ExpectedAccountID: "012345678901",
ExpectedRoleName: "TestRole",
ExpectedAccessToken: "WFsIHZhbHVldGhpcyBpcyBub3QgYSByZ",
Response: func(mock mockClient) (*sso.GetRoleCredentialsOutput, error) {
return &sso.GetRoleCredentialsOutput{
RoleCredentials: &sso.RoleCredentials{
AccessKeyId: aws.String("AccessKey"),
SecretAccessKey: aws.String("SecretKey"),
SessionToken: aws.String("SessionToken"),
Expiration: aws.Int64(1611177743123),
},
}, nil
},
},
TokenProvider: &mockTokenProvider{
Response: func() (bearer.Token, error) {
return bearer.Token{
Value: "WFsIHZhbHVldGhpcyBpcyBub3QgYSByZ",
}, nil
},
},
AccountID: "012345678901",
RoleName: "TestRole",
StartURL: "ignored value",
ExpectedCredentials: credentials.Value{
AccessKeyID: "AccessKey",
SecretAccessKey: "SecretKey",
SessionToken: "SessionToken",
ProviderName: ProviderName,
},
ExpectedExpire: time.Date(2021, 01, 20, 21, 22, 23, 0.123e9, time.UTC),
},
"token provider return error": {
TokenProvider: &mockTokenProvider{
Response: func() (bearer.Token, error) {
return bearer.Token{}, fmt.Errorf("mock token provider return error")
},
},
ExpectedErr: true,
},
"expired access token": {
StartURL: "https://expired",
ExpectedErr: true,
},
"api error": {
Client: mockClient{
ExpectedAccountID: "012345678901",
ExpectedRoleName: "TestRole",
ExpectedClientRegion: "us-west-2",
ExpectedAccessToken: "dGhpcyBpcyBub3QgYSByZWFsIHZhbHVl",
ExpectedAccountID: "012345678901",
ExpectedRoleName: "TestRole",
ExpectedAccessToken: "dGhpcyBpcyBub3QgYSByZWFsIHZhbHVl",
Response: func(mock mockClient) (*sso.GetRoleCredentialsOutput, error) {
return nil, fmt.Errorf("api error")
},
},
AccountID: "012345678901",
Region: "us-west-2",
RoleName: "TestRole",
StartURL: "https://valid-required-only",
ExpectedErr: true,
Expand All @@ -158,10 +237,14 @@ func TestProvider(t *testing.T) {
tt.Client.t = t

provider := &Provider{
Client: tt.Client,
AccountID: tt.AccountID,
RoleName: tt.RoleName,
StartURL: tt.StartURL,
Client: tt.Client,
AccountID: tt.AccountID,
RoleName: tt.RoleName,
StartURL: tt.StartURL,
CachedTokenFilepath: tt.CachedTokenFilePath,
}
if tt.TokenProvider != nil {
provider.TokenProvider = tt.TokenProvider
}

provider.Expiry.CurrentTime = nowTime
Expand Down
Loading

0 comments on commit d744468

Please sign in to comment.