diff --git a/errors.toml b/errors.toml index 43fc6a582aa..06848a79d1e 100644 --- a/errors.toml +++ b/errors.toml @@ -701,6 +701,11 @@ error = ''' leader is nil ''' +["PD:server:ErrRateLimitExceeded"] +error = ''' +rate limit exceeded +''' + ["PD:server:ErrServerNotStarted"] error = ''' server not started diff --git a/pkg/errs/errno.go b/pkg/errs/errno.go index 0bd2a57dba5..a5e05219dfa 100644 --- a/pkg/errs/errno.go +++ b/pkg/errs/errno.go @@ -182,6 +182,7 @@ var ( ErrCancelStartEtcd = errors.Normalize("etcd start canceled", errors.RFCCodeText("PD:server:ErrCancelStartEtcd")) ErrConfigItem = errors.Normalize("cannot set invalid configuration", errors.RFCCodeText("PD:server:ErrConfiguration")) ErrServerNotStarted = errors.Normalize("server not started", errors.RFCCodeText("PD:server:ErrServerNotStarted")) + ErrRateLimitExceeded = errors.Normalize("rate limit exceeded", errors.RFCCodeText("PD:server:ErrRateLimitExceeded")) ) // logutil errors diff --git a/server/api/router.go b/server/api/router.go index ce649ba9aef..f47f0e3ebf2 100644 --- a/server/api/router.go +++ b/server/api/router.go @@ -219,7 +219,7 @@ func createRouter(prefix string, svr *server.Server) *mux.Router { registerFunc(clusterRouter, "/store/{id}/limit", storeHandler.SetStoreLimit, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) storesHandler := newStoresHandler(handler, rd) - registerFunc(clusterRouter, "/stores", storesHandler.GetStores, setMethods(http.MethodGet), setAuditBackend(prometheus)) + registerFunc(clusterRouter, "/stores", storesHandler.GetAllStores, setMethods(http.MethodGet), setAuditBackend(prometheus)) registerFunc(clusterRouter, "/stores/remove-tombstone", storesHandler.RemoveTombStone, setMethods(http.MethodDelete), setAuditBackend(localLog, prometheus)) registerFunc(clusterRouter, "/stores/limit", storesHandler.GetAllStoresLimit, setMethods(http.MethodGet), setAuditBackend(prometheus)) registerFunc(clusterRouter, "/stores/limit", storesHandler.SetAllStoresLimit, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) @@ -311,7 +311,8 @@ func createRouter(prefix string, svr *server.Server) *mux.Router { serviceMiddlewareHandler := newServiceMiddlewareHandler(svr, rd) registerFunc(apiRouter, "/service-middleware/config", serviceMiddlewareHandler.GetServiceMiddlewareConfig, setMethods(http.MethodGet), setAuditBackend(prometheus)) registerFunc(apiRouter, "/service-middleware/config", serviceMiddlewareHandler.SetServiceMiddlewareConfig, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) - registerFunc(apiRouter, "/service-middleware/config/rate-limit", serviceMiddlewareHandler.SetRatelimitConfig, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus), setRateLimitAllowList()) + registerFunc(apiRouter, "/service-middleware/config/rate-limit", serviceMiddlewareHandler.SetRateLimitConfig, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus), setRateLimitAllowList()) + registerFunc(apiRouter, "/service-middleware/config/grpc-rate-limit", serviceMiddlewareHandler.SetGRPCRateLimitConfig, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus), setRateLimitAllowList()) logHandler := newLogHandler(svr, rd) registerFunc(apiRouter, "/admin/log", logHandler.SetLogLevel, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) diff --git a/server/api/service_middleware.go b/server/api/service_middleware.go index 00d69a4d902..41d2f6601f0 100644 --- a/server/api/service_middleware.go +++ b/server/api/service_middleware.go @@ -111,6 +111,8 @@ func (h *serviceMiddlewareHandler) updateServiceMiddlewareConfig(cfg *config.Ser return h.updateAudit(cfg, kp[len(kp)-1], value) case "rate-limit": return h.svr.UpdateRateLimit(&cfg.RateLimitConfig, kp[len(kp)-1], value) + case "grpc-rate-limit": + return h.svr.UpdateGRPCRateLimit(&cfg.GRPCRateLimitConfig, kp[len(kp)-1], value) } return errors.Errorf("config prefix %s not found", kp[0]) } @@ -139,7 +141,7 @@ func (h *serviceMiddlewareHandler) updateAudit(config *config.ServiceMiddlewareC // @Failure 400 {string} string "The input is invalid." // @Failure 500 {string} string "config item not found" // @Router /service-middleware/config/rate-limit [POST] -func (h *serviceMiddlewareHandler) SetRatelimitConfig(w http.ResponseWriter, r *http.Request) { +func (h *serviceMiddlewareHandler) SetRateLimitConfig(w http.ResponseWriter, r *http.Request) { var input map[string]interface{} if err := apiutil.ReadJSONRespondError(h.rd, w, r.Body, &input); err != nil { return @@ -192,14 +194,14 @@ func (h *serviceMiddlewareHandler) SetRatelimitConfig(w http.ResponseWriter, r * qpsRateUpdatedFlag := "QPS rate limiter is not changed." qps, okq := input["qps"].(float64) if okq { - brust := 0 + burst := 0 if int(qps) > 1 { - brust = int(qps) + burst = int(qps) } else if qps > 0 { - brust = 1 + burst = 1 } cfg.QPS = qps - cfg.QPSBurst = brust + cfg.QPSBurst = burst } if !okc && !okq { h.rd.JSON(w, http.StatusOK, "No changed.") @@ -227,6 +229,76 @@ func (h *serviceMiddlewareHandler) SetRatelimitConfig(w http.ResponseWriter, r * } } +// @Tags service_middleware +// @Summary update gRPC ratelimit config +// @Param body body object string "json params" +// @Produce json +// @Success 200 {string} string +// @Failure 400 {string} string "The input is invalid." +// @Failure 500 {string} string "config item not found" +// @Router /service-middleware/config/grpc-rate-limit [POST] +func (h *serviceMiddlewareHandler) SetGRPCRateLimitConfig(w http.ResponseWriter, r *http.Request) { + var input map[string]interface{} + if err := apiutil.ReadJSONRespondError(h.rd, w, r.Body, &input); err != nil { + return + } + + serviceLabel, ok := input["label"].(string) + if !ok || len(serviceLabel) == 0 { + h.rd.JSON(w, http.StatusBadRequest, "The label is empty.") + return + } + if !h.svr.IsGRPCServiceLabelExist(serviceLabel) { + h.rd.JSON(w, http.StatusBadRequest, "There is no label matched.") + return + } + + cfg := h.svr.GetGRPCRateLimitConfig().LimiterConfig[serviceLabel] + // update concurrency limiter + concurrencyUpdatedFlag := "Concurrency limiter is not changed." + concurrencyFloat, okc := input["concurrency"].(float64) + if okc { + cfg.ConcurrencyLimit = uint64(concurrencyFloat) + } + // update qps rate limiter + qpsRateUpdatedFlag := "QPS rate limiter is not changed." + qps, okq := input["qps"].(float64) + if okq { + burst := 0 + if int(qps) > 1 { + burst = int(qps) + } else if qps > 0 { + burst = 1 + } + cfg.QPS = qps + cfg.QPSBurst = burst + } + if !okc && !okq { + h.rd.JSON(w, http.StatusOK, "No changed.") + } else { + status := h.svr.UpdateGRPCServiceRateLimiter(serviceLabel, ratelimit.UpdateDimensionConfig(&cfg)) + switch { + case status&ratelimit.QPSChanged != 0: + qpsRateUpdatedFlag = "QPS rate limiter is changed." + case status&ratelimit.QPSDeleted != 0: + qpsRateUpdatedFlag = "QPS rate limiter is deleted." + } + switch { + case status&ratelimit.ConcurrencyChanged != 0: + concurrencyUpdatedFlag = "Concurrency limiter is changed." + case status&ratelimit.ConcurrencyDeleted != 0: + concurrencyUpdatedFlag = "Concurrency limiter is deleted." + } + err := h.svr.UpdateGRPCRateLimitConfig("grpc-limiter-config", serviceLabel, cfg) + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) + } else { + result := rateLimitResult{concurrencyUpdatedFlag, qpsRateUpdatedFlag, h.svr.GetServiceMiddlewareConfig().GRPCRateLimitConfig.LimiterConfig} + h.rd.JSON(w, http.StatusOK, result) + } + } +} + type rateLimitResult struct { ConcurrencyUpdatedFlag string `json:"concurrency"` QPSRateUpdatedFlag string `json:"qps"` diff --git a/server/api/service_middleware_test.go b/server/api/service_middleware_test.go index e1a5853db16..ec8094e670e 100644 --- a/server/api/service_middleware_test.go +++ b/server/api/service_middleware_test.go @@ -62,8 +62,9 @@ func (suite *auditMiddlewareTestSuite) TestConfigAuditSwitch() { suite.True(sc.EnableAudit) ms := map[string]interface{}{ - "enable-audit": "true", - "enable-rate-limit": "true", + "enable-audit": "true", + "enable-rate-limit": "true", + "enable-grpc-rate-limit": "true", } postData, err := json.Marshal(ms) suite.NoError(err) @@ -71,10 +72,12 @@ func (suite *auditMiddlewareTestSuite) TestConfigAuditSwitch() { sc = &config.ServiceMiddlewareConfig{} suite.NoError(tu.ReadGetJSON(re, testDialClient, addr, sc)) suite.True(sc.EnableAudit) - suite.True(sc.EnableRateLimit) + suite.True(sc.RateLimitConfig.EnableRateLimit) + suite.True(sc.GRPCRateLimitConfig.EnableRateLimit) ms = map[string]interface{}{ - "audit.enable-audit": "false", - "enable-rate-limit": "false", + "audit.enable-audit": "false", + "enable-rate-limit": "false", + "enable-grpc-rate-limit": "false", } postData, err = json.Marshal(ms) suite.NoError(err) @@ -82,7 +85,8 @@ func (suite *auditMiddlewareTestSuite) TestConfigAuditSwitch() { sc = &config.ServiceMiddlewareConfig{} suite.NoError(tu.ReadGetJSON(re, testDialClient, addr, sc)) suite.False(sc.EnableAudit) - suite.False(sc.EnableRateLimit) + suite.False(sc.RateLimitConfig.EnableRateLimit) + suite.False(sc.GRPCRateLimitConfig.EnableRateLimit) // test empty ms = map[string]interface{}{} @@ -273,12 +277,12 @@ func (suite *rateLimitConfigTestSuite) TestUpdateRateLimitConfig() { suite.NoError(err) limiter := suite.svr.GetServiceRateLimiter() - limiter.Update("SetRatelimitConfig", ratelimit.AddLabelAllowList()) + limiter.Update("SetRateLimitConfig", ratelimit.AddLabelAllowList()) // Allow list input = make(map[string]interface{}) input["type"] = "label" - input["label"] = "SetRatelimitConfig" + input["label"] = "SetRateLimitConfig" input["qps"] = 100 input["concurrency"] = 100 jsonBody, err = json.Marshal(input) @@ -288,31 +292,128 @@ func (suite *rateLimitConfigTestSuite) TestUpdateRateLimitConfig() { suite.NoError(err) } +func (suite *rateLimitConfigTestSuite) TestUpdateGRPCRateLimitConfig() { + urlPrefix := fmt.Sprintf("%s%s/api/v1/service-middleware/config/grpc-rate-limit", suite.svr.GetAddr(), apiPrefix) + re := suite.Require() + + // test empty label + input := make(map[string]interface{}) + input["label"] = "" + jsonBody, err := json.Marshal(input) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, urlPrefix, jsonBody, + tu.Status(re, http.StatusBadRequest), tu.StringEqual(re, "\"The label is empty.\"\n")) + suite.NoError(err) + // test no label matched + input = make(map[string]interface{}) + input["label"] = "TestLabel" + jsonBody, err = json.Marshal(input) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, urlPrefix, jsonBody, + tu.Status(re, http.StatusBadRequest), tu.StringEqual(re, "\"There is no label matched.\"\n")) + suite.NoError(err) + + // no change + input = make(map[string]interface{}) + input["label"] = "StoreHeartbeat" + jsonBody, err = json.Marshal(input) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, urlPrefix, jsonBody, + tu.StatusOK(re), tu.StringEqual(re, "\"No changed.\"\n")) + suite.NoError(err) + + // change concurrency + input = make(map[string]interface{}) + input["label"] = "StoreHeartbeat" + input["concurrency"] = 100 + jsonBody, err = json.Marshal(input) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, urlPrefix, jsonBody, + tu.StatusOK(re), tu.StringContain(re, "Concurrency limiter is changed.")) + suite.NoError(err) + input["concurrency"] = 0 + jsonBody, err = json.Marshal(input) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, urlPrefix, jsonBody, + tu.StatusOK(re), tu.StringContain(re, "Concurrency limiter is deleted.")) + suite.NoError(err) + + // change qps + input = make(map[string]interface{}) + input["label"] = "StoreHeartbeat" + input["qps"] = 100 + jsonBody, err = json.Marshal(input) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, urlPrefix, jsonBody, + tu.StatusOK(re), tu.StringContain(re, "QPS rate limiter is changed.")) + suite.NoError(err) + + input = make(map[string]interface{}) + input["label"] = "StoreHeartbeat" + input["qps"] = 0.3 + jsonBody, err = json.Marshal(input) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, urlPrefix, jsonBody, + tu.StatusOK(re), tu.StringContain(re, "QPS rate limiter is changed.")) + suite.NoError(err) + suite.Equal(1, suite.svr.GetGRPCRateLimitConfig().LimiterConfig["StoreHeartbeat"].QPSBurst) + + input["qps"] = -1 + jsonBody, err = json.Marshal(input) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, urlPrefix, jsonBody, + tu.StatusOK(re), tu.StringContain(re, "QPS rate limiter is deleted.")) + suite.NoError(err) + + // change both + input = make(map[string]interface{}) + input["label"] = "GetStore" + input["qps"] = 100 + input["concurrency"] = 100 + jsonBody, err = json.Marshal(input) + suite.NoError(err) + result := rateLimitResult{} + err = tu.CheckPostJSON(testDialClient, urlPrefix, jsonBody, + tu.StatusOK(re), tu.StringContain(re, "Concurrency limiter is changed."), + tu.StringContain(re, "QPS rate limiter is changed."), + tu.ExtractJSON(re, &result), + ) + suite.Equal(100., result.LimiterConfig["GetStore"].QPS) + suite.Equal(100, result.LimiterConfig["GetStore"].QPSBurst) + suite.Equal(uint64(100), result.LimiterConfig["GetStore"].ConcurrencyLimit) + suite.NoError(err) +} + func (suite *rateLimitConfigTestSuite) TestConfigRateLimitSwitch() { addr := fmt.Sprintf("%s/service-middleware/config", suite.urlPrefix) sc := &config.ServiceMiddlewareConfig{} re := suite.Require() suite.NoError(tu.ReadGetJSON(re, testDialClient, addr, sc)) - suite.False(sc.EnableRateLimit) + suite.False(sc.RateLimitConfig.EnableRateLimit) + suite.False(sc.GRPCRateLimitConfig.EnableRateLimit) ms := map[string]interface{}{ - "enable-rate-limit": "true", + "enable-rate-limit": "true", + "enable-grpc-rate-limit": "true", } postData, err := json.Marshal(ms) suite.NoError(err) suite.NoError(tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re))) sc = &config.ServiceMiddlewareConfig{} suite.NoError(tu.ReadGetJSON(re, testDialClient, addr, sc)) - suite.True(sc.EnableRateLimit) + suite.True(sc.RateLimitConfig.EnableRateLimit) + suite.True(sc.GRPCRateLimitConfig.EnableRateLimit) ms = map[string]interface{}{ - "enable-rate-limit": "false", + "enable-rate-limit": "false", + "enable-grpc-rate-limit": "false", } postData, err = json.Marshal(ms) suite.NoError(err) suite.NoError(tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re))) sc = &config.ServiceMiddlewareConfig{} suite.NoError(tu.ReadGetJSON(re, testDialClient, addr, sc)) - suite.False(sc.EnableRateLimit) + suite.False(sc.RateLimitConfig.EnableRateLimit) + suite.False(sc.GRPCRateLimitConfig.EnableRateLimit) // test empty ms = map[string]interface{}{} @@ -327,7 +428,8 @@ func (suite *rateLimitConfigTestSuite) TestConfigRateLimitSwitch() { suite.NoError(tu.CheckPostJSON(testDialClient, addr, postData, tu.Status(re, http.StatusBadRequest), tu.StringEqual(re, "config item rate-limit not found"))) suite.NoError(failpoint.Enable("github.com/tikv/pd/server/config/persistServiceMiddlewareFail", "return(true)")) ms = map[string]interface{}{ - "rate-limit.enable-rate-limit": "true", + "rate-limit.enable-rate-limit": "true", + "grpc-rate-limit.enable-grpc-rate-limit": "true", } postData, err = json.Marshal(ms) suite.NoError(err) @@ -341,7 +443,7 @@ func (suite *rateLimitConfigTestSuite) TestConfigRateLimitSwitch() { suite.NoError(tu.CheckPostJSON(testDialClient, addr, postData, tu.Status(re, http.StatusBadRequest), tu.StringEqual(re, "config item rate-limit not found"))) } -func (suite *rateLimitConfigTestSuite) TestConfigLimiterConifgByOriginAPI() { +func (suite *rateLimitConfigTestSuite) TestConfigLimiterConfigByOriginAPI() { // this test case is used to test updating `limiter-config` by origin API simply addr := fmt.Sprintf("%s/service-middleware/config", suite.urlPrefix) dimensionConfig := ratelimit.DimensionConfig{QPS: 1} diff --git a/server/api/store.go b/server/api/store.go index 49384439bc4..a3e8c4518a2 100644 --- a/server/api/store.go +++ b/server/api/store.go @@ -734,13 +734,13 @@ func (h *storesHandler) GetStoresProgress(w http.ResponseWriter, r *http.Request } // @Tags store -// @Summary Get stores in the cluster. +// @Summary Get all stores in the cluster. // @Param state query array true "Specify accepted store states." // @Produce json // @Success 200 {object} StoresInfo // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /stores [get] -func (h *storesHandler) GetStores(w http.ResponseWriter, r *http.Request) { +func (h *storesHandler) GetAllStores(w http.ResponseWriter, r *http.Request) { rc := getCluster(r) stores := rc.GetMetaStores() StoresInfo := &StoresInfo{ diff --git a/server/config/service_middleware_config.go b/server/config/service_middleware_config.go index dd497e17733..ef0b04b2abd 100644 --- a/server/config/service_middleware_config.go +++ b/server/config/service_middleware_config.go @@ -17,14 +17,16 @@ package config import "github.com/tikv/pd/pkg/ratelimit" const ( - defaultEnableAuditMiddleware = true - defaultEnableRateLimitMiddleware = false + defaultEnableAuditMiddleware = true + defaultEnableRateLimitMiddleware = false + defaultEnableGRPCRateLimitMiddleware = false ) // ServiceMiddlewareConfig is the configuration for PD Service middleware. type ServiceMiddlewareConfig struct { - AuditConfig `json:"audit"` - RateLimitConfig `json:"rate-limit"` + AuditConfig `json:"audit"` + RateLimitConfig `json:"rate-limit"` + GRPCRateLimitConfig `json:"grpc-rate-limit"` } // NewServiceMiddlewareConfig returns a new service middleware config @@ -32,13 +34,18 @@ func NewServiceMiddlewareConfig() *ServiceMiddlewareConfig { audit := AuditConfig{ EnableAudit: defaultEnableAuditMiddleware, } - ratelimit := RateLimitConfig{ + rateLimit := RateLimitConfig{ + EnableRateLimit: defaultEnableRateLimitMiddleware, + LimiterConfig: make(map[string]ratelimit.DimensionConfig), + } + grpcRateLimit := GRPCRateLimitConfig{ EnableRateLimit: defaultEnableRateLimitMiddleware, LimiterConfig: make(map[string]ratelimit.DimensionConfig), } cfg := &ServiceMiddlewareConfig{ - AuditConfig: audit, - RateLimitConfig: ratelimit, + AuditConfig: audit, + RateLimitConfig: rateLimit, + GRPCRateLimitConfig: grpcRateLimit, } return cfg } @@ -74,3 +81,17 @@ func (c *RateLimitConfig) Clone() *RateLimitConfig { cfg := *c return &cfg } + +// GRPCRateLimitConfig is the configuration for gRPC rate limit +type GRPCRateLimitConfig struct { + // EnableRateLimit controls the switch of the rate limit middleware + EnableRateLimit bool `json:"enable-grpc-rate-limit,string"` + // RateLimitConfig is the config of rate limit middleware + LimiterConfig map[string]ratelimit.DimensionConfig `json:"grpc-limiter-config"` +} + +// Clone returns a cloned rate limit config. +func (c *GRPCRateLimitConfig) Clone() *GRPCRateLimitConfig { + cfg := *c + return &cfg +} diff --git a/server/config/service_middleware_persist_options.go b/server/config/service_middleware_persist_options.go index 2dd9d245f89..cd67c9dd1ac 100644 --- a/server/config/service_middleware_persist_options.go +++ b/server/config/service_middleware_persist_options.go @@ -25,8 +25,9 @@ import ( // ServiceMiddlewarePersistOptions wraps all service middleware configurations that need to persist to storage and // allows to access them safely. type ServiceMiddlewarePersistOptions struct { - audit atomic.Value - rateLimit atomic.Value + audit atomic.Value + rateLimit atomic.Value + grpcRateLimit atomic.Value } // NewServiceMiddlewarePersistOptions creates a new ServiceMiddlewarePersistOptions instance. @@ -34,6 +35,7 @@ func NewServiceMiddlewarePersistOptions(cfg *ServiceMiddlewareConfig) *ServiceMi o := &ServiceMiddlewarePersistOptions{} o.audit.Store(&cfg.AuditConfig) o.rateLimit.Store(&cfg.RateLimitConfig) + o.grpcRateLimit.Store(&cfg.GRPCRateLimitConfig) return o } @@ -67,11 +69,27 @@ func (o *ServiceMiddlewarePersistOptions) IsRateLimitEnabled() bool { return o.GetRateLimitConfig().EnableRateLimit } +// GetGRPCRateLimitConfig returns pd service middleware configurations. +func (o *ServiceMiddlewarePersistOptions) GetGRPCRateLimitConfig() *GRPCRateLimitConfig { + return o.grpcRateLimit.Load().(*GRPCRateLimitConfig) +} + +// SetGRPCRateLimitConfig sets the PD service middleware configuration. +func (o *ServiceMiddlewarePersistOptions) SetGRPCRateLimitConfig(cfg *GRPCRateLimitConfig) { + o.grpcRateLimit.Store(cfg) +} + +// IsGRPCRateLimitEnabled returns whether rate limit middleware is enabled +func (o *ServiceMiddlewarePersistOptions) IsGRPCRateLimitEnabled() bool { + return o.GetGRPCRateLimitConfig().EnableRateLimit +} + // Persist saves the configuration to the storage. func (o *ServiceMiddlewarePersistOptions) Persist(storage endpoint.ServiceMiddlewareStorage) error { cfg := &ServiceMiddlewareConfig{ - AuditConfig: *o.GetAuditConfig(), - RateLimitConfig: *o.GetRateLimitConfig(), + AuditConfig: *o.GetAuditConfig(), + RateLimitConfig: *o.GetRateLimitConfig(), + GRPCRateLimitConfig: *o.GetGRPCRateLimitConfig(), } err := storage.SaveServiceMiddlewareConfig(cfg) failpoint.Inject("persistServiceMiddlewareFail", func() { @@ -91,6 +109,7 @@ func (o *ServiceMiddlewarePersistOptions) Reload(storage endpoint.ServiceMiddlew if isExist { o.audit.Store(&cfg.AuditConfig) o.rateLimit.Store(&cfg.RateLimitConfig) + o.grpcRateLimit.Store(&cfg.GRPCRateLimitConfig) } return nil } diff --git a/server/grpc_service.go b/server/grpc_service.go index 4bc63224401..34e34f1e34a 100644 --- a/server/grpc_service.go +++ b/server/grpc_service.go @@ -19,6 +19,7 @@ import ( "fmt" "io" "path" + "runtime" "runtime/trace" "strconv" "strings" @@ -278,6 +279,17 @@ func (s *GrpcServer) getMinTSFromSingleServer( // GetMembers implements gRPC PDServer. func (s *GrpcServer) GetMembers(context.Context, *pdpb.GetMembersRequest) (*pdpb.GetMembersResponse, error) { + if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { + fName := currentFunction() + limiter := s.GetGRPCRateLimiter() + if limiter.Allow(fName) { + defer limiter.Release(fName) + } else { + return &pdpb.GetMembersResponse{ + Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, errs.ErrRateLimitExceeded.FastGenByArgs().Error()), + }, nil + } + } // Here we purposely do not check the cluster ID because the client does not know the correct cluster ID // at startup and needs to get the cluster ID with the first request (i.e. GetMembers). if s.IsClosed() { @@ -760,6 +772,17 @@ func (s *GrpcServer) IsSnapshotRecovering(ctx context.Context, request *pdpb.IsS // GetStore implements gRPC PDServer. func (s *GrpcServer) GetStore(ctx context.Context, request *pdpb.GetStoreRequest) (*pdpb.GetStoreResponse, error) { + if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { + fName := currentFunction() + limiter := s.GetGRPCRateLimiter() + if limiter.Allow(fName) { + defer limiter.Release(fName) + } else { + return &pdpb.GetStoreResponse{ + Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, errs.ErrRateLimitExceeded.FastGenByArgs().Error()), + }, nil + } + } fn := func(ctx context.Context, client *grpc.ClientConn) (interface{}, error) { return pdpb.NewPDClient(client).GetStore(ctx, request) } @@ -852,6 +875,17 @@ func (s *GrpcServer) PutStore(ctx context.Context, request *pdpb.PutStoreRequest // GetAllStores implements gRPC PDServer. func (s *GrpcServer) GetAllStores(ctx context.Context, request *pdpb.GetAllStoresRequest) (*pdpb.GetAllStoresResponse, error) { + if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { + fName := currentFunction() + limiter := s.GetGRPCRateLimiter() + if limiter.Allow(fName) { + defer limiter.Release(fName) + } else { + return &pdpb.GetAllStoresResponse{ + Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, errs.ErrRateLimitExceeded.FastGenByArgs().Error()), + }, nil + } + } fn := func(ctx context.Context, client *grpc.ClientConn) (interface{}, error) { return pdpb.NewPDClient(client).GetAllStores(ctx, request) } @@ -886,6 +920,17 @@ func (s *GrpcServer) GetAllStores(ctx context.Context, request *pdpb.GetAllStore // StoreHeartbeat implements gRPC PDServer. func (s *GrpcServer) StoreHeartbeat(ctx context.Context, request *pdpb.StoreHeartbeatRequest) (*pdpb.StoreHeartbeatResponse, error) { + if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { + fName := currentFunction() + limiter := s.GetGRPCRateLimiter() + if limiter.Allow(fName) { + defer limiter.Release(fName) + } else { + return &pdpb.StoreHeartbeatResponse{ + Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, errs.ErrRateLimitExceeded.FastGenByArgs().Error()), + }, nil + } + } fn := func(ctx context.Context, client *grpc.ClientConn) (interface{}, error) { return pdpb.NewPDClient(client).StoreHeartbeat(ctx, request) } @@ -1256,6 +1301,17 @@ func (s *GrpcServer) RegionHeartbeat(stream pdpb.PD_RegionHeartbeatServer) error // GetRegion implements gRPC PDServer. func (s *GrpcServer) GetRegion(ctx context.Context, request *pdpb.GetRegionRequest) (*pdpb.GetRegionResponse, error) { + if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { + fName := currentFunction() + limiter := s.GetGRPCRateLimiter() + if limiter.Allow(fName) { + defer limiter.Release(fName) + } else { + return &pdpb.GetRegionResponse{ + Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, errs.ErrRateLimitExceeded.FastGenByArgs().Error()), + }, nil + } + } fn := func(ctx context.Context, client *grpc.ClientConn) (interface{}, error) { return pdpb.NewPDClient(client).GetRegion(ctx, request) } @@ -1289,6 +1345,17 @@ func (s *GrpcServer) GetRegion(ctx context.Context, request *pdpb.GetRegionReque // GetPrevRegion implements gRPC PDServer func (s *GrpcServer) GetPrevRegion(ctx context.Context, request *pdpb.GetRegionRequest) (*pdpb.GetRegionResponse, error) { + if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { + fName := currentFunction() + limiter := s.GetGRPCRateLimiter() + if limiter.Allow(fName) { + defer limiter.Release(fName) + } else { + return &pdpb.GetRegionResponse{ + Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, errs.ErrRateLimitExceeded.FastGenByArgs().Error()), + }, nil + } + } fn := func(ctx context.Context, client *grpc.ClientConn) (interface{}, error) { return pdpb.NewPDClient(client).GetPrevRegion(ctx, request) } @@ -1323,6 +1390,17 @@ func (s *GrpcServer) GetPrevRegion(ctx context.Context, request *pdpb.GetRegionR // GetRegionByID implements gRPC PDServer. func (s *GrpcServer) GetRegionByID(ctx context.Context, request *pdpb.GetRegionByIDRequest) (*pdpb.GetRegionResponse, error) { + if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { + fName := currentFunction() + limiter := s.GetGRPCRateLimiter() + if limiter.Allow(fName) { + defer limiter.Release(fName) + } else { + return &pdpb.GetRegionResponse{ + Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, errs.ErrRateLimitExceeded.FastGenByArgs().Error()), + }, nil + } + } fn := func(ctx context.Context, client *grpc.ClientConn) (interface{}, error) { return pdpb.NewPDClient(client).GetRegionByID(ctx, request) } @@ -1356,6 +1434,17 @@ func (s *GrpcServer) GetRegionByID(ctx context.Context, request *pdpb.GetRegionB // ScanRegions implements gRPC PDServer. func (s *GrpcServer) ScanRegions(ctx context.Context, request *pdpb.ScanRegionsRequest) (*pdpb.ScanRegionsResponse, error) { + if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { + fName := currentFunction() + limiter := s.GetGRPCRateLimiter() + if limiter.Allow(fName) { + defer limiter.Release(fName) + } else { + return &pdpb.ScanRegionsResponse{ + Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, errs.ErrRateLimitExceeded.FastGenByArgs().Error()), + }, nil + } + } fn := func(ctx context.Context, client *grpc.ClientConn) (interface{}, error) { return pdpb.NewPDClient(client).ScanRegions(ctx, request) } @@ -2530,3 +2619,9 @@ func (s *GrpcServer) GetExternalTimestamp(ctx context.Context, request *pdpb.Get Timestamp: timestamp, }, nil } + +func currentFunction() string { + counter, _, _, _ := runtime.Caller(1) + s := strings.Split(runtime.FuncForPC(counter).Name(), ".") + return s[len(s)-1] +} diff --git a/server/server.go b/server/server.go index c03ebbc17b4..2403b7be437 100644 --- a/server/server.go +++ b/server/server.go @@ -217,6 +217,10 @@ type Server struct { serviceLabels map[string][]apiutil.AccessPath apiServiceLabelMap map[apiutil.AccessPath]string + grpcServiceRateLimiter *ratelimit.Limiter + grpcServiceLabels map[string]struct{} + grpcServer *grpc.Server + serviceAuditBackendLabels map[string]*audit.BackendLabels auditBackends []audit.Backend @@ -266,8 +270,10 @@ func CreateServer(ctx context.Context, cfg *config.Config, services []string, le audit.NewPrometheusHistogramBackend(serviceAuditHistogram, false), } s.serviceRateLimiter = ratelimit.NewLimiter() + s.grpcServiceRateLimiter = ratelimit.NewLimiter() s.serviceAuditBackendLabels = make(map[string]*audit.BackendLabels) s.serviceLabels = make(map[string][]apiutil.AccessPath) + s.grpcServiceLabels = make(map[string]struct{}) s.apiServiceLabelMap = make(map[apiutil.AccessPath]string) // Adjust etcd config. @@ -299,8 +305,8 @@ func CreateServer(ctx context.Context, cfg *config.Config, services []string, le diagnosticspb.RegisterDiagnosticsServer(gs, s) // Register the micro services GRPC service. s.registry.InstallAllGRPCServices(s, gs) + s.grpcServer = gs } - s.etcdCfg = etcdCfg s.lg = cfg.Logger s.logProps = cfg.LogProps @@ -367,9 +373,18 @@ func (s *Server) startEtcd(ctx context.Context) error { time.Sleep(1500 * time.Millisecond) }) s.member = member.NewMember(etcd, s.electionClient, etcdServerID) + s.initGRPCServiceLabels() return nil } +func (s *Server) initGRPCServiceLabels() { + for _, serviceInfo := range s.grpcServer.GetServiceInfo() { + for _, methodInfo := range serviceInfo.Methods { + s.grpcServiceLabels[methodInfo.Name] = struct{}{} + } + } +} + func (s *Server) startClient() (*clientv3.Client, *http.Client, error) { tlsConfig, err := s.cfg.Security.ToTLSConfig() if err != nil { @@ -899,6 +914,7 @@ func (s *Server) GetServiceMiddlewareConfig() *config.ServiceMiddlewareConfig { cfg := s.serviceMiddlewareCfg.Clone() cfg.AuditConfig = *s.serviceMiddlewarePersistOptions.GetAuditConfig().Clone() cfg.RateLimitConfig = *s.serviceMiddlewarePersistOptions.GetRateLimitConfig().Clone() + cfg.GRPCRateLimitConfig = *s.serviceMiddlewarePersistOptions.GetGRPCRateLimitConfig().Clone() return cfg } @@ -1105,7 +1121,7 @@ func (s *Server) SetAuditConfig(cfg config.AuditConfig) error { func (s *Server) UpdateRateLimitConfig(key, label string, value ratelimit.DimensionConfig) error { cfg := s.GetServiceMiddlewareConfig() rateLimitCfg := make(map[string]ratelimit.DimensionConfig) - for label, item := range cfg.LimiterConfig { + for label, item := range cfg.RateLimitConfig.LimiterConfig { rateLimitCfg[label] = item } rateLimitCfg[label] = value @@ -1140,7 +1156,7 @@ func (s *Server) SetRateLimitConfig(cfg config.RateLimitConfig) error { s.serviceMiddlewarePersistOptions.SetRateLimitConfig(&cfg) if err := s.serviceMiddlewarePersistOptions.Persist(s.storage); err != nil { s.serviceMiddlewarePersistOptions.SetRateLimitConfig(old) - log.Error("failed to update Rate Limit config", + log.Error("failed to update rate limit config", zap.Reflect("new", cfg), zap.Reflect("old", old), errs.ZapError(err)) @@ -1150,6 +1166,55 @@ func (s *Server) SetRateLimitConfig(cfg config.RateLimitConfig) error { return nil } +// UpdateGRPCRateLimitConfig is used to update rate-limit config which will reserve old limiter-config +func (s *Server) UpdateGRPCRateLimitConfig(key, label string, value ratelimit.DimensionConfig) error { + cfg := s.GetServiceMiddlewareConfig() + rateLimitCfg := make(map[string]ratelimit.DimensionConfig) + for label, item := range cfg.GRPCRateLimitConfig.LimiterConfig { + rateLimitCfg[label] = item + } + rateLimitCfg[label] = value + return s.UpdateGRPCRateLimit(&cfg.GRPCRateLimitConfig, key, &rateLimitCfg) +} + +// UpdateGRPCRateLimit is used to update gRPC rate-limit config which will overwrite limiter-config +func (s *Server) UpdateGRPCRateLimit(cfg *config.GRPCRateLimitConfig, key string, value interface{}) error { + updated, found, err := jsonutil.AddKeyValue(cfg, key, value) + if err != nil { + return err + } + + if !found { + return errors.Errorf("config item %s not found", key) + } + + if updated { + err = s.SetGRPCRateLimitConfig(*cfg) + } + return err +} + +// GetGRPCRateLimitConfig gets the rate limit config information. +func (s *Server) GetGRPCRateLimitConfig() *config.GRPCRateLimitConfig { + return s.serviceMiddlewarePersistOptions.GetGRPCRateLimitConfig().Clone() +} + +// SetGRPCRateLimitConfig sets the rate limit config. +func (s *Server) SetGRPCRateLimitConfig(cfg config.GRPCRateLimitConfig) error { + old := s.serviceMiddlewarePersistOptions.GetGRPCRateLimitConfig() + s.serviceMiddlewarePersistOptions.SetGRPCRateLimitConfig(&cfg) + if err := s.serviceMiddlewarePersistOptions.Persist(s.storage); err != nil { + s.serviceMiddlewarePersistOptions.SetGRPCRateLimitConfig(old) + log.Error("failed to update gRPC rate limit config", + zap.Reflect("new", cfg), + zap.Reflect("old", old), + errs.ZapError(err)) + return err + } + log.Info("gRPC rate limit config is updated", zap.Reflect("new", cfg), zap.Reflect("old", old)) + return nil +} + // GetPDServerConfig gets the balance config information. func (s *Server) GetPDServerConfig() *config.PDServerConfig { return s.persistOptions.GetPDServerConfig().Clone() @@ -1330,15 +1395,21 @@ func (s *Server) GetServiceLabels(serviceLabel string) []apiutil.AccessPath { return nil } +// IsGRPCServiceLabelExist returns if the service label exists +func (s *Server) IsGRPCServiceLabelExist(serviceLabel string) bool { + _, ok := s.grpcServiceLabels[serviceLabel] + return ok +} + // GetAPIAccessServiceLabel returns service label by given access path // TODO: this function will be used for updating api rate limit config func (s *Server) GetAPIAccessServiceLabel(accessPath apiutil.AccessPath) string { - if servicelabel, ok := s.apiServiceLabelMap[accessPath]; ok { - return servicelabel + if serviceLabel, ok := s.apiServiceLabelMap[accessPath]; ok { + return serviceLabel } accessPathNoMethod := apiutil.NewAccessPath(accessPath.Path, "") - if servicelabel, ok := s.apiServiceLabelMap[accessPathNoMethod]; ok { - return servicelabel + if serviceLabel, ok := s.apiServiceLabelMap[accessPathNoMethod]; ok { + return serviceLabel } return "" } @@ -1387,6 +1458,16 @@ func (s *Server) UpdateServiceRateLimiter(serviceLabel string, opts ...ratelimit return s.serviceRateLimiter.Update(serviceLabel, opts...) } +// GetGRPCRateLimiter is used to get rate limiter +func (s *Server) GetGRPCRateLimiter() *ratelimit.Limiter { + return s.grpcServiceRateLimiter +} + +// UpdateGRPCServiceRateLimiter is used to update RateLimiter +func (s *Server) UpdateGRPCServiceRateLimiter(serviceLabel string, opts ...ratelimit.Option) ratelimit.UpdateStatus { + return s.grpcServiceRateLimiter.Update(serviceLabel, opts...) +} + // GetClusterStatus gets cluster status. func (s *Server) GetClusterStatus() (*cluster.Status, error) { s.cluster.Lock() @@ -1704,6 +1785,7 @@ func (s *Server) reloadConfigFromKV() error { return err } s.loadRateLimitConfig() + s.loadGRPCRateLimitConfig() s.loadKeyspaceConfig() useRegionStorage := s.persistOptions.IsUseRegionStorage() regionStorage := storage.TrySwitchRegionStorage(s.storage, useRegionStorage) @@ -1733,6 +1815,14 @@ func (s *Server) loadRateLimitConfig() { } } +func (s *Server) loadGRPCRateLimitConfig() { + cfg := s.serviceMiddlewarePersistOptions.GetGRPCRateLimitConfig().LimiterConfig + for key := range cfg { + value := cfg[key] + s.grpcServiceRateLimiter.Update(key, ratelimit.UpdateDimensionConfig(&value)) + } +} + // ReplicateFileToMember is used to synchronize state to a member. // Each member will write `data` to a local file named `name`. // For security reason, data should be in JSON format.