From 1784735a1ec2e97728e93c9c80ab08871bbb8965 Mon Sep 17 00:00:00 2001 From: ayratsa <119592562+ayratsa@users.noreply.github.com> Date: Sat, 30 Nov 2024 19:58:45 +0200 Subject: [PATCH] Fix slice deep map (#1) * Implement deep mapping of nested slices of structs * Add tests for deep mapping of nested slices of structs --- mapstructure.go | 71 ++++++++++++++++++++++++++++++++++++++++---- mapstructure_test.go | 43 +++++++++++++++++++++++++++ 2 files changed, 108 insertions(+), 6 deletions(-) diff --git a/mapstructure.go b/mapstructure.go index e77e63b..8c193a5 100644 --- a/mapstructure.go +++ b/mapstructure.go @@ -254,6 +254,13 @@ type DecoderConfig struct { // } Squash bool + // Deep will map structures in slices instead of copying them + // + // type Parent struct { + // Children []Child `mapstructure:",deep"` + // } + Deep bool + // Metadata is the struct that will contain extra metadata about // the decoding. If this is nil, then no metadata will be tracked. Metadata *Metadata @@ -999,6 +1006,9 @@ func (d *Decoder) decodeMapFromStruct(name string, dataVal reflect.Value, val re // If Squash is set in the config, we squash the field down. squash := d.config.Squash && v.Kind() == reflect.Struct && f.Anonymous + // If Deep is set in the config, set as default value. + deep := d.config.Deep + v = dereferencePtrToStructIfNeeded(v, d.config.TagName) // Determine the name of the key in the map @@ -1036,6 +1046,9 @@ func (d *Decoder) decodeMapFromStruct(name string, dataVal reflect.Value, val re continue } } + + deep = deep || strings.Index(tagValue[index+1:], "deep") != -1 + if keyNameTagValue := tagValue[:index]; keyNameTagValue != "" { keyName = keyNameTagValue } @@ -1082,6 +1095,41 @@ func (d *Decoder) decodeMapFromStruct(name string, dataVal reflect.Value, val re valMap.SetMapIndex(reflect.ValueOf(keyName), vMap) } + case reflect.Slice: + if deep { + var childType reflect.Type + switch v.Type().Elem().Kind() { + case reflect.Struct: + childType = reflect.TypeOf(map[string]interface{}{}) + default: + childType = v.Type().Elem() + } + + sType := reflect.SliceOf(childType) + + addrVal := reflect.New(sType) + + vSlice := reflect.MakeSlice(sType, v.Len(), v.Cap()) + + if v.Len() > 0 { + reflect.Indirect(addrVal).Set(vSlice) + + err := d.decode(keyName, v.Interface(), reflect.Indirect(addrVal)) + if err != nil { + return err + } + } + + vSlice = reflect.Indirect(addrVal) + + valMap.SetMapIndex(reflect.ValueOf(keyName), vSlice) + + break + } + + // When deep mapping is not needed, fallthrough to normal copy + fallthrough + default: valMap.SetMapIndex(reflect.ValueOf(keyName), v) } @@ -1608,13 +1656,24 @@ func isStructTypeConvertibleToMap(typ reflect.Type, checkMapstructureTags bool, } func dereferencePtrToStructIfNeeded(v reflect.Value, tagName string) reflect.Value { - if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct { + + if v.Kind() != reflect.Ptr { return v } - deref := v.Elem() - derefT := deref.Type() - if isStructTypeConvertibleToMap(derefT, true, tagName) { - return deref + + switch v.Elem().Kind() { + case reflect.Slice: + return v.Elem() + + case reflect.Struct: + deref := v.Elem() + derefT := deref.Type() + if isStructTypeConvertibleToMap(derefT, true, tagName) { + return deref + } + return v + + default: + return v } - return v } diff --git a/mapstructure_test.go b/mapstructure_test.go index 519e722..e210929 100644 --- a/mapstructure_test.go +++ b/mapstructure_test.go @@ -3368,6 +3368,49 @@ func testArrayInput(t *testing.T, input map[string]interface{}, expected *Array) } } +func TestDecode_structArrayDeepMap(t *testing.T) { + type SourceChild struct { + String string `mapstructure:"some-string"` + } + + type SourceParent struct { + ChildrenA []SourceChild `mapstructure:"children-a,deep"` + ChildrenB *[]SourceChild `mapstructure:"children-b,deep"` + } + + var target map[string]interface{} + + source := SourceParent{ + ChildrenA: []SourceChild{ + {String: "one"}, + {String: "two"}, + }, + ChildrenB: &[]SourceChild{ + {String: "one"}, + {String: "two"}, + }, + } + + if err := Decode(source, &target); err != nil { + t.Fatalf("got error: %s", err) + } + + expected := map[string]interface{}{ + "children-a": []map[string]interface{}{ + {"some-string": "one"}, + {"some-string": "two"}, + }, + "children-b": []map[string]interface{}{ + {"some-string": "one"}, + {"some-string": "two"}, + }, + } + + if !reflect.DeepEqual(target, expected) { + t.Fatalf("failed: \nexpected: %#v\nresult: %#v", expected, target) + } +} + func stringPtr(v string) *string { return &v } func intPtr(v int) *int { return &v } func uintPtr(v uint) *uint { return &v }