diff --git a/callbacks.go b/callbacks.go index b4fe3ca..4644c54 100644 --- a/callbacks.go +++ b/callbacks.go @@ -77,35 +77,53 @@ func getMethod(value reflect.Value, name string) reflect.Value { return method } -// Get methods from the given value and any embedded fields. +// getMethods gets all methods with the given name from the given value +// and any embedded fields. +// +// Returns a slice of bound methods that can be called directly. func getMethods(value reflect.Value, name string) []reflect.Value { - // Collect all possible receivers - receivers := []reflect.Value{value} - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - if value.Kind() == reflect.Struct { - t := value.Type() - for i := 0; i < value.NumField(); i++ { - field := value.Field(i) - fieldType := t.Field(i) - if !fieldType.IsExported() { - continue - } + // Traverses embedded fields of the struct + // starting from the given value to collect all possible receivers + // for the given method name. + var traverse func(value reflect.Value, receivers []reflect.Value) []reflect.Value + traverse = func(value reflect.Value, receivers []reflect.Value) []reflect.Value { + // Always consider the current value for hooks. + receivers = append(receivers, value) - // Hooks on exported embedded fields should be called. - if fieldType.Anonymous { - receivers = append(receivers, field) - continue - } + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + // If the current value is a struct, also consider embedded fields. + // Two kinds of embedded fields are considered if they're exported: + // + // - standard Go embedded fields + // - fields tagged with `embed:""` + if value.Kind() == reflect.Struct { + t := value.Type() + for i := 0; i < value.NumField(); i++ { + fieldValue := value.Field(i) + field := t.Field(i) - // Hooks on exported fields that are not exported, - // but are tagged with `embed:""` should be called. - if _, ok := fieldType.Tag.Lookup("embed"); ok { - receivers = append(receivers, field) + if !field.IsExported() { + continue + } + + // Consider a field embedded if it's actually embedded + // or if it's tagged with `embed:""`. + _, isEmbedded := field.Tag.Lookup("embed") + isEmbedded = isEmbedded || field.Anonymous + if isEmbedded { + receivers = traverse(fieldValue, receivers) + } } } + + return receivers } + + receivers := traverse(value, nil /* receivers */) + // Search all receivers for methods var methods []reflect.Value for _, receiver := range receivers { diff --git a/kong_test.go b/kong_test.go index 2ceb1b1..6b5f5d6 100644 --- a/kong_test.go +++ b/kong_test.go @@ -2405,6 +2405,8 @@ func TestProviderMethods(t *testing.T) { } type EmbeddedCallback struct { + Nested NestedCallback `embed:""` + Embedded bool } @@ -2414,6 +2416,8 @@ func (e *EmbeddedCallback) AfterApply() error { } type taggedEmbeddedCallback struct { + NestedCallback + Tagged bool } @@ -2422,6 +2426,15 @@ func (e *taggedEmbeddedCallback) AfterApply() error { return nil } +type NestedCallback struct { + nested bool +} + +func (n *NestedCallback) AfterApply() error { + n.nested = true + return nil +} + type EmbeddedRoot struct { EmbeddedCallback Tagged taggedEmbeddedCallback `embed:""` @@ -2441,9 +2454,15 @@ func TestEmbeddedCallbacks(t *testing.T) { expected := &EmbeddedRoot{ EmbeddedCallback: EmbeddedCallback{ Embedded: true, + Nested: NestedCallback{ + nested: true, + }, }, Tagged: taggedEmbeddedCallback{ Tagged: true, + NestedCallback: NestedCallback{ + nested: true, + }, }, Root: true, }