diff --git a/mock/mock.go b/mock/mock.go index eca55f6a1..0c95f95dc 100644 --- a/mock/mock.go +++ b/mock/mock.go @@ -78,6 +78,9 @@ type Call struct { // Calls which must be satisfied before this call can be requires []*Call + + // Calls that depend on this call to be satisfied so succeed + requiredBy []*Call } func newCall(parent *Mock, methodName string, callerInfo []string, methodArguments ...interface{}) *Call { @@ -242,6 +245,21 @@ func (c *Call) Unset() *Call { // trim slice up to last copied index c.Parent.ExpectedCalls = c.Parent.ExpectedCalls[:index] + // in-place filter slice for dependent calls to be cleaned - iterate from 0'th to last skipping unnecessary ones + for _, dependentCall := range c.requiredBy { + var index int + for _, requiredByDependent := range dependentCall.requires { + if requiredByDependent == c { + // Remove from the required calls of the dependent call + continue + } + dependentCall.requires[index] = requiredByDependent + index++ + } + dependentCall.requires = dependentCall.requires[:index] + } + c.requiredBy = []*Call{} + if !foundMatchingCall { unlockOnce.Do(c.unlock) c.Parent.fail("\n\nmock: Could not find expected call\n-----------------------------\n\n%s\n\n", @@ -267,6 +285,8 @@ func (c *Call) NotBefore(calls ...*Call) *Call { if call.Parent == nil { panic("not before calls must be created with Mock.On()") } + + call.requiredBy = append(call.requiredBy, c) } c.requires = append(c.requires, calls...) diff --git a/mock/mock_test.go b/mock/mock_test.go index b80a8a75b..785612527 100644 --- a/mock/mock_test.go +++ b/mock/mock_test.go @@ -592,6 +592,41 @@ func Test_Mock_UnsetIfAlreadyUnsetFails(t *testing.T) { assert.Equal(t, 0, len(mockedService.ExpectedCalls)) } +func Test_Mock_UnsetOfCallRequiredByNotBefore(t *testing.T) { + // make a test impl object + var mockedServiceA = new(TestExampleImplementation) + var mockedServiceB = new(TestExampleImplementation) + var mockedServiceC = new(TestExampleImplementation) + + mock1 := mockedServiceA. + On("TheExampleMethod", 1, 1, 1). + Return(1). + Once() + + mock2 := mockedServiceB. + On("TheExampleMethod", 2, 2, 2). + Return(2). + NotBefore(mock1) + + mock3 := mockedServiceC. + On("TheExampleMethod", 3, 3, 3). + Return(3). + NotBefore(mock1). + NotBefore(mock2) + + assert.Equal(t, 2, len(mock1.requiredBy)) + assert.Equal(t, 1, len(mock2.requires)) + assert.Equal(t, 1, len(mock2.requiredBy)) + assert.Equal(t, 2, len(mock3.requires)) + + mock1.Unset() + + assert.Equal(t, 0, len(mock1.requiredBy)) + assert.Equal(t, 0, len(mock2.requires)) + assert.Equal(t, 1, len(mock2.requiredBy)) + assert.Equal(t, 1, len(mock3.requires)) +} + func Test_Mock_UnsetByOnMethodSpec(t *testing.T) { // make a test impl object var mockedService = new(TestExampleImplementation)