Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions azkv/keysource.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ type MasterKey struct {
// using TokenCredential.ApplyToMasterKey.
// If nil, azidentity.NewDefaultAzureCredential is used.
tokenCredential azcore.TokenCredential
// clientOptions contains the azkeys.ClientOptions used by the Azure client.
clientOptions *azkeys.ClientOptions
}

// NewMasterKey creates a new MasterKey from a URL, key name and version,
Expand Down Expand Up @@ -118,6 +120,23 @@ func (t TokenCredential) ApplyToMasterKey(key *MasterKey) {
key.tokenCredential = t.token
}

// ClientOptions is a wrapper around azkeys.ClientOptions to allow
// configuration of the Azure Key Vault client.
type ClientOptions struct {
o *azkeys.ClientOptions
}

// NewClientOptions creates a new ClientOptions with the provided
// azkeys.ClientOptions.
func NewClientOptions(o *azkeys.ClientOptions) *ClientOptions {
return &ClientOptions{o: o}
}

// ApplyToMasterKey configures the ClientOptions on the provided key.
func (c ClientOptions) ApplyToMasterKey(key *MasterKey) {
key.clientOptions = c.o
}

// Encrypt takes a SOPS data key, encrypts it with Azure Key Vault, and stores
// the result in the EncryptedKey field.
//
Expand All @@ -135,7 +154,7 @@ func (key *MasterKey) EncryptContext(ctx context.Context, dataKey []byte) error
return fmt.Errorf("failed to get Azure token credential to encrypt data: %w", err)
}

