diff --git a/pkg/ratelimit/controller.go b/pkg/ratelimit/controller.go new file mode 100644 index 00000000000..0c95be9b11b --- /dev/null +++ b/pkg/ratelimit/controller.go @@ -0,0 +1,79 @@ +// Copyright 2023 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ratelimit + +import ( + "sync" + + "golang.org/x/time/rate" +) + +var emptyFunc = func() {} + +// Controller is a controller which holds multiple limiters to manage the request rate of different objects. +type Controller struct { + limiters sync.Map + // the label which is in labelAllowList won't be limited, and only inited by hard code. + labelAllowList map[string]struct{} +} + +// NewController returns a global limiter which can be updated in the later. +func NewController() *Controller { + return &Controller{ + labelAllowList: make(map[string]struct{}), + } +} + +// Allow is used to check whether it has enough token. +func (l *Controller) Allow(label string) (DoneFunc, error) { + var ok bool + lim, ok := l.limiters.Load(label) + if ok { + return lim.(*limiter).allow() + } + return emptyFunc, nil +} + +// Update is used to update Ratelimiter with Options +func (l *Controller) Update(label string, opts ...Option) UpdateStatus { + var status UpdateStatus + for _, opt := range opts { + status |= opt(label, l) + } + return status +} + +// GetQPSLimiterStatus returns the status of a given label's QPS limiter. +func (l *Controller) GetQPSLimiterStatus(label string) (limit rate.Limit, burst int) { + if limit, exist := l.limiters.Load(label); exist { + return limit.(*limiter).getQPSLimiterStatus() + } + return 0, 0 +} + +// GetConcurrencyLimiterStatus returns the status of a given label's concurrency limiter. +func (l *Controller) GetConcurrencyLimiterStatus(label string) (limit uint64, current uint64) { + if limit, exist := l.limiters.Load(label); exist { + return limit.(*limiter).getConcurrencyLimiterStatus() + } + return 0, 0 +} + +// IsInAllowList returns whether this label is in allow list. +// If returns true, the given label won't be limited +func (l *Controller) IsInAllowList(label string) bool { + _, allow := l.labelAllowList[label] + return allow +} diff --git a/pkg/ratelimit/controller_test.go b/pkg/ratelimit/controller_test.go new file mode 100644 index 00000000000..a830217cb9f --- /dev/null +++ b/pkg/ratelimit/controller_test.go @@ -0,0 +1,426 @@ +// Copyright 2023 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ratelimit + +import ( + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/tikv/pd/pkg/utils/syncutil" + "golang.org/x/time/rate" +) + +type changeAndResult struct { + opt Option + checkOptionStatus func(string, Option) + totalRequest int + success int + fail int + release int + waitDuration time.Duration + checkStatusFunc func(string) +} + +type labelCase struct { + label string + round []changeAndResult +} + +func runMulitLabelLimiter(t *testing.T, limiter *Controller, testCase []labelCase) { + re := require.New(t) + var caseWG sync.WaitGroup + for _, tempCas := range testCase { + caseWG.Add(1) + cas := tempCas + go func() { + var lock syncutil.Mutex + successCount, failedCount := 0, 0 + var wg sync.WaitGroup + r := &releaseUtil{} + for _, rd := range cas.round { + rd.checkOptionStatus(cas.label, rd.opt) + time.Sleep(rd.waitDuration) + for i := 0; i < rd.totalRequest; i++ { + wg.Add(1) + go func() { + countRateLimiterHandleResult(limiter, cas.label, &successCount, &failedCount, &lock, &wg, r) + }() + } + wg.Wait() + re.Equal(rd.fail, failedCount) + re.Equal(rd.success, successCount) + for i := 0; i < rd.release; i++ { + r.release() + } + rd.checkStatusFunc(cas.label) + failedCount -= rd.fail + successCount -= rd.success + } + caseWG.Done() + }() + } + caseWG.Wait() +} + +func TestControllerWithConcurrencyLimiter(t *testing.T) { + t.Parallel() + re := require.New(t) + limiter := NewController() + testCase := []labelCase{ + { + label: "test1", + round: []changeAndResult{ + { + opt: UpdateConcurrencyLimiter(10), + checkOptionStatus: func(label string, o Option) { + status := limiter.Update(label, o) + re.True(status&ConcurrencyChanged != 0) + }, + totalRequest: 15, + fail: 5, + success: 10, + release: 10, + waitDuration: 0, + checkStatusFunc: func(label string) { + limit, current := limiter.GetConcurrencyLimiterStatus(label) + re.Equal(uint64(10), limit) + re.Equal(uint64(0), current) + }, + }, + { + opt: UpdateConcurrencyLimiter(10), + checkOptionStatus: func(label string, o Option) { + status := limiter.Update(label, o) + re.True(status&ConcurrencyNoChange != 0) + }, + checkStatusFunc: func(label string) {}, + }, + { + opt: UpdateConcurrencyLimiter(5), + checkOptionStatus: func(label string, o Option) { + status := limiter.Update(label, o) + re.True(status&ConcurrencyChanged != 0) + }, + totalRequest: 15, + fail: 10, + success: 5, + release: 5, + waitDuration: 0, + checkStatusFunc: func(label string) { + limit, current := limiter.GetConcurrencyLimiterStatus(label) + re.Equal(uint64(5), limit) + re.Equal(uint64(0), current) + }, + }, + { + opt: UpdateConcurrencyLimiter(0), + checkOptionStatus: func(label string, o Option) { + status := limiter.Update(label, o) + re.True(status&ConcurrencyDeleted != 0) + }, + totalRequest: 15, + fail: 0, + success: 15, + release: 5, + waitDuration: 0, + checkStatusFunc: func(label string) { + limit, current := limiter.GetConcurrencyLimiterStatus(label) + re.Equal(uint64(0), limit) + re.Equal(uint64(0), current) + }, + }, + }, + }, + { + label: "test2", + round: []changeAndResult{ + { + opt: UpdateConcurrencyLimiter(15), + checkOptionStatus: func(label string, o Option) { + status := limiter.Update(label, o) + re.True(status&ConcurrencyChanged != 0) + }, + totalRequest: 10, + fail: 0, + success: 10, + release: 0, + waitDuration: 0, + checkStatusFunc: func(label string) { + limit, current := limiter.GetConcurrencyLimiterStatus(label) + re.Equal(uint64(15), limit) + re.Equal(uint64(10), current) + }, + }, + { + opt: UpdateConcurrencyLimiter(10), + checkOptionStatus: func(label string, o Option) { + status := limiter.Update(label, o) + re.True(status&ConcurrencyChanged != 0) + }, + totalRequest: 10, + fail: 10, + success: 0, + release: 10, + waitDuration: 0, + checkStatusFunc: func(label string) { + limit, current := limiter.GetConcurrencyLimiterStatus(label) + re.Equal(uint64(10), limit) + re.Equal(uint64(0), current) + }, + }, + }, + }, + } + runMulitLabelLimiter(t, limiter, testCase) +} + +func TestBlockList(t *testing.T) { + t.Parallel() + re := require.New(t) + opts := []Option{AddLabelAllowList()} + limiter := NewController() + label := "test" + + re.False(limiter.IsInAllowList(label)) + for _, opt := range opts { + opt(label, limiter) + } + re.True(limiter.IsInAllowList(label)) + + status := UpdateQPSLimiter(float64(rate.Every(time.Second)), 1)(label, limiter) + re.True(status&InAllowList != 0) + for i := 0; i < 10; i++ { + _, err := limiter.Allow(label) + re.NoError(err) + } +} + +func TestControllerWithQPSLimiter(t *testing.T) { + t.Parallel() + re := require.New(t) + limiter := NewController() + testCase := []labelCase{ + { + label: "test1", + round: []changeAndResult{ + { + opt: UpdateQPSLimiter(float64(rate.Every(time.Second)), 1), + checkOptionStatus: func(label string, o Option) { + status := limiter.Update(label, o) + re.True(status&QPSChanged != 0) + }, + totalRequest: 3, + fail: 2, + success: 1, + waitDuration: 0, + checkStatusFunc: func(label string) { + limit, burst := limiter.GetQPSLimiterStatus(label) + re.Equal(rate.Limit(1), limit) + re.Equal(1, burst) + }, + }, + { + opt: UpdateQPSLimiter(float64(rate.Every(time.Second)), 1), + checkOptionStatus: func(label string, o Option) { + status := limiter.Update(label, o) + re.True(status&QPSNoChange != 0) + }, + checkStatusFunc: func(label string) {}, + }, + { + opt: UpdateQPSLimiter(5, 5), + checkOptionStatus: func(label string, o Option) { + status := limiter.Update(label, o) + re.True(status&QPSChanged != 0) + }, + totalRequest: 10, + fail: 5, + success: 5, + waitDuration: time.Second, + checkStatusFunc: func(label string) { + limit, burst := limiter.GetQPSLimiterStatus(label) + re.Equal(rate.Limit(5), limit) + re.Equal(5, burst) + }, + }, + { + opt: UpdateQPSLimiter(0, 0), + checkOptionStatus: func(label string, o Option) { + status := limiter.Update(label, o) + re.True(status&QPSDeleted != 0) + }, + totalRequest: 10, + fail: 0, + success: 10, + release: 0, + waitDuration: 0, + checkStatusFunc: func(label string) { + limit, burst := limiter.GetQPSLimiterStatus(label) + re.Equal(rate.Limit(0), limit) + re.Equal(0, burst) + }, + }, + }, + }, + { + label: "test2", + round: []changeAndResult{ + { + opt: UpdateQPSLimiter(50, 5), + checkOptionStatus: func(label string, o Option) { + status := limiter.Update(label, o) + re.True(status&QPSChanged != 0) + }, + totalRequest: 10, + fail: 5, + success: 5, + waitDuration: time.Second, + checkStatusFunc: func(label string) { + limit, burst := limiter.GetQPSLimiterStatus(label) + re.Equal(rate.Limit(50), limit) + re.Equal(5, burst) + }, + }, + { + opt: UpdateQPSLimiter(0, 0), + checkOptionStatus: func(label string, o Option) { + status := limiter.Update(label, o) + re.True(status&QPSDeleted != 0) + }, + totalRequest: 10, + fail: 0, + success: 10, + release: 0, + waitDuration: 0, + checkStatusFunc: func(label string) { + limit, burst := limiter.GetQPSLimiterStatus(label) + re.Equal(rate.Limit(0), limit) + re.Equal(0, burst) + }, + }, + }, + }, + } + runMulitLabelLimiter(t, limiter, testCase) +} + +func TestControllerWithTwoLimiters(t *testing.T) { + t.Parallel() + re := require.New(t) + limiter := NewController() + testCase := []labelCase{ + { + label: "test1", + round: []changeAndResult{ + { + opt: UpdateDimensionConfig(&DimensionConfig{ + QPS: 100, + QPSBurst: 100, + ConcurrencyLimit: 100, + }), + checkOptionStatus: func(label string, o Option) { + status := limiter.Update(label, o) + re.True(status&QPSChanged != 0) + }, + totalRequest: 200, + fail: 100, + success: 100, + release: 100, + waitDuration: time.Second, + checkStatusFunc: func(label string) { + limit, burst := limiter.GetQPSLimiterStatus(label) + re.Equal(rate.Limit(100), limit) + re.Equal(100, burst) + climit, current := limiter.GetConcurrencyLimiterStatus(label) + re.Equal(uint64(100), climit) + re.Equal(uint64(0), current) + }, + }, + { + opt: UpdateQPSLimiter(float64(rate.Every(time.Second)), 1), + checkOptionStatus: func(label string, o Option) { + status := limiter.Update(label, o) + re.True(status&QPSChanged != 0) + }, + totalRequest: 200, + fail: 199, + success: 1, + release: 0, + waitDuration: time.Second, + checkStatusFunc: func(label string) { + limit, current := limiter.GetConcurrencyLimiterStatus(label) + re.Equal(uint64(100), limit) + re.Equal(uint64(1), current) + }, + }, + }, + }, + { + label: "test2", + round: []changeAndResult{ + { + opt: UpdateQPSLimiter(50, 5), + checkOptionStatus: func(label string, o Option) { + status := limiter.Update(label, o) + re.True(status&QPSChanged != 0) + }, + totalRequest: 10, + fail: 5, + success: 5, + waitDuration: time.Second, + checkStatusFunc: func(label string) { + limit, burst := limiter.GetQPSLimiterStatus(label) + re.Equal(rate.Limit(50), limit) + re.Equal(5, burst) + }, + }, + { + opt: UpdateQPSLimiter(0, 0), + checkOptionStatus: func(label string, o Option) { + status := limiter.Update(label, o) + re.True(status&QPSDeleted != 0) + }, + totalRequest: 10, + fail: 0, + success: 10, + release: 0, + waitDuration: 0, + checkStatusFunc: func(label string) { + limit, burst := limiter.GetQPSLimiterStatus(label) + re.Equal(rate.Limit(0), limit) + re.Equal(0, burst) + }, + }, + }, + }, + } + runMulitLabelLimiter(t, limiter, testCase) +} + +func countRateLimiterHandleResult(limiter *Controller, label string, successCount *int, + failedCount *int, lock *syncutil.Mutex, wg *sync.WaitGroup, r *releaseUtil) { + doneFucn, err := limiter.Allow(label) + lock.Lock() + defer lock.Unlock() + if err == nil { + *successCount++ + r.append(doneFucn) + } else { + *failedCount++ + } + wg.Done() +} diff --git a/pkg/ratelimit/limiter.go b/pkg/ratelimit/limiter.go index 4bf930ed6c5..444b5aa2481 100644 --- a/pkg/ratelimit/limiter.go +++ b/pkg/ratelimit/limiter.go @@ -1,4 +1,4 @@ -// Copyright 2022 TiKV Project Authors. +// Copyright 2023 TiKV Project Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,11 +15,16 @@ package ratelimit import ( - "sync" + "math" + "github.com/tikv/pd/pkg/errs" + "github.com/tikv/pd/pkg/utils/syncutil" "golang.org/x/time/rate" ) +// DoneFunc is done function. +type DoneFunc func() + // DimensionConfig is the limit dimension config of one label type DimensionConfig struct { // qps conifg @@ -29,92 +34,125 @@ type DimensionConfig struct { ConcurrencyLimit uint64 } -// Limiter is a controller for the request rate. -type Limiter struct { - qpsLimiter sync.Map - concurrencyLimiter sync.Map - // the label which is in labelAllowList won't be limited - labelAllowList map[string]struct{} +type limiter struct { + mu syncutil.RWMutex + concurrency *concurrencyLimiter + rate *RateLimiter } -// NewLimiter returns a global limiter which can be updated in the later. -func NewLimiter() *Limiter { - return &Limiter{ - labelAllowList: make(map[string]struct{}), - } +func newLimiter() *limiter { + lim := &limiter{} + return lim } -// Allow is used to check whether it has enough token. -func (l *Limiter) Allow(label string) bool { - var cl *concurrencyLimiter - var ok bool - if limiter, exist := l.concurrencyLimiter.Load(label); exist { - if cl, ok = limiter.(*concurrencyLimiter); ok && !cl.allow() { - return false - } - } +func (l *limiter) getConcurrencyLimiter() *concurrencyLimiter { + l.mu.RLock() + defer l.mu.RUnlock() + return l.concurrency +} - if limiter, exist := l.qpsLimiter.Load(label); exist { - if ql, ok := limiter.(*RateLimiter); ok && !ql.Allow() { - if cl != nil { - cl.release() - } - return false - } - } +func (l *limiter) getRateLimiter() *RateLimiter { + l.mu.RLock() + defer l.mu.RUnlock() + return l.rate +} - return true +func (l *limiter) deleteRateLimiter() bool { + l.mu.Lock() + defer l.mu.Unlock() + l.rate = nil + return l.isEmpty() } -// Release is used to refill token. It may be not uesful for some limiters because they will refill automatically -func (l *Limiter) Release(label string) { - if limiter, exist := l.concurrencyLimiter.Load(label); exist { - if cl, ok := limiter.(*concurrencyLimiter); ok { - cl.release() - } - } +func (l *limiter) deleteConcurrency() bool { + l.mu.Lock() + defer l.mu.Unlock() + l.concurrency = nil + return l.isEmpty() } -// Update is used to update Ratelimiter with Options -func (l *Limiter) Update(label string, opts ...Option) UpdateStatus { - var status UpdateStatus - for _, opt := range opts { - status |= opt(label, l) - } - return status +func (l *limiter) isEmpty() bool { + return l.concurrency == nil && l.rate == nil } -// GetQPSLimiterStatus returns the status of a given label's QPS limiter. -func (l *Limiter) GetQPSLimiterStatus(label string) (limit rate.Limit, burst int) { - if limiter, exist := l.qpsLimiter.Load(label); exist { - return limiter.(*RateLimiter).Limit(), limiter.(*RateLimiter).Burst() +func (l *limiter) getQPSLimiterStatus() (limit rate.Limit, burst int) { + baseLimiter := l.getRateLimiter() + if baseLimiter != nil { + return baseLimiter.Limit(), baseLimiter.Burst() } - return 0, 0 } -// QPSUnlimit deletes QPS limiter of the given label -func (l *Limiter) QPSUnlimit(label string) { - l.qpsLimiter.Delete(label) +func (l *limiter) getConcurrencyLimiterStatus() (limit uint64, current uint64) { + baseLimiter := l.getConcurrencyLimiter() + if baseLimiter != nil { + return baseLimiter.getLimit(), baseLimiter.getCurrent() + } + return 0, 0 } -// GetConcurrencyLimiterStatus returns the status of a given label's concurrency limiter. -func (l *Limiter) GetConcurrencyLimiterStatus(label string) (limit uint64, current uint64) { - if limiter, exist := l.concurrencyLimiter.Load(label); exist { - return limiter.(*concurrencyLimiter).getLimit(), limiter.(*concurrencyLimiter).getCurrent() +func (l *limiter) updateConcurrencyConfig(limit uint64) UpdateStatus { + oldConcurrencyLimit, _ := l.getConcurrencyLimiterStatus() + if oldConcurrencyLimit == limit { + return ConcurrencyNoChange + } + if limit < 1 { + l.deleteConcurrency() + return ConcurrencyDeleted } - return 0, 0 + l.mu.Lock() + defer l.mu.Unlock() + if l.concurrency != nil { + l.concurrency.setLimit(limit) + } else { + l.concurrency = newConcurrencyLimiter(limit) + } + return ConcurrencyChanged +} + +func (l *limiter) updateQPSConfig(limit float64, burst int) UpdateStatus { + oldQPSLimit, oldBurst := l.getQPSLimiterStatus() + if math.Abs(float64(oldQPSLimit)-limit) < eps && oldBurst == burst { + return QPSNoChange + } + if limit <= eps || burst < 1 { + l.deleteRateLimiter() + return QPSDeleted + } + l.mu.Lock() + defer l.mu.Unlock() + if l.rate != nil { + l.rate.SetLimit(rate.Limit(limit)) + l.rate.SetBurst(burst) + } else { + l.rate = NewRateLimiter(limit, burst) + } + return QPSChanged } -// ConcurrencyUnlimit deletes concurrency limiter of the given label -func (l *Limiter) ConcurrencyUnlimit(label string) { - l.concurrencyLimiter.Delete(label) +func (l *limiter) updateDimensionConfig(cfg *DimensionConfig) UpdateStatus { + status := l.updateQPSConfig(cfg.QPS, cfg.QPSBurst) + status |= l.updateConcurrencyConfig(cfg.ConcurrencyLimit) + return status } -// IsInAllowList returns whether this label is in allow list. -// If returns true, the given label won't be limited -func (l *Limiter) IsInAllowList(label string) bool { - _, allow := l.labelAllowList[label] - return allow +func (l *limiter) allow() (DoneFunc, error) { + concurrency := l.getConcurrencyLimiter() + if concurrency != nil && !concurrency.allow() { + return nil, errs.ErrRateLimitExceeded + } + + rate := l.getRateLimiter() + if rate != nil && !rate.Allow() { + if concurrency != nil { + concurrency.release() + } + return nil, errs.ErrRateLimitExceeded + } + return func() { + if concurrency != nil { + concurrency.release() + } + }, nil } diff --git a/pkg/ratelimit/limiter_test.go b/pkg/ratelimit/limiter_test.go index d5d9829816a..8834495f3e9 100644 --- a/pkg/ratelimit/limiter_test.go +++ b/pkg/ratelimit/limiter_test.go @@ -1,4 +1,4 @@ -// Copyright 2022 TiKV Project Authors. +// Copyright 2023 TiKV Project Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -24,162 +24,145 @@ import ( "golang.org/x/time/rate" ) -func TestUpdateConcurrencyLimiter(t *testing.T) { +type releaseUtil struct { + dones []DoneFunc +} + +func (r *releaseUtil) release() { + if len(r.dones) > 0 { + r.dones[0]() + r.dones = r.dones[1:] + } +} + +func (r *releaseUtil) append(d DoneFunc) { + r.dones = append(r.dones, d) +} + +func TestWithConcurrencyLimiter(t *testing.T) { t.Parallel() re := require.New(t) - opts := []Option{UpdateConcurrencyLimiter(10)} - limiter := NewLimiter() - - label := "test" - status := limiter.Update(label, opts...) + limiter := newLimiter() + status := limiter.updateConcurrencyConfig(10) re.True(status&ConcurrencyChanged != 0) var lock syncutil.Mutex successCount, failedCount := 0, 0 var wg sync.WaitGroup + r := &releaseUtil{} for i := 0; i < 15; i++ { wg.Add(1) go func() { - countRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) + countSingleLimiterHandleResult(limiter, &successCount, &failedCount, &lock, &wg, r) }() } wg.Wait() re.Equal(5, failedCount) re.Equal(10, successCount) for i := 0; i < 10; i++ { - limiter.Release(label) + r.release() } - limit, current := limiter.GetConcurrencyLimiterStatus(label) + limit, current := limiter.getConcurrencyLimiterStatus() re.Equal(uint64(10), limit) re.Equal(uint64(0), current) - status = limiter.Update(label, UpdateConcurrencyLimiter(10)) + status = limiter.updateConcurrencyConfig(10) re.True(status&ConcurrencyNoChange != 0) - status = limiter.Update(label, UpdateConcurrencyLimiter(5)) + status = limiter.updateConcurrencyConfig(5) re.True(status&ConcurrencyChanged != 0) failedCount = 0 successCount = 0 for i := 0; i < 15; i++ { wg.Add(1) - go countRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) + go countSingleLimiterHandleResult(limiter, &successCount, &failedCount, &lock, &wg, r) } wg.Wait() re.Equal(10, failedCount) re.Equal(5, successCount) for i := 0; i < 5; i++ { - limiter.Release(label) + r.release() } - status = limiter.Update(label, UpdateConcurrencyLimiter(0)) + status = limiter.updateConcurrencyConfig(0) re.True(status&ConcurrencyDeleted != 0) failedCount = 0 successCount = 0 for i := 0; i < 15; i++ { wg.Add(1) - go countRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) + go countSingleLimiterHandleResult(limiter, &successCount, &failedCount, &lock, &wg, r) } wg.Wait() re.Equal(0, failedCount) re.Equal(15, successCount) - limit, current = limiter.GetConcurrencyLimiterStatus(label) + limit, current = limiter.getConcurrencyLimiterStatus() re.Equal(uint64(0), limit) re.Equal(uint64(0), current) } -func TestBlockList(t *testing.T) { +func TestWithQPSLimiter(t *testing.T) { t.Parallel() re := require.New(t) - opts := []Option{AddLabelAllowList()} - limiter := NewLimiter() - label := "test" - - re.False(limiter.IsInAllowList(label)) - for _, opt := range opts { - opt(label, limiter) - } - re.True(limiter.IsInAllowList(label)) - - status := UpdateQPSLimiter(float64(rate.Every(time.Second)), 1)(label, limiter) - re.True(status&InAllowList != 0) - for i := 0; i < 10; i++ { - re.True(limiter.Allow(label)) - } -} - -func TestUpdateQPSLimiter(t *testing.T) { - t.Parallel() - re := require.New(t) - opts := []Option{UpdateQPSLimiter(float64(rate.Every(time.Second)), 1)} - limiter := NewLimiter() - - label := "test" - status := limiter.Update(label, opts...) + limiter := newLimiter() + status := limiter.updateQPSConfig(float64(rate.Every(time.Second)), 1) re.True(status&QPSChanged != 0) var lock syncutil.Mutex successCount, failedCount := 0, 0 var wg sync.WaitGroup + r := &releaseUtil{} wg.Add(3) for i := 0; i < 3; i++ { - go countRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) + go countSingleLimiterHandleResult(limiter, &successCount, &failedCount, &lock, &wg, r) } wg.Wait() re.Equal(2, failedCount) re.Equal(1, successCount) - limit, burst := limiter.GetQPSLimiterStatus(label) + limit, burst := limiter.getQPSLimiterStatus() re.Equal(rate.Limit(1), limit) re.Equal(1, burst) - status = limiter.Update(label, UpdateQPSLimiter(float64(rate.Every(time.Second)), 1)) + status = limiter.updateQPSConfig(float64(rate.Every(time.Second)), 1) re.True(status&QPSNoChange != 0) - status = limiter.Update(label, UpdateQPSLimiter(5, 5)) + status = limiter.updateQPSConfig(5, 5) re.True(status&QPSChanged != 0) - limit, burst = limiter.GetQPSLimiterStatus(label) + limit, burst = limiter.getQPSLimiterStatus() re.Equal(rate.Limit(5), limit) re.Equal(5, burst) time.Sleep(time.Second) for i := 0; i < 10; i++ { if i < 5 { - re.True(limiter.Allow(label)) + _, err := limiter.allow() + re.NoError(err) } else { - re.False(limiter.Allow(label)) + _, err := limiter.allow() + re.Error(err) } } time.Sleep(time.Second) - status = limiter.Update(label, UpdateQPSLimiter(0, 0)) + status = limiter.updateQPSConfig(0, 0) re.True(status&QPSDeleted != 0) for i := 0; i < 10; i++ { - re.True(limiter.Allow(label)) + _, err := limiter.allow() + re.NoError(err) } - qLimit, qCurrent := limiter.GetQPSLimiterStatus(label) + qLimit, qCurrent := limiter.getQPSLimiterStatus() re.Equal(rate.Limit(0), qLimit) re.Equal(0, qCurrent) -} -func TestQPSLimiter(t *testing.T) { - t.Parallel() - re := require.New(t) - opts := []Option{UpdateQPSLimiter(float64(rate.Every(3*time.Second)), 100)} - limiter := NewLimiter() - - label := "test" - for _, opt := range opts { - opt(label, limiter) - } - - var lock syncutil.Mutex - successCount, failedCount := 0, 0 - var wg sync.WaitGroup + successCount = 0 + failedCount = 0 + status = limiter.updateQPSConfig(float64(rate.Every(3*time.Second)), 100) + re.True(status&QPSChanged != 0) wg.Add(200) for i := 0; i < 200; i++ { - go countRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) + go countSingleLimiterHandleResult(limiter, &successCount, &failedCount, &lock, &wg, r) } wg.Wait() re.Equal(200, failedCount+successCount) @@ -188,12 +171,12 @@ func TestQPSLimiter(t *testing.T) { time.Sleep(4 * time.Second) // 3+1 wg.Add(1) - countRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) + countSingleLimiterHandleResult(limiter, &successCount, &failedCount, &lock, &wg, r) wg.Wait() re.Equal(101, successCount) } -func TestTwoLimiters(t *testing.T) { +func TestWithTwoLimiters(t *testing.T) { t.Parallel() re := require.New(t) cfg := &DimensionConfig{ @@ -201,20 +184,18 @@ func TestTwoLimiters(t *testing.T) { QPSBurst: 100, ConcurrencyLimit: 100, } - opts := []Option{UpdateDimensionConfig(cfg)} - limiter := NewLimiter() - - label := "test" - for _, opt := range opts { - opt(label, limiter) - } + limiter := newLimiter() + status := limiter.updateDimensionConfig(cfg) + re.True(status&QPSChanged != 0) + re.True(status&ConcurrencyChanged != 0) var lock syncutil.Mutex successCount, failedCount := 0, 0 var wg sync.WaitGroup + r := &releaseUtil{} wg.Add(200) for i := 0; i < 200; i++ { - go countRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) + go countSingleLimiterHandleResult(limiter, &successCount, &failedCount, &lock, &wg, r) } wg.Wait() re.Equal(100, failedCount) @@ -223,35 +204,42 @@ func TestTwoLimiters(t *testing.T) { wg.Add(100) for i := 0; i < 100; i++ { - go countRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) + go countSingleLimiterHandleResult(limiter, &successCount, &failedCount, &lock, &wg, r) } wg.Wait() re.Equal(200, failedCount) re.Equal(100, successCount) for i := 0; i < 100; i++ { - limiter.Release(label) + r.release() } - limiter.Update(label, UpdateQPSLimiter(float64(rate.Every(10*time.Second)), 1)) + status = limiter.updateQPSConfig(float64(rate.Every(10*time.Second)), 1) + re.True(status&QPSChanged != 0) wg.Add(100) for i := 0; i < 100; i++ { - go countRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) + go countSingleLimiterHandleResult(limiter, &successCount, &failedCount, &lock, &wg, r) } wg.Wait() re.Equal(101, successCount) re.Equal(299, failedCount) - limit, current := limiter.GetConcurrencyLimiterStatus(label) + limit, current := limiter.getConcurrencyLimiterStatus() re.Equal(uint64(100), limit) re.Equal(uint64(1), current) + + cfg = &DimensionConfig{} + status = limiter.updateDimensionConfig(cfg) + re.True(status&ConcurrencyDeleted != 0) + re.True(status&QPSDeleted != 0) } -func countRateLimiterHandleResult(limiter *Limiter, label string, successCount *int, - failedCount *int, lock *syncutil.Mutex, wg *sync.WaitGroup) { - result := limiter.Allow(label) +func countSingleLimiterHandleResult(limiter *limiter, successCount *int, + failedCount *int, lock *syncutil.Mutex, wg *sync.WaitGroup, r *releaseUtil) { + doneFucn, err := limiter.allow() lock.Lock() defer lock.Unlock() - if result { + if err == nil { *successCount++ + r.append(doneFucn) } else { *failedCount++ } diff --git a/pkg/ratelimit/option.go b/pkg/ratelimit/option.go index 53afb9926d4..b1cc459d786 100644 --- a/pkg/ratelimit/option.go +++ b/pkg/ratelimit/option.go @@ -14,8 +14,6 @@ package ratelimit -import "golang.org/x/time/rate" - // UpdateStatus is flags for updating limiter config. type UpdateStatus uint32 @@ -40,77 +38,46 @@ const ( // Option is used to create a limiter with the optional settings. // these setting is used to add a kind of limiter for a service -type Option func(string, *Limiter) UpdateStatus +type Option func(string, *Controller) UpdateStatus // AddLabelAllowList adds a label into allow list. // It means the given label will not be limited func AddLabelAllowList() Option { - return func(label string, l *Limiter) UpdateStatus { + return func(label string, l *Controller) UpdateStatus { l.labelAllowList[label] = struct{}{} return 0 } } -func updateConcurrencyConfig(l *Limiter, label string, limit uint64) UpdateStatus { - oldConcurrencyLimit, _ := l.GetConcurrencyLimiterStatus(label) - if oldConcurrencyLimit == limit { - return ConcurrencyNoChange - } - if limit < 1 { - l.ConcurrencyUnlimit(label) - return ConcurrencyDeleted - } - if limiter, exist := l.concurrencyLimiter.LoadOrStore(label, newConcurrencyLimiter(limit)); exist { - limiter.(*concurrencyLimiter).setLimit(limit) - } - return ConcurrencyChanged -} - -func updateQPSConfig(l *Limiter, label string, limit float64, burst int) UpdateStatus { - oldQPSLimit, oldBurst := l.GetQPSLimiterStatus(label) - - if (float64(oldQPSLimit)-limit < eps && float64(oldQPSLimit)-limit > -eps) && oldBurst == burst { - return QPSNoChange - } - if limit <= eps || burst < 1 { - l.QPSUnlimit(label) - return QPSDeleted - } - if limiter, exist := l.qpsLimiter.LoadOrStore(label, NewRateLimiter(limit, burst)); exist { - limiter.(*RateLimiter).SetLimit(rate.Limit(limit)) - limiter.(*RateLimiter).SetBurst(burst) - } - return QPSChanged -} - // UpdateConcurrencyLimiter creates a concurrency limiter for a given label if it doesn't exist. func UpdateConcurrencyLimiter(limit uint64) Option { - return func(label string, l *Limiter) UpdateStatus { + return func(label string, l *Controller) UpdateStatus { if _, allow := l.labelAllowList[label]; allow { return InAllowList } - return updateConcurrencyConfig(l, label, limit) + lim, _ := l.limiters.LoadOrStore(label, newLimiter()) + return lim.(*limiter).updateConcurrencyConfig(limit) } } // UpdateQPSLimiter creates a QPS limiter for a given label if it doesn't exist. func UpdateQPSLimiter(limit float64, burst int) Option { - return func(label string, l *Limiter) UpdateStatus { + return func(label string, l *Controller) UpdateStatus { if _, allow := l.labelAllowList[label]; allow { return InAllowList } - return updateQPSConfig(l, label, limit, burst) + lim, _ := l.limiters.LoadOrStore(label, newLimiter()) + return lim.(*limiter).updateQPSConfig(limit, burst) } } // UpdateDimensionConfig creates QPS limiter and concurrency limiter for a given label by config if it doesn't exist. func UpdateDimensionConfig(cfg *DimensionConfig) Option { - return func(label string, l *Limiter) UpdateStatus { + return func(label string, l *Controller) UpdateStatus { if _, allow := l.labelAllowList[label]; allow { return InAllowList } - status := updateQPSConfig(l, label, cfg.QPS, cfg.QPSBurst) - status |= updateConcurrencyConfig(l, label, cfg.ConcurrencyLimit) - return status + lim, _ := l.limiters.LoadOrStore(label, newLimiter()) + return lim.(*limiter).updateDimensionConfig(cfg) } } diff --git a/server/api/middleware.go b/server/api/middleware.go index 4173c37b396..6536935592f 100644 --- a/server/api/middleware.go +++ b/server/api/middleware.go @@ -177,8 +177,8 @@ func (s *rateLimitMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request, // There is no need to check whether rateLimiter is nil. CreateServer ensures that it is created rateLimiter := s.svr.GetServiceRateLimiter() - if rateLimiter.Allow(requestInfo.ServiceLabel) { - defer rateLimiter.Release(requestInfo.ServiceLabel) + if done, err := rateLimiter.Allow(requestInfo.ServiceLabel); err == nil { + defer done() next(w, r) } else { http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests) diff --git a/server/grpc_service.go b/server/grpc_service.go index fa74f1ea8b6..24280f46437 100644 --- a/server/grpc_service.go +++ b/server/grpc_service.go @@ -431,11 +431,11 @@ func (s *GrpcServer) GetMembers(context.Context, *pdpb.GetMembersRequest) (*pdpb if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { fName := currentFunction() limiter := s.GetGRPCRateLimiter() - if limiter.Allow(fName) { - defer limiter.Release(fName) + if done, err := limiter.Allow(fName); err == nil { + defer done() } else { return &pdpb.GetMembersResponse{ - Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, errs.ErrRateLimitExceeded.FastGenByArgs().Error()), + Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), }, nil } } @@ -662,11 +662,11 @@ func (s *GrpcServer) GetStore(ctx context.Context, request *pdpb.GetStoreRequest if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { fName := currentFunction() limiter := s.GetGRPCRateLimiter() - if limiter.Allow(fName) { - defer limiter.Release(fName) + if done, err := limiter.Allow(fName); err == nil { + defer done() } else { return &pdpb.GetStoreResponse{ - Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, errs.ErrRateLimitExceeded.FastGenByArgs().Error()), + Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), }, nil } } @@ -765,11 +765,11 @@ func (s *GrpcServer) GetAllStores(ctx context.Context, request *pdpb.GetAllStore if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { fName := currentFunction() limiter := s.GetGRPCRateLimiter() - if limiter.Allow(fName) { - defer limiter.Release(fName) + if done, err := limiter.Allow(fName); err == nil { + defer done() } else { return &pdpb.GetAllStoresResponse{ - Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, errs.ErrRateLimitExceeded.FastGenByArgs().Error()), + Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), }, nil } } @@ -810,8 +810,8 @@ func (s *GrpcServer) StoreHeartbeat(ctx context.Context, request *pdpb.StoreHear if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { fName := currentFunction() limiter := s.GetGRPCRateLimiter() - if limiter.Allow(fName) { - defer limiter.Release(fName) + if done, err := limiter.Allow(fName); err == nil { + defer done() } else { return &pdpb.StoreHeartbeatResponse{ Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, errs.ErrRateLimitExceeded.FastGenByArgs().Error()), @@ -1286,8 +1286,8 @@ func (s *GrpcServer) GetRegion(ctx context.Context, request *pdpb.GetRegionReque if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { fName := currentFunction() limiter := s.GetGRPCRateLimiter() - if limiter.Allow(fName) { - defer limiter.Release(fName) + if done, err := limiter.Allow(fName); err == nil { + defer done() } else { return &pdpb.GetRegionResponse{ Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, errs.ErrRateLimitExceeded.FastGenByArgs().Error()), @@ -1330,8 +1330,8 @@ func (s *GrpcServer) GetPrevRegion(ctx context.Context, request *pdpb.GetRegionR if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { fName := currentFunction() limiter := s.GetGRPCRateLimiter() - if limiter.Allow(fName) { - defer limiter.Release(fName) + if done, err := limiter.Allow(fName); err == nil { + defer done() } else { return &pdpb.GetRegionResponse{ Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, errs.ErrRateLimitExceeded.FastGenByArgs().Error()), @@ -1375,8 +1375,8 @@ func (s *GrpcServer) GetRegionByID(ctx context.Context, request *pdpb.GetRegionB if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { fName := currentFunction() limiter := s.GetGRPCRateLimiter() - if limiter.Allow(fName) { - defer limiter.Release(fName) + if done, err := limiter.Allow(fName); err == nil { + defer done() } else { return &pdpb.GetRegionResponse{ Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, errs.ErrRateLimitExceeded.FastGenByArgs().Error()), @@ -1419,8 +1419,8 @@ func (s *GrpcServer) ScanRegions(ctx context.Context, request *pdpb.ScanRegionsR if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { fName := currentFunction() limiter := s.GetGRPCRateLimiter() - if limiter.Allow(fName) { - defer limiter.Release(fName) + if done, err := limiter.Allow(fName); err == nil { + defer done() } else { return &pdpb.ScanRegionsResponse{ Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, errs.ErrRateLimitExceeded.FastGenByArgs().Error()), diff --git a/server/server.go b/server/server.go index c815e7d50c6..187c30dbf7a 100644 --- a/server/server.go +++ b/server/server.go @@ -216,11 +216,11 @@ type Server struct { // related data structures defined in the PD grpc service pdProtoFactory *tsoutil.PDProtoFactory - serviceRateLimiter *ratelimit.Limiter + serviceRateLimiter *ratelimit.Controller serviceLabels map[string][]apiutil.AccessPath apiServiceLabelMap map[apiutil.AccessPath]string - grpcServiceRateLimiter *ratelimit.Limiter + grpcServiceRateLimiter *ratelimit.Controller grpcServiceLabels map[string]struct{} grpcServer *grpc.Server @@ -273,8 +273,8 @@ func CreateServer(ctx context.Context, cfg *config.Config, services []string, le audit.NewLocalLogBackend(true), audit.NewPrometheusHistogramBackend(serviceAuditHistogram, false), } - s.serviceRateLimiter = ratelimit.NewLimiter() - s.grpcServiceRateLimiter = ratelimit.NewLimiter() + s.serviceRateLimiter = ratelimit.NewController() + s.grpcServiceRateLimiter = ratelimit.NewController() s.serviceAuditBackendLabels = make(map[string]*audit.BackendLabels) s.serviceLabels = make(map[string][]apiutil.AccessPath) s.grpcServiceLabels = make(map[string]struct{}) @@ -1467,7 +1467,7 @@ func (s *Server) SetServiceAuditBackendLabels(serviceLabel string, labels []stri } // GetServiceRateLimiter is used to get rate limiter -func (s *Server) GetServiceRateLimiter() *ratelimit.Limiter { +func (s *Server) GetServiceRateLimiter() *ratelimit.Controller { return s.serviceRateLimiter } @@ -1482,7 +1482,7 @@ func (s *Server) UpdateServiceRateLimiter(serviceLabel string, opts ...ratelimit } // GetGRPCRateLimiter is used to get rate limiter -func (s *Server) GetGRPCRateLimiter() *ratelimit.Limiter { +func (s *Server) GetGRPCRateLimiter() *ratelimit.Controller { return s.grpcServiceRateLimiter }