diff --git a/pkg/credentials/credentials.go b/pkg/credentials/credentials.go index aacdcf05e..e214bedec 100644 --- a/pkg/credentials/credentials.go +++ b/pkg/credentials/credentials.go @@ -151,7 +151,7 @@ type ProviderEnvironmentVariable struct { } func Any(credentialsFilePath string) (map[string]string, error) { - credentialsFinder, err := newCredsFinder(credentialsFilePath, TypeUniversal) + credentialsFinder, err := newCredentialsFinder(withYAMLFile(credentialsFilePath)) if err != nil { return nil, err } @@ -159,7 +159,7 @@ func Any(credentialsFilePath string) (map[string]string, error) { creds := map[string]string{} for _, key := range allKeys { - if val := credentialsFinder(key); val != "" { + if val := credentialsFinder.get(key); val != "" { creds[key] = val // NB: We want to use Equinix Metal env vars everywhere, even if // users has PACKET_ env vars on their systems. @@ -178,11 +178,12 @@ func Any(credentialsFilePath string) (map[string]string, error) { // ProviderCredentials implements fetching credentials for each supported provider func ProviderCredentials(cloudProvider kubeoneapi.CloudProviderSpec, credentialsFilePath string, credentialsType Type) (map[string]string, error) { - credentialsFinder, err := newCredsFinder(credentialsFilePath, credentialsType) + credentialsFinderStore, err := newCredentialsFinder(withYAMLFile(credentialsFilePath), withType(credentialsType)) if err != nil { return nil, err } + credentialsFinder := credentialsFinderStore.lookupFunc() switch { case cloudProvider.AWS != nil: return credentialsFinder.aws() @@ -277,43 +278,76 @@ func ProviderCredentials(cloudProvider kubeoneapi.CloudProviderSpec, credentials } } -func newCredsFinder(credentialsFilePath string, credentialsType Type) (lookupFunc, error) { - staticMap := map[string]string{} - finder := func(name string) string { - switch { - case credentialsType != TypeUniversal: - typedName := string(credentialsType) + "_" + name - if val := os.Getenv(typedName); val != "" { - return val - } - if val, ok := staticMap[typedName]; ok && val != "" { - return val - } +func withYAMLFile(filePath string) func(*credentialsFinder) error { + return func(cf *credentialsFinder) error { + if filePath == "" { + return nil + } - fallthrough - default: - if val := os.Getenv(name); val != "" { - return val - } + buf, err := os.ReadFile(filePath) + if err != nil { + return fail.Runtime(err, "reading credentials file") + } - return staticMap[name] + if err = yaml.Unmarshal(buf, &cf.static); err != nil { + return fail.Runtime(err, "unmarshalling credentials file") } + + return nil } +} + +func withType(typ Type) func(*credentialsFinder) error { + return func(cf *credentialsFinder) error { + cf.typ = typ - if credentialsFilePath == "" { - return finder, nil + return nil } +} - buf, err := os.ReadFile(credentialsFilePath) - if err != nil { - return nil, fail.Runtime(err, "loading credentials file") +func newCredentialsFinder(opts ...func(*credentialsFinder) error) (*credentialsFinder, error) { + cf := credentialsFinder{ + static: map[string]string{}, + dynamic: os.Getenv, } - if err = yaml.Unmarshal(buf, &staticMap); err != nil { - return nil, fail.Runtime(err, "unmarshalling credentials file") + for _, optFn := range opts { + if err := optFn(&cf); err != nil { + return nil, err + } + } + + return &cf, nil +} + +type credentialsFinder struct { + static map[string]string + dynamic func(string) string + typ Type +} + +func (cf *credentialsFinder) lookupFunc() lookupFunc { return cf.get } + +func (cf *credentialsFinder) typedKey(name string) string { + return string(cf.typ) + "_" + name +} + +func (cf *credentialsFinder) fetch(name string) string { + if val := cf.static[name]; val != "" { + return val + } + + return cf.dynamic(name) +} + +func (cf *credentialsFinder) get(name string) string { + if cf.typ != TypeUniversal { + if val := cf.fetch(cf.typedKey(name)); val != "" { + return val + } } - return finder, nil + return cf.fetch(name) } // lookupFunc is function that retrieves credentials from the sources diff --git a/pkg/credentials/credentials_test.go b/pkg/credentials/credentials_test.go index bd8461b7f..5e2adc40b 100644 --- a/pkg/credentials/credentials_test.go +++ b/pkg/credentials/credentials_test.go @@ -248,3 +248,95 @@ func TestVmwareCloudDirectorValidationFunc(t *testing.T) { }) } } + +func TestCredentialsFinder(t *testing.T) { + withDynamicFixture := func(dynamicFn func(string) string) func(*credentialsFinder) error { + return func(cf *credentialsFinder) error { + cf.dynamic = dynamicFn + + return nil + } + } + + withStaticFixture := func(static map[string]string) func(*credentialsFinder) error { + return func(cf *credentialsFinder) error { + cf.static = static + + return nil + } + } + + tests := []struct { + name string + key string + want string + opts []func(*credentialsFinder) error + }{ + { + name: "static universal", + key: "key1", + want: "val1", + opts: []func(*credentialsFinder) error{ + withStaticFixture(map[string]string{ + "key1": "val1", + }), + }, + }, + { + name: "static with type OSM", + key: "key1", + want: "OSM_val1", + opts: []func(*credentialsFinder) error{ + withType(TypeOSM), + withStaticFixture(map[string]string{ + "OSM_key1": "OSM_val1", + }), + }, + }, + { + name: "dynamic with type OSM", + key: "key1", + want: "OSM_val1", + opts: []func(*credentialsFinder) error{ + withType(TypeOSM), + withStaticFixture(map[string]string{ + "key1": "from_static", + }), + withDynamicFixture(func(key string) string { + return map[string]string{ + "OSM_key1": "OSM_val1", + }[key] + }), + }, + }, + { + name: "static precedence over dynamic with type OSM", + key: "key1", + want: "from_static", + opts: []func(*credentialsFinder) error{ + withType(TypeOSM), + withStaticFixture(map[string]string{ + "OSM_key1": "from_static", + }), + withDynamicFixture(func(key string) string { + return map[string]string{ + "OSM_key1": "from_dynamic", + }[key] + }), + }, + }, + } + + for _, tcase := range tests { + t.Run(tcase.name, func(t *testing.T) { + finder, err := newCredentialsFinder(tcase.opts...) + if err != nil { + t.Fatalf("got unexpcted error: %v", err) + } + + if result := finder.get(tcase.key); result != tcase.want { + t.Errorf("get(%q)=%q, want %q", tcase.key, result, tcase.want) + } + }) + } +}