diff --git a/config/config.go b/config/config.go index 06ad57370..fa2e24062 100644 --- a/config/config.go +++ b/config/config.go @@ -170,23 +170,86 @@ type destinationRule struct { } type creationRule struct { - PathRegex string `yaml:"path_regex"` - KMS string - AwsProfile string `yaml:"aws_profile"` - Age string `yaml:"age"` - PGP string - GCPKMS string `yaml:"gcp_kms"` - AzureKeyVault string `yaml:"azure_keyvault"` - VaultURI string `yaml:"hc_vault_transit_uri"` - KeyGroups []keyGroup `yaml:"key_groups"` - ShamirThreshold int `yaml:"shamir_threshold"` - UnencryptedSuffix string `yaml:"unencrypted_suffix"` - EncryptedSuffix string `yaml:"encrypted_suffix"` - UnencryptedRegex string `yaml:"unencrypted_regex"` - EncryptedRegex string `yaml:"encrypted_regex"` - UnencryptedCommentRegex string `yaml:"unencrypted_comment_regex"` - EncryptedCommentRegex string `yaml:"encrypted_comment_regex"` - MACOnlyEncrypted bool `yaml:"mac_only_encrypted"` + PathRegex string `yaml:"path_regex"` + KMS interface{} `yaml:"kms"` // string or []string + AwsProfile string `yaml:"aws_profile"` + Age interface{} `yaml:"age"` // string or []string + PGP interface{} `yaml:"pgp"` // string or []string + GCPKMS interface{} `yaml:"gcp_kms"` // string or []string + AzureKeyVault interface{} `yaml:"azure_keyvault"` // string or []string + VaultURI interface{} `yaml:"hc_vault_transit_uri"` // string or []string + KeyGroups []keyGroup `yaml:"key_groups"` + ShamirThreshold int `yaml:"shamir_threshold"` + UnencryptedSuffix string `yaml:"unencrypted_suffix"` + EncryptedSuffix string `yaml:"encrypted_suffix"` + UnencryptedRegex string `yaml:"unencrypted_regex"` + EncryptedRegex string `yaml:"encrypted_regex"` + UnencryptedCommentRegex string `yaml:"unencrypted_comment_regex"` + EncryptedCommentRegex string `yaml:"encrypted_comment_regex"` + MACOnlyEncrypted bool `yaml:"mac_only_encrypted"` +} + +// Helper methods to safely extract keys as []string +func (c *creationRule) GetKMSKeys() ([]string, error) { + return parseKeyField(c.KMS, "kms") +} + +func (c *creationRule) GetAgeKeys() ([]string, error) { + return parseKeyField(c.Age, "age") +} + +func (c *creationRule) GetPGPKeys() ([]string, error) { + return parseKeyField(c.PGP, "pgp") +} + +func (c *creationRule) GetGCPKMSKeys() ([]string, error) { + return parseKeyField(c.GCPKMS, "gcp_kms") +} + +func (c *creationRule) GetAzureKeyVaultKeys() ([]string, error) { + return parseKeyField(c.AzureKeyVault, "azure_keyvault") +} + +func (c *creationRule) GetVaultURIs() ([]string, error) { + return parseKeyField(c.VaultURI, "hc_vault_transit_uri") +} + +// Utility function to handle both string and []string +func parseKeyField(field interface{}, fieldName string) ([]string, error) { + if field == nil { + return []string{}, nil + } + + switch v := field.(type) { + case string: + if v == "" { + return []string{}, nil + } + // Existing CSV parsing logic + keys := strings.Split(v, ",") + result := make([]string, 0, len(keys)) + for _, key := range keys { + trimmed := strings.TrimSpace(key) + if trimmed != "" { // Skip empty strings (fixes trailing comma issue) + result = append(result, trimmed) + } + } + return result, nil + case []interface{}: + result := make([]string, len(v)) + for i, item := range v { + if str, ok := item.(string); ok { + result[i] = str + } else { + return nil, fmt.Errorf("invalid %s key configuration: expected string in list, got %T", fieldName, item) + } + } + return result, nil + case []string: + return v, nil + default: + return nil, fmt.Errorf("invalid %s key configuration: expected string, []string, or nil, got %T", fieldName, field) + } } func NewStoresConfig() *StoresConfig { @@ -279,6 +342,14 @@ func extractMasterKeys(group keyGroup) (sops.KeyGroup, error) { return deduplicateKeygroup(keyGroup), nil } +func getKeysWithValidation(getKeysFunc func() ([]string, error), keyType string) ([]string, error) { + keys, err := getKeysFunc() + if err != nil { + return nil, fmt.Errorf("invalid %s key configuration: %w", keyType, err) + } + return keys, nil +} + func getKeyGroupsFromCreationRule(cRule *creationRule, kmsEncryptionContext map[string]*string) ([]sops.KeyGroup, error) { var groups []sops.KeyGroup if len(cRule.KeyGroups) > 0 { @@ -291,8 +362,13 @@ func getKeyGroupsFromCreationRule(cRule *creationRule, kmsEncryptionContext map[ } } else { var keyGroup sops.KeyGroup - if cRule.Age != "" { - ageKeys, err := age.MasterKeysFromRecipients(cRule.Age) + ageKeys, err := getKeysWithValidation(cRule.GetAgeKeys, "age") + if err != nil { + return nil, err + } + + if len(ageKeys) > 0 { + ageKeys, err := age.MasterKeysFromRecipients(strings.Join(ageKeys, ",")) if err != nil { return nil, err } else { @@ -301,23 +377,43 @@ func getKeyGroupsFromCreationRule(cRule *creationRule, kmsEncryptionContext map[ } } } - for _, k := range pgp.MasterKeysFromFingerprintString(cRule.PGP) { + pgpKeys, err := getKeysWithValidation(cRule.GetPGPKeys, "pgp") + if err != nil { + return nil, err + } + for _, k := range pgp.MasterKeysFromFingerprintString(strings.Join(pgpKeys, ",")) { keyGroup = append(keyGroup, k) } - for _, k := range kms.MasterKeysFromArnString(cRule.KMS, kmsEncryptionContext, cRule.AwsProfile) { + kmsKeys, err := getKeysWithValidation(cRule.GetKMSKeys, "kms") + if err != nil { + return nil, err + } + for _, k := range kms.MasterKeysFromArnString(strings.Join(kmsKeys, ","), kmsEncryptionContext, cRule.AwsProfile) { keyGroup = append(keyGroup, k) } - for _, k := range gcpkms.MasterKeysFromResourceIDString(cRule.GCPKMS) { + gcpkmsKeys, err := getKeysWithValidation(cRule.GetGCPKMSKeys, "gcpkms") + if err != nil { + return nil, err + } + for _, k := range gcpkms.MasterKeysFromResourceIDString(strings.Join(gcpkmsKeys, ",")) { keyGroup = append(keyGroup, k) } - azureKeys, err := azkv.MasterKeysFromURLs(cRule.AzureKeyVault) + azKeys, err := getKeysWithValidation(cRule.GetAzureKeyVaultKeys, "azure_keyvault") + if err != nil { + return nil, err + } + azureKeys, err := azkv.MasterKeysFromURLs(strings.Join(azKeys, ",")) if err != nil { return nil, err } for _, k := range azureKeys { keyGroup = append(keyGroup, k) } - vaultKeys, err := hcvault.NewMasterKeysFromURIs(cRule.VaultURI) + vaultKeyUris, err := getKeysWithValidation(cRule.GetVaultURIs, "vault") + if err != nil { + return nil, err + } + vaultKeys, err := hcvault.NewMasterKeysFromURIs(strings.Join(vaultKeyUris, ",")) if err != nil { return nil, err } diff --git a/config/config_test.go b/config/config_test.go index cb8340a7f..7c869fedc 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -718,3 +718,40 @@ func TestLoadConfigFileWithVaultDestinationRules(t *testing.T) { assert.NotNil(t, conf.Destination) assert.Contains(t, conf.Destination.Path("barfoo"), "/v1/kv/barfoo/barfoo") } + +func TestCreationRuleNativeKeyLists(t *testing.T) { + var sampleConfigWithNativeKeyLists = []byte(` +creation_rules: + - path_regex: native_list* + pgp: + - "85D77543B3D624B63CEA9E6DBC17301B491B3F21" # name@email.com + - "FBC7B9E2A4F9289AC0C1D4843D16CEE4A27381B4" # server_XYZ + kms: + - "arn:aws:kms:us-east-1:123456789012:key/12345678-1234-1234-1234-123456789012" + age: + - "age1ql3z7hjy54pw3hyww5ayyfg7zqgvc7w3j2elw8zmrj2kg5sfn9aqmcac8p" + gcp_kms: + - "projects/test-project/locations/global/keyRings/test-ring/cryptoKeys/test-key" + hc_vault_transit_uri: + - "https://vault.example.com:8200/v1/transit/keys/key1" +`) + conf, err := parseCreationRuleForFile(parseConfigFile(sampleConfigWithNativeKeyLists, t), "/conf/path", "native_list_test", nil) + assert.Nil(t, err) + if conf == nil { + t.Fatal("Expected configuration but got nil") + } + + assert.True(t, len(conf.KeyGroups) == 1) + assert.True(t, len(conf.KeyGroups[0]) == 6) + + keyTypeCounts := make(map[string]int) + for _, key := range conf.KeyGroups[0] { + keyTypeCounts[key.TypeToIdentifier()]++ + } + + assert.Equal(t, 2, keyTypeCounts["pgp"]) + assert.Equal(t, 1, keyTypeCounts["kms"]) + assert.Equal(t, 1, keyTypeCounts["age"]) + assert.Equal(t, 1, keyTypeCounts["gcp_kms"]) + assert.Equal(t, 1, keyTypeCounts["hc_vault"]) +}