diff --git a/azkv/keysource.go b/azkv/keysource.go index b4eef54c6..97761f529 100644 --- a/azkv/keysource.go +++ b/azkv/keysource.go @@ -60,6 +60,8 @@ type MasterKey struct { // using TokenCredential.ApplyToMasterKey. // If nil, azidentity.NewDefaultAzureCredential is used. tokenCredential azcore.TokenCredential + // clientOptions contains the azkeys.ClientOptions used by the Azure client. + clientOptions *azkeys.ClientOptions } // NewMasterKey creates a new MasterKey from a URL, key name and version, @@ -118,6 +120,23 @@ func (t TokenCredential) ApplyToMasterKey(key *MasterKey) { key.tokenCredential = t.token } +// ClientOptions is a wrapper around azkeys.ClientOptions to allow +// configuration of the Azure Key Vault client. +type ClientOptions struct { + o *azkeys.ClientOptions +} + +// NewClientOptions creates a new ClientOptions with the provided +// azkeys.ClientOptions. +func NewClientOptions(o *azkeys.ClientOptions) *ClientOptions { + return &ClientOptions{o: o} +} + +// ApplyToMasterKey configures the ClientOptions on the provided key. +func (c ClientOptions) ApplyToMasterKey(key *MasterKey) { + key.clientOptions = c.o +} + // Encrypt takes a SOPS data key, encrypts it with Azure Key Vault, and stores // the result in the EncryptedKey field. // @@ -135,7 +154,7 @@ func (key *MasterKey) EncryptContext(ctx context.Context, dataKey []byte) error return fmt.Errorf("failed to get Azure token credential to encrypt data: %w", err) } - c, err := azkeys.NewClient(key.VaultURL, token, nil) + c, err := azkeys.NewClient(key.VaultURL, token, key.clientOptions) if err != nil { log.WithFields(logrus.Fields{"key": key.Name, "version": key.Version}).Info("Encryption failed") return fmt.Errorf("failed to construct Azure Key Vault client to encrypt data: %w", err) @@ -198,7 +217,7 @@ func (key *MasterKey) DecryptContext(ctx context.Context) ([]byte, error) { return nil, fmt.Errorf("failed to base64 decode Azure Key Vault encrypted key: %w", err) } - c, err := azkeys.NewClient(key.VaultURL, token, nil) + c, err := azkeys.NewClient(key.VaultURL, token, key.clientOptions) if err != nil { log.WithFields(logrus.Fields{"key": key.Name, "version": key.Version}).Info("Decryption failed") return nil, fmt.Errorf("failed to construct Azure Key Vault client to decrypt data: %w", err) diff --git a/gcpkms/keysource.go b/gcpkms/keysource.go index cc834b055..1969e8b90 100644 --- a/gcpkms/keysource.go +++ b/gcpkms/keysource.go @@ -66,6 +66,8 @@ type MasterKey struct { // Mostly useful for testing at present, to wire the client to a mock // server. grpcConn *grpc.ClientConn + // grpcDialOpts are the gRPC dial options used to create the gRPC connection. + grpcDialOpts []grpc.DialOption } // NewMasterKeyFromResourceID creates a new MasterKey with the provided resource @@ -116,6 +118,14 @@ func (c CredentialJSON) ApplyToMasterKey(key *MasterKey) { key.credentialJSON = c } +// DialOptions are the gRPC dial options used to create the gRPC connection. +type DialOptions []grpc.DialOption + +// ApplyToMasterKey configures the DialOptions on the provided key. +func (d DialOptions) ApplyToMasterKey(key *MasterKey) { + key.grpcDialOpts = d +} + // Encrypt takes a SOPS data key, encrypts it with GCP KMS, and stores the // result in the EncryptedKey field. // @@ -275,8 +285,13 @@ func (key *MasterKey) newKMSClient(ctx context.Context) (*kms.KeyManagementClien } } - if key.grpcConn != nil { + switch { + case key.grpcConn != nil: opts = append(opts, option.WithGRPCConn(key.grpcConn)) + case len(key.grpcDialOpts) > 0: + for _, opt := range key.grpcDialOpts { + opts = append(opts, option.WithGRPCDialOption(opt)) + } } client, err := kms.NewKeyManagementClient(ctx, opts...) diff --git a/hcvault/keysource.go b/hcvault/keysource.go index 2b1a9909f..67706e71e 100644 --- a/hcvault/keysource.go +++ b/hcvault/keysource.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "io" + "net/http" "net/url" "os" "path" @@ -71,6 +72,8 @@ type MasterKey struct { // Token.ApplyToMasterKey. If empty, the default client configuration // is used, before falling back to the token stored in defaultTokenFile. token string + // httpClient is used to override the default HTTP client used by the Vault client. + httpClient *http.Client } // NewMasterKeysFromURIs creates a list of MasterKeys from a list of Vault @@ -129,6 +132,22 @@ func NewMasterKey(address, enginePath, keyName string) *MasterKey { return key } +// HTTPClient is a wrapper around http.Client used for configuring the +// Vault client. +type HTTPClient struct { + hc *http.Client +} + +// NewHTTPClient creates a new HTTPClient with the provided http.Client. +func NewHTTPClient(hc *http.Client) *HTTPClient { + return &HTTPClient{hc: hc} +} + +// ApplyToMasterKey configures the HTTP client on the provided key. +func (h HTTPClient) ApplyToMasterKey(key *MasterKey) { + key.httpClient = h.hc +} + // Encrypt takes a SOPS data key, encrypts it with Vault Transit, and stores // the result in the EncryptedKey field. // @@ -142,7 +161,7 @@ func (key *MasterKey) Encrypt(dataKey []byte) error { func (key *MasterKey) EncryptContext(ctx context.Context, dataKey []byte) error { fullPath := key.encryptPath() - client, err := vaultClient(key.VaultAddress, key.token) + client, err := vaultClient(key.VaultAddress, key.token, key.httpClient) if err != nil { log.WithField("Path", fullPath).Info("Encryption failed") return err @@ -194,7 +213,7 @@ func (key *MasterKey) Decrypt() ([]byte, error) { func (key *MasterKey) DecryptContext(ctx context.Context) ([]byte, error) { fullPath := key.decryptPath() - client, err := vaultClient(key.VaultAddress, key.token) + client, err := vaultClient(key.VaultAddress, key.token, key.httpClient) if err != nil { log.WithField("Path", fullPath).Info("Decryption failed") return nil, err @@ -308,10 +327,14 @@ func dataKeyFromSecret(secret *api.Secret) ([]byte, error) { // vaultClient returns a new Vault client, configured with the given address // and token. -func vaultClient(address, token string) (*api.Client, error) { +func vaultClient(address, token string, hc *http.Client) (*api.Client, error) { cfg := api.DefaultConfig() cfg.Address = address + if hc != nil { + cfg.HttpClient = hc + } + client, err := api.NewClient(cfg) if err != nil { return nil, fmt.Errorf("cannot create Vault client: %w", err) diff --git a/hcvault/keysource_test.go b/hcvault/keysource_test.go index 646895335..02d5a13e2 100644 --- a/hcvault/keysource_test.go +++ b/hcvault/keysource_test.go @@ -187,7 +187,7 @@ func TestMasterKey_Encrypt(t *testing.T) { assert.NoError(t, key.Encrypt(dataKey)) assert.NotEmpty(t, key.EncryptedKey) - client, err := vaultClient(key.VaultAddress, key.token) + client, err := vaultClient(key.VaultAddress, key.token, nil) assert.NoError(t, err) payload := decryptPayload(key.EncryptedKey) @@ -230,7 +230,7 @@ func TestMasterKey_Decrypt(t *testing.T) { (Token(testVaultToken)).ApplyToMasterKey(key) assert.NoError(t, createVaultKey(key)) - client, err := vaultClient(key.VaultAddress, key.token) + client, err := vaultClient(key.VaultAddress, key.token, nil) assert.NoError(t, err) dataKey := []byte("the heart of a shrimp is located in its head") @@ -368,7 +368,7 @@ func Test_vaultClient(t *testing.T) { t.Setenv("VAULT_TOKEN", "") t.Setenv("HOME", tmpDir) - got, err := vaultClient(testVaultAddress, "") + got, err := vaultClient(testVaultAddress, "", nil) assert.NoError(t, err) assert.NotNil(t, got) assert.Empty(t, got.Token()) @@ -378,7 +378,7 @@ func Test_vaultClient(t *testing.T) { token := "test-token" t.Setenv("VAULT_TOKEN", token) - got, err := vaultClient(testVaultAddress, "") + got, err := vaultClient(testVaultAddress, "", nil) assert.NoError(t, err) assert.NotNil(t, got) assert.Equal(t, token, got.Token()) @@ -388,7 +388,7 @@ func Test_vaultClient(t *testing.T) { ignored := "test-token" t.Setenv("VAULT_TOKEN", ignored) - got, err := vaultClient(testVaultAddress, testVaultToken) + got, err := vaultClient(testVaultAddress, testVaultToken, nil) assert.NoError(t, err) assert.NotNil(t, got) assert.Equal(t, testVaultToken, got.Token()) @@ -407,7 +407,7 @@ func Test_vaultClient(t *testing.T) { t.Setenv("VAULT_TOKEN", "") t.Setenv("HOME", tmpDir) - got, err := vaultClient(testVaultAddress, "") + got, err := vaultClient(testVaultAddress, "", nil) assert.NoError(t, err) assert.NotNil(t, got) assert.Equal(t, token, got.Token()) @@ -487,7 +487,7 @@ func Test_engineAndKeyFromPath(t *testing.T) { // enableVaultTransit enables the Vault Transit backend on the given enginePath. func enableVaultTransit(address, token, enginePath string) error { - client, err := vaultClient(address, token) + client, err := vaultClient(address, token, nil) if err != nil { return fmt.Errorf("cannot create Vault client: %w", err) } @@ -504,7 +504,7 @@ func enableVaultTransit(address, token, enginePath string) error { // createVaultKey creates a new RSA-4096 Vault key using the data from the // provided MasterKey. func createVaultKey(key *MasterKey) error { - client, err := vaultClient(key.VaultAddress, key.token) + client, err := vaultClient(key.VaultAddress, key.token, nil) if err != nil { return fmt.Errorf("cannot create Vault client: %w", err) } diff --git a/kms/keysource.go b/kms/keysource.go index 515c4b853..bdb963722 100644 --- a/kms/keysource.go +++ b/kms/keysource.go @@ -9,6 +9,7 @@ import ( "context" "encoding/base64" "fmt" + "net/http" "os" "regexp" "sort" @@ -79,6 +80,8 @@ type MasterKey struct { // injected using e.g. an environment variable. The field is not publicly // exposed, nor configurable. baseEndpoint string + // httpClient is used to override the default HTTP client used by the AWS client. + httpClient *http.Client } // NewMasterKey creates a new MasterKey from an ARN, role and context, setting @@ -233,6 +236,22 @@ func (c CredentialsProvider) ApplyToMasterKey(key *MasterKey) { key.credentialsProvider = c.provider } +// HTTPClient is a wrapper around http.Client used for configuring the +// AWS KMS client. +type HTTPClient struct { + hc *http.Client +} + +// NewHTTPClient creates a new HTTPClient with the provided http.Client. +func NewHTTPClient(hc *http.Client) *HTTPClient { + return &HTTPClient{hc: hc} +} + +// ApplyToMasterKey configures the HTTP client on the provided key. +func (h HTTPClient) ApplyToMasterKey(key *MasterKey) { + key.httpClient = h.hc +} + // Encrypt takes a SOPS data key, encrypts it with KMS and stores the result // in the EncryptedKey field. // @@ -385,6 +404,9 @@ func (key MasterKey) createKMSConfig(ctx context.Context) (*aws.Config, error) { lo.SharedConfigProfile = key.AwsProfile } lo.Region = region + if key.httpClient != nil { + lo.HTTPClient = key.httpClient + } return nil }) if err != nil {