Skip to content

Commit

Permalink
Added accountid to AWSClient and set it early in the initialization p…
Browse files Browse the repository at this point in the history
…hase. We use iam.GetUser(nil) scattered around to get the account id, but this isn't the most reliable method. GetAccountId now uses one more method (sts:GetCallerIdentity) to get the account id, this works with federated users.
  • Loading branch information
bigkraig committed May 5, 2016
1 parent a6ce5f1 commit 8dadb51
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 31 deletions.
17 changes: 13 additions & 4 deletions builtin/providers/aws/auth_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@ import (
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/hashicorp/go-cleanhttp"
)

func GetAccountId(iamconn *iam.IAM, authProviderName string) (string, error) {
func GetAccountId(iamconn *iam.IAM, stsconn *sts.STS, authProviderName string) (string, error) {
// If we have creds from instance profile, we can use metadata API
if authProviderName == ec2rolecreds.ProviderName {
log.Println("[DEBUG] Trying to get account ID via AWS Metadata API")
Expand All @@ -42,16 +43,24 @@ func GetAccountId(iamconn *iam.IAM, authProviderName string) (string, error) {
return parseAccountIdFromArn(*outUser.User.Arn)
}

// Then try IAM ListRoles
awsErr, ok := err.(awserr.Error)
// AccessDenied and ValidationError can be raised
// if credentials belong to federated profile, so we ignore these
if !ok || (awsErr.Code() != "AccessDenied" && awsErr.Code() != "ValidationError") {
return "", fmt.Errorf("Failed getting account ID via 'iam:GetUser': %s", err)
}

log.Printf("[DEBUG] Getting account ID via iam:GetUser failed: %s", err)
log.Println("[DEBUG] Trying to get account ID via iam:ListRoles instead")

// Then try STS GetCallerIdentity
log.Println("[DEBUG] Trying to get account ID via sts:GetCallerIdentity")
outCallerIdentity, err := stsconn.GetCallerIdentity(&sts.GetCallerIdentityInput{})
if err == nil {
return *outCallerIdentity.Account, nil
}
log.Printf("[DEBUG] Getting account ID via sts:GetCallerIdentity failed: %s", err)

// Then try IAM ListRoles
log.Println("[DEBUG] Trying to get account ID via iam:ListRoles")
outRoles, err := iamconn.ListRoles(&iam.ListRolesInput{
MaxItems: aws.Int64(int64(1)),
})
Expand Down
87 changes: 69 additions & 18 deletions builtin/providers/aws/auth_helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/sts"
)

func TestAWSGetAccountId_shouldBeValid_fromEC2Role(t *testing.T) {
Expand All @@ -28,10 +29,10 @@ func TestAWSGetAccountId_shouldBeValid_fromEC2Role(t *testing.T) {
defer awsTs()

iamEndpoints := []*iamEndpoint{}
ts, iamConn := getMockedAwsIamApi(iamEndpoints)
ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints)
defer ts()

id, err := GetAccountId(iamConn, ec2rolecreds.ProviderName)
id, err := GetAccountId(iamConn, stsConn, ec2rolecreds.ProviderName)
if err != nil {
t.Fatalf("Getting account ID from EC2 metadata API failed: %s", err)
}
Expand All @@ -55,10 +56,10 @@ func TestAWSGetAccountId_shouldBeValid_EC2RoleHasPriority(t *testing.T) {
Response: &iamResponse{200, iamResponse_GetUser_valid, "text/xml"},
},
}
ts, iamConn := getMockedAwsIamApi(iamEndpoints)
ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints)
defer ts()

id, err := GetAccountId(iamConn, ec2rolecreds.ProviderName)
id, err := GetAccountId(iamConn, stsConn, ec2rolecreds.ProviderName)
if err != nil {
t.Fatalf("Getting account ID from EC2 metadata API failed: %s", err)
}
Expand All @@ -76,10 +77,36 @@ func TestAWSGetAccountId_shouldBeValid_fromIamUser(t *testing.T) {
Response: &iamResponse{200, iamResponse_GetUser_valid, "text/xml"},
},
}
ts, iamConn := getMockedAwsIamApi(iamEndpoints)

ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints)
defer ts()

id, err := GetAccountId(iamConn, "")
id, err := GetAccountId(iamConn, stsConn, "")
if err != nil {
t.Fatalf("Getting account ID via GetUser failed: %s", err)
}

expectedAccountId := "123456789012"
if id != expectedAccountId {
t.Fatalf("Expected account ID: %s, given: %s", expectedAccountId, id)
}
}

