diff --git a/events.go b/events.go index 682e17b..b63c730 100644 --- a/events.go +++ b/events.go @@ -90,8 +90,12 @@ func (em *events) Trigger(event Event) error { func (em *events) callTypedHandler(handler reflect.Value, args []any) error { // Prepare slice of in arguments for handler. handlerInArgs := make([]reflect.Value, 0, handler.Type().NumIn()) - for index, eventArg := range args { - eventArgType := reflect.TypeOf(eventArg) + + // Fill handler args with provided event args. + maxArgsLen := min(len(args), handler.Type().NumIn()) + for index := 0; index < maxArgsLen; index++ { + eventArgType := reflect.TypeOf(args[index]) + eventArgValue := reflect.ValueOf(args[index]) handlerArgType := handler.Type().In(index) if !eventArgType.AssignableTo(handlerArgType) { return fmt.Errorf( @@ -99,7 +103,15 @@ func (em *events) callTypedHandler(handler reflect.Value, args []any) error { HandlerArgTypeMismatchError, eventArgType, handlerArgType, index, ) } - handlerInArgs = append(handlerInArgs, reflect.ValueOf(eventArg)) + handlerInArgs = append(handlerInArgs, eventArgValue) + } + + // Fill handler args with default type values. + if len(handlerInArgs) < handler.Type().NumIn() { + for index := len(handlerInArgs); index < handler.Type().NumIn(); index++ { + zeroValuePtr := reflect.New(handler.Type().In(index)) + handlerInArgs = append(handlerInArgs, zeroValuePtr.Elem()) + } } // Invoke original event handler function. diff --git a/events_test.go b/events_test.go index 6b9d264..00dda66 100644 --- a/events_test.go +++ b/events_test.go @@ -10,6 +10,8 @@ func TestEvents(t *testing.T) { testEvent1Args := [][]any(nil) testEvent2Args := [][]any(nil) testEvent3Args := [][]any(nil) + testEvent4Args := [][]any(nil) + testEvent5Args := [][]any(nil) ev := &events{events: make(map[string][]handler)} ev.Subscribe("TestEvent1", func(args ...any) { @@ -23,14 +25,26 @@ func TestEvents(t *testing.T) { testEvent3Args = append(testEvent3Args, []any{x, y, z}) return nil }) + ev.Subscribe("TestEvent4", func(x string, y int) error { + testEvent4Args = append(testEvent4Args, []any{x, y}) + return nil + }) + ev.Subscribe("TestEvent5", func(x string, y int, z bool) error { + testEvent5Args = append(testEvent5Args, []any{x, y, z}) + return nil + }) equal(t, ev.Trigger(NewEvent("TestEvent1", 1)), nil) equal(t, ev.Trigger(NewEvent("TestEvent1", "x")), nil) equal(t, ev.Trigger(NewEvent("TestEvent2", true)), nil) equal(t, ev.Trigger(NewEvent("TestEvent3", "x", 1, true)), nil) + equal(t, ev.Trigger(NewEvent("TestEvent4", "x", 1, true)), nil) + equal(t, ev.Trigger(NewEvent("TestEvent5", "x", 1)), nil) equal(t, testEvent1Args, [][]any{{1}, {"x"}}) equal(t, testEvent2Args, [][]any{{true}}) equal(t, testEvent3Args, [][]any{{"x", 1, true}}) + equal(t, testEvent4Args, [][]any{{"x", 1}}) + equal(t, testEvent5Args, [][]any{{"x", 1, false}}) } func equal(t *testing.T, a, b any) { diff --git a/go.mod b/go.mod index 5b920be..1ed5f98 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module github.com/NVIDIA/gontainer -go 1.20 \ No newline at end of file +go 1.21 \ No newline at end of file