Skip to content

Commit

Permalink
Fix issue of not being able to decode to pointers of common types suc…
Browse files Browse the repository at this point in the history
…h as *time.Time
  • Loading branch information
greencoda committed Feb 2, 2025
1 parent ced4103 commit 40b6e95
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 1 deletion.
2 changes: 1 addition & 1 deletion commonDecoders.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ func decodeURL(targetValue reflect.Value, sourceValue any) error {
return fmt.Errorf("%w: %w", errCannotParseURL, err)
}

targetValue.Set(reflect.ValueOf(parsedURL))
targetValue.Set(reflect.ValueOf(*parsedURL))

return nil
}
23 changes: 23 additions & 0 deletions commonDecoders_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,29 @@ func (s *CommonDecodersTestSuite) Test_Decode_Time() {
s.NoError(decodeErr)
}

func (s *CommonDecodersTestSuite) Test_Decode_Time_Ptr() {
s.valueContainer.On("Errors").Return([]error{})
s.valueContainer.On("Get").Return([]any{map[string]any{"test_time": "2025-01-13T16:00:00+09:00"}})

loadErr := s.configSet.Load(s.valueContainer)
s.Require().NoError(loadErr)

type targetStruct struct {
TestTime *time.Time `cfg:"test_time"`
}

var (
target targetStruct
expected = time.Date(2025, 1, 13, 16, 0, 0, 0, time.FixedZone("", 9*60*60))
)

decodeErr := s.configSet.Decode(&target)

s.Require().NotNil(target.TestTime)
s.Equal(expected, *target.TestTime)
s.NoError(decodeErr)
}

func (s *CommonDecodersTestSuite) Test_Decode_Time_FromNil() {
s.valueContainer.On("Errors").Return([]error{})
s.valueContainer.On("Get").Return([]any{map[string]any{"test_time": nil}})
Expand Down
9 changes: 9 additions & 0 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,15 @@ func (c *ConfigSet) decodeField(targetValue reflect.Value, fieldOpts fieldOption
)

if commonDecoder := getCommonDecoder(targetValue.Type()); commonDecoder != nil {
// Handle pointer types
for targetValue.Kind() == reflect.Ptr {
if targetValue.IsNil() {
targetValue.Set(reflect.New(targetValue.Type().Elem()))
}

targetValue = targetValue.Elem()
}

if err := commonDecoder(targetValue, fieldConfigValue); err != nil {
if fieldOpts.strict {
return 0, fmt.Errorf("error decoding field value: %w", err)
Expand Down
54 changes: 54 additions & 0 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ func (o *unmarshalerNumber) UnmarshalText(raw []byte) error {
return fmt.Errorf("%w: %w", errCannotParseInt, err)
}

if n < 0 || n > 255 {
return fmt.Errorf("%w: value out of range for uint8", errCannotParseInt)
}

*o = unmarshalerNumber(n)

return nil
Expand Down Expand Up @@ -203,6 +207,56 @@ func (s *DecodeTestSuite) Test_Decode_Map() {
s.NoError(decodeErr)
}

func (s *DecodeTestSuite) Test_Decode_StructMap() {
s.valueContainer.On("Errors").Return([]error{})
s.valueContainer.On("Get").Return([]any{
map[string]any{
"test_map": map[string]any{
"test_map_key_1": map[string]any{
"test_int": 1,
},
"test_map_key_2": map[string]any{
"test_int": 2,
},
"test_map_key_3": map[string]any{
"test_int": 3,
},
},
},
})

loadErr := s.configSet.Load(s.valueContainer)
s.Require().NoError(loadErr)

type targetSubStruct struct {
TestInt int `cfg:"test_int"`
}

type targetStruct struct {
TestMap map[string]targetSubStruct `cfg:"test_map"`
}

var (
target targetStruct
expected = map[string]targetSubStruct{
"test_map_key_1": {
TestInt: 1,
},
"test_map_key_2": {
TestInt: 2,
},
"test_map_key_3": {
TestInt: 3,
},
}
)

decodeErr := s.configSet.Decode(&target)

s.Equal(expected, target.TestMap)
s.NoError(decodeErr)
}

func (s *DecodeTestSuite) Test_Decode_Map_FromInvalidKeyFormat() {
s.valueContainer.On("Errors").Return([]error{})
s.valueContainer.On("Get").Return([]any{
Expand Down
23 changes: 23 additions & 0 deletions primitiveDecoders_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,29 @@ func (s *PrimitiveDecodersTestSuite) Test_DecodeString() {
s.NoError(decodeErr)
}

func (s *PrimitiveDecodersTestSuite) Test_DecodeString_Ptr() {
s.valueContainer.On("Errors").Return([]error{})
s.valueContainer.On("Get").Return([]any{map[string]any{"test_string": "test"}})

loadErr := s.configSet.Load(s.valueContainer)
s.Require().NoError(loadErr)

type targetStruct struct {
TestString *string `cfg:"test_string"`
}

var (
target targetStruct
expected = "test"
)

decodeErr := s.configSet.Decode(&target)

s.Require().NotNil(target.TestString)
s.Equal(expected, *target.TestString)
s.NoError(decodeErr)
}

func (s *PrimitiveDecodersTestSuite) Test_DecodeBool() {
s.valueContainer.On("Errors").Return([]error{})
s.valueContainer.On("Get").Return([]any{map[string]any{"test_bool": true}})
Expand Down

0 comments on commit 40b6e95

Please sign in to comment.