Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

provider/aws: Add support for AssumeRole prior to operations #8638

Merged
merged 3 commits into from
Sep 3, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 67 additions & 4 deletions builtin/providers/aws/auth_helpers.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package aws

import (
"errors"
"fmt"
"log"
"os"
Expand All @@ -11,6 +12,7 @@ import (
"github.com/aws/aws-sdk-go/aws/awserr"
awsCredentials "github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/iam"
Expand Down Expand Up @@ -75,7 +77,7 @@ func GetAccountId(iamconn *iam.IAM, stsconn *sts.STS, authProviderName string) (
}

if len(outRoles.Roles) < 1 {
return "", fmt.Errorf("Failed getting account ID via 'iam:ListRoles': No roles available")
return "", errors.New("Failed getting account ID via 'iam:ListRoles': No roles available")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why errors.New and not errWrap?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are generating a new error rather than wrapping one produced by an upstream API?

}

return parseAccountIdFromArn(*outRoles.Roles[0].Arn)
Expand All @@ -92,7 +94,7 @@ func parseAccountIdFromArn(arn string) (string, error) {
// This function is responsible for reading credentials from the
// environment in the case that they're not explicitly specified
// in the Terraform configuration.
func GetCredentials(c *Config) *awsCredentials.Credentials {
func GetCredentials(c *Config) (*awsCredentials.Credentials, error) {
// build a chain provider, lazy-evaulated by aws-sdk
providers := []awsCredentials.Provider{
&awsCredentials.StaticProvider{Value: awsCredentials.Value{
Expand Down Expand Up @@ -126,7 +128,7 @@ func GetCredentials(c *Config) *awsCredentials.Credentials {
providers = append(providers, &ec2rolecreds.EC2RoleProvider{
Client: metadataClient,
})
log.Printf("[INFO] AWS EC2 instance detected via default metadata" +
log.Print("[INFO] AWS EC2 instance detected via default metadata" +
" API endpoint, EC2RoleProvider added to the auth chain")
} else {
if usedEndpoint == "" {
Expand All @@ -137,7 +139,68 @@ func GetCredentials(c *Config) *awsCredentials.Credentials {
}
}

return awsCredentials.NewChainCredentials(providers)
// This is the "normal" flow (i.e. not assuming a role)
if c.AssumeRoleARN == "" {
return awsCredentials.NewChainCredentials(providers), nil
}

// Otherwise we need to construct and STS client with the main credentials, and verify
// that we can assume the defined role.
log.Printf("[INFO] Attempting to AssumeRole %s (SessionName: %q, ExternalId: %q)",
c.AssumeRoleARN, c.AssumeRoleSessionName, c.AssumeRoleExternalID)

creds := awsCredentials.NewChainCredentials(providers)
cp, err := creds.Get()
if err != nil {
if awsErr, ok := err.(awserr.Error); ok && awsErr.Code() == "NoCredentialProviders" {
return nil, errors.New(`No valid credential sources found for AWS Provider.
Please see https://terraform.io/docs/providers/aws/index.html for more information on
providing credentials for the AWS Provider`)
}

return nil, fmt.Errorf("Error loading credentials for AWS Provider: %s", err)
}

log.Printf("[INFO] AWS Auth provider used: %q", cp.ProviderName)

awsConfig := &aws.Config{
Credentials: creds,
Region: aws.String(c.Region),
MaxRetries: aws.Int(c.MaxRetries),
HTTPClient: cleanhttp.DefaultClient(),
S3ForcePathStyle: aws.Bool(c.S3ForcePathStyle),
}

stsclient := sts.New(session.New(awsConfig))
assumeRoleProvider := &stscreds.AssumeRoleProvider{
Client: stsclient,
RoleARN: c.AssumeRoleARN,
}
if c.AssumeRoleSessionName != "" {
assumeRoleProvider.RoleSessionName = c.AssumeRoleSessionName
}
if c.AssumeRoleExternalID != "" {
assumeRoleProvider.ExternalID = aws.String(c.AssumeRoleExternalID)
}

providers = []awsCredentials.Provider{assumeRoleProvider}

assumeRoleCreds := awsCredentials.NewChainCredentials(providers)
_, err = assumeRoleCreds.Get()
if err != nil {
if awsErr, ok := err.(awserr.Error); ok && awsErr.Code() == "NoCredentialProviders" {
return nil, fmt.Errorf("The role %q cannot be assumed.\n\n"+
" There are a number of possible causes of this - the most common are:\n"+
" * The credentials used in order to assume the role are invalid\n"+
" * The credentials do not have appropriate permission to assume the role\n"+
" * The role ARN is not valid",
c.AssumeRoleARN)
}

return nil, fmt.Errorf("Error loading credentials for AWS Provider: %s", err)
}

return assumeRoleCreds, nil
}

func setOptionalEndpoint(cfg *aws.Config) string {
Expand Down
98 changes: 70 additions & 28 deletions builtin/providers/aws/auth_helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func TestAWSGetAccountId_shouldBeValid_EC2RoleHasPriority(t *testing.T) {
defer awsTs()

iamEndpoints := []*iamEndpoint{
&iamEndpoint{
{
Request: &iamRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"},
Response: &iamResponse{200, iamResponse_GetUser_valid, "text/xml"},
},
Expand All @@ -72,7 +72,7 @@ func TestAWSGetAccountId_shouldBeValid_EC2RoleHasPriority(t *testing.T) {

func TestAWSGetAccountId_shouldBeValid_fromIamUser(t *testing.T) {
iamEndpoints := []*iamEndpoint{
&iamEndpoint{
{
Request: &iamRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"},
Response: &iamResponse{200, iamResponse_GetUser_valid, "text/xml"},
},
Expand All @@ -94,11 +94,11 @@ func TestAWSGetAccountId_shouldBeValid_fromIamUser(t *testing.T) {

func TestAWSGetAccountId_shouldBeValid_fromGetCallerIdentity(t *testing.T) {
iamEndpoints := []*iamEndpoint{
&iamEndpoint{
{
Request: &iamRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"},
Response: &iamResponse{403, iamResponse_GetUser_unauthorized, "text/xml"},
},
&iamEndpoint{
{
Request: &iamRequest{"POST", "/", "Action=GetCallerIdentity&Version=2011-06-15"},
Response: &iamResponse{200, stsResponse_GetCallerIdentity_valid, "text/xml"},
},
Expand All @@ -119,15 +119,15 @@ func TestAWSGetAccountId_shouldBeValid_fromGetCallerIdentity(t *testing.T) {

func TestAWSGetAccountId_shouldBeValid_fromIamListRoles(t *testing.T) {
iamEndpoints := []*iamEndpoint{
&iamEndpoint{
{
Request: &iamRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"},
Response: &iamResponse{403, iamResponse_GetUser_unauthorized, "text/xml"},
},
&iamEndpoint{
{
Request: &iamRequest{"POST", "/", "Action=GetCallerIdentity&Version=2011-06-15"},
Response: &iamResponse{403, stsResponse_GetCallerIdentity_unauthorized, "text/xml"},
},
&iamEndpoint{
{
Request: &iamRequest{"POST", "/", "Action=ListRoles&MaxItems=1&Version=2010-05-08"},
Response: &iamResponse{200, iamResponse_ListRoles_valid, "text/xml"},
},
Expand All @@ -148,11 +148,11 @@ func TestAWSGetAccountId_shouldBeValid_fromIamListRoles(t *testing.T) {

func TestAWSGetAccountId_shouldBeValid_federatedRole(t *testing.T) {
iamEndpoints := []*iamEndpoint{
&iamEndpoint{
{
Request: &iamRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"},
Response: &iamResponse{400, iamResponse_GetUser_federatedFailure, "text/xml"},
},
&iamEndpoint{
{
Request: &iamRequest{"POST", "/", "Action=ListRoles&MaxItems=1&Version=2010-05-08"},
Response: &iamResponse{200, iamResponse_ListRoles_valid, "text/xml"},
},
Expand All @@ -173,11 +173,11 @@ func TestAWSGetAccountId_shouldBeValid_federatedRole(t *testing.T) {

func TestAWSGetAccountId_shouldError_unauthorizedFromIam(t *testing.T) {
iamEndpoints := []*iamEndpoint{
&iamEndpoint{
{
Request: &iamRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"},
Response: &iamResponse{403, iamResponse_GetUser_unauthorized, "text/xml"},
},
&iamEndpoint{
{
Request: &iamRequest{"POST", "/", "Action=ListRoles&MaxItems=1&Version=2010-05-08"},
Response: &iamResponse{403, iamResponse_ListRoles_unauthorized, "text/xml"},
},
Expand Down Expand Up @@ -218,15 +218,20 @@ func TestAWSGetCredentials_shouldError(t *testing.T) {
defer resetEnv()
cfg := Config{}

c := GetCredentials(&cfg)
_, err := c.Get()
c, err := GetCredentials(&cfg)
if awsErr, ok := err.(awserr.Error); ok {
if awsErr.Code() != "NoCredentialProviders" {
t.Fatal("Expected NoCredentialProviders error")
}
}
_, err = c.Get()
if awsErr, ok := err.(awserr.Error); ok {
if awsErr.Code() != "NoCredentialProviders" {
t.Fatalf("Expected NoCredentialProviders error")
t.Fatal("Expected NoCredentialProviders error")
}
}
if err == nil {
t.Fatalf("Expected an error with empty env, keys, and IAM in AWS Config")
t.Fatal("Expected an error with empty env, keys, and IAM in AWS Config")
}
}

Expand All @@ -251,14 +256,19 @@ func TestAWSGetCredentials_shouldBeStatic(t *testing.T) {
Token: c.Token,
}

creds := GetCredentials(&cfg)
creds, err := GetCredentials(&cfg)
if err != nil {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should check the err before we check if the creds are nil

t.Fatalf("Error gettings creds: %s", err)
}
if creds == nil {
t.Fatalf("Expected a static creds provider to be returned")
t.Fatal("Expected a static creds provider to be returned")
}

v, err := creds.Get()
if err != nil {
t.Fatalf("Error gettings creds: %s", err)
}

if v.AccessKeyID != c.Key {
t.Fatalf("AccessKeyID mismatch, expected: (%s), got (%s)", c.Key, v.AccessKeyID)
}
Expand Down Expand Up @@ -286,9 +296,12 @@ func TestAWSGetCredentials_shouldIAM(t *testing.T) {
// An empty config, no key supplied
cfg := Config{}

creds := GetCredentials(&cfg)
creds, err := GetCredentials(&cfg)
if err != nil {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same - err check first before creds?

t.Fatalf("Error gettings creds: %s", err)
}
if creds == nil {
t.Fatalf("Expected a static creds provider to be returned")
t.Fatal("Expected a static creds provider to be returned")
}

v, err := creds.Get()
Expand Down Expand Up @@ -335,10 +348,14 @@ func TestAWSGetCredentials_shouldIgnoreIAM(t *testing.T) {
Token: c.Token,
}

creds := GetCredentials(&cfg)
creds, err := GetCredentials(&cfg)
if err != nil {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Err before creds? :)

t.Fatalf("Error gettings creds: %s", err)
}
if creds == nil {
t.Fatalf("Expected a static creds provider to be returned")
t.Fatal("Expected a static creds provider to be returned")
}

v, err := creds.Get()
if err != nil {
t.Fatalf("Error gettings creds: %s", err)
Expand All @@ -362,7 +379,14 @@ func TestAWSGetCredentials_shouldErrorWithInvalidEndpoint(t *testing.T) {
ts := invalidAwsEnv(t)
defer ts()

creds := GetCredentials(&Config{})
creds, err := GetCredentials(&Config{})
if err != nil {
t.Fatalf("Error gettings creds: %s", err)
}
if creds == nil {
t.Fatal("Expected a static creds provider to be returned")
}

v, err := creds.Get()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to check the creds are not nil as elsewhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

if err == nil {
t.Fatal("Expected error returned when getting creds w/ invalid EC2 endpoint")
Expand All @@ -380,11 +404,17 @@ func TestAWSGetCredentials_shouldIgnoreInvalidEndpoint(t *testing.T) {
ts := invalidAwsEnv(t)
defer ts()

creds := GetCredentials(&Config{AccessKey: "accessKey", SecretKey: "secretKey"})
creds, err := GetCredentials(&Config{AccessKey: "accessKey", SecretKey: "secretKey"})
if err != nil {
t.Fatalf("Error gettings creds: %s", err)
}
v, err := creds.Get()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check that creds are not nil ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, added this now.

if err != nil {
t.Fatalf("Getting static credentials w/ invalid EC2 endpoint failed: %s", err)
}
if creds == nil {
t.Fatal("Expected a static creds provider to be returned")
}

if v.ProviderName != "StaticProvider" {
t.Fatalf("Expected provider name to be %q, %q given", "StaticProvider", v.ProviderName)
Expand All @@ -406,10 +436,14 @@ func TestAWSGetCredentials_shouldCatchEC2RoleProvider(t *testing.T) {
ts := awsEnv(t)
defer ts()

creds := GetCredentials(&Config{})
creds, err := GetCredentials(&Config{})
if err != nil {
t.Fatalf("Error gettings creds: %s", err)
}
if creds == nil {
t.Fatalf("Expected an EC2Role creds provider to be returned")
t.Fatal("Expected an EC2Role creds provider to be returned")
}

v, err := creds.Get()
if err != nil {
t.Fatalf("Expected no error when getting creds: %s", err)
Expand Down Expand Up @@ -452,10 +486,14 @@ func TestAWSGetCredentials_shouldBeShared(t *testing.T) {
t.Fatalf("Error resetting env var AWS_SHARED_CREDENTIALS_FILE: %s", err)
}

creds := GetCredentials(&Config{Profile: "myprofile", CredsFilename: file.Name()})
creds, err := GetCredentials(&Config{Profile: "myprofile", CredsFilename: file.Name()})
if err != nil {
t.Fatalf("Error gettings creds: %s", err)
}
if creds == nil {
t.Fatalf("Expected a provider chain to be returned")
t.Fatal("Expected a provider chain to be returned")
}

v, err := creds.Get()
if err != nil {
t.Fatalf("Error gettings creds: %s", err)
Expand All @@ -479,10 +517,14 @@ func TestAWSGetCredentials_shouldBeENV(t *testing.T) {
defer resetEnv()

cfg := Config{}
creds := GetCredentials(&cfg)
creds, err := GetCredentials(&cfg)
if err != nil {
t.Fatalf("Error gettings creds: %s", err)
}
if creds == nil {
t.Fatalf("Expected a static creds provider to be returned")
}

v, err := creds.Get()
if err != nil {
t.Fatalf("Error gettings creds: %s", err)
Expand Down
Loading