From ee8f1597809c6b0edec0f56f1c526415522f65a3 Mon Sep 17 00:00:00 2001 From: Alec Sammon Date: Wed, 3 Jul 2024 13:55:58 +0100 Subject: [PATCH] Fix pointer marshal --- sheriff.go | 11 +++++++++++ sheriff_test.go | 41 ++++++++++++++++++++++++++++++++++++++--- 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/sheriff.go b/sheriff.go index c6d1995..0d4844e 100644 --- a/sheriff.go +++ b/sheriff.go @@ -282,10 +282,21 @@ func marshalValue(options *Options, v reflect.Value) (interface{}, error) { // types which are e.g. structs, slices or maps and implement one of the following interfaces should not be // marshalled by sheriff because they'll be correctly marshalled by json.Marshal instead. // Otherwise (e.g. net.IP) a byte slice may be output as a list of uints instead of as an IP string. + // This needs to be checked for both value and pointer types. switch val.(type) { case json.Marshaler, encoding.TextMarshaler, fmt.Stringer: return val, nil } + + if v.CanAddr() { + addrVal := v.Addr().Interface() + + switch addrVal.(type) { + case json.Marshaler, encoding.TextMarshaler, fmt.Stringer: + return addrVal, nil + } + } + k := v.Kind() switch k { diff --git a/sheriff_test.go b/sheriff_test.go index 72b5a62..68bb4c1 100644 --- a/sheriff_test.go +++ b/sheriff_test.go @@ -2,6 +2,7 @@ package sheriff import ( "encoding/json" + "fmt" "net" "reflect" "testing" @@ -579,14 +580,46 @@ type TestMarshal_Embedded struct { Foo string `json:"foo" groups:"test"` } +// TestMarshal_EmbeddedCustom is used to test an embedded struct with a custom marshaler that is not a pointer. +type TestMarshal_EmbeddedCustom struct { + Val int + Set bool +} + +func (t TestMarshal_EmbeddedCustom) MarshalJSON() ([]byte, error) { + if t.Set { + return []byte(fmt.Sprintf("%d", t.Val)), nil + } + + return nil, nil +} + +// TestMarshal_EmbeddedCustomPtr is used to test an embedded struct with a custom marshaler that is a pointer. +type TestMarshal_EmbeddedCustomPtr struct { + Val int + Set bool +} + +func (t *TestMarshal_EmbeddedCustomPtr) MarshalJSON() ([]byte, error) { + if t.Set { + return []byte(fmt.Sprintf("%d", t.Val)), nil + } + + return nil, nil +} + type TestMarshal_EmbeddedParent struct { *TestMarshal_Embedded - Bar string `json:"bar" groups:"test"` + *TestMarshal_EmbeddedCustom `json:"value"` + *TestMarshal_EmbeddedCustomPtr `json:"value_ptr"` + Bar string `json:"bar" groups:"test"` } func TestMarshal_EmbeddedField(t *testing.T) { v := TestMarshal_EmbeddedParent{ &TestMarshal_Embedded{"Hello"}, + &TestMarshal_EmbeddedCustom{10, true}, + &TestMarshal_EmbeddedCustomPtr{20, true}, "World", } o := &Options{Groups: []string{"test"}} @@ -598,8 +631,10 @@ func TestMarshal_EmbeddedField(t *testing.T) { assert.NoError(t, err) expected, err := json.Marshal(map[string]interface{}{ - "bar": "World", - "foo": "Hello", + "bar": "World", + "foo": "Hello", + "value": 10, + "value_ptr": 20, }) assert.NoError(t, err)