Skip to content

Commit

Permalink
provider: Allow AssumeRoleARN and credential validation call to short…
Browse files Browse the repository at this point in the history
…cut account ID and partition lookup

* If provider configuration is set to assume role, use ARN to bypass account information calls
* If provider credential validation is enabled, use results to bypass account information calls
* Refactor GetAccountID back to GetAccountInformation to ensure non-SDK-default partitions can be handled
* Refactor GetAccountInformation testing
  • Loading branch information
bflad committed Jul 13, 2018
1 parent 90538bc commit cdfa7ea
Show file tree
Hide file tree
Showing 3 changed files with 470 additions and 333 deletions.
157 changes: 101 additions & 56 deletions aws/auth_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,92 +18,137 @@ import (
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/go-multierror"
)

func GetAccountID(iamconn *iam.IAM, stsconn *sts.STS, authProviderName string) (string, error) {
var errors error
func GetAccountInformation(iamconn *iam.IAM, stsconn *sts.STS, authProviderName string) (string, string, error) {
var accountID, partition string
var err, errors error

// First, try STS GetCallerIdentity
log.Println("[DEBUG] Trying to get account ID via sts:GetCallerIdentity")
outCallerIdentity, err := stsconn.GetCallerIdentity(&sts.GetCallerIdentityInput{})
if err == nil {
return parseAccountIDFromArn(*outCallerIdentity.Arn)
if authProviderName == ec2rolecreds.ProviderName {
accountID, partition, err = GetAccountInformationFromEC2Metadata()
} else {
accountID, partition, err = GetAccountInformationFromIAMGetUser(iamconn)
}
if accountID != "" {
return accountID, partition, nil
}
log.Printf("[DEBUG] Getting account ID via sts:GetCallerIdentity failed: %s", err)
errors = multierror.Append(errors, err)

// If we have creds from instance profile, we can use metadata API
if authProviderName == ec2rolecreds.ProviderName {
log.Println("[DEBUG] Trying to get account ID via AWS Metadata API")
accountID, partition, err = GetAccountInformationFromSTSGetCallerIdentity(stsconn)
if accountID != "" {
return accountID, partition, nil
}
errors = multierror.Append(errors, err)

cfg := &aws.Config{}
setOptionalEndpoint(cfg)
sess, err := session.NewSession(cfg)
if err != nil {
return "", errwrap.Wrapf("Error creating AWS session: {{err}}", err)
}
accountID, partition, err = GetAccountInformationFromIAMListRoles(iamconn)
if accountID != "" {
return accountID, partition, nil
}
errors = multierror.Append(errors, err)

metadataClient := ec2metadata.New(sess)
info, err := metadataClient.IAMInfo()
if err == nil {
return parseAccountIDFromArn(info.InstanceProfileArn)
}
log.Printf("[DEBUG] Failed to get account info from metadata service: %s", err)
errors = multierror.Append(errors, err)
return accountID, partition, errors
}

func GetAccountInformationFromEC2Metadata() (string, string, error) {
log.Println("[DEBUG] Trying to get account information via EC2 Metadata")

cfg := &aws.Config{}
setOptionalEndpoint(cfg)
sess, err := session.NewSession(cfg)
if err != nil {
return "", "", fmt.Errorf("error creating EC2 Metadata session: %s", err)
}

metadataClient := ec2metadata.New(sess)
info, err := metadataClient.IAMInfo()
if err != nil {
// We can end up here if there's an issue with the instance metadata service
// or if we're getting credentials from AdRoll's Hologram (in which case IAMInfo will
// error out). In any event, if we can't get account info here, we should try
// the other methods available.
// If we have creds from something that looks like an IAM instance profile, but
// we were unable to retrieve account info from the instance profile, it's probably
// a safe assumption that we're not an IAM user
} else {
// Creds aren't from an IAM instance profile, so try try iam:GetUser
log.Println("[DEBUG] Trying to get account ID via iam:GetUser")
outUser, err := iamconn.GetUser(nil)
if err == nil {
return parseAccountIDFromArn(*outUser.User.Arn)
}
errors = multierror.Append(errors, err)
awsErr, ok := err.(awserr.Error)
// error out).
err = fmt.Errorf("failed getting account information via EC2 Metadata IAM information: %s", err)
log.Printf("[DEBUG] %s", err)
return "", "", err
}

return parseAccountInformationFromARN(info.InstanceProfileArn)
}

func GetAccountInformationFromIAMGetUser(iamconn *iam.IAM) (string, string, error) {
log.Println("[DEBUG] Trying to get account information via iam:GetUser")

output, err := iamconn.GetUser(&iam.GetUserInput{})
if err != nil {
// AccessDenied and ValidationError can be raised
// if credentials belong to federated profile, so we ignore these
if !ok || (awsErr.Code() != "AccessDenied" && awsErr.Code() != "ValidationError" && awsErr.Code() != "InvalidClientTokenId") {
return "", fmt.Errorf("Failed getting account ID via 'iam:GetUser': %s", err)
if isAWSErr(err, "AccessDenied", "") {
return "", "", nil
}
if isAWSErr(err, "InvalidClientTokenId", "") {
return "", "", nil
}
if isAWSErr(err, "ValidationError", "") {
return "", "", nil
}
log.Printf("[DEBUG] Getting account ID via iam:GetUser failed: %s", err)
err = fmt.Errorf("failed getting account information via iam:GetUser: %s", err)
log.Printf("[DEBUG] %s", err)
return "", "", err
}

if output == nil || output.User == nil {
err = errors.New("empty iam:GetUser response")
log.Printf("[DEBUG] %s", err)
return "", "", err
}

// Then try IAM ListRoles
log.Println("[DEBUG] Trying to get account ID via iam:ListRoles")
outRoles, err := iamconn.ListRoles(&iam.ListRolesInput{
return parseAccountInformationFromARN(aws.StringValue(output.User.Arn))
}

func GetAccountInformationFromIAMListRoles(iamconn *iam.IAM) (string, string, error) {
log.Println("[DEBUG] Trying to get account information via iam:ListRoles")

output, err := iamconn.ListRoles(&iam.ListRolesInput{
MaxItems: aws.Int64(int64(1)),
})
if err != nil {
log.Printf("[DEBUG] Failed to get account ID via iam:ListRoles: %s", err)
errors = multierror.Append(errors, err)
return "", fmt.Errorf("Failed getting account ID via all available methods. Errors: %s", errors)
err = fmt.Errorf("failed getting account information via iam:ListRoles: %s", err)
log.Printf("[DEBUG] %s", err)
return "", "", err
}

if output == nil || len(output.Roles) < 1 {
err = fmt.Errorf("empty iam:ListRoles response")
log.Printf("[DEBUG] %s", err)
return "", "", err
}

return parseAccountInformationFromARN(aws.StringValue(output.Roles[0].Arn))
}

func GetAccountInformationFromSTSGetCallerIdentity(stsconn *sts.STS) (string, string, error) {
log.Println("[DEBUG] Trying to get account information via sts:GetCallerIdentity")

output, err := stsconn.GetCallerIdentity(&sts.GetCallerIdentityInput{})
if err != nil {
return "", "", fmt.Errorf("error calling sts:GetCallerIdentity: %s", err)
}

if len(outRoles.Roles) < 1 {
err = fmt.Errorf("Failed to get account ID via iam:ListRoles: No roles available")
if output == nil || output.Arn == nil {
err = errors.New("empty sts:GetCallerIdentity response")
log.Printf("[DEBUG] %s", err)
errors = multierror.Append(errors, err)
return "", fmt.Errorf("Failed getting account ID via all available methods. Errors: %s", errors)
return "", "", err
}

return parseAccountIDFromArn(*outRoles.Roles[0].Arn)
return parseAccountInformationFromARN(aws.StringValue(output.Arn))
}

func parseAccountIDFromArn(inputARN string) (string, error) {
func parseAccountInformationFromARN(inputARN string) (string, string, error) {
arn, err := arn.Parse(inputARN)
if err != nil {
return "", fmt.Errorf("Unable to parse ID from invalid ARN: %q", arn)
return "", "", fmt.Errorf("error parsing ARN (%s): %s", inputARN, err)
}
return arn.AccountID, nil
return arn.AccountID, arn.Partition, nil
}

// This function is responsible for reading credentials from the
Expand Down
Loading

0 comments on commit cdfa7ea

Please sign in to comment.