Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow use of mock.Anything/etc inside slices #1348

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 106 additions & 67 deletions mock/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func (c *Call) unlock() {

// Return specifies the return arguments for the expectation.
//
// Mock.On("DoSomething").Return(errors.New("failed"))
// Mock.On("DoSomething").Return(errors.New("failed"))
func (c *Call) Return(returnArguments ...interface{}) *Call {
c.lock()
defer c.unlock()
Expand All @@ -111,7 +111,7 @@ func (c *Call) Return(returnArguments ...interface{}) *Call {

// Panic specifies if the functon call should fail and the panic message
//
// Mock.On("DoSomething").Panic("test panic")
// Mock.On("DoSomething").Panic("test panic")
func (c *Call) Panic(msg string) *Call {
c.lock()
defer c.unlock()
Expand All @@ -123,22 +123,22 @@ func (c *Call) Panic(msg string) *Call {

// Once indicates that that the mock should only return the value once.
//
// Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Once()
// Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Once()
func (c *Call) Once() *Call {
return c.Times(1)
}

// Twice indicates that that the mock should only return the value twice.
//
// Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Twice()
// Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Twice()
func (c *Call) Twice() *Call {
return c.Times(2)
}

// Times indicates that that the mock should only return the indicated number
// of times.
//
// Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Times(5)
// Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Times(5)
func (c *Call) Times(i int) *Call {
c.lock()
defer c.unlock()
Expand All @@ -149,7 +149,7 @@ func (c *Call) Times(i int) *Call {
// WaitUntil sets the channel that will block the mock's return until its closed
// or a message is received.
//
// Mock.On("MyMethod", arg1, arg2).WaitUntil(time.After(time.Second))
// Mock.On("MyMethod", arg1, arg2).WaitUntil(time.After(time.Second))
func (c *Call) WaitUntil(w <-chan time.Time) *Call {
c.lock()
defer c.unlock()
Expand All @@ -159,7 +159,7 @@ func (c *Call) WaitUntil(w <-chan time.Time) *Call {

// After sets how long to block until the call returns
//
// Mock.On("MyMethod", arg1, arg2).After(time.Second)
// Mock.On("MyMethod", arg1, arg2).After(time.Second)
func (c *Call) After(d time.Duration) *Call {
c.lock()
defer c.unlock()
Expand All @@ -171,10 +171,10 @@ func (c *Call) After(d time.Duration) *Call {
// mocking a method (such as an unmarshaler) that takes a pointer to a struct and
// sets properties in such struct
//
// Mock.On("Unmarshal", AnythingOfType("*map[string]interface{}")).Return().Run(func(args Arguments) {
// arg := args.Get(0).(*map[string]interface{})
// arg["foo"] = "bar"
// })
// Mock.On("Unmarshal", AnythingOfType("*map[string]interface{}")).Return().Run(func(args Arguments) {
// arg := args.Get(0).(*map[string]interface{})
// arg["foo"] = "bar"
// })
func (c *Call) Run(fn func(args Arguments)) *Call {
c.lock()
defer c.unlock()
Expand All @@ -194,16 +194,18 @@ func (c *Call) Maybe() *Call {
// On chains a new expectation description onto the mocked interface. This
// allows syntax like.
//
// Mock.
// On("MyMethod", 1).Return(nil).
// On("MyOtherMethod", 'a', 'b', 'c').Return(errors.New("Some Error"))
// Mock.
// On("MyMethod", 1).Return(nil).
// On("MyOtherMethod", 'a', 'b', 'c').Return(errors.New("Some Error"))
//
//go:noinline
func (c *Call) On(methodName string, arguments ...interface{}) *Call {
return c.Parent.On(methodName, arguments...)
}

// Unset removes a mock handler from being called.
// test.On("func", mock.Anything).Unset()
//
// test.On("func", mock.Anything).Unset()
func (c *Call) Unset() *Call {
var unlockOnce sync.Once

Expand Down Expand Up @@ -249,9 +251,9 @@ func (c *Call) Unset() *Call {
// calls have been called as expected. The referenced calls may be from the
// same mock instance and/or other mock instances.
//
// Mock.On("Do").Return(nil).Notbefore(
// Mock.On("Init").Return(nil)
// )
// Mock.On("Do").Return(nil).Notbefore(
// Mock.On("Init").Return(nil)
// )
func (c *Call) NotBefore(calls ...*Call) *Call {
c.lock()
defer c.unlock()
Expand Down Expand Up @@ -334,7 +336,7 @@ func (m *Mock) fail(format string, args ...interface{}) {
// On starts a description of an expectation of the specified method
// being called.
//
// Mock.On("MyMethod", arg1, arg2)
// Mock.On("MyMethod", arg1, arg2)
func (m *Mock) On(methodName string, arguments ...interface{}) *Call {
for _, arg := range arguments {
if v := reflect.ValueOf(arg); v.Kind() == reflect.Func {
Expand Down Expand Up @@ -758,6 +760,7 @@ type AnythingOfTypeArgument string
// name of the type to check for. Used in Diff and Assert.
//
// For example:
//
// Assert(t, AnythingOfType("string"), AnythingOfType("int"))
func AnythingOfType(t string) AnythingOfTypeArgument {
return AnythingOfTypeArgument(t)
Expand Down Expand Up @@ -862,6 +865,12 @@ func (args Arguments) Is(objects ...interface{}) bool {
return true
}

type missing struct{}

func (m *missing) String() string {
return "(Missing)"
}

// Diff gets a string describing the differences between the arguments
// and the specified objects.
//
Expand All @@ -879,66 +888,24 @@ func (args Arguments) Diff(objects []interface{}) (string, int) {

for i := 0; i < maxArgCount; i++ {
var actual, expected interface{}
var actualFmt, expectedFmt string

if len(objects) <= i {
actual = "(Missing)"
actualFmt = "(Missing)"
actual = missing{}
} else {
actual = objects[i]
actualFmt = fmt.Sprintf("(%[1]T=%[1]v)", actual)
}

if len(args) <= i {
expected = "(Missing)"
expectedFmt = "(Missing)"
expected = missing{}
} else {
expected = args[i]
expectedFmt = fmt.Sprintf("(%[1]T=%[1]v)", expected)
}

if matcher, ok := expected.(argumentMatcher); ok {
var matches bool
func() {
defer func() {
if r := recover(); r != nil {
actualFmt = fmt.Sprintf("panic in argument matcher: %v", r)
}
}()
matches = matcher.Matches(actual)
}()
if matches {
output = fmt.Sprintf("%s\t%d: PASS: %s matched by %s\n", output, i, actualFmt, matcher)
} else {
differences++
output = fmt.Sprintf("%s\t%d: FAIL: %s not matched by %s\n", output, i, actualFmt, matcher)
}
} else if reflect.TypeOf(expected) == reflect.TypeOf((*AnythingOfTypeArgument)(nil)).Elem() {
// type checking
if reflect.TypeOf(actual).Name() != string(expected.(AnythingOfTypeArgument)) && reflect.TypeOf(actual).String() != string(expected.(AnythingOfTypeArgument)) {
// not match
differences++
output = fmt.Sprintf("%s\t%d: FAIL: type %s != type %s - %s\n", output, i, expected, reflect.TypeOf(actual).Name(), actualFmt)
}
} else if reflect.TypeOf(expected) == reflect.TypeOf((*IsTypeArgument)(nil)) {
t := expected.(*IsTypeArgument).t
if reflect.TypeOf(t) != reflect.TypeOf(actual) {
differences++
output = fmt.Sprintf("%s\t%d: FAIL: type %s != type %s - %s\n", output, i, reflect.TypeOf(t).Name(), reflect.TypeOf(actual).Name(), actualFmt)
}
} else {
// normal checking

if assert.ObjectsAreEqual(expected, Anything) || assert.ObjectsAreEqual(actual, Anything) || assert.ObjectsAreEqual(actual, expected) {
// match
output = fmt.Sprintf("%s\t%d: PASS: %s == %s\n", output, i, actualFmt, expectedFmt)
} else {
// not match
differences++
output = fmt.Sprintf("%s\t%d: FAIL: %s != %s\n", output, i, actualFmt, expectedFmt)
}
equal, elementOutput := compareElements(expected, actual, i, false)
output += elementOutput
if !equal {
differences++
}

}

if differences == 0 {
Expand All @@ -948,6 +915,78 @@ func (args Arguments) Diff(objects []interface{}) (string, int) {
return output, differences
}

func compareElements(expected, actual interface{}, i int, isRecursive bool) (bool, string) {
var expectedFmt, actualFmt string
if m, ok := expected.(missing); ok {
expectedFmt = m.String()
} else {
expectedFmt = fmt.Sprintf("(%[1]T=%[1]v)", expected)
}

if m, ok := actual.(missing); ok {
actualFmt = m.String()
} else {
actualFmt = fmt.Sprintf("(%[1]T=%[1]v)", actual)
}

if matcher, ok := expected.(argumentMatcher); ok {
var matches bool
func() {
defer func() {
if r := recover(); r != nil {
actualFmt = fmt.Sprintf("panic in argument matcher: %v", r)
}
}()
matches = matcher.Matches(actual)
}()
if matches {
return true, fmt.Sprintf("\t%d: PASS: %s matched by %s\n", i, actualFmt, matcher)
} else {
return false, fmt.Sprintf("\t%d: FAIL: %s not matched by %s\n", i, actualFmt, matcher)
}
} else if reflect.TypeOf(expected) == reflect.TypeOf((*AnythingOfTypeArgument)(nil)).Elem() {
// type checking
if reflect.TypeOf(actual).Name() != string(expected.(AnythingOfTypeArgument)) && reflect.TypeOf(actual).String() != string(expected.(AnythingOfTypeArgument)) {
// not match
return false, fmt.Sprintf("\t%d: FAIL: type %s != type %s - %s\n", i, expected, reflect.TypeOf(actual).Name(), actualFmt)
} else {
return true, ""
}
} else if reflect.TypeOf(expected) == reflect.TypeOf((*IsTypeArgument)(nil)) {
t := expected.(*IsTypeArgument).t
if reflect.TypeOf(t) != reflect.TypeOf(actual) {
return false, fmt.Sprintf("\t%d: FAIL: type %s != type %s - %s\n", i, reflect.TypeOf(t).Name(), reflect.TypeOf(actual).Name(), actualFmt)
} else {
return true, ""
}
} else if ev, av := reflect.ValueOf(expected), reflect.ValueOf(actual); ev.Kind() == reflect.Slice && av.Kind() == reflect.Slice && !isRecursive {
// Unroll slices to check for Anything / AnythingOFType

if ev.Len() != av.Len() {
return false, fmt.Sprintf("\t%d: FAIL: %s != %s\n", i, actualFmt, expectedFmt)
}

for e := 0; e < ev.Len(); e++ {
equal, _ := compareElements(ev.Index(e).Interface(), av.Index(e).Interface(), i, true)
if !equal {
return false, fmt.Sprintf("\t%d: FAIL: %s != %s\n", i, actualFmt, expectedFmt)
}
}

return true, fmt.Sprintf("\t%d: PASS: %s == %s\n", i, actualFmt, expectedFmt)
} else {
// normal checking

if assert.ObjectsAreEqual(expected, Anything) || assert.ObjectsAreEqual(actual, Anything) || assert.ObjectsAreEqual(actual, expected) {
// match
return true, fmt.Sprintf("\t%d: PASS: %s == %s\n", i, actualFmt, expectedFmt)
} else {
// not match
return false, fmt.Sprintf("\t%d: FAIL: %s != %s\n", i, actualFmt, expectedFmt)
}
}
}

// Assert compares the arguments with the specified objects and fails if
// they do not exactly match.
func (args Arguments) Assert(t TestingT, objects ...interface{}) bool {
Expand Down
34 changes: 32 additions & 2 deletions mock/mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ func (m *MockTestingT) Errorf(string, ...interface{}) {
// the execution stops.
// When expecting this method, the call that invokes it should use the following code:
//
// assert.PanicsWithValue(t, mockTestingTFailNowCalled, func() {...})
// assert.PanicsWithValue(t, mockTestingTFailNowCalled, func() {...})
func (m *MockTestingT) FailNow() {
m.failNowCount++

Expand Down Expand Up @@ -1599,7 +1599,7 @@ func Test_Mock_AssertOptional(t *testing.T) {
}

/*
Arguments helper methods
Arguments helper methods
*/
func Test_Arguments_Get(t *testing.T) {

Expand Down Expand Up @@ -2067,6 +2067,36 @@ func TestConcurrentArgumentRead(t *testing.T) {
<-done // wait until Use is called or assertions will fail
}

func TestAnythingInSlices(t *testing.T) {
m := &TestExampleImplementation{}

m.On("TheExampleMethodVariadic", []interface{}{1, Anything, 3, Anything, 5}).Return(nil)
var err error

assert.NotPanics(t, func() {
err = m.TheExampleMethodVariadic(1, 2, 3, 4, 5)
})

assert.NoError(t, err)
m.AssertExpectations(t)
m.AssertCalled(t, "TheExampleMethodVariadic", []interface{}{Anything, 2, Anything, 4, Anything})
}

func TestAnythingOfTypeInSlices(t *testing.T) {
m := &TestExampleImplementation{}

m.On("TheExampleMethodVariadic", []interface{}{1, AnythingOfType("int"), 3, AnythingOfType("int"), 5}).Return(nil)
var err error

assert.NotPanics(t, func() {
err = m.TheExampleMethodVariadic(1, 2, 3, 4, 5)
})

assert.NoError(t, err)
m.AssertExpectations(t)
m.AssertCalled(t, "TheExampleMethodVariadic", []interface{}{AnythingOfType("int"), 2, AnythingOfType("int"), 4, AnythingOfType("int")})
}

type caller interface {
Call()
}
Expand Down