c, err := azkeys.NewClient(key.VaultURL, token, nil)
c, err := azkeys.NewClient(key.VaultURL, token, key.clientOptions)
if err != nil {
log.WithFields(logrus.Fields{"key": key.Name, "version": key.Version}).Info("Encryption failed")
return fmt.Errorf("failed to construct Azure Key Vault client to encrypt data: %w", err)
Expand Down Expand Up @@ -198,7 +217,7 @@ func (key *MasterKey) DecryptContext(ctx context.Context) ([]byte, error) {
return nil, fmt.Errorf("failed to base64 decode Azure Key Vault encrypted key: %w", err)
}

c, err := azkeys.NewClient(key.VaultURL, token, nil)
c, err := azkeys.NewClient(key.VaultURL, token, key.clientOptions)
if err != nil {
log.WithFields(logrus.Fields{"key": key.Name, "version": key.Version}).Info("Decryption failed")
return nil, fmt.Errorf("failed to construct Azure Key Vault client to decrypt data: %w", err)
Expand Down
17 changes: 16 additions & 1 deletion gcpkms/keysource.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ type MasterKey struct {
// Mostly useful for testing at present, to wire the client to a mock
// server.
grpcConn *grpc.ClientConn
// grpcDialOpts are the gRPC dial options used to create the gRPC connection.
grpcDialOpts []grpc.DialOption
}

// NewMasterKeyFromResourceID creates a new MasterKey with the provided resource
Expand Down Expand Up @@ -116,6 +118,14 @@ func (c CredentialJSON) ApplyToMasterKey(key *MasterKey) {
key.credentialJSON = c
}

// DialOptions are the gRPC dial options used to create the gRPC connection.
type DialOptions []grpc.DialOption

// ApplyToMasterKey configures the DialOptions on the provided key.
func (d DialOptions) ApplyToMasterKey(key *MasterKey) {
key.grpcDialOpts = d
}

// Encrypt takes a SOPS data key, encrypts it with GCP KMS, and stores the
// result in the EncryptedKey field.
//
Expand Down Expand Up @@ -275,8 +285,13 @@ func (key *MasterKey) newKMSClient(ctx context.Context) (*kms.KeyManagementClien
}
}

if key.grpcConn != nil {
switch {
case key.grpcConn != nil:
opts = append(opts, option.WithGRPCConn(key.grpcConn))
case len(key.grpcDialOpts) > 0:
for _, opt := range key.grpcDialOpts {
opts = append(opts, option.WithGRPCDialOption(opt))
}
}

client, err := kms.NewKeyManagementClient(ctx, opts...)
Expand Down
29 changes: 26 additions & 3 deletions hcvault/keysource.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"io"
"net/http"
"net/url"
"os"
"path"
Expand Down Expand Up @@ -71,6 +72,8 @@ type MasterKey struct {
// Token.ApplyToMasterKey. If empty, the default client configuration
// is used, before falling back to the token stored in defaultTokenFile.
token string
// httpClient is used to override the default HTTP client used by the Vault client.
httpClient *http.Client
}

// NewMasterKeysFromURIs creates a list of MasterKeys from a list of Vault
Expand Down Expand Up @@ -129,6 +132,22 @@ func NewMasterKey(address, enginePath, keyName string) *MasterKey {
return key
}

// HTTPClient is a wrapper around http.Client used for configuring the
// Vault client.
type HTTPClient struct {
hc *http.Client
}

// NewHTTPClient creates a new HTTPClient with the provided http.Client.
func NewHTTPClient(hc *http.Client) *HTTPClient {
return &HTTPClient{hc: hc}
}

// ApplyToMasterKey configures the HTTP client on the provided key.
func (h HTTPClient) ApplyToMasterKey(key *MasterKey) {
key.httpClient = h.hc
}

// Encrypt takes a SOPS data key, encrypts it with Vault Transit, and stores
// the result in the EncryptedKey field.
//
Expand All @@ -142,7 +161,7 @@ func (key *MasterKey) Encrypt(dataKey []byte) error {
func (key *MasterKey) EncryptContext(ctx context.Context, dataKey []byte) error {
fullPath := key.encryptPath()

client, err := vaultClient(key.VaultAddress, key.token)
client, err := vaultClient(key.VaultAddress, key.token, key.httpClient)
if err != nil {
log.WithField("Path", fullPath).Info("Encryption failed")
return err
Expand Down Expand Up @@ -194,7 +213,7 @@ func (key *MasterKey) Decrypt() ([]byte, error) {
func (key *MasterKey) DecryptContext(ctx context.Context) ([]byte, error) {
fullPath := key.decryptPath()

client, err := vaultClient(key.VaultAddress, key.token)
client, err := vaultClient(key.VaultAddress, key.token, key.httpClient)
if err != nil {
log.WithField("Path", fullPath).Info("Decryption failed")
return nil, err
Expand Down Expand Up @@ -308,10 +327,14 @@ func dataKeyFromSecret(secret *api.Secret) ([]byte, error) {

// vaultClient returns a new Vault client, configured with the given address
// and token.
func vaultClient(address, token string) (*api.Client, error) {
func vaultClient(address, token string, hc *http.Client) (*api.Client, error) {
cfg := api.DefaultConfig()
cfg.Address = address

if hc != nil {
cfg.HttpClient = hc
}

client, err := api.NewClient(cfg)
if err != nil {
return nil, fmt.Errorf("cannot create Vault client: %w", err)
Expand Down
16 changes: 8 additions & 8 deletions hcvault/keysource_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ func TestMasterKey_Encrypt(t *testing.T) {
assert.NoError(t, key.Encrypt(dataKey))
assert.NotEmpty(t, key.EncryptedKey)

client, err := vaultClient(key.VaultAddress, key.token)
client, err := vaultClient(key.VaultAddress, key.token, nil)
assert.NoError(t, err)

payload := decryptPayload(key.EncryptedKey)
Expand Down Expand Up @@ -230,7 +230,7 @@ func TestMasterKey_Decrypt(t *testing.T) {
(Token(testVaultToken)).ApplyToMasterKey(key)
assert.NoError(t, createVaultKey(key))

client, err := vaultClient(key.VaultAddress, key.token)
client, err := vaultClient(key.VaultAddress, key.token, nil)
assert.NoError(t, err)

dataKey := []byte("the heart of a shrimp is located in its head")
Expand Down Expand Up @@ -368,7 +368,7 @@ func Test_vaultClient(t *testing.T) {
t.Setenv("VAULT_TOKEN", "")
t.Setenv("HOME", tmpDir)

got, err := vaultClient(testVaultAddress, "")
got, err := vaultClient(testVaultAddress, "", nil)
assert.NoError(t, err)
assert.NotNil(t, got)
assert.Empty(t, got.Token())
Expand All @@ -378,7 +378,7 @@ func Test_vaultClient(t *testing.T) {
token := "test-token"
t.Setenv("VAULT_TOKEN", token)

got, err := vaultClient(testVaultAddress, "")
got, err := vaultClient(testVaultAddress, "", nil)
assert.NoError(t, err)
assert.NotNil(t, got)
assert.Equal(t, token, got.Token())
Expand All @@ -388,7 +388,7 @@ func Test_vaultClient(t *testing.T) {
ignored := "test-token"
t.Setenv("VAULT_TOKEN", ignored)

got, err := vaultClient(testVaultAddress, testVaultToken)
got, err := vaultClient(testVaultAddress, testVaultToken, nil)
assert.NoError(t, err)
assert.NotNil(t, got)
assert.Equal(t, testVaultToken, got.Token())
Expand All @@ -407,7 +407,7 @@ func Test_vaultClient(t *testing.T) {
t.Setenv("VAULT_TOKEN", "")
t.Setenv("HOME", tmpDir)

got, err := vaultClient(testVaultAddress, "")
got, err := vaultClient(testVaultAddress, "", nil)
assert.NoError(t, err)
assert.NotNil(t, got)
assert.Equal(t, token, got.Token())
Expand Down Expand Up @@ -487,7 +487,7 @@ func Test_engineAndKeyFromPath(t *testing.T) {

// enableVaultTransit enables the Vault Transit backend on the given enginePath.
func enableVaultTransit(address, token, enginePath string) error {
client, err := vaultClient(address, token)
client, err := vaultClient(address, token, nil)
if err != nil {
return fmt.Errorf("cannot create Vault client: %w", err)
}
Expand All @@ -504,7 +504,7 @@ func enableVaultTransit(address, token, enginePath string) error {
// createVaultKey creates a new RSA-4096 Vault key using the data from the
// provided MasterKey.
func createVaultKey(key *MasterKey) error {
client, err := vaultClient(key.VaultAddress, key.token)
client, err := vaultClient(key.VaultAddress, key.token, nil)
if err != nil {
return fmt.Errorf("cannot create Vault client: %w", err)
}
Expand Down
22 changes: 22 additions & 0 deletions kms/keysource.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"context"
"encoding/base64"
"fmt"
"net/http"
"os"
"regexp"
"sort"
Expand Down Expand Up @@ -79,6 +80,8 @@ type MasterKey struct {
// injected using e.g. an environment variable. The field is not publicly
// exposed, nor configurable.
baseEndpoint string
// httpClient is used to override the default HTTP client used by the AWS client.
httpClient *http.Client
}

// NewMasterKey creates a new MasterKey from an ARN, role and context, setting
Expand Down Expand Up @@ -233,6 +236,22 @@ func (c CredentialsProvider) ApplyToMasterKey(key *MasterKey) {
key.credentialsProvider = c.provider
}

// HTTPClient is a wrapper around http.Client used for configuring the
// AWS KMS client.
type HTTPClient struct {
hc *http.Client
}

// NewHTTPClient creates a new HTTPClient with the provided http.Client.
func NewHTTPClient(hc *http.Client) *HTTPClient {
return &HTTPClient{hc: hc}
}

// ApplyToMasterKey configures the HTTP client on the provided key.
func (h HTTPClient) ApplyToMasterKey(key *MasterKey) {
key.httpClient = h.hc
}

// Encrypt takes a SOPS data key, encrypts it with KMS and stores the result
// in the EncryptedKey field.
//
Expand Down Expand Up @@ -385,6 +404,9 @@ func (key MasterKey) createKMSConfig(ctx context.Context) (*aws.Config, error) {
lo.SharedConfigProfile = key.AwsProfile
}
lo.Region = region
if key.httpClient != nil {
lo.HTTPClient = key.httpClient
}
return nil
})
if err != nil {
Expand Down
Loading