Skip to content

Commit

Permalink
api: save cluster in request context (#3756)
Browse files Browse the repository at this point in the history
Signed-off-by: disksing <i@disksing.com>

Co-authored-by: Ti Chi Robot <ti-community-prow-bot@tidb.io>
  • Loading branch information
disksing and ti-chi-bot committed Jun 15, 2021
1 parent f4d6f28 commit 69209e0
Show file tree
Hide file tree
Showing 9 changed files with 60 additions and 61 deletions.
13 changes: 2 additions & 11 deletions server/api/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func newAdminHandler(svr *server.Server, rd *render.Render) *adminHandler {
// @Failure 400 {string} string "The input is invalid."
// @Router /admin/cache/region/{id} [delete]
func (h *adminHandler) HandleDropCacheRegion(w http.ResponseWriter, r *http.Request) {
rc := h.svr.GetRaftCluster()
rc := getCluster(r)
vars := mux.Vars(r)
regionIDStr := vars["id"]
regionID, err := strconv.ParseUint(regionIDStr, 10, 64)
Expand Down Expand Up @@ -114,7 +114,6 @@ func (h *adminHandler) persistFile(w http.ResponseWriter, r *http.Request) {
// Intentionally no swagger mark as it is supposed to be only used in
// server-to-server.
func (h *adminHandler) UpdateWaitAsyncTime(w http.ResponseWriter, r *http.Request) {
handler := h.svr.GetHandler()
var input map[string]interface{}
if err := apiutil.ReadJSONRespondError(h.rd, w, r.Body, &input); err != nil {
return
Expand All @@ -129,15 +128,7 @@ func (h *adminHandler) UpdateWaitAsyncTime(w http.ResponseWriter, r *http.Reques
h.rd.JSON(w, http.StatusBadRequest, "invalid member id")
return
}
cluster, err := handler.GetRaftCluster()
if err != nil {
if err == server.ErrServerNotStarted {
h.rd.JSON(w, http.StatusInternalServerError, err.Error())
} else {
h.rd.JSON(w, http.StatusForbidden, err.Error())
}
return
}
cluster := getCluster(r)
cluster.GetReplicationMode().UpdateMemberWaitAsyncTime(memberID)
h.rd.JSON(w, http.StatusOK, nil)
}
8 changes: 4 additions & 4 deletions server/api/component.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func newComponentHandler(svr *server.Server, rd *render.Render) *componentHandle
// @Failure 500 {string} string "PD server failed to proceed the request."
// @Router /component [post]
func (h *componentHandler) Register(w http.ResponseWriter, r *http.Request) {
rc := h.svr.GetRaftCluster()
rc := getCluster(r)
input := make(map[string]string)
if err := apiutil.ReadJSONRespondError(h.rd, w, r.Body, &input); err != nil {
return
Expand Down Expand Up @@ -76,7 +76,7 @@ func (h *componentHandler) Register(w http.ResponseWriter, r *http.Request) {
// @Failure 400 {string} string "The input is invalid."
// @Router /component [delete]
func (h *componentHandler) UnRegister(w http.ResponseWriter, r *http.Request) {
rc := h.svr.GetRaftCluster()
rc := getCluster(r)
vars := mux.Vars(r)
component := vars["component"]
addr := vars["addr"]
Expand All @@ -93,7 +93,7 @@ func (h *componentHandler) UnRegister(w http.ResponseWriter, r *http.Request) {
// @Success 200 {object} Addresses
// @Router /component [get]
func (h *componentHandler) GetAllAddress(w http.ResponseWriter, r *http.Request) {
rc := h.svr.GetRaftCluster()
rc := getCluster(r)
addrs := rc.GetComponentManager().GetAllComponentAddrs()
h.rd.JSON(w, http.StatusOK, addrs)
}
Expand All @@ -105,7 +105,7 @@ func (h *componentHandler) GetAllAddress(w http.ResponseWriter, r *http.Request)
// @Failure 404 {string} string "The component does not exist."
// @Router /component/{type} [get]
func (h *componentHandler) GetAddress(w http.ResponseWriter, r *http.Request) {
rc := h.svr.GetRaftCluster()
rc := getCluster(r)
vars := mux.Vars(r)
component := vars["type"]
addrs := rc.GetComponentManager().GetComponentAddrs(component)
Expand Down
4 changes: 2 additions & 2 deletions server/api/label.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func newLabelsHandler(svr *server.Server, rd *render.Render) *labelsHandler {
// @Success 200 {array} metapb.StoreLabel
// @Router /labels [get]
func (h *labelsHandler) Get(w http.ResponseWriter, r *http.Request) {
rc := h.svr.GetRaftCluster()
rc := getCluster(r)
var labels []*metapb.StoreLabel
m := make(map[string]struct{})
stores := rc.GetStores()
Expand All @@ -67,7 +67,7 @@ func (h *labelsHandler) Get(w http.ResponseWriter, r *http.Request) {
// @Failure 500 {string} string "PD server failed to proceed the request."
// @Router /labels/stores [get]
func (h *labelsHandler) GetStores(w http.ResponseWriter, r *http.Request) {
rc := h.svr.GetRaftCluster()
rc := getCluster(r)
name := r.URL.Query().Get("name")
value := r.URL.Query().Get("value")
filter, err := newStoresLabelFilter(name, value)
Expand Down
11 changes: 10 additions & 1 deletion server/api/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
package api

import (
"context"
"net/http"

"github.com/tikv/pd/pkg/errs"
"github.com/tikv/pd/server"
"github.com/tikv/pd/server/cluster"
"github.com/unrolled/render"
)

Expand All @@ -40,6 +42,13 @@ func (m clusterMiddleware) Middleware(h http.Handler) http.Handler {
m.rd.JSON(w, http.StatusInternalServerError, errs.ErrNotBootstrapped.FastGenByArgs().Error())
return
}
h.ServeHTTP(w, r)
ctx := context.WithValue(r.Context(), clusterCtxKey{}, rc)
h.ServeHTTP(w, r.WithContext(ctx))
})
}

type clusterCtxKey struct{}

func getCluster(r *http.Request) *cluster.RaftCluster {
return r.Context().Value(clusterCtxKey{}).(*cluster.RaftCluster)
}
26 changes: 13 additions & 13 deletions server/api/region.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ func newRegionHandler(svr *server.Server, rd *render.Render) *regionHandler {
// @Failure 400 {string} string "The input is invalid."
// @Router /region/id/{id} [get]
func (h *regionHandler) GetRegionByID(w http.ResponseWriter, r *http.Request) {
rc := h.svr.GetRaftCluster()
rc := getCluster(r)

vars := mux.Vars(r)
regionIDStr := vars["id"]
Expand All @@ -222,7 +222,7 @@ func (h *regionHandler) GetRegionByID(w http.ResponseWriter, r *http.Request) {
// @Success 200 {object} RegionInfo
// @Router /region/key/{key} [get]
func (h *regionHandler) GetRegionByKey(w http.ResponseWriter, r *http.Request) {
rc := h.svr.GetRaftCluster()
rc := getCluster(r)
vars := mux.Vars(r)
key := vars["key"]
key, err := url.QueryUnescape(key)
Expand Down Expand Up @@ -263,7 +263,7 @@ func convertToAPIRegions(regions []*core.RegionInfo) *RegionsInfo {
// @Success 200 {object} RegionsInfo
// @Router /regions [get]
func (h *regionsHandler) GetAll(w http.ResponseWriter, r *http.Request) {
rc := h.svr.GetRaftCluster()
rc := getCluster(r)
regions := rc.GetRegions()
regionsInfo := convertToAPIRegions(regions)
h.rd.JSON(w, http.StatusOK, regionsInfo)
Expand All @@ -278,7 +278,7 @@ func (h *regionsHandler) GetAll(w http.ResponseWriter, r *http.Request) {
// @Failure 400 {string} string "The input is invalid."
// @Router /regions/key [get]
func (h *regionsHandler) ScanRegions(w http.ResponseWriter, r *http.Request) {
rc := h.svr.GetRaftCluster()
rc := getCluster(r)
startKey := r.URL.Query().Get("key")

limit := defaultRegionLimit
Expand All @@ -304,7 +304,7 @@ func (h *regionsHandler) ScanRegions(w http.ResponseWriter, r *http.Request) {
// @Success 200 {object} RegionsInfo
// @Router /regions/count [get]
func (h *regionsHandler) GetRegionCount(w http.ResponseWriter, r *http.Request) {
rc := h.svr.GetRaftCluster()
rc := getCluster(r)
count := rc.GetRegionCount()
h.rd.JSON(w, http.StatusOK, &RegionsInfo{Count: count})
}
Expand All @@ -317,7 +317,7 @@ func (h *regionsHandler) GetRegionCount(w http.ResponseWriter, r *http.Request)
// @Failure 400 {string} string "The input is invalid."
// @Router /regions/store/{id} [get]
func (h *regionsHandler) GetStoreRegions(w http.ResponseWriter, r *http.Request) {
rc := h.svr.GetRaftCluster()
rc := getCluster(r)

vars := mux.Vars(r)
id, err := strconv.Atoi(vars["id"])
Expand Down Expand Up @@ -483,7 +483,7 @@ func (h *regionsHandler) GetSizeHistogram(w http.ResponseWriter, r *http.Request
h.rd.JSON(w, http.StatusBadRequest, err.Error())
return
}
rc := h.svr.GetRaftCluster()
rc := getCluster(r)
regions := rc.GetRegions()
histSizes := make([]int64, 0, len(regions))
for _, region := range regions {
Expand All @@ -507,7 +507,7 @@ func (h *regionsHandler) GetKeysHistogram(w http.ResponseWriter, r *http.Request
h.rd.JSON(w, http.StatusBadRequest, err.Error())
return
}
rc := h.svr.GetRaftCluster()
rc := getCluster(r)
regions := rc.GetRegions()
histKeys := make([]int64, 0, len(regions))
for _, region := range regions {
Expand Down Expand Up @@ -561,7 +561,7 @@ func calHist(bound int, list *[]int64) *[]*histItem {
// @Failure 404 {string} string "The region does not exist."
// @Router /regions/sibling/{id} [get]
func (h *regionsHandler) GetRegionSiblings(w http.ResponseWriter, r *http.Request) {
rc := h.svr.GetRaftCluster()
rc := getCluster(r)

vars := mux.Vars(r)
id, err := strconv.Atoi(vars["id"])
Expand Down Expand Up @@ -658,7 +658,7 @@ func (h *regionsHandler) GetTopSize(w http.ResponseWriter, r *http.Request) {
// @Failure 400 {string} string "The input is invalid."
// @Router /regions/accelerate-schedule [post]
func (h *regionsHandler) AccelerateRegionsScheduleInRange(w http.ResponseWriter, r *http.Request) {
rc := h.svr.GetRaftCluster()
rc := getCluster(r)
var input map[string]interface{}
if err := apiutil.ReadJSONRespondError(h.rd, w, r.Body, &input); err != nil {
return
Expand Down Expand Up @@ -700,7 +700,7 @@ func (h *regionsHandler) AccelerateRegionsScheduleInRange(w http.ResponseWriter,
}

func (h *regionsHandler) GetTopNRegions(w http.ResponseWriter, r *http.Request, less func(a, b *core.RegionInfo) bool) {
rc := h.svr.GetRaftCluster()
rc := getCluster(r)
limit := defaultRegionLimit
if limitStr := r.URL.Query().Get("limit"); limitStr != "" {
var err error
Expand All @@ -727,7 +727,7 @@ func (h *regionsHandler) GetTopNRegions(w http.ResponseWriter, r *http.Request,
// @Failure 400 {string} string "The input is invalid."
// @Router /regions/scatter [post]
func (h *regionsHandler) ScatterRegions(w http.ResponseWriter, r *http.Request) {
rc := h.svr.GetRaftCluster()
rc := getCluster(r)
var input map[string]interface{}
if err := apiutil.ReadJSONRespondError(h.rd, w, r.Body, &input); err != nil {
return
Expand Down Expand Up @@ -803,7 +803,7 @@ func (h *regionsHandler) ScatterRegions(w http.ResponseWriter, r *http.Request)
// @Failure 400 {string} string "The input is invalid."
// @Router /regions/split [post]
func (h *regionsHandler) SplitRegions(w http.ResponseWriter, r *http.Request) {
rc := h.svr.GetRaftCluster()
rc := getCluster(r)
var input map[string]interface{}
if err := apiutil.ReadJSONRespondError(h.rd, w, r.Body, &input); err != nil {
return
Expand Down
2 changes: 1 addition & 1 deletion server/api/replication_mode.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,5 @@ func newReplicationModeHandler(svr *server.Server, rd *render.Render) *replicati
// @Success 200 {object} replication.HTTPReplicationStatus
// @Router /replication_mode/status [get]
func (h *replicationModeHandler) GetStatus(w http.ResponseWriter, r *http.Request) {
h.rd.JSON(w, http.StatusOK, h.svr.GetRaftCluster().GetReplicationMode().GetReplicationStatusHTTP())
h.rd.JSON(w, http.StatusOK, getCluster(r).GetReplicationMode().GetReplicationStatusHTTP())
}
Loading

0 comments on commit 69209e0

Please sign in to comment.