From c4fec32eef492d227dc01258446f0a27ab3d76fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B8rn=20Erik=20Pedersen?= Date: Mon, 10 Oct 2016 10:35:18 +0200 Subject: [PATCH] Restore performance for the simple case ``` benchmark old ns/op new ns/op delta BenchmarkGetBool-4 1021 479 -53.09% BenchmarkGetBoolFromMap-4 6.56 6.39 -2.59% benchmark old allocs new allocs delta BenchmarkGetBool-4 6 4 -33.33% BenchmarkGetBoolFromMap-4 0 0 +0.00% benchmark old bytes new bytes delta BenchmarkGetBool-4 113 49 -56.64% BenchmarkGetBoolFromMap-4 0 0 +0.00% ``` Fixes #249 Fixes https://github.com/spf13/hugo/issues/2536 --- viper.go | 71 +++++++++++++++++++++++++---------- viper_test.go | 102 +++++++++++++++++++++++++++++++++++++++----------- 2 files changed, 133 insertions(+), 40 deletions(-) diff --git a/viper.go b/viper.go index 8f2784944f..a03540cd54 100644 --- a/viper.go +++ b/viper.go @@ -399,17 +399,42 @@ func (v *Viper) providerPathExists(p *defaultRemoteProvider) bool { return false } +// searchMapForKey may end up traversing the map if the key references a nested +// item (foo.bar), but will use a fast path for the common case. +// Note: This assumes that the key given is already lowercase. +func (v *Viper) searchMapForKey(source map[string]interface{}, lcaseKey string) interface{} { + if !strings.Contains(lcaseKey, v.keyDelim) { + v, ok := source[lcaseKey] + if ok { + return v + } + return nil + } + + path := strings.Split(lcaseKey, v.keyDelim) + return v.searchMap(source, path) +} + // searchMap recursively searches for a value for path in source map. // Returns nil if not found. +// Note: This assumes that the path entries are lower cased. func (v *Viper) searchMap(source map[string]interface{}, path []string) interface{} { if len(path) == 0 { return source } + // Fast path + if len(path) == 1 { + if v, ok := source[path[0]]; ok { + return v + } + return nil + } + var ok bool var next interface{} for k, v := range source { - if strings.ToLower(k) == strings.ToLower(path[0]) { + if k == path[0] { ok = true next = v break @@ -594,8 +619,8 @@ func (v *Viper) Get(key string) interface{} { valType := val if v.typeByDefValue { - path := strings.Split(lcaseKey, v.keyDelim) - defVal := v.searchMap(v.defaults, path) + // TODO(bep) this branch isn't covered by a single test. + defVal := v.searchMapForKey(v.defaults, lcaseKey) if defVal != nil { valType = defVal } @@ -841,32 +866,39 @@ func (v *Viper) BindEnv(input ...string) error { // Viper will check in the following order: // flag, env, config file, key/value store, default. // Viper will check to see if an alias exists first. -func (v *Viper) find(key string) interface{} { - var val interface{} - var exists bool +// Note: this assumes a lower-cased key given. +func (v *Viper) find(lcaseKey string) interface{} { + + var ( + val interface{} + exists bool + path = strings.Split(lcaseKey, v.keyDelim) + nested = len(path) > 1 + ) // compute the path through the nested maps to the nested value - path := strings.Split(key, v.keyDelim) - if shadow := v.isPathShadowedInDeepMap(path, castMapStringToMapInterface(v.aliases)); shadow != "" { + if nested && v.isPathShadowedInDeepMap(path, castMapStringToMapInterface(v.aliases)) != "" { return nil } // if the requested key is an alias, then return the proper key - key = v.realKey(key) - // re-compute the path - path = strings.Split(key, v.keyDelim) + lcaseKey = v.realKey(lcaseKey) // Set() override first - val = v.searchMap(v.override, path) + val = v.searchMapForKey(v.override, lcaseKey) if val != nil { return val } - if shadow := v.isPathShadowedInDeepMap(path, v.override); shadow != "" { + + path = strings.Split(lcaseKey, v.keyDelim) + nested = len(path) > 1 + + if nested && v.isPathShadowedInDeepMap(path, v.override) != "" { return nil } // PFlag override next - flag, exists := v.pflags[key] + flag, exists := v.pflags[lcaseKey] if exists && flag.HasChanged() { switch flag.ValueType() { case "int", "int8", "int16", "int32", "int64": @@ -880,7 +912,8 @@ func (v *Viper) find(key string) interface{} { return flag.ValueString() } } - if shadow := v.isPathShadowedInFlatMap(path, v.pflags); shadow != "" { + + if nested && v.isPathShadowedInFlatMap(path, v.pflags) != "" { return nil } @@ -888,14 +921,14 @@ func (v *Viper) find(key string) interface{} { if v.automaticEnvApplied { // even if it hasn't been registered, if automaticEnv is used, // check any Get request - if val = v.getEnv(v.mergeWithEnvPrefix(key)); val != "" { + if val = v.getEnv(v.mergeWithEnvPrefix(lcaseKey)); val != "" { return val } - if shadow := v.isPathShadowedInAutoEnv(path); shadow != "" { + if nested && v.isPathShadowedInAutoEnv(path) != "" { return nil } } - envkey, exists := v.env[key] + envkey, exists := v.env[lcaseKey] if exists { if val = v.getEnv(envkey); val != "" { return val @@ -934,7 +967,7 @@ func (v *Viper) find(key string) interface{} { // last chance: if no other value is returned and a flag does exist for the value, // get the flag's value even if the flag's value has not changed - if flag, exists := v.pflags[key]; exists { + if flag, exists := v.pflags[lcaseKey]; exists { switch flag.ValueType() { case "int", "int8", "int16", "int32", "int64": return cast.ToInt(flag.ValueString()) diff --git a/viper_test.go b/viper_test.go index 02d6eb1c81..f2cc56a3ce 100644 --- a/viper_test.go +++ b/viper_test.go @@ -18,6 +18,8 @@ import ( "testing" "time" + "github.com/spf13/cast" + "github.com/spf13/pflag" "github.com/stretchr/testify/assert" ) @@ -131,12 +133,18 @@ func initConfigs() { unmarshalReader(remote, v.kvstore) } -func initYAML() { +func initConfig(typ, config string) { Reset() - SetConfigType("yaml") - r := bytes.NewReader(yamlExample) + SetConfigType(typ) + r := strings.NewReader(config) - unmarshalReader(r, v.config) + if err := unmarshalReader(r, v.config); err != nil { + panic(err) + } +} + +func initYAML() { + initConfig("yaml", string(yamlExample)) } func initJSON() { @@ -435,13 +443,8 @@ func TestAllKeys(t *testing.T) { assert.Equal(t, all, AllSettings()) } -func TestCaseInSensitive(t *testing.T) { - assert.Equal(t, true, Get("hacker")) - Set("Title", "Checking Case") - assert.Equal(t, "Checking Case", Get("tItle")) -} - func TestAliasesOfAliases(t *testing.T) { + Set("Title", "Checking Case") RegisterAlias("Foo", "Bar") RegisterAlias("Bar", "Title") assert.Equal(t, "Checking Case", Get("FOO")) @@ -538,7 +541,6 @@ func TestBindPFlag(t *testing.T) { } func TestBoundCaseSensitivity(t *testing.T) { - assert.Equal(t, "brown", Get("eyes")) BindEnv("eYEs", "TURTLE_EYES") @@ -917,8 +919,19 @@ func TestSetConfigNameClearsFileCache(t *testing.T) { } func TestShadowedNestedValue(t *testing.T) { + + config := `name: steve +clothing: + jacket: leather + trousers: denim + pants: + size: large +` + initConfig("yaml", config) + + assert.Equal(t, "steve", GetString("name")) + polyester := "polyester" - initYAML() SetDefault("clothing.shirt", polyester) SetDefault("clothing.jacket.price", 100) @@ -942,18 +955,65 @@ func TestDotParameter(t *testing.T) { assert.Equal(t, expected, actual) } -func TestGetBool(t *testing.T) { - key := "BooleanKey" - v = New() - v.Set(key, true) - if !v.GetBool(key) { - t.Fatal("GetBool returned false") - } - if v.GetBool("NotFound") { - t.Fatal("GetBool returned true") +func TestCaseInSensitive(t *testing.T) { + for _, config := range []struct { + typ string + content string + }{ + {"yaml", ` +aBcD: 1 +eF: + gH: 2 + iJk: 3 + Lm: + nO: 4 + P: + Q: 5 + R: 6 +`}, + {"json", `{ + "aBcD": 1, + "eF": { + "iJk": 3, + "Lm": { + "P": { + "Q": 5, + "R": 6 + }, + "nO": 4 + }, + "gH": 2 + } +}`}, + {"toml", `aBcD = 1 +[eF] +gH = 2 +iJk = 3 +[eF.Lm] +nO = 4 +[eF.Lm.P] +Q = 5 +R = 6 +`}, + } { + doTestCaseInSensitive(t, config.typ, config.content) } } +func doTestCaseInSensitive(t *testing.T, typ, config string) { + initConfig(typ, config) + Set("RfD", true) + assert.Equal(t, true, Get("rfd")) + assert.Equal(t, true, Get("rFD")) + assert.Equal(t, 1, cast.ToInt(Get("abcd"))) + assert.Equal(t, 1, cast.ToInt(Get("Abcd"))) + assert.Equal(t, 2, cast.ToInt(Get("ef.gh"))) + assert.Equal(t, 3, cast.ToInt(Get("ef.ijk"))) + assert.Equal(t, 4, cast.ToInt(Get("ef.lm.no"))) + assert.Equal(t, 5, cast.ToInt(Get("ef.lm.p.q"))) + +} + func BenchmarkGetBool(b *testing.B) { key := "BenchmarkGetBool" v = New()