From e57054dd59e149d99566a92a72da8c8e3e140e89 Mon Sep 17 00:00:00 2001 From: Tim Li Date: Thu, 6 Jun 2024 13:07:28 -0700 Subject: [PATCH] Add unit tests for RespondDecisionTaskFailed and RespondQueryTaskCompleted methods in frontend api handler --- service/frontend/api/handler_test.go | 286 ++++++++++++++++++++++++++- 1 file changed, 281 insertions(+), 5 deletions(-) diff --git a/service/frontend/api/handler_test.go b/service/frontend/api/handler_test.go index b9c9a06a0f4..60fa59f37e0 100644 --- a/service/frontend/api/handler_test.go +++ b/service/frontend/api/handler_test.go @@ -36,6 +36,7 @@ import ( "github.com/stretchr/testify/suite" "github.com/uber/cadence/client/history" + "github.com/uber/cadence/client/matching" "github.com/uber/cadence/common" "github.com/uber/cadence/common/archiver" "github.com/uber/cadence/common/archiver/provider" @@ -71,11 +72,12 @@ type ( suite.Suite *require.Assertions - controller *gomock.Controller - mockResource *resource.Test - mockDomainCache *cache.MockDomainCache - mockHistoryClient *history.MockClient - domainHandler domain.Handler + controller *gomock.Controller + mockResource *resource.Test + mockDomainCache *cache.MockDomainCache + mockHistoryClient *history.MockClient + mockMatchingClient *matching.MockClient + domainHandler domain.Handler mockProducer *mocks.KafkaProducer mockMessagingClient messaging.Client @@ -115,6 +117,7 @@ func (s *workflowHandlerSuite) SetupTest() { s.mockResource = resource.NewTest(s.T(), s.controller, metrics.Frontend) s.mockDomainCache = s.mockResource.DomainCache s.mockHistoryClient = s.mockResource.HistoryClient + s.mockMatchingClient = s.mockResource.MatchingClient s.mockMetadataMgr = s.mockResource.MetadataMgr s.mockHistoryV2Mgr = s.mockResource.HistoryMgr s.mockVisibilityMgr = s.mockResource.VisibilityMgr @@ -2647,6 +2650,279 @@ func (s *workflowHandlerSuite) TestRespondDecisionTaskCompleted() { }) } +func (s *workflowHandlerSuite) TestRespondDecisionTaskFailed() { + validRequest := &types.RespondDecisionTaskFailedRequest{ + TaskToken: []byte("token"), + Cause: types.DecisionTaskFailedCauseWorkflowWorkerUnhandledFailure.Ptr(), + Identity: "identity", + Details: make([]byte, 1000), + } + config := s.newConfig(dc.NewInMemoryClient()) + config.EnableClientVersionCheck = dc.GetBoolPropertyFn(true) + wh := NewWorkflowHandler(s.mockResource, config, s.mockVersionChecker, nil) + wh.tokenSerializer = s.mockTokenSerializer + + testInput := map[string]struct { + input *types.RespondDecisionTaskFailedRequest + mockFn func() + expectError bool + expectErrorType error + }{ + "shutting down": { + input: validRequest, + mockFn: func() { + wh.shuttingDown = int32(1) + }, + expectError: true, + expectErrorType: validate.ErrShuttingDown, + }, + "nil request": { + input: nil, + mockFn: func() {}, + expectError: true, + expectErrorType: validate.ErrRequestNotSet, + }, + "nil task token": { + input: &types.RespondDecisionTaskFailedRequest{ + TaskToken: nil, + }, + mockFn: func() {}, + expectError: true, + expectErrorType: validate.ErrTaskTokenNotSet, + }, + "deserialization failure": { + input: validRequest, + mockFn: func() { + s.mockTokenSerializer.EXPECT().Deserialize(gomock.Any()).Return(nil, errors.New("failed to deserialize token")) + }, + expectError: true, + }, + "empty domain ID": { + input: validRequest, + mockFn: func() { + s.mockTokenSerializer.EXPECT().Deserialize(gomock.Any()).Return(&common.TaskToken{DomainID: ""}, nil) + }, + expectError: true, + expectErrorType: validate.ErrDomainNotSet, + }, + "cannot get domain name": { + input: validRequest, + mockFn: func() { + s.mockTokenSerializer.EXPECT().Deserialize(gomock.Any()).Return(&common.TaskToken{DomainID: s.testDomainID}, nil) + s.mockDomainCache.EXPECT().GetDomainName(s.testDomainID).Return("", errors.New("error getting domain name")) + }, + expectError: true, + }, + "exceeds id length limit": { + input: validRequest, + mockFn: func() { + s.mockTokenSerializer.EXPECT().Deserialize(gomock.Any()).Return(&common.TaskToken{DomainID: s.testDomainID}, nil) + s.mockDomainCache.EXPECT().GetDomainName(s.testDomainID).Return(s.testDomain, nil) + wh.config.MaxIDLengthWarnLimit = dc.GetIntPropertyFn(1) + wh.config.IdentityMaxLength = dc.GetIntPropertyFilteredByDomain(1) + }, + expectError: true, + expectErrorType: validate.ErrIdentityTooLong, + }, + "exceeds blob size limit": { + input: validRequest, + mockFn: func() { + s.mockTokenSerializer.EXPECT().Deserialize(gomock.Any()).Return(&common.TaskToken{DomainID: s.testDomainID}, nil) + s.mockDomainCache.EXPECT().GetDomainName(s.testDomainID).Return(s.testDomain, nil) + wh.config.BlobSizeLimitWarn = dc.GetIntPropertyFilteredByDomain(1) + wh.config.BlobSizeLimitError = dc.GetIntPropertyFilteredByDomain(1) + s.mockHistoryClient.EXPECT().RespondDecisionTaskFailed(gomock.Any(), gomock.Any()).Return(nil) + }, + expectError: false, + }, + "history client returns error": { + input: validRequest, + mockFn: func() { + s.mockTokenSerializer.EXPECT().Deserialize(gomock.Any()).Return(&common.TaskToken{DomainID: s.testDomainID}, nil) + s.mockDomainCache.EXPECT().GetDomainName(s.testDomainID).Return(s.testDomain, nil) + wh.config.BlobSizeLimitWarn = dc.GetIntPropertyFilteredByDomain(1000) + wh.config.BlobSizeLimitError = dc.GetIntPropertyFilteredByDomain(1000) + s.mockHistoryClient.EXPECT().RespondDecisionTaskFailed(gomock.Any(), gomock.Any()).Return(errors.New("error")) + }, + expectError: true, + }, + } + + for name, input := range testInput { + s.Run(name, func() { + input.mockFn() + err := wh.RespondDecisionTaskFailed(context.Background(), input.input) + if input.expectError { + s.Error(err) + if input.expectErrorType != nil { + s.ErrorIs(err, input.expectErrorType) + } + } else { + s.NoError(err) + } + wh.shuttingDown = int32(0) + wh.config.MaxIDLengthWarnLimit = dc.GetIntPropertyFn(1000) + wh.config.IdentityMaxLength = dc.GetIntPropertyFilteredByDomain(1000) + }) + } + + // test version checker + s.Run("version checker", func() { + mockCtrl := gomock.NewController(s.T()) + mockResource := resource.NewTest(s.T(), mockCtrl, metrics.Frontend) + mockVersionChecker := client.NewMockVersionChecker(mockCtrl) + + cfg := frontendcfg.NewConfig( + dc.NewCollection( + dc.NewInMemoryClient(), + mockResource.GetLogger(), + ), + numHistoryShards, + false, + "hostname", + ) + cfg.EnableClientVersionCheck = dc.GetBoolPropertyFn(true) + wh := NewWorkflowHandler(mockResource, cfg, mockVersionChecker, nil) + mockVersionChecker.EXPECT().ClientSupported(gomock.Any(), gomock.Any()).Return(errors.New("error")).Times(1) + err := wh.RespondDecisionTaskFailed(context.Background(), validRequest) + s.Error(err) + }) +} + +func (s *workflowHandlerSuite) TestRespondQueryTaskCompleted() { + config := s.newConfig(dc.NewInMemoryClient()) + config.EnableClientVersionCheck = dc.GetBoolPropertyFn(true) + wh := NewWorkflowHandler(s.mockResource, config, s.mockVersionChecker, nil) + wh.tokenSerializer = s.mockTokenSerializer + + validInput := &types.RespondQueryTaskCompletedRequest{ + TaskToken: []byte("token"), + QueryResult: []byte(`{"result": "result"}`), + } + + testInput := map[string]struct { + input *types.RespondQueryTaskCompletedRequest + mockFn func() + expectError bool + expectErrorType error + }{ + "shutting down": { + input: validInput, + mockFn: func() { + wh.shuttingDown = int32(1) + }, + expectError: true, + expectErrorType: validate.ErrShuttingDown, + }, + "nil request": { + input: nil, + mockFn: func() {}, + expectError: true, + expectErrorType: validate.ErrRequestNotSet, + }, + "empty task token": { + input: &types.RespondQueryTaskCompletedRequest{ + TaskToken: nil, + }, + mockFn: func() {}, + expectError: true, + expectErrorType: validate.ErrTaskTokenNotSet, + }, + "deserialzation failure": { + input: validInput, + mockFn: func() { + s.mockTokenSerializer.EXPECT().DeserializeQueryTaskToken(gomock.Any()).Return(nil, errors.New("failed to deserialize token")) + }, + expectError: true, + }, + "empty domain ID": { + input: validInput, + mockFn: func() { + s.mockTokenSerializer.EXPECT().DeserializeQueryTaskToken(gomock.Any()).Return(&common.QueryTaskToken{DomainID: ""}, nil) + }, + expectError: true, + expectErrorType: validate.ErrInvalidTaskToken, + }, + "cannot get domain name": { + input: validInput, + mockFn: func() { + s.mockTokenSerializer.EXPECT().DeserializeQueryTaskToken(gomock.Any()).Return(&common.QueryTaskToken{ + DomainID: s.testDomainID, + TaskList: "tasklist", + TaskID: "taskID"}, nil) + s.mockDomainCache.EXPECT().GetDomainName(s.testDomainID).Return("", errors.New("error getting domain name")) + }, + expectError: true, + }, + "exceed blob size limit and success": { + input: validInput, + mockFn: func() { + s.mockTokenSerializer.EXPECT().DeserializeQueryTaskToken(gomock.Any()).Return(&common.QueryTaskToken{ + DomainID: s.testDomainID, + TaskList: "tasklist", + TaskID: "taskID"}, nil) + s.mockDomainCache.EXPECT().GetDomainName(s.testDomainID).Return(s.testDomain, nil) + wh.config.BlobSizeLimitWarn = dc.GetIntPropertyFilteredByDomain(1) + wh.config.BlobSizeLimitError = dc.GetIntPropertyFilteredByDomain(1) + s.mockMatchingClient.EXPECT().RespondQueryTaskCompleted(gomock.Any(), gomock.Any()).Return(nil) + }, + expectError: false, + }, + "matching client returns error": { + input: validInput, + mockFn: func() { + s.mockTokenSerializer.EXPECT().DeserializeQueryTaskToken(gomock.Any()).Return(&common.QueryTaskToken{ + DomainID: s.testDomainID, + TaskList: "tasklist", + TaskID: "taskID"}, nil) + s.mockDomainCache.EXPECT().GetDomainName(s.testDomainID).Return(s.testDomain, nil) + wh.config.BlobSizeLimitWarn = dc.GetIntPropertyFilteredByDomain(1000) + wh.config.BlobSizeLimitError = dc.GetIntPropertyFilteredByDomain(1000) + s.mockMatchingClient.EXPECT().RespondQueryTaskCompleted(gomock.Any(), gomock.Any()).Return(errors.New("error")) + }, + expectError: true, + }, + } + + for name, input := range testInput { + s.Run(name, func() { + input.mockFn() + err := wh.RespondQueryTaskCompleted(context.Background(), input.input) + if input.expectError { + s.Error(err) + if input.expectErrorType != nil { + s.ErrorIs(err, input.expectErrorType) + } + } else { + s.NoError(err) + } + wh.shuttingDown = int32(0) + }) + } + + // test version checker + s.Run("version checker", func() { + mockCtrl := gomock.NewController(s.T()) + mockResource := resource.NewTest(s.T(), mockCtrl, metrics.Frontend) + mockVersionChecker := client.NewMockVersionChecker(mockCtrl) + + cfg := frontendcfg.NewConfig( + dc.NewCollection( + dc.NewInMemoryClient(), + mockResource.GetLogger(), + ), + numHistoryShards, + false, + "hostname", + ) + cfg.EnableClientVersionCheck = dc.GetBoolPropertyFn(true) + wh := NewWorkflowHandler(mockResource, cfg, mockVersionChecker, nil) + mockVersionChecker.EXPECT().ClientSupported(gomock.Any(), gomock.Any()).Return(errors.New("error")).Times(1) + err := wh.RespondQueryTaskCompleted(context.Background(), validInput) + s.Error(err) + }) +} + func updateRequest( historyArchivalURI *string, historyArchivalStatus *types.ArchivalStatus,