Skip to content

Commit

Permalink
hooks: Recursively search embedded fields for methods (#494)
Browse files Browse the repository at this point in the history
* hooks: Recursively search embedded fields for methods

Follow up to #493 and 840220c

Kong currently supports hooks on embedded fields of a parsed node,
but only at the first level of embedding:

```
type mainCmd struct {
    FooOptions
}

type FooOptions struct {
    BarOptions
}

func (f *FooOptions) BeforeApply() error {
    // this will be called
}

type BarOptions struct {
}

func (b *BarOptions) BeforeApply() error {
    // this will not be called
}
```

This change adds support for hooks to be defined
on embedded fields of embedded fields so that the above
example would work as expected.

Per #493, the definition of "embedded" field is adjusted to mean:

- Any anonymous (Go-embedded) field that is exported
- Any non-anonymous field that is tagged with `embed:""`

*Testing*:
Includes a test case for embedding an anonymous field in an `embed:""`
and an `embed:""` field in an anonymous field.

* Use recursion to build up the list of receivers

The 'receivers' parameter helps avoid constant memory allocation
as the backing storage for the slice is reused across recursive calls.
  • Loading branch information
abhinav authored Jan 30, 2025
1 parent 4e1757c commit 4be6ae6
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 23 deletions.
64 changes: 41 additions & 23 deletions callbacks.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,35 +77,53 @@ func getMethod(value reflect.Value, name string) reflect.Value {
return method
}

// Get methods from the given value and any embedded fields.
// getMethods gets all methods with the given name from the given value
// and any embedded fields.
//
// Returns a slice of bound methods that can be called directly.
func getMethods(value reflect.Value, name string) []reflect.Value {
// Collect all possible receivers
receivers := []reflect.Value{value}
if value.Kind() == reflect.Ptr {
value = value.Elem()
}
if value.Kind() == reflect.Struct {
t := value.Type()
for i := 0; i < value.NumField(); i++ {
field := value.Field(i)
fieldType := t.Field(i)
if !fieldType.IsExported() {
continue
}
// Traverses embedded fields of the struct
// starting from the given value to collect all possible receivers
// for the given method name.
var traverse func(value reflect.Value, receivers []reflect.Value) []reflect.Value
traverse = func(value reflect.Value, receivers []reflect.Value) []reflect.Value {
// Always consider the current value for hooks.
receivers = append(receivers, value)

// Hooks on exported embedded fields should be called.
if fieldType.Anonymous {
receivers = append(receivers, field)
continue
}
if value.Kind() == reflect.Ptr {
value = value.Elem()
}

// If the current value is a struct, also consider embedded fields.
// Two kinds of embedded fields are considered if they're exported:
//
// - standard Go embedded fields
// - fields tagged with `embed:""`
if value.Kind() == reflect.Struct {
t := value.Type()
for i := 0; i < value.NumField(); i++ {
fieldValue := value.Field(i)
field := t.Field(i)

// Hooks on exported fields that are not exported,
// but are tagged with `embed:""` should be called.
if _, ok := fieldType.Tag.Lookup("embed"); ok {
receivers = append(receivers, field)
if !field.IsExported() {
continue
}

// Consider a field embedded if it's actually embedded
// or if it's tagged with `embed:""`.
_, isEmbedded := field.Tag.Lookup("embed")
isEmbedded = isEmbedded || field.Anonymous
if isEmbedded {
receivers = traverse(fieldValue, receivers)
}
}
}

return receivers
}

receivers := traverse(value, nil /* receivers */)

// Search all receivers for methods
var methods []reflect.Value
for _, receiver := range receivers {
Expand Down
19 changes: 19 additions & 0 deletions kong_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2405,6 +2405,8 @@ func TestProviderMethods(t *testing.T) {
}

type EmbeddedCallback struct {
Nested NestedCallback `embed:""`

Embedded bool
}

Expand All @@ -2414,6 +2416,8 @@ func (e *EmbeddedCallback) AfterApply() error {
}

type taggedEmbeddedCallback struct {
NestedCallback

Tagged bool
}

Expand All @@ -2422,6 +2426,15 @@ func (e *taggedEmbeddedCallback) AfterApply() error {
return nil
}

type NestedCallback struct {
nested bool
}

func (n *NestedCallback) AfterApply() error {
n.nested = true
return nil
}

type EmbeddedRoot struct {
EmbeddedCallback
Tagged taggedEmbeddedCallback `embed:""`
Expand All @@ -2441,9 +2454,15 @@ func TestEmbeddedCallbacks(t *testing.T) {
expected := &EmbeddedRoot{
EmbeddedCallback: EmbeddedCallback{
Embedded: true,
Nested: NestedCallback{
nested: true,
},
},
Tagged: taggedEmbeddedCallback{
Tagged: true,
NestedCallback: NestedCallback{
nested: true,
},
},
Root: true,
}
Expand Down

0 comments on commit 4be6ae6

Please sign in to comment.