diff --git a/gomock/callset.go b/gomock/callset.go index eacfe72..f2131a1 100644 --- a/gomock/callset.go +++ b/gomock/callset.go @@ -29,6 +29,8 @@ type callSet struct { expectedMu *sync.Mutex // Calls that have been exhausted. exhausted map[callSetKey][]*Call + // when set to true, existing call expectations are overridden when new call expectations are made + allowOverride bool } // callSetKey is the key in the maps in callSet @@ -45,6 +47,15 @@ func newCallSet() *callSet { } } +func newOverridableCallSet() *callSet { + return &callSet{ + expected: make(map[callSetKey][]*Call), + expectedMu: &sync.Mutex{}, + exhausted: make(map[callSetKey][]*Call), + allowOverride: true, + } +} + // Add adds a new expected call. func (cs callSet) Add(call *Call) { key := callSetKey{call.receiver, call.method} @@ -56,6 +67,10 @@ func (cs callSet) Add(call *Call) { if call.exhausted() { m = cs.exhausted } + if cs.allowOverride { + m[key] = make([]*Call, 0) + } + m[key] = append(m[key], call) } diff --git a/gomock/callset_test.go b/gomock/callset_test.go index c69c86a..74e2ce4 100644 --- a/gomock/callset_test.go +++ b/gomock/callset_test.go @@ -42,6 +42,24 @@ func TestCallSetAdd(t *testing.T) { } } +func TestCallSetAdd_WhenOverridable_ClearsPreviousExpectedAndExhausted(t *testing.T) { + method := "TestMethod" + var receiver interface{} = "TestReceiver" + cs := newOverridableCallSet() + + cs.Add(newCall(t, receiver, method, reflect.TypeOf(receiverType{}.Func))) + numExpectedCalls := len(cs.expected[callSetKey{receiver, method}]) + if numExpectedCalls != 1 { + t.Fatalf("Expected 1 expected call in callset, got %d", numExpectedCalls) + } + + cs.Add(newCall(t, receiver, method, reflect.TypeOf(receiverType{}.Func))) + newNumExpectedCalls := len(cs.expected[callSetKey{receiver, method}]) + if newNumExpectedCalls != 1 { + t.Fatalf("Expected 1 expected call in callset, got %d", newNumExpectedCalls) + } +} + func TestCallSetRemove(t *testing.T) { method := "TestMethod" var receiver interface{} = "TestReceiver" diff --git a/gomock/controller.go b/gomock/controller.go index db03522..de904c8 100644 --- a/gomock/controller.go +++ b/gomock/controller.go @@ -108,12 +108,23 @@ func NewController(t TestReporter, opts ...ControllerOption) *Controller { return ctrl } -// ControllerOption configures how a Controller should behave. Currently -// there are no implementations of it. +// ControllerOption configures how a Controller should behave. type ControllerOption interface { apply(*Controller) } +type overridableExpectationsOption struct{} + +// WithOverridableExpectations allows for overridable call expectations +// i.e., subsequent call expectations override existing call expectations +func WithOverridableExpectations() overridableExpectationsOption { + return overridableExpectationsOption{} +} + +func (o overridableExpectationsOption) apply(ctrl *Controller) { + ctrl.expectedCalls = newOverridableCallSet() +} + type cancelReporter struct { t TestHelper cancel func() diff --git a/gomock/example_test.go b/gomock/example_test.go index fee8332..25b20a2 100644 --- a/gomock/example_test.go +++ b/gomock/example_test.go @@ -48,3 +48,21 @@ func ExampleCall_DoAndReturn_captureArguments() { fmt.Printf("%s %s", r, s) // Output: I'm sleepy foo } + +func ExampleCall_DoAndReturn_withOverridableExpectations() { + t := &testing.T{} // provided by test + ctrl := gomock.NewController(t, gomock.WithOverridableExpectations()) + mockIndex := NewMockFoo(ctrl) + var s string + + mockIndex.EXPECT().Bar(gomock.AssignableToTypeOf(s)).DoAndReturn( + func(arg string) interface{} { + s = arg + return "I'm sleepy" + }, + ) + + r := mockIndex.Bar("foo") + fmt.Printf("%s %s", r, s) + // Output: I'm sleepy foo +} diff --git a/gomock/overridable_controller_test.go b/gomock/overridable_controller_test.go new file mode 100644 index 0000000..3d75e6a --- /dev/null +++ b/gomock/overridable_controller_test.go @@ -0,0 +1,34 @@ +package gomock_test + +import ( + "testing" + + "go.uber.org/mock/gomock" +) + +func TestEcho_NoOverride(t *testing.T) { + ctrl := gomock.NewController(t, gomock.WithOverridableExpectations()) + mockIndex := NewMockFoo(ctrl) + + mockIndex.EXPECT().Bar(gomock.Any()).Return("foo") + res := mockIndex.Bar("input") + + if res != "foo" { + t.Fatalf("expected response to equal 'foo', got %s", res) + } +} + +func TestEcho_WithOverride_BaseCase(t *testing.T) { + ctrl := gomock.NewController(t, gomock.WithOverridableExpectations()) + mockIndex := NewMockFoo(ctrl) + + // initial expectation set + mockIndex.EXPECT().Bar(gomock.Any()).Return("foo") + // override + mockIndex.EXPECT().Bar(gomock.Any()).Return("bar") + res := mockIndex.Bar("input") + + if res != "bar" { + t.Fatalf("expected response to equal 'bar', got %s", res) + } +}