diff --git a/pkg/p2p/client_test.go b/pkg/p2p/client_test.go index 345601fee16..471cf1bf74d 100644 --- a/pkg/p2p/client_test.go +++ b/pkg/p2p/client_test.go @@ -180,8 +180,7 @@ func TestMessageClientBasics(t *testing.T) { sender.AssertExpectations(t) // Test point 7: Interrupt the connection - grpcStream.ExpectedCalls = nil - grpcStream.Calls = nil + grpcStream.ResetMock() sender.ExpectedCalls = nil sender.Calls = nil diff --git a/pkg/p2p/mock_grpc_client.go b/pkg/p2p/mock_grpc_client.go index 7b32816875f..f9121e716d2 100644 --- a/pkg/p2p/mock_grpc_client.go +++ b/pkg/p2p/mock_grpc_client.go @@ -15,6 +15,7 @@ package p2p import ( "context" + "sync" "sync/atomic" "github.com/pingcap/tiflow/proto/p2p" @@ -22,8 +23,8 @@ import ( "google.golang.org/grpc" ) -//nolint:unused type mockSendMessageClient struct { + mu sync.Mutex mock.Mock // embeds an empty interface p2p.CDCPeerToPeer_SendMessageClient @@ -41,13 +42,24 @@ func newMockSendMessageClient(ctx context.Context) *mockSendMessageClient { } func (s *mockSendMessageClient) Send(packet *p2p.MessagePacket) error { + s.mu.Lock() + defer s.mu.Unlock() + args := s.Called(packet) atomic.AddInt32(&s.msgCount, 1) return args.Error(0) } func (s *mockSendMessageClient) Recv() (*p2p.SendMessageResponse, error) { - args := s.Called() + var args mock.Arguments + func() { + // We use a deferred Unlock in case `s.Called()` panics. + s.mu.Lock() + defer s.mu.Unlock() + + args = s.MethodCalled("Recv") + }() + if err := args.Error(1); err != nil { return nil, err } @@ -66,12 +78,18 @@ func (s *mockSendMessageClient) Context() context.Context { return s.ctx } -//nolint:unused +func (s *mockSendMessageClient) ResetMock() { + s.mu.Lock() + defer s.mu.Unlock() + + s.ExpectedCalls = nil + s.Calls = nil +} + type mockCDCPeerToPeerClient struct { mock.Mock } -//nolint:unused func (c *mockCDCPeerToPeerClient) SendMessage( ctx context.Context, opts ...grpc.CallOption, ) (p2p.CDCPeerToPeer_SendMessageClient, error) {