Skip to content

Commit

Permalink
fix(dispatch): When registering, compare methods to impl
Browse files Browse the repository at this point in the history
Registering MethodSets compares the MethodSet against the implementation to
make sure they match each other. Add tests for dispatch to ensure common
failure cases actually work.
  • Loading branch information
dustmop committed Mar 5, 2021
1 parent f171fd8 commit c004812
Show file tree
Hide file tree
Showing 3 changed files with 261 additions and 30 deletions.
99 changes: 83 additions & 16 deletions lib/dispatch.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,60 +151,110 @@ type callable struct {
// RegisterMethods iterates the methods provided by the lib API, and makes them visible to dispatch
func (inst *Instance) RegisterMethods() {
reg := make(map[string]callable)
// TODO(dustmop): Change registerOne to take both the MethodSet and the Impl, validate
// that their signatures agree.
inst.registerOne("fsi", &FSIImpl{}, reg)
inst.registerOne("access", accessImpl{}, reg)
inst.registerOne("fsi", inst.Filesys(), fsiImpl{}, reg)
inst.registerOne("access", inst.Access(), accessImpl{}, reg)
inst.regMethods = &regMethodSet{reg: reg}
}

func (inst *Instance) registerOne(ourName string, impl interface{}, reg map[string]callable) {
func (inst *Instance) registerOne(ourName string, methods MethodSet, impl interface{}, reg map[string]callable) {
implType := reflect.TypeOf(impl)
msetType := reflect.TypeOf(methods)
methodMap := inst.buildMethodMap(methods)
// Iterate methods on the implementation, register those that have the right signature
num := implType.NumMethod()
for k := 0; k < num; k++ {
m := implType.Method(k)
lowerName := strings.ToLower(m.Name)
i := implType.Method(k)
lowerName := strings.ToLower(i.Name)
funcName := fmt.Sprintf("%s.%s", ourName, lowerName)

// Validate the parameters to the method
// Validate the parameters to the implementation
// should have 3 input parameters: (receiver, scope, input struct)
// should have 2 output parametres: (output value, error)
// TODO(dustmop): allow variadic returns: error only, cursor for pagination
f := m.Type
f := i.Type
if f.NumIn() != 3 {
log.Fatalf("%s: bad number of inputs: %d", funcName, f.NumIn())
panic(fmt.Sprintf("%s: bad number of inputs: %d", funcName, f.NumIn()))
}
if f.NumOut() != 2 {
log.Fatalf("%s: bad number of outputs: %d", funcName, f.NumOut())
panic(fmt.Sprintf("%s: bad number of outputs: %d", funcName, f.NumOut()))
}
// First input must be the receiver
inType := f.In(0)
if inType != implType {
log.Fatalf("%s: first input param should be impl, got %v", funcName, inType)
panic(fmt.Sprintf("%s: first input param should be impl, got %v", funcName, inType))
}
// Second input must be a scope
inType = f.In(1)
if inType.Name() != "scope" {
log.Fatalf("%s: second input param should be scope, got %v", funcName, inType)
panic(fmt.Sprintf("%s: second input param should be scope, got %v", funcName, inType))
}
// Third input is a pointer to the input struct
inType = f.In(2)
if inType.Kind() != reflect.Ptr {
log.Fatalf("%s: third input param must be a struct pointer, got %v", funcName, inType)
panic(fmt.Sprintf("%s: third input param must be a struct pointer, got %v", funcName, inType))
}
inType = inType.Elem()
if inType.Kind() != reflect.Struct {
log.Fatalf("%s: third input param must be a struct pointer, got %v", funcName, inType)
panic(fmt.Sprintf("%s: third input param must be a struct pointer, got %v", funcName, inType))
}
// First output is anything
outType := f.Out(0)
// Second output must be an error
outErrType := f.Out(1)
if outErrType.Name() != "error" {
log.Fatalf("%s: second output param should be error, got %v", funcName, outErrType)
panic(fmt.Sprintf("%s: second output param should be error, got %v", funcName, outErrType))
}

// Validate the parameters to the method that matches the implementation
// should have 3 input parameters: (receiver, context.Context, input struct [same as impl])
// should have 2 output parametres: (output value [same as impl], error)
m, ok := methodMap[i.Name]
if !ok {
panic(fmt.Sprintf("method %s not found on MethodSet", i.Name))
}
f = m.Type
if f.NumIn() != 3 {
panic(fmt.Sprintf("%s: bad number of inputs: %d", funcName, f.NumIn()))
}
msetNumMethods := f.NumOut()
if msetNumMethods < 1 && msetNumMethods > 2 {
panic(fmt.Sprintf("%s: bad number of outputs: %d", funcName, f.NumOut()))
}
// First input must be the receiver
mType := f.In(0)
if mType.Name() != msetType.Name() {
panic(fmt.Sprintf("%s: first input param should be impl, got %v", funcName, mType))
}
// Second input must be a context
mType = f.In(1)
if mType.Name() != "Context" {
panic(fmt.Sprintf("%s: second input param should be context.Context, got %v", funcName, mType))
}
// Third input is a pointer to the input struct
mType = f.In(2)
if mType.Kind() != reflect.Ptr {
panic(fmt.Sprintf("%s: third input param must be a pointer, got %v", funcName, mType))
}
mType = mType.Elem()
if mType != inType {
panic(fmt.Sprintf("%s: third input param must match impl, expect %v, got %v", funcName, inType, mType))
}
// First output, if there's more than 1, matches the impl output
if msetNumMethods == 2 {
mType = f.Out(0)
if mType != outType {
panic(fmt.Sprintf("%s: first output param must match impl, expect %v, got %v", funcName, outType, mType))
}
}
// Last output must be an error
mType = f.Out(msetNumMethods - 1)
if mType.Name() != "error" {
panic(fmt.Sprintf("%s: last output param should be error, got %v", funcName, mType))
}

// Remove this method from the methodSetMap now that it has been processed
delete(methodMap, i.Name)

// Save the method to the registration table
reg[funcName] = callable{
Impl: impl,
Expand All @@ -214,6 +264,23 @@ func (inst *Instance) registerOne(ourName string, impl interface{}, reg map[stri
}
log.Debugf("%d: registered %s(*%s) %v", k, funcName, inType, outType)
}

for k := range methodMap {
if k != "Name" {
panic(fmt.Sprintf("%s: did not find implementation for method %s", msetType, k))
}
}
}

func (inst *Instance) buildMethodMap(impl interface{}) map[string]reflect.Method {
result := make(map[string]reflect.Method)
implType := reflect.TypeOf(impl)
num := implType.NumMethod()
for k := 0; k < num; k++ {
m := implType.Method(k)
result[m.Name] = m
}
return result
}

// MethodSet represents a set of methods to be registered
Expand Down
166 changes: 166 additions & 0 deletions lib/dispatch_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
package lib

import (
"context"
"fmt"
"testing"
)

func TestRegisterMethods(t *testing.T) {
ctx := context.Background()

inst, cleanup := NewMemTestInstance(ctx, t)
defer cleanup()
m := &animalMethods{d: inst}

reg := make(map[string]callable)
inst.registerOne("animal", m, animalImpl{}, reg)

expectToPanic(t, func() {
reg = make(map[string]callable)
inst.registerOne("animal", m, badAnimalOneImpl{}, reg)
}, "animal.cat: bad number of outputs: 1")

expectToPanic(t, func() {
reg = make(map[string]callable)
inst.registerOne("animal", m, badAnimalTwoImpl{}, reg)
}, "method Doggie not found on MethodSet")

expectToPanic(t, func() {
reg = make(map[string]callable)
inst.registerOne("animal", m, badAnimalThreeImpl{}, reg)
}, "animal.cat: second input param should be scope, got context.Context")

expectToPanic(t, func() {
reg = make(map[string]callable)
inst.registerOne("animal", m, badAnimalFourImpl{}, reg)
}, "animal.dog: third input param must be a struct pointer, got string")

expectToPanic(t, func() {
reg = make(map[string]callable)
inst.registerOne("animal", m, badAnimalFiveImpl{}, reg)
}, "*lib.animalMethods: did not find implementation for method Dog")
}

func expectToPanic(t *testing.T, regFunc func(), expectMessage string) {
t.Helper()

doneCh := make(chan error)
panicMessage := ""

go func() {
defer func() {
if r := recover(); r != nil {
panicMessage = fmt.Sprintf("%s", r)
}
doneCh <- nil
}()
regFunc()
}()
// Block until the goroutine is done
_ = <- doneCh

if panicMessage == "" {
t.Errorf("expected a panic, did not get one")
} else if panicMessage != expectMessage {
t.Errorf("error mismatch, expect: %q, got: %q", expectMessage, panicMessage)
}
}

// Test data: methodSet and implementation

type animalMethods struct {
d dispatcher
}

func (m *animalMethods) Name() string {
return "animal"
}

type animalParams struct {
Name string
}

func (m *animalMethods) Cat(ctx context.Context, p *animalParams) (string, error) {
got, err := m.d.Dispatch(ctx, dispatchMethodName(m, "cat"), p)
if res, ok := got.(string); ok {
return res, err
}
return "", dispatchReturnError(got, err)
}

func (m *animalMethods) Dog(ctx context.Context, p *animalParams) (string, error) {
got, err := m.d.Dispatch(ctx, dispatchMethodName(m, "dog"), p)
if res, ok := got.(string); ok {
return res, err
}
return "", dispatchReturnError(got, err)
}

// Good implementation

type animalImpl struct {}

func (animalImpl) Cat(scp scope, p *animalParams) (string, error) {
return fmt.Sprintf("%s says meow", p.Name), nil
}

func (animalImpl) Dog(scp scope, p *animalParams) (string, error) {
return fmt.Sprintf("%s says bark", p.Name), nil
}

// Bad implementation #1 (cat doesn't return an error)

type badAnimalOneImpl struct {}

func (badAnimalOneImpl) Cat(scp scope, p *animalParams) string {
return fmt.Sprintf("%s says meow", p.Name)
}

func (badAnimalOneImpl) Dog(scp scope, p *animalParams) (string, error) {
return fmt.Sprintf("%s says bark", p.Name), nil
}

// Bad implementation #2 (dog method name doesn't match)

type badAnimalTwoImpl struct {}

func (badAnimalTwoImpl) Cat(scp scope, p *animalParams) (string, error) {
return fmt.Sprintf("%s says meow", p.Name), nil
}

func (badAnimalTwoImpl) Doggie(scp scope, p *animalParams) (string, error) {
return fmt.Sprintf("%s says bark", p.Name), nil
}

// Bad implementation #3 (cat doesn't accept a scope)

type badAnimalThreeImpl struct {}

func (badAnimalThreeImpl) Cat(ctx context.Context, p *animalParams) (string, error) {
return fmt.Sprintf("%s says meow", p.Name), nil
}

func (badAnimalThreeImpl) Dog(scp scope, p *animalParams) (string, error) {
return fmt.Sprintf("%s says bark", p.Name), nil
}

// Bad implementation #4 (dog input struct doesn't match)

type badAnimalFourImpl struct {}

func (badAnimalFourImpl) Cat(scp scope, p *animalParams) (string, error) {
return fmt.Sprintf("%s says meow", p.Name), nil
}

func (badAnimalFourImpl) Dog(scp scope, name string) (string, error) {
return fmt.Sprintf("%s says bark", name), nil
}

// Bad implementation #5 (dog method is missing)

type badAnimalFiveImpl struct {}

func (badAnimalFiveImpl) Cat(scp scope, p *animalParams) (string, error) {
return fmt.Sprintf("%s says meow", p.Name), nil
}
Loading

0 comments on commit c004812

Please sign in to comment.