Skip to content

Commit

Permalink
Fix tikv#4373 : change HTTP service source
Browse files Browse the repository at this point in the history
Signed-off-by: Cabinfever_B <cabinfeveroier@gmail.com>
  • Loading branch information
CabinfeverB committed Dec 15, 2021
1 parent 7bd6353 commit 16bbac9
Showing 1 changed file with 15 additions and 25 deletions.
40 changes: 15 additions & 25 deletions server/middleware/self_protection.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"sync"
"time"

"github.com/gorilla/mux"
"github.com/pingcap/log"
"github.com/tikv/pd/pkg/apiutil"
PDServer "github.com/tikv/pd/server"
Expand Down Expand Up @@ -72,8 +73,8 @@ func (h *SelfProtectionHandler) UpdateServiceHandlers() {
}
enableUseDefault := h.s.GetConfig().SelfProtectionConfig.EnableUseDefault
h.ServiceHandlers = make(map[string]*ServiceSelfProtectionHandler)
// if enableUseDefault is 0, only use config defined by users
if enableUseDefault == 0 {
// if enableUseDefault is 2, only use config defined by users
if enableUseDefault == 2 {
for i := range h.s.GetConfig().SelfProtectionConfig.ServiceSelfprotectionConfig {
serviceName := h.s.GetConfig().SelfProtectionConfig.ServiceSelfprotectionConfig[i].ServiceName
serviceSelfProtectionHandler := NewServiceSelfProtectionHandler(&h.s.GetConfig().SelfProtectionConfig.ServiceSelfprotectionConfig[i])
Expand All @@ -82,7 +83,7 @@ func (h *SelfProtectionHandler) UpdateServiceHandlers() {
// if enableUseDefault is 1, config defined by users has higher priority than dafault
} else if enableUseDefault == 1 {
mergeSelfProtectionConfig(h.ServiceHandlers, h.s.GetConfig().SelfProtectionConfig.ServiceSelfprotectionConfig, config.DefaultServiceSelfProtectionConfig)
// if enableUseDefault is 1, dafault config has higher priority than config defined by users
// if enableUseDefault is 0, dafault config has higher priority than config defined by users
} else {
mergeSelfProtectionConfig(h.ServiceHandlers, config.DefaultServiceSelfProtectionConfig, h.s.GetConfig().SelfProtectionConfig.ServiceSelfprotectionConfig)
}
Expand All @@ -104,9 +105,14 @@ func mergeSelfProtectionConfig(handlers map[string]*ServiceSelfProtectionHandler
}
}

func (h *SelfProtectionHandler) GetHTTPAPIServiceName(url string) (string, bool) {
serviceName, ok := h.httpAPIServiceNames[url]
return serviceName, ok
func (h *SelfProtectionHandler) GetHTTPAPIServiceName(req *http.Request) (string, bool) {
route := mux.CurrentRoute(req)
if route != nil {
if route.GetName() != "" {
return route.GetName(), true
}
}
return "", false
}

func (h *SelfProtectionHandler) GetGRPCServiceName(method string) (string, bool) {
Expand All @@ -133,18 +139,8 @@ func (h *SelfProtectionHandler) GetComponentNameOnGRPC(ctx context.Context) stri
return componentAnonymousValue
}

func (h *SelfProtectionHandler) SelfProtectionHandle(componentName string, serviceName string) bool {
serviceHandler, ok := h.ServiceHandlers[serviceName]
if !ok {
return true
}
limitAllow := serviceHandler.Allow(componentName)

return limitAllow
}

func (h *SelfProtectionHandler) SelfProtectionHandleHTTP(req *http.Request) bool {
serviceName, foundName := h.GetHTTPAPIServiceName(req.URL.Path)
serviceName, foundName := h.GetHTTPAPIServiceName(req)
if !foundName {
return true
}
Expand All @@ -157,7 +153,7 @@ func (h *SelfProtectionHandler) SelfProtectionHandleHTTP(req *http.Request) bool
if serviceHandler.EnableAudit() {
logInfo := &LogInfo{
ServiceName: serviceName,
Method: fmt.Sprintf("http:%s", req.URL.Path),
Method: fmt.Sprintf("HTTP-%s:%s", req.Method, req.URL.Path),
Component: componentSignature,
TimeStamp: time.Now().Local().String(),
RateLimitAllow: limitAllow,
Expand Down Expand Up @@ -368,13 +364,7 @@ func (logger *AuditLogger) Audit(info *LogInfo) {
}

func (h *SelfProtectionHandler) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
serviceName, foundName := h.GetHTTPAPIServiceName(r.URL.Path)
if !foundName {
next(w, r)
return
}
componentSignature := h.GetComponentNameOnHTTP(r)
if h.SelfProtectionHandle(componentSignature, serviceName) {
if h.SelfProtectionHandleHTTP(r) {
next(w, r)
} else {
http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
Expand Down

0 comments on commit 16bbac9

Please sign in to comment.