Skip to content

Commit

Permalink
fix: Only consider pointer to structs when checking for embedded fields
Browse files Browse the repository at this point in the history
  • Loading branch information
erezrokah committed Oct 27, 2022
1 parent 5573bc2 commit 64cf625
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 13 deletions.
8 changes: 7 additions & 1 deletion funk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,14 @@ type EmbeddedStruct struct {
EmbeddedField *string
}

type RootStruct struct {
type RootStructPointer struct {
*EmbeddedStruct

RootField *string
}

type RootStructNotPointer struct {
EmbeddedStruct

RootField *string
}
13 changes: 3 additions & 10 deletions retrieve.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,11 @@ func isNilIndirection(v reflect.Value, name string) bool {
vType := v.Type()
for i := 0; i < vType.NumField(); i++ {
field := vType.Field(i)
if !isEmbeddedStructField(field) {
if !isEmbeddedStructPointerField(field) {
return false
}

fieldType := field.Type
if fieldType.Kind() == reflect.Ptr {
fieldType = field.Type.Elem()
}
fieldType := field.Type.Elem()

_, found := fieldType.FieldByName(name)
if found {
Expand All @@ -128,14 +125,10 @@ func isNilIndirection(v reflect.Value, name string) bool {
return false
}

func isEmbeddedStructField(field reflect.StructField) bool {
func isEmbeddedStructPointerField(field reflect.StructField) bool {
if !field.Anonymous {
return false
}

if field.Type.Kind() == reflect.Struct {
return true
}

return field.Type.Kind() == reflect.Ptr && field.Type.Elem().Kind() == reflect.Struct
}
11 changes: 9 additions & 2 deletions retrieve_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,17 @@ func TestGetOrElse(t *testing.T) {
})
}

func TestEmbeddedStruct(t *testing.T) {
func TestEmbeddedStructPointer(t *testing.T) {
is := assert.New(t)

root := RootStruct{}
root := RootStructPointer{}
is.Equal(Get(root, "EmbeddedField"), nil)
is.Equal(Get(root, "EmbeddedStruct.EmbeddedField"), nil)
}

func TestEmbeddedStructNotPointer(t *testing.T) {
is := assert.New(t)

root := RootStructNotPointer{}
is.Equal(Get(root, "EmbeddedField"), nil)
}

0 comments on commit 64cf625

Please sign in to comment.