Skip to content

Commit

Permalink
Fix tikv#4373 : add comment
Browse files Browse the repository at this point in the history
Signed-off-by: Cabinfever_B <cabinfeveroier@gmail.com>
  • Loading branch information
CabinfeverB committed Dec 20, 2021
1 parent e104967 commit 1c186dd
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 40 deletions.
2 changes: 2 additions & 0 deletions pkg/apiutil/apiutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ func GetIPAddrFromGRPCContext(ctx context.Context) string {
return ip
}

// GetPeerAddrFromGRPCContext return gRPC client real IP if gateway or proxy put real IP In
func GetRealIPAddrFromGRPCContext(ctx context.Context) (string, bool) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
Expand All @@ -155,6 +156,7 @@ func GetRealIPAddrFromGRPCContext(ctx context.Context) (string, bool) {
return realIPs[0], true
}

// GetPeerAddrFromGRPCContext return gRPC client IP which may be a proxy IP from peer info
func GetPeerAddrFromGRPCContext(ctx context.Context) string {
var addr string
if pr, ok := peer.FromContext(ctx); ok {
Expand Down
2 changes: 1 addition & 1 deletion server/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func NewHandler(ctx context.Context, svr *server.Server) (http.Handler, server.S
router.PathPrefix(apiPrefix).Handler(negroni.New(
serverapi.NewRuntimeServiceValidator(svr, group),
serverapi.NewRedirector(svr),
svr.SelfProtectionHandler,
serverapi.NewSelfProtector(svr),
negroni.Wrap(r)),
)

Expand Down
7 changes: 3 additions & 4 deletions server/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ type Config struct {

ReplicationMode ReplicationModeConfig `toml:"replication-mode" json:"replication-mode"`

SelfProtectionConfig SelfProtectionConfig `toml:"SelfProtectionConfig"`
SelfProtectionConfig SelfProtectionConfig `toml:"self-protection-config"`
}

// NewConfig creates a new config.
Expand Down Expand Up @@ -267,10 +267,9 @@ var (
DefaultStoreLimit = StoreLimit{AddPeer: 15, RemovePeer: 15}
// DefaultTiFlashStoreLimit is the default TiFlash store limit of add peer and remove peer.
DefaultTiFlashStoreLimit = StoreLimit{AddPeer: 30, RemovePeer: 30}

// ServiceSelfProtectionConfig is used for self protection mechanism
DefaultServiceSelfProtectionConfig = []ServiceSelfprotectionConfig{}

HTTPAPIServiceNames = map[string]string{}
// GRPCMethodServiceNames is used to get logic service name of the gRPC service
GRPCMethodServiceNames = map[string]string{}
)

Expand Down
79 changes: 64 additions & 15 deletions server/middleware/self_protection.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,33 @@ var (
componentAnonymousValue = "anonymous"
)

// SelfProtectionHandler a
// SelfProtectionHandler is a framework to handle self protection mechanism
// Self-protection granularity is a logical service
type SelfProtectionHandler struct {
// grpcServiceNames is used to find the service name of grpc method
GrpcServiceNames map[string]string
// ServiceHandlers a
ServiceHandlers map[string]*ServiceSelfProtectionHandler
}

// MergeSelfProtectionConfig is used for when both the user configuration and the default configuration exist
func MergeSelfProtectionConfig(handlers map[string]*ServiceSelfProtectionHandler, highPriorityConfigs []config.ServiceSelfprotectionConfig, lowPriorityConfigs []config.ServiceSelfprotectionConfig) {
for i := range highPriorityConfigs {
serviceName := highPriorityConfigs[i].ServiceName
serviceSelfProtectionHandler := NewServiceSelfProtectionHandler(&highPriorityConfigs[i])
handlers[serviceName] = serviceSelfProtectionHandler
}
for i := range lowPriorityConfigs {
serviceName := lowPriorityConfigs[i].ServiceName
if _, find := handlers[serviceName]; find {
continue
}
serviceSelfProtectionHandler := NewServiceSelfProtectionHandler(&lowPriorityConfigs[i])
handlers[serviceName] = serviceSelfProtectionHandler
}
}

// GetHTTPAPIServiceName return mux route name registered for ServiceName
func (h *SelfProtectionHandler) GetHTTPAPIServiceName(req *http.Request) (string, bool) {
route := mux.CurrentRoute(req)
if route != nil {
Expand All @@ -63,11 +82,13 @@ func (h *SelfProtectionHandler) GetHTTPAPIServiceName(req *http.Request) (string
return "", false
}

// GetGRPCServiceName return ServiceName by mapping gRPC method name
func (h *SelfProtectionHandler) GetGRPCServiceName(method string) (string, bool) {
serviceName, ok := h.GrpcServiceNames[method]
return serviceName, ok
}

// GetComponentNameOnHTTP return component name from Request Header
func (h *SelfProtectionHandler) GetComponentNameOnHTTP(r *http.Request) string {
componentName := r.Header.Get(componentSignatureKey)
if componentName == "" {
Expand All @@ -76,6 +97,7 @@ func (h *SelfProtectionHandler) GetComponentNameOnHTTP(r *http.Request) string {
return componentName
}

// GetComponentNameOnGRPC return component name from gRPC metadata
func (h *SelfProtectionHandler) GetComponentNameOnGRPC(ctx context.Context) string {
md, ok := metadata.FromIncomingContext(ctx)
if ok {
Expand All @@ -87,6 +109,7 @@ func (h *SelfProtectionHandler) GetComponentNameOnGRPC(ctx context.Context) stri
return componentAnonymousValue
}

// SelfProtectionHandleHTTP is used to handle http api self protection
func (h *SelfProtectionHandler) SelfProtectionHandleHTTP(req *http.Request) bool {
serviceName, foundName := h.GetHTTPAPIServiceName(req)
if !foundName {
Expand All @@ -112,6 +135,7 @@ func (h *SelfProtectionHandler) SelfProtectionHandleHTTP(req *http.Request) bool
return limitAllow
}

// SelfProtectionHandleGRPC is used to handle gRPC self protection
func (h *SelfProtectionHandler) SelfProtectionHandleGRPC(fullMethod string, ctx context.Context) bool {
serviceName, foundName := h.GetGRPCServiceName(fullMethod)
if !foundName {
Expand All @@ -137,17 +161,20 @@ func (h *SelfProtectionHandler) SelfProtectionHandleGRPC(fullMethod string, ctx
return limitAllow
}

// ServiceSelfProtectionHandler currently includes QPS rate limiter and audit logger
type ServiceSelfProtectionHandler struct {
apiRateLimiter *APIRateLimiter
auditLogger *AuditLogger
}

// NewServiceSelfProtectionHandler return a new ServiceSelfProtectionHandler
func NewServiceSelfProtectionHandler(config *config.ServiceSelfprotectionConfig) *ServiceSelfProtectionHandler {
handler := &ServiceSelfProtectionHandler{}
handler.Update(config)
return handler
}

// Update is used to update ServiceSelfProtectionHandler
func (h *ServiceSelfProtectionHandler) Update(config *config.ServiceSelfprotectionConfig) {
if h.apiRateLimiter == nil {
h.apiRateLimiter = NewAPIRateLimiter(config)
Expand All @@ -161,20 +188,23 @@ func (h *ServiceSelfProtectionHandler) Update(config *config.ServiceSelfprotecti
}
}

// RateLimitAllow is used to check whether the rate limit allow request process
func (h *ServiceSelfProtectionHandler) RateLimitAllow(componentName string) bool {
if h.apiRateLimiter == nil {
return true
}
return h.apiRateLimiter.Allow(componentName)
}

// EnableAudit is used to check Whether to enable the audit handle
func (h *ServiceSelfProtectionHandler) EnableAudit() bool {
if h.auditLogger == nil {
return true
}
return h.auditLogger.Enable()
}

// GetLogInfoFromHTTP return LogInfo from http.Request
func GetLogInfoFromHTTP(req *http.Request, logInfo *LogInfo) {
// Get IP
logInfo.IP = apiutil.GetIPAddrFromHTTPRequest(req)
Expand All @@ -186,16 +216,21 @@ func GetLogInfoFromHTTP(req *http.Request, logInfo *LogInfo) {
req.Body = io.NopCloser(bytes.NewBuffer(buf))
}

// GetLogInfoFromHTTP return LogInfo from Context
func GetLogInfoFromGRPC(ctx context.Context, logInfo *LogInfo) {
// Get IP
logInfo.IP = apiutil.GetIPAddrFromGRPCContext(ctx)
// gRPC can't get Param in middware
}

// AuditLog is a entrance to access AuditLoggor
func (h *ServiceSelfProtectionHandler) AuditLog(logInfo *LogInfo) {
h.auditLogger.Log(logInfo)
}

// APIRateLimiter is used to limit unnecessary and excess request
// Currently support QPS rate limit by compoenent
// It depends on the rate.Limiter which implement a token-bucket algorithm
type APIRateLimiter struct {
mu sync.RWMutex

Expand All @@ -207,12 +242,14 @@ type APIRateLimiter struct {
componentQPSRateLimiter map[string]*rate.Limiter
}

// NewAPIRateLimiter create a new api rate limiter
func NewAPIRateLimiter(config *config.ServiceSelfprotectionConfig) *APIRateLimiter {
limiter := &APIRateLimiter{}
limiter.Update(config)
return limiter
}

// Update will replace all handler by service
func (rl *APIRateLimiter) Update(config *config.ServiceSelfprotectionConfig) {
rl.mu.Lock()
defer rl.mu.Unlock()
Expand All @@ -233,6 +270,7 @@ func (rl *APIRateLimiter) Update(config *config.ServiceSelfprotectionConfig) {
}
}

// QPSAllow firstly check component token bucket and then check total token bucket
func (rl *APIRateLimiter) QPSAllow(componentName string) bool {
if !rl.enableQPSLimit {
return true
Expand All @@ -249,12 +287,14 @@ func (rl *APIRateLimiter) QPSAllow(componentName string) bool {
return isComponentQPSLimit && isTotalQPSLimit
}

// Allow currentlt only supports QPS rate limit
func (rl *APIRateLimiter) Allow(componentName string) bool {
rl.mu.RLock()
defer rl.mu.RUnlock()
return rl.QPSAllow(componentName)
}

// LogInfo stores needed api request info
type LogInfo struct {
ServiceName string
Method string
Expand All @@ -265,18 +305,24 @@ type LogInfo struct {
RateLimitAllow bool
}

// AuditLogger is used to record some information about the service for auditing when problems occur
// Currently it can be bonded two audit labels
// LoggerLabelLog("log") means AuditLogger will restore info in local file system
// LoggerLabelMonitored("Monitored") means AuditLogger will report info to promethus
type AuditLogger struct {
mu sync.RWMutex
enableAudit bool
labels map[string]bool
}

// NewAuditLogger return a new AuditLogger
func NewAuditLogger(config *config.ServiceSelfprotectionConfig) *AuditLogger {
logger := &AuditLogger{}
logger.Update(config)
return logger
}

// Update AuditLogger by config
func (logger *AuditLogger) Update(config *config.ServiceSelfprotectionConfig) {
logger.mu.Lock()
defer logger.mu.Unlock()
Expand All @@ -287,13 +333,17 @@ func (logger *AuditLogger) Update(config *config.ServiceSelfprotectionConfig) {
}
}

// Enable is used to check Whether to enable the audit handle
func (logger *AuditLogger) Enable() bool {
logger.mu.RLock()
defer logger.mu.RUnlock()
return logger.enableAudit
}

// Log is used to handle log action
func (logger *AuditLogger) Log(info *LogInfo) {
logger.mu.RLock()
defer logger.mu.RUnlock()
if isLog, ok := logger.labels[LoggerLabelLog]; ok {
if isLog {
log.Info("service_audit_detailed",
Expand All @@ -311,14 +361,7 @@ func (logger *AuditLogger) Log(info *LogInfo) {
}
}

func (h *SelfProtectionHandler) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
if h.SelfProtectionHandleHTTP(r) {
next(w, r)
} else {
http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
}
}

// UnaryServerInterceptor returns a gRPC stream server interceptor to handle self protection in gRPC unary service
func (h *SelfProtectionHandler) UnaryServerInterceptor() grpc.UnaryServerInterceptor {
return grpc.UnaryServerInterceptor(func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
if h.SelfProtectionHandleGRPC(info.FullMethod, ctx) {
Expand All @@ -328,6 +371,7 @@ func (h *SelfProtectionHandler) UnaryServerInterceptor() grpc.UnaryServerInterce
})
}

// StreamServerInterceptor returns a gRPC stream server interceptor to handle self protection in gRPC stream service
func (h *SelfProtectionHandler) StreamServerInterceptor() grpc.StreamServerInterceptor {
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
if h.SelfProtectionHandleGRPC(info.FullMethod, stream.Context()) {
Expand All @@ -337,16 +381,18 @@ func (h *SelfProtectionHandler) StreamServerInterceptor() grpc.StreamServerInter
}
}

// UserSignatureGRPCClientInterceptorBuilder add component user signature in gRPC
type UserSignatureGRPCClientInterceptorBuilder struct {
// ComponentSignatureGRPCClientInterceptorBuilder add component signature in gRPC
type ComponentSignatureGRPCClientInterceptorBuilder struct {
component string
}

func (builder *UserSignatureGRPCClientInterceptorBuilder) SetComponentName(component string) {
// SetComponentName set component name
func (builder *ComponentSignatureGRPCClientInterceptorBuilder) SetComponentName(component string) {
builder.component = component
}

func (builder *UserSignatureGRPCClientInterceptorBuilder) UnaryClientInterceptor() grpc.UnaryClientInterceptor {
// UnaryClientInterceptor return a ComponentSignature UnaryClientInterceptor
func (builder *ComponentSignatureGRPCClientInterceptorBuilder) UnaryClientInterceptor() grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
headerData := metadata.Pairs(componentSignatureKey, builder.component)
ctxH := metadata.NewOutgoingContext(ctx, headerData)
Expand All @@ -355,15 +401,17 @@ func (builder *UserSignatureGRPCClientInterceptorBuilder) UnaryClientInterceptor
}
}

func (builder *UserSignatureGRPCClientInterceptorBuilder) StreamClientInterceptor() grpc.StreamClientInterceptor {
// StreamClientInterceptor return a ComponentSignature StreamClientInterceptor
func (builder *ComponentSignatureGRPCClientInterceptorBuilder) StreamClientInterceptor() grpc.StreamClientInterceptor {
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
headerData := metadata.Pairs(componentSignatureKey, builder.component)
ctxH := metadata.NewOutgoingContext(ctx, headerData)
return streamer(ctxH, desc, cc, method, opts...)
}
}

func (builder *UserSignatureGRPCClientInterceptorBuilder) UserSignatureDialOptions() []grpc.DialOption {
// UserSignatureDialOptions create opts with ComponentSignature Interceptors
func (builder *ComponentSignatureGRPCClientInterceptorBuilder) UserSignatureDialOptions() []grpc.DialOption {
streamInterceptors := []grpc.StreamClientInterceptor{builder.StreamClientInterceptor()}
unaryInterceptors := []grpc.UnaryClientInterceptor{builder.UnaryClientInterceptor()}
opts := []grpc.DialOption{grpc.WithChainStreamInterceptor(streamInterceptors...), grpc.WithChainUnaryInterceptor(unaryInterceptors...)}
Expand All @@ -376,6 +424,7 @@ type UserSignatureRoundTripper struct {
Component string
}

// RoundTrip is used to implement RoundTripper
func (rt *UserSignatureRoundTripper) RoundTrip(req *http.Request) (resp *http.Response, err error) {
req.Header.Set(componentSignatureKey, rt.Component)
// Send the request, get the response and the error
Expand Down
27 changes: 7 additions & 20 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,16 +349,19 @@ func (s *Server) startEtcd(ctx context.Context) error {
return nil
}

// NewSelfProtectionHandler returns a new SelfProtectionHandler with config
func NewSelfProtectionHandler(server *Server) *middleware.SelfProtectionHandler {
handler := &middleware.SelfProtectionHandler{
GrpcServiceNames: config.GRPCMethodServiceNames,
ServiceHandlers: make(map[string]*middleware.ServiceSelfProtectionHandler),
}
UpdateServiceHandlers(handler, server)
updateServiceHandlers(handler, server)
return handler
}

func UpdateServiceHandlers(h *middleware.SelfProtectionHandler, server *Server) {
// updateServiceHandlers update ServiceHandlers
// it will make a new map and merge user-defined handlers and dafault handlers with different priority according to enableUseDefault
func updateServiceHandlers(h *middleware.SelfProtectionHandler, server *Server) {
if server == nil {
return
}
Expand All @@ -373,26 +376,10 @@ func UpdateServiceHandlers(h *middleware.SelfProtectionHandler, server *Server)
}
// if enableUseDefault is 1, config defined by users has higher priority than dafault
} else if enableUseDefault == 1 {
mergeSelfProtectionConfig(h.ServiceHandlers, server.GetConfig().SelfProtectionConfig.ServiceSelfprotectionConfig, config.DefaultServiceSelfProtectionConfig)
middleware.MergeSelfProtectionConfig(h.ServiceHandlers, server.GetConfig().SelfProtectionConfig.ServiceSelfprotectionConfig, config.DefaultServiceSelfProtectionConfig)
// if enableUseDefault is 0, dafault config has higher priority than config defined by users
} else {
mergeSelfProtectionConfig(h.ServiceHandlers, config.DefaultServiceSelfProtectionConfig, server.GetConfig().SelfProtectionConfig.ServiceSelfprotectionConfig)
}
}

func mergeSelfProtectionConfig(handlers map[string]*middleware.ServiceSelfProtectionHandler, highPriorityConfigs []config.ServiceSelfprotectionConfig, lowPriorityConfigs []config.ServiceSelfprotectionConfig) {
for i := range highPriorityConfigs {
serviceName := highPriorityConfigs[i].ServiceName
serviceSelfProtectionHandler := middleware.NewServiceSelfProtectionHandler(&highPriorityConfigs[i])
handlers[serviceName] = serviceSelfProtectionHandler
}
for i := range lowPriorityConfigs {
serviceName := lowPriorityConfigs[i].ServiceName
if _, find := handlers[serviceName]; find {
continue
}
serviceSelfProtectionHandler := middleware.NewServiceSelfProtectionHandler(&lowPriorityConfigs[i])
handlers[serviceName] = serviceSelfProtectionHandler
middleware.MergeSelfProtectionConfig(h.ServiceHandlers, config.DefaultServiceSelfProtectionConfig, server.GetConfig().SelfProtectionConfig.ServiceSelfprotectionConfig)
}
}

Expand Down

0 comments on commit 1c186dd

Please sign in to comment.