Skip to content

Commit

Permalink
Resolve bound_iam_principal_arn to internal AWS ID
Browse files Browse the repository at this point in the history
This adds a (now-default) option on roles to resolve the
bound_iam_principal_arn (when using AWS IAM auth) to AWS's internal
unique ID. The primary reason for this is to prevent a particular role
or user from being deleted and recreated with the same ARN and thus
taking over Vault permissions that were intended to be bound to the
previous ARN, which more closely mimics AWS's behavior.

By preferentially resolving via the internal unqiue ID rather than the
ARN, this also fixes the issue in hashicorp#2729
  • Loading branch information
joelthompson committed May 26, 2017
1 parent 20eadd3 commit 0f8e818
Show file tree
Hide file tree
Showing 9 changed files with 492 additions and 126 deletions.
81 changes: 81 additions & 0 deletions builtin/credential/aws/backend.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package awsauth

import (
"fmt"
"sync"
"time"

"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/hashicorp/vault/logical"
Expand Down Expand Up @@ -54,6 +56,15 @@ type backend struct {
// When the credentials are modified or deleted, all the cached client objects
// will be flushed. The empty STS role signifies the master account
IAMClientsMap map[string]map[string]*iam.IAM

// AWS Account ID of the "default" AWS credentials
// This cache avoids the need to call GetCallerIdentity repeatedly to learn it
// We can't store this because, in certain pathological cases, it could change
// out from under us, such as a standby and active Vault server in different AWS
// accounts using their IAM instance profile to get their credentials.
defaultAWSAccountID string

resolveArnToUniqueId func(logical.Storage, string) (string, error)
}

func Backend(conf *logical.BackendConfig) (*backend, error) {
Expand All @@ -65,6 +76,8 @@ func Backend(conf *logical.BackendConfig) (*backend, error) {
IAMClientsMap: make(map[string]map[string]*iam.IAM),
}

b.resolveArnToUniqueId = b.resolveArnToRealUniqueId

b.Backend = &framework.Backend{
PeriodicFunc: b.periodicFunc,
AuthRenew: b.pathLoginRenew,
Expand Down Expand Up @@ -171,7 +184,75 @@ func (b *backend) invalidate(key string) {
defer b.configMutex.Unlock()
b.flushCachedEC2Clients()
b.flushCachedIAMClients()
b.defaultAWSAccountID = ""
}
}

// Putting this here so we can inject a fake resolver into the backend for unit testing
// purposes
func (b *backend) resolveArnToRealUniqueId(s logical.Storage, arn string) (string, error) {
entity, err := parseIamArn(arn)
if err != nil {
return "", err
}
// This odd-looking code is here because IAM is an inherently global service. IAM and STS ARNs
// don't have regions in them, and there is only a single global endpoint for IAM; see
// http://docs.aws.amazon.com/general/latest/gr/rande.html#iam_region
// However, the ARNs do have a partition in them, because the GovCloud and China partitions DO
// have their own separate endpoints, and the partition is encoded in the ARN. If Amazon's Go SDK
// would allow us to pass a partition back to the IAM client, it would be much simpler. But it
// doesn't appear that's possible, so in order to properly support GovCloud and China, we do a
// circular dance of extracting the partition from the ARN, finding any arbitrary region in the
// partition, and passing that region back back to the SDK, so that the SDK can figure out the
// proper partition from the arbitrary region we passed in to look up the endpoint.
// Sigh
region := getAnyRegionForAwsPartition(entity.Partition)
if region == nil {
return "", fmt.Errorf("Unable to resolve partition %q to a region", entity.Partition)
}
iamClient, err := b.clientIAM(s, region.ID(), entity.AccountNumber)
if err != nil {
return "", err
}

switch entity.Type {
case "user":
userInfo, err := iamClient.GetUser(&iam.GetUserInput{UserName: &entity.FriendlyName})
if err != nil {
return "", err
}
return *userInfo.User.UserId, nil
case "role":
roleInfo, err := iamClient.GetRole(&iam.GetRoleInput{RoleName: &entity.FriendlyName})
if err != nil {
return "", err
}
return *roleInfo.Role.RoleId, nil
case "instance-profile":
profileInfo, err := iamClient.GetInstanceProfile(&iam.GetInstanceProfileInput{InstanceProfileName: &entity.FriendlyName})
if err != nil {
return "", err
}
return *profileInfo.InstanceProfile.InstanceProfileId, nil
default:
return "", fmt.Errorf("unrecognized error type %#v", entity.Type)
}
}

// Adapted from https://docs.aws.amazon.com/sdk-for-go/api/aws/endpoints/
// the "Enumerating Regions and Endpoint Metadata" section
func getAnyRegionForAwsPartition(partitionId string) *endpoints.Region {
resolver := endpoints.DefaultResolver()
partitions := resolver.(endpoints.EnumPartitions).Partitions()

for _, p := range partitions {
if p.ID() == partitionId {
for _, r := range p.Regions() {
return &r
}
}
}
return nil
}

const backendHelp = `
Expand Down
33 changes: 29 additions & 4 deletions builtin/credential/aws/backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1346,7 +1346,7 @@ func TestBackendAcc_LoginWithCallerIdentity(t *testing.T) {
if err != nil {
t.Fatalf("Received error retrieving identity: %s", err)
}
testIdentityArn, _, _, err := parseIamArn(*testIdentity.Arn)
entity, err := parseIamArn(*testIdentity.Arn)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -1385,7 +1385,7 @@ func TestBackendAcc_LoginWithCallerIdentity(t *testing.T) {

// configuring the valid role we'll be able to login to
roleData := map[string]interface{}{
"bound_iam_principal_arn": testIdentityArn,
"bound_iam_principal_arn": entity.canonicalArn(),
"policies": "root",
"auth_type": iamAuthType,
}
Expand Down Expand Up @@ -1417,8 +1417,17 @@ func TestBackendAcc_LoginWithCallerIdentity(t *testing.T) {
t.Fatalf("bad: failed to create role; resp:%#v\nerr:%v", resp, err)
}

fakeArn := "arn:aws:iam::123456789012:role/FakeRole"
fakeArnResolver := func(s logical.Storage, arn string) (string, error) {
if arn == fakeArn {
return fmt.Sprintf("FakeUniqueIdFor%s", fakeArn), nil
}
return b.resolveArnToRealUniqueId(s, arn)
}
b.resolveArnToUniqueId = fakeArnResolver

// now we're creating the invalid role we won't be able to login to
roleData["bound_iam_principal_arn"] = "arn:aws:iam::123456789012:role/FakeRole"
roleData["bound_iam_principal_arn"] = fakeArn
roleRequest.Path = "role/" + testInvalidRoleName
resp, err = b.HandleRequest(roleRequest)
if err != nil || (resp != nil && resp.IsError()) {
Expand Down Expand Up @@ -1491,7 +1500,7 @@ func TestBackendAcc_LoginWithCallerIdentity(t *testing.T) {
t.Errorf("bad: expected failed login due to bad auth type: resp:%#v\nerr:%v", resp, err)
}

// finally, the happy path tests :)
// finally, the happy path test :)

loginData["role"] = testValidRoleName
resp, err = b.HandleRequest(loginRequest)
Expand All @@ -1501,4 +1510,20 @@ func TestBackendAcc_LoginWithCallerIdentity(t *testing.T) {
if resp == nil || resp.Auth == nil || resp.IsError() {
t.Errorf("bad: expected valid login: resp:%#v", resp)
}

// Now, fake out the unique ID resolver to ensure we fail login if the unique ID
// changes from under us
b.resolveArnToUniqueId = resolveArnToFakeUniqueId
// First, we need to update the role to force Vault to use our fake resolver to
// pick up the fake user ID
roleData["bound_iam_principal_arn"] = entity.canonicalArn()
roleRequest.Path = "role/" + testValidRoleName
resp, err = b.HandleRequest(roleRequest)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: failed to recreate role: resp:%#v\nerr:%v", resp, err)
}
resp, err = b.HandleRequest(loginRequest)
if err != nil || resp == nil || !resp.IsError() {
t.Errorf("bad: expected failed login due to changed AWS role ID: resp: %#v\nerr:%v", resp, err)
}
}
66 changes: 51 additions & 15 deletions builtin/credential/aws/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/vault/helper/awsutil"
"github.com/hashicorp/vault/logical"
Expand Down Expand Up @@ -70,7 +71,7 @@ func (b *backend) getRawClientConfig(s logical.Storage, region, clientType strin
// It uses getRawClientConfig to obtain config for the runtime environemnt, and if
// stsRole is a non-empty string, it will use AssumeRole to obtain a set of assumed
// credentials. The credentials will expire after 15 minutes but will auto-refresh.
func (b *backend) getClientConfig(s logical.Storage, region, stsRole, clientType string) (*aws.Config, error) {
func (b *backend) getClientConfig(s logical.Storage, region, stsRole, accountID, clientType string) (*aws.Config, error) {

config, err := b.getRawClientConfig(s, region, clientType)
if err != nil {
Expand All @@ -80,20 +81,36 @@ func (b *backend) getClientConfig(s logical.Storage, region, stsRole, clientType
return nil, fmt.Errorf("could not compile valid credentials through the default provider chain")
}

stsConfig, err := b.getRawClientConfig(s, region, "sts")
if stsConfig == nil {
return nil, fmt.Errorf("could not configure STS client")
}
if err != nil {
return nil, err
}
if stsRole != "" {
assumeRoleConfig, err := b.getRawClientConfig(s, region, "sts")
if err != nil {
return nil, err
}
if assumeRoleConfig == nil {
return nil, fmt.Errorf("could not configure STS client")
}
assumedCredentials := stscreds.NewCredentials(session.New(assumeRoleConfig), stsRole)
assumedCredentials := stscreds.NewCredentials(session.New(stsConfig), stsRole)
// Test that we actually have permissions to assume the role
if _, err = assumedCredentials.Get(); err != nil {
return nil, err
}
config.Credentials = assumedCredentials
} else {
if b.defaultAWSAccountID == "" {
client := sts.New(session.New(stsConfig))
if client == nil {
return nil, fmt.Errorf("could not obtain sts client: %v", err)
}
inputParams := &sts.GetCallerIdentityInput{}
identity, err := client.GetCallerIdentity(inputParams)
if err != nil {
return nil, fmt.Errorf("unable to fetch current caller: %v", err)
}
b.defaultAWSAccountID = *identity.Account
}
if b.defaultAWSAccountID != accountID {
return nil, fmt.Errorf("unable to fetch client for account ID %s -- default client is for account %s", accountID, b.defaultAWSAccountID)
}
}

return config, nil
Expand Down Expand Up @@ -121,8 +138,25 @@ func (b *backend) flushCachedIAMClients() {
}
}

func (b *backend) stsRoleForAccount(s logical.Storage, accountID string) (string, error) {
// Check if an STS configuration exists for the AWS account
sts, err := b.lockedAwsStsEntry(s, accountID)
if err != nil {
return "", fmt.Errorf("error fetching STS config for account ID %q: %q\n", accountID, err)
}
// An empty STS role signifies the master account
if sts != nil {
return sts.StsRole, nil
}
return "", nil
}

// clientEC2 creates a client to interact with AWS EC2 API
func (b *backend) clientEC2(s logical.Storage, region string, stsRole string) (*ec2.EC2, error) {
func (b *backend) clientEC2(s logical.Storage, region, accountID string) (*ec2.EC2, error) {
stsRole, err := b.stsRoleForAccount(s, accountID)
if err != nil {
return nil, err
}
b.configMutex.RLock()
if b.EC2ClientsMap[region] != nil && b.EC2ClientsMap[region][stsRole] != nil {
defer b.configMutex.RUnlock()
Expand All @@ -142,8 +176,7 @@ func (b *backend) clientEC2(s logical.Storage, region string, stsRole string) (*

// Create an AWS config object using a chain of providers
var awsConfig *aws.Config
var err error
awsConfig, err = b.getClientConfig(s, region, stsRole, "ec2")
awsConfig, err = b.getClientConfig(s, region, stsRole, accountID, "ec2")

if err != nil {
return nil, err
Expand All @@ -168,7 +201,11 @@ func (b *backend) clientEC2(s logical.Storage, region string, stsRole string) (*
}

// clientIAM creates a client to interact with AWS IAM API
func (b *backend) clientIAM(s logical.Storage, region string, stsRole string) (*iam.IAM, error) {
func (b *backend) clientIAM(s logical.Storage, region, accountID string) (*iam.IAM, error) {
stsRole, err := b.stsRoleForAccount(s, accountID)
if err != nil {
return nil, err
}
b.configMutex.RLock()
if b.IAMClientsMap[region] != nil && b.IAMClientsMap[region][stsRole] != nil {
defer b.configMutex.RUnlock()
Expand All @@ -188,8 +225,7 @@ func (b *backend) clientIAM(s logical.Storage, region string, stsRole string) (*

// Create an AWS config object using a chain of providers
var awsConfig *aws.Config
var err error
awsConfig, err = b.getClientConfig(s, region, stsRole, "iam")
awsConfig, err = b.getClientConfig(s, region, stsRole, accountID, "iam")

if err != nil {
return nil, err
Expand Down
4 changes: 4 additions & 0 deletions builtin/credential/aws/path_config_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ func (b *backend) pathConfigClientDelete(
// Remove all the cached EC2 client objects in the backend.
b.flushCachedIAMClients()

// unset the cached default AWS account ID
b.defaultAWSAccountID = ""

return nil, nil
}

Expand Down Expand Up @@ -234,6 +237,7 @@ func (b *backend) pathConfigClientCreateUpdate(
if changedCreds {
b.flushCachedEC2Clients()
b.flushCachedIAMClients()
b.defaultAWSAccountID = ""
}

return nil, nil
Expand Down
Loading

0 comments on commit 0f8e818

Please sign in to comment.