From 16bbac9d6c1328913b3bddc96fa7103de3c849e1 Mon Sep 17 00:00:00 2001 From: Cabinfever_B Date: Wed, 15 Dec 2021 18:26:09 +0800 Subject: [PATCH] Fix #4373 : change HTTP service source Signed-off-by: Cabinfever_B --- server/middleware/self_protection.go | 40 +++++++++++----------------- 1 file changed, 15 insertions(+), 25 deletions(-) diff --git a/server/middleware/self_protection.go b/server/middleware/self_protection.go index 53c51210386..b41456c6aab 100644 --- a/server/middleware/self_protection.go +++ b/server/middleware/self_protection.go @@ -23,6 +23,7 @@ import ( "sync" "time" + "github.com/gorilla/mux" "github.com/pingcap/log" "github.com/tikv/pd/pkg/apiutil" PDServer "github.com/tikv/pd/server" @@ -72,8 +73,8 @@ func (h *SelfProtectionHandler) UpdateServiceHandlers() { } enableUseDefault := h.s.GetConfig().SelfProtectionConfig.EnableUseDefault h.ServiceHandlers = make(map[string]*ServiceSelfProtectionHandler) - // if enableUseDefault is 0, only use config defined by users - if enableUseDefault == 0 { + // if enableUseDefault is 2, only use config defined by users + if enableUseDefault == 2 { for i := range h.s.GetConfig().SelfProtectionConfig.ServiceSelfprotectionConfig { serviceName := h.s.GetConfig().SelfProtectionConfig.ServiceSelfprotectionConfig[i].ServiceName serviceSelfProtectionHandler := NewServiceSelfProtectionHandler(&h.s.GetConfig().SelfProtectionConfig.ServiceSelfprotectionConfig[i]) @@ -82,7 +83,7 @@ func (h *SelfProtectionHandler) UpdateServiceHandlers() { // if enableUseDefault is 1, config defined by users has higher priority than dafault } else if enableUseDefault == 1 { mergeSelfProtectionConfig(h.ServiceHandlers, h.s.GetConfig().SelfProtectionConfig.ServiceSelfprotectionConfig, config.DefaultServiceSelfProtectionConfig) - // if enableUseDefault is 1, dafault config has higher priority than config defined by users + // if enableUseDefault is 0, dafault config has higher priority than config defined by users } else { mergeSelfProtectionConfig(h.ServiceHandlers, config.DefaultServiceSelfProtectionConfig, h.s.GetConfig().SelfProtectionConfig.ServiceSelfprotectionConfig) } @@ -104,9 +105,14 @@ func mergeSelfProtectionConfig(handlers map[string]*ServiceSelfProtectionHandler } } -func (h *SelfProtectionHandler) GetHTTPAPIServiceName(url string) (string, bool) { - serviceName, ok := h.httpAPIServiceNames[url] - return serviceName, ok +func (h *SelfProtectionHandler) GetHTTPAPIServiceName(req *http.Request) (string, bool) { + route := mux.CurrentRoute(req) + if route != nil { + if route.GetName() != "" { + return route.GetName(), true + } + } + return "", false } func (h *SelfProtectionHandler) GetGRPCServiceName(method string) (string, bool) { @@ -133,18 +139,8 @@ func (h *SelfProtectionHandler) GetComponentNameOnGRPC(ctx context.Context) stri return componentAnonymousValue } -func (h *SelfProtectionHandler) SelfProtectionHandle(componentName string, serviceName string) bool { - serviceHandler, ok := h.ServiceHandlers[serviceName] - if !ok { - return true - } - limitAllow := serviceHandler.Allow(componentName) - - return limitAllow -} - func (h *SelfProtectionHandler) SelfProtectionHandleHTTP(req *http.Request) bool { - serviceName, foundName := h.GetHTTPAPIServiceName(req.URL.Path) + serviceName, foundName := h.GetHTTPAPIServiceName(req) if !foundName { return true } @@ -157,7 +153,7 @@ func (h *SelfProtectionHandler) SelfProtectionHandleHTTP(req *http.Request) bool if serviceHandler.EnableAudit() { logInfo := &LogInfo{ ServiceName: serviceName, - Method: fmt.Sprintf("http:%s", req.URL.Path), + Method: fmt.Sprintf("HTTP-%s:%s", req.Method, req.URL.Path), Component: componentSignature, TimeStamp: time.Now().Local().String(), RateLimitAllow: limitAllow, @@ -368,13 +364,7 @@ func (logger *AuditLogger) Audit(info *LogInfo) { } func (h *SelfProtectionHandler) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { - serviceName, foundName := h.GetHTTPAPIServiceName(r.URL.Path) - if !foundName { - next(w, r) - return - } - componentSignature := h.GetComponentNameOnHTTP(r) - if h.SelfProtectionHandle(componentSignature, serviceName) { + if h.SelfProtectionHandleHTTP(r) { next(w, r) } else { http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)