Skip to content

Commit

Permalink
Merge pull request #122 from hashicorp/custom-ca-bundle
Browse files Browse the repository at this point in the history
Add Custom CA Bundle configuration
  • Loading branch information
gdavison authored Feb 17, 2022
2 parents 3038b71 + 8ae03cd commit ec08b43
Show file tree
Hide file tree
Showing 20 changed files with 736 additions and 173 deletions.
2 changes: 2 additions & 0 deletions .semgrep/imports.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,6 @@ rules:
- metavariable-regex:
metavariable: "$X"
regex: '^"github.com/aws/aws-sdk-go-v2/.+"$'
- pattern-not: |
import ("github.com/aws/aws-sdk-go-v2/aws/transport/http")
severity: ERROR
40 changes: 14 additions & 26 deletions aws_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ import (
"github.com/aws/smithy-go/middleware"
"github.com/hashicorp/aws-sdk-go-base/v2/internal/constants"
"github.com/hashicorp/aws-sdk-go-base/v2/internal/endpoints"
"github.com/hashicorp/aws-sdk-go-base/v2/internal/httpclient"
"github.com/hashicorp/go-multierror"
"github.com/mitchellh/go-homedir"
"github.com/hashicorp/aws-sdk-go-base/v2/internal/expand"
)

func GetAwsConfig(ctx context.Context, c *Config) (aws.Config, error) {
Expand Down Expand Up @@ -140,7 +138,7 @@ func GetAwsAccountIDAndPartition(ctx context.Context, awsConfig aws.Config, c *C
}

func commonLoadOptions(c *Config) ([]func(*config.LoadOptions) error, error) {
httpClient, err := httpclient.DefaultHttpClient(c)
httpClient, err := defaultHttpClient(c)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -172,16 +170,26 @@ func commonLoadOptions(c *Config) ([]func(*config.LoadOptions) error, error) {
}

if len(c.SharedConfigFiles) > 0 {
configFiles, err := expandFilePaths(c.SharedConfigFiles)
configFiles, err := expand.FilePaths(c.SharedConfigFiles)
if err != nil {
return nil, fmt.Errorf("error expanding shared config files: %w", err)
return nil, fmt.Errorf("expanding shared config files: %w", err)
}
loadOptions = append(
loadOptions,
config.WithSharedConfigFiles(configFiles),
)
}

if c.CustomCABundle != "" {
reader, err := c.CustomCABundleReader()
if err != nil {
return nil, err
}
loadOptions = append(loadOptions,
config.WithCustomCABundle(reader),
)
}

if c.EC2MetadataServiceEndpoint != "" {
loadOptions = append(loadOptions,
config.WithEC2IMDSEndpoint(c.EC2MetadataServiceEndpoint),
Expand Down Expand Up @@ -222,23 +230,3 @@ func commonLoadOptions(c *Config) ([]func(*config.LoadOptions) error, error) {

return loadOptions, nil
}

func expandFilePaths(in []string) ([]string, error) {
var errs *multierror.Error
result := make([]string, 0, len(in))
for _, v := range in {
p, err := expandFilePath(v)
if err != nil {
errs = multierror.Append(errs, err)
continue
}
result = append(result, p)
}
return result, errs.ErrorOrNil()
}

func expandFilePath(in string) (s string, err error) {
e := os.ExpandEnv(in)
s, err = homedir.Expand(e)
return
}
240 changes: 182 additions & 58 deletions aws_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ import (
"fmt"
"io/ioutil"
"net"
"net/http"
"os"
"path/filepath"
"reflect"
"runtime"
"strings"
Expand Down Expand Up @@ -2019,6 +2021,186 @@ ec2_metadata_service_endpoint_mode = IPv4
}
}

func TestCustomCABundle(t *testing.T) {
testCases := map[string]struct {
Config *Config
SetConfig bool
SetEnvironmentVariable bool
SetSharedConfigurationFile bool
ExpandEnvVars bool
EnvironmentVariables map[string]string
ExpectTLSClientConfigRootCAsSet bool
}{
"no configuration": {
Config: &Config{
AccessKey: servicemocks.MockStaticAccessKey,
Region: "us-east-1",
SecretKey: servicemocks.MockStaticSecretKey,
},
ExpectTLSClientConfigRootCAsSet: false,
},

"config": {
Config: &Config{
AccessKey: servicemocks.MockStaticAccessKey,
Region: "us-east-1",
SecretKey: servicemocks.MockStaticSecretKey,
},
SetConfig: true,
ExpectTLSClientConfigRootCAsSet: true,
},

"expanded config": {
Config: &Config{
AccessKey: servicemocks.MockStaticAccessKey,
Region: "us-east-1",
SecretKey: servicemocks.MockStaticSecretKey,
},
SetConfig: true,
ExpandEnvVars: true,
ExpectTLSClientConfigRootCAsSet: true,
},

"envvar": {
Config: &Config{
AccessKey: servicemocks.MockStaticAccessKey,
Region: "us-east-1",
SecretKey: servicemocks.MockStaticSecretKey,
},
SetEnvironmentVariable: true,
ExpectTLSClientConfigRootCAsSet: true,
},

// Not implemented in AWS SDK for Go v2: https://github.com/aws/aws-sdk-go-v2/issues/1589
// "shared configuration file": {
// Config: &Config{
// AccessKey: servicemocks.MockStaticAccessKey,
// Region: "us-east-1",
// SecretKey: servicemocks.MockStaticSecretKey,
// },
// SetSharedConfigurationFile: true,
// ExpectTLSClientConfigRootCAsSet: true,
// },

"config overrides envvar": {
Config: &Config{
AccessKey: servicemocks.MockStaticAccessKey,
Region: "us-east-1",
SecretKey: servicemocks.MockStaticSecretKey,
},
SetConfig: true,
EnvironmentVariables: map[string]string{
"AWS_CA_BUNDLE": "no-such-file",
},
ExpectTLSClientConfigRootCAsSet: true,
},

// Not implemented in AWS SDK for Go v2: https://github.com/aws/aws-sdk-go-v2/issues/1589
// "envvar overrides shared configuration": {
// Config: &Config{
// AccessKey: servicemocks.MockStaticAccessKey,
// Region: "us-east-1",
// SecretKey: servicemocks.MockStaticSecretKey,
// },
// EnvironmentVariables: map[string]string{
// "AWS_CA_BUNDLE": EC2MetadataEndpointModeIPv6,
// },
// SharedConfigurationFile: `
// [default]
// ec2_metadata_service_endpoint_mode = IPv4
// `,
// ExpectTLSClientConfigRootCAsSet: true,
// },
}

for testName, testCase := range testCases {
testCase := testCase

t.Run(testName, func(t *testing.T) {
oldEnv := servicemocks.InitSessionTestEnv()
defer servicemocks.PopEnv(oldEnv)

for k, v := range testCase.EnvironmentVariables {
os.Setenv(k, v)
}

tempdir, err := ioutil.TempDir("", "temp")
if err != nil {
t.Fatalf("error creating temp dir: %s", err)
}
defer os.Remove(tempdir)
os.Setenv("TMPDIR", tempdir)

pemFile, err := servicemocks.TempPEMFile()
defer os.Remove(pemFile)
if err != nil {
t.Fatalf("error creating PEM file: %s", err)
}

if testCase.ExpandEnvVars {
tmpdir := os.Getenv("TMPDIR")
rel, err := filepath.Rel(tmpdir, pemFile)
if err != nil {
t.Fatalf("error making path relative: %s", err)
}
t.Logf("relative: %s", rel)
pemFile = filepath.Join("$TMPDIR", rel)
t.Logf("env tempfile: %s", pemFile)
}

if testCase.SetConfig {
testCase.Config.CustomCABundle = pemFile
}

if testCase.SetEnvironmentVariable {
os.Setenv("AWS_CA_BUNDLE", pemFile)
}

if testCase.SetSharedConfigurationFile {
file, err := ioutil.TempFile("", "aws-sdk-go-base-shared-configuration-file")

if err != nil {
t.Fatalf("unexpected error creating temporary shared configuration file: %s", err)
}

defer os.Remove(file.Name())

err = ioutil.WriteFile(
file.Name(),
[]byte(fmt.Sprintf(`
[default]
ca_bundle = %s
`, pemFile)),
0600)

if err != nil {
t.Fatalf("unexpected error writing shared configuration file: %s", err)
}

testCase.Config.SharedConfigFiles = []string{file.Name()}
}

testCase.Config.SkipCredsValidation = true

awsConfig, err := GetAwsConfig(context.Background(), testCase.Config)
if err != nil {
t.Fatalf("error in GetAwsConfig() '%[1]T': %[1]s", err)
}

type transportGetter interface {
GetTransport() *http.Transport
}

trGetter := awsConfig.HTTPClient.(transportGetter)
tr := trGetter.GetTransport()

if a, e := tr.TLSClientConfig.RootCAs != nil, testCase.ExpectTLSClientConfigRootCAsSet; a != e {
t.Errorf("expected(%t) CA Bundle, got: %t", e, a)
}
})
}
}

func TestGetAwsConfigWithAccountIDAndPartition(t *testing.T) {
oldEnv := servicemocks.InitSessionTestEnv()
defer servicemocks.PopEnv(oldEnv)
Expand Down Expand Up @@ -2350,61 +2532,3 @@ func (r *withNoDelay) RetryDelay(attempt int, err error) (time.Duration, error)

return 0 * time.Second, nil
}

func TestExpandFilePath(t *testing.T) {
testcases := map[string]struct {
path string
expected string
envvars map[string]string
}{
"filename": {
path: "file",
expected: "file",
},
"file in current dir": {
path: "./file",
expected: "./file",
},
"file with tilde": {
path: "~/file",
expected: "/my/home/dir/file",
envvars: map[string]string{
"HOME": "/my/home/dir",
},
},
"file with envvar": {
path: "$HOME/file",
expected: "/home/dir/file",
envvars: map[string]string{
"HOME": "/home/dir",
},
},
"full file in envvar": {
path: "$CONF_FILE",
expected: "/path/to/conf/file",
envvars: map[string]string{
"CONF_FILE": "/path/to/conf/file",
},
},
}

for name, testcase := range testcases {
t.Run(name, func(t *testing.T) {
oldEnv := servicemocks.StashEnv()
defer servicemocks.PopEnv(oldEnv)

for k, v := range testcase.envvars {
os.Setenv(k, v)
}

a, err := expandFilePath(testcase.path)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}

if a != testcase.expected {
t.Errorf("expected expansion to %q, got %q", testcase.expected, a)
}
})
}
}
5 changes: 3 additions & 2 deletions credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
"github.com/aws/aws-sdk-go-v2/service/sts/types"
"github.com/hashicorp/aws-sdk-go-base/v2/internal/expand"
)

func getCredentialsProvider(ctx context.Context, c *Config) (aws.CredentialsProvider, error) {
Expand Down Expand Up @@ -41,9 +42,9 @@ func getCredentialsProvider(ctx context.Context, c *Config) (aws.CredentialsProv
)
}
if len(c.SharedCredentialsFiles) > 0 {
credsFiles, err := expandFilePaths(c.SharedCredentialsFiles)
credsFiles, err := expand.FilePaths(c.SharedCredentialsFiles)
if err != nil {
return nil, fmt.Errorf("error expanding shared credentials files: %w", err)
return nil, fmt.Errorf("expanding shared credentials files: %w", err)
}
loadOptions = append(
loadOptions,
Expand Down
4 changes: 2 additions & 2 deletions internal/httpclient/http_client.go → http_client.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package httpclient
package awsbase

import (
"fmt"
Expand All @@ -9,7 +9,7 @@ import (
"github.com/hashicorp/aws-sdk-go-base/v2/internal/config"
)

func DefaultHttpClient(c *config.Config) (*awshttp.BuildableClient, error) {
func defaultHttpClient(c *config.Config) (*awshttp.BuildableClient, error) {
var err error

httpClient := awshttp.NewBuildableClient().
Expand Down
Loading

0 comments on commit ec08b43

Please sign in to comment.