From cdfa7ea8a9767447677f424e5b6a95a64b95eb37 Mon Sep 17 00:00:00 2001 From: Brian Flad Date: Fri, 13 Jul 2018 01:16:15 -0400 Subject: [PATCH] provider: Allow AssumeRoleARN and credential validation call to shortcut account ID and partition lookup * If provider configuration is set to assume role, use ARN to bypass account information calls * If provider credential validation is enabled, use results to bypass account information calls * Refactor GetAccountID back to GetAccountInformation to ensure non-SDK-default partitions can be handled * Refactor GetAccountInformation testing --- aws/auth_helpers.go | 157 ++++++---- aws/auth_helpers_test.go | 603 ++++++++++++++++++++++----------------- aws/config.go | 43 +-- 3 files changed, 470 insertions(+), 333 deletions(-) diff --git a/aws/auth_helpers.go b/aws/auth_helpers.go index 5f65a47c45b..355910fec27 100644 --- a/aws/auth_helpers.go +++ b/aws/auth_helpers.go @@ -18,92 +18,137 @@ import ( "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/iam" "github.com/aws/aws-sdk-go/service/sts" - "github.com/hashicorp/errwrap" "github.com/hashicorp/go-cleanhttp" "github.com/hashicorp/go-multierror" ) -func GetAccountID(iamconn *iam.IAM, stsconn *sts.STS, authProviderName string) (string, error) { - var errors error +func GetAccountInformation(iamconn *iam.IAM, stsconn *sts.STS, authProviderName string) (string, string, error) { + var accountID, partition string + var err, errors error - // First, try STS GetCallerIdentity - log.Println("[DEBUG] Trying to get account ID via sts:GetCallerIdentity") - outCallerIdentity, err := stsconn.GetCallerIdentity(&sts.GetCallerIdentityInput{}) - if err == nil { - return parseAccountIDFromArn(*outCallerIdentity.Arn) + if authProviderName == ec2rolecreds.ProviderName { + accountID, partition, err = GetAccountInformationFromEC2Metadata() + } else { + accountID, partition, err = GetAccountInformationFromIAMGetUser(iamconn) + } + if accountID != "" { + return accountID, partition, nil } - log.Printf("[DEBUG] Getting account ID via sts:GetCallerIdentity failed: %s", err) errors = multierror.Append(errors, err) - // If we have creds from instance profile, we can use metadata API - if authProviderName == ec2rolecreds.ProviderName { - log.Println("[DEBUG] Trying to get account ID via AWS Metadata API") + accountID, partition, err = GetAccountInformationFromSTSGetCallerIdentity(stsconn) + if accountID != "" { + return accountID, partition, nil + } + errors = multierror.Append(errors, err) - cfg := &aws.Config{} - setOptionalEndpoint(cfg) - sess, err := session.NewSession(cfg) - if err != nil { - return "", errwrap.Wrapf("Error creating AWS session: {{err}}", err) - } + accountID, partition, err = GetAccountInformationFromIAMListRoles(iamconn) + if accountID != "" { + return accountID, partition, nil + } + errors = multierror.Append(errors, err) - metadataClient := ec2metadata.New(sess) - info, err := metadataClient.IAMInfo() - if err == nil { - return parseAccountIDFromArn(info.InstanceProfileArn) - } - log.Printf("[DEBUG] Failed to get account info from metadata service: %s", err) - errors = multierror.Append(errors, err) + return accountID, partition, errors +} + +func GetAccountInformationFromEC2Metadata() (string, string, error) { + log.Println("[DEBUG] Trying to get account information via EC2 Metadata") + + cfg := &aws.Config{} + setOptionalEndpoint(cfg) + sess, err := session.NewSession(cfg) + if err != nil { + return "", "", fmt.Errorf("error creating EC2 Metadata session: %s", err) + } + + metadataClient := ec2metadata.New(sess) + info, err := metadataClient.IAMInfo() + if err != nil { // We can end up here if there's an issue with the instance metadata service // or if we're getting credentials from AdRoll's Hologram (in which case IAMInfo will - // error out). In any event, if we can't get account info here, we should try - // the other methods available. - // If we have creds from something that looks like an IAM instance profile, but - // we were unable to retrieve account info from the instance profile, it's probably - // a safe assumption that we're not an IAM user - } else { - // Creds aren't from an IAM instance profile, so try try iam:GetUser - log.Println("[DEBUG] Trying to get account ID via iam:GetUser") - outUser, err := iamconn.GetUser(nil) - if err == nil { - return parseAccountIDFromArn(*outUser.User.Arn) - } - errors = multierror.Append(errors, err) - awsErr, ok := err.(awserr.Error) + // error out). + err = fmt.Errorf("failed getting account information via EC2 Metadata IAM information: %s", err) + log.Printf("[DEBUG] %s", err) + return "", "", err + } + + return parseAccountInformationFromARN(info.InstanceProfileArn) +} + +func GetAccountInformationFromIAMGetUser(iamconn *iam.IAM) (string, string, error) { + log.Println("[DEBUG] Trying to get account information via iam:GetUser") + + output, err := iamconn.GetUser(&iam.GetUserInput{}) + if err != nil { // AccessDenied and ValidationError can be raised // if credentials belong to federated profile, so we ignore these - if !ok || (awsErr.Code() != "AccessDenied" && awsErr.Code() != "ValidationError" && awsErr.Code() != "InvalidClientTokenId") { - return "", fmt.Errorf("Failed getting account ID via 'iam:GetUser': %s", err) + if isAWSErr(err, "AccessDenied", "") { + return "", "", nil + } + if isAWSErr(err, "InvalidClientTokenId", "") { + return "", "", nil + } + if isAWSErr(err, "ValidationError", "") { + return "", "", nil } - log.Printf("[DEBUG] Getting account ID via iam:GetUser failed: %s", err) + err = fmt.Errorf("failed getting account information via iam:GetUser: %s", err) + log.Printf("[DEBUG] %s", err) + return "", "", err + } + + if output == nil || output.User == nil { + err = errors.New("empty iam:GetUser response") + log.Printf("[DEBUG] %s", err) + return "", "", err } - // Then try IAM ListRoles - log.Println("[DEBUG] Trying to get account ID via iam:ListRoles") - outRoles, err := iamconn.ListRoles(&iam.ListRolesInput{ + return parseAccountInformationFromARN(aws.StringValue(output.User.Arn)) +} + +func GetAccountInformationFromIAMListRoles(iamconn *iam.IAM) (string, string, error) { + log.Println("[DEBUG] Trying to get account information via iam:ListRoles") + + output, err := iamconn.ListRoles(&iam.ListRolesInput{ MaxItems: aws.Int64(int64(1)), }) if err != nil { - log.Printf("[DEBUG] Failed to get account ID via iam:ListRoles: %s", err) - errors = multierror.Append(errors, err) - return "", fmt.Errorf("Failed getting account ID via all available methods. Errors: %s", errors) + err = fmt.Errorf("failed getting account information via iam:ListRoles: %s", err) + log.Printf("[DEBUG] %s", err) + return "", "", err + } + + if output == nil || len(output.Roles) < 1 { + err = fmt.Errorf("empty iam:ListRoles response") + log.Printf("[DEBUG] %s", err) + return "", "", err + } + + return parseAccountInformationFromARN(aws.StringValue(output.Roles[0].Arn)) +} + +func GetAccountInformationFromSTSGetCallerIdentity(stsconn *sts.STS) (string, string, error) { + log.Println("[DEBUG] Trying to get account information via sts:GetCallerIdentity") + + output, err := stsconn.GetCallerIdentity(&sts.GetCallerIdentityInput{}) + if err != nil { + return "", "", fmt.Errorf("error calling sts:GetCallerIdentity: %s", err) } - if len(outRoles.Roles) < 1 { - err = fmt.Errorf("Failed to get account ID via iam:ListRoles: No roles available") + if output == nil || output.Arn == nil { + err = errors.New("empty sts:GetCallerIdentity response") log.Printf("[DEBUG] %s", err) - errors = multierror.Append(errors, err) - return "", fmt.Errorf("Failed getting account ID via all available methods. Errors: %s", errors) + return "", "", err } - return parseAccountIDFromArn(*outRoles.Roles[0].Arn) + return parseAccountInformationFromARN(aws.StringValue(output.Arn)) } -func parseAccountIDFromArn(inputARN string) (string, error) { +func parseAccountInformationFromARN(inputARN string) (string, string, error) { arn, err := arn.Parse(inputARN) if err != nil { - return "", fmt.Errorf("Unable to parse ID from invalid ARN: %q", arn) + return "", "", fmt.Errorf("error parsing ARN (%s): %s", inputARN, err) } - return arn.AccountID, nil + return arn.AccountID, arn.Partition, nil } // This function is responsible for reading credentials from the diff --git a/aws/auth_helpers_test.go b/aws/auth_helpers_test.go index de7d8162aa6..ded7104ed4b 100644 --- a/aws/auth_helpers_test.go +++ b/aws/auth_helpers_test.go @@ -10,304 +10,393 @@ import ( "testing" "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds" - "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/iam" "github.com/aws/aws-sdk-go/service/sts" ) -func TestAWSGetAccountID_shouldBeValid_fromEC2Role(t *testing.T) { - resetEnv := unsetEnv(t) - defer resetEnv() - // capture the test server's close method, to call after the test returns - awsTs := awsMetadataApiMock(append(securityCredentialsEndpoints, instanceIdEndpoint, iamInfoEndpoint)) - defer awsTs() - - closeEmpty, emptySess, err := getMockedAwsApiSession("zero", []*awsMockEndpoint{}) - defer closeEmpty() - if err != nil { - t.Fatal(err) - } - - iamConn := iam.New(emptySess) - stsConn := sts.New(emptySess) - - id, err := GetAccountID(iamConn, stsConn, ec2rolecreds.ProviderName) - if err != nil { - t.Fatalf("Getting account ID from EC2 metadata API failed: %s", err) - } - - expectedAccountId := "123456789013" - if id != expectedAccountId { - t.Fatalf("Expected account ID: %s, given: %s", expectedAccountId, id) - } -} - -func TestAWSGetAccountID_shouldBeValid_EC2RoleHasPriority(t *testing.T) { - resetEnv := unsetEnv(t) - defer resetEnv() - // capture the test server's close method, to call after the test returns - awsTs := awsMetadataApiMock(append(securityCredentialsEndpoints, instanceIdEndpoint, iamInfoEndpoint)) - defer awsTs() - - iamEndpoints := []*awsMockEndpoint{ +func TestGetAccountInformation(t *testing.T) { + var testCases = []struct { + Description string + AuthProviderName string + EC2MetadataEndpoints []*endpoint + IAMEndpoints []*awsMockEndpoint + STSEndpoints []*awsMockEndpoint + ErrCount int + ExpectedAccountID string + ExpectedPartition string + }{ { - Request: &awsMockRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"}, - Response: &awsMockResponse{200, iamResponse_GetUser_valid, "text/xml"}, + Description: "EC2 Metadata over iam:GetUser when using EC2 Instance Profile", + AuthProviderName: ec2rolecreds.ProviderName, + EC2MetadataEndpoints: append(ec2metadata_securityCredentialsEndpoints, ec2metadata_instanceIdEndpoint, ec2metadata_iamInfoEndpoint), + IAMEndpoints: []*awsMockEndpoint{ + { + Request: &awsMockRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"}, + Response: &awsMockResponse{200, iamResponse_GetUser_valid, "text/xml"}, + }, + }, + ExpectedAccountID: ec2metadata_iamInfoEndpoint_expectedAccountID, + ExpectedPartition: ec2metadata_iamInfoEndpoint_expectedPartition, }, - } - closeIam, iamSess, err := getMockedAwsApiSession("IAM", iamEndpoints) - defer closeIam() - if err != nil { - t.Fatal(err) - } - iamConn := iam.New(iamSess) - closeSts, stsSess, err := getMockedAwsApiSession("STS", []*awsMockEndpoint{}) - defer closeSts() - if err != nil { - t.Fatal(err) - } - stsConn := sts.New(stsSess) - - id, err := GetAccountID(iamConn, stsConn, ec2rolecreds.ProviderName) - if err != nil { - t.Fatalf("Getting account ID from EC2 metadata API failed: %s", err) - } - - expectedAccountId := "123456789013" - if id != expectedAccountId { - t.Fatalf("Expected account ID: %s, given: %s", expectedAccountId, id) - } -} - -func TestAWSGetAccountID_shouldBeValid_fromIamUser(t *testing.T) { - iamEndpoints := []*awsMockEndpoint{ { - Request: &awsMockRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"}, - Response: &awsMockResponse{200, iamResponse_GetUser_valid, "text/xml"}, + Description: "Mimic the metadata service mocked by Hologram (https://github.com/AdRoll/hologram)", + AuthProviderName: ec2rolecreds.ProviderName, + EC2MetadataEndpoints: ec2metadata_securityCredentialsEndpoints, + IAMEndpoints: []*awsMockEndpoint{ + { + Request: &awsMockRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"}, + Response: &awsMockResponse{403, iamResponse_GetUser_unauthorized, "text/xml"}, + }, + }, + STSEndpoints: []*awsMockEndpoint{ + { + Request: &awsMockRequest{"POST", "/", "Action=GetCallerIdentity&Version=2011-06-15"}, + Response: &awsMockResponse{200, stsResponse_GetCallerIdentity_valid, "text/xml"}, + }, + }, + ExpectedAccountID: stsResponse_GetCallerIdentity_valid_expectedAccountID, + ExpectedPartition: stsResponse_GetCallerIdentity_valid_expectedPartition, + }, + { + Description: "iam:ListRoles if iam:GetUser AccessDenied and sts:GetCallerIdentity fails", + IAMEndpoints: []*awsMockEndpoint{ + { + Request: &awsMockRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"}, + Response: &awsMockResponse{403, iamResponse_GetUser_unauthorized, "text/xml"}, + }, + { + Request: &awsMockRequest{"POST", "/", "Action=ListRoles&MaxItems=1&Version=2010-05-08"}, + Response: &awsMockResponse{200, iamResponse_ListRoles_valid, "text/xml"}, + }, + }, + STSEndpoints: []*awsMockEndpoint{ + { + Request: &awsMockRequest{"POST", "/", "Action=GetCallerIdentity&Version=2011-06-15"}, + Response: &awsMockResponse{403, stsResponse_GetCallerIdentity_unauthorized, "text/xml"}, + }, + }, + ExpectedAccountID: iamResponse_ListRoles_valid_expectedAccountID, + ExpectedPartition: iamResponse_ListRoles_valid_expectedPartition, + }, + { + Description: "iam:ListRoles if iam:GetUser ValidationError and sts:GetCallerIdentity fails", + IAMEndpoints: []*awsMockEndpoint{ + { + Request: &awsMockRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"}, + Response: &awsMockResponse{400, iamResponse_GetUser_federatedFailure, "text/xml"}, + }, + { + Request: &awsMockRequest{"POST", "/", "Action=ListRoles&MaxItems=1&Version=2010-05-08"}, + Response: &awsMockResponse{200, iamResponse_ListRoles_valid, "text/xml"}, + }, + }, + STSEndpoints: []*awsMockEndpoint{ + { + Request: &awsMockRequest{"POST", "/", "Action=GetCallerIdentity&Version=2011-06-15"}, + Response: &awsMockResponse{403, stsResponse_GetCallerIdentity_unauthorized, "text/xml"}, + }, + }, + ExpectedAccountID: iamResponse_ListRoles_valid_expectedAccountID, + ExpectedPartition: iamResponse_ListRoles_valid_expectedPartition, + }, + { + Description: "Error when all endpoints fail", + IAMEndpoints: []*awsMockEndpoint{ + { + Request: &awsMockRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"}, + Response: &awsMockResponse{400, iamResponse_GetUser_federatedFailure, "text/xml"}, + }, + { + Request: &awsMockRequest{"POST", "/", "Action=ListRoles&MaxItems=1&Version=2010-05-08"}, + Response: &awsMockResponse{403, iamResponse_ListRoles_unauthorized, "text/xml"}, + }, + }, + STSEndpoints: []*awsMockEndpoint{ + { + Request: &awsMockRequest{"POST", "/", "Action=GetCallerIdentity&Version=2011-06-15"}, + Response: &awsMockResponse{403, stsResponse_GetCallerIdentity_unauthorized, "text/xml"}, + }, + }, + ErrCount: 1, }, } - closeIam, iamSess, err := getMockedAwsApiSession("IAM", iamEndpoints) - defer closeIam() - if err != nil { - t.Fatal(err) - } - closeSts, stsSess, err := getMockedAwsApiSession("STS", []*awsMockEndpoint{}) - defer closeSts() - if err != nil { - t.Fatal(err) - } - - iamConn := iam.New(iamSess) - stsConn := sts.New(stsSess) + for _, testCase := range testCases { + resetEnv := unsetEnv(t) + defer resetEnv() + // capture the test server's close method, to call after the test returns + awsTs := awsMetadataApiMock(testCase.EC2MetadataEndpoints) + defer awsTs() - id, err := GetAccountID(iamConn, stsConn, "") - if err != nil { - t.Fatalf("Getting account ID via GetUser failed: %s", err) - } + closeIam, iamSess, err := getMockedAwsApiSession("IAM", testCase.IAMEndpoints) + defer closeIam() + if err != nil { + t.Fatal(err) + } - expectedAccountId := "123456789012" - if id != expectedAccountId { - t.Fatalf("Expected account ID: %s, given: %s", expectedAccountId, id) - } -} + closeSts, stsSess, err := getMockedAwsApiSession("STS", testCase.STSEndpoints) + defer closeSts() + if err != nil { + t.Fatal(err) + } -func TestAWSGetAccountID_shouldBeValid_fromGetCallerIdentity(t *testing.T) { - iamEndpoints := []*awsMockEndpoint{ - { - Request: &awsMockRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"}, - Response: &awsMockResponse{403, iamResponse_GetUser_unauthorized, "text/xml"}, - }, - } - closeIam, iamSess, err := getMockedAwsApiSession("IAM", iamEndpoints) - defer closeIam() - if err != nil { - t.Fatal(err) - } + iamConn := iam.New(iamSess) + stsConn := sts.New(stsSess) - stsEndpoints := []*awsMockEndpoint{ - { - Request: &awsMockRequest{"POST", "/", "Action=GetCallerIdentity&Version=2011-06-15"}, - Response: &awsMockResponse{200, stsResponse_GetCallerIdentity_valid, "text/xml"}, - }, - } - closeSts, stsSess, err := getMockedAwsApiSession("STS", stsEndpoints) - defer closeSts() - if err != nil { - t.Fatal(err) + accountID, partition, err := GetAccountInformation(iamConn, stsConn, testCase.AuthProviderName) + if err != nil && testCase.ErrCount == 0 { + t.Fatalf("%s: Expected no error, received error: %s", testCase.Description, err) + } + if err == nil && testCase.ErrCount > 0 { + t.Fatalf("%s: Expected %d error(s), received none", testCase.Description, testCase.ErrCount) + } + if accountID != testCase.ExpectedAccountID { + t.Fatalf("%s: Parsed account ID doesn't match with expected (%q != %q)", testCase.Description, accountID, testCase.ExpectedAccountID) + } + if partition != testCase.ExpectedPartition { + t.Fatalf("%s: Parsed partition doesn't match with expected (%q != %q)", testCase.Description, partition, testCase.ExpectedPartition) + } } - - testGetAccountID(t, iamSess, stsSess, credentials.SharedCredsProviderName) } -func TestAWSGetAccountID_shouldBeValid_EC2RoleFallsBackToCallerIdentity(t *testing.T) { - // This mimics the metadata service mocked by Hologram (https://github.com/AdRoll/hologram) +func TestGetAccountInformationFromEC2Metadata(t *testing.T) { resetEnv := unsetEnv(t) defer resetEnv() - - awsTs := awsMetadataApiMock(securityCredentialsEndpoints) + // capture the test server's close method, to call after the test returns + awsTs := awsMetadataApiMock(append(ec2metadata_securityCredentialsEndpoints, ec2metadata_instanceIdEndpoint, ec2metadata_iamInfoEndpoint)) defer awsTs() - iamEndpoints := []*awsMockEndpoint{ - { - Request: &awsMockRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"}, - Response: &awsMockResponse{403, iamResponse_GetUser_unauthorized, "text/xml"}, - }, - } - closeIam, iamSess, err := getMockedAwsApiSession("IAM", iamEndpoints) - defer closeIam() + id, partition, err := GetAccountInformationFromEC2Metadata() if err != nil { - t.Fatal(err) + t.Fatalf("Getting account ID from EC2 metadata API failed: %s", err) } - stsEndpoints := []*awsMockEndpoint{ - { - Request: &awsMockRequest{"POST", "/", "Action=GetCallerIdentity&Version=2011-06-15"}, - Response: &awsMockResponse{200, stsResponse_GetCallerIdentity_valid, "text/xml"}, - }, + if id != ec2metadata_iamInfoEndpoint_expectedAccountID { + t.Fatalf("Expected account ID: %s, given: %s", ec2metadata_iamInfoEndpoint_expectedAccountID, id) } - closeSts, stsSess, err := getMockedAwsApiSession("STS", stsEndpoints) - defer closeSts() - if err != nil { - t.Fatal(err) + if partition != ec2metadata_iamInfoEndpoint_expectedPartition { + t.Fatalf("Expected partition: %s, given: %s", ec2metadata_iamInfoEndpoint_expectedPartition, partition) } - - testGetAccountID(t, iamSess, stsSess, ec2rolecreds.ProviderName) } -func TestAWSGetAccountID_shouldBeValid_fromIamListRoles(t *testing.T) { - iamEndpoints := []*awsMockEndpoint{ +func TestGetAccountInformationFromIAMGetUser(t *testing.T) { + var testCases = []struct { + MockEndpoints []*awsMockEndpoint + ErrCount int + ExpectedAccountID string + ExpectedPartition string + }{ { - Request: &awsMockRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"}, - Response: &awsMockResponse{403, iamResponse_GetUser_unauthorized, "text/xml"}, + MockEndpoints: []*awsMockEndpoint{ + { + Request: &awsMockRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"}, + Response: &awsMockResponse{400, iamResponse_GetUser_federatedFailure, "text/xml"}, + }, + }, + // We ignore this error + ErrCount: 0, }, { - Request: &awsMockRequest{"POST", "/", "Action=ListRoles&MaxItems=1&Version=2010-05-08"}, - Response: &awsMockResponse{200, iamResponse_ListRoles_valid, "text/xml"}, + MockEndpoints: []*awsMockEndpoint{ + { + Request: &awsMockRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"}, + Response: &awsMockResponse{403, iamResponse_GetUser_unauthorized, "text/xml"}, + }, + }, + // We ignore this error + ErrCount: 0, }, - } - closeIam, iamSess, err := getMockedAwsApiSession("IAM", iamEndpoints) - defer closeIam() - if err != nil { - t.Fatal(err) - } - - stsEndpoints := []*awsMockEndpoint{ { - Request: &awsMockRequest{"POST", "/", "Action=GetCallerIdentity&Version=2011-06-15"}, - Response: &awsMockResponse{403, stsResponse_GetCallerIdentity_unauthorized, "text/xml"}, + MockEndpoints: []*awsMockEndpoint{ + { + Request: &awsMockRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"}, + Response: &awsMockResponse{200, iamResponse_GetUser_valid, "text/xml"}, + }, + }, + ExpectedAccountID: iamResponse_GetUser_valid_expectedAccountID, + ExpectedPartition: iamResponse_GetUser_valid_expectedPartition, }, } - closeSts, stsSess, err := getMockedAwsApiSession("STS", stsEndpoints) - defer closeSts() - if err != nil { - t.Fatal(err) - } - iamConn := iam.New(iamSess) - stsConn := sts.New(stsSess) + for _, testCase := range testCases { + closeIam, iamSess, err := getMockedAwsApiSession("IAM", testCase.MockEndpoints) + defer closeIam() + if err != nil { + t.Fatal(err) + } - id, err := GetAccountID(iamConn, stsConn, "") - if err != nil { - t.Fatalf("Getting account ID via ListRoles failed: %s", err) - } + iamConn := iam.New(iamSess) - expectedAccountId := "123456789012" - if id != expectedAccountId { - t.Fatalf("Expected account ID: %s, given: %s", expectedAccountId, id) + accountID, partition, err := GetAccountInformationFromIAMGetUser(iamConn) + if err != nil && testCase.ErrCount == 0 { + t.Fatalf("Expected no error, received error: %s", err) + } + if err == nil && testCase.ErrCount > 0 { + t.Fatalf("Expected %d error(s), received none", testCase.ErrCount) + } + if accountID != testCase.ExpectedAccountID { + t.Fatalf("Parsed account ID doesn't match with expected (%q != %q)", accountID, testCase.ExpectedAccountID) + } + if partition != testCase.ExpectedPartition { + t.Fatalf("Parsed partition doesn't match with expected (%q != %q)", partition, testCase.ExpectedPartition) + } } } -func TestAWSGetAccountID_shouldBeValid_federatedRole(t *testing.T) { - iamEndpoints := []*awsMockEndpoint{ +func TestGetAccountInformationFromIAMListRoles(t *testing.T) { + var testCases = []struct { + MockEndpoints []*awsMockEndpoint + ErrCount int + ExpectedAccountID string + ExpectedPartition string + }{ { - Request: &awsMockRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"}, - Response: &awsMockResponse{400, iamResponse_GetUser_federatedFailure, "text/xml"}, + MockEndpoints: []*awsMockEndpoint{ + { + Request: &awsMockRequest{"POST", "/", "Action=ListRoles&MaxItems=1&Version=2010-05-08"}, + Response: &awsMockResponse{403, iamResponse_ListRoles_unauthorized, "text/xml"}, + }, + }, + ErrCount: 1, }, { - Request: &awsMockRequest{"POST", "/", "Action=ListRoles&MaxItems=1&Version=2010-05-08"}, - Response: &awsMockResponse{200, iamResponse_ListRoles_valid, "text/xml"}, + MockEndpoints: []*awsMockEndpoint{ + { + Request: &awsMockRequest{"POST", "/", "Action=ListRoles&MaxItems=1&Version=2010-05-08"}, + Response: &awsMockResponse{200, iamResponse_ListRoles_valid, "text/xml"}, + }, + }, + ExpectedAccountID: iamResponse_ListRoles_valid_expectedAccountID, + ExpectedPartition: iamResponse_ListRoles_valid_expectedPartition, }, } - closeIam, iamSess, err := getMockedAwsApiSession("IAM", iamEndpoints) - defer closeIam() - if err != nil { - t.Fatal(err) - } - closeSts, stsSess, err := getMockedAwsApiSession("STS", []*awsMockEndpoint{}) - defer closeSts() - if err != nil { - t.Fatal(err) - } - - iamConn := iam.New(iamSess) - stsConn := sts.New(stsSess) + for _, testCase := range testCases { + closeIam, iamSess, err := getMockedAwsApiSession("IAM", testCase.MockEndpoints) + defer closeIam() + if err != nil { + t.Fatal(err) + } - id, err := GetAccountID(iamConn, stsConn, "") - if err != nil { - t.Fatalf("Getting account ID via ListRoles failed: %s", err) - } + iamConn := iam.New(iamSess) - expectedAccountId := "123456789012" - if id != expectedAccountId { - t.Fatalf("Expected account ID: %s, given: %s", expectedAccountId, id) + accountID, partition, err := GetAccountInformationFromIAMListRoles(iamConn) + if err != nil && testCase.ErrCount == 0 { + t.Fatalf("Expected no error, received error: %s", err) + } + if err == nil && testCase.ErrCount > 0 { + t.Fatalf("Expected %d error(s), received none", testCase.ErrCount) + } + if accountID != testCase.ExpectedAccountID { + t.Fatalf("Parsed account ID doesn't match with expected (%q != %q)", accountID, testCase.ExpectedAccountID) + } + if partition != testCase.ExpectedPartition { + t.Fatalf("Parsed partition doesn't match with expected (%q != %q)", partition, testCase.ExpectedPartition) + } } } -func TestAWSGetAccountID_shouldError_unauthorizedFromIam(t *testing.T) { - iamEndpoints := []*awsMockEndpoint{ +func TestGetAccountInformationFromSTSGetCallerIdentity(t *testing.T) { + var testCases = []struct { + MockEndpoints []*awsMockEndpoint + ErrCount int + ExpectedAccountID string + ExpectedPartition string + }{ { - Request: &awsMockRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"}, - Response: &awsMockResponse{403, iamResponse_GetUser_unauthorized, "text/xml"}, + MockEndpoints: []*awsMockEndpoint{ + { + Request: &awsMockRequest{"POST", "/", "Action=GetCallerIdentity&Version=2011-06-15"}, + Response: &awsMockResponse{403, stsResponse_GetCallerIdentity_unauthorized, "text/xml"}, + }, + }, + ErrCount: 1, }, { - Request: &awsMockRequest{"POST", "/", "Action=ListRoles&MaxItems=1&Version=2010-05-08"}, - Response: &awsMockResponse{403, iamResponse_ListRoles_unauthorized, "text/xml"}, + MockEndpoints: []*awsMockEndpoint{ + { + Request: &awsMockRequest{"POST", "/", "Action=GetCallerIdentity&Version=2011-06-15"}, + Response: &awsMockResponse{200, stsResponse_GetCallerIdentity_valid, "text/xml"}, + }, + }, + ExpectedAccountID: stsResponse_GetCallerIdentity_valid_expectedAccountID, + ExpectedPartition: stsResponse_GetCallerIdentity_valid_expectedPartition, }, } - closeIam, iamSess, err := getMockedAwsApiSession("IAM", iamEndpoints) - defer closeIam() - if err != nil { - t.Fatal(err) - } - closeSts, stsSess, err := getMockedAwsApiSession("STS", []*awsMockEndpoint{}) - defer closeSts() - if err != nil { - t.Fatal(err) - } - - iamConn := iam.New(iamSess) - stsConn := sts.New(stsSess) + for _, testCase := range testCases { + closeSts, stsSess, err := getMockedAwsApiSession("STS", testCase.MockEndpoints) + defer closeSts() + if err != nil { + t.Fatal(err) + } - id, err := GetAccountID(iamConn, stsConn, "") - if err == nil { - t.Fatal("Expected error when getting account ID") - } + stsConn := sts.New(stsSess) - if id != "" { - t.Fatalf("Expected no account ID, given: %s", id) + accountID, partition, err := GetAccountInformationFromSTSGetCallerIdentity(stsConn) + if err != nil && testCase.ErrCount == 0 { + t.Fatalf("Expected no error, received error: %s", err) + } + if err == nil && testCase.ErrCount > 0 { + t.Fatalf("Expected %d error(s), received none", testCase.ErrCount) + } + if accountID != testCase.ExpectedAccountID { + t.Fatalf("Parsed account ID doesn't match with expected (%q != %q)", accountID, testCase.ExpectedAccountID) + } + if partition != testCase.ExpectedPartition { + t.Fatalf("Parsed partition doesn't match with expected (%q != %q)", partition, testCase.ExpectedPartition) + } } } -func TestAWSParseAccountIDFromArn(t *testing.T) { - validArn := "arn:aws:iam::101636750127:instance-profile/aws-elasticbeanstalk-ec2-role" - expectedId := "101636750127" - id, err := parseAccountIDFromArn(validArn) - if err != nil { - t.Fatalf("Expected no error when parsing valid ARN: %s", err) - } - if id != expectedId { - t.Fatalf("Parsed id doesn't match with expected (%q != %q)", id, expectedId) +func TestAWSParseAccountInformationFromARN(t *testing.T) { + var testCases = []struct { + InputARN string + ErrCount int + ExpectedAccountID string + ExpectedPartition string + }{ + { + InputARN: "invalid-arn", + ErrCount: 1, + }, + { + InputARN: "arn:aws:iam::123456789012:instance-profile/name", + ExpectedAccountID: "123456789012", + ExpectedPartition: "aws", + }, + { + InputARN: "arn:aws:iam::123456789012:user/name", + ExpectedAccountID: "123456789012", + ExpectedPartition: "aws", + }, + { + InputARN: "arn:aws:sts::123456789012:assumed-role/name", + ExpectedAccountID: "123456789012", + ExpectedPartition: "aws", + }, + { + InputARN: "arn:aws-us-gov:sts::123456789012:assumed-role/name", + ExpectedAccountID: "123456789012", + ExpectedPartition: "aws-us-gov", + }, } - invalidArn := "blablah" - id, err = parseAccountIDFromArn(invalidArn) - if err == nil { - t.Fatalf("Expected error when parsing invalid ARN (%q)", invalidArn) + for _, testCase := range testCases { + accountID, partition, err := parseAccountInformationFromARN(testCase.InputARN) + if err != nil && testCase.ErrCount == 0 { + t.Fatalf("Expected no error when parsing ARN, received error: %s", err) + } + if err == nil && testCase.ErrCount > 0 { + t.Fatalf("Expected %d error(s) when parsing ARN, received none", testCase.ErrCount) + } + if accountID != testCase.ExpectedAccountID { + t.Fatalf("Parsed account ID doesn't match with expected (%q != %q)", accountID, testCase.ExpectedAccountID) + } + if partition != testCase.ExpectedPartition { + t.Fatalf("Parsed partition doesn't match with expected (%q != %q)", partition, testCase.ExpectedPartition) + } } } @@ -388,7 +477,7 @@ func TestAWSGetCredentials_shouldIAM(t *testing.T) { defer resetEnv() // capture the test server's close method, to call after the test returns - ts := awsMetadataApiMock(append(securityCredentialsEndpoints, instanceIdEndpoint, iamInfoEndpoint)) + ts := awsMetadataApiMock(append(ec2metadata_securityCredentialsEndpoints, ec2metadata_instanceIdEndpoint, ec2metadata_iamInfoEndpoint)) defer ts() // An empty config, no key supplied @@ -424,7 +513,7 @@ func TestAWSGetCredentials_shouldIgnoreIAM(t *testing.T) { resetEnv := unsetEnv(t) defer resetEnv() // capture the test server's close method, to call after the test returns - ts := awsMetadataApiMock(append(securityCredentialsEndpoints, instanceIdEndpoint, iamInfoEndpoint)) + ts := awsMetadataApiMock(append(ec2metadata_securityCredentialsEndpoints, ec2metadata_instanceIdEndpoint, ec2metadata_iamInfoEndpoint)) defer ts() simple := []struct { Key, Secret, Token string @@ -531,7 +620,7 @@ func TestAWSGetCredentials_shouldCatchEC2RoleProvider(t *testing.T) { resetEnv := unsetEnv(t) defer resetEnv() // capture the test server's close method, to call after the test returns - ts := awsMetadataApiMock(append(securityCredentialsEndpoints, instanceIdEndpoint, iamInfoEndpoint)) + ts := awsMetadataApiMock(append(ec2metadata_securityCredentialsEndpoints, ec2metadata_instanceIdEndpoint, ec2metadata_iamInfoEndpoint)) defer ts() creds, err := GetCredentials(&Config{}) @@ -638,22 +727,6 @@ func TestAWSGetCredentials_shouldBeENV(t *testing.T) { } } -func testGetAccountID(t *testing.T, iamSess, stsSess *session.Session, credProviderName string) { - - iamConn := iam.New(iamSess) - stsConn := sts.New(stsSess) - - id, err := GetAccountID(iamConn, stsConn, credProviderName) - if err != nil { - t.Fatalf("Getting account ID failed: %s", err) - } - - expectedAccountId := "123456789012" - if id != expectedAccountId { - t.Fatalf("Expected account ID: %s, given: %s", expectedAccountId, id) - } -} - // unsetEnv unsets environment variables for testing a "clean slate" with no // credentials in the environment func unsetEnv(t *testing.T) func() { @@ -791,12 +864,12 @@ type endpoint struct { Body string `json:"body"` } -var instanceIdEndpoint = &endpoint{ +var ec2metadata_instanceIdEndpoint = &endpoint{ Uri: "/latest/meta-data/instance-id", Body: "mock-instance-id", } -var securityCredentialsEndpoints = []*endpoint{ +var ec2metadata_securityCredentialsEndpoints = []*endpoint{ &endpoint{ Uri: "/latest/meta-data/iam/security-credentials", Body: "test_role", @@ -807,18 +880,21 @@ var securityCredentialsEndpoints = []*endpoint{ }, } -var iamInfoEndpoint = &endpoint{ +var ec2metadata_iamInfoEndpoint = &endpoint{ Uri: "/latest/meta-data/iam/info", - Body: "{\"Code\": \"Success\",\"LastUpdated\": \"2016-03-17T12:27:32Z\",\"InstanceProfileArn\": \"arn:aws:iam::123456789013:instance-profile/my-instance-profile\",\"InstanceProfileId\": \"AIPAABCDEFGHIJKLMN123\"}", + Body: "{\"Code\": \"Success\",\"LastUpdated\": \"2016-03-17T12:27:32Z\",\"InstanceProfileArn\": \"arn:aws:iam::000000000000:instance-profile/my-instance-profile\",\"InstanceProfileId\": \"AIPAABCDEFGHIJKLMN123\"}", } +const ec2metadata_iamInfoEndpoint_expectedAccountID = `000000000000` +const ec2metadata_iamInfoEndpoint_expectedPartition = `aws` + const iamResponse_GetUser_valid = ` AIDACKCEVSQ6C2EXAMPLE /division_abc/subdivision_xyz/ Bob - arn:aws:iam::123456789012:user/division_abc/subdivision_xyz/Bob + arn:aws:iam::111111111111:user/division_abc/subdivision_xyz/Bob 2013-10-02T17:01:44Z 2014-10-10T14:37:51Z @@ -828,6 +904,9 @@ const iamResponse_GetUser_valid = ` Sender @@ -839,15 +918,18 @@ const iamResponse_GetUser_unauthorized = ` - arn:aws:iam::123456789012:user/Alice + arn:aws:iam::222222222222:user/Alice AKIAI44QH8DHBEXAMPLE - 123456789012 + 222222222222 01234567-89ab-cdef-0123-456789abcdef ` +const stsResponse_GetCallerIdentity_valid_expectedAccountID = `222222222222` +const stsResponse_GetCallerIdentity_valid_expectedPartition = `aws` + const stsResponse_GetCallerIdentity_unauthorized = ` Sender @@ -876,7 +958,7 @@ const iamResponse_ListRoles_valid = ` ` +const iamResponse_ListRoles_valid_expectedAccountID = `444444444444` +const iamResponse_ListRoles_valid_expectedPartition = `aws` + const iamResponse_ListRoles_unauthorized = ` Sender diff --git a/aws/config.go b/aws/config.go index 9e0e2fa738c..371a00aa015 100644 --- a/aws/config.go +++ b/aws/config.go @@ -418,26 +418,32 @@ func (c *Config) Client() (interface{}, error) { log.Println("[INFO] Initializing DeviceFarm SDK connection") client.devicefarmconn = devicefarm.New(awsDeviceFarmSess) - // These two services need to be set up early so we can check on AccountID + // Beyond verifying credentials (if enabled), we use the next set of logic + // to determine two pieces of information required for manually assembling + // resource ARNs when they are not available in the service API: + // * client.accountid + // * client.partition client.iamconn = iam.New(awsIamSess) client.stsconn = sts.New(awsStsSess) + if c.AssumeRoleARN != "" { + client.accountid, client.partition, _ = parseAccountInformationFromARN(c.AssumeRoleARN) + } + + // Validate credentials early and fail before we do any graph walking. if !c.SkipCredsValidation { - err = c.ValidateCredentials(client.stsconn) + var err error + client.accountid, client.partition, err = GetAccountInformationFromSTSGetCallerIdentity(client.stsconn) if err != nil { - return nil, err + return nil, fmt.Errorf("error validating provider credentials: %s", err) } } - // Infer AWS partition from configured region - if partition, ok := endpoints.PartitionForRegion(endpoints.DefaultPartitions(), client.region); ok { - client.partition = partition.ID() - } - - if !c.SkipRequestingAccountId { - accountID, err := GetAccountID(client.iamconn, client.stsconn, cp.ProviderName) - if err == nil { - client.accountid = accountID + if client.accountid == "" && !c.SkipRequestingAccountId { + var err error + client.accountid, client.partition, err = GetAccountInformation(client.iamconn, client.stsconn, cp.ProviderName) + if err != nil { + return nil, fmt.Errorf("Failed getting account information via all available methods. Errors: %s", err) } } @@ -446,6 +452,13 @@ func (c *Config) Client() (interface{}, error) { return nil, authErr } + // Infer AWS partition from configured region if we still need it + if client.partition == "" { + if partition, ok := endpoints.PartitionForRegion(endpoints.DefaultPartitions(), client.region); ok { + client.partition = partition.ID() + } + } + client.ec2conn = ec2.New(awsEc2Sess) if !c.SkipGetEC2Platforms { @@ -612,12 +625,6 @@ func (c *Config) ValidateRegion() error { return fmt.Errorf("Not a valid region: %s", c.Region) } -// Validate credentials early and fail before we do any graph walking. -func (c *Config) ValidateCredentials(stsconn *sts.STS) error { - _, err := stsconn.GetCallerIdentity(&sts.GetCallerIdentityInput{}) - return err -} - // ValidateAccountId returns a context-specific error if the configured account // id is explicitly forbidden or not authorised; and nil if it is authorised. func (c *Config) ValidateAccountId(accountId string) error {