Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Custom CA Bundle configuration #122

Merged
merged 5 commits into from
Feb 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .semgrep/imports.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,6 @@ rules:
- metavariable-regex:
metavariable: "$X"
regex: '^"github.com/aws/aws-sdk-go-v2/.+"$'
- pattern-not: |
import ("github.com/aws/aws-sdk-go-v2/aws/transport/http")
severity: ERROR
40 changes: 14 additions & 26 deletions aws_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ import (
"github.com/aws/smithy-go/middleware"
"github.com/hashicorp/aws-sdk-go-base/v2/internal/constants"
"github.com/hashicorp/aws-sdk-go-base/v2/internal/endpoints"
"github.com/hashicorp/aws-sdk-go-base/v2/internal/httpclient"
"github.com/hashicorp/go-multierror"
"github.com/mitchellh/go-homedir"
"github.com/hashicorp/aws-sdk-go-base/v2/internal/expand"
)

func GetAwsConfig(ctx context.Context, c *Config) (aws.Config, error) {
Expand Down Expand Up @@ -140,7 +138,7 @@ func GetAwsAccountIDAndPartition(ctx context.Context, awsConfig aws.Config, c *C
}

func commonLoadOptions(c *Config) ([]func(*config.LoadOptions) error, error) {
httpClient, err := httpclient.DefaultHttpClient(c)
httpClient, err := defaultHttpClient(c)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -172,16 +170,26 @@ func commonLoadOptions(c *Config) ([]func(*config.LoadOptions) error, error) {
}

if len(c.SharedConfigFiles) > 0 {
configFiles, err := expandFilePaths(c.SharedConfigFiles)
configFiles, err := expand.FilePaths(c.SharedConfigFiles)
if err != nil {
return nil, fmt.Errorf("error expanding shared config files: %w", err)
return nil, fmt.Errorf("expanding shared config files: %w", err)
}
loadOptions = append(
loadOptions,
config.WithSharedConfigFiles(configFiles),
)
}

if c.CustomCABundle != "" {
reader, err := c.CustomCABundleReader()
if err != nil {
return nil, err
}
loadOptions = append(loadOptions,
config.WithCustomCABundle(reader),
)
}

if c.EC2MetadataServiceEndpoint != "" {
loadOptions = append(loadOptions,
config.WithEC2IMDSEndpoint(c.EC2MetadataServiceEndpoint),
Expand Down Expand Up @@ -222,23 +230,3 @@ func commonLoadOptions(c *Config) ([]func(*config.LoadOptions) error, error) {

return loadOptions, nil
}

func expandFilePaths(in []string) ([]string, error) {
var errs *multierror.Error
result := make([]string, 0, len(in))
for _, v := range in {
p, err := expandFilePath(v)
if err != nil {
errs = multierror.Append(errs, err)
continue
}
result = append(result, p)
}
return result, errs.ErrorOrNil()
}

func expandFilePath(in string) (s string, err error) {
e := os.ExpandEnv(in)
s, err = homedir.Expand(e)
return
}
240 changes: 182 additions & 58 deletions aws_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ import (
"fmt"
"io/ioutil"
"net"
"net/http"
"os"
"path/filepath"
"reflect"
"runtime"
"strings"
Expand Down Expand Up @@ -2019,6 +2021,186 @@ ec2_metadata_service_endpoint_mode = IPv4
}
}

func TestCustomCABundle(t *testing.T) {
testCases := map[string]struct {
Config *Config
SetConfig bool
SetEnvironmentVariable bool
SetSharedConfigurationFile bool
ExpandEnvVars bool
EnvironmentVariables map[string]string
ExpectTLSClientConfigRootCAsSet bool
}{
"no configuration": {
Config: &Config{
AccessKey: servicemocks.MockStaticAccessKey,
Region: "us-east-1",
SecretKey: servicemocks.MockStaticSecretKey,
},
ExpectTLSClientConfigRootCAsSet: false,
},

"config": {
Config: &Config{
AccessKey: servicemocks.MockStaticAccessKey,
Region: "us-east-1",
SecretKey: servicemocks.MockStaticSecretKey,
},
SetConfig: true,
ExpectTLSClientConfigRootCAsSet: true,
},

"expanded config": {
Config: &Config{
AccessKey: servicemocks.MockStaticAccessKey,
Region: "us-east-1",
SecretKey: servicemocks.MockStaticSecretKey,
},
SetConfig: true,
ExpandEnvVars: true,
ExpectTLSClientConfigRootCAsSet: true,
},

"envvar": {
Config: &Config{
AccessKey: servicemocks.MockStaticAccessKey,
Region: "us-east-1",
SecretKey: servicemocks.MockStaticSecretKey,
},
SetEnvironmentVariable: true,
ExpectTLSClientConfigRootCAsSet: true,
},

// Not implemented in AWS SDK for Go v2: https://github.com/aws/aws-sdk-go-v2/issues/1589
// "shared configuration file": {
// Config: &Config{
// AccessKey: servicemocks.MockStaticAccessKey,
// Region: "us-east-1",
// SecretKey: servicemocks.MockStaticSecretKey,
// },
// SetSharedConfigurationFile: true,
// ExpectTLSClientConfigRootCAsSet: true,
// },

"config overrides envvar": {
Config: &Config{
AccessKey: servicemocks.MockStaticAccessKey,
Region: "us-east-1",
SecretKey: servicemocks.MockStaticSecretKey,
},
SetConfig: true,
EnvironmentVariables: map[string]string{
"AWS_CA_BUNDLE": "no-such-file",
},
ExpectTLSClientConfigRootCAsSet: true,
},

// Not implemented in AWS SDK for Go v2: https://github.com/aws/aws-sdk-go-v2/issues/1589
// "envvar overrides shared configuration": {
// Config: &Config{
// AccessKey: servicemocks.MockStaticAccessKey,
// Region: "us-east-1",
// SecretKey: servicemocks.MockStaticSecretKey,
// },
// EnvironmentVariables: map[string]string{
// "AWS_CA_BUNDLE": EC2MetadataEndpointModeIPv6,
// },
// SharedConfigurationFile: `
// [default]
// ec2_metadata_service_endpoint_mode = IPv4
// `,
// ExpectTLSClientConfigRootCAsSet: true,
// },
}

for testName, testCase := range testCases {
testCase := testCase

t.Run(testName, func(t *testing.T) {
oldEnv := servicemocks.InitSessionTestEnv()
defer servicemocks.PopEnv(oldEnv)

for k, v := range testCase.EnvironmentVariables {
os.Setenv(k, v)
}

tempdir, err := ioutil.TempDir("", "temp")
if err != nil {
t.Fatalf("error creating temp dir: %s", err)
}
defer os.Remove(tempdir)
os.Setenv("TMPDIR", tempdir)

pemFile, err := servicemocks.TempPEMFile()
defer os.Remove(pemFile)
if err != nil {
t.Fatalf("error creating PEM file: %s", err)
}

if testCase.ExpandEnvVars {
tmpdir := os.Getenv("TMPDIR")
rel, err := filepath.Rel(tmpdir, pemFile)
if err != nil {
t.Fatalf("error making path relative: %s", err)
}
t.Logf("relative: %s", rel)
pemFile = filepath.Join("$TMPDIR", rel)
t.Logf("env tempfile: %s", pemFile)
}

if testCase.SetConfig {
testCase.Config.CustomCABundle = pemFile
}

if testCase.SetEnvironmentVariable {
os.Setenv("AWS_CA_BUNDLE", pemFile)
}

if testCase.SetSharedConfigurationFile {
file, err := ioutil.TempFile("", "aws-sdk-go-base-shared-configuration-file")

if err != nil {
t.Fatalf("unexpected error creating temporary shared configuration file: %s", err)
}

defer os.Remove(file.Name())

err = ioutil.WriteFile(
file.Name(),
[]byte(fmt.Sprintf(`
[default]
ca_bundle = %s
`, pemFile)),
0600)

if err != nil {
t.Fatalf("unexpected error writing shared configuration file: %s", err)
}

testCase.Config.SharedConfigFiles = []string{file.Name()}
}

testCase.Config.SkipCredsValidation = true

awsConfig, err := GetAwsConfig(context.Background(), testCase.Config)
if err != nil {
t.Fatalf("error in GetAwsConfig() '%[1]T': %[1]s", err)
}

type transportGetter interface {
GetTransport() *http.Transport
}

trGetter := awsConfig.HTTPClient.(transportGetter)
tr := trGetter.GetTransport()

if a, e := tr.TLSClientConfig.RootCAs != nil, testCase.ExpectTLSClientConfigRootCAsSet; a != e {
t.Errorf("expected(%t) CA Bundle, got: %t", e, a)
}
})
}
}

func TestGetAwsConfigWithAccountIDAndPartition(t *testing.T) {
oldEnv := servicemocks.InitSessionTestEnv()
defer servicemocks.PopEnv(oldEnv)
Expand Down Expand Up @@ -2350,61 +2532,3 @@ func (r *withNoDelay) RetryDelay(attempt int, err error) (time.Duration, error)

return 0 * time.Second, nil
}

func TestExpandFilePath(t *testing.T) {
testcases := map[string]struct {
path string
expected string
envvars map[string]string
}{
"filename": {
path: "file",
expected: "file",
},
"file in current dir": {
path: "./file",
expected: "./file",
},
"file with tilde": {
path: "~/file",
expected: "/my/home/dir/file",
envvars: map[string]string{
"HOME": "/my/home/dir",
},
},
"file with envvar": {
path: "$HOME/file",
expected: "/home/dir/file",
envvars: map[string]string{
"HOME": "/home/dir",
},
},
"full file in envvar": {
path: "$CONF_FILE",
expected: "/path/to/conf/file",
envvars: map[string]string{
"CONF_FILE": "/path/to/conf/file",
},
},
}

for name, testcase := range testcases {
t.Run(name, func(t *testing.T) {
oldEnv := servicemocks.StashEnv()
defer servicemocks.PopEnv(oldEnv)

for k, v := range testcase.envvars {
os.Setenv(k, v)
}

a, err := expandFilePath(testcase.path)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}

if a != testcase.expected {
t.Errorf("expected expansion to %q, got %q", testcase.expected, a)
}
})
}
}
5 changes: 3 additions & 2 deletions credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
"github.com/aws/aws-sdk-go-v2/service/sts/types"
"github.com/hashicorp/aws-sdk-go-base/v2/internal/expand"
)

func getCredentialsProvider(ctx context.Context, c *Config) (aws.CredentialsProvider, error) {
Expand Down Expand Up @@ -41,9 +42,9 @@ func getCredentialsProvider(ctx context.Context, c *Config) (aws.CredentialsProv
)
}
if len(c.SharedCredentialsFiles) > 0 {
credsFiles, err := expandFilePaths(c.SharedCredentialsFiles)
credsFiles, err := expand.FilePaths(c.SharedCredentialsFiles)
if err != nil {
return nil, fmt.Errorf("error expanding shared credentials files: %w", err)
return nil, fmt.Errorf("expanding shared credentials files: %w", err)
}
loadOptions = append(
loadOptions,
Expand Down
4 changes: 2 additions & 2 deletions internal/httpclient/http_client.go → http_client.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package httpclient
package awsbase

import (
"fmt"
Expand All @@ -9,7 +9,7 @@ import (
"github.com/hashicorp/aws-sdk-go-base/v2/internal/config"
)

func DefaultHttpClient(c *config.Config) (*awshttp.BuildableClient, error) {
func defaultHttpClient(c *config.Config) (*awshttp.BuildableClient, error) {
var err error

httpClient := awshttp.NewBuildableClient().
Expand Down
Loading