Skip to content

Commit

Permalink
close tikv#4500: add QPS rate limiter
Browse files Browse the repository at this point in the history
Signed-off-by: Cabinfever_B <cabinfeveroier@gmail.com>
  • Loading branch information
CabinfeverB committed Dec 24, 2021
2 parents 06fc953 + d4c2179 commit a0a01db
Show file tree
Hide file tree
Showing 7 changed files with 357 additions and 4 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ require (
go.etcd.io/etcd v0.5.0-alpha.5.0.20191023171146-3cf2f69b5738
go.uber.org/goleak v1.1.12
go.uber.org/zap v1.16.0
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4
golang.org/x/tools v0.1.5
google.golang.org/grpc v1.26.0
gopkg.in/natefinch/lumberjack.v2 v2.0.0
Expand Down
30 changes: 30 additions & 0 deletions pkg/apiutil/apiutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ import (
"github.com/unrolled/render"
)

var (
// ComponentSignatureKey is used for http request header key
// to identify component signature
ComponentSignatureKey = "component"
// ComponentAnonymousValue identify anonymous request source
ComponentAnonymousValue = "anonymous"
)

// DeferClose captures the error returned from closing (if an error occurs).
// This is designed to be used in a defer statement.
func DeferClose(c io.Closer, err *error) {
Expand Down Expand Up @@ -139,3 +147,25 @@ func GetHTTPRouteName(req *http.Request) (string, bool) {
}
return "", false
}

// GetComponentNameOnHTTP return component name from Request Header
func GetComponentNameOnHTTP(r *http.Request) string {
componentName := r.Header.Get(ComponentSignatureKey)
if componentName == "" {
componentName = ComponentAnonymousValue
}
return componentName
}

type ComponentSignatureRoundTripper struct {
Proxied http.RoundTripper
Component string
}

// RoundTrip is used to implement RoundTripper
func (rt *ComponentSignatureRoundTripper) RoundTrip(req *http.Request) (resp *http.Response, err error) {
req.Header.Add(ComponentSignatureKey, rt.Component)
// Send the request, get the response and the error
resp, err = rt.Proxied.RoundTrip(req)
return
}
60 changes: 59 additions & 1 deletion server/self_protection.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ package server

import (
"net/http"
"sync"

"github.com/tikv/pd/pkg/apiutil"
"golang.org/x/time/rate"
)

// SelfProtectionManager is a framework to handle self protection mechanism
Expand Down Expand Up @@ -75,6 +77,62 @@ func (h *HTTPServiceSelfProtectionManager) Handle() bool {

// serviceSelfProtectionHandler is a handler which is independent communication mode
type serviceSelfProtectionHandler struct {
// todo APIRateLimiter
apiRateLimiter *APIRateLimiter
// todo AuditLogger
}

func (h *serviceSelfProtectionHandler) Handle(componentName string) bool {
limitAllow := h.rateLimitAllow(componentName)
// it will include other self-protection actions
return limitAllow
}

// RateLimitAllow is used to check whether the rate limit allow request process
func (h *serviceSelfProtectionHandler) rateLimitAllow(componentName string) bool {
if h.apiRateLimiter == nil {
return true
}
return h.apiRateLimiter.Allow(componentName)
}

// APIRateLimiter is used to limit unnecessary and excess request
// Currently support QPS rate limit by compoenent
// It depends on the rate.Limiter which implement a token-bucket algorithm
type APIRateLimiter struct {
mu sync.RWMutex

enableQPSLimit bool

totalQPSRateLimiter *rate.Limiter

enableComponentQPSLimit bool
componentQPSRateLimiter map[string]*rate.Limiter
}

// QPSAllow firstly check component token bucket and then check total token bucket
// if component rate limiter doesn't allow, it won't ask total limiter
func (rl *APIRateLimiter) QPSAllow(componentName string) bool {
if !rl.enableQPSLimit {
return true
}
isComponentQPSLimit := true
if rl.enableComponentQPSLimit {
componentRateLimiter, ok := rl.componentQPSRateLimiter[componentName]
// The current strategy is to ignore the component limit if it is not found
if ok {
isComponentQPSLimit = componentRateLimiter.Allow()
}
}
if !isComponentQPSLimit {
return isComponentQPSLimit
}
isTotalQPSLimit := rl.totalQPSRateLimiter.Allow()
return isTotalQPSLimit
}

// Allow currentlt only supports QPS rate limit
func (rl *APIRateLimiter) Allow(componentName string) bool {
rl.mu.RLock()
defer rl.mu.RUnlock()
return rl.QPSAllow(componentName)
}
168 changes: 168 additions & 0 deletions server/self_protection_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
// Copyright 2021 TiKV Project Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package server

import (
"sync"
"time"

. "github.com/pingcap/check"
"golang.org/x/time/rate"
)

var _ = Suite(&testSelfProtectHandler{})

type testSelfProtectHandler struct {
rateLimiterOnlyTotal *APIRateLimiter
rateLimiterDisabled *APIRateLimiter
rateLimiterZeroBucket *APIRateLimiter
rateLimiterComopnent *APIRateLimiter
rateLimiterNoComopnentConfig *APIRateLimiter
}

func (s *testSelfProtectHandler) SetUpSuite(c *C) {
s.rateLimiterOnlyTotal = &APIRateLimiter{
enableQPSLimit: true,
totalQPSRateLimiter: rate.NewLimiter(100, 100),
enableComponentQPSLimit: false,
}
s.rateLimiterDisabled = &APIRateLimiter{
enableQPSLimit: false,
}
s.rateLimiterZeroBucket = &APIRateLimiter{
enableQPSLimit: true,
totalQPSRateLimiter: rate.NewLimiter(0, 0),
}
s.rateLimiterComopnent = &APIRateLimiter{
enableQPSLimit: true,
totalQPSRateLimiter: rate.NewLimiter(100, 100),
enableComponentQPSLimit: true,
componentQPSRateLimiter: make(map[string]*rate.Limiter),
}
s.rateLimiterComopnent.componentQPSRateLimiter["pdctl"] = rate.NewLimiter(100, 100)
s.rateLimiterComopnent.componentQPSRateLimiter["anonymous"] = rate.NewLimiter(100, 100)

s.rateLimiterNoComopnentConfig = &APIRateLimiter{
enableQPSLimit: true,
totalQPSRateLimiter: rate.NewLimiter(200, 200),
enableComponentQPSLimit: true,
componentQPSRateLimiter: make(map[string]*rate.Limiter),
}
s.rateLimiterNoComopnentConfig.componentQPSRateLimiter["pdctl"] = rate.NewLimiter(10, 10)
}

func CountRateLimiterHandleResult(handler *serviceSelfProtectionHandler, component string, successCount *int,
failedCount *int, lock *sync.Mutex, wg *sync.WaitGroup) {
result := handler.Handle(component)
lock.Lock()
defer lock.Unlock()
if result {
*successCount++
} else {
*failedCount++
}
wg.Done()
}

func (s *testSelfProtectHandler) TestRateLimiterOnlyTotal(c *C) {
time.Sleep(1 * time.Second)
handler := serviceSelfProtectionHandler{
apiRateLimiter: s.rateLimiterOnlyTotal,
}
var lock sync.Mutex
successCount, failedCount := 0, 0
var wg sync.WaitGroup
wg.Add(110)
for i := 0; i < 110; i++ {
go CountRateLimiterHandleResult(&handler, "", &successCount, &failedCount, &lock, &wg)
}
wg.Wait()
c.Assert(failedCount, Equals, 10)
c.Assert(successCount, Equals, 100)
}

func (s *testSelfProtectHandler) TestRateLimiterDisabled(c *C) {
time.Sleep(1 * time.Second)
handler := serviceSelfProtectionHandler{
apiRateLimiter: s.rateLimiterDisabled,
}
c.Assert(handler.Handle(""), Equals, true)
}

func (s *testSelfProtectHandler) TestRateLimiterZeroBucket(c *C) {
time.Sleep(1 * time.Second)
handler := serviceSelfProtectionHandler{
apiRateLimiter: s.rateLimiterZeroBucket,
}
var lock sync.Mutex
successCount, failedCount := 0, 0
var wg sync.WaitGroup
wg.Add(110)
for i := 0; i < 110; i++ {
go CountRateLimiterHandleResult(&handler, "", &successCount, &failedCount, &lock, &wg)
}
wg.Wait()
c.Assert(failedCount, Equals, 110)
c.Assert(successCount, Equals, 0)
}

func (s *testSelfProtectHandler) TestRateLimiterComopnent(c *C) {
time.Sleep(1 * time.Second)
handler := serviceSelfProtectionHandler{
apiRateLimiter: s.rateLimiterComopnent,
}
var lock sync.Mutex
successCount, failedCount := 0, 0
var wg sync.WaitGroup
wg.Add(300)
for i := 0; i < 150; i++ {
go CountRateLimiterHandleResult(&handler, "anonymous", &successCount, &failedCount, &lock, &wg)
go CountRateLimiterHandleResult(&handler, "pdctl", &successCount, &failedCount, &lock, &wg)
}
wg.Wait()
c.Assert(failedCount, Equals, 200)
c.Assert(successCount, Equals, 100)

time.Sleep(2 * time.Second)
successCount, failedCount = 0, 0
wg.Add(150)
for i := 0; i < 150; i++ {
go CountRateLimiterHandleResult(&handler, "anonymous", &successCount, &failedCount, &lock, &wg)
}
wg.Wait()
c.Assert(failedCount, Equals, 50)
c.Assert(successCount, Equals, 100)
}

func (s *testSelfProtectHandler) TestRateLimiterComponentNoConfig(c *C) {
time.Sleep(1 * time.Second)
handler := serviceSelfProtectionHandler{
apiRateLimiter: s.rateLimiterNoComopnentConfig,
}
var lock sync.Mutex
successAnonymousCount, failedAnonymousCount := 0, 0
successPdctlCount, failedPdctlCount := 0, 0
var wg sync.WaitGroup
wg.Add(400)
for i := 0; i < 200; i++ {
go CountRateLimiterHandleResult(&handler, "anonymous", &successAnonymousCount, &failedAnonymousCount, &lock, &wg)
go CountRateLimiterHandleResult(&handler, "pdctl", &successPdctlCount, &failedPdctlCount, &lock, &wg)
}
wg.Wait()
c.Assert(successAnonymousCount, Equals, 190)
c.Assert(failedAnonymousCount, Equals, 10)
c.Assert(failedPdctlCount, Equals, 190)
c.Assert(successPdctlCount, Equals, 10)
}
79 changes: 79 additions & 0 deletions tests/pdctl/global_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// Copyright 2021 TiKV Project Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package pdctl

import (
"context"
"fmt"
"net/http"
"testing"

. "github.com/pingcap/check"
"github.com/pingcap/log"
"github.com/tikv/pd/pkg/apiutil"
"github.com/tikv/pd/pkg/testutil"
"github.com/tikv/pd/server"
cmd "github.com/tikv/pd/tools/pd-ctl/pdctl"
"go.uber.org/zap"
)

func Test(t *testing.T) {
TestingT(t)
}

var _ = Suite(&globalTestSuite{})

type globalTestSuite struct{}

func (s *globalTestSuite) SetUpSuite(c *C) {
server.EnableZap = true
}

func (s *globalTestSuite) TestSendAndGetComponent(c *C) {
handler := func(ctx context.Context, s *server.Server) (http.Handler, server.ServiceGroup, error) {
mux := http.NewServeMux()
mux.HandleFunc("/pd/api/v1/health", func(w http.ResponseWriter, r *http.Request) {
component := apiutil.GetComponentNameOnHTTP(r)
for k := range r.Header {
log.Info("header", zap.String("key", k))
}
log.Info("component", zap.String("component", component))
c.Assert(component, Equals, "pdctl")
fmt.Fprint(w, component)
})
info := server.ServiceGroup{
IsCore: true,
}
return mux, info, nil
}
cfg := server.NewTestSingleConfig(checkerWithNilAssert(c))
ctx, cancel := context.WithCancel(context.Background())
svr, err := server.CreateServer(ctx, cfg, handler)
c.Assert(err, IsNil)
err = svr.Run()
c.Assert(err, IsNil)
pdAddr := svr.GetAddr()
defer func() {
cancel()
svr.Close()
testutil.CleanServer(svr.GetConfig().DataDir)
}()

cmd := cmd.GetRootCmd()
args := []string{"-u", pdAddr, "health"}
output, err := ExecuteCommand(cmd, args...)
c.Assert(err, IsNil)
c.Assert(string(output), Equals, "pdctl\n")
}
9 changes: 9 additions & 0 deletions tests/pdctl/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/pingcap/kvproto/pkg/metapb"
"github.com/pingcap/kvproto/pkg/pdpb"
"github.com/spf13/cobra"
"github.com/tikv/pd/pkg/assertutil"
"github.com/tikv/pd/server"
"github.com/tikv/pd/server/api"
"github.com/tikv/pd/server/core"
Expand Down Expand Up @@ -113,3 +114,11 @@ func MustPutRegion(c *check.C, cluster *tests.TestCluster, regionID, storeID uin
c.Assert(err, check.IsNil)
return r
}

func checkerWithNilAssert(c *check.C) *assertutil.Checker {
checker := assertutil.NewChecker(c.FailNow)
checker.IsNil = func(obtained interface{}) {
c.Assert(obtained, check.IsNil)
}
return checker
}
Loading

0 comments on commit a0a01db

Please sign in to comment.