diff --git a/pkg/apiutil/apiutil.go b/pkg/apiutil/apiutil.go index d98d6157ec8..717d8400c02 100644 --- a/pkg/apiutil/apiutil.go +++ b/pkg/apiutil/apiutil.go @@ -23,6 +23,7 @@ import ( "strconv" "strings" + "github.com/gorilla/mux" "github.com/pingcap/errcode" "github.com/pingcap/errors" "github.com/pingcap/log" @@ -190,3 +191,12 @@ func (rt *ComponentSignatureRoundTripper) RoundTrip(req *http.Request) (resp *ht resp, err = rt.proxied.RoundTrip(req) return } + +// GetRouteName return mux route name registered +func GetRouteName(req *http.Request) string { + route := mux.CurrentRoute(req) + if route != nil { + return route.GetName() + } + return "" +} diff --git a/pkg/requestutil/context.go b/pkg/requestutil/context.go new file mode 100644 index 00000000000..c2bdef9b343 --- /dev/null +++ b/pkg/requestutil/context.go @@ -0,0 +1,38 @@ +// Copyright 2022 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package requestutil + +import ( + "context" +) + +// The key type is unexported to prevent collisions +type key int + +const ( + // requestInfoKey is the context key for the request compoenent. + requestInfoKey key = iota +) + +// WithRequestInfo returns a copy of parent in which the request info value is set +func WithRequestInfo(parent context.Context, requestInfo RequestInfo) context.Context { + return context.WithValue(parent, requestInfoKey, requestInfo) +} + +// RequestInfoFrom returns the value of the request info key on the ctx +func RequestInfoFrom(ctx context.Context) (RequestInfo, bool) { + requestInfo, ok := ctx.Value(requestInfoKey).(RequestInfo) + return requestInfo, ok +} diff --git a/pkg/requestutil/context_test.go b/pkg/requestutil/context_test.go new file mode 100644 index 00000000000..60bd3fc828c --- /dev/null +++ b/pkg/requestutil/context_test.go @@ -0,0 +1,57 @@ +// Copyright 2022 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package requestutil + +import ( + "context" + "testing" + + . "github.com/pingcap/check" +) + +func Test(t *testing.T) { + TestingT(t) +} + +var _ = Suite(&testRequestContextSuite{}) + +type testRequestContextSuite struct { +} + +func (s *testRequestContextSuite) TestRequestInfo(c *C) { + ctx := context.Background() + _, ok := RequestInfoFrom(ctx) + c.Assert(ok, Equals, false) + ctx = WithRequestInfo(ctx, + RequestInfo{ + ServiceLabel: "test label", + Method: "POST", + Component: "pdctl", + IP: "localhost", + URLParam: "{\"id\"=1}", + BodyParam: "{\"state\"=\"Up\"}", + TimeStamp: "2022", + }) + result, ok := RequestInfoFrom(ctx) + c.Assert(result, NotNil) + c.Assert(ok, Equals, true) + c.Assert(result.ServiceLabel, Equals, "test label") + c.Assert(result.Method, Equals, "POST") + c.Assert(result.Component, Equals, "pdctl") + c.Assert(result.IP, Equals, "localhost") + c.Assert(result.URLParam, Equals, "{\"id\"=1}") + c.Assert(result.BodyParam, Equals, "{\"state\"=\"Up\"}") + c.Assert(result.TimeStamp, Equals, "2022") +} diff --git a/pkg/requestutil/request_info.go b/pkg/requestutil/request_info.go new file mode 100644 index 00000000000..2197081e576 --- /dev/null +++ b/pkg/requestutil/request_info.go @@ -0,0 +1,70 @@ +// Copyright 2022 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package requestutil + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "time" + + "github.com/tikv/pd/pkg/apiutil" +) + +// RequestInfo holds service information from http.Request +type RequestInfo struct { + ServiceLabel string + Method string + Component string + IP string + TimeStamp string + URLParam string + BodyParam string +} + +// GetRequestInfo returns request info needed from http.Request +func GetRequestInfo(r *http.Request) RequestInfo { + return RequestInfo{ + ServiceLabel: apiutil.GetRouteName(r), + Method: fmt.Sprintf("%s/%s:%s", r.Proto, r.Method, r.URL.Path), + Component: apiutil.GetComponentNameOnHTTP(r), + IP: apiutil.GetIPAddrFromHTTPRequest(r), + TimeStamp: time.Now().Local().String(), + URLParam: getURLParam(r), + BodyParam: getBodyParam(r), + } +} + +func getURLParam(r *http.Request) string { + buf, err := json.Marshal(r.URL.Query()) + if err != nil { + return "" + } + return string(buf) +} + +func getBodyParam(r *http.Request) string { + if r.Body == nil { + return "" + } + // http request body is a io.Reader between bytes.Reader and strings.Reader, it only has EOF error + buf, _ := io.ReadAll(r.Body) + r.Body.Close() + bodyParam := string(buf) + r.Body = io.NopCloser(bytes.NewBuffer(buf)) + return bodyParam +} diff --git a/server/api/admin.go b/server/api/admin.go index fd41df8c5b8..5f83dcc65ef 100644 --- a/server/api/admin.go +++ b/server/api/admin.go @@ -138,3 +138,21 @@ func (h *adminHandler) UpdateWaitAsyncTime(w http.ResponseWriter, r *http.Reques cluster.GetReplicationMode().UpdateMemberWaitAsyncTime(memberID) h.rd.JSON(w, http.StatusOK, nil) } + +// @Tags admin +// @Summary switch Service Middlewares including ServiceInfo, Audit and rate limit +// @Param enable query string true "enable" Enums(true, false) +// @Produce json +// @Success 200 {string} string "Switching Service middleware is successful." +// @Failure 400 {string} string "The input is invalid." +// @Router /admin/service-middleware [POST] +func (h *adminHandler) HanldeServiceMiddlewareSwitch(w http.ResponseWriter, r *http.Request) { + enableStr := r.URL.Query().Get("enable") + enable, err := strconv.ParseBool(enableStr) + if err != nil { + h.rd.JSON(w, http.StatusBadRequest, "The input is invalid.") + return + } + h.svr.SetServiceMiddleware(enable) + h.rd.JSON(w, http.StatusOK, "Switching Service middleware is successful.") +} diff --git a/server/api/admin_test.go b/server/api/admin_test.go index 738ae2ba273..533f089e548 100644 --- a/server/api/admin_test.go +++ b/server/api/admin_test.go @@ -173,3 +173,45 @@ func (s *testTSOSuite) TestResetTS(c *C) { c.Assert(err, NotNil) c.Assert(err.Error(), Equals, "\"invalid tso value\"\n") } + +var _ = Suite(&testServiceSuite{}) + +type testServiceSuite struct { + svr *server.Server + cleanup cleanUpFunc +} + +func (s *testServiceSuite) SetUpSuite(c *C) { + s.svr, s.cleanup = mustNewServer(c) + mustWaitLeader(c, []*server.Server{s.svr}) + + mustBootstrapCluster(c, s.svr) + mustPutStore(c, s.svr, 1, metapb.StoreState_Up, nil) +} + +func (s *testServiceSuite) TearDownSuite(c *C) { + s.cleanup() +} + +func (s *testServiceSuite) TestSwitchServiceMiddleware(c *C) { + urlPrefix := fmt.Sprintf("%s%s/api/v1/admin/service-middleware", s.svr.GetAddr(), apiPrefix) + disableURL := fmt.Sprintf("%s?enable=false", urlPrefix) + err := postJSON(testDialClient, disableURL, nil, + func(res []byte, code int) { + c.Assert(string(res), Equals, "\"Switching Service middleware is successful.\"\n") + c.Assert(code, Equals, http.StatusOK) + }) + + c.Assert(err, IsNil) + c.Assert(s.svr.IsServiceMiddlewareEnabled(), Equals, false) + + enableURL := fmt.Sprintf("%s?enable=true", urlPrefix) + err = postJSON(testDialClient, enableURL, nil, + func(res []byte, code int) { + c.Assert(string(res), Equals, "\"Switching Service middleware is successful.\"\n") + c.Assert(code, Equals, http.StatusOK) + }) + + c.Assert(err, IsNil) + c.Assert(s.svr.IsServiceMiddlewareEnabled(), Equals, true) +} diff --git a/server/api/middleware.go b/server/api/middleware.go index 1423d81d7d9..0d5824cce66 100644 --- a/server/api/middleware.go +++ b/server/api/middleware.go @@ -18,12 +18,45 @@ import ( "context" "net/http" + "github.com/pingcap/failpoint" "github.com/tikv/pd/pkg/errs" + "github.com/tikv/pd/pkg/requestutil" "github.com/tikv/pd/server" "github.com/tikv/pd/server/cluster" "github.com/unrolled/render" + "github.com/urfave/negroni" ) +// requestInfoMiddleware is used to gather info from requsetInfo +type requestInfoMiddleware struct { + svr *server.Server +} + +func newRequestInfoMiddleware(s *server.Server) negroni.Handler { + return &requestInfoMiddleware{svr: s} +} + +func (rm *requestInfoMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { + if !rm.svr.IsServiceMiddlewareEnabled() { + next(w, r) + return + } + + requestInfo := requestutil.GetRequestInfo(r) + r = r.WithContext(requestutil.WithRequestInfo(r.Context(), requestInfo)) + + failpoint.Inject("addRequestInfoMiddleware", func() { + w.Header().Add("service-label", requestInfo.ServiceLabel) + w.Header().Add("body-param", requestInfo.BodyParam) + w.Header().Add("url-param", requestInfo.URLParam) + w.Header().Add("method", requestInfo.Method) + w.Header().Add("component", requestInfo.Component) + w.Header().Add("ip", requestInfo.IP) + }) + + next(w, r) +} + type clusterMiddleware struct { s *server.Server rd *render.Render diff --git a/server/api/router.go b/server/api/router.go index 3c646edc350..e2abc5e4250 100644 --- a/server/api/router.go +++ b/server/api/router.go @@ -23,8 +23,26 @@ import ( "github.com/pingcap/failpoint" "github.com/tikv/pd/server" "github.com/unrolled/render" + "github.com/urfave/negroni" ) +// createRouteOption is used to register service for mux.Route +type createRouteOption func(route *mux.Route) + +// setMethods is used to add HTTP Method matcher for mux.Route +func setMethods(method ...string) createRouteOption { + return func(route *mux.Route) { + route.Methods(method...) + } +} + +// setQueries is used to add queries for mux.Route +func setQueries(pairs ...string) createRouteOption { + return func(route *mux.Route) { + route.Queries(pairs...) + } +} + func createStreamingRender() *render.Render { return render.New(render.Options{ StreamingJSON: true, @@ -37,6 +55,63 @@ func createIndentRender() *render.Render { }) } +// middlewareBuilder is used to build service middleware for HTTP api +type serviceMiddlewareBuilder struct { + svr *server.Server + handler http.Handler +} + +func newServiceMiddlewareBuilder(s *server.Server) *serviceMiddlewareBuilder { + return &serviceMiddlewareBuilder{ + svr: s, + handler: negroni.New( + newRequestInfoMiddleware(s), + // todo: add audit and rate limit middleware + ), + } +} + +// registerRouteHandleFunc is used to registers a new route which will be registered matcher or service by opts for the URL path +func (s *serviceMiddlewareBuilder) registerRouteHandleFunc(router *mux.Router, serviceLabel, path string, + handleFunc func(http.ResponseWriter, *http.Request), opts ...createRouteOption) *mux.Route { + route := router.HandleFunc(path, s.middlewareFunc(handleFunc)).Name(serviceLabel) + for _, opt := range opts { + opt(route) + } + return route +} + +// registerRouteHandleFunc is used to registers a new route which will be registered matcher or service by opts for the URL path +func (s *serviceMiddlewareBuilder) registerRouteHandler(router *mux.Router, serviceLabel, path string, + handler http.Handler, opts ...createRouteOption) *mux.Route { + route := router.Handle(path, s.middleware(handler)).Name(serviceLabel) + for _, opt := range opts { + opt(route) + } + return route +} + +// registerRouteHandleFunc is used to registers a new route which will be registered matcher or service by opts for the URL path prefix. +func (s *serviceMiddlewareBuilder) registerPathPrefixRouteHandler(router *mux.Router, serviceLabel, prefix string, + handler http.Handler, opts ...createRouteOption) *mux.Route { + route := router.PathPrefix(prefix).Handler(s.middleware(handler)).Name(serviceLabel) + for _, opt := range opts { + opt(route) + } + return route +} + +func (s *serviceMiddlewareBuilder) middleware(handler http.Handler) http.Handler { + return negroni.New(negroni.Wrap(s.handler), negroni.Wrap(handler)) +} + +func (s *serviceMiddlewareBuilder) middlewareFunc(next func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + s.handler.ServeHTTP(w, r) + next(w, r) + } +} + // The returned function is used as a lazy router to avoid the data race problem. // @title Placement Driver Core API // @version 1.0 @@ -61,171 +136,178 @@ func createRouter(prefix string, svr *server.Server) *mux.Router { escapeRouter := clusterRouter.NewRoute().Subrouter().UseEncodedPath() + serviceBuilder := newServiceMiddlewareBuilder(svr) + register := serviceBuilder.registerRouteHandler + registerPrefix := serviceBuilder.registerPathPrefixRouteHandler + registerFunc := serviceBuilder.registerRouteHandleFunc + operatorHandler := newOperatorHandler(handler, rd) - apiRouter.HandleFunc("/operators", operatorHandler.List).Methods("GET") - apiRouter.HandleFunc("/operators", operatorHandler.Post).Methods("POST") - apiRouter.HandleFunc("/operators/{region_id}", operatorHandler.Get).Methods("GET") - apiRouter.HandleFunc("/operators/{region_id}", operatorHandler.Delete).Methods("DELETE") + registerFunc(apiRouter, "GetOperators", "/operators", operatorHandler.List, setMethods("GET")) + registerFunc(apiRouter, "SetOperators", "/operators", operatorHandler.Post, setMethods("POST")) + registerFunc(apiRouter, "GetRegionOperator", "/operators/{region_id}", operatorHandler.Get, setMethods("GET")) + registerFunc(apiRouter, "DeleteRegionOperator", "/operators/{region_id}", operatorHandler.Delete, setMethods("DELETE")) checkerHandler := newCheckerHandler(svr, rd) - apiRouter.HandleFunc("/checker/{name}", checkerHandler.PauseOrResume).Methods("POST") - apiRouter.HandleFunc("/checker/{name}", checkerHandler.GetStatus).Methods("GET") + registerFunc(apiRouter, "SetChecker", "/checker/{name}", checkerHandler.PauseOrResume, setMethods("POST")) + registerFunc(apiRouter, "GetChecker", "/checker/{name}", checkerHandler.GetStatus, setMethods("GET")) schedulerHandler := newSchedulerHandler(svr, rd) - apiRouter.HandleFunc("/schedulers", schedulerHandler.List).Methods("GET") - apiRouter.HandleFunc("/schedulers", schedulerHandler.Post).Methods("POST") - apiRouter.HandleFunc("/schedulers/{name}", schedulerHandler.Delete).Methods("DELETE") - apiRouter.HandleFunc("/schedulers/{name}", schedulerHandler.PauseOrResume).Methods("POST") + registerFunc(apiRouter, "GetSchedulers", "/schedulers", schedulerHandler.List, setMethods("GET")) + registerFunc(apiRouter, "AddScheduler", "/schedulers", schedulerHandler.Post, setMethods("POST")) + registerFunc(apiRouter, "DeleteScheduler", "/schedulers/{name}", schedulerHandler.Delete, setMethods("DELETE")) + registerFunc(apiRouter, "PauseOrResumeScheduler", "/schedulers/{name}", schedulerHandler.PauseOrResume, setMethods("POST")) schedulerConfigHandler := newSchedulerConfigHandler(svr, rd) - apiRouter.PathPrefix("/scheduler-config").Handler(schedulerConfigHandler) + registerPrefix(apiRouter, "GetSchedulerConfig", "/scheduler-config", schedulerConfigHandler) clusterHandler := newClusterHandler(svr, rd) - apiRouter.Handle("/cluster", clusterHandler).Methods("GET") - apiRouter.HandleFunc("/cluster/status", clusterHandler.GetClusterStatus).Methods("GET") + register(apiRouter, "GetCluster", "/cluster", clusterHandler, setMethods("GET")) + registerFunc(apiRouter, "GetClusterStatus", "/cluster/status", clusterHandler.GetClusterStatus) confHandler := newConfHandler(svr, rd) - apiRouter.HandleFunc("/config", confHandler.Get).Methods("GET") - apiRouter.HandleFunc("/config", confHandler.Post).Methods("POST") - apiRouter.HandleFunc("/config/default", confHandler.GetDefault).Methods("GET") - apiRouter.HandleFunc("/config/schedule", confHandler.GetSchedule).Methods("GET") - apiRouter.HandleFunc("/config/schedule", confHandler.SetSchedule).Methods("POST") - apiRouter.HandleFunc("/config/pd-server", confHandler.GetPDServer).Methods("GET") - apiRouter.HandleFunc("/config/replicate", confHandler.GetReplication).Methods("GET") - apiRouter.HandleFunc("/config/replicate", confHandler.SetReplication).Methods("POST") - apiRouter.HandleFunc("/config/label-property", confHandler.GetLabelProperty).Methods("GET") - apiRouter.HandleFunc("/config/label-property", confHandler.SetLabelProperty).Methods("POST") - apiRouter.HandleFunc("/config/cluster-version", confHandler.GetClusterVersion).Methods("GET") - apiRouter.HandleFunc("/config/cluster-version", confHandler.SetClusterVersion).Methods("POST") - apiRouter.HandleFunc("/config/replication-mode", confHandler.GetReplicationMode).Methods("GET") - apiRouter.HandleFunc("/config/replication-mode", confHandler.SetReplicationMode).Methods("POST") + registerFunc(apiRouter, "GetConfig", "/config", confHandler.Get, setMethods("GET")) + registerFunc(apiRouter, "SetConfig", "/config", confHandler.Post, setMethods("POST")) + registerFunc(apiRouter, "GetDefaultConfig", "/config/default", confHandler.GetDefault, setMethods("GET")) + registerFunc(apiRouter, "GetScheduleConfig", "/config/schedule", confHandler.GetSchedule, setMethods("GET")) + registerFunc(apiRouter, "SetScheduleConfig", "/config/schedule", confHandler.SetSchedule, setMethods("POST")) + registerFunc(apiRouter, "GetPDServerConfig", "/config/pd-server", confHandler.GetPDServer, setMethods("GET")) + registerFunc(apiRouter, "GetReplicationConfig", "/config/replicate", confHandler.GetReplication, setMethods("GET")) + registerFunc(apiRouter, "SetReplicationConfig", "/config/replicate", confHandler.SetReplication, setMethods("POST")) + registerFunc(apiRouter, "GetLabelProperty", "/config/label-property", confHandler.GetLabelProperty, setMethods("GET")) + registerFunc(apiRouter, "SetLabelProperty", "/config/label-property", confHandler.SetLabelProperty, setMethods("POST")) + registerFunc(apiRouter, "GetClusterVersion", "/config/cluster-version", confHandler.GetClusterVersion, setMethods("GET")) + registerFunc(apiRouter, "SetClusterVersion", "/config/cluster-version", confHandler.SetClusterVersion, setMethods("POST")) + registerFunc(apiRouter, "GetReplicationMode", "/config/replication-mode", confHandler.GetReplicationMode, setMethods("GET")) + registerFunc(apiRouter, "SetReplicationMode", "/config/replication-mode", confHandler.SetReplicationMode, setMethods("POST")) rulesHandler := newRulesHandler(svr, rd) - clusterRouter.HandleFunc("/config/rules", rulesHandler.GetAll).Methods("GET") - clusterRouter.HandleFunc("/config/rules", rulesHandler.SetAll).Methods("POST") - clusterRouter.HandleFunc("/config/rules/batch", rulesHandler.Batch).Methods("POST") - clusterRouter.HandleFunc("/config/rules/group/{group}", rulesHandler.GetAllByGroup).Methods("GET") - clusterRouter.HandleFunc("/config/rules/region/{region}", rulesHandler.GetAllByRegion).Methods("GET") - clusterRouter.HandleFunc("/config/rules/key/{key}", rulesHandler.GetAllByKey).Methods("GET") - clusterRouter.HandleFunc("/config/rule/{group}/{id}", rulesHandler.Get).Methods("GET") - clusterRouter.HandleFunc("/config/rule", rulesHandler.Set).Methods("POST") - clusterRouter.HandleFunc("/config/rule/{group}/{id}", rulesHandler.Delete).Methods("DELETE") + registerFunc(clusterRouter, "GetAllRules", "/config/rules", rulesHandler.GetAll, setMethods("GET")) + registerFunc(clusterRouter, "SetAllRules", "/config/rules", rulesHandler.SetAll, setMethods("POST")) + registerFunc(clusterRouter, "SetBatchRules", "/config/rules/batch", rulesHandler.Batch, setMethods("POST")) + registerFunc(clusterRouter, "GetRuleByGroup", "/config/rules/group/{group}", rulesHandler.GetAllByGroup, setMethods("GET")) + registerFunc(clusterRouter, "GetRuleByByRegion", "/config/rules/region/{region}", rulesHandler.GetAllByRegion, setMethods("GET")) + registerFunc(clusterRouter, "GetRuleByKey", "/config/rules/key/{key}", rulesHandler.GetAllByKey, setMethods("GET")) + registerFunc(clusterRouter, "GetRuleByGroupAndID", "/config/rule/{group}/{id}", rulesHandler.Get, setMethods("GET")) + registerFunc(clusterRouter, "SetRule", "/config/rule", rulesHandler.Set, setMethods("POST")) + registerFunc(clusterRouter, "DeleteRuleByGroup", "/config/rule/{group}/{id}", rulesHandler.Delete, setMethods("DELETE")) regionLabelHandler := newRegionLabelHandler(svr, rd) - clusterRouter.HandleFunc("/config/region-label/rules", regionLabelHandler.GetAllRules).Methods("GET") - clusterRouter.HandleFunc("/config/region-label/rules/ids", regionLabelHandler.GetRulesByIDs).Methods("GET") + registerFunc(clusterRouter, "GetAllRegionLabelRule", "/config/region-label/rules", regionLabelHandler.GetAllRules, setMethods("GET")) + registerFunc(clusterRouter, "GetRegionLabelRulesByIDs", "/config/region-label/rules/ids", regionLabelHandler.GetRulesByIDs, setMethods("GET")) // {id} can be a string with special characters, we should enable path encode to support it. - escapeRouter.HandleFunc("/config/region-label/rule/{id}", regionLabelHandler.GetRule).Methods("GET") - escapeRouter.HandleFunc("/config/region-label/rule/{id}", regionLabelHandler.DeleteRule).Methods("DELETE") - clusterRouter.HandleFunc("/config/region-label/rule", regionLabelHandler.SetRule).Methods("POST") - clusterRouter.HandleFunc("/config/region-label/rules", regionLabelHandler.Patch).Methods("PATCH") + registerFunc(escapeRouter, "GetRegionLabelRuleByID", "/config/region-label/rule/{id}", regionLabelHandler.GetRule, setMethods("GET")) + registerFunc(escapeRouter, "DeleteRegionLabelRule", "/config/region-label/rule/{id}", regionLabelHandler.DeleteRule, setMethods("DELETE")) + registerFunc(clusterRouter, "SetRegionLabelRule", "/config/region-label/rule", regionLabelHandler.SetRule, setMethods("POST")) + registerFunc(clusterRouter, "PatchRegionLabelRules", "/config/region-label/rules", regionLabelHandler.Patch, setMethods("PATCH")) - clusterRouter.HandleFunc("/region/id/{id}/label/{key}", regionLabelHandler.GetRegionLabel).Methods("GET") - clusterRouter.HandleFunc("/region/id/{id}/labels", regionLabelHandler.GetRegionLabels).Methods("GET") + registerFunc(clusterRouter, "GetRegionLabelByKey", "/region/id/{id}/label/{key}", regionLabelHandler.GetRegionLabel, setMethods("GET")) + registerFunc(clusterRouter, "GetAllRegionLabels", "/region/id/{id}/labels", regionLabelHandler.GetRegionLabels, setMethods("GET")) - clusterRouter.HandleFunc("/config/rule_group/{id}", rulesHandler.GetGroupConfig).Methods("GET") - clusterRouter.HandleFunc("/config/rule_group", rulesHandler.SetGroupConfig).Methods("POST") - clusterRouter.HandleFunc("/config/rule_group/{id}", rulesHandler.DeleteGroupConfig).Methods("DELETE") - clusterRouter.HandleFunc("/config/rule_groups", rulesHandler.GetAllGroupConfigs).Methods("GET") + registerFunc(clusterRouter, "GetRuleGroup", "/config/rule_group/{id}", rulesHandler.GetGroupConfig, setMethods("GET")) + registerFunc(clusterRouter, "SetRuleGroup", "/config/rule_group", rulesHandler.SetGroupConfig, setMethods("POST")) + registerFunc(clusterRouter, "DeleteRuleGroup", "/config/rule_group/{id}", rulesHandler.DeleteGroupConfig, setMethods("DELETE")) + registerFunc(clusterRouter, "GetAllRuleGroups", "/config/rule_groups", rulesHandler.GetAllGroupConfigs, setMethods("GET")) - clusterRouter.HandleFunc("/config/placement-rule", rulesHandler.GetAllGroupBundles).Methods("GET") - clusterRouter.HandleFunc("/config/placement-rule", rulesHandler.SetAllGroupBundles).Methods("POST") + registerFunc(clusterRouter, "GetAllPlacementRules", "/config/placement-rule", rulesHandler.GetAllGroupBundles, setMethods("GET")) + registerFunc(clusterRouter, "SetAllPlacementRules", "/config/placement-rule", rulesHandler.SetAllGroupBundles, setMethods("POST")) // {group} can be a regular expression, we should enable path encode to // support special characters. - clusterRouter.HandleFunc("/config/placement-rule/{group}", rulesHandler.GetGroupBundle).Methods("GET") - clusterRouter.HandleFunc("/config/placement-rule/{group}", rulesHandler.SetGroupBundle).Methods("POST") - escapeRouter.HandleFunc("/config/placement-rule/{group}", rulesHandler.DeleteGroupBundle).Methods("DELETE") + registerFunc(clusterRouter, "GetPlacementRuleByGroup", "/config/placement-rule/{group}", rulesHandler.GetGroupBundle, setMethods("GET")) + registerFunc(clusterRouter, "SetPlacementRuleByGroup", "/config/placement-rule/{group}", rulesHandler.SetGroupBundle, setMethods("POST")) + registerFunc(escapeRouter, "DeletePlacementRuleByGroup", "/config/placement-rule/{group}", rulesHandler.DeleteGroupBundle, setMethods("DELETE")) storeHandler := newStoreHandler(handler, rd) - clusterRouter.HandleFunc("/store/{id}", storeHandler.Get).Methods("GET") - clusterRouter.HandleFunc("/store/{id}", storeHandler.Delete).Methods("DELETE") - clusterRouter.HandleFunc("/store/{id}/state", storeHandler.SetState).Methods("POST") - clusterRouter.HandleFunc("/store/{id}/label", storeHandler.SetLabels).Methods("POST") - clusterRouter.HandleFunc("/store/{id}/weight", storeHandler.SetWeight).Methods("POST") - clusterRouter.HandleFunc("/store/{id}/limit", storeHandler.SetLimit).Methods("POST") + registerFunc(clusterRouter, "GetStore", "/store/{id}", storeHandler.Get, setMethods("GET")) + registerFunc(clusterRouter, "DeleteStore", "/store/{id}", storeHandler.Delete, setMethods("DELETE")) + registerFunc(clusterRouter, "SetStoreState", "/store/{id}/state", storeHandler.SetState, setMethods("POST")) + registerFunc(clusterRouter, "SetStoreLabel", "/store/{id}/label", storeHandler.SetLabels, setMethods("POST")) + registerFunc(clusterRouter, "SetStoreWeight", "/store/{id}/weight", storeHandler.SetWeight, setMethods("POST")) + registerFunc(clusterRouter, "SetStoreLimit", "/store/{id}/limit", storeHandler.SetLimit, setMethods("POST")) + storesHandler := newStoresHandler(handler, rd) - clusterRouter.Handle("/stores", storesHandler).Methods("GET") - clusterRouter.HandleFunc("/stores/remove-tombstone", storesHandler.RemoveTombStone).Methods("DELETE") - clusterRouter.HandleFunc("/stores/limit", storesHandler.GetAllLimit).Methods("GET") - clusterRouter.HandleFunc("/stores/limit", storesHandler.SetAllLimit).Methods("POST") - clusterRouter.HandleFunc("/stores/limit/scene", storesHandler.SetStoreLimitScene).Methods("POST") - clusterRouter.HandleFunc("/stores/limit/scene", storesHandler.GetStoreLimitScene).Methods("GET") + register(clusterRouter, "GetAllStores", "/stores", storesHandler, setMethods("GET")) + registerFunc(clusterRouter, "RemoveTombstone", "/stores/remove-tombstone", storesHandler.RemoveTombStone, setMethods("DELETE")) + registerFunc(clusterRouter, "GetAllStoresLimit", "/stores/limit", storesHandler.GetAllLimit, setMethods("GET")) + registerFunc(clusterRouter, "SetAllStoresLimit", "/stores/limit", storesHandler.SetAllLimit, setMethods("POST")) + registerFunc(clusterRouter, "SetStoreSceneLimit", "/stores/limit/scene", storesHandler.SetStoreLimitScene, setMethods("POST")) + registerFunc(clusterRouter, "GetStoreSceneLimit", "/stores/limit/scene", storesHandler.GetStoreLimitScene, setMethods("GET")) labelsHandler := newLabelsHandler(svr, rd) - clusterRouter.HandleFunc("/labels", labelsHandler.Get).Methods("GET") - clusterRouter.HandleFunc("/labels/stores", labelsHandler.GetStores).Methods("GET") + registerFunc(clusterRouter, "GetLabels", "/labels", labelsHandler.Get, setMethods("GET")) + registerFunc(clusterRouter, "GetStoresByLabel", "/labels/stores", labelsHandler.GetStores, setMethods("GET")) hotStatusHandler := newHotStatusHandler(handler, rd) - apiRouter.HandleFunc("/hotspot/regions/write", hotStatusHandler.GetHotWriteRegions).Methods("GET") - apiRouter.HandleFunc("/hotspot/regions/read", hotStatusHandler.GetHotReadRegions).Methods("GET") - apiRouter.HandleFunc("/hotspot/regions/history", hotStatusHandler.GetHistoryHotRegions).Methods("GET") - apiRouter.HandleFunc("/hotspot/stores", hotStatusHandler.GetHotStores).Methods("GET") + registerFunc(apiRouter, "GetHotspotWriteRegion", "/hotspot/regions/write", hotStatusHandler.GetHotWriteRegions, setMethods("GET")) + registerFunc(apiRouter, "GetHotspotReadRegion", "/hotspot/regions/read", hotStatusHandler.GetHotReadRegions, setMethods("GET")) + registerFunc(apiRouter, "GetHotspotStores", "/hotspot/regions/history", hotStatusHandler.GetHistoryHotRegions, setMethods("GET")) + registerFunc(apiRouter, "GetHistoryHotspotRegion", "/hotspot/stores", hotStatusHandler.GetHotStores, setMethods("GET")) regionHandler := newRegionHandler(svr, rd) - clusterRouter.HandleFunc("/region/id/{id}", regionHandler.GetRegionByID).Methods("GET") - clusterRouter.UseEncodedPath().HandleFunc("/region/key/{key}", regionHandler.GetRegionByKey).Methods("GET") + registerFunc(clusterRouter, "GetRegionByID", "/region/id/{id}", regionHandler.GetRegionByID, setMethods("GET")) + registerFunc(clusterRouter.UseEncodedPath(), "GetRegion", "/region/key/{key}", regionHandler.GetRegionByKey, setMethods("GET")) srd := createStreamingRender() regionsAllHandler := newRegionsHandler(svr, srd) - clusterRouter.HandleFunc("/regions", regionsAllHandler.GetAll).Methods("GET") + registerFunc(clusterRouter, "GetAllRegions", "/regions", regionsAllHandler.GetAll, setMethods("GET")) regionsHandler := newRegionsHandler(svr, rd) - clusterRouter.HandleFunc("/regions/key", regionsHandler.ScanRegions).Methods("GET") - clusterRouter.HandleFunc("/regions/count", regionsHandler.GetRegionCount).Methods("GET") - clusterRouter.HandleFunc("/regions/store/{id}", regionsHandler.GetStoreRegions).Methods("GET") - clusterRouter.HandleFunc("/regions/writeflow", regionsHandler.GetTopWriteFlow).Methods("GET") - clusterRouter.HandleFunc("/regions/readflow", regionsHandler.GetTopReadFlow).Methods("GET") - clusterRouter.HandleFunc("/regions/confver", regionsHandler.GetTopConfVer).Methods("GET") - clusterRouter.HandleFunc("/regions/version", regionsHandler.GetTopVersion).Methods("GET") - clusterRouter.HandleFunc("/regions/size", regionsHandler.GetTopSize).Methods("GET") - clusterRouter.HandleFunc("/regions/check/miss-peer", regionsHandler.GetMissPeerRegions).Methods("GET") - clusterRouter.HandleFunc("/regions/check/extra-peer", regionsHandler.GetExtraPeerRegions).Methods("GET") - clusterRouter.HandleFunc("/regions/check/pending-peer", regionsHandler.GetPendingPeerRegions).Methods("GET") - clusterRouter.HandleFunc("/regions/check/down-peer", regionsHandler.GetDownPeerRegions).Methods("GET") - clusterRouter.HandleFunc("/regions/check/learner-peer", regionsHandler.GetLearnerPeerRegions).Methods("GET") - clusterRouter.HandleFunc("/regions/check/empty-region", regionsHandler.GetEmptyRegion).Methods("GET") - clusterRouter.HandleFunc("/regions/check/offline-peer", regionsHandler.GetOfflinePeer).Methods("GET") - - clusterRouter.HandleFunc("/regions/check/hist-size", regionsHandler.GetSizeHistogram).Methods("GET") - clusterRouter.HandleFunc("/regions/check/hist-keys", regionsHandler.GetKeysHistogram).Methods("GET") - clusterRouter.HandleFunc("/regions/sibling/{id}", regionsHandler.GetRegionSiblings).Methods("GET") - clusterRouter.HandleFunc("/regions/accelerate-schedule", regionsHandler.AccelerateRegionsScheduleInRange).Methods("POST") - clusterRouter.HandleFunc("/regions/scatter", regionsHandler.ScatterRegions).Methods("POST") - clusterRouter.HandleFunc("/regions/split", regionsHandler.SplitRegions).Methods("POST") - clusterRouter.HandleFunc("/regions/range-holes", regionsHandler.GetRangeHoles).Methods("GET") - clusterRouter.HandleFunc("/regions/replicated", regionsHandler.CheckRegionsReplicated).Methods("GET").Queries("startKey", "{startKey}", "endKey", "{endKey}") - - apiRouter.Handle("/version", newVersionHandler(rd)).Methods("GET") - apiRouter.Handle("/status", newStatusHandler(svr, rd)).Methods("GET") + registerFunc(clusterRouter, "ScanRegions", "/regions/key", regionsHandler.ScanRegions, setMethods("GET")) + registerFunc(clusterRouter, "CountRegions", "/regions/count", regionsHandler.GetRegionCount, setMethods("GET")) + registerFunc(clusterRouter, "GetRegionsByStore", "/regions/store/{id}", regionsHandler.GetStoreRegions, setMethods("GET")) + registerFunc(clusterRouter, "GetTopWriteRegions", "/regions/writeflow", regionsHandler.GetTopWriteFlow, setMethods("GET")) + registerFunc(clusterRouter, "GetTopReadRegions", "/regions/readflow", regionsHandler.GetTopReadFlow, setMethods("GET")) + registerFunc(clusterRouter, "GetTopConfverRegions", "/regions/confver", regionsHandler.GetTopConfVer, setMethods("GET")) + registerFunc(clusterRouter, "GetTopVersionRegions", "/regions/version", regionsHandler.GetTopVersion, setMethods("GET")) + registerFunc(clusterRouter, "GetTopSizeRegions", "/regions/size", regionsHandler.GetTopSize, setMethods("GET")) + registerFunc(clusterRouter, "GetMissPeerRegions", "/regions/check/miss-peer", regionsHandler.GetMissPeerRegions, setMethods("GET")) + registerFunc(clusterRouter, "GetExtraPeerRegions", "/regions/check/extra-peer", regionsHandler.GetExtraPeerRegions, setMethods("GET")) + registerFunc(clusterRouter, "GetPendingPeerRegions", "/regions/check/pending-peer", regionsHandler.GetPendingPeerRegions, setMethods("GET")) + registerFunc(clusterRouter, "GetDownPeerRegions", "/regions/check/down-peer", regionsHandler.GetDownPeerRegions, setMethods("GET")) + registerFunc(clusterRouter, "GetLearnerPeerRegions", "/regions/check/learner-peer", regionsHandler.GetLearnerPeerRegions, setMethods("GET")) + registerFunc(clusterRouter, "GetEmptyRegion", "/regions/check/empty-region", regionsHandler.GetEmptyRegion, setMethods("GET")) + registerFunc(clusterRouter, "GetOfflinePeer", "/regions/check/offline-peer", regionsHandler.GetOfflinePeer, setMethods("GET")) + + registerFunc(clusterRouter, "GetSizeHistogram", "/regions/check/hist-size", regionsHandler.GetSizeHistogram, setMethods("GET")) + registerFunc(clusterRouter, "GetKeysHistogram", "/regions/check/hist-keys", regionsHandler.GetKeysHistogram, setMethods("GET")) + registerFunc(clusterRouter, "GetRegionSiblings", "/regions/sibling/{id}", regionsHandler.GetRegionSiblings, setMethods("GET")) + registerFunc(clusterRouter, "AccelerateRegionsSchedule", "/regions/accelerate-schedule", regionsHandler.AccelerateRegionsScheduleInRange, setMethods("POST")) + registerFunc(clusterRouter, "ScatterRegions", "/regions/scatter", regionsHandler.ScatterRegions, setMethods("POST")) + registerFunc(clusterRouter, "SplitRegions", "/regions/split", regionsHandler.SplitRegions, setMethods("POST")) + registerFunc(clusterRouter, "GetRangeHoles", "/regions/range-holes", regionsHandler.GetRangeHoles, setMethods("GET")) + registerFunc(clusterRouter, "CheckRegionsReplicated", "/regions/replicated", regionsHandler.CheckRegionsReplicated, setMethods("GET"), setQueries("startKey", "{startKey}", "endKey", "{endKey}")) + + register(apiRouter, "GetVersion", "/version", newVersionHandler(rd), setMethods("GET")) + register(apiRouter, "GetPDStatus", "/status", newStatusHandler(svr, rd), setMethods("GET")) memberHandler := newMemberHandler(svr, rd) - apiRouter.HandleFunc("/members", memberHandler.ListMembers).Methods("GET") - apiRouter.HandleFunc("/members/name/{name}", memberHandler.DeleteByName).Methods("DELETE") - apiRouter.HandleFunc("/members/id/{id}", memberHandler.DeleteByID).Methods("DELETE") - apiRouter.HandleFunc("/members/name/{name}", memberHandler.SetMemberPropertyByName).Methods("POST") + registerFunc(apiRouter, "GetMembers", "/members", memberHandler.ListMembers, setMethods("GET")) + registerFunc(apiRouter, "DeleteMemberByName", "/members/name/{name}", memberHandler.DeleteByName, setMethods("DELETE")) + registerFunc(apiRouter, "DeleteMemberByID", "/members/id/{id}", memberHandler.DeleteByID, setMethods("DELETE")) + registerFunc(apiRouter, "SetMemberByName", "/members/name/{name}", memberHandler.SetMemberPropertyByName, setMethods("POST")) leaderHandler := newLeaderHandler(svr, rd) - apiRouter.HandleFunc("/leader", leaderHandler.Get).Methods("GET") - apiRouter.HandleFunc("/leader/resign", leaderHandler.Resign).Methods("POST") - apiRouter.HandleFunc("/leader/transfer/{next_leader}", leaderHandler.Transfer).Methods("POST") + registerFunc(apiRouter, "GetLeader", "/leader", leaderHandler.Get, setMethods("GET")) + registerFunc(apiRouter, "ResignLeader", "/leader/resign", leaderHandler.Resign, setMethods("POST")) + registerFunc(apiRouter, "TransferLeader", "/leader/transfer/{next_leader}", leaderHandler.Transfer, setMethods("POST")) statsHandler := newStatsHandler(svr, rd) - clusterRouter.HandleFunc("/stats/region", statsHandler.Region).Methods("GET") + registerFunc(clusterRouter, "GetRegionStatus", "/stats/region", statsHandler.Region, setMethods("GET")) trendHandler := newTrendHandler(svr, rd) - apiRouter.HandleFunc("/trend", trendHandler.Handle).Methods("GET") + registerFunc(apiRouter, "GetTrend", "/trend", trendHandler.Handle, setMethods("GET")) adminHandler := newAdminHandler(svr, rd) - clusterRouter.HandleFunc("/admin/cache/region/{id}", adminHandler.HandleDropCacheRegion).Methods("DELETE") - clusterRouter.HandleFunc("/admin/reset-ts", adminHandler.ResetTS).Methods("POST") - apiRouter.HandleFunc("/admin/persist-file/{file_name}", adminHandler.persistFile).Methods("POST") - clusterRouter.HandleFunc("/admin/replication_mode/wait-async", adminHandler.UpdateWaitAsyncTime).Methods("POST") + registerFunc(clusterRouter, "DeleteRegionCache", "/admin/cache/region/{id}", adminHandler.HandleDropCacheRegion, setMethods("DELETE")) + registerFunc(clusterRouter, "ResetTS", "/admin/reset-ts", adminHandler.ResetTS, setMethods("POST")) + registerFunc(apiRouter, "SavePersistFile", "/admin/persist-file/{file_name}", adminHandler.persistFile, setMethods("POST")) + registerFunc(clusterRouter, "SetWaitAsyncTime", "/admin/replication_mode/wait-async", adminHandler.UpdateWaitAsyncTime, setMethods("POST")) + registerFunc(apiRouter, "SwitchServiceMiddleware", "/admin/service-middleware", adminHandler.HanldeServiceMiddlewareSwitch, setMethods("POST")) logHandler := newLogHandler(svr, rd) - apiRouter.HandleFunc("/admin/log", logHandler.Handle).Methods("POST") + registerFunc(apiRouter, "SetLogLevel", "/admin/log", logHandler.Handle, setMethods("POST")) replicationModeHandler := newReplicationModeHandler(svr, rd) - clusterRouter.HandleFunc("/replication_mode/status", replicationModeHandler.GetStatus) + registerFunc(clusterRouter, "GetReplicationModeStatus", "/replication_mode/status", replicationModeHandler.GetStatus) // Deprecated: component exists for historical compatibility and should not be used anymore. See https://github.com/tikv/tikv/issues/11472. componentHandler := newComponentHandler(svr, rd) @@ -235,46 +317,47 @@ func createRouter(prefix string, svr *server.Server) *mux.Router { clusterRouter.HandleFunc("/component/{type}", componentHandler.GetAddress).Methods("GET") pluginHandler := newPluginHandler(handler, rd) - apiRouter.HandleFunc("/plugin", pluginHandler.LoadPlugin).Methods("POST") - apiRouter.HandleFunc("/plugin", pluginHandler.UnloadPlugin).Methods("DELETE") + registerFunc(apiRouter, "SetPlugin", "/plugin", pluginHandler.LoadPlugin, setMethods("POST")) + registerFunc(apiRouter, "DeletePlugin", "/plugin", pluginHandler.UnloadPlugin, setMethods("DELETE")) - apiRouter.Handle("/health", newHealthHandler(svr, rd)).Methods("GET") + register(apiRouter, "GetHealthStatus", "/health", newHealthHandler(svr, rd), setMethods("GET")) // Deprecated: This API is no longer maintained anymore. apiRouter.Handle("/diagnose", newDiagnoseHandler(svr, rd)).Methods("GET") - apiRouter.HandleFunc("/ping", func(w http.ResponseWriter, r *http.Request) {}).Methods("GET") + registerFunc(apiRouter, "Ping", "/ping", func(w http.ResponseWriter, r *http.Request) {}, setMethods("GET")) + // metric query use to query metric data, the protocol is compatible with prometheus. - apiRouter.Handle("/metric/query", newQueryMetric(svr)).Methods("GET", "POST") - apiRouter.Handle("/metric/query_range", newQueryMetric(svr)).Methods("GET", "POST") + register(apiRouter, "QueryMetric", "/metric/query", newQueryMetric(svr), setMethods("GET", "POST")) + register(apiRouter, "QueryMetric", "/metric/query_range", newQueryMetric(svr), setMethods("GET", "POST")) // tso API tsoHandler := newTSOHandler(svr, rd) - apiRouter.HandleFunc("/tso/allocator/transfer/{name}", tsoHandler.TransferLocalTSOAllocator).Methods("POST") + registerFunc(apiRouter, "TransferLocalTSOAllocator", "/tso/allocator/transfer/{name}", tsoHandler.TransferLocalTSOAllocator, setMethods("POST")) // profile API - apiRouter.HandleFunc("/debug/pprof/profile", pprof.Profile) - apiRouter.HandleFunc("/debug/pprof/trace", pprof.Trace) - apiRouter.HandleFunc("/debug/pprof/symbol", pprof.Symbol) - apiRouter.Handle("/debug/pprof/heap", pprof.Handler("heap")) - apiRouter.Handle("/debug/pprof/mutex", pprof.Handler("mutex")) - apiRouter.Handle("/debug/pprof/allocs", pprof.Handler("allocs")) - apiRouter.Handle("/debug/pprof/block", pprof.Handler("block")) - apiRouter.Handle("/debug/pprof/goroutine", pprof.Handler("goroutine")) - apiRouter.Handle("/debug/pprof/threadcreate", pprof.Handler("threadcreate")) - apiRouter.Handle("/debug/pprof/zip", newProfHandler(svr, rd)) + registerFunc(apiRouter, "DebugPProfProfile", "/debug/pprof/profile", pprof.Profile) + registerFunc(apiRouter, "DebugPProfTrace", "/debug/pprof/trace", pprof.Trace) + registerFunc(apiRouter, "DebugPProfSymbol", "/debug/pprof/symbol", pprof.Symbol) + register(apiRouter, "DebugPProfHeap", "/debug/pprof/heap", pprof.Handler("heap")) + register(apiRouter, "DebugPProfMutex", "/debug/pprof/mutex", pprof.Handler("mutex")) + register(apiRouter, "DebugPProfAllocs", "/debug/pprof/allocs", pprof.Handler("allocs")) + register(apiRouter, "DebugPProfBlock", "/debug/pprof/block", pprof.Handler("block")) + register(apiRouter, "DebugPProfGoroutine", "/debug/pprof/goroutine", pprof.Handler("goroutine")) + register(apiRouter, "DebugPProfThreadCreate", "/debug/pprof/threadcreate", pprof.Handler("threadcreate")) + register(apiRouter, "DebugPProfZip", "/debug/pprof/zip", newProfHandler(svr, rd)) // service GC safepoint API serviceGCSafepointHandler := newServiceGCSafepointHandler(svr, rd) - apiRouter.HandleFunc("/gc/safepoint", serviceGCSafepointHandler.List).Methods("GET") - apiRouter.HandleFunc("/gc/safepoint/{service_id}", serviceGCSafepointHandler.Delete).Methods("DELETE") + registerFunc(apiRouter, "GetGCSafePoint", "/gc/safepoint", serviceGCSafepointHandler.List, setMethods("GET")) + registerFunc(apiRouter, "DeleteGCSafePoint", "/gc/safepoint/{service_id}", serviceGCSafepointHandler.Delete, setMethods("DELETE")) // unsafe admin operation API unsafeOperationHandler := newUnsafeOperationHandler(svr, rd) - clusterRouter.HandleFunc("/admin/unsafe/remove-failed-stores", - unsafeOperationHandler.RemoveFailedStores).Methods("POST") - clusterRouter.HandleFunc("/admin/unsafe/remove-failed-stores/show", - unsafeOperationHandler.GetFailedStoresRemovalStatus).Methods("GET") - clusterRouter.HandleFunc("/admin/unsafe/remove-failed-stores/history", - unsafeOperationHandler.GetFailedStoresRemovalHistory).Methods("GET") + registerFunc(clusterRouter, "RemoveFailedStoresUnsafely", "/admin/unsafe/remove-failed-stores", + unsafeOperationHandler.RemoveFailedStores, setMethods("POST")) + registerFunc(clusterRouter, "GetOngoingFailedStoresRemoval", "/admin/unsafe/remove-failed-stores/show", + unsafeOperationHandler.GetFailedStoresRemovalStatus, setMethods("GET")) + registerFunc(clusterRouter, "GetHistoryFailedStoresRemoval", "/admin/unsafe/remove-failed-stores/history", + unsafeOperationHandler.GetFailedStoresRemovalHistory, setMethods("GET")) // API to set or unset failpoints failpoint.Inject("enableFailpointAPI", func() { diff --git a/server/config/config.go b/server/config/config.go index 2ebbfd140e4..6a088e210ea 100644 --- a/server/config/config.go +++ b/server/config/config.go @@ -162,6 +162,8 @@ type Config struct { Dashboard DashboardConfig `toml:"dashboard" json:"dashboard"` ReplicationMode ReplicationModeConfig `toml:"replication-mode" json:"replication-mode"` + + EnableServiceMiddleware bool } // NewConfig creates a new config. diff --git a/server/server.go b/server/server.go index 2c45006108e..23669b79c5b 100644 --- a/server/server.go +++ b/server/server.go @@ -726,6 +726,16 @@ func (s *Server) SetStorage(storage storage.Storage) { s.storage = storage } +// SetServiceMiddleware change EnableServiceMiddleware +func (s *Server) SetServiceMiddleware(status bool) { + s.cfg.EnableServiceMiddleware = status +} + +// IsServiceMiddlewareEnabled returns EnableServiceMiddleware status +func (s *Server) IsServiceMiddlewareEnabled() bool { + return s.cfg.EnableServiceMiddleware +} + // GetBasicCluster returns the basic cluster of server. func (s *Server) GetBasicCluster() *core.BasicCluster { return s.basicCluster diff --git a/tests/server/api/api_test.go b/tests/server/api/api_test.go index 019dc14f84d..299530f0924 100644 --- a/tests/server/api/api_test.go +++ b/tests/server/api/api_test.go @@ -15,13 +15,16 @@ package api_test import ( + "bytes" "context" + "encoding/json" "io" "net/http" "testing" "time" . "github.com/pingcap/check" + "github.com/pingcap/failpoint" "github.com/tikv/pd/pkg/apiutil/serverapi" "github.com/tikv/pd/pkg/testutil" "github.com/tikv/pd/pkg/typeutil" @@ -110,6 +113,113 @@ func (s *serverTestSuite) TestReconnect(c *C) { } } +var _ = Suite(&testMiddlewareSuite{}) + +type testMiddlewareSuite struct { + cleanup func() + cluster *tests.TestCluster +} + +func (s *testMiddlewareSuite) SetUpSuite(c *C) { + ctx, cancel := context.WithCancel(context.Background()) + server.EnableZap = true + s.cleanup = cancel + cluster, err := tests.NewTestCluster(ctx, 1) + c.Assert(err, IsNil) + c.Assert(cluster.RunInitialServers(), IsNil) + c.Assert(cluster.WaitLeader(), Not(HasLen), 0) + s.cluster = cluster +} + +func (s *testMiddlewareSuite) TearDownSuite(c *C) { + s.cleanup() + s.cluster.Destroy() +} + +func (s *testMiddlewareSuite) TestRequestInfoMiddleware(c *C) { + c.Assert(failpoint.Enable("github.com/tikv/pd/server/api/addRequestInfoMiddleware", "return(true)"), IsNil) + leader := s.cluster.GetServer(s.cluster.GetLeader()) + + req, _ := http.NewRequest("POST", leader.GetAddr()+"/pd/api/v1/admin/service-middleware?enable=true", nil) + resp, err := dialClient.Do(req) + c.Assert(err, IsNil) + resp.Body.Close() + c.Assert(leader.GetServer().IsServiceMiddlewareEnabled(), Equals, true) + + labels := make(map[string]interface{}) + labels["testkey"] = "testvalue" + data, _ := json.Marshal(labels) + resp, err = dialClient.Post(leader.GetAddr()+"/pd/api/v1/debug/pprof/profile?force=true", "application/json", bytes.NewBuffer(data)) + c.Assert(err, IsNil) + _, err = io.ReadAll(resp.Body) + resp.Body.Close() + c.Assert(err, IsNil) + c.Assert(resp.StatusCode, Equals, http.StatusOK) + + c.Assert(resp.Header.Get("service-label"), Equals, "DebugPProfProfile") + c.Assert(resp.Header.Get("url-param"), Equals, "{\"force\":[\"true\"]}") + c.Assert(resp.Header.Get("body-param"), Equals, "{\"testkey\":\"testvalue\"}") + c.Assert(resp.Header.Get("method"), Equals, "HTTP/1.1/POST:/pd/api/v1/debug/pprof/profile") + c.Assert(resp.Header.Get("component"), Equals, "anonymous") + c.Assert(resp.Header.Get("ip"), Equals, "127.0.0.1") + + req, _ = http.NewRequest("POST", leader.GetAddr()+"/pd/api/v1/admin/service-middleware?enable=false", nil) + resp, err = dialClient.Do(req) + c.Assert(err, IsNil) + resp.Body.Close() + c.Assert(leader.GetServer().IsServiceMiddlewareEnabled(), Equals, false) + + header := mustRequestSuccess(c, leader.GetServer()) + c.Assert(header.Get("service-label"), Equals, "") + + c.Assert(failpoint.Disable("github.com/tikv/pd/server/api/addRequestInfoMiddleware"), IsNil) +} + +func BenchmarkDoRequestWithServiceMiddleware(b *testing.B) { + b.StopTimer() + ctx, cancel := context.WithCancel(context.Background()) + server.EnableZap = true + cluster, _ := tests.NewTestCluster(ctx, 1) + cluster.RunInitialServers() + cluster.WaitLeader() + leader := cluster.GetServer(cluster.GetLeader()) + req, _ := http.NewRequest("POST", leader.GetAddr()+"/pd/api/v1/admin/service-middleware?enable=true", nil) + resp, _ := dialClient.Do(req) + resp.Body.Close() + b.StartTimer() + for i := 0; i < b.N; i++ { + doTestRequest(leader) + } + cancel() + cluster.Destroy() +} + +func BenchmarkDoRequestWithoutServiceMiddleware(b *testing.B) { + b.StopTimer() + ctx, cancel := context.WithCancel(context.Background()) + server.EnableZap = true + cluster, _ := tests.NewTestCluster(ctx, 1) + cluster.RunInitialServers() + cluster.WaitLeader() + leader := cluster.GetServer(cluster.GetLeader()) + req, _ := http.NewRequest("POST", leader.GetAddr()+"/pd/api/v1/admin/service-middleware?enable=false", nil) + resp, _ := dialClient.Do(req) + resp.Body.Close() + b.StartTimer() + for i := 0; i < b.N; i++ { + doTestRequest(leader) + } + cancel() + cluster.Destroy() +} + +func doTestRequest(srv *tests.TestServer) { + req, _ := http.NewRequest("GET", srv.GetAddr()+"/pd/api/v1/component/admin/unsafe/remove-failed-stores/history", nil) + req.Header.Set("component", "test") + resp, _ := dialClient.Do(req) + resp.Body.Close() +} + var _ = Suite(&testRedirectorSuite{}) type testRedirectorSuite struct {