func TestAWSGetAccountId_shouldBeValid_fromGetCallerIdentity(t *testing.T) {
iamEndpoints := []*iamEndpoint{
&iamEndpoint{
Request: &iamRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"},
Response: &iamResponse{403, iamResponse_GetUser_unauthorized, "text/xml"},
},
&iamEndpoint{
Request: &iamRequest{"POST", "/", "Action=GetCallerIdentity&Version=2011-06-15"},
Response: &iamResponse{200, stsResponse_GetCallerIdentity_valid, "text/xml"},
},
}
ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints)
defer ts()

id, err := GetAccountId(iamConn, stsConn, "")
if err != nil {
t.Fatalf("Getting account ID via GetUser failed: %s", err)
}
Expand All @@ -96,15 +123,19 @@ func TestAWSGetAccountId_shouldBeValid_fromIamListRoles(t *testing.T) {
Request: &iamRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"},
Response: &iamResponse{403, iamResponse_GetUser_unauthorized, "text/xml"},
},
&iamEndpoint{
Request: &iamRequest{"POST", "/", "Action=GetCallerIdentity&Version=2011-06-15"},
Response: &iamResponse{403, stsResponse_GetCallerIdentity_unauthorized, "text/xml"},
},
&iamEndpoint{
Request: &iamRequest{"POST", "/", "Action=ListRoles&MaxItems=1&Version=2010-05-08"},
Response: &iamResponse{200, iamResponse_ListRoles_valid, "text/xml"},
},
}
ts, iamConn := getMockedAwsIamApi(iamEndpoints)
ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints)
defer ts()

id, err := GetAccountId(iamConn, "")
id, err := GetAccountId(iamConn, stsConn, "")
if err != nil {
t.Fatalf("Getting account ID via ListRoles failed: %s", err)
}
Expand All @@ -126,10 +157,10 @@ func TestAWSGetAccountId_shouldBeValid_federatedRole(t *testing.T) {
Response: &iamResponse{200, iamResponse_ListRoles_valid, "text/xml"},
},
}
ts, iamConn := getMockedAwsIamApi(iamEndpoints)
ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints)
defer ts()

id, err := GetAccountId(iamConn, "")
id, err := GetAccountId(iamConn, stsConn, "")
if err != nil {
t.Fatalf("Getting account ID via ListRoles failed: %s", err)
}
Expand All @@ -151,10 +182,10 @@ func TestAWSGetAccountId_shouldError_unauthorizedFromIam(t *testing.T) {
Response: &iamResponse{403, iamResponse_ListRoles_unauthorized, "text/xml"},
},
}
ts, iamConn := getMockedAwsIamApi(iamEndpoints)
ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints)
defer ts()

id, err := GetAccountId(iamConn, "")
id, err := GetAccountId(iamConn, stsConn, "")
if err == nil {
t.Fatal("Expected error when getting account ID")
}
Expand Down Expand Up @@ -586,15 +617,15 @@ func invalidAwsEnv(t *testing.T) func() {
return ts.Close
}

// getMockedAwsIamApi establishes a httptest server to simulate behaviour
// of a real AWS' IAM server
func getMockedAwsIamApi(endpoints []*iamEndpoint) (func(), *iam.IAM) {
// getMockedAwsIamStsApi establishes a httptest server to simulate behaviour
// of a real AWS' IAM & STS server
func getMockedAwsIamStsApi(endpoints []*iamEndpoint) (func(), *iam.IAM, *sts.STS) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
buf := new(bytes.Buffer)
buf.ReadFrom(r.Body)
requestBody := buf.String()

log.Printf("[DEBUG] Received IAM API %q request to %q: %s",
log.Printf("[DEBUG] Received API %q request to %q: %s",
r.Method, r.RequestURI, requestBody)

for _, e := range endpoints {
Expand Down Expand Up @@ -624,8 +655,8 @@ func getMockedAwsIamApi(endpoints []*iamEndpoint) (func(), *iam.IAM) {
CredentialsChainVerboseErrors: aws.Bool(true),
})
iamConn := iam.New(sess)

return ts.Close, iamConn
stsConn := sts.New(sess)
return ts.Close, iamConn, stsConn
}

func getEnv() *currentEnv {
Expand Down Expand Up @@ -718,6 +749,26 @@ const iamResponse_GetUser_unauthorized = `<ErrorResponse xmlns="https://iam.amaz
<RequestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</RequestId>
</ErrorResponse>`

const stsResponse_GetCallerIdentity_valid = `<GetCallerIdentityResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
<GetCallerIdentityResult>
<Arn>arn:aws:iam::123456789012:user/Alice</Arn>
<UserId>AKIAI44QH8DHBEXAMPLE</UserId>
<Account>123456789012</Account>
</GetCallerIdentityResult>
<ResponseMetadata>
<RequestId>01234567-89ab-cdef-0123-456789abcdef</RequestId>
</ResponseMetadata>
</GetCallerIdentityResponse>`

const stsResponse_GetCallerIdentity_unauthorized = `<ErrorResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
<Error>
<Type>Sender</Type>
<Code>AccessDenied</Code>
<Message>User: arn:aws:iam::123456789012:user/Bob is not authorized to perform: sts:GetCallerIdentity</Message>
</Error>
<RequestId>01234567-89ab-cdef-0123-456789abcdef</RequestId>
</ErrorResponse>`

const iamResponse_GetUser_federatedFailure = `<ErrorResponse xmlns="https://iam.amazonaws.com/doc/2010-05-08/">
<Error>
<Type>Sender</Type>
Expand Down
25 changes: 16 additions & 9 deletions builtin/providers/aws/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ import (
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/sns"
"github.com/aws/aws-sdk-go/service/sqs"
"github.com/aws/aws-sdk-go/service/sts"
)

type Config struct {
Expand Down Expand Up @@ -92,8 +93,10 @@ type AWSClient struct {
s3conn *s3.S3
sqsconn *sqs.SQS
snsconn *sns.SNS
stsconn *sts.STS
redshiftconn *redshift.Redshift
r53conn *route53.Route53
accountid string
region string
rdsconn *rds.RDS
iamconn *iam.IAM
Expand Down Expand Up @@ -172,6 +175,9 @@ func (c *Config) Client() (interface{}, error) {
awsIamSess := sess.Copy(&aws.Config{Endpoint: aws.String(c.IamEndpoint)})
client.iamconn = iam.New(awsIamSess)

log.Println("[INFO] Initializing STS connection")
client.stsconn = sts.New(sess)

err = c.ValidateCredentials(client.iamconn)
if err != nil {
errs = append(errs, err)
Expand All @@ -185,6 +191,11 @@ func (c *Config) Client() (interface{}, error) {
// http://docs.aws.amazon.com/general/latest/gr/sigv4_changes.html
usEast1Sess := sess.Copy(&aws.Config{Region: aws.String("us-east-1")})

accountId, err := GetAccountId(client.iamconn, client.stsconn, cp.ProviderName)
if err == nil {
client.accountid = accountId
}

log.Println("[INFO] Initializing DynamoDB connection")
dynamoSess := sess.Copy(&aws.Config{Endpoint: aws.String(c.DynamoDBEndpoint)})
client.dynamodbconn = dynamodb.New(dynamoSess)
Expand Down Expand Up @@ -215,7 +226,7 @@ func (c *Config) Client() (interface{}, error) {
log.Println("[INFO] Initializing Elastic Beanstalk Connection")
client.elasticbeanstalkconn = elasticbeanstalk.New(sess)

authErr := c.ValidateAccountId(client.iamconn, cp.ProviderName)
authErr := c.ValidateAccountId(client.accountid)
if authErr != nil {
errs = append(errs, authErr)
}
Expand Down Expand Up @@ -338,32 +349,28 @@ func (c *Config) ValidateCredentials(iamconn *iam.IAM) error {

// ValidateAccountId returns a context-specific error if the configured account
// id is explicitly forbidden or not authorised; and nil if it is authorised.
func (c *Config) ValidateAccountId(iamconn *iam.IAM, authProviderName string) error {
func (c *Config) ValidateAccountId(accountId string) error {
if c.AllowedAccountIds == nil && c.ForbiddenAccountIds == nil {
return nil
}

log.Printf("[INFO] Validating account ID")
account_id, err := GetAccountId(iamconn, authProviderName)
if err != nil {
return err
}

if c.ForbiddenAccountIds != nil {
for _, id := range c.ForbiddenAccountIds {
if id == account_id {
if id == accountId {
return fmt.Errorf("Forbidden account ID (%s)", id)
}
}
}

if c.AllowedAccountIds != nil {
for _, id := range c.AllowedAccountIds {
if id == account_id {
if id == accountId {
return nil
}
}
return fmt.Errorf("Account ID not allowed (%s)", account_id)
return fmt.Errorf("Account ID not allowed (%s)", accountId)
}

return nil
Expand Down

0 comments on commit 8dadb51

Please sign in to comment.