Skip to content

Commit

Permalink
Typed signatures for event handlers.
Browse files Browse the repository at this point in the history
Signed-off-by: Pavel Patrin <ppatrin@nvidia.com>
  • Loading branch information
pavelpatrin committed Aug 4, 2024
1 parent db04011 commit 00353ec
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 47 deletions.
23 changes: 22 additions & 1 deletion container.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,24 @@ import (
"sync"
)

// Events declaration.
const (
// ContainerStarting declares container starting event.
ContainerStarting = "ContainerStarting"

// ContainerStarted declares container started event.
ContainerStarted = "ContainerStarted"

// ContainerClosing declares container closing event.
ContainerClosing = "ContainerClosing"

// ContainerClosed declares container closed event.
ContainerClosed = "ContainerClosed"

// UnhandledPanic declares unhandled panic in container.
UnhandledPanic = "UnhandledPanic"
)

// New returns new container instance with a set of configured services.
// The `factories` specifies factories for services with dependency resolution.
func New(factories ...*Factory) (result Container, err error) {
Expand All @@ -43,7 +61,10 @@ func New(factories ...*Factory) (result Container, err error) {
}()

// Prepare events broker instance.
events := events{}
events := &events{
mutex: sync.RWMutex{},
events: make(map[string][]handler),
}

// Prepare services registry instance.
registry := &registry{events: events}
Expand Down
139 changes: 96 additions & 43 deletions events.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,70 +20,120 @@ package gontainer
import (
"errors"
"fmt"
"reflect"
"sync"
)

// Events declaration.
const (
// ContainerStarting declares container starting event.
ContainerStarting = "ContainerStarting"

// ContainerStarted declares container started event.
ContainerStarted = "ContainerStarted"

// ContainerClosing declares container closing event.
ContainerClosing = "ContainerClosing"

// ContainerClosed declares container closed event.
ContainerClosed = "ContainerClosed"

// UnhandledPanic declares unhandled panic in container.
UnhandledPanic = "UnhandledPanic"
)
// HandlerTypeMismatchError declares handler type mismatch error.
var HandlerTypeMismatchError = errors.New("handler type mismatch")

// Events declares event broker type.
type Events interface {
// Subscribe registers event handler.
Subscribe(name string, handler any)
Subscribe(name string, handlerFn any)

// Trigger triggers specified event handlers.
Trigger(event Event) error
}

// events implements Events interface.
type events map[string][]Handler
type events struct {
mutex sync.RWMutex
events map[string][]handler
}

// Subscribe subscribes event handler to the event.
func (e events) Subscribe(name string, handler any) {
var handlerWrapper Handler

// Infer event handler signature.
switch handler := handler.(type) {
case func():
handlerWrapper = func(...any) error { handler(); return nil }
case func(...any):
handlerWrapper = func(args ...any) error { handler(args...); return nil }
case func() error:
handlerWrapper = func(...any) error { return handler() }
case func(...any) error:
handlerWrapper = func(args ...any) error { return handler(args...) }
func (em *events) Subscribe(name string, handlerFn any) {
em.mutex.Lock()
defer em.mutex.Unlock()

// Validate event handler type.
handlerValue := reflect.ValueOf(handlerFn)
handlerType := handlerValue.Type()
if handlerType.Kind() != reflect.Func {
panic(fmt.Sprintf("unexpected event handler type: %T", handlerFn))
}

// Validate event handler output signature.
switch {
case handlerType.NumOut() == 0:
case handlerType.NumOut() == 1 && handlerType.Out(0).Implements(errorType):
default:
panic(fmt.Sprintf("unexpected event handler type: %T", handler))
panic(fmt.Sprintf("unexpected event handler signature: %T", handlerFn))
}

e[name] = append(e[name], handlerWrapper)
// Register event handler function.
if handlerType.NumIn() == 1 && handlerType.In(0) == anySliceType {
em.events[name] = append(em.events[name], func(event Event) error {
return em.callAnyVarHandler(handlerValue, event.Args())
})
} else {
em.events[name] = append(em.events[name], func(event Event) error {
return em.callTypedHandler(handlerValue, event.Args())
})
}
}

// Trigger triggers specified event handlers.
func (e events) Trigger(event Event) error {
errs := make([]error, 0, len(e[event.Name()]))
for _, handler := range e[event.Name()] {
if err := handler(event.Args()...); err != nil {
func (em *events) Trigger(event Event) error {
em.mutex.RLock()
defer em.mutex.RUnlock()

errs := make([]error, 0, len(em.events[event.Name()]))
for _, handler := range em.events[event.Name()] {
if err := handler(event); err != nil {
errs = append(errs, err)
}
}

return errors.Join(errs...)
}

// callTypedHandler calls `func(TypeA, TypeB, TypeC) [error]` event handler.
func (em *events) callTypedHandler(handler reflect.Value, args []any) error {
// Prepare slice of in arguments for handler.
handlerInArgs := make([]reflect.Value, 0, handler.Type().NumIn())
for index, eventArg := range args {
eventArgType := reflect.TypeOf(eventArg)
handlerArgType := handler.Type().In(index)
if !eventArgType.AssignableTo(handlerArgType) {
return fmt.Errorf(
"%w: type '%s' is not assignable to '%s' (index %d)",
HandlerTypeMismatchError, eventArgType, handlerArgType, index,
)
}
handlerInArgs = append(handlerInArgs, reflect.ValueOf(eventArg))
}

// Invoke original event handler function.
handlerOutArgs := handler.Call(handlerInArgs)
return em.getCallOutError(handlerOutArgs)
}

// callAnyVarHandler calls `func(...any) [error]` event handler.
func (em *events) callAnyVarHandler(handler reflect.Value, args []any) error {
// Prepare slice of in arguments for handler.
handlerInArgs := make([]reflect.Value, 0, len(args))
for _, arg := range args {
handlerInArgs = append(handlerInArgs, reflect.ValueOf(arg))
}

// Invoke original event handler function.
handlerOutArgs := handler.Call(handlerInArgs)
return em.getCallOutError(handlerOutArgs)
}

func (em *events) getCallOutError(outArgs []reflect.Value) error {
if len(outArgs) == 1 {
// Use the value as an error.
// Ignore failed cast of nil error.
err, _ := outArgs[0].Interface().(error)
return err
}

return nil
}

// Event declares service container events.
type Event interface {
// Name returns event name.
Expand All @@ -95,11 +145,11 @@ type Event interface {

// NewEvent returns new event instance.
func NewEvent(name string, args ...any) Event {
return event{name: name, args: args}
return &event{name: name, args: args}
}

// Handler declares event handler function.
type Handler func(args ...any) error
// handler declares event handler function.
type handler func(event Event) error

// event wraps string event.
type event struct {
Expand All @@ -108,7 +158,10 @@ type event struct {
}

// Name implements Event interface.
func (e event) Name() string { return e.name }
func (e *event) Name() string { return e.name }

// Args implements Event interface.
func (e event) Args() []any { return e.args }
func (e *event) Args() []any { return e.args }

// anySliceType contains reflection type for any slice variable.
var anySliceType = reflect.TypeOf((*[]any)(nil)).Elem()
9 changes: 8 additions & 1 deletion events_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,28 @@ import (
func TestEvents(t *testing.T) {
testEvent1Args := [][]any(nil)
testEvent2Args := [][]any(nil)
testEvent3Args := [][]any(nil)

ev := make(events)
ev := &events{events: make(map[string][]handler)}
ev.Subscribe("TestEvent1", func(args ...any) {
testEvent1Args = append(testEvent1Args, args)
})
ev.Subscribe("TestEvent2", func(args ...any) error {
testEvent2Args = append(testEvent2Args, args)
return nil
})
ev.Subscribe("TestEvent3", func(x string, y int, z bool) error {
testEvent3Args = append(testEvent3Args, []any{x, y, z})
return nil
})

equal(t, ev.Trigger(NewEvent("TestEvent1", 1)), nil)
equal(t, ev.Trigger(NewEvent("TestEvent1", "x")), nil)
equal(t, ev.Trigger(NewEvent("TestEvent2", true)), nil)
equal(t, ev.Trigger(NewEvent("TestEvent3", "x", 1, true)), nil)
equal(t, testEvent1Args, [][]any{{1}, {"x"}})
equal(t, testEvent2Args, [][]any{{true}})
equal(t, testEvent3Args, [][]any{{"x", 1, true}})
}

func equal(t *testing.T, a, b any) {
Expand Down
4 changes: 2 additions & 2 deletions registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func TestRegistryRegisterFactory(t *testing.T) {
opts := WithMetadata("test", func() {})
factory := NewFactory(fun, opts)

registry := &registry{events: events{}}
registry := &registry{}
equal(t, registry.registerFactory(ctx, factory), nil)
equal(t, registry.factories, []*Factory{factory})
equal(t, factory.factoryFunc == nil, false)
Expand All @@ -32,7 +32,7 @@ func TestRegistryStartFactories(t *testing.T) {
ctx := context.Background()
factory := NewFactory(func() bool { return true })

registry := &registry{events: events{}}
registry := &registry{}
equal(t, registry.registerFactory(ctx, factory), nil)
equal(t, registry.startFactories(), nil)
equal(t, factory.factorySpawned, true)
Expand Down

0 comments on commit 00353ec

Please sign in to comment.