Skip to content

Commit

Permalink
credentials: support AWS_CONTAINER_CREDENTIALS_FULL_URI (#1185)
Browse files Browse the repository at this point in the history
This environment variable allows the definition of a full URI to 
an ECS task endpoint on a loop-back address.

It is documented sparsely compared to 
AWS_CONTAINER_CREDENTIALS_RELATIVE_URI but is 
implemented in every major AWS SDK.

See:
- https://github.com/aws/aws-sdk-go/blob/master/aws/defaults/defaults.go#L117
- https://docs.aws.amazon.com/AWSJavaScriptSDK/latest/AWS/ECSCredentials.html
- 
https://docs.aws.amazon.com/AWSJavaSDK/latest/javadoc/com/amazonaws/auth/EC2ContainerCredentialsProviderWrapper.html
  • Loading branch information
ribbybibby authored and nitisht committed Nov 25, 2019
1 parent 3721161 commit d18cd1c
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 6 deletions.
54 changes: 48 additions & 6 deletions pkg/credentials/iam_aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
"net/url"
"os"
Expand Down Expand Up @@ -57,14 +58,35 @@ const (
)

// https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task-iam-roles.html
func getEndpoint(endpoint string) (string, bool) {
func getEndpoint(endpoint string) (string, bool, error) {
ecsFullURI := os.Getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI")
ecsURI := os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI")

if endpoint != "" {
return endpoint, os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != ""
return endpoint, ecsURI != "" || ecsFullURI != "", nil
}
if ecsFullURI != "" {
u, err := url.Parse(ecsFullURI)
if err != nil {
return "", false, err
}
host := u.Hostname()
if host == "" {
return "", false, fmt.Errorf("can't parse host from uri: %s", ecsFullURI)
}

if loopback, err := isLoopback(host); loopback {
return ecsFullURI, true, nil
} else if err != nil {
return "", false, err
} else {
return "", false, fmt.Errorf("host is not on a loopback address: %s", host)
}
}
if ecsURI := os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI"); ecsURI != "" {
return fmt.Sprintf("%s%s", defaultECSRoleEndpoint, ecsURI), true
if ecsURI != "" {
return fmt.Sprintf("%s%s", defaultECSRoleEndpoint, ecsURI), true, nil
}
return defaultIAMRoleEndpoint, false
return defaultIAMRoleEndpoint, false, nil
}

// NewIAM returns a pointer to a new Credentials object wrapping the IAM.
Expand All @@ -82,9 +104,14 @@ func NewIAM(endpoint string) *Credentials {
// Error will be returned if the request fails, or unable to extract
// the desired
func (m *IAM) Retrieve() (Value, error) {
endpoint, isEcsTask := getEndpoint(m.endpoint)
var roleCreds ec2RoleCredRespBody
var err error

endpoint, isEcsTask, err := getEndpoint(m.endpoint)
if err != nil {
return Value{}, err
}

if isEcsTask {
roleCreds, err = getEcsTaskCredentials(m.Client, endpoint)
} else {
Expand Down Expand Up @@ -248,3 +275,18 @@ func getCredentials(client *http.Client, endpoint string) (ec2RoleCredRespBody,

return respCreds, nil
}

// isLoopback identifies if a host is on a loopback address
func isLoopback(host string) (bool, error) {
ips, err := net.LookupHost(host)
if err != nil {
return false, err
}
for _, ip := range ips {
if !net.ParseIP(ip).IsLoopback() {
return false, nil
}
}

return true, nil
}
30 changes: 30 additions & 0 deletions pkg/credentials/iam_aws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,3 +243,33 @@ func TestEcsTask(t *testing.T) {
t.Error("Expected creds to be expired.")
}
}

func TestEcsTaskFullURI(t *testing.T) {
server := initEcsTaskTestServer("2014-12-16T01:51:37Z")
defer server.Close()
p := &IAM{
Client: http.DefaultClient,
}
os.Setenv("AWS_CONTAINER_CREDENTIALS_FULL_URI",
fmt.Sprintf("%s%s", server.URL, "/v2/credentials?id=task_credential_id"))
creds, err := p.Retrieve()
os.Unsetenv("AWS_CONTAINER_CREDENTIALS_FULL_URI")
if err != nil {
t.Errorf("Unexpected failure %s", err)
}
if "accessKey" != creds.AccessKeyID {
t.Errorf("Expected \"accessKey\", got %s", creds.AccessKeyID)
}

if "secret" != creds.SecretAccessKey {
t.Errorf("Expected \"secret\", got %s", creds.SecretAccessKey)
}

if "token" != creds.SessionToken {
t.Errorf("Expected \"token\", got %s", creds.SessionToken)
}

if !p.IsExpired() {
t.Error("Expected creds to be expired.")
}
}

0 comments on commit d18cd1c

Please sign in to comment.