From c97971d2ae39a0b7498b19fdd652c01af391b13c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miroslav=20=C5=A0ediv=C3=BD?= Date: Mon, 12 Aug 2024 13:58:11 +0200 Subject: [PATCH] Adding support for squash: interface. (#17) * Adding support for squash: interface. * fix lint issues. --- mapstructure.go | 11 ++- mapstructure_test.go | 195 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 203 insertions(+), 3 deletions(-) diff --git a/mapstructure.go b/mapstructure.go index d627088..1cd6204 100644 --- a/mapstructure.go +++ b/mapstructure.go @@ -1375,10 +1375,15 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e } if squash { - if fieldVal.Kind() != reflect.Struct { - errs = append(errs, fmt.Errorf("%s: unsupported type for squash: %s", fieldType.Name, fieldVal.Kind())) - } else { + switch fieldVal.Kind() { + case reflect.Struct: structs = append(structs, fieldVal) + case reflect.Interface: + if !fieldVal.IsNil() { + structs = append(structs, fieldVal.Elem().Elem()) + } + default: + errs = append(errs, fmt.Errorf("%s: unsupported type for squash: %s", fieldType.Name, fieldVal.Kind())) } continue } diff --git a/mapstructure_test.go b/mapstructure_test.go index a7c3d9e..4d246cc 100644 --- a/mapstructure_test.go +++ b/mapstructure_test.go @@ -113,6 +113,60 @@ type SquashOnNonStructType struct { InvalidSquashType int `mapstructure:",squash"` } +type TestInterface interface { + GetVfoo() string + GetVbarfoo() string + GetVfoobar() string +} + +type TestInterfaceImpl struct { + Vfoo string +} + +func (t *TestInterfaceImpl) GetVfoo() string { + return t.Vfoo +} + +func (t *TestInterfaceImpl) GetVbarfoo() string { + return "" +} + +func (t *TestInterfaceImpl) GetVfoobar() string { + return "" +} + +type TestNestedInterfaceImpl struct { + SquashOnNestedInterfaceType `mapstructure:",squash"` + Vfoo string +} + +func (t *TestNestedInterfaceImpl) GetVfoo() string { + return t.Vfoo +} + +func (t *TestNestedInterfaceImpl) GetVbarfoo() string { + return t.Vbarfoo +} + +func (t *TestNestedInterfaceImpl) GetVfoobar() string { + return t.NestedSquash.Vfoobar +} + +type SquashOnInterfaceType struct { + TestInterface `mapstructure:",squash"` + Vbar string +} + +type NestedSquash struct { + SquashOnInterfaceType `mapstructure:",squash"` + Vfoobar string +} + +type SquashOnNestedInterfaceType struct { + NestedSquash NestedSquash `mapstructure:",squash"` + Vbarfoo string +} + type Map struct { Vfoo string Vother map[string]string @@ -1051,6 +1105,147 @@ func TestDecode_SquashOnNonStructType(t *testing.T) { } } +func TestDecode_SquashOnInterfaceType(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "VFoo": "42", + "VBar": "43", + } + + result := SquashOnInterfaceType{ + TestInterface: &TestInterfaceImpl{}, + } + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an err: %s", err) + } + + res := result.GetVfoo() + if res != "42" { + t.Errorf("unexpected value for VFoo: %s", res) + } + + res = result.Vbar + if res != "43" { + t.Errorf("unexpected value for Vbar: %s", res) + } +} + +func TestDecode_SquashOnOuterNestedInterfaceType(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "VFoo": "42", + "VBar": "43", + "Vfoobar": "44", + "Vbarfoo": "45", + } + + result := SquashOnNestedInterfaceType{ + NestedSquash: NestedSquash{ + SquashOnInterfaceType: SquashOnInterfaceType{ + TestInterface: &TestInterfaceImpl{}, + }, + }, + } + + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an err: %s", err) + } + + res := result.NestedSquash.GetVfoo() + if res != "42" { + t.Errorf("unexpected value for VFoo: %s", res) + } + + res = result.NestedSquash.Vbar + if res != "43" { + t.Errorf("unexpected value for Vbar: %s", res) + } + + res = result.NestedSquash.Vfoobar + if res != "44" { + t.Errorf("unexpected value for Vfoobar: %s", res) + } + + res = result.Vbarfoo + if res != "45" { + t.Errorf("unexpected value for Vbarfoo: %s", res) + } +} + +func TestDecode_SquashOnInnerNestedInterfaceType(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "VFoo": "42", + "VBar": "43", + "Vfoobar": "44", + "Vbarfoo": "45", + } + + result := SquashOnInterfaceType{ + TestInterface: &TestNestedInterfaceImpl{ + SquashOnNestedInterfaceType: SquashOnNestedInterfaceType{ + NestedSquash: NestedSquash{ + SquashOnInterfaceType: SquashOnInterfaceType{ + TestInterface: &TestInterfaceImpl{}, + }, + }, + }, + }, + } + + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an err: %s", err) + } + + res := result.GetVfoo() + if res != "42" { + t.Errorf("unexpected value for VFoo: %s", res) + } + + res = result.Vbar + if res != "43" { + t.Errorf("unexpected value for Vbar: %s", res) + } + + res = result.GetVfoobar() + if res != "44" { + t.Errorf("unexpected value for Vfoobar: %s", res) + } + + res = result.GetVbarfoo() + if res != "45" { + t.Errorf("unexpected value for Vbarfoo: %s", res) + } +} + +func TestDecode_SquashOnNilInterfaceType(t *testing.T) { + t.Parallel() + + input := map[string]interface{}{ + "VFoo": "42", + "VBar": "43", + } + + result := SquashOnInterfaceType{ + TestInterface: nil, + } + err := Decode(input, &result) + if err != nil { + t.Fatalf("got an err: %s", err) + } + + res := result.Vbar + if res != "43" { + t.Errorf("unexpected value for Vbar: %s", res) + } +} + func TestDecode_DecodeHook(t *testing.T) { t.Parallel()