diff --git a/altsrc/map_input_source.go b/altsrc/map_input_source.go index 37709c77ed..41221b1875 100644 --- a/altsrc/map_input_source.go +++ b/altsrc/map_input_source.go @@ -71,24 +71,28 @@ func (fsm *MapInputSource) Int(name string) (int, error) { func (fsm *MapInputSource) Duration(name string) (time.Duration, error) { otherGenericValue, exists := fsm.valueMap[name] if exists { - otherValue, isType := otherGenericValue.(time.Duration) - if !isType { - return 0, incorrectTypeForFlagError(name, "duration", otherGenericValue) - } - return otherValue, nil + return castDuration(name, otherGenericValue) } nestedGenericValue, exists := nestedVal(name, fsm.valueMap) if exists { - otherValue, isType := nestedGenericValue.(time.Duration) - if !isType { - return 0, incorrectTypeForFlagError(name, "duration", nestedGenericValue) - } - return otherValue, nil + return castDuration(name, nestedGenericValue) } return 0, nil } +func castDuration(name string, value interface{}) (time.Duration, error) { + if otherValue, isType := value.(time.Duration); isType { + return otherValue, nil + } + otherStringValue, isType := value.(string) + parsedValue, err := time.ParseDuration(otherStringValue) + if !isType || err != nil { + return 0, incorrectTypeForFlagError(name, "duration", value) + } + return parsedValue, nil +} + // Float64 returns an float64 from the map if it exists otherwise returns 0 func (fsm *MapInputSource) Float64(name string) (float64, error) { otherGenericValue, exists := fsm.valueMap[name] diff --git a/altsrc/map_input_source_test.go b/altsrc/map_input_source_test.go new file mode 100644 index 0000000000..a921d0493a --- /dev/null +++ b/altsrc/map_input_source_test.go @@ -0,0 +1,25 @@ +package altsrc + +import ( + "testing" + "time" +) + +func TestMapDuration(t *testing.T) { + inputSource := &MapInputSource{ + file: "test", + valueMap: map[interface{}]interface{}{ + "duration_of_duration_type": time.Minute, + "duration_of_string_type": "1m", + "duration_of_int_type": 1000, + }, + } + d, err := inputSource.Duration("duration_of_duration_type") + expect(t, time.Minute, d) + expect(t, nil, err) + d, err = inputSource.Duration("duration_of_string_type") + expect(t, time.Minute, d) + expect(t, nil, err) + d, err = inputSource.Duration("duration_of_int_type") + refute(t, nil, err) +}