Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GH-1275] Support for AWS access via IAMs AssumeRole functionality #8506

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 39 additions & 2 deletions builtin/providers/aws/auth_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@ import (
"github.com/aws/aws-sdk-go/aws/awserr"
awsCredentials "github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/go-multierror"
)

func GetAccountId(iamconn *iam.IAM, stsconn *sts.STS, authProviderName string) (string, error) {
Expand Down Expand Up @@ -92,7 +94,9 @@ func parseAccountIdFromArn(arn string) (string, error) {
// This function is responsible for reading credentials from the
// environment in the case that they're not explicitly specified
// in the Terraform configuration.
func GetCredentials(c *Config) *awsCredentials.Credentials {
func GetCredentials(c *Config) (*awsCredentials.Credentials, error) {
var errs []error

// build a chain provider, lazy-evaulated by aws-sdk
providers := []awsCredentials.Provider{
&awsCredentials.StaticProvider{Value: awsCredentials.Value{
Expand Down Expand Up @@ -137,7 +141,40 @@ func GetCredentials(c *Config) *awsCredentials.Credentials {
}
}

return awsCredentials.NewChainCredentials(providers)
if c.RoleArn != "" {
log.Printf("[INFO] attempting to assume role %s", c.RoleArn)

creds := awsCredentials.NewChainCredentials(providers)
cp, err := creds.Get()
if err != nil {
if awsErr, ok := err.(awserr.Error); ok && awsErr.Code() == "NoCredentialProviders" {
errs = append(errs, fmt.Errorf(`No valid credential sources found for AWS Provider.
Please see https://terraform.io/docs/providers/aws/index.html for more information on
providing credentials for the AWS Provider`))
} else {
errs = append(errs, fmt.Errorf("Error loading credentials for AWS Provider: %s", err))
}
return nil, &multierror.Error{Errors: errs}
}

log.Printf("[INFO] AWS Auth provider used: %q", cp.ProviderName)

awsConfig := &aws.Config{
Credentials: creds,
Region: aws.String(c.Region),
MaxRetries: aws.Int(c.MaxRetries),
HTTPClient: cleanhttp.DefaultClient(),
S3ForcePathStyle: aws.Bool(c.S3ForcePathStyle),
}

stsclient := sts.New(session.New(awsConfig))
providers = []awsCredentials.Provider{&stscreds.AssumeRoleProvider{
Client: stsclient,
RoleARN: c.RoleArn,
}}
}

return awsCredentials.NewChainCredentials(providers), nil
}

func setOptionalEndpoint(cfg *aws.Config) string {
Expand Down
50 changes: 39 additions & 11 deletions builtin/providers/aws/auth_helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,13 @@ func TestAWSGetCredentials_shouldError(t *testing.T) {
defer resetEnv()
cfg := Config{}

c := GetCredentials(&cfg)
_, err := c.Get()
c, err := GetCredentials(&cfg)
if awsErr, ok := err.(awserr.Error); ok {
if awsErr.Code() != "NoCredentialProviders" {
t.Fatalf("Expected NoCredentialProviders error")
}
}
_, err = c.Get()
if awsErr, ok := err.(awserr.Error); ok {
if awsErr.Code() != "NoCredentialProviders" {
t.Fatalf("Expected NoCredentialProviders error")
Expand Down Expand Up @@ -251,10 +256,13 @@ func TestAWSGetCredentials_shouldBeStatic(t *testing.T) {
Token: c.Token,
}

creds := GetCredentials(&cfg)
creds, err := GetCredentials(&cfg)
if creds == nil {
t.Fatalf("Expected a static creds provider to be returned")
}
if err != nil {
t.Fatalf("Error gettings creds: %s", err)
}
v, err := creds.Get()
if err != nil {
t.Fatalf("Error gettings creds: %s", err)
Expand Down Expand Up @@ -286,11 +294,13 @@ func TestAWSGetCredentials_shouldIAM(t *testing.T) {
// An empty config, no key supplied
cfg := Config{}

creds := GetCredentials(&cfg)
creds, err := GetCredentials(&cfg)
if creds == nil {
t.Fatalf("Expected a static creds provider to be returned")
}

if err != nil {
t.Fatalf("Error gettings creds: %s", err)
}
v, err := creds.Get()
if err != nil {
t.Fatalf("Error gettings creds: %s", err)
Expand Down Expand Up @@ -335,10 +345,13 @@ func TestAWSGetCredentials_shouldIgnoreIAM(t *testing.T) {
Token: c.Token,
}

creds := GetCredentials(&cfg)
creds, err := GetCredentials(&cfg)
if creds == nil {
t.Fatalf("Expected a static creds provider to be returned")
}
if err != nil {
t.Fatalf("Error gettings creds: %s", err)
}
v, err := creds.Get()
if err != nil {
t.Fatalf("Error gettings creds: %s", err)
Expand All @@ -362,7 +375,10 @@ func TestAWSGetCredentials_shouldErrorWithInvalidEndpoint(t *testing.T) {
ts := invalidAwsEnv(t)
defer ts()

creds := GetCredentials(&Config{})
creds, err := GetCredentials(&Config{})
if err != nil {
t.Fatalf("Error gettings creds: %s", err)
}
v, err := creds.Get()
if err == nil {
t.Fatal("Expected error returned when getting creds w/ invalid EC2 endpoint")
Expand All @@ -380,7 +396,10 @@ func TestAWSGetCredentials_shouldIgnoreInvalidEndpoint(t *testing.T) {
ts := invalidAwsEnv(t)
defer ts()

creds := GetCredentials(&Config{AccessKey: "accessKey", SecretKey: "secretKey"})
creds, err := GetCredentials(&Config{AccessKey: "accessKey", SecretKey: "secretKey"})
if err != nil {
t.Fatalf("Error gettings creds: %s", err)
}
v, err := creds.Get()
if err != nil {
t.Fatalf("Getting static credentials w/ invalid EC2 endpoint failed: %s", err)
Expand All @@ -406,10 +425,13 @@ func TestAWSGetCredentials_shouldCatchEC2RoleProvider(t *testing.T) {
ts := awsEnv(t)
defer ts()

creds := GetCredentials(&Config{})
creds, err := GetCredentials(&Config{})
if creds == nil {
t.Fatalf("Expected an EC2Role creds provider to be returned")
}
if err != nil {
t.Fatalf("Error gettings creds: %s", err)
}
v, err := creds.Get()
if err != nil {
t.Fatalf("Expected no error when getting creds: %s", err)
Expand Down Expand Up @@ -452,10 +474,13 @@ func TestAWSGetCredentials_shouldBeShared(t *testing.T) {
t.Fatalf("Error resetting env var AWS_SHARED_CREDENTIALS_FILE: %s", err)
}

creds := GetCredentials(&Config{Profile: "myprofile", CredsFilename: file.Name()})
creds, err := GetCredentials(&Config{Profile: "myprofile", CredsFilename: file.Name()})
if creds == nil {
t.Fatalf("Expected a provider chain to be returned")
}
if err != nil {
t.Fatalf("Error gettings creds: %s", err)
}
v, err := creds.Get()
if err != nil {
t.Fatalf("Error gettings creds: %s", err)
Expand All @@ -479,10 +504,13 @@ func TestAWSGetCredentials_shouldBeENV(t *testing.T) {
defer resetEnv()

cfg := Config{}
creds := GetCredentials(&cfg)
creds, err := GetCredentials(&cfg)
if creds == nil {
t.Fatalf("Expected a static creds provider to be returned")
}
if err != nil {
t.Fatalf("Error gettings creds: %s", err)
}
v, err := creds.Get()
if err != nil {
t.Fatalf("Error gettings creds: %s", err)
Expand Down
6 changes: 5 additions & 1 deletion builtin/providers/aws/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ type Config struct {
Profile string
Token string
Region string
RoleArn string
MaxRetries int

AllowedAccountIds []interface{}
Expand Down Expand Up @@ -150,7 +151,10 @@ func (c *Config) Client() (interface{}, error) {
client.region = c.Region

log.Println("[INFO] Building AWS auth structure")
creds := GetCredentials(c)
creds, err := GetCredentials(c)
if err != nil {
return nil, &multierror.Error{Errors: errs}
}
// Call Get to check for credential provider. If nothing found, we'll get an
// error, and we can present it nicely to the user
cp, err := creds.Get()
Expand Down
10 changes: 10 additions & 0 deletions builtin/providers/aws/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ func Provider() terraform.ResourceProvider {
InputDefault: "us-east-1",
},

"role_arn": &schema.Schema{
Type: schema.TypeString,
Optional: true,
Default: "",
Description: descriptions["role_arn"],
},

"max_retries": &schema.Schema{
Type: schema.TypeInt,
Optional: true,
Expand Down Expand Up @@ -351,6 +358,8 @@ func init() {
"profile": "The profile for API operations. If not set, the default profile\n" +
"created with `aws configure` will be used.",

"role_arn": "The role to be assumed using the supplied access_key and secret_key",

"shared_credentials_file": "The path to the shared credentials file. If not set\n" +
"this defaults to ~/.aws/credentials.",

Expand Down Expand Up @@ -402,6 +411,7 @@ func providerConfigure(d *schema.ResourceData) (interface{}, error) {
CredsFilename: d.Get("shared_credentials_file").(string),
Token: d.Get("token").(string),
Region: d.Get("region").(string),
RoleArn: d.Get("role_arn").(string),
MaxRetries: d.Get("max_retries").(int),
DynamoDBEndpoint: d.Get("dynamodb_endpoint").(string),
KinesisEndpoint: d.Get("kinesis_endpoint").(string),
Expand Down
4 changes: 2 additions & 2 deletions state/remote/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func s3Factory(conf map[string]string) (Client, error) {
kmsKeyID := conf["kms_key_id"]

var errs []error
creds := terraformAws.GetCredentials(&terraformAws.Config{
creds, err := terraformAws.GetCredentials(&terraformAws.Config{
AccessKey: conf["access_key"],
SecretKey: conf["secret_key"],
Token: conf["token"],
Expand All @@ -69,7 +69,7 @@ func s3Factory(conf map[string]string) (Client, error) {
})
// Call Get to check for credential provider. If nothing found, we'll get an
// error, and we can present it nicely to the user
_, err := creds.Get()
_, err = creds.Get()
if err != nil {
if awsErr, ok := err.(awserr.Error); ok && awsErr.Code() == "NoCredentialProviders" {
errs = append(errs, fmt.Errorf(`No valid credential sources found for AWS S3 remote.
Expand Down
13 changes: 13 additions & 0 deletions website/source/docs/providers/aws/index.html.markdown
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,19 @@ You can provide custom metadata API endpoint via `AWS_METADATA_ENDPOINT` variabl
which expects the endpoint URL including the version
and defaults to `http://169.254.169.254:80/latest`.

###Assume role

If provided with a role arn, terraform will attempt to assume this role
using the supplied credentials.

Usage:

```
provider "aws" {
role_arn = "arn:aws:iam::ACCOUNT_ID:role/ROLE_NAME"
}
```

## Argument Reference

The following arguments are supported in the `provider` block:
Expand Down