From d0d91814845efbe2e4b55176f171cacfbcc784ad Mon Sep 17 00:00:00 2001 From: Maximilian Pachl Date: Sat, 14 Oct 2017 11:31:44 +0200 Subject: [PATCH] properly handle embedded (anonymous) fields --- sheriff.go | 65 ++++++++++++++++++++++++++++++++----------------- sheriff_test.go | 61 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 104 insertions(+), 22 deletions(-) diff --git a/sheriff.go b/sheriff.go index 95ecfc6..843101a 100644 --- a/sheriff.go +++ b/sheriff.go @@ -80,32 +80,43 @@ func Marshal(options *Options, data interface{}) (interface{}, error) { continue } - if checkGroups { - groups := strings.Split(field.Tag.Get("groups"), ",") - - shouldShow := listContains(groups, options.Groups) - if !shouldShow || len(groups) == 0 { - continue - } + // if there is an anonymous field which is a struct + // we want the childs exposed at the toplevel to be + // consistent with the embedded json marshaller + if val.Kind() == reflect.Ptr { + val = val.Elem() } - if since := field.Tag.Get("since"); since != "" { - sinceVersion, err := version.NewVersion(since) - if err != nil { - return nil, err - } - if options.ApiVersion.LessThan(sinceVersion) { - continue + // we can skip the group checkif if the field is a composition field + isEmbeddedField := field.Anonymous && val.Kind() == reflect.Struct + if !isEmbeddedField { + if checkGroups { + groups := strings.Split(field.Tag.Get("groups"), ",") + + shouldShow := listContains(groups, options.Groups) + if !shouldShow || len(groups) == 0 { + continue + } } - } - if until := field.Tag.Get("until"); until != "" { - untilVersion, err := version.NewVersion(until) - if err != nil { - return nil, err + if since := field.Tag.Get("since"); since != "" { + sinceVersion, err := version.NewVersion(since) + if err != nil { + return nil, err + } + if options.ApiVersion.LessThan(sinceVersion) { + continue + } } - if options.ApiVersion.GreaterThan(untilVersion) { - continue + + if until := field.Tag.Get("until"); until != "" { + untilVersion, err := version.NewVersion(until) + if err != nil { + return nil, err + } + if options.ApiVersion.GreaterThan(untilVersion) { + continue + } } } @@ -113,7 +124,17 @@ func Marshal(options *Options, data interface{}) (interface{}, error) { if err != nil { return nil, err } - dest[jsonTag] = v + + // when a composition field we want to bring the child + // nodes to the top + nestedVal, ok := v.(map[string]interface{}) + if isEmbeddedField && ok { + for key, value := range nestedVal { + dest[key] = value + } + } else { + dest[jsonTag] = v + } } return dest, nil diff --git a/sheriff_test.go b/sheriff_test.go index 7fbda78..3cd3d5c 100644 --- a/sheriff_test.go +++ b/sheriff_test.go @@ -403,3 +403,64 @@ func TestMarshal_EmptyMap(t *testing.T) { assert.Equal(t, string(expected), string(actual)) } + +type TestMarshal_Embedded struct { + Foo string `json:"foo" groups:"test"` +} + +type TestMarshal_EmbeddedParent struct { + *TestMarshal_Embedded + Bar string `json:"bar" groups:"test"` +} + +func TestMarshal_EmbeddedField(t *testing.T) { + v := TestMarshal_EmbeddedParent{ + &TestMarshal_Embedded{"Hello"}, + "World", + } + o := &Options{Groups: []string{"test"}} + + actualMap, err := Marshal(o, v) + assert.NoError(t, err) + + actual, err := json.Marshal(actualMap) + assert.NoError(t, err) + + expected, err := json.Marshal(map[string]interface{}{ + "bar": "World", + "foo": "Hello", + }) + assert.NoError(t, err) + + assert.Equal(t, string(expected), string(actual)) +} + +type TestMarshal_EmbeddedEmpty struct { + Foo string +} + +type TestMarshal_EmbeddedParentEmpty struct { + *TestMarshal_EmbeddedEmpty + Bar string `json:"bar" groups:"test"` +} + +func TestMarshal_EmbeddedFieldEmpty(t *testing.T) { + v := TestMarshal_EmbeddedParentEmpty{ + &TestMarshal_EmbeddedEmpty{"Hello"}, + "World", + } + o := &Options{Groups: []string{"test"}} + + actualMap, err := Marshal(o, v) + assert.NoError(t, err) + + actual, err := json.Marshal(actualMap) + assert.NoError(t, err) + + expected, err := json.Marshal(map[string]interface{}{ + "bar": "World", + }) + assert.NoError(t, err) + + assert.Equal(t, string(expected), string(actual)) +}