Skip to content

Commit

Permalink
Merge pull request #37851 from hashicorp/b-route53-fips
Browse files Browse the repository at this point in the history
Fixes region overrides when using custom endpoints
  • Loading branch information
gdavison authored Jun 7, 2024
2 parents 190d01e + 3b4a4fe commit 1dedb65
Show file tree
Hide file tree
Showing 239 changed files with 12,280 additions and 7,383 deletions.
27 changes: 27 additions & 0 deletions .changelog/37851.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
```release-note:bug
service/chatbot: Correctly overrides region when using custom endpoint.
```

```release-note:bug
service/costoptimizationhub: Correctly overrides region when using custom endpoint.
```

```release-note:bug
service/cur: Correctly overrides region when using custom endpoint.
```

```release-note:bug
service/globalaccelerator: Correctly overrides region when using custom endpoint.
```

```release-note:bug
service/route53: Correctly overrides region when using custom endpoint.
```

```release-note:bug
service/route53domains: Correctly overrides region when using custom endpoint.
```

```release-note:bug
service/shield: Correctly overrides region when using custom endpoint.
```
104 changes: 80 additions & 24 deletions internal/generate/serviceendpointtests/file.gtpl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"context"
{{- if ne .GoV2Package "" }}
"errors"
"reflect"
{{- end }}
"fmt"
"maps"
Expand All @@ -14,29 +15,29 @@ import (
{{- end }}
"os"
"path/filepath"
"reflect"
"strings"
"testing"

{{ if ne .GoV1Package "" }}
{{ if .ImportAWS_V1 }}
aws_sdkv1 "github.com/aws/aws-sdk-go/aws"
{{ end -}}
{{ if eq .GoV2Package "" }}"github.com/aws/aws-sdk-go/aws/endpoints"{{ end }}
{{- if eq .GoV2Package "" }}
"github.com/aws/aws-sdk-go/aws/endpoints"
{{- end }}
{{ .GoV1Package }}_sdkv1 "github.com/aws/aws-sdk-go/service/{{ .GoV1Package }}"
{{- end }}
{{- if ne .V1AlternateInputPackage "" }}
{{ .V1AlternateInputPackage }}_sdkv1 "github.com/aws/aws-sdk-go/service/{{ .V1AlternateInputPackage }}"
{{- end -}}
{{- if ne .GoV2Package "" }}
aws_sdkv2 "github.com/aws/aws-sdk-go-v2/aws"
awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
{{ .GoV2Package }}_sdkv2 "github.com/aws/aws-sdk-go-v2/service/{{ .GoV2Package }}"
{{- if .ImportAwsTypes }}
awstypes "github.com/aws/aws-sdk-go-v2/service/{{ .GoV2Package }}/types"
{{- end }}
{{- end }}
"github.com/aws/smithy-go/middleware"
smithyhttp "github.com/aws/smithy-go/transport/http"
{{- end }}
"github.com/google/go-cmp/cmp"
"github.com/hashicorp/aws-sdk-go-base/v2/servicemocks"
{{- if gt (len .Aliases) 0 }}
Expand Down Expand Up @@ -70,11 +71,17 @@ type configFile struct {
type caseExpectations struct {
diags diag.Diagnostics
endpoint string
region string
}

type apiCallParams struct {
endpoint string
region string
}

type setupFunc func(setup *caseSetup)

type callFunc func(ctx context.Context, t *testing.T, meta *conns.AWSClient) string
type callFunc func(ctx context.Context, t *testing.T, meta *conns.AWSClient) apiCallParams

const (
packageNameConfigEndpoint = "https://packagename-config.endpoint.test/"
Expand Down Expand Up @@ -109,13 +116,24 @@ const (
{{ end }}
)

const (
expectedCallRegion = {{ if .OverrideRegion }}"{{ .OverrideRegion }}"{{ else }}"{{ .Region }}"{{ end }} //lintignore:AWSAT003
)

func TestEndpointConfiguration(t *testing.T) { //nolint:paralleltest // uses t.Setenv
const region = "{{ .Region }}" //lintignore:AWSAT003
const providerRegion = "{{ .Region }}" //lintignore:AWSAT003
{{ if .OverrideRegionRegionalEndpoint -}}
// {{ .HumanFriendly }} uses a regional endpoint but is only available in one region or a limited number of regions.
// The provider overrides the region for {{ .HumanFriendly }}, but the AWS SDK's endpoint resolution returns one for the current region.
const expectedEndpointRegion = "{{ .OverrideRegion }}" //lintignore:AWSAT003
{{ else -}}
const expectedEndpointRegion = providerRegion
{{ end }}

testcases := map[string]endpointTestCase{
"no config": {
with: []setupFunc{withNoConfig},
expected: expectDefaultEndpoint(region),
expected: expectDefaultEndpoint(expectedEndpointRegion),
},

// Package name endpoint on Config
Expand Down Expand Up @@ -456,7 +474,7 @@ func TestEndpointConfiguration(t *testing.T) { //nolint:paralleltest // uses t.S
with: []setupFunc{
withUseFIPSInConfig,
},
expected: expectDefaultFIPSEndpoint(region),
expected: expectDefaultFIPSEndpoint(expectedEndpointRegion),
},

"use fips config with package name endpoint config": {
Expand All @@ -474,7 +492,7 @@ func TestEndpointConfiguration(t *testing.T) { //nolint:paralleltest // uses t.S
testcase := testcase

t.Run(name, func(t *testing.T) {
testEndpointCase(t, region, testcase, callServiceV1)
testEndpointCase(t, providerRegion, testcase, callServiceV1)
})
}
})
Expand All @@ -484,7 +502,7 @@ func TestEndpointConfiguration(t *testing.T) { //nolint:paralleltest // uses t.S
testcase := testcase

t.Run(name, func(t *testing.T) {
testEndpointCase(t, region, testcase, callServiceV2)
testEndpointCase(t, providerRegion, testcase, callServiceV2)
})
}
})
Expand All @@ -493,7 +511,7 @@ func TestEndpointConfiguration(t *testing.T) { //nolint:paralleltest // uses t.S
testcase := testcase

t.Run(name, func(t *testing.T) {
testEndpointCase(t, region, testcase, callService)
testEndpointCase(t, providerRegion, testcase, callService)
})
}
{{ end -}}
Expand Down Expand Up @@ -577,19 +595,20 @@ func defaultFIPSEndpoint(region string) string {
}

{{ if ne .GoV2Package "" }}
func callService{{ if ne .GoV1Package "" }}V2{{ end }}(ctx context.Context, t *testing.T, meta *conns.AWSClient) string {
func callService{{ if ne .GoV1Package "" }}V2{{ end }}(ctx context.Context, t *testing.T, meta *conns.AWSClient) apiCallParams {
t.Helper()

var endpoint string

client := meta.{{ .ProviderNameUpper }}Client(ctx)

var result apiCallParams

_, err := client.{{ .APICall }}(ctx, &{{ .GoV2Package }}_sdkv2.{{ .APICall }}Input{
{{ if ne .APICallParams "" }}{{ .APICallParams }},{{ end }}
},
func(opts *{{ .GoV2Package }}_sdkv2.Options) {
opts.APIOptions = append(opts.APIOptions,
addRetrieveEndpointURLMiddleware(t, &endpoint),
addRetrieveEndpointURLMiddleware(t, &result.endpoint),
addRetrieveRegionMiddleware(&result.region),
addCancelRequestMiddleware(),
)
},
Expand All @@ -600,12 +619,12 @@ func callService{{ if ne .GoV1Package "" }}V2{{ end }}(ctx context.Context, t *t
t.Fatalf("Unexpected error: %s", err)
}

return endpoint
return result
}
{{ end }}

{{ if ne .GoV1Package "" }}
func callService{{ if ne .GoV2Package "" }}V1{{ end }}(ctx context.Context, t *testing.T, meta *conns.AWSClient) string {
func callService{{ if ne .GoV2Package "" }}V1{{ end }}(ctx context.Context, t *testing.T, meta *conns.AWSClient) apiCallParams {
t.Helper()

client := meta.{{ .ProviderNameUpper }}Conn(ctx)
Expand All @@ -619,9 +638,10 @@ func callService{{ if ne .GoV2Package "" }}V1{{ end }}(ctx context.Context, t *t

req.HTTPRequest.URL.Path = "/"

endpoint := req.HTTPRequest.URL.String()

return endpoint
return apiCallParams{
endpoint: req.HTTPRequest.URL.String(),
region: aws_sdkv1.StringValue(client.Config.Region),
}
}
{{ end }}

Expand Down Expand Up @@ -699,38 +719,44 @@ func withUseFIPSInConfig(setup *caseSetup) {
func expectDefaultEndpoint(region string) caseExpectations {
return caseExpectations{
endpoint: defaultEndpoint(region),
region: expectedCallRegion,
}
}

func expectDefaultFIPSEndpoint(region string) caseExpectations {
return caseExpectations{
endpoint: defaultFIPSEndpoint(region),
region: expectedCallRegion,
}
}

func expectPackageNameConfigEndpoint() caseExpectations {
return caseExpectations{
endpoint: packageNameConfigEndpoint,
region: expectedCallRegion,
}
}

{{ range $i, $alias := .Aliases }}
func expectAliasName{{ $i }}ConfigEndpoint() caseExpectations {
return caseExpectations{
endpoint: aliasName{{ $i }}ConfigEndpoint,
region: expectedCallRegion,
}
}
{{ end }}

func expectAwsEnvVarEndpoint() caseExpectations {
return caseExpectations{
endpoint: awsServiceEnvvarEndpoint,
region: expectedCallRegion,
}
}

func expectBaseEnvVarEndpoint() caseExpectations {
return caseExpectations{
endpoint: baseEnvvarEndpoint,
region: expectedCallRegion,
}
}

Expand All @@ -741,6 +767,7 @@ func expectTfAwsEnvVarEndpoint() caseExpectations {
diags: diag.Diagnostics{
provider.DeprecatedEnvVarDiag(tfAwsEnvVar, awsEnvVar),
},
region: expectedCallRegion,
}
}
{{ end }}
Expand All @@ -752,19 +779,22 @@ func expectDeprecatedEnvVarEndpoint() caseExpectations {
diags: diag.Diagnostics{
provider.DeprecatedEnvVarDiag(deprecatedEnvVar, awsEnvVar),
},
region: expectedCallRegion,
}
}
{{ end }}

func expectServiceConfigFileEndpoint() caseExpectations {
return caseExpectations{
endpoint: serviceConfigFileEndpoint,
region: expectedCallRegion,
}
}

func expectBaseConfigFileEndpoint() caseExpectations {
return caseExpectations{
endpoint: baseConfigFileEndpoint,
region: expectedCallRegion,
}
}

Expand Down Expand Up @@ -828,13 +858,18 @@ func testEndpointCase(t *testing.T, region string, testcase endpointTestCase, ca

meta := p.Meta().(*conns.AWSClient)

endpoint := callF(ctx, t, meta)
callParams := callF(ctx, t, meta)

if e, a := testcase.expected.endpoint, callParams.endpoint; e != a {
t.Errorf("expected endpoint %q, got %q", e, a)
}

if endpoint != testcase.expected.endpoint {
t.Errorf("expected endpoint %q, got %q", testcase.expected.endpoint, endpoint)
if e, a := testcase.expected.region, callParams.region; e != a {
t.Errorf("expected region %q, got %q", e, a)
}
}

{{ if ne .GoV2Package "" }}
func addRetrieveEndpointURLMiddleware(t *testing.T, endpoint *string) func(*middleware.Stack) error {
return func(stack *middleware.Stack) error {
return stack.Finalize.Add(
Expand Down Expand Up @@ -865,6 +900,26 @@ func retrieveEndpointURLMiddleware(t *testing.T, endpoint *string) middleware.Fi
})
}

func addRetrieveRegionMiddleware(region *string) func(*middleware.Stack) error {
return func(stack *middleware.Stack) error {
return stack.Serialize.Add(
retrieveRegionMiddleware(region),
middleware.After,
)
}
}

func retrieveRegionMiddleware(region *string) middleware.SerializeMiddleware {
return middleware.SerializeMiddlewareFunc(
"Test: Retrieve Region",
func(ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler) (middleware.SerializeOutput, middleware.Metadata, error) {
*region = awsmiddleware.GetRegion(ctx)

return next.HandleSerialize(ctx, in)
},
)
}

var errCancelOperation = fmt.Errorf("Test: Canceling request")

func addCancelRequestMiddleware() func(*middleware.Stack) error {
Expand Down Expand Up @@ -897,6 +952,7 @@ func fullValueTypeName(v reflect.Value) string {
requestType := v.Type()
return fmt.Sprintf("%s.%s", requestType.PkgPath(), requestType.Name())
}
{{ end }}

func generateSharedConfigFile(config configFile) string {
var buf strings.Builder
Expand Down
Loading

0 comments on commit 1dedb65

Please sign in to comment.