diff --git a/go.mod b/go.mod index 6e04a968..3be917f2 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,6 @@ module github.com/golang/mock -require golang.org/x/tools v0.0.0-20190425150028-36563e24a262 +require ( + github.com/pkg/errors v0.8.1 + golang.org/x/tools v0.0.0-20190425150028-36563e24a262 +) diff --git a/go.sum b/go.sum index 9009852e..5c788287 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= diff --git a/gomock/call.go b/gomock/call.go index d3d195c5..96ac9eca 100644 --- a/gomock/call.go +++ b/gomock/call.go @@ -17,6 +17,7 @@ package gomock import ( "fmt" "reflect" + "runtime" "strconv" "strings" ) @@ -149,6 +150,27 @@ func (c *Call) Do(f interface{}) *Call { vargs[i] = reflect.Zero(ft.In(i)) } } + defer func() { + if r := recover(); r != nil { + errMsg, ok := r.(string) + + // We only handle a very specific panic + // If it's not that one, then we "rethrow" the panic + // This allows users to use functions that panic in their tests + if !ok { + panic(r) + } + if !strings.Contains(errMsg, "reflect: Call using") && + !strings.Contains(errMsg, "reflect.Set: value of") { + panic(r) + } + skipFrames := 8 + stackTraceStr := "\n\n" + currentStackTrace(skipFrames) + funcPC := v.Pointer() + file, line := runtime.FuncForPC(funcPC).FileLine(funcPC) + c.t.Fatalf("%s (incorrect func args at %s:%d?)%+v", errMsg, file, line, stackTraceStr) + } + }() v.Call(vargs) return nil }) @@ -239,6 +261,25 @@ func (c *Call) SetArg(n int, value interface{}) *Call { case reflect.Slice: setSlice(args[n], v) default: + defer func() { + if r := recover(); r != nil { + errMsg, ok := r.(string) + + // We only handle a very specific panic + // If it's not that one, then we "rethrow" the panic + // This allows users to use functions that panic in their tests + if !ok { + panic(r) + } + if !strings.Contains(errMsg, "reflect: Call using") && + !strings.Contains(errMsg, "reflect.Set: value of") { + panic(r) + } + skipFrames := 8 + stackTraceStr := "\n\n" + currentStackTrace(skipFrames) + c.t.Fatalf("%s%+v", errMsg, stackTraceStr) + } + }() reflect.ValueOf(args[n]).Elem().Set(v) } return nil @@ -382,7 +423,7 @@ func (c *Call) matches(args []interface{}) error { // Check that the call is not exhausted. if c.exhausted() { - return fmt.Errorf("Expected call at %s has already been called the max number of times.", c.origin) + return fmt.Errorf("Expected call at %s has already been called the max number of times (%d).", c.origin, c.maxCalls) } return nil diff --git a/gomock/stack.go b/gomock/stack.go new file mode 100644 index 00000000..24ce9c30 --- /dev/null +++ b/gomock/stack.go @@ -0,0 +1,43 @@ +package gomock + +import ( + "bytes" + "fmt" + "strings" + + "github.com/pkg/errors" +) + +const skipFrames = 2 + +type stackTracer interface { + StackTrace() errors.StackTrace +} + +func stackTraceStringFromError(err error, skipFrames int) string { + if err, ok := err.(stackTracer); ok { + return stackTraceString(err.StackTrace(), skipFrames) + } + return "" +} + +func currentStackTrace(skipFrames int) string { + err := errors.New("fake error just to get stack") + if err, ok := err.(stackTracer); ok { + return stackTraceString(err.StackTrace(), skipFrames) + } + return "" +} + +func stackTraceString(stackTrace errors.StackTrace, skipFrames int) string { + buffer := bytes.NewBufferString("") + for i := skipFrames + 1; i < len(stackTrace); i++ { + frame := stackTrace[i] + buffer.WriteString(fmt.Sprintf("%+v\n", frame)) + filename := fmt.Sprintf("%s", frame) + if strings.Contains(filename, "_test.go") { + break + } + } + return buffer.String() +}