Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

provider: Infer AWS partition from configured region #3173

Merged
merged 1 commit into from
Feb 5, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions aws/auth_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ import (
"fmt"
"log"
"os"
"strings"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/arn"
"github.com/aws/aws-sdk-go/aws/awserr"
awsCredentials "github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
Expand All @@ -23,7 +23,7 @@ import (
"github.com/hashicorp/go-multierror"
)

func GetAccountInfo(iamconn *iam.IAM, stsconn *sts.STS, authProviderName string) (string, string, error) {
func GetAccountID(iamconn *iam.IAM, stsconn *sts.STS, authProviderName string) (string, error) {
var errors error
// If we have creds from instance profile, we can use metadata API
if authProviderName == ec2rolecreds.ProviderName {
Expand All @@ -33,13 +33,13 @@ func GetAccountInfo(iamconn *iam.IAM, stsconn *sts.STS, authProviderName string)
setOptionalEndpoint(cfg)
sess, err := session.NewSession(cfg)
if err != nil {
return "", "", errwrap.Wrapf("Error creating AWS session: {{err}}", err)
return "", errwrap.Wrapf("Error creating AWS session: {{err}}", err)
}

metadataClient := ec2metadata.New(sess)
info, err := metadataClient.IAMInfo()
if err == nil {
return parseAccountInfoFromArn(info.InstanceProfileArn)
return parseAccountIDFromArn(info.InstanceProfileArn)
}
log.Printf("[DEBUG] Failed to get account info from metadata service: %s", err)
errors = multierror.Append(errors, err)
Expand All @@ -55,14 +55,14 @@ func GetAccountInfo(iamconn *iam.IAM, stsconn *sts.STS, authProviderName string)
log.Println("[DEBUG] Trying to get account ID via iam:GetUser")
outUser, err := iamconn.GetUser(nil)
if err == nil {
return parseAccountInfoFromArn(*outUser.User.Arn)
return parseAccountIDFromArn(*outUser.User.Arn)
}
errors = multierror.Append(errors, err)
awsErr, ok := err.(awserr.Error)
// AccessDenied and ValidationError can be raised
// if credentials belong to federated profile, so we ignore these
if !ok || (awsErr.Code() != "AccessDenied" && awsErr.Code() != "ValidationError" && awsErr.Code() != "InvalidClientTokenId") {
return "", "", fmt.Errorf("Failed getting account ID via 'iam:GetUser': %s", err)
return "", fmt.Errorf("Failed getting account ID via 'iam:GetUser': %s", err)
}
log.Printf("[DEBUG] Getting account ID via iam:GetUser failed: %s", err)
}
Expand All @@ -71,7 +71,7 @@ func GetAccountInfo(iamconn *iam.IAM, stsconn *sts.STS, authProviderName string)
log.Println("[DEBUG] Trying to get account ID via sts:GetCallerIdentity")
outCallerIdentity, err := stsconn.GetCallerIdentity(&sts.GetCallerIdentityInput{})
if err == nil {
return parseAccountInfoFromArn(*outCallerIdentity.Arn)
return parseAccountIDFromArn(*outCallerIdentity.Arn)
}
log.Printf("[DEBUG] Getting account ID via sts:GetCallerIdentity failed: %s", err)
errors = multierror.Append(errors, err)
Expand All @@ -84,25 +84,25 @@ func GetAccountInfo(iamconn *iam.IAM, stsconn *sts.STS, authProviderName string)
if err != nil {
log.Printf("[DEBUG] Failed to get account ID via iam:ListRoles: %s", err)
errors = multierror.Append(errors, err)
return "", "", fmt.Errorf("Failed getting account ID via all available methods. Errors: %s", errors)
return "", fmt.Errorf("Failed getting account ID via all available methods. Errors: %s", errors)
}

if len(outRoles.Roles) < 1 {
err = fmt.Errorf("Failed to get account ID via iam:ListRoles: No roles available")
log.Printf("[DEBUG] %s", err)
errors = multierror.Append(errors, err)
return "", "", fmt.Errorf("Failed getting account ID via all available methods. Errors: %s", errors)
return "", fmt.Errorf("Failed getting account ID via all available methods. Errors: %s", errors)
}

return parseAccountInfoFromArn(*outRoles.Roles[0].Arn)
return parseAccountIDFromArn(*outRoles.Roles[0].Arn)
}

func parseAccountInfoFromArn(arn string) (string, string, error) {
parts := strings.Split(arn, ":")
if len(parts) < 5 {
return "", "", fmt.Errorf("Unable to parse ID from invalid ARN: %q", arn)
func parseAccountIDFromArn(inputARN string) (string, error) {
arn, err := arn.Parse(inputARN)
if err != nil {
return "", fmt.Errorf("Unable to parse ID from invalid ARN: %q", arn)
}
return parts[1], parts[4], nil
return arn.AccountID, nil
}

// This function is responsible for reading credentials from the
Expand Down
80 changes: 21 additions & 59 deletions aws/auth_helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (
"github.com/aws/aws-sdk-go/service/sts"
)

func TestAWSGetAccountInfo_shouldBeValid_fromEC2Role(t *testing.T) {
func TestAWSGetAccountID_shouldBeValid_fromEC2Role(t *testing.T) {
resetEnv := unsetEnv(t)
defer resetEnv()
// capture the test server's close method, to call after the test returns
Expand All @@ -33,23 +33,18 @@ func TestAWSGetAccountInfo_shouldBeValid_fromEC2Role(t *testing.T) {
iamConn := iam.New(emptySess)
stsConn := sts.New(emptySess)

part, id, err := GetAccountInfo(iamConn, stsConn, ec2rolecreds.ProviderName)
id, err := GetAccountID(iamConn, stsConn, ec2rolecreds.ProviderName)
if err != nil {
t.Fatalf("Getting account ID from EC2 metadata API failed: %s", err)
}

expectedPart := "aws"
if part != expectedPart {
t.Fatalf("Expected partition: %s, given: %s", expectedPart, part)
}

expectedAccountId := "123456789013"
if id != expectedAccountId {
t.Fatalf("Expected account ID: %s, given: %s", expectedAccountId, id)
}
}

func TestAWSGetAccountInfo_shouldBeValid_EC2RoleHasPriority(t *testing.T) {
func TestAWSGetAccountID_shouldBeValid_EC2RoleHasPriority(t *testing.T) {
resetEnv := unsetEnv(t)
defer resetEnv()
// capture the test server's close method, to call after the test returns
Expand All @@ -75,23 +70,18 @@ func TestAWSGetAccountInfo_shouldBeValid_EC2RoleHasPriority(t *testing.T) {
}
stsConn := sts.New(stsSess)

part, id, err := GetAccountInfo(iamConn, stsConn, ec2rolecreds.ProviderName)
id, err := GetAccountID(iamConn, stsConn, ec2rolecreds.ProviderName)
if err != nil {
t.Fatalf("Getting account ID from EC2 metadata API failed: %s", err)
}

expectedPart := "aws"
if part != expectedPart {
t.Fatalf("Expected partition: %s, given: %s", expectedPart, part)
}

expectedAccountId := "123456789013"
if id != expectedAccountId {
t.Fatalf("Expected account ID: %s, given: %s", expectedAccountId, id)
}
}

func TestAWSGetAccountInfo_shouldBeValid_fromIamUser(t *testing.T) {
func TestAWSGetAccountID_shouldBeValid_fromIamUser(t *testing.T) {
iamEndpoints := []*awsMockEndpoint{
{
Request: &awsMockRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"},
Expand All @@ -113,23 +103,18 @@ func TestAWSGetAccountInfo_shouldBeValid_fromIamUser(t *testing.T) {
iamConn := iam.New(iamSess)
stsConn := sts.New(stsSess)

part, id, err := GetAccountInfo(iamConn, stsConn, "")
id, err := GetAccountID(iamConn, stsConn, "")
if err != nil {
t.Fatalf("Getting account ID via GetUser failed: %s", err)
}

expectedPart := "aws"
if part != expectedPart {
t.Fatalf("Expected partition: %s, given: %s", expectedPart, part)
}

expectedAccountId := "123456789012"
if id != expectedAccountId {
t.Fatalf("Expected account ID: %s, given: %s", expectedAccountId, id)
}
}

func TestAWSGetAccountInfo_shouldBeValid_fromGetCallerIdentity(t *testing.T) {
func TestAWSGetAccountID_shouldBeValid_fromGetCallerIdentity(t *testing.T) {
iamEndpoints := []*awsMockEndpoint{
{
Request: &awsMockRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"},
Expand All @@ -154,10 +139,10 @@ func TestAWSGetAccountInfo_shouldBeValid_fromGetCallerIdentity(t *testing.T) {
t.Fatal(err)
}

testGetAccountInfo(t, iamSess, stsSess, credentials.SharedCredsProviderName)
testGetAccountID(t, iamSess, stsSess, credentials.SharedCredsProviderName)
}

func TestAWSGetAccountInfo_shouldBeValid_EC2RoleFallsBackToCallerIdentity(t *testing.T) {
func TestAWSGetAccountID_shouldBeValid_EC2RoleFallsBackToCallerIdentity(t *testing.T) {
// This mimics the metadata service mocked by Hologram (https://github.com/AdRoll/hologram)
resetEnv := unsetEnv(t)
defer resetEnv()
Expand Down Expand Up @@ -189,10 +174,10 @@ func TestAWSGetAccountInfo_shouldBeValid_EC2RoleFallsBackToCallerIdentity(t *tes
t.Fatal(err)
}

testGetAccountInfo(t, iamSess, stsSess, ec2rolecreds.ProviderName)
testGetAccountID(t, iamSess, stsSess, ec2rolecreds.ProviderName)
}

func TestAWSGetAccountInfo_shouldBeValid_fromIamListRoles(t *testing.T) {
func TestAWSGetAccountID_shouldBeValid_fromIamListRoles(t *testing.T) {
iamEndpoints := []*awsMockEndpoint{
{
Request: &awsMockRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"},
Expand Down Expand Up @@ -224,23 +209,18 @@ func TestAWSGetAccountInfo_shouldBeValid_fromIamListRoles(t *testing.T) {
iamConn := iam.New(iamSess)
stsConn := sts.New(stsSess)

part, id, err := GetAccountInfo(iamConn, stsConn, "")
id, err := GetAccountID(iamConn, stsConn, "")
if err != nil {
t.Fatalf("Getting account ID via ListRoles failed: %s", err)
}

expectedPart := "aws"
if part != expectedPart {
t.Fatalf("Expected partition: %s, given: %s", expectedPart, part)
}

expectedAccountId := "123456789012"
if id != expectedAccountId {
t.Fatalf("Expected account ID: %s, given: %s", expectedAccountId, id)
}
}

func TestAWSGetAccountInfo_shouldBeValid_federatedRole(t *testing.T) {
func TestAWSGetAccountID_shouldBeValid_federatedRole(t *testing.T) {
iamEndpoints := []*awsMockEndpoint{
{
Request: &awsMockRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"},
Expand All @@ -266,23 +246,18 @@ func TestAWSGetAccountInfo_shouldBeValid_federatedRole(t *testing.T) {
iamConn := iam.New(iamSess)
stsConn := sts.New(stsSess)

part, id, err := GetAccountInfo(iamConn, stsConn, "")
id, err := GetAccountID(iamConn, stsConn, "")
if err != nil {
t.Fatalf("Getting account ID via ListRoles failed: %s", err)
}

expectedPart := "aws"
if part != expectedPart {
t.Fatalf("Expected partition: %s, given: %s", expectedPart, part)
}

expectedAccountId := "123456789012"
if id != expectedAccountId {
t.Fatalf("Expected account ID: %s, given: %s", expectedAccountId, id)
}
}

func TestAWSGetAccountInfo_shouldError_unauthorizedFromIam(t *testing.T) {
func TestAWSGetAccountID_shouldError_unauthorizedFromIam(t *testing.T) {
iamEndpoints := []*awsMockEndpoint{
{
Request: &awsMockRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"},
Expand All @@ -308,37 +283,29 @@ func TestAWSGetAccountInfo_shouldError_unauthorizedFromIam(t *testing.T) {
iamConn := iam.New(iamSess)
stsConn := sts.New(stsSess)

part, id, err := GetAccountInfo(iamConn, stsConn, "")
id, err := GetAccountID(iamConn, stsConn, "")
if err == nil {
t.Fatal("Expected error when getting account ID")
}

if part != "" {
t.Fatalf("Expected no partition, given: %s", part)
}

if id != "" {
t.Fatalf("Expected no account ID, given: %s", id)
}
}

func TestAWSParseAccountInfoFromArn(t *testing.T) {
func TestAWSParseAccountIDFromArn(t *testing.T) {
validArn := "arn:aws:iam::101636750127:instance-profile/aws-elasticbeanstalk-ec2-role"
expectedPart := "aws"
expectedId := "101636750127"
part, id, err := parseAccountInfoFromArn(validArn)
id, err := parseAccountIDFromArn(validArn)
if err != nil {
t.Fatalf("Expected no error when parsing valid ARN: %s", err)
}
if part != expectedPart {
t.Fatalf("Parsed part doesn't match with expected (%q != %q)", part, expectedPart)
}
if id != expectedId {
t.Fatalf("Parsed id doesn't match with expected (%q != %q)", id, expectedId)
}

invalidArn := "blablah"
part, id, err = parseAccountInfoFromArn(invalidArn)
id, err = parseAccountIDFromArn(invalidArn)
if err == nil {
t.Fatalf("Expected error when parsing invalid ARN (%q)", invalidArn)
}
Expand Down Expand Up @@ -671,21 +638,16 @@ func TestAWSGetCredentials_shouldBeENV(t *testing.T) {
}
}

func testGetAccountInfo(t *testing.T, iamSess, stsSess *session.Session, credProviderName string) {
func testGetAccountID(t *testing.T, iamSess, stsSess *session.Session, credProviderName string) {

iamConn := iam.New(iamSess)
stsConn := sts.New(stsSess)

part, id, err := GetAccountInfo(iamConn, stsConn, credProviderName)
id, err := GetAccountID(iamConn, stsConn, credProviderName)
if err != nil {
t.Fatalf("Getting account ID failed: %s", err)
}

expectedPart := "aws"
if part != expectedPart {
t.Fatalf("Expected partition: %s, given: %s", expectedPart, part)
}

expectedAccountId := "123456789012"
if id != expectedAccountId {
t.Fatalf("Expected account ID: %s, given: %s", expectedAccountId, id)
Expand Down
11 changes: 8 additions & 3 deletions aws/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/acm"
Expand Down Expand Up @@ -370,11 +371,15 @@ func (c *Config) Client() (interface{}, error) {
}
}

// Infer AWS partition from configured region
if partition, ok := endpoints.PartitionForRegion(endpoints.DefaultPartitions(), client.region); ok {
client.partition = partition.ID()
}

if !c.SkipRequestingAccountId {
partition, accountId, err := GetAccountInfo(client.iamconn, client.stsconn, cp.ProviderName)
accountID, err := GetAccountID(client.iamconn, client.stsconn, cp.ProviderName)
if err == nil {
client.partition = partition
client.accountid = accountId
client.accountid = accountID
}
}

Expand Down