diff --git a/CHANGELOG_PENDING.md b/CHANGELOG_PENDING.md index 8a1927a39ca..37916bb84d8 100644 --- a/CHANGELOG_PENDING.md +++ b/CHANGELOG_PENDING.md @@ -3,3 +3,6 @@ ### SDK Enhancements ### SDK Bugs +* `aws/ec2metadata`: Add support for EC2 IMDS endpoint from environment variable ([#3504](https://github.com/aws/aws-sdk-go/pull/3504)) + * Adds support for specifying a custom EC2 IMDS endpoint from the environment variable, `AWS_EC2_METADATA_SERVICE_ENDPOINT`. + * The `aws/session#Options` struct also has a new field, `EC2IMDSEndpoint`. This field can be used to configure the custom endpoint of the EC2 IMDS client. The option only applies to EC2 IMDS clients created after the Session with `EC2IMDSEndpoint` is specified. diff --git a/aws/defaults/defaults_test.go b/aws/defaults/defaults_test.go index a27cf9b9ef5..2ab6ec24607 100644 --- a/aws/defaults/defaults_test.go +++ b/aws/defaults/defaults_test.go @@ -120,7 +120,7 @@ func TestDefaultEC2RoleProvider(t *testing.T) { if ec2Provider == nil { t.Fatalf("expect provider not to be nil, but was") } - if e, a := "http://169.254.169.254/latest", ec2Provider.Client.Endpoint; e != a { + if e, a := "http://169.254.169.254", ec2Provider.Client.Endpoint; e != a { t.Errorf("expect %q endpoint, got %q", e, a) } } diff --git a/aws/ec2metadata/api.go b/aws/ec2metadata/api.go index a716c021cf3..69fa63dc08f 100644 --- a/aws/ec2metadata/api.go +++ b/aws/ec2metadata/api.go @@ -20,7 +20,7 @@ func (c *EC2Metadata) getToken(ctx aws.Context, duration time.Duration) (tokenOu op := &request.Operation{ Name: "GetToken", HTTPMethod: "PUT", - HTTPPath: "/api/token", + HTTPPath: "/latest/api/token", } var output tokenOutput @@ -62,7 +62,7 @@ func (c *EC2Metadata) GetMetadataWithContext(ctx aws.Context, p string) (string, op := &request.Operation{ Name: "GetMetadata", HTTPMethod: "GET", - HTTPPath: sdkuri.PathJoin("/meta-data", p), + HTTPPath: sdkuri.PathJoin("/latest/meta-data", p), } output := &metadataOutput{} @@ -88,7 +88,7 @@ func (c *EC2Metadata) GetUserDataWithContext(ctx aws.Context) (string, error) { op := &request.Operation{ Name: "GetUserData", HTTPMethod: "GET", - HTTPPath: "/user-data", + HTTPPath: "/latest/user-data", } output := &metadataOutput{} @@ -113,7 +113,7 @@ func (c *EC2Metadata) GetDynamicDataWithContext(ctx aws.Context, p string) (stri op := &request.Operation{ Name: "GetDynamicData", HTTPMethod: "GET", - HTTPPath: sdkuri.PathJoin("/dynamic", p), + HTTPPath: sdkuri.PathJoin("/latest/dynamic", p), } output := &metadataOutput{} diff --git a/aws/ec2metadata/api_test.go b/aws/ec2metadata/api_test.go index 9139cc424f6..8e81c6dd33f 100644 --- a/aws/ec2metadata/api_test.go +++ b/aws/ec2metadata/api_test.go @@ -22,6 +22,7 @@ import ( "github.com/aws/aws-sdk-go/aws/ec2metadata" "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/awstesting/unit" + "github.com/aws/aws-sdk-go/internal/sdktesting" ) const instanceIdentityDocument = `{ @@ -106,22 +107,22 @@ func newTestServer(t *testing.T, testType testType, testServer *testServer) *htt switch testType { case SecureTestType: mux.HandleFunc("/latest/api/token", getTokenRequiredParams(t, testServer.secureGetTokenHandler)) - mux.HandleFunc("/latest/", testServer.secureGetLatestHandler) + mux.HandleFunc("/", testServer.secureGetLatestHandler) case InsecureTestType: mux.HandleFunc("/latest/api/token", testServer.insecureGetTokenHandler) - mux.HandleFunc("/latest/", testServer.insecureGetLatestHandler) + mux.HandleFunc("/", testServer.insecureGetLatestHandler) case BadRequestTestType: mux.HandleFunc("/latest/api/token", getTokenRequiredParams(t, testServer.badRequestGetTokenHandler)) - mux.HandleFunc("/latest/", testServer.badRequestGetLatestHandler) + mux.HandleFunc("/", testServer.badRequestGetLatestHandler) case ServerErrorForTokenTestType: mux.HandleFunc("/latest/api/token", getTokenRequiredParams(t, testServer.serverErrorGetTokenHandler)) - mux.HandleFunc("/latest/", testServer.insecureGetLatestHandler) + mux.HandleFunc("/", testServer.insecureGetLatestHandler) case pageNotFoundForTokenTestType: mux.HandleFunc("/latest/api/token", getTokenRequiredParams(t, testServer.pageNotFoundGetTokenHandler)) - mux.HandleFunc("/latest/", testServer.insecureGetLatestHandler) + mux.HandleFunc("/", testServer.insecureGetLatestHandler) case pageNotFoundWith401TestType: mux.HandleFunc("/latest/api/token", getTokenRequiredParams(t, testServer.pageNotFoundGetTokenHandler)) - mux.HandleFunc("/latest/", testServer.unauthorizedGetLatestHandler) + mux.HandleFunc("/", testServer.unauthorizedGetLatestHandler) } @@ -204,17 +205,17 @@ func (opListProvider *operationListProvider) addToOperationPerformedList(r *requ } func TestEndpoint(t *testing.T) { + restoreEnvFn := sdktesting.StashEnv() + defer restoreEnvFn() + c := ec2metadata.New(unit.Session) op := &request.Operation{ Name: "GetMetadata", HTTPMethod: "GET", - HTTPPath: path.Join("/", "meta-data", "testpath"), + HTTPPath: path.Join("/latest", "meta-data", "testpath"), } req := c.NewRequest(op, nil, nil) - if e, a := "http://169.254.169.254/latest", req.ClientInfo.Endpoint; e != a { - t.Errorf("expect %v, got %v", e, a) - } if e, a := "http://169.254.169.254/latest/meta-data/testpath", req.HTTPRequest.URL.String(); e != a { t.Errorf("expect %v, got %v", e, a) } @@ -289,7 +290,9 @@ func TestGetMetadata(t *testing.T) { op := &operationListProvider{} - c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")}) + c := ec2metadata.New(unit.Session, &aws.Config{ + Endpoint: aws.String(server.URL), + }) c.Handlers.Complete.PushBack(op.addToOperationPerformedList) resp, err := c.GetMetadata("some/path") @@ -340,7 +343,9 @@ func TestGetUserData_Error(t *testing.T) { })) defer server.Close() - c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")}) + c := ec2metadata.New(unit.Session, &aws.Config{ + Endpoint: aws.String(server.URL), + }) resp, err := c.GetUserData() if err == nil { @@ -425,7 +430,9 @@ func TestGetRegion(t *testing.T) { op := &operationListProvider{} - c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")}) + c := ec2metadata.New(unit.Session, &aws.Config{ + Endpoint: aws.String(server.URL), + }) c.Handlers.Complete.PushBack(op.addToOperationPerformedList) resp, err := c.Region() @@ -494,7 +501,9 @@ func TestMetadataIAMInfo_success(t *testing.T) { op := &operationListProvider{} - c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")}) + c := ec2metadata.New(unit.Session, &aws.Config{ + Endpoint: aws.String(server.URL), + }) c.Handlers.Complete.PushBack(op.addToOperationPerformedList) iamInfo, err := c.IAMInfo() @@ -570,7 +579,9 @@ func TestMetadataIAMInfo_failure(t *testing.T) { op := &operationListProvider{} - c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")}) + c := ec2metadata.New(unit.Session, &aws.Config{ + Endpoint: aws.String(server.URL), + }) c.Handlers.Complete.PushBack(op.addToOperationPerformedList) iamInfo, err := c.IAMInfo() @@ -675,7 +686,9 @@ func TestEC2RoleProviderInstanceIdentity(t *testing.T) { op := &operationListProvider{} - c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")}) + c := ec2metadata.New(unit.Session, &aws.Config{ + Endpoint: aws.String(server.URL), + }) c.Handlers.Complete.PushBack(op.addToOperationPerformedList) doc, err := c.GetInstanceIdentityDocument() @@ -719,7 +732,9 @@ func TestEC2MetadataRetryFailure(t *testing.T) { server := httptest.NewServer(mux) defer server.Close() - c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")}) + c := ec2metadata.New(unit.Session, &aws.Config{ + Endpoint: aws.String(server.URL), + }) c.Handlers.AfterRetry.PushBack(func(i *request.Request) { t.Logf("%v received, retrying operation %v", i.HTTPResponse.StatusCode, i.Operation.Name) @@ -774,7 +789,9 @@ func TestEC2MetadataRetryOnce(t *testing.T) { server := httptest.NewServer(mux) defer server.Close() - c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")}) + c := ec2metadata.New(unit.Session, &aws.Config{ + Endpoint: aws.String(server.URL), + }) // Handler on client that logs if retried c.Handlers.AfterRetry.PushBack(func(i *request.Request) { @@ -807,7 +824,9 @@ func TestEC2Metadata_Concurrency(t *testing.T) { server := newTestServer(t, SecureTestType, ts) defer server.Close() - c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")}) + c := ec2metadata.New(unit.Session, &aws.Config{ + Endpoint: aws.String(server.URL), + }) var wg sync.WaitGroup wg.Add(10) @@ -838,11 +857,13 @@ func TestRequestOnMetadata(t *testing.T) { server := newTestServer(t, SecureTestType, ts) defer server.Close() - c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")}) + c := ec2metadata.New(unit.Session, &aws.Config{ + Endpoint: aws.String(server.URL), + }) req := c.NewRequest(&request.Operation{ Name: "Ec2Metadata request", HTTPMethod: "GET", - HTTPPath: "/latest", + HTTPPath: "/latest/foo", Paginator: nil, BeforePresignFn: nil, }, nil, nil) @@ -878,7 +899,9 @@ func TestExhaustiveRetryToFetchToken(t *testing.T) { op := &operationListProvider{} - c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")}) + c := ec2metadata.New(unit.Session, &aws.Config{ + Endpoint: aws.String(server.URL), + }) c.Handlers.Complete.PushBack(op.addToOperationPerformedList) resp, err := c.GetMetadata("/some/path") @@ -930,7 +953,9 @@ func TestExhaustiveRetryWith401(t *testing.T) { op := &operationListProvider{} - c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")}) + c := ec2metadata.New(unit.Session, &aws.Config{ + Endpoint: aws.String(server.URL), + }) c.Handlers.Complete.PushBack(op.addToOperationPerformedList) resp, err := c.GetMetadata("/some/path") @@ -991,7 +1016,9 @@ func TestRequestTimeOut(t *testing.T) { op := &operationListProvider{} - c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")}) + c := ec2metadata.New(unit.Session, &aws.Config{ + Endpoint: aws.String(server.URL), + }) // for test, change the timeout to 100 ms c.Config.HTTPClient.Timeout = 100 * time.Millisecond @@ -1068,7 +1095,9 @@ func TestTokenExpiredBehavior(t *testing.T) { op := &operationListProvider{} - c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")}) + c := ec2metadata.New(unit.Session, &aws.Config{ + Endpoint: aws.String(server.URL), + }) c.Handlers.Complete.PushBack(op.addToOperationPerformedList) resp, err := c.GetMetadata("/some/path") diff --git a/aws/ec2metadata/service.go b/aws/ec2metadata/service.go index dc7e051e0c0..8f35b3464ba 100644 --- a/aws/ec2metadata/service.go +++ b/aws/ec2metadata/service.go @@ -5,6 +5,10 @@ // variable "AWS_EC2_METADATA_DISABLED=true". This environment variable set to // true instructs the SDK to disable the EC2 Metadata client. The client cannot // be used while the environment variable is set to true, (case insensitive). +// +// The endpoint of the EC2 IMDS client can be configured via the environment +// variable, AWS_EC2_METADATA_SERVICE_ENDPOINT when creating the client with a +// Session. See aws/session#Options.EC2IMDSEndpoint for more details. package ec2metadata import ( @@ -12,6 +16,7 @@ import ( "errors" "io" "net/http" + "net/url" "os" "strconv" "strings" @@ -69,6 +74,9 @@ func New(p client.ConfigProvider, cfgs ...*aws.Config) *EC2Metadata { // a client when not using a session. Generally using just New with a session // is preferred. // +// Will remove the URL path from the endpoint provided to ensure the EC2 IMDS +// client is able to communicate with the EC2 IMDS API. +// // If an unmodified HTTP client is provided from the stdlib default, or no client // the EC2RoleProvider's EC2Metadata HTTP client's timeout will be shortened. // To disable this set Config.EC2MetadataDisableTimeoutOverride to false. Enabled by default. @@ -86,6 +94,15 @@ func NewClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegio cfg.MaxRetries = aws.Int(2) } + if u, err := url.Parse(endpoint); err == nil { + // Remove path from the endpoint since it will be added by requests. + // This is an artifact of the SDK adding `/latest` to the endpoint for + // EC2 IMDS, but this is now moved to the operation definition. + u.Path = "" + u.RawPath = "" + endpoint = u.String() + } + svc := &EC2Metadata{ Client: client.New( cfg, diff --git a/aws/ec2metadata/service_test.go b/aws/ec2metadata/service_test.go index bd94856fd5a..fea0d666c2c 100644 --- a/aws/ec2metadata/service_test.go +++ b/aws/ec2metadata/service_test.go @@ -1,3 +1,5 @@ +// +build go1.7 + package ec2metadata_test import ( @@ -89,9 +91,7 @@ func TestClientDisableIMDS(t *testing.T) { os.Setenv("AWS_EC2_METADATA_DISABLED", "true") - svc := ec2metadata.New(unit.Session, &aws.Config{ - LogLevel: aws.LogLevel(aws.LogDebugWithHTTPBody), - }) + svc := ec2metadata.New(unit.Session) resp, err := svc.GetUserData() if err == nil { t.Fatalf("expect error, got none") @@ -109,6 +109,37 @@ func TestClientDisableIMDS(t *testing.T) { } } +func TestClientStripPath(t *testing.T) { + cases := map[string]struct { + Endpoint string + Expect string + }{ + "no change": { + Endpoint: "http://example.aws", + Expect: "http://example.aws", + }, + "strip path": { + Endpoint: "http://example.aws/foo", + Expect: "http://example.aws", + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + restoreEnvFn := sdktesting.StashEnv() + defer restoreEnvFn() + + svc := ec2metadata.New(unit.Session, &aws.Config{ + Endpoint: aws.String(c.Endpoint), + }) + + if e, a := c.Expect, svc.ClientInfo.Endpoint; e != a { + t.Errorf("expect %v endpoint, got %v", e, a) + } + }) + } +} + func runEC2MetadataClients(t *testing.T, cfg *aws.Config, atOnce int) { var wg sync.WaitGroup wg.Add(atOnce) diff --git a/aws/session/credentials_test.go b/aws/session/credentials_test.go index 703d9fa2cfc..e12e03126dc 100644 --- a/aws/session/credentials_test.go +++ b/aws/session/credentials_test.go @@ -24,6 +24,26 @@ import ( "github.com/aws/aws-sdk-go/service/sts" ) +func newEc2MetadataServer(key, secret string, closeAfterGetCreds bool) *httptest.Server { + var server *httptest.Server + server = httptest.NewServer(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/latest/meta-data/iam/security-credentials/RoleName" { + w.Write([]byte(fmt.Sprintf(ec2MetadataResponse, key, secret))) + + if closeAfterGetCreds { + go server.Close() + } + } else if r.URL.Path == "/latest/meta-data/iam/security-credentials/" { + w.Write([]byte("RoleName")) + } else { + w.Write([]byte("")) + } + })) + + return server +} + func setupCredentialsEndpoints(t *testing.T) (endpoints.Resolver, func()) { origECSEndpoint := shareddefaults.ECSContainerCredentialsURI @@ -37,16 +57,7 @@ func setupCredentialsEndpoints(t *testing.T) (endpoints.Resolver, func()) { })) shareddefaults.ECSContainerCredentialsURI = ecsMetadataServer.URL - ec2MetadataServer := httptest.NewServer(http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/meta-data/iam/security-credentials/RoleName" { - w.Write([]byte(ec2MetadataResponse)) - } else if r.URL.Path == "/meta-data/iam/security-credentials/" { - w.Write([]byte("RoleName")) - } else { - w.Write([]byte("")) - } - })) + ec2MetadataServer := newEc2MetadataServer("ec2_key", "ec2_secret", false) stsServer := httptest.NewServer(http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { @@ -87,15 +98,16 @@ func TestSharedConfigCredentialSource(t *testing.T) { const configFile = "testdata/credential_source_config" cases := []struct { - name string - profile string - sessOptProfile string - expectedError error - expectedAccessKey string - expectedSecretKey string - expectedChain []string - init func() - dependentOnOS bool + name string + profile string + sessOptProfile string + sessOptEC2IMDSEndpoint string + expectedError error + expectedAccessKey string + expectedSecretKey string + expectedChain []string + init func() + dependentOnOS bool }{ { name: "credential source and source profile", @@ -128,6 +140,16 @@ func TestSharedConfigCredentialSource(t *testing.T) { expectedAccessKey: "AKID", expectedSecretKey: "SECRET", }, + { + name: "ec2metadata custom EC2 IMDS endpoint, env var", + profile: "not-exists-profile", + expectedAccessKey: "ec2_custom_key", + expectedSecretKey: "ec2_custom_secret", + init: func() { + altServer := newEc2MetadataServer("ec2_custom_key", "ec2_custom_secret", true) + os.Setenv("AWS_EC2_METADATA_SERVICE_ENDPOINT", altServer.URL) + }, + }, { name: "ecs container credential source", profile: "ecscontainer", @@ -219,7 +241,8 @@ func TestSharedConfigCredentialSource(t *testing.T) { Logger: t, EndpointResolver: endpointResolver, }, - Handlers: handlers, + Handlers: handlers, + EC2IMDSEndpoint: c.sessOptEC2IMDSEndpoint, }) if e, a := c.expectedError, err; e != a { t.Fatalf("expected %v, but received %v", e, a) @@ -262,8 +285,8 @@ const ecsResponse = `{ const ec2MetadataResponse = `{ "Code": "Success", "Type": "AWS-HMAC", - "AccessKeyId" : "ec2-access-key", - "SecretAccessKey" : "ec2-secret-key", + "AccessKeyId" : "%s", + "SecretAccessKey" : "%s", "Token" : "token", "Expiration" : "2100-01-01T00:00:00Z", "LastUpdated" : "2009-11-23T0:00:00Z" diff --git a/aws/session/doc.go b/aws/session/doc.go index 7ec66e7e589..cc461bd3230 100644 --- a/aws/session/doc.go +++ b/aws/session/doc.go @@ -241,5 +241,22 @@ over the AWS_CA_BUNDLE environment variable, and will be used if both are set. Setting a custom HTTPClient in the aws.Config options will override this setting. To use this option and custom HTTP client, the HTTP client needs to be provided when creating the session. Not the service client. + +The endpoint of the EC2 IMDS client can be configured via the environment +variable, AWS_EC2_METADATA_SERVICE_ENDPOINT when creating the client with a +Session. See Options.EC2IMDSEndpoint for more details. + + AWS_EC2_METADATA_SERVICE_ENDPOINT=http://169.254.169.254 + +If using an URL with an IPv6 address literal, the IPv6 address +component must be enclosed in square brackets. + + AWS_EC2_METADATA_SERVICE_ENDPOINT=http://[::1] + +The custom EC2 IMDS endpoint can also be specified via the Session options. + + sess, err := session.NewSessionWithOptions(session.Options{ + EC2IMDSEndpoint: "http://[::1]", + }) */ package session diff --git a/aws/session/env_config.go b/aws/session/env_config.go index c1e0e9c9543..d67c261d74f 100644 --- a/aws/session/env_config.go +++ b/aws/session/env_config.go @@ -148,6 +148,11 @@ type envConfig struct { // // AWS_S3_USE_ARN_REGION=true S3UseARNRegion bool + + // Specifies the alternative endpoint to use for EC2 IMDS. + // + // AWS_EC2_METADATA_SERVICE_ENDPOINT=http://[::1] + EC2IMDSEndpoint string } var ( @@ -211,6 +216,9 @@ var ( s3UseARNRegionEnvKey = []string{ "AWS_S3_USE_ARN_REGION", } + ec2IMDSEndpointEnvKey = []string{ + "AWS_EC2_METADATA_SERVICE_ENDPOINT", + } ) // loadEnvConfig retrieves the SDK's environment configuration. @@ -332,6 +340,8 @@ func envConfigLoad(enableSharedConfig bool) (envConfig, error) { } } + setFromEnvVal(&cfg.EC2IMDSEndpoint, ec2IMDSEndpointEnvKey) + return cfg, nil } diff --git a/aws/session/env_config_test.go b/aws/session/env_config_test.go index e106dbf0784..ebc294ad8a0 100644 --- a/aws/session/env_config_test.go +++ b/aws/session/env_config_test.go @@ -302,6 +302,16 @@ func TestLoadEnvConfig(t *testing.T) { SharedConfigFile: shareddefaults.SharedConfigFilename(), }, }, + { + Env: map[string]string{ + "AWS_EC2_METADATA_SERVICE_ENDPOINT": "http://example.aws", + }, + Config: envConfig{ + EC2IMDSEndpoint: "http://example.aws", + SharedCredentialsFile: shareddefaults.SharedCredentialsFilename(), + SharedConfigFile: shareddefaults.SharedConfigFilename(), + }, + }, } for i, c := range cases { diff --git a/aws/session/session.go b/aws/session/session.go index 0ff49960510..6430a7f1526 100644 --- a/aws/session/session.go +++ b/aws/session/session.go @@ -48,6 +48,8 @@ var ErrSharedConfigInvalidCredSource = awserr.New(ErrCodeSharedConfig, "credenti type Session struct { Config *aws.Config Handlers request.Handlers + + options Options } // New creates a new instance of the handlers merging in the provided configs @@ -99,7 +101,7 @@ func New(cfgs ...*aws.Config) *Session { return s } - s := deprecatedNewSession(cfgs...) + s := deprecatedNewSession(envCfg, cfgs...) if envErr != nil { msg := "failed to load env config" s.logDeprecatedNewSessionError(msg, envErr, cfgs) @@ -243,6 +245,23 @@ type Options struct { // function to initialize this value before changing the handlers to be // used by the SDK. Handlers request.Handlers + + // Allows specifying a custom endpoint to be used by the EC2 IMDS client + // when making requests to the EC2 IMDS API. The must endpoint value must + // include protocol prefix. + // + // If unset, will the EC2 IMDS client will use its default endpoint. + // + // Can also be specified via the environment variable, + // AWS_EC2_METADATA_SERVICE_ENDPOINT. + // + // AWS_EC2_METADATA_SERVICE_ENDPOINT=http://169.254.169.254 + // + // If using an URL with an IPv6 address literal, the IPv6 address + // component must be enclosed in square brackets. + // + // AWS_EC2_METADATA_SERVICE_ENDPOINT=http://[::1] + EC2IMDSEndpoint string } // NewSessionWithOptions returns a new Session created from SDK defaults, config files, @@ -329,7 +348,25 @@ func Must(sess *Session, err error) *Session { return sess } -func deprecatedNewSession(cfgs ...*aws.Config) *Session { +// Wraps the endpoint resolver with a resolver that will return a custom +// endpoint for EC2 IMDS. +func wrapEC2IMDSEndpoint(resolver endpoints.Resolver, endpoint string) endpoints.Resolver { + return endpoints.ResolverFunc( + func(service, region string, opts ...func(*endpoints.Options)) ( + endpoints.ResolvedEndpoint, error, + ) { + if service == ec2MetadataServiceID { + return endpoints.ResolvedEndpoint{ + URL: endpoint, + SigningName: ec2MetadataServiceID, + SigningRegion: region, + }, nil + } + return resolver.EndpointFor(service, region) + }) +} + +func deprecatedNewSession(envCfg envConfig, cfgs ...*aws.Config) *Session { cfg := defaults.Config() handlers := defaults.Handlers() @@ -341,6 +378,11 @@ func deprecatedNewSession(cfgs ...*aws.Config) *Session { // endpoints for service client configurations. cfg.EndpointResolver = endpoints.DefaultResolver() } + + if len(envCfg.EC2IMDSEndpoint) != 0 { + cfg.EndpointResolver = wrapEC2IMDSEndpoint(cfg.EndpointResolver, envCfg.EC2IMDSEndpoint) + } + cfg.Credentials = defaults.CredChain(cfg, handlers) // Reapply any passed in configs to override credentials if set @@ -349,6 +391,9 @@ func deprecatedNewSession(cfgs ...*aws.Config) *Session { s := &Session{ Config: cfg, Handlers: handlers, + options: Options{ + EC2IMDSEndpoint: envCfg.EC2IMDSEndpoint, + }, } initHandlers(s) @@ -418,6 +463,7 @@ func newSession(opts Options, envCfg envConfig, cfgs ...*aws.Config) (*Session, s := &Session{ Config: cfg, Handlers: handlers, + options: opts, } initHandlers(s) @@ -570,6 +616,14 @@ func mergeConfigSrcs(cfg, userCfg *aws.Config, endpoints.LegacyS3UsEast1Endpoint, }) + ec2IMDSEndpoint := sessOpts.EC2IMDSEndpoint + if len(ec2IMDSEndpoint) == 0 { + ec2IMDSEndpoint = envCfg.EC2IMDSEndpoint + } + if len(ec2IMDSEndpoint) != 0 { + cfg.EndpointResolver = wrapEC2IMDSEndpoint(cfg.EndpointResolver, ec2IMDSEndpoint) + } + // Configure credentials if not already set by the user when creating the // Session. if cfg.Credentials == credentials.AnonymousCredentials && userCfg.Credentials == nil { @@ -627,6 +681,7 @@ func (s *Session) Copy(cfgs ...*aws.Config) *Session { newSession := &Session{ Config: s.Config.Copy(cfgs...), Handlers: s.Handlers.Copy(), + options: s.options, } initHandlers(newSession) @@ -665,6 +720,8 @@ func (s *Session) ClientConfig(service string, cfgs ...*aws.Config) client.Confi } } +const ec2MetadataServiceID = "ec2metadata" + func (s *Session) resolveEndpoint(service, region string, cfg *aws.Config) (endpoints.ResolvedEndpoint, error) { if ep := aws.StringValue(cfg.Endpoint); len(ep) != 0 { diff --git a/aws/session/session_test.go b/aws/session/session_test.go index e03e985b0b3..fd1df282099 100644 --- a/aws/session/session_test.go +++ b/aws/session/session_test.go @@ -663,3 +663,65 @@ func TestSession_RegionalEndpoints(t *testing.T) { }) } } + +func TestSession_ClientConfig_ResolveEndpoint(t *testing.T) { + cases := map[string]struct { + Service string + Region string + Env map[string]string + Options Options + ExpectEndpoint string + }{ + "IMDS custom endpoint from env": { + Service: ec2MetadataServiceID, + Region: "ignored", + Env: map[string]string{ + "AWS_EC2_METADATA_SERVICE_ENDPOINT": "http://example.aws", + }, + ExpectEndpoint: "http://example.aws", + }, + "IMDS custom endpoint from aws.Config": { + Service: ec2MetadataServiceID, + Region: "ignored", + Options: Options{ + EC2IMDSEndpoint: "http://example.aws", + }, + ExpectEndpoint: "http://example.aws", + }, + "IMDS custom endpoint from aws.Config and env": { + Service: ec2MetadataServiceID, + Region: "ignored", + Env: map[string]string{ + "AWS_EC2_METADATA_SERVICE_ENDPOINT": "http://wrong.example.aws", + }, + Options: Options{ + EC2IMDSEndpoint: "http://correct.example.aws", + }, + ExpectEndpoint: "http://correct.example.aws", + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + restoreEnvFn := initSessionTestEnv() + defer restoreEnvFn() + + for k, v := range c.Env { + os.Setenv(k, v) + } + + s, err := NewSessionWithOptions(c.Options) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + clientCfg := s.ClientConfig(c.Service, &aws.Config{ + Region: aws.String(c.Region), + }) + + if e, a := c.ExpectEndpoint, clientCfg.Endpoint; e != a { + t.Errorf("expect %v, got %v", e, a) + } + }) + } +}