diff --git a/gomock/call.go b/gomock/call.go index 7345f654..694dfd44 100644 --- a/gomock/call.go +++ b/gomock/call.go @@ -19,6 +19,8 @@ import ( "reflect" "strconv" "strings" + + "github.com/golang/mock/gomock/internal/validate" ) // Call represents an expected call to a mock. @@ -106,9 +108,20 @@ func (c *Call) MaxTimes(n int) *Call { // The return values from this function are returned by the mocked function. // It takes an interface{} argument to support n-arity functions. func (c *Call) DoAndReturn(f interface{}) *Call { - // TODO: Check arity and types here, rather than dying badly elsewhere. v := reflect.ValueOf(f) + switch v.Kind() { + case reflect.Func: + mt := c.methodType + + ft := v.Type() + if err := validate.InputAndOutputSig(ft, mt); err != nil { + panic(fmt.Sprintf("DoAndReturn: %s", err)) + } + default: + panic("DoAndReturn: argument must be a function") + } + c.addAction(func(args []interface{}) []interface{} { vargs := make([]reflect.Value, len(args)) ft := v.Type() @@ -135,9 +148,20 @@ func (c *Call) DoAndReturn(f interface{}) *Call { // return values call DoAndReturn. // It takes an interface{} argument to support n-arity functions. func (c *Call) Do(f interface{}) *Call { - // TODO: Check arity and types here, rather than dying badly elsewhere. v := reflect.ValueOf(f) + switch v.Kind() { + case reflect.Func: + mt := c.methodType + + ft := v.Type() + if err := validate.InputSig(ft, mt); err != nil { + panic(fmt.Sprintf("Do: %s", err)) + } + default: + panic("Do: argument must be a function") + } + c.addAction(func(args []interface{}) []interface{} { vargs := make([]reflect.Value, len(args)) ft := v.Type() diff --git a/gomock/call_test.go b/gomock/call_test.go index 3a8315b3..99d069a9 100644 --- a/gomock/call_test.go +++ b/gomock/call_test.go @@ -1,6 +1,21 @@ +// Copyright 2020 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package gomock import ( + "reflect" "testing" ) @@ -49,3 +64,804 @@ func TestCall_After(t *testing.T) { } }) } + +func TestCall_Do(t *testing.T) { + t.Run("Do function matches Call function", func(t *testing.T) { + tr := &mockTestReporter{} + + doFunc := func(x int) bool { + if x < 20 { + return false + } + + return true + } + + callFunc := func(x int) bool { + return false + } + + c := &Call{ + t: tr, + methodType: reflect.TypeOf(callFunc), + } + + c.Do(doFunc) + + if len(c.actions) != 1 { + t.Errorf("expected %d actions but got %d", 1, len(c.actions)) + } + }) + + t.Run("Do function matches Call function and is a interface{}", func(t *testing.T) { + tr := &mockTestReporter{} + + doFunc := func(x int) bool { + if x < 20 { + return false + } + + return true + } + + callFunc := func(x interface{}) bool { + return false + } + + c := &Call{ + t: tr, + methodType: reflect.TypeOf(callFunc), + } + + c.Do(doFunc) + + if len(c.actions) != 1 { + t.Errorf("expected %d actions but got %d", 1, len(c.actions)) + } + }) + + t.Run("Do function matches Call function and is a map[interface{}]interface{}", func(t *testing.T) { + tr := &mockTestReporter{} + + doFunc := func(x map[int]string) bool { + return true + } + + callFunc := func(x map[interface{}]interface{}) bool { + return false + } + + c := &Call{ + t: tr, + methodType: reflect.TypeOf(callFunc), + } + + c.Do(doFunc) + + if len(c.actions) != 1 { + t.Errorf("expected %d actions but got %d", 1, len(c.actions)) + } + }) + + t.Run("Do function matches Call function and is variadic", func(t *testing.T) { + tr := &mockTestReporter{} + + doFunc := func(x []int) bool { + return true + } + + callFunc := func(x ...int) bool { + return false + } + + c := &Call{ + t: tr, + methodType: reflect.TypeOf(callFunc), + } + + c.Do(doFunc) + + if len(c.actions) != 1 { + t.Errorf("expected %d actions but got %d", 1, len(c.actions)) + } + }) + + t.Run("Do function matches Call function and is variadic interface{}", func(t *testing.T) { + tr := &mockTestReporter{} + + doFunc := func(x []int) bool { + return true + } + + callFunc := func(x ...interface{}) bool { + return false + } + + c := &Call{ + t: tr, + methodType: reflect.TypeOf(callFunc), + } + + c.Do(doFunc) + + if len(c.actions) != 1 { + t.Errorf("expected %d actions but got %d", 1, len(c.actions)) + } + }) + + t.Run("argument to Do is not a function", func(t *testing.T) { + tr := &mockTestReporter{} + + callFunc := func(x int, y int) bool { + return false + } + + c := &Call{ + t: tr, + methodType: reflect.TypeOf(callFunc), + } + + defer func() { + if r := recover(); r == nil { + t.Error("expected Do to panic") + } + }() + + c.Do("meow") + + if len(c.actions) != 1 { + t.Errorf("expected %d actions but got %d", 1, len(c.actions)) + } + }) + + t.Run("number of args for Do func don't match Call func", func(t *testing.T) { + tr := &mockTestReporter{} + + doFunc := func(x int) bool { + if x < 20 { + return false + } + + return true + } + + callFunc := func(x int, y int) bool { + return false + } + + c := &Call{ + t: tr, + methodType: reflect.TypeOf(callFunc), + } + + defer func() { + if r := recover(); r == nil { + t.Error("expected Do to panic") + } + }() + + c.Do(doFunc) + + if len(c.actions) != 1 { + t.Errorf("expected %d actions but got %d", 1, len(c.actions)) + } + }) + + t.Run("arg types for Do func don't match Call func", func(t *testing.T) { + tr := &mockTestReporter{} + + doFunc := func(x int) bool { + if x < 20 { + return false + } + + return true + } + + callFunc := func(x string) bool { + return false + } + + c := &Call{ + t: tr, + methodType: reflect.TypeOf(callFunc), + } + + defer func() { + if r := recover(); r == nil { + t.Error("expected Do to panic") + } + }() + + c.Do(doFunc) + + if len(c.actions) != 1 { + t.Errorf("expected %d actions but got %d", 1, len(c.actions)) + } + }) + + t.Run("Do function does not match Call function and is a slice", func(t *testing.T) { + tr := &mockTestReporter{} + + doFunc := func(x []string) bool { + return true + } + + callFunc := func(x []int) bool { + return false + } + + c := &Call{ + t: tr, + methodType: reflect.TypeOf(callFunc), + } + + defer func() { + if r := recover(); r == nil { + t.Error("expected Do to panic") + } + }() + + c.Do(doFunc) + + if len(c.actions) != 1 { + t.Errorf("expected %d actions but got %d", 1, len(c.actions)) + } + }) + + t.Run("Do function does not match Call function and is a slice interface{}", func(t *testing.T) { + tr := &mockTestReporter{} + + doFunc := func(x []string) bool { + return true + } + + callFunc := func(x []interface{}) bool { + return false + } + + c := &Call{ + t: tr, + methodType: reflect.TypeOf(callFunc), + } + + defer func() { + if r := recover(); r == nil { + t.Error("expected Do to panic") + } + }() + + c.Do(doFunc) + + if len(c.actions) != 1 { + t.Errorf("expected %d actions but got %d", 1, len(c.actions)) + } + }) + + t.Run("Do function does not match Call function and is a composite struct", func(t *testing.T) { + tr := &mockTestReporter{} + + doFunc := func(x b) bool { + return true + } + + callFunc := func(x a) bool { + return false + } + + c := &Call{ + t: tr, + methodType: reflect.TypeOf(callFunc), + } + + defer func() { + if r := recover(); r == nil { + t.Error("expected Do to panic") + } + }() + + c.Do(doFunc) + + if len(c.actions) != 1 { + t.Errorf("expected %d actions but got %d", 1, len(c.actions)) + } + }) + + t.Run("Do function does not match Call function and is a map", func(t *testing.T) { + tr := &mockTestReporter{} + + doFunc := func(x map[int]string) bool { + return true + } + + callFunc := func(x map[interface{}]int) bool { + return false + } + + c := &Call{ + t: tr, + methodType: reflect.TypeOf(callFunc), + } + + defer func() { + if r := recover(); r == nil { + t.Error("expected Do to panic") + } + }() + + c.Do(doFunc) + + if len(c.actions) != 1 { + t.Errorf("expected %d actions but got %d", 1, len(c.actions)) + } + }) + + t.Run("number of return vals for Do func don't match Call func", func(t *testing.T) { + tr := &mockTestReporter{} + + doFunc := func(x int) bool { + if x < 20 { + return false + } + + return true + } + + callFunc := func(x int) (bool, error) { + return false, nil + } + + c := &Call{ + t: tr, + methodType: reflect.TypeOf(callFunc), + } + + c.Do(doFunc) + + if len(c.actions) != 1 { + t.Errorf("expected %d actions but got %d", 1, len(c.actions)) + } + }) + + t.Run("return types for Do func don't match Call func", func(t *testing.T) { + tr := &mockTestReporter{} + + doFunc := func(x int) bool { + if x < 20 { + return false + } + + return true + } + + callFunc := func(x int) error { + return nil + } + + c := &Call{ + t: tr, + methodType: reflect.TypeOf(callFunc), + } + + c.Do(doFunc) + + if len(c.actions) != 1 { + t.Errorf("expected %d actions but got %d", 1, len(c.actions)) + } + }) +} + +func TestCall_DoAndReturn(t *testing.T) { + t.Run("DoAndReturn function matches Call function", func(t *testing.T) { + tr := &mockTestReporter{} + + doFunc := func(x int) bool { + if x < 20 { + return false + } + + return true + } + + callFunc := func(x int) bool { + return false + } + + c := &Call{ + t: tr, + methodType: reflect.TypeOf(callFunc), + } + + c.DoAndReturn(doFunc) + + if len(c.actions) != 1 { + t.Errorf("expected %d actions but got %d", 1, len(c.actions)) + } + }) + + t.Run("DoAndReturn function matches Call function and is a interface{}", func(t *testing.T) { + tr := &mockTestReporter{} + + doFunc := func(x int) bool { + if x < 20 { + return false + } + + return true + } + + callFunc := func(x interface{}) bool { + return false + } + + c := &Call{ + t: tr, + methodType: reflect.TypeOf(callFunc), + } + + c.DoAndReturn(doFunc) + + if len(c.actions) != 1 { + t.Errorf("expected %d actions but got %d", 1, len(c.actions)) + } + }) + + t.Run("DoAndReturn function matches Call function and is a map[interface{}]interface{}", func(t *testing.T) { + tr := &mockTestReporter{} + + doFunc := func(x map[int]string) bool { + return true + } + + callFunc := func(x map[interface{}]interface{}) bool { + return false + } + + c := &Call{ + t: tr, + methodType: reflect.TypeOf(callFunc), + } + + c.DoAndReturn(doFunc) + + if len(c.actions) != 1 { + t.Errorf("expected %d actions but got %d", 1, len(c.actions)) + } + }) + + t.Run("DoAndReturn function matches Call function and is variadic", func(t *testing.T) { + tr := &mockTestReporter{} + + doFunc := func(x []int) bool { + return true + } + + callFunc := func(x ...int) bool { + return false + } + + c := &Call{ + t: tr, + methodType: reflect.TypeOf(callFunc), + } + + c.DoAndReturn(doFunc) + + if len(c.actions) != 1 { + t.Errorf("expected %d actions but got %d", 1, len(c.actions)) + } + }) + + t.Run("DoAndReturn function matches Call function and is variadic interface{}", func(t *testing.T) { + tr := &mockTestReporter{} + + doFunc := func(x []int) bool { + return true + } + + callFunc := func(x ...interface{}) bool { + return false + } + + c := &Call{ + t: tr, + methodType: reflect.TypeOf(callFunc), + } + + c.DoAndReturn(doFunc) + + if len(c.actions) != 1 { + t.Errorf("expected %d actions but got %d", 1, len(c.actions)) + } + }) + + t.Run("argument to DoAndReturn is not a function", func(t *testing.T) { + tr := &mockTestReporter{} + + callFunc := func(x int, y int) bool { + return false + } + + c := &Call{ + t: tr, + methodType: reflect.TypeOf(callFunc), + } + + defer func() { + if r := recover(); r == nil { + t.Error("expected DoAndReturn to panic") + } + }() + + c.DoAndReturn("meow") + + if len(c.actions) != 1 { + t.Errorf("expected %d actions but got %d", 1, len(c.actions)) + } + }) + + t.Run("number of args for DoAndReturn func don't match Call func", func(t *testing.T) { + tr := &mockTestReporter{} + + doFunc := func(x int) bool { + if x < 20 { + return false + } + + return true + } + + callFunc := func(x int, y int) bool { + return false + } + + c := &Call{ + t: tr, + methodType: reflect.TypeOf(callFunc), + } + + defer func() { + if r := recover(); r == nil { + t.Error("expected DoAndReturn to panic") + } + }() + + c.DoAndReturn(doFunc) + + if len(c.actions) != 1 { + t.Errorf("expected %d actions but got %d", 1, len(c.actions)) + } + }) + + t.Run("arg types for DoAndReturn func don't match Call func", func(t *testing.T) { + tr := &mockTestReporter{} + + doFunc := func(x int) bool { + if x < 20 { + return false + } + + return true + } + + callFunc := func(x string) bool { + return false + } + + c := &Call{ + t: tr, + methodType: reflect.TypeOf(callFunc), + } + + defer func() { + if r := recover(); r == nil { + t.Error("expected DoAndReturn to panic") + } + }() + + c.DoAndReturn(doFunc) + + if len(c.actions) != 1 { + t.Errorf("expected %d actions but got %d", 1, len(c.actions)) + } + }) + + t.Run("DoAndReturn function does not match Call function and is a slice", func(t *testing.T) { + tr := &mockTestReporter{} + + doFunc := func(x []string) bool { + return true + } + + callFunc := func(x []int) bool { + return false + } + + c := &Call{ + t: tr, + methodType: reflect.TypeOf(callFunc), + } + + defer func() { + if r := recover(); r == nil { + t.Error("expected DoAndReturn to panic") + } + }() + + c.DoAndReturn(doFunc) + + if len(c.actions) != 1 { + t.Errorf("expected %d actions but got %d", 1, len(c.actions)) + } + }) + + t.Run("DoAndReturn function does not match Call function and is a slice interface{}", func(t *testing.T) { + tr := &mockTestReporter{} + + doFunc := func(x []string) bool { + return true + } + + callFunc := func(x []interface{}) bool { + return false + } + + c := &Call{ + t: tr, + methodType: reflect.TypeOf(callFunc), + } + + defer func() { + if r := recover(); r == nil { + t.Error("expected DoAndReturn to panic") + } + }() + + c.DoAndReturn(doFunc) + + if len(c.actions) != 1 { + t.Errorf("expected %d actions but got %d", 1, len(c.actions)) + } + }) + + t.Run("DoAndReturn function does not match Call function and is a composite struct", func(t *testing.T) { + tr := &mockTestReporter{} + + doFunc := func(x b) bool { + return true + } + + callFunc := func(x a) bool { + return false + } + + c := &Call{ + t: tr, + methodType: reflect.TypeOf(callFunc), + } + + defer func() { + if r := recover(); r == nil { + t.Error("expected DoAndReturn to panic") + } + }() + + c.DoAndReturn(doFunc) + + if len(c.actions) != 1 { + t.Errorf("expected %d actions but got %d", 1, len(c.actions)) + } + }) + + t.Run("DoAndReturn function does not match Call function and is a map", func(t *testing.T) { + tr := &mockTestReporter{} + + doFunc := func(x map[int]string) bool { + return true + } + + callFunc := func(x map[interface{}]int) bool { + return false + } + + c := &Call{ + t: tr, + methodType: reflect.TypeOf(callFunc), + } + + defer func() { + if r := recover(); r == nil { + t.Error("expected DoAndReturn to panic") + } + }() + + c.DoAndReturn(doFunc) + + if len(c.actions) != 1 { + t.Errorf("expected %d actions but got %d", 1, len(c.actions)) + } + }) + + t.Run("number of return vals for DoAndReturn func don't match Call func", func(t *testing.T) { + tr := &mockTestReporter{} + + doFunc := func(x int) bool { + if x < 20 { + return false + } + + return true + } + + callFunc := func(x int) (bool, error) { + return false, nil + } + + c := &Call{ + t: tr, + methodType: reflect.TypeOf(callFunc), + } + + defer func() { + if r := recover(); r == nil { + t.Error("expected DoAndReturn to panic") + } + }() + + c.DoAndReturn(doFunc) + + if len(c.actions) != 1 { + t.Errorf("expected %d actions but got %d", 1, len(c.actions)) + } + }) + + t.Run("return types for DoAndReturn func don't match Call func", func(t *testing.T) { + tr := &mockTestReporter{} + + doFunc := func(x int) bool { + if x < 20 { + return false + } + + return true + } + + callFunc := func(x int) error { + return nil + } + + c := &Call{ + t: tr, + methodType: reflect.TypeOf(callFunc), + } + + defer func() { + if r := recover(); r == nil { + t.Error("expected DoAndReturn to panic") + } + }() + + c.DoAndReturn(doFunc) + + if len(c.actions) != 1 { + t.Errorf("expected %d actions but got %d", 1, len(c.actions)) + } + }) +} + +type a struct { + name string +} + +func (testObj a) Name() string { + return testObj.name +} + +type b struct { + a + foo string +} + +func (testObj b) Foo() string { + return testObj.foo +} diff --git a/gomock/controller_test.go b/gomock/controller_test.go index c22908b8..1f6b09d3 100644 --- a/gomock/controller_test.go +++ b/gomock/controller_test.go @@ -492,9 +492,11 @@ func TestDo(t *testing.T) { doCalled := false var argument string ctrl.RecordCall(subject, "FooMethod", "argument").Do( - func(arg string) { + func(arg string) int { doCalled = true argument = arg + + return 0 }) if doCalled { t.Error("Do() callback called too early.") diff --git a/gomock/internal/validate/validate.go b/gomock/internal/validate/validate.go new file mode 100644 index 00000000..5ef253ac --- /dev/null +++ b/gomock/internal/validate/validate.go @@ -0,0 +1,199 @@ +// Copyright 2020 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package validate + +import ( + "fmt" + "reflect" +) + +// InputAndOutputSig compares the argument and return signatures of actualFunc +// against expectedFunc. It returns an error unless everything matches. +func InputAndOutputSig(actualFunc, expectedFunc reflect.Type) error { + if err := InputSig(actualFunc, expectedFunc); err != nil { + return err + } + + if err := outputSig(actualFunc, expectedFunc); err != nil { + return err + } + + return nil +} + +// InputSig compares the argument signatures of actualFunc +// against expectedFunc. It returns an error unless everything matches. +func InputSig(actualFunc, expectedFunc reflect.Type) error { + // check number of arguments and type of each argument + if actualFunc.NumIn() != expectedFunc.NumIn() { + return fmt.Errorf( + "expected function to have %d arguments not %d", + expectedFunc.NumIn(), actualFunc.NumIn()) + } + + lastIdx := expectedFunc.NumIn() + + // If the function has a variadic argument validate that one first so that + // we aren't checking for it while we iterate over the other args + if expectedFunc.IsVariadic() { + if ok := variadicArg(lastIdx, actualFunc, expectedFunc); !ok { + i := lastIdx - 1 + return fmt.Errorf( + "expected function to have"+ + " arg of type %v at position %d"+ + " not type %v", + expectedFunc.In(i), i, actualFunc.In(i), + ) + } + + lastIdx-- + } + + for i := 0; i < lastIdx; i++ { + expectedArg := expectedFunc.In(i) + actualArg := actualFunc.In(i) + + if err := arg(actualArg, expectedArg); err != nil { + return fmt.Errorf("input argument at %d: %s", i, err) + } + } + + return nil +} + +func outputSig(actualFunc, expectedFunc reflect.Type) error { + // check number of return vals and type of each val + if actualFunc.NumOut() != expectedFunc.NumOut() { + return fmt.Errorf( + "expected function to have %d return vals not %d", + expectedFunc.NumOut(), actualFunc.NumOut()) + } + + for i := 0; i < expectedFunc.NumOut(); i++ { + expectedArg := expectedFunc.Out(i) + actualArg := actualFunc.Out(i) + + if err := arg(actualArg, expectedArg); err != nil { + return fmt.Errorf("return argument at %d: %s", i, err) + } + } + + return nil +} + +func variadicArg(lastIdx int, actualFunc, expectedFunc reflect.Type) bool { + if actualFunc.In(lastIdx-1) != expectedFunc.In(lastIdx-1) { + if actualFunc.In(lastIdx-1).Kind() != reflect.Slice { + return false + } + + expectedArgT := expectedFunc.In(lastIdx - 1) + expectedElem := expectedArgT.Elem() + if expectedElem.Kind() != reflect.Interface { + return false + } + + actualArgT := actualFunc.In(lastIdx - 1) + actualElem := actualArgT.Elem() + + if ok := actualElem.ConvertibleTo(expectedElem); !ok { + return false + } + + } + + return true +} + +func interfaceArg(actualArg, expectedArg reflect.Type) error { + if !actualArg.ConvertibleTo(expectedArg) { + return fmt.Errorf( + "expected arg convertible to type %v not type %v", + expectedArg, actualArg, + ) + } + + return nil +} + +func mapArg(actualArg, expectedArg reflect.Type) error { + expectedKey := expectedArg.Key() + actualKey := actualArg.Key() + + switch expectedKey.Kind() { + case reflect.Interface: + if err := interfaceArg(actualKey, expectedKey); err != nil { + return fmt.Errorf("map key: %s", err) + } + default: + if actualKey != expectedKey { + return fmt.Errorf("expected map key of type %v not type %v", + expectedKey, actualKey) + } + } + + expectedElem := expectedArg.Elem() + actualElem := actualArg.Elem() + + switch expectedElem.Kind() { + case reflect.Interface: + if err := interfaceArg(actualElem, expectedElem); err != nil { + return fmt.Errorf("map element: %s", err) + } + default: + if actualElem != expectedElem { + return fmt.Errorf("expected map element of type %v not type %v", + expectedElem, actualElem) + } + } + + return nil +} + +func arg(actualArg, expectedArg reflect.Type) error { + switch expectedArg.Kind() { + // If the expected arg is an interface we only care if the actual arg is convertible + // to that interface + case reflect.Interface: + if err := interfaceArg(actualArg, expectedArg); err != nil { + return err + } + default: + // If the expected arg is not an interface then first check to see if + // the actual arg is even the same reflect.Kind + if expectedArg.Kind() != actualArg.Kind() { + return fmt.Errorf("expected arg of kind %v not %v", + expectedArg.Kind(), actualArg.Kind()) + } + + switch expectedArg.Kind() { + // If the expected arg is a map then we need to handle the case where + // the map key or element type is an interface + case reflect.Map: + if err := mapArg(actualArg, expectedArg); err != nil { + return err + } + default: + if actualArg != expectedArg { + return fmt.Errorf( + "Expected arg of type %v not type %v", + expectedArg, actualArg, + ) + } + } + } + + return nil +}