diff --git a/viper.go b/viper.go index 7de2e78e4..53ded8a34 100644 --- a/viper.go +++ b/viper.go @@ -1111,7 +1111,32 @@ func Unmarshal(rawVal any, opts ...DecoderConfigOption) error { } func (v *Viper) Unmarshal(rawVal any, opts ...DecoderConfigOption) error { - return decode(v.AllSettings(), defaultDecoderConfig(rawVal, opts...)) + // TODO: make this optional? + structKeys, err := v.decodeStructKeys(rawVal, opts...) + if err != nil { + return err + } + + // TODO: struct keys should be enough? + return decode(v.getSettings(append(v.AllKeys(), structKeys...)), defaultDecoderConfig(rawVal, opts...)) +} + +func (v *Viper) decodeStructKeys(input any, opts ...DecoderConfigOption) ([]string, error) { + var structKeyMap map[string]any + + err := decode(input, defaultDecoderConfig(&structKeyMap, opts...)) + if err != nil { + return nil, err + } + + flattenedStructKeyMap := v.flattenAndMergeMap(map[string]bool{}, structKeyMap, "") + + r := make([]string, 0, len(flattenedStructKeyMap)) + for v := range flattenedStructKeyMap { + r = append(r, v) + } + + return r, nil } // defaultDecoderConfig returns default mapstructure.DecoderConfig with support @@ -2098,9 +2123,13 @@ outer: func AllSettings() map[string]any { return v.AllSettings() } func (v *Viper) AllSettings() map[string]any { + return v.getSettings(v.AllKeys()) +} + +func (v *Viper) getSettings(keys []string) map[string]any { m := map[string]any{} // start from the list of keys, and construct the map one value at a time - for _, k := range v.AllKeys() { + for _, k := range keys { value := v.Get(k) if value == nil { // should not happen, since AllKeys() returns only keys holding a value, diff --git a/viper_test.go b/viper_test.go index 0e416e7df..b8274a90b 100644 --- a/viper_test.go +++ b/viper_test.go @@ -948,6 +948,105 @@ func TestUnmarshalWithDecoderOptions(t *testing.T) { }, &C) } +func TestUnmarshalWithAutomaticEnv(t *testing.T) { + t.Setenv("PORT", "1313") + t.Setenv("NAME", "Steve") + t.Setenv("DURATION", "1s1ms") + t.Setenv("MODES", "1,2,3") + t.Setenv("SECRET", "42") + t.Setenv("FILESYSTEM_SIZE", "4096") + + type AuthConfig struct { + Secret string `mapstructure:"secret"` + } + + type StorageConfig struct { + Size int `mapstructure:"size"` + } + + type Configuration struct { + Port int `mapstructure:"port"` + Name string `mapstructure:"name"` + Duration time.Duration `mapstructure:"duration"` + + // Infer name from struct + Modes []int + + // Squash nested struct (omit prefix) + Authentication AuthConfig `mapstructure:",squash"` + + // Different key + Storage StorageConfig `mapstructure:"filesystem"` + + // Omitted field + Flag bool `mapstructure:"flag"` + } + + v := New() + v.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) + v.AutomaticEnv() + + t.Run("OK", func(t *testing.T) { + var config Configuration + if err := v.Unmarshal(&config); err != nil { + t.Fatalf("unable to decode into struct, %v", err) + } + + assert.Equal( + t, + Configuration{ + Name: "Steve", + Port: 1313, + Duration: time.Second + time.Millisecond, + Modes: []int{1, 2, 3}, + Authentication: AuthConfig{ + Secret: "42", + }, + Storage: StorageConfig{ + Size: 4096, + }, + }, + config, + ) + }) + + t.Run("Precedence", func(t *testing.T) { + var config Configuration + + v.Set("port", 1234) + if err := v.Unmarshal(&config); err != nil { + t.Fatalf("unable to decode into struct, %v", err) + } + + assert.Equal( + t, + Configuration{ + Name: "Steve", + Port: 1234, + Duration: time.Second + time.Millisecond, + Modes: []int{1, 2, 3}, + Authentication: AuthConfig{ + Secret: "42", + }, + Storage: StorageConfig{ + Size: 4096, + }, + }, + config, + ) + }) + + t.Run("Unset", func(t *testing.T) { + var config Configuration + + err := v.Unmarshal(&config, func(config *mapstructure.DecoderConfig) { + config.ErrorUnset = true + }) + + assert.Error(t, err, "expected viper.Unmarshal to return error due to unset field 'FLAG'") + }) +} + func TestBindPFlags(t *testing.T) { v := New() // create independent Viper object flagSet := pflag.NewFlagSet("test", pflag.ContinueOnError)