From 96d375c7620199652996bf65c542b0a0b9dca745 Mon Sep 17 00:00:00 2001 From: Dave Wyatt Date: Sat, 11 Feb 2023 14:34:48 -0500 Subject: [PATCH 1/3] Add new failing test --- mock/mock_test.go | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/mock/mock_test.go b/mock/mock_test.go index 260bb9c4f..cae2932af 100644 --- a/mock/mock_test.go +++ b/mock/mock_test.go @@ -113,7 +113,7 @@ func (m *MockTestingT) Errorf(string, ...interface{}) { // the execution stops. // When expecting this method, the call that invokes it should use the following code: // -// assert.PanicsWithValue(t, mockTestingTFailNowCalled, func() {...}) +// assert.PanicsWithValue(t, mockTestingTFailNowCalled, func() {...}) func (m *MockTestingT) FailNow() { m.failNowCount++ @@ -1599,7 +1599,7 @@ func Test_Mock_AssertOptional(t *testing.T) { } /* - Arguments helper methods +Arguments helper methods */ func Test_Arguments_Get(t *testing.T) { @@ -2067,6 +2067,21 @@ func TestConcurrentArgumentRead(t *testing.T) { <-done // wait until Use is called or assertions will fail } +func TestAnythingInSlices(t *testing.T) { + m := &TestExampleImplementation{} + + m.On("TheExampleMethodVariadic", []interface{}{1, Anything, 3, Anything, 5}).Return(nil) + var err error + + assert.NotPanics(t, func() { + err = m.TheExampleMethodVariadic(1, 2, 3, 4, 5) + }) + + assert.NoError(t, err) + m.AssertExpectations(t) + m.AssertCalled(t, "TheExampleMethodVaridic", Anything, 2, Anything, 4, Anything) +} + type caller interface { Call() } From 9ae8abf893773e0d7aecb1476de5cc3ee67cd475 Mon Sep 17 00:00:00 2001 From: Dave Wyatt Date: Sat, 11 Feb 2023 15:16:59 -0500 Subject: [PATCH 2/3] Support Anything and AnythingOfType inside slice arguments --- mock/mock.go | 173 ++++++++++++++++++++++++++++------------------ mock/mock_test.go | 17 ++++- 2 files changed, 122 insertions(+), 68 deletions(-) diff --git a/mock/mock.go b/mock/mock.go index e6ff8dfeb..e115304ed 100644 --- a/mock/mock.go +++ b/mock/mock.go @@ -99,7 +99,7 @@ func (c *Call) unlock() { // Return specifies the return arguments for the expectation. // -// Mock.On("DoSomething").Return(errors.New("failed")) +// Mock.On("DoSomething").Return(errors.New("failed")) func (c *Call) Return(returnArguments ...interface{}) *Call { c.lock() defer c.unlock() @@ -111,7 +111,7 @@ func (c *Call) Return(returnArguments ...interface{}) *Call { // Panic specifies if the functon call should fail and the panic message // -// Mock.On("DoSomething").Panic("test panic") +// Mock.On("DoSomething").Panic("test panic") func (c *Call) Panic(msg string) *Call { c.lock() defer c.unlock() @@ -123,14 +123,14 @@ func (c *Call) Panic(msg string) *Call { // Once indicates that that the mock should only return the value once. // -// Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Once() +// Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Once() func (c *Call) Once() *Call { return c.Times(1) } // Twice indicates that that the mock should only return the value twice. // -// Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Twice() +// Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Twice() func (c *Call) Twice() *Call { return c.Times(2) } @@ -138,7 +138,7 @@ func (c *Call) Twice() *Call { // Times indicates that that the mock should only return the indicated number // of times. // -// Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Times(5) +// Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Times(5) func (c *Call) Times(i int) *Call { c.lock() defer c.unlock() @@ -149,7 +149,7 @@ func (c *Call) Times(i int) *Call { // WaitUntil sets the channel that will block the mock's return until its closed // or a message is received. // -// Mock.On("MyMethod", arg1, arg2).WaitUntil(time.After(time.Second)) +// Mock.On("MyMethod", arg1, arg2).WaitUntil(time.After(time.Second)) func (c *Call) WaitUntil(w <-chan time.Time) *Call { c.lock() defer c.unlock() @@ -159,7 +159,7 @@ func (c *Call) WaitUntil(w <-chan time.Time) *Call { // After sets how long to block until the call returns // -// Mock.On("MyMethod", arg1, arg2).After(time.Second) +// Mock.On("MyMethod", arg1, arg2).After(time.Second) func (c *Call) After(d time.Duration) *Call { c.lock() defer c.unlock() @@ -171,10 +171,10 @@ func (c *Call) After(d time.Duration) *Call { // mocking a method (such as an unmarshaler) that takes a pointer to a struct and // sets properties in such struct // -// Mock.On("Unmarshal", AnythingOfType("*map[string]interface{}")).Return().Run(func(args Arguments) { -// arg := args.Get(0).(*map[string]interface{}) -// arg["foo"] = "bar" -// }) +// Mock.On("Unmarshal", AnythingOfType("*map[string]interface{}")).Return().Run(func(args Arguments) { +// arg := args.Get(0).(*map[string]interface{}) +// arg["foo"] = "bar" +// }) func (c *Call) Run(fn func(args Arguments)) *Call { c.lock() defer c.unlock() @@ -194,16 +194,18 @@ func (c *Call) Maybe() *Call { // On chains a new expectation description onto the mocked interface. This // allows syntax like. // -// Mock. -// On("MyMethod", 1).Return(nil). -// On("MyOtherMethod", 'a', 'b', 'c').Return(errors.New("Some Error")) +// Mock. +// On("MyMethod", 1).Return(nil). +// On("MyOtherMethod", 'a', 'b', 'c').Return(errors.New("Some Error")) +// //go:noinline func (c *Call) On(methodName string, arguments ...interface{}) *Call { return c.Parent.On(methodName, arguments...) } // Unset removes a mock handler from being called. -// test.On("func", mock.Anything).Unset() +// +// test.On("func", mock.Anything).Unset() func (c *Call) Unset() *Call { var unlockOnce sync.Once @@ -249,9 +251,9 @@ func (c *Call) Unset() *Call { // calls have been called as expected. The referenced calls may be from the // same mock instance and/or other mock instances. // -// Mock.On("Do").Return(nil).Notbefore( -// Mock.On("Init").Return(nil) -// ) +// Mock.On("Do").Return(nil).Notbefore( +// Mock.On("Init").Return(nil) +// ) func (c *Call) NotBefore(calls ...*Call) *Call { c.lock() defer c.unlock() @@ -334,7 +336,7 @@ func (m *Mock) fail(format string, args ...interface{}) { // On starts a description of an expectation of the specified method // being called. // -// Mock.On("MyMethod", arg1, arg2) +// Mock.On("MyMethod", arg1, arg2) func (m *Mock) On(methodName string, arguments ...interface{}) *Call { for _, arg := range arguments { if v := reflect.ValueOf(arg); v.Kind() == reflect.Func { @@ -758,6 +760,7 @@ type AnythingOfTypeArgument string // name of the type to check for. Used in Diff and Assert. // // For example: +// // Assert(t, AnythingOfType("string"), AnythingOfType("int")) func AnythingOfType(t string) AnythingOfTypeArgument { return AnythingOfTypeArgument(t) @@ -862,6 +865,12 @@ func (args Arguments) Is(objects ...interface{}) bool { return true } +type missing struct{} + +func (m *missing) String() string { + return "(Missing)" +} + // Diff gets a string describing the differences between the arguments // and the specified objects. // @@ -879,66 +888,24 @@ func (args Arguments) Diff(objects []interface{}) (string, int) { for i := 0; i < maxArgCount; i++ { var actual, expected interface{} - var actualFmt, expectedFmt string if len(objects) <= i { - actual = "(Missing)" - actualFmt = "(Missing)" + actual = missing{} } else { actual = objects[i] - actualFmt = fmt.Sprintf("(%[1]T=%[1]v)", actual) } if len(args) <= i { - expected = "(Missing)" - expectedFmt = "(Missing)" + expected = missing{} } else { expected = args[i] - expectedFmt = fmt.Sprintf("(%[1]T=%[1]v)", expected) } - if matcher, ok := expected.(argumentMatcher); ok { - var matches bool - func() { - defer func() { - if r := recover(); r != nil { - actualFmt = fmt.Sprintf("panic in argument matcher: %v", r) - } - }() - matches = matcher.Matches(actual) - }() - if matches { - output = fmt.Sprintf("%s\t%d: PASS: %s matched by %s\n", output, i, actualFmt, matcher) - } else { - differences++ - output = fmt.Sprintf("%s\t%d: FAIL: %s not matched by %s\n", output, i, actualFmt, matcher) - } - } else if reflect.TypeOf(expected) == reflect.TypeOf((*AnythingOfTypeArgument)(nil)).Elem() { - // type checking - if reflect.TypeOf(actual).Name() != string(expected.(AnythingOfTypeArgument)) && reflect.TypeOf(actual).String() != string(expected.(AnythingOfTypeArgument)) { - // not match - differences++ - output = fmt.Sprintf("%s\t%d: FAIL: type %s != type %s - %s\n", output, i, expected, reflect.TypeOf(actual).Name(), actualFmt) - } - } else if reflect.TypeOf(expected) == reflect.TypeOf((*IsTypeArgument)(nil)) { - t := expected.(*IsTypeArgument).t - if reflect.TypeOf(t) != reflect.TypeOf(actual) { - differences++ - output = fmt.Sprintf("%s\t%d: FAIL: type %s != type %s - %s\n", output, i, reflect.TypeOf(t).Name(), reflect.TypeOf(actual).Name(), actualFmt) - } - } else { - // normal checking - - if assert.ObjectsAreEqual(expected, Anything) || assert.ObjectsAreEqual(actual, Anything) || assert.ObjectsAreEqual(actual, expected) { - // match - output = fmt.Sprintf("%s\t%d: PASS: %s == %s\n", output, i, actualFmt, expectedFmt) - } else { - // not match - differences++ - output = fmt.Sprintf("%s\t%d: FAIL: %s != %s\n", output, i, actualFmt, expectedFmt) - } + equal, elementOutput := compareElements(expected, actual, i) + output += elementOutput + if !equal { + differences++ } - } if differences == 0 { @@ -948,6 +915,78 @@ func (args Arguments) Diff(objects []interface{}) (string, int) { return output, differences } +func compareElements(expected, actual interface{}, i int) (bool, string) { + var expectedFmt, actualFmt string + if m, ok := expected.(missing); ok { + expectedFmt = m.String() + } else { + expectedFmt = fmt.Sprintf("(%[1]T=%[1]v)", expected) + } + + if m, ok := actual.(missing); ok { + actualFmt = m.String() + } else { + actualFmt = fmt.Sprintf("(%[1]T=%[1]v)", actual) + } + + if matcher, ok := expected.(argumentMatcher); ok { + var matches bool + func() { + defer func() { + if r := recover(); r != nil { + actualFmt = fmt.Sprintf("panic in argument matcher: %v", r) + } + }() + matches = matcher.Matches(actual) + }() + if matches { + return true, fmt.Sprintf("\t%d: PASS: %s matched by %s\n", i, actualFmt, matcher) + } else { + return false, fmt.Sprintf("\t%d: FAIL: %s not matched by %s\n", i, actualFmt, matcher) + } + } else if reflect.TypeOf(expected) == reflect.TypeOf((*AnythingOfTypeArgument)(nil)).Elem() { + // type checking + if reflect.TypeOf(actual).Name() != string(expected.(AnythingOfTypeArgument)) && reflect.TypeOf(actual).String() != string(expected.(AnythingOfTypeArgument)) { + // not match + return false, fmt.Sprintf("\t%d: FAIL: type %s != type %s - %s\n", i, expected, reflect.TypeOf(actual).Name(), actualFmt) + } else { + return true, "" + } + } else if reflect.TypeOf(expected) == reflect.TypeOf((*IsTypeArgument)(nil)) { + t := expected.(*IsTypeArgument).t + if reflect.TypeOf(t) != reflect.TypeOf(actual) { + return false, fmt.Sprintf("\t%d: FAIL: type %s != type %s - %s\n", i, reflect.TypeOf(t).Name(), reflect.TypeOf(actual).Name(), actualFmt) + } else { + return true, "" + } + } else if ev, av := reflect.ValueOf(expected), reflect.ValueOf(actual); ev.Kind() == reflect.Slice && av.Kind() == reflect.Slice { + // Unroll slices to check for Anything / AnythingOFType + + if ev.Len() != av.Len() { + return false, fmt.Sprintf("\t%d: FAIL: %s != %s\n", i, actualFmt, expectedFmt) + } + + for e := 0; e < ev.Len(); e++ { + equal, _ := compareElements(ev.Index(e).Interface(), av.Index(e).Interface(), i) + if !equal { + return false, fmt.Sprintf("\t%d: FAIL: %s != %s\n", i, actualFmt, expectedFmt) + } + } + + return true, fmt.Sprintf("\t%d: PASS: %s == %s\n", i, actualFmt, expectedFmt) + } else { + // normal checking + + if assert.ObjectsAreEqual(expected, Anything) || assert.ObjectsAreEqual(actual, Anything) || assert.ObjectsAreEqual(actual, expected) { + // match + return true, fmt.Sprintf("\t%d: PASS: %s == %s\n", i, actualFmt, expectedFmt) + } else { + // not match + return false, fmt.Sprintf("\t%d: FAIL: %s != %s\n", i, actualFmt, expectedFmt) + } + } +} + // Assert compares the arguments with the specified objects and fails if // they do not exactly match. func (args Arguments) Assert(t TestingT, objects ...interface{}) bool { diff --git a/mock/mock_test.go b/mock/mock_test.go index cae2932af..ac6683764 100644 --- a/mock/mock_test.go +++ b/mock/mock_test.go @@ -2079,7 +2079,22 @@ func TestAnythingInSlices(t *testing.T) { assert.NoError(t, err) m.AssertExpectations(t) - m.AssertCalled(t, "TheExampleMethodVaridic", Anything, 2, Anything, 4, Anything) + m.AssertCalled(t, "TheExampleMethodVariadic", []interface{}{Anything, 2, Anything, 4, Anything}) +} + +func TestAnythingOfTypeInSlices(t *testing.T) { + m := &TestExampleImplementation{} + + m.On("TheExampleMethodVariadic", []interface{}{1, AnythingOfType("int"), 3, AnythingOfType("int"), 5}).Return(nil) + var err error + + assert.NotPanics(t, func() { + err = m.TheExampleMethodVariadic(1, 2, 3, 4, 5) + }) + + assert.NoError(t, err) + m.AssertExpectations(t) + m.AssertCalled(t, "TheExampleMethodVariadic", []interface{}{AnythingOfType("int"), 2, AnythingOfType("int"), 4, AnythingOfType("int")}) } type caller interface { From 89b6d398b43c76015ce9a7e0ba1be8636f3fedec Mon Sep 17 00:00:00 2001 From: Dave Wyatt Date: Sat, 11 Feb 2023 15:23:01 -0500 Subject: [PATCH 3/3] Limit slice comparisons to 1 level deep --- mock/mock.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mock/mock.go b/mock/mock.go index e115304ed..03e809e30 100644 --- a/mock/mock.go +++ b/mock/mock.go @@ -901,7 +901,7 @@ func (args Arguments) Diff(objects []interface{}) (string, int) { expected = args[i] } - equal, elementOutput := compareElements(expected, actual, i) + equal, elementOutput := compareElements(expected, actual, i, false) output += elementOutput if !equal { differences++ @@ -915,7 +915,7 @@ func (args Arguments) Diff(objects []interface{}) (string, int) { return output, differences } -func compareElements(expected, actual interface{}, i int) (bool, string) { +func compareElements(expected, actual interface{}, i int, isRecursive bool) (bool, string) { var expectedFmt, actualFmt string if m, ok := expected.(missing); ok { expectedFmt = m.String() @@ -959,7 +959,7 @@ func compareElements(expected, actual interface{}, i int) (bool, string) { } else { return true, "" } - } else if ev, av := reflect.ValueOf(expected), reflect.ValueOf(actual); ev.Kind() == reflect.Slice && av.Kind() == reflect.Slice { + } else if ev, av := reflect.ValueOf(expected), reflect.ValueOf(actual); ev.Kind() == reflect.Slice && av.Kind() == reflect.Slice && !isRecursive { // Unroll slices to check for Anything / AnythingOFType if ev.Len() != av.Len() { @@ -967,7 +967,7 @@ func compareElements(expected, actual interface{}, i int) (bool, string) { } for e := 0; e < ev.Len(); e++ { - equal, _ := compareElements(ev.Index(e).Interface(), av.Index(e).Interface(), i) + equal, _ := compareElements(ev.Index(e).Interface(), av.Index(e).Interface(), i, true) if !equal { return false, fmt.Sprintf("\t%d: FAIL: %s != %s\n", i, actualFmt, expectedFmt) }