diff --git a/service/history/service.go b/service/history/service.go index f85918022fd..514aaed5b67 100644 --- a/service/history/service.go +++ b/service/history/service.go @@ -114,7 +114,13 @@ func (s *Service) Start() { }) rawHandler := handler.NewHandler(s.Resource, s.config, wfIDCache) - s.handler = ratelimited.NewHistoryHandler(rawHandler, wfIDCache) + s.handler = ratelimited.NewHistoryHandler( + rawHandler, + wfIDCache, + s.config.WorkflowIDExternalRateLimitEnabled, + s.Resource.GetDomainCache(), + s.Resource.GetLogger(), + ) thriftHandler := thrift.NewThriftHandler(s.handler) thriftHandler.Register(s.GetDispatcher()) diff --git a/service/history/templates/ratelimited.tmpl b/service/history/templates/ratelimited.tmpl index e5f70018db9..2bbf03f4d83 100644 --- a/service/history/templates/ratelimited.tmpl +++ b/service/history/templates/ratelimited.tmpl @@ -5,6 +5,7 @@ import ( "github.com/uber/cadence/common/quotas" "github.com/uber/cadence/common/types" "github.com/uber/cadence/service/history" + "github.com/uber/cadence/common/log" ) {{ $ratelimitTypeMap := dict "StartWorkflowExecution" ( @@ -39,19 +40,32 @@ import ( // {{$decorator}} implements {{.Interface.Type}} interface instrumented with rate limiter. type {{$decorator}} struct { - wrapped {{.Interface.Type}} - workflowIDCache workflowcache.WFCache + wrapped {{.Interface.Type}} + workflowIDCache workflowcache.WFCache + ratelimitExternalPerWorkflowID dynamicconfig.BoolPropertyFnWithDomainFilter + domainCache cache.DomainCache + logger log.Logger + allowFunc func (domainID string, workflowID string) bool } // New{{$Decorator}} creates a new instance of {{$interfaceName}} with ratelimiter. func New{{$Decorator}}( wrapped {{.Interface.Type}}, workflowIDCache workflowcache.WFCache, + ratelimitExternalPerWorkflowID dynamicconfig.BoolPropertyFnWithDomainFilter, + domainCache cache.DomainCache, + logger log.Logger, ) {{.Interface.Type}} { - return &{{$decorator}}{ + wrapper := &{{$decorator}}{ wrapped: wrapped, workflowIDCache: workflowIDCache, + ratelimitExternalPerWorkflowID: ratelimitExternalPerWorkflowID, + domainCache: domainCache, + logger: logger, } + wrapper.allowFunc = wrapper.allowWfID + + return wrapper } {{range $method := .Interface.Methods}} @@ -82,7 +96,11 @@ func (h *{{$decorator}}) {{$method.Declaration}} { return } - h.workflowIDCache.AllowExternal({{$domainID}}, {{$workflowID}}) + if !h.allowFunc({{$domainID}}, {{$workflowID}}) { + err = &types.ServiceBusyError{"Too many requests for the workflow ID"} + return + } + {{- end}} {{- end}} {{$method.Pass "h.wrapped."}} diff --git a/service/history/wrappers/ratelimited/handler_generated.go b/service/history/wrappers/ratelimited/handler_generated.go index 4f849e6bd41..2ab84d55be3 100644 --- a/service/history/wrappers/ratelimited/handler_generated.go +++ b/service/history/wrappers/ratelimited/handler_generated.go @@ -32,6 +32,9 @@ import ( "context" "time" + "github.com/uber/cadence/common/cache" + "github.com/uber/cadence/common/dynamicconfig" + "github.com/uber/cadence/common/log" "github.com/uber/cadence/common/types" "github.com/uber/cadence/service/frontend/validate" "github.com/uber/cadence/service/history/handler" @@ -40,19 +43,32 @@ import ( // historyHandler implements handler.Handler interface instrumented with rate limiter. type historyHandler struct { - wrapped handler.Handler - workflowIDCache workflowcache.WFCache + wrapped handler.Handler + workflowIDCache workflowcache.WFCache + ratelimitExternalPerWorkflowID dynamicconfig.BoolPropertyFnWithDomainFilter + domainCache cache.DomainCache + logger log.Logger + allowFunc func(domainID string, workflowID string) bool } // NewHistoryHandler creates a new instance of Handler with ratelimiter. func NewHistoryHandler( wrapped handler.Handler, workflowIDCache workflowcache.WFCache, + ratelimitExternalPerWorkflowID dynamicconfig.BoolPropertyFnWithDomainFilter, + domainCache cache.DomainCache, + logger log.Logger, ) handler.Handler { - return &historyHandler{ - wrapped: wrapped, - workflowIDCache: workflowIDCache, + wrapper := &historyHandler{ + wrapped: wrapped, + workflowIDCache: workflowIDCache, + ratelimitExternalPerWorkflowID: ratelimitExternalPerWorkflowID, + domainCache: domainCache, + logger: logger, } + wrapper.allowFunc = wrapper.allowWfID + + return wrapper } func (h *historyHandler) CloseShard(ctx context.Context, cp1 *types.CloseShardRequest) (err error) { @@ -92,7 +108,10 @@ func (h *historyHandler) DescribeWorkflowExecution(ctx context.Context, hp1 *typ return } - h.workflowIDCache.AllowExternal(hp1.GetDomainUUID(), hp1.Request.GetExecution().GetWorkflowID()) + if !h.allowFunc(hp1.GetDomainUUID(), hp1.Request.GetExecution().GetWorkflowID()) { + err = &types.ServiceBusyError{"Too many requests for the workflow ID"} + return + } return h.wrapped.DescribeWorkflowExecution(ctx, hp1) } @@ -245,7 +264,10 @@ func (h *historyHandler) SignalWithStartWorkflowExecution(ctx context.Context, h return } - h.workflowIDCache.AllowExternal(hp1.GetDomainUUID(), hp1.SignalWithStartRequest.GetWorkflowID()) + if !h.allowFunc(hp1.GetDomainUUID(), hp1.SignalWithStartRequest.GetWorkflowID()) { + err = &types.ServiceBusyError{"Too many requests for the workflow ID"} + return + } return h.wrapped.SignalWithStartWorkflowExecution(ctx, hp1) } @@ -266,7 +288,10 @@ func (h *historyHandler) SignalWorkflowExecution(ctx context.Context, hp1 *types return } - h.workflowIDCache.AllowExternal(hp1.GetDomainUUID(), hp1.SignalRequest.GetWorkflowExecution().GetWorkflowID()) + if !h.allowFunc(hp1.GetDomainUUID(), hp1.SignalRequest.GetWorkflowExecution().GetWorkflowID()) { + err = &types.ServiceBusyError{"Too many requests for the workflow ID"} + return + } return h.wrapped.SignalWorkflowExecution(ctx, hp1) } @@ -292,7 +317,10 @@ func (h *historyHandler) StartWorkflowExecution(ctx context.Context, hp1 *types. return } - h.workflowIDCache.AllowExternal(hp1.GetDomainUUID(), hp1.StartRequest.GetWorkflowID()) + if !h.allowFunc(hp1.GetDomainUUID(), hp1.StartRequest.GetWorkflowID()) { + err = &types.ServiceBusyError{"Too many requests for the workflow ID"} + return + } return h.wrapped.StartWorkflowExecution(ctx, hp1) } diff --git a/service/history/wrappers/ratelimited/handler_generated_test.go b/service/history/wrappers/ratelimited/handler_generated_test.go index 75869dee0a2..812fe8c8b7e 100644 --- a/service/history/wrappers/ratelimited/handler_generated_test.go +++ b/service/history/wrappers/ratelimited/handler_generated_test.go @@ -24,50 +24,61 @@ package ratelimited import ( "context" + "fmt" "testing" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" + "github.com/uber/cadence/common/log" "github.com/uber/cadence/common/types" "github.com/uber/cadence/service/history/handler" - "github.com/uber/cadence/service/history/workflowcache" ) const ( testDomainID = "test-domain-id" testWorkflowID = "test-workflow-id" + testDomainName = "test-domain-name" ) func TestRatelimitedEndpoints_Table(t *testing.T) { controller := gomock.NewController(t) - workflowIDCache := workflowcache.NewMockWFCache(controller) handlerMock := handler.NewMockHandler(controller) - - wrapper := NewHistoryHandler(handlerMock, workflowIDCache) - - tests := []struct { + var rateLimitingEnabled bool + + wrapper := NewHistoryHandler( + handlerMock, + nil, + func(domainName string) bool { return rateLimitingEnabled }, + nil, + log.NewNoop(), + ) + + // We define the calls that should be ratelimited + limitedCalls := []struct { name string - call func() (interface{}, error) - mock func() + // Defines how to call the wrapper function (correct request type, and call) + callWrapper func() (interface{}, error) + // Defines the expected call to the wrapped handler (what to call if the call is not ratelimited) + expectCallToEndpoint func() }{ { name: "StartWorkflowExecution", - call: func() (interface{}, error) { + callWrapper: func() (interface{}, error) { startRequest := &types.HistoryStartWorkflowExecutionRequest{ DomainUUID: testDomainID, StartRequest: &types.StartWorkflowExecutionRequest{WorkflowID: testWorkflowID}, } return wrapper.StartWorkflowExecution(context.Background(), startRequest) }, - mock: func() { + expectCallToEndpoint: func() { handlerMock.EXPECT().StartWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) }, }, { name: "SignalWithStartWorkflowExecution", - call: func() (interface{}, error) { + callWrapper: func() (interface{}, error) { signalWithStartRequest := &types.HistorySignalWithStartWorkflowExecutionRequest{ DomainUUID: testDomainID, SignalWithStartRequest: &types.SignalWithStartWorkflowExecutionRequest{WorkflowID: testWorkflowID}, @@ -75,13 +86,13 @@ func TestRatelimitedEndpoints_Table(t *testing.T) { return wrapper.SignalWithStartWorkflowExecution(context.Background(), signalWithStartRequest) }, - mock: func() { + expectCallToEndpoint: func() { handlerMock.EXPECT().SignalWithStartWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) }, }, { name: "SignalWorkflowExecution", - call: func() (interface{}, error) { + callWrapper: func() (interface{}, error) { signalRequest := &types.HistorySignalWorkflowExecutionRequest{ DomainUUID: testDomainID, SignalRequest: &types.SignalWorkflowExecutionRequest{ @@ -91,13 +102,13 @@ func TestRatelimitedEndpoints_Table(t *testing.T) { return nil, wrapper.SignalWorkflowExecution(context.Background(), signalRequest) }, - mock: func() { + expectCallToEndpoint: func() { handlerMock.EXPECT().SignalWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil).Times(1) }, }, { name: "DescribeWorkflowExecution", - call: func() (interface{}, error) { + callWrapper: func() (interface{}, error) { describeRequest := &types.HistoryDescribeWorkflowExecutionRequest{ DomainUUID: testDomainID, Request: &types.DescribeWorkflowExecutionRequest{ @@ -107,24 +118,26 @@ func TestRatelimitedEndpoints_Table(t *testing.T) { return wrapper.DescribeWorkflowExecution(context.Background(), describeRequest) }, - mock: func() { + expectCallToEndpoint: func() { handlerMock.EXPECT().DescribeWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) }, }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // For now true and false needs to do the same as we are only shadowing - workflowIDCache.EXPECT().AllowExternal(testDomainID, testWorkflowID).Return(true).Times(1) - tt.mock() - _, err := tt.call() + for _, endpoint := range limitedCalls { + t.Run(fmt.Sprintf("%s, %s", endpoint.name, "not limited"), func(t *testing.T) { + wrapper.(*historyHandler).allowFunc = func(string, string) bool { return true } + endpoint.expectCallToEndpoint() + _, err := endpoint.callWrapper() assert.NoError(t, err) + }) - workflowIDCache.EXPECT().AllowExternal(testDomainID, testWorkflowID).Return(false).Times(1) - tt.mock() - _, err = tt.call() - assert.NoError(t, err) + t.Run(fmt.Sprintf("%s, %s", endpoint.name, "limited"), func(t *testing.T) { + wrapper.(*historyHandler).allowFunc = func(string, string) bool { return false } + _, err := endpoint.callWrapper() + var sbErr *types.ServiceBusyError + assert.ErrorAs(t, err, &sbErr) + assert.ErrorContains(t, err, "Too many requests for the workflow ID") }) } } diff --git a/service/history/wrappers/ratelimited/ratelimit.go b/service/history/wrappers/ratelimited/ratelimit.go new file mode 100644 index 00000000000..778734cacf9 --- /dev/null +++ b/service/history/wrappers/ratelimited/ratelimit.go @@ -0,0 +1,39 @@ +// The MIT License (MIT) + +// Copyright (c) 2017-2020 Uber Technologies Inc. + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package ratelimited + +import "github.com/uber/cadence/common/log/tag" + +func (h *historyHandler) allowWfID(domainUUID, workflowID string) bool { + domainName, err := h.domainCache.GetDomainName(domainUUID) + if err != nil { + h.logger.Error("Error when getting domain name", tag.Error(err)) + // Fail open + return true + } + + allow := h.workflowIDCache.AllowExternal(domainUUID, workflowID) + enabled := h.ratelimitExternalPerWorkflowID(domainName) + + return allow || !enabled +} diff --git a/service/history/wrappers/ratelimited/ratelimit_test.go b/service/history/wrappers/ratelimited/ratelimit_test.go new file mode 100644 index 00000000000..2e0d8346cb3 --- /dev/null +++ b/service/history/wrappers/ratelimited/ratelimit_test.go @@ -0,0 +1,106 @@ +// The MIT License (MIT) + +// Copyright (c) 2017-2020 Uber Technologies Inc. + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package ratelimited + +import ( + "fmt" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + "github.com/uber/cadence/common/cache" + "github.com/uber/cadence/common/log" + "github.com/uber/cadence/service/history/workflowcache" +) + +func TestAllowWfID(t *testing.T) { + tests := []struct { + ratelimitEnabled bool + workflowIDCacheAllow bool + expected bool + }{ + { + ratelimitEnabled: true, + workflowIDCacheAllow: true, + expected: true, + }, + { + ratelimitEnabled: true, + workflowIDCacheAllow: false, + expected: false, + }, + { + ratelimitEnabled: false, + workflowIDCacheAllow: true, + expected: true, + }, + { + ratelimitEnabled: false, + workflowIDCacheAllow: false, + expected: true, + }, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("ratelimitEnabled: %t, workflowIDCacheAllow: %t", tt.ratelimitEnabled, tt.workflowIDCacheAllow), func(t *testing.T) { + ctrl := gomock.NewController(t) + workflowIDCacheMock := workflowcache.NewMockWFCache(ctrl) + workflowIDCacheMock.EXPECT().AllowExternal(testDomainID, testWorkflowID).Return(tt.workflowIDCacheAllow).Times(1) + + domainCacheMock := cache.NewMockDomainCache(ctrl) + domainCacheMock.EXPECT().GetDomainName(testDomainID).Return(testDomainID, nil).Times(1) + + h := &historyHandler{ + workflowIDCache: workflowIDCacheMock, + domainCache: domainCacheMock, + logger: log.NewNoop(), + ratelimitExternalPerWorkflowID: func(domain string) bool { return tt.ratelimitEnabled }, + } + + got := h.allowWfID(testDomainID, testWorkflowID) + + assert.Equal(t, tt.expected, got) + }) + } +} + +func TestAllowWfID_DomainCacheError(t *testing.T) { + ctrl := gomock.NewController(t) + domainCacheMock := cache.NewMockDomainCache(ctrl) + domainCacheMock.EXPECT().GetDomainName(testDomainID).Return("", fmt.Errorf("TEST ERROR")).Times(1) + + loggerMock := &log.MockLogger{} + loggerMock.On("Error", "Error when getting domain name", mock.Anything).Return().Times(1) + + h := &historyHandler{ + domainCache: domainCacheMock, + logger: loggerMock, + } + + got := h.allowWfID(testDomainID, testWorkflowID) + + // Fail open + assert.True(t, got) +}