From e4beaccd13657a4119c2a61950f2e46615557be9 Mon Sep 17 00:00:00 2001 From: Rohith BCS Date: Thu, 2 Nov 2023 14:10:08 +0530 Subject: [PATCH 1/2] fix : set gw ratelimits at event level --- gateway/gateway_test.go | 4 ++-- gateway/handle.go | 2 +- gateway/throttler/throttler.go | 10 +++++----- gateway/throttler/throttler_test.go | 8 ++++---- mocks/gateway/throttler.go | 8 ++++---- 5 files changed, 16 insertions(+), 16 deletions(-) diff --git a/gateway/gateway_test.go b/gateway/gateway_test.go index 61e05b41a3..b88237079d 100644 --- a/gateway/gateway_test.go +++ b/gateway/gateway_test.go @@ -844,7 +844,7 @@ var _ = Describe("Gateway", func() { }) It("should store messages successfully if rate limit is not reached for workspace", func() { - c.mockRateLimiter.EXPECT().CheckLimitReached(gomock.Any(), gomock.Any()).Return(false, nil).Times(1) + c.mockRateLimiter.EXPECT().CheckLimitReached(gomock.Any(), gomock.Any(), gomock.Any()).Return(false, nil).Times(1) c.mockJobsDB.EXPECT().WithStoreSafeTx(gomock.Any(), gomock.Any()).Times(1).Do(func(ctx context.Context, f func(tx jobsdb.StoreSafeTx) error) { _ = f(jobsdb.EmptyStoreSafeTx()) }).Return(nil) @@ -884,7 +884,7 @@ var _ = Describe("Gateway", func() { }) It("should reject messages if rate limit is reached for workspace", func() { - c.mockRateLimiter.EXPECT().CheckLimitReached(gomock.Any(), gomock.Any()).Return(true, nil).Times(1) + c.mockRateLimiter.EXPECT().CheckLimitReached(gomock.Any(), gomock.Any(), gomock.Any()).Return(true, nil).Times(1) expectHandlerResponse( gateway.webAliasHandler(), authorizedRequest(WriteKeyEnabled, bytes.NewBufferString("{}")), diff --git a/gateway/handle.go b/gateway/handle.go index 60e71e0e55..712c9fc13f 100644 --- a/gateway/handle.go +++ b/gateway/handle.go @@ -297,7 +297,7 @@ func (gw *Handle) getJobDataFromRequest(req *webRequestT) (jobData *jobFromReq, if gw.conf.enableRateLimit.Load() { // In case of "batch" requests, if rate-limiter returns true for LimitReached, just drop the event batch and continue. - ok, errCheck := gw.rateLimiter.CheckLimitReached(context.TODO(), workspaceId) + ok, errCheck := gw.rateLimiter.CheckLimitReached(context.TODO(), workspaceId, int64(len(eventsBatch))) if errCheck != nil { gw.stats.NewTaggedStat("gateway.rate_limiter_error", stats.CountType, stats.Tags{"workspaceId": workspaceId}).Increment() gw.logger.Errorf("Rate limiter error: %v Allowing the request", errCheck) diff --git a/gateway/throttler/throttler.go b/gateway/throttler/throttler.go index 5e959b0b46..cb8f82a7a4 100644 --- a/gateway/throttler/throttler.go +++ b/gateway/throttler/throttler.go @@ -23,7 +23,7 @@ type Limiter interface { } type Throttler interface { - CheckLimitReached(context context.Context, workspaceId string) (bool, error) + CheckLimitReached(context context.Context, workspaceId string, eventCount int64) (bool, error) } type Factory struct { @@ -45,9 +45,9 @@ func New(stats stats.Stats) (*Factory, error) { return &f, nil } -func (f *Factory) CheckLimitReached(context context.Context, workspaceId string) (bool, error) { +func (f *Factory) CheckLimitReached(context context.Context, workspaceId string, eventCount int64) (bool, error) { t := f.get(workspaceId) - return t.checkLimitReached(context, workspaceId) + return t.checkLimitReached(context, workspaceId, eventCount) } func (f *Factory) get(workspaceId string) *throttler { @@ -98,8 +98,8 @@ type throttler struct { } // checkLimitReached returns true if we're not allowed to process the number of event -func (t *throttler) checkLimitReached(ctx context.Context, key string) (limited bool, retErr error) { - allowed, _, err := t.limiter.Allow(ctx, 1, t.config.limit, getWindowInSecs(t.config.window), key) +func (t *throttler) checkLimitReached(ctx context.Context, key string, count int64) (limited bool, retErr error) { + allowed, _, err := t.limiter.Allow(ctx, count, t.config.limit, getWindowInSecs(t.config.window), key) if err != nil { return false, fmt.Errorf("could not limit: %w", err) } diff --git a/gateway/throttler/throttler_test.go b/gateway/throttler/throttler_test.go index 397be6b481..0891fe2829 100644 --- a/gateway/throttler/throttler_test.go +++ b/gateway/throttler/throttler_test.go @@ -34,14 +34,14 @@ func TestGateway_Throttler(t *testing.T) { } for i := 0; i < eventLimit; i++ { - _, err := testThrottler.checkLimitReached(context.TODO(), workspaceId) + _, err := testThrottler.checkLimitReached(context.TODO(), workspaceId, 1) require.NoError(t, err) } startTime := time.Now() var passed int for i := 0; i < 2*eventLimit; i++ { - allowed, err := testThrottler.checkLimitReached(context.TODO(), workspaceId) + allowed, err := testThrottler.checkLimitReached(context.TODO(), workspaceId, 1) require.NoError(t, err) if allowed { passed++ @@ -69,14 +69,14 @@ func TestGateway_Factory(t *testing.T) { require.NotNil(t, rateLimiter) for i := 0; i < eventLimit; i++ { - _, err := rateLimiter.CheckLimitReached(context.TODO(), workspaceId) + _, err := rateLimiter.CheckLimitReached(context.TODO(), workspaceId, 1) require.NoError(t, err) } startTime := time.Now() var passed int for i := 0; i < 2*eventLimit; i++ { - allowed, err := rateLimiter.CheckLimitReached(context.TODO(), workspaceId) + allowed, err := rateLimiter.CheckLimitReached(context.TODO(), workspaceId, 1) require.NoError(t, err) if allowed { passed++ diff --git a/mocks/gateway/throttler.go b/mocks/gateway/throttler.go index 5e546b5731..ae469fd5dc 100644 --- a/mocks/gateway/throttler.go +++ b/mocks/gateway/throttler.go @@ -35,16 +35,16 @@ func (m *MockThrottler) EXPECT() *MockThrottlerMockRecorder { } // CheckLimitReached mocks base method. -func (m *MockThrottler) CheckLimitReached(arg0 context.Context, arg1 string) (bool, error) { +func (m *MockThrottler) CheckLimitReached(arg0 context.Context, arg1 string, arg2 int64) (bool, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CheckLimitReached", arg0, arg1) + ret := m.ctrl.Call(m, "CheckLimitReached", arg0, arg1, arg2) ret0, _ := ret[0].(bool) ret1, _ := ret[1].(error) return ret0, ret1 } // CheckLimitReached indicates an expected call of CheckLimitReached. -func (mr *MockThrottlerMockRecorder) CheckLimitReached(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockThrottlerMockRecorder) CheckLimitReached(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckLimitReached", reflect.TypeOf((*MockThrottler)(nil).CheckLimitReached), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckLimitReached", reflect.TypeOf((*MockThrottler)(nil).CheckLimitReached), arg0, arg1, arg2) } From fdda3a5604e3150e718647617a842a2feeff13f3 Mon Sep 17 00:00:00 2001 From: Rohith BCS Date: Thu, 2 Nov 2023 15:14:58 +0530 Subject: [PATCH 2/2] chore: skip rETL ETL from rate limiting --- gateway/gateway_test.go | 3 ++- gateway/handle.go | 29 +++++++++++++++-------------- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/gateway/gateway_test.go b/gateway/gateway_test.go index b88237079d..6ac21d3133 100644 --- a/gateway/gateway_test.go +++ b/gateway/gateway_test.go @@ -884,10 +884,11 @@ var _ = Describe("Gateway", func() { }) It("should reject messages if rate limit is reached for workspace", func() { + conf.Set("Gateway.allowReqsWithoutUserIDAndAnonymousID", true) c.mockRateLimiter.EXPECT().CheckLimitReached(gomock.Any(), gomock.Any(), gomock.Any()).Return(true, nil).Times(1) expectHandlerResponse( gateway.webAliasHandler(), - authorizedRequest(WriteKeyEnabled, bytes.NewBufferString("{}")), + authorizedRequest(WriteKeyEnabled, bytes.NewBufferString(`{"data": "valid-json"}`)), http.StatusTooManyRequests, response.TooManyRequests+"\n", "alias", diff --git a/gateway/handle.go b/gateway/handle.go index 712c9fc13f..3836ec620a 100644 --- a/gateway/handle.go +++ b/gateway/handle.go @@ -267,6 +267,9 @@ func (gw *Handle) getJobDataFromRequest(req *webRequestT) (jobData *jobFromReq, userIDHeader = req.userIDHeader ipAddr = req.ipAddr body = req.requestPayload + + // values retrieved from first event in batch + sourcesJobRunID, sourcesTaskRunID = req.authContext.SourceJobRunID, req.authContext.SourceTaskRunID ) fillMessageID := func(event map[string]interface{}) { @@ -295,18 +298,6 @@ func (gw *Handle) getJobDataFromRequest(req *webRequestT) (jobData *jobFromReq, eventsBatch := gjson.GetBytes(body, "batch").Array() jobData.numEvents = len(eventsBatch) - if gw.conf.enableRateLimit.Load() { - // In case of "batch" requests, if rate-limiter returns true for LimitReached, just drop the event batch and continue. - ok, errCheck := gw.rateLimiter.CheckLimitReached(context.TODO(), workspaceId, int64(len(eventsBatch))) - if errCheck != nil { - gw.stats.NewTaggedStat("gateway.rate_limiter_error", stats.CountType, stats.Tags{"workspaceId": workspaceId}).Increment() - gw.logger.Errorf("Rate limiter error: %v Allowing the request", errCheck) - } - if ok { - return jobData, errRequestDropped - } - } - type jobObject struct { userID string events []map[string]interface{} @@ -317,8 +308,6 @@ func (gw *Handle) getJobDataFromRequest(req *webRequestT) (jobData *jobFromReq, out []jobObject marshalledParams []byte - // values retrieved from first event in batch - sourcesJobRunID, sourcesTaskRunID = req.authContext.SourceJobRunID, req.authContext.SourceTaskRunID // facts about the batch populated as we iterate over events containsAudienceList, suppressed bool @@ -408,6 +397,18 @@ func (gw *Handle) getJobDataFromRequest(req *webRequestT) (jobData *jobFromReq, }) } + if gw.conf.enableRateLimit.Load() && sourcesJobRunID == "" && sourcesTaskRunID == "" { + // In case of "batch" requests, if rate-limiter returns true for LimitReached, just drop the event batch and continue. + ok, errCheck := gw.rateLimiter.CheckLimitReached(context.TODO(), workspaceId, int64(len(eventsBatch))) + if errCheck != nil { + gw.stats.NewTaggedStat("gateway.rate_limiter_error", stats.CountType, stats.Tags{"workspaceId": workspaceId}).Increment() + gw.logger.Errorf("Rate limiter error: %v Allowing the request", errCheck) + } + if ok { + return jobData, errRequestDropped + } + } + if len(out) == 0 && suppressed { err = errRequestSuppressed return