Skip to content

Commit

Permalink
Merge pull request #3173 from terraform-providers/f-provider-infer-pa…
Browse files Browse the repository at this point in the history
…rtition

provider: Infer AWS partition from configured region
  • Loading branch information
bflad authored Feb 5, 2018
2 parents 2c89cac + c09b938 commit 8427cf8
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 77 deletions.
30 changes: 15 additions & 15 deletions aws/auth_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ import (
"fmt"
"log"
"os"
"strings"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/arn"
"github.com/aws/aws-sdk-go/aws/awserr"
awsCredentials "github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
Expand All @@ -23,7 +23,7 @@ import (
"github.com/hashicorp/go-multierror"
)

func GetAccountInfo(iamconn *iam.IAM, stsconn *sts.STS, authProviderName string) (string, string, error) {
func GetAccountID(iamconn *iam.IAM, stsconn *sts.STS, authProviderName string) (string, error) {
var errors error
// If we have creds from instance profile, we can use metadata API
if authProviderName == ec2rolecreds.ProviderName {
Expand All @@ -33,13 +33,13 @@ func GetAccountInfo(iamconn *iam.IAM, stsconn *sts.STS, authProviderName string)
setOptionalEndpoint(cfg)
sess, err := session.NewSession(cfg)
if err != nil {
return "", "", errwrap.Wrapf("Error creating AWS session: {{err}}", err)
return "", errwrap.Wrapf("Error creating AWS session: {{err}}", err)
}

metadataClient := ec2metadata.New(sess)
info, err := metadataClient.IAMInfo()
if err == nil {
return parseAccountInfoFromArn(info.InstanceProfileArn)
return parseAccountIDFromArn(info.InstanceProfileArn)
}
log.Printf("[DEBUG] Failed to get account info from metadata service: %s", err)
errors = multierror.Append(errors, err)
Expand All @@ -55,14 +55,14 @@ func GetAccountInfo(iamconn *iam.IAM, stsconn *sts.STS, authProviderName string)
log.Println("[DEBUG] Trying to get account ID via iam:GetUser")
outUser, err := iamconn.GetUser(nil)
if err == nil {
return parseAccountInfoFromArn(*outUser.User.Arn)
return parseAccountIDFromArn(*outUser.User.Arn)
}
errors = multierror.Append(errors, err)
awsErr, ok := err.(awserr.Error)
// AccessDenied and ValidationError can be raised
// if credentials belong to federated profile, so we ignore these
if !ok || (awsErr.Code() != "AccessDenied" && awsErr.Code() != "ValidationError" && awsErr.Code() != "InvalidClientTokenId") {
return "", "", fmt.Errorf("Failed getting account ID via 'iam:GetUser': %s", err)
return "", fmt.Errorf("Failed getting account ID via 'iam:GetUser': %s", err)
}
log.Printf("[DEBUG] Getting account ID via iam:GetUser failed: %s", err)
}
Expand All @@ -71,7 +71,7 @@ func GetAccountInfo(iamconn *iam.IAM, stsconn *sts.STS, authProviderName string)
log.Println("[DEBUG] Trying to get account ID via sts:GetCallerIdentity")
outCallerIdentity, err := stsconn.GetCallerIdentity(&sts.GetCallerIdentityInput{})
if err == nil {
return parseAccountInfoFromArn(*outCallerIdentity.Arn)
return parseAccountIDFromArn(*outCallerIdentity.Arn)
}
log.Printf("[DEBUG] Getting account ID via sts:GetCallerIdentity failed: %s", err)
errors = multierror.Append(errors, err)
Expand All @@ -84,25 +84,25 @@ func GetAccountInfo(iamconn *iam.IAM, stsconn *sts.STS, authProviderName string)
if err != nil {
log.Printf("[DEBUG] Failed to get account ID via iam:ListRoles: %s", err)
errors = multierror.Append(errors, err)
return "", "", fmt.Errorf("Failed getting account ID via all available methods. Errors: %s", errors)
return "", fmt.Errorf("Failed getting account ID via all available methods. Errors: %s", errors)
}

if len(outRoles.Roles) < 1 {
err = fmt.Errorf("Failed to get account ID via iam:ListRoles: No roles available")
log.Printf("[DEBUG] %s", err)
errors = multierror.Append(errors, err)
return "", "", fmt.Errorf("Failed getting account ID via all available methods. Errors: %s", errors)
return "", fmt.Errorf("Failed getting account ID via all available methods. Errors: %s", errors)
}

return parseAccountInfoFromArn(*outRoles.Roles[0].Arn)
return parseAccountIDFromArn(*outRoles.Roles[0].Arn)
}

func parseAccountInfoFromArn(arn string) (string, string, error) {
parts := strings.Split(arn, ":")
if len(parts) < 5 {
return "", "", fmt.Errorf("Unable to parse ID from invalid ARN: %q", arn)
func parseAccountIDFromArn(inputARN string) (string, error) {
arn, err := arn.Parse(inputARN)
if err != nil {
return "", fmt.Errorf("Unable to parse ID from invalid ARN: %q", arn)
}
return parts[1], parts[4], nil
return arn.AccountID, nil
}

// This function is responsible for reading credentials from the
Expand Down
80 changes: 21 additions & 59 deletions aws/auth_helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (
"github.com/aws/aws-sdk-go/service/sts"
)

func TestAWSGetAccountInfo_shouldBeValid_fromEC2Role(t *testing.T) {
func TestAWSGetAccountID_shouldBeValid_fromEC2Role(t *testing.T) {
resetEnv := unsetEnv(t)
defer resetEnv()
// capture the test server's close method, to call after the test returns
Expand All @@ -33,23 +33,18 @@ func TestAWSGetAccountInfo_shouldBeValid_fromEC2Role(t *testing.T) {
iamConn := iam.New(emptySess)
stsConn := sts.New(emptySess)

part, id, err := GetAccountInfo(iamConn, stsConn, ec2rolecreds.ProviderName)
id, err := GetAccountID(iamConn, stsConn, ec2rolecreds.ProviderName)
if err != nil {
t.Fatalf("Getting account ID from EC2 metadata API failed: %s", err)
}

expectedPart := "aws"
if part != expectedPart {
t.Fatalf("Expected partition: %s, given: %s", expectedPart, part)
}

expectedAccountId := "123456789013"
if id != expectedAccountId {
t.Fatalf("Expected account ID: %s, given: %s", expectedAccountId, id)
}
}

func TestAWSGetAccountInfo_shouldBeValid_EC2RoleHasPriority(t *testing.T) {
func TestAWSGetAccountID_shouldBeValid_EC2RoleHasPriority(t *testing.T) {
resetEnv := unsetEnv(t)
defer resetEnv()
// capture the test server's close method, to call after the test returns
Expand All @@ -75,23 +70,18 @@ func TestAWSGetAccountInfo_shouldBeValid_EC2RoleHasPriority(t *testing.T) {
}
stsConn := sts.New(stsSess)

part, id, err := GetAccountInfo(iamConn, stsConn, ec2rolecreds.ProviderName)
id, err := GetAccountID(iamConn, stsConn, ec2rolecreds.ProviderName)
if err != nil {
t.Fatalf("Getting account ID from EC2 metadata API failed: %s", err)
}

expectedPart := "aws"
if part != expectedPart {
t.Fatalf("Expected partition: %s, given: %s", expectedPart, part)
}

expectedAccountId := "123456789013"
if id != expectedAccountId {
t.Fatalf("Expected account ID: %s, given: %s", expectedAccountId, id)
}
}

func TestAWSGetAccountInfo_shouldBeValid_fromIamUser(t *testing.T) {
func TestAWSGetAccountID_shouldBeValid_fromIamUser(t *testing.T) {
iamEndpoints := []*awsMockEndpoint{
{
Request: &awsMockRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"},
Expand All @@ -113,23 +103,18 @@ func TestAWSGetAccountInfo_shouldBeValid_fromIamUser(t *testing.T) {
iamConn := iam.New(iamSess)
stsConn := sts.New(stsSess)

part, id, err := GetAccountInfo(iamConn, stsConn, "")
id, err := GetAccountID(iamConn, stsConn, "")
if err != nil {
t.Fatalf("Getting account ID via GetUser failed: %s", err)
}

expectedPart := "aws"
if part != expectedPart {
t.Fatalf("Expected partition: %s, given: %s", expectedPart, part)
}

expectedAccountId := "123456789012"
if id != expectedAccountId {
t.Fatalf("Expected account ID: %s, given: %s", expectedAccountId, id)
}
}

func TestAWSGetAccountInfo_shouldBeValid_fromGetCallerIdentity(t *testing.T) {
func TestAWSGetAccountID_shouldBeValid_fromGetCallerIdentity(t *testing.T) {
iamEndpoints := []*awsMockEndpoint{
{
Request: &awsMockRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"},
Expand All @@ -154,10 +139,10 @@ func TestAWSGetAccountInfo_shouldBeValid_fromGetCallerIdentity(t *testing.T) {
t.Fatal(err)
}

testGetAccountInfo(t, iamSess, stsSess, credentials.SharedCredsProviderName)
testGetAccountID(t, iamSess, stsSess, credentials.SharedCredsProviderName)
}

func TestAWSGetAccountInfo_shouldBeValid_EC2RoleFallsBackToCallerIdentity(t *testing.T) {
func TestAWSGetAccountID_shouldBeValid_EC2RoleFallsBackToCallerIdentity(t *testing.T) {
// This mimics the metadata service mocked by Hologram (https://github.com/AdRoll/hologram)
resetEnv := unsetEnv(t)
defer resetEnv()
Expand Down Expand Up @@ -189,10 +174,10 @@ func TestAWSGetAccountInfo_shouldBeValid_EC2RoleFallsBackToCallerIdentity(t *tes
t.Fatal(err)
}

testGetAccountInfo(t, iamSess, stsSess, ec2rolecreds.ProviderName)
testGetAccountID(t, iamSess, stsSess, ec2rolecreds.ProviderName)
}

func TestAWSGetAccountInfo_shouldBeValid_fromIamListRoles(t *testing.T) {
func TestAWSGetAccountID_shouldBeValid_fromIamListRoles(t *testing.T) {
iamEndpoints := []*awsMockEndpoint{
{
Request: &awsMockRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"},
Expand Down Expand Up @@ -224,23 +209,18 @@ func TestAWSGetAccountInfo_shouldBeValid_fromIamListRoles(t *testing.T) {
iamConn := iam.New(iamSess)
stsConn := sts.New(stsSess)

part, id, err := GetAccountInfo(iamConn, stsConn, "")
id, err := GetAccountID(iamConn, stsConn, "")
if err != nil {
t.Fatalf("Getting account ID via ListRoles failed: %s", err)
}

expectedPart := "aws"
if part != expectedPart {
t.Fatalf("Expected partition: %s, given: %s", expectedPart, part)
}

expectedAccountId := "123456789012"
if id != expectedAccountId {
t.Fatalf("Expected account ID: %s, given: %s", expectedAccountId, id)
}
}

func TestAWSGetAccountInfo_shouldBeValid_federatedRole(t *testing.T) {
func TestAWSGetAccountID_shouldBeValid_federatedRole(t *testing.T) {
iamEndpoints := []*awsMockEndpoint{
{
Request: &awsMockRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"},
Expand All @@ -266,23 +246,18 @@ func TestAWSGetAccountInfo_shouldBeValid_federatedRole(t *testing.T) {
iamConn := iam.New(iamSess)
stsConn := sts.New(stsSess)

part, id, err := GetAccountInfo(iamConn, stsConn, "")
id, err := GetAccountID(iamConn, stsConn, "")
if err != nil {
t.Fatalf("Getting account ID via ListRoles failed: %s", err)
}

expectedPart := "aws"
if part != expectedPart {
t.Fatalf("Expected partition: %s, given: %s", expectedPart, part)
}

expectedAccountId := "123456789012"
if id != expectedAccountId {
t.Fatalf("Expected account ID: %s, given: %s", expectedAccountId, id)
}
}

func TestAWSGetAccountInfo_shouldError_unauthorizedFromIam(t *testing.T) {
func TestAWSGetAccountID_shouldError_unauthorizedFromIam(t *testing.T) {
iamEndpoints := []*awsMockEndpoint{
{
Request: &awsMockRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"},
Expand All @@ -308,37 +283,29 @@ func TestAWSGetAccountInfo_shouldError_unauthorizedFromIam(t *testing.T) {
iamConn := iam.New(iamSess)
stsConn := sts.New(stsSess)

part, id, err := GetAccountInfo(iamConn, stsConn, "")
id, err := GetAccountID(iamConn, stsConn, "")
if err == nil {
t.Fatal("Expected error when getting account ID")
}

if part != "" {
t.Fatalf("Expected no partition, given: %s", part)
}

if id != "" {
t.Fatalf("Expected no account ID, given: %s", id)
}
}

func TestAWSParseAccountInfoFromArn(t *testing.T) {
func TestAWSParseAccountIDFromArn(t *testing.T) {
validArn := "arn:aws:iam::101636750127:instance-profile/aws-elasticbeanstalk-ec2-role"
expectedPart := "aws"
expectedId := "101636750127"
part, id, err := parseAccountInfoFromArn(validArn)
id, err := parseAccountIDFromArn(validArn)
if err != nil {
t.Fatalf("Expected no error when parsing valid ARN: %s", err)
}
if part != expectedPart {
t.Fatalf("Parsed part doesn't match with expected (%q != %q)", part, expectedPart)
}
if id != expectedId {
t.Fatalf("Parsed id doesn't match with expected (%q != %q)", id, expectedId)
}

invalidArn := "blablah"
part, id, err = parseAccountInfoFromArn(invalidArn)
id, err = parseAccountIDFromArn(invalidArn)
if err == nil {
t.Fatalf("Expected error when parsing invalid ARN (%q)", invalidArn)
}
Expand Down Expand Up @@ -671,21 +638,16 @@ func TestAWSGetCredentials_shouldBeENV(t *testing.T) {
}
}

func testGetAccountInfo(t *testing.T, iamSess, stsSess *session.Session, credProviderName string) {
func testGetAccountID(t *testing.T, iamSess, stsSess *session.Session, credProviderName string) {

iamConn := iam.New(iamSess)
stsConn := sts.New(stsSess)

part, id, err := GetAccountInfo(iamConn, stsConn, credProviderName)
id, err := GetAccountID(iamConn, stsConn, credProviderName)
if err != nil {
t.Fatalf("Getting account ID failed: %s", err)
}

expectedPart := "aws"
if part != expectedPart {
t.Fatalf("Expected partition: %s, given: %s", expectedPart, part)
}

expectedAccountId := "123456789012"
if id != expectedAccountId {
t.Fatalf("Expected account ID: %s, given: %s", expectedAccountId, id)
Expand Down
10 changes: 7 additions & 3 deletions aws/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -364,11 +364,15 @@ func (c *Config) Client() (interface{}, error) {
}
}

// Infer AWS partition from configured region
if partition, ok := endpoints.PartitionForRegion(endpoints.DefaultPartitions(), client.region); ok {
client.partition = partition.ID()
}

if !c.SkipRequestingAccountId {
partition, accountId, err := GetAccountInfo(client.iamconn, client.stsconn, cp.ProviderName)
accountID, err := GetAccountID(client.iamconn, client.stsconn, cp.ProviderName)
if err == nil {
client.partition = partition
client.accountid = accountId
client.accountid = accountID
}
}

Expand Down

0 comments on commit 8427cf8

Please sign in to comment.