From c09b938476f81d517c448b59fbbc9dfb94c52bf3 Mon Sep 17 00:00:00 2001 From: Brian Flad Date: Mon, 29 Jan 2018 00:05:25 -0500 Subject: [PATCH] provider: Infer AWS partition from configured region --- aws/auth_helpers.go | 30 +++++++-------- aws/auth_helpers_test.go | 80 +++++++++++----------------------------- aws/config.go | 11 ++++-- 3 files changed, 44 insertions(+), 77 deletions(-) diff --git a/aws/auth_helpers.go b/aws/auth_helpers.go index ca183ff1d46..50221f56f43 100644 --- a/aws/auth_helpers.go +++ b/aws/auth_helpers.go @@ -5,10 +5,10 @@ import ( "fmt" "log" "os" - "strings" "time" "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/arn" "github.com/aws/aws-sdk-go/aws/awserr" awsCredentials "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds" @@ -23,7 +23,7 @@ import ( "github.com/hashicorp/go-multierror" ) -func GetAccountInfo(iamconn *iam.IAM, stsconn *sts.STS, authProviderName string) (string, string, error) { +func GetAccountID(iamconn *iam.IAM, stsconn *sts.STS, authProviderName string) (string, error) { var errors error // If we have creds from instance profile, we can use metadata API if authProviderName == ec2rolecreds.ProviderName { @@ -33,13 +33,13 @@ func GetAccountInfo(iamconn *iam.IAM, stsconn *sts.STS, authProviderName string) setOptionalEndpoint(cfg) sess, err := session.NewSession(cfg) if err != nil { - return "", "", errwrap.Wrapf("Error creating AWS session: {{err}}", err) + return "", errwrap.Wrapf("Error creating AWS session: {{err}}", err) } metadataClient := ec2metadata.New(sess) info, err := metadataClient.IAMInfo() if err == nil { - return parseAccountInfoFromArn(info.InstanceProfileArn) + return parseAccountIDFromArn(info.InstanceProfileArn) } log.Printf("[DEBUG] Failed to get account info from metadata service: %s", err) errors = multierror.Append(errors, err) @@ -55,14 +55,14 @@ func GetAccountInfo(iamconn *iam.IAM, stsconn *sts.STS, authProviderName string) log.Println("[DEBUG] Trying to get account ID via iam:GetUser") outUser, err := iamconn.GetUser(nil) if err == nil { - return parseAccountInfoFromArn(*outUser.User.Arn) + return parseAccountIDFromArn(*outUser.User.Arn) } errors = multierror.Append(errors, err) awsErr, ok := err.(awserr.Error) // AccessDenied and ValidationError can be raised // if credentials belong to federated profile, so we ignore these if !ok || (awsErr.Code() != "AccessDenied" && awsErr.Code() != "ValidationError" && awsErr.Code() != "InvalidClientTokenId") { - return "", "", fmt.Errorf("Failed getting account ID via 'iam:GetUser': %s", err) + return "", fmt.Errorf("Failed getting account ID via 'iam:GetUser': %s", err) } log.Printf("[DEBUG] Getting account ID via iam:GetUser failed: %s", err) } @@ -71,7 +71,7 @@ func GetAccountInfo(iamconn *iam.IAM, stsconn *sts.STS, authProviderName string) log.Println("[DEBUG] Trying to get account ID via sts:GetCallerIdentity") outCallerIdentity, err := stsconn.GetCallerIdentity(&sts.GetCallerIdentityInput{}) if err == nil { - return parseAccountInfoFromArn(*outCallerIdentity.Arn) + return parseAccountIDFromArn(*outCallerIdentity.Arn) } log.Printf("[DEBUG] Getting account ID via sts:GetCallerIdentity failed: %s", err) errors = multierror.Append(errors, err) @@ -84,25 +84,25 @@ func GetAccountInfo(iamconn *iam.IAM, stsconn *sts.STS, authProviderName string) if err != nil { log.Printf("[DEBUG] Failed to get account ID via iam:ListRoles: %s", err) errors = multierror.Append(errors, err) - return "", "", fmt.Errorf("Failed getting account ID via all available methods. Errors: %s", errors) + return "", fmt.Errorf("Failed getting account ID via all available methods. Errors: %s", errors) } if len(outRoles.Roles) < 1 { err = fmt.Errorf("Failed to get account ID via iam:ListRoles: No roles available") log.Printf("[DEBUG] %s", err) errors = multierror.Append(errors, err) - return "", "", fmt.Errorf("Failed getting account ID via all available methods. Errors: %s", errors) + return "", fmt.Errorf("Failed getting account ID via all available methods. Errors: %s", errors) } - return parseAccountInfoFromArn(*outRoles.Roles[0].Arn) + return parseAccountIDFromArn(*outRoles.Roles[0].Arn) } -func parseAccountInfoFromArn(arn string) (string, string, error) { - parts := strings.Split(arn, ":") - if len(parts) < 5 { - return "", "", fmt.Errorf("Unable to parse ID from invalid ARN: %q", arn) +func parseAccountIDFromArn(inputARN string) (string, error) { + arn, err := arn.Parse(inputARN) + if err != nil { + return "", fmt.Errorf("Unable to parse ID from invalid ARN: %q", arn) } - return parts[1], parts[4], nil + return arn.AccountID, nil } // This function is responsible for reading credentials from the diff --git a/aws/auth_helpers_test.go b/aws/auth_helpers_test.go index 661981329d7..de7d8162aa6 100644 --- a/aws/auth_helpers_test.go +++ b/aws/auth_helpers_test.go @@ -17,7 +17,7 @@ import ( "github.com/aws/aws-sdk-go/service/sts" ) -func TestAWSGetAccountInfo_shouldBeValid_fromEC2Role(t *testing.T) { +func TestAWSGetAccountID_shouldBeValid_fromEC2Role(t *testing.T) { resetEnv := unsetEnv(t) defer resetEnv() // capture the test server's close method, to call after the test returns @@ -33,23 +33,18 @@ func TestAWSGetAccountInfo_shouldBeValid_fromEC2Role(t *testing.T) { iamConn := iam.New(emptySess) stsConn := sts.New(emptySess) - part, id, err := GetAccountInfo(iamConn, stsConn, ec2rolecreds.ProviderName) + id, err := GetAccountID(iamConn, stsConn, ec2rolecreds.ProviderName) if err != nil { t.Fatalf("Getting account ID from EC2 metadata API failed: %s", err) } - expectedPart := "aws" - if part != expectedPart { - t.Fatalf("Expected partition: %s, given: %s", expectedPart, part) - } - expectedAccountId := "123456789013" if id != expectedAccountId { t.Fatalf("Expected account ID: %s, given: %s", expectedAccountId, id) } } -func TestAWSGetAccountInfo_shouldBeValid_EC2RoleHasPriority(t *testing.T) { +func TestAWSGetAccountID_shouldBeValid_EC2RoleHasPriority(t *testing.T) { resetEnv := unsetEnv(t) defer resetEnv() // capture the test server's close method, to call after the test returns @@ -75,23 +70,18 @@ func TestAWSGetAccountInfo_shouldBeValid_EC2RoleHasPriority(t *testing.T) { } stsConn := sts.New(stsSess) - part, id, err := GetAccountInfo(iamConn, stsConn, ec2rolecreds.ProviderName) + id, err := GetAccountID(iamConn, stsConn, ec2rolecreds.ProviderName) if err != nil { t.Fatalf("Getting account ID from EC2 metadata API failed: %s", err) } - expectedPart := "aws" - if part != expectedPart { - t.Fatalf("Expected partition: %s, given: %s", expectedPart, part) - } - expectedAccountId := "123456789013" if id != expectedAccountId { t.Fatalf("Expected account ID: %s, given: %s", expectedAccountId, id) } } -func TestAWSGetAccountInfo_shouldBeValid_fromIamUser(t *testing.T) { +func TestAWSGetAccountID_shouldBeValid_fromIamUser(t *testing.T) { iamEndpoints := []*awsMockEndpoint{ { Request: &awsMockRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"}, @@ -113,23 +103,18 @@ func TestAWSGetAccountInfo_shouldBeValid_fromIamUser(t *testing.T) { iamConn := iam.New(iamSess) stsConn := sts.New(stsSess) - part, id, err := GetAccountInfo(iamConn, stsConn, "") + id, err := GetAccountID(iamConn, stsConn, "") if err != nil { t.Fatalf("Getting account ID via GetUser failed: %s", err) } - expectedPart := "aws" - if part != expectedPart { - t.Fatalf("Expected partition: %s, given: %s", expectedPart, part) - } - expectedAccountId := "123456789012" if id != expectedAccountId { t.Fatalf("Expected account ID: %s, given: %s", expectedAccountId, id) } } -func TestAWSGetAccountInfo_shouldBeValid_fromGetCallerIdentity(t *testing.T) { +func TestAWSGetAccountID_shouldBeValid_fromGetCallerIdentity(t *testing.T) { iamEndpoints := []*awsMockEndpoint{ { Request: &awsMockRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"}, @@ -154,10 +139,10 @@ func TestAWSGetAccountInfo_shouldBeValid_fromGetCallerIdentity(t *testing.T) { t.Fatal(err) } - testGetAccountInfo(t, iamSess, stsSess, credentials.SharedCredsProviderName) + testGetAccountID(t, iamSess, stsSess, credentials.SharedCredsProviderName) } -func TestAWSGetAccountInfo_shouldBeValid_EC2RoleFallsBackToCallerIdentity(t *testing.T) { +func TestAWSGetAccountID_shouldBeValid_EC2RoleFallsBackToCallerIdentity(t *testing.T) { // This mimics the metadata service mocked by Hologram (https://github.com/AdRoll/hologram) resetEnv := unsetEnv(t) defer resetEnv() @@ -189,10 +174,10 @@ func TestAWSGetAccountInfo_shouldBeValid_EC2RoleFallsBackToCallerIdentity(t *tes t.Fatal(err) } - testGetAccountInfo(t, iamSess, stsSess, ec2rolecreds.ProviderName) + testGetAccountID(t, iamSess, stsSess, ec2rolecreds.ProviderName) } -func TestAWSGetAccountInfo_shouldBeValid_fromIamListRoles(t *testing.T) { +func TestAWSGetAccountID_shouldBeValid_fromIamListRoles(t *testing.T) { iamEndpoints := []*awsMockEndpoint{ { Request: &awsMockRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"}, @@ -224,23 +209,18 @@ func TestAWSGetAccountInfo_shouldBeValid_fromIamListRoles(t *testing.T) { iamConn := iam.New(iamSess) stsConn := sts.New(stsSess) - part, id, err := GetAccountInfo(iamConn, stsConn, "") + id, err := GetAccountID(iamConn, stsConn, "") if err != nil { t.Fatalf("Getting account ID via ListRoles failed: %s", err) } - expectedPart := "aws" - if part != expectedPart { - t.Fatalf("Expected partition: %s, given: %s", expectedPart, part) - } - expectedAccountId := "123456789012" if id != expectedAccountId { t.Fatalf("Expected account ID: %s, given: %s", expectedAccountId, id) } } -func TestAWSGetAccountInfo_shouldBeValid_federatedRole(t *testing.T) { +func TestAWSGetAccountID_shouldBeValid_federatedRole(t *testing.T) { iamEndpoints := []*awsMockEndpoint{ { Request: &awsMockRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"}, @@ -266,23 +246,18 @@ func TestAWSGetAccountInfo_shouldBeValid_federatedRole(t *testing.T) { iamConn := iam.New(iamSess) stsConn := sts.New(stsSess) - part, id, err := GetAccountInfo(iamConn, stsConn, "") + id, err := GetAccountID(iamConn, stsConn, "") if err != nil { t.Fatalf("Getting account ID via ListRoles failed: %s", err) } - expectedPart := "aws" - if part != expectedPart { - t.Fatalf("Expected partition: %s, given: %s", expectedPart, part) - } - expectedAccountId := "123456789012" if id != expectedAccountId { t.Fatalf("Expected account ID: %s, given: %s", expectedAccountId, id) } } -func TestAWSGetAccountInfo_shouldError_unauthorizedFromIam(t *testing.T) { +func TestAWSGetAccountID_shouldError_unauthorizedFromIam(t *testing.T) { iamEndpoints := []*awsMockEndpoint{ { Request: &awsMockRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"}, @@ -308,37 +283,29 @@ func TestAWSGetAccountInfo_shouldError_unauthorizedFromIam(t *testing.T) { iamConn := iam.New(iamSess) stsConn := sts.New(stsSess) - part, id, err := GetAccountInfo(iamConn, stsConn, "") + id, err := GetAccountID(iamConn, stsConn, "") if err == nil { t.Fatal("Expected error when getting account ID") } - if part != "" { - t.Fatalf("Expected no partition, given: %s", part) - } - if id != "" { t.Fatalf("Expected no account ID, given: %s", id) } } -func TestAWSParseAccountInfoFromArn(t *testing.T) { +func TestAWSParseAccountIDFromArn(t *testing.T) { validArn := "arn:aws:iam::101636750127:instance-profile/aws-elasticbeanstalk-ec2-role" - expectedPart := "aws" expectedId := "101636750127" - part, id, err := parseAccountInfoFromArn(validArn) + id, err := parseAccountIDFromArn(validArn) if err != nil { t.Fatalf("Expected no error when parsing valid ARN: %s", err) } - if part != expectedPart { - t.Fatalf("Parsed part doesn't match with expected (%q != %q)", part, expectedPart) - } if id != expectedId { t.Fatalf("Parsed id doesn't match with expected (%q != %q)", id, expectedId) } invalidArn := "blablah" - part, id, err = parseAccountInfoFromArn(invalidArn) + id, err = parseAccountIDFromArn(invalidArn) if err == nil { t.Fatalf("Expected error when parsing invalid ARN (%q)", invalidArn) } @@ -671,21 +638,16 @@ func TestAWSGetCredentials_shouldBeENV(t *testing.T) { } } -func testGetAccountInfo(t *testing.T, iamSess, stsSess *session.Session, credProviderName string) { +func testGetAccountID(t *testing.T, iamSess, stsSess *session.Session, credProviderName string) { iamConn := iam.New(iamSess) stsConn := sts.New(stsSess) - part, id, err := GetAccountInfo(iamConn, stsConn, credProviderName) + id, err := GetAccountID(iamConn, stsConn, credProviderName) if err != nil { t.Fatalf("Getting account ID failed: %s", err) } - expectedPart := "aws" - if part != expectedPart { - t.Fatalf("Expected partition: %s, given: %s", expectedPart, part) - } - expectedAccountId := "123456789012" if id != expectedAccountId { t.Fatalf("Expected account ID: %s, given: %s", expectedAccountId, id) diff --git a/aws/config.go b/aws/config.go index 99a2e7347e3..4a7d6e168e6 100644 --- a/aws/config.go +++ b/aws/config.go @@ -12,6 +12,7 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/endpoints" "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/acm" @@ -370,11 +371,15 @@ func (c *Config) Client() (interface{}, error) { } } + // Infer AWS partition from configured region + if partition, ok := endpoints.PartitionForRegion(endpoints.DefaultPartitions(), client.region); ok { + client.partition = partition.ID() + } + if !c.SkipRequestingAccountId { - partition, accountId, err := GetAccountInfo(client.iamconn, client.stsconn, cp.ProviderName) + accountID, err := GetAccountID(client.iamconn, client.stsconn, cp.ProviderName) if err == nil { - client.partition = partition - client.accountid = accountId + client.accountid = accountID } }