diff --git a/Makefile b/Makefile index 01558556ee4..77ad51a1a87 100644 --- a/Makefile +++ b/Makefile @@ -33,6 +33,10 @@ else BUILD_CGO_ENABLED := 1 endif +ifeq ($(FAILPOINT), 1) + BUILD_TAGS += with_fail +endif + ifeq ("$(WITH_RACE)", "1") BUILD_FLAGS += -race BUILD_CGO_ENABLED := 1 @@ -73,6 +77,11 @@ PD_SERVER_DEP += dashboard-ui pd-server: ${PD_SERVER_DEP} CGO_ENABLED=$(BUILD_CGO_ENABLED) go build $(BUILD_FLAGS) -gcflags '$(GCFLAGS)' -ldflags '$(LDFLAGS)' -tags "$(BUILD_TAGS)" -o $(BUILD_BIN_PATH)/pd-server cmd/pd-server/main.go +pd-server-failpoint: + @$(FAILPOINT_ENABLE) + FAILPOINT=1 $(MAKE) pd-server || { $(FAILPOINT_DISABLE); exit 1; } + @$(FAILPOINT_DISABLE) + pd-server-basic: SWAGGER=0 DASHBOARD=0 $(MAKE) pd-server diff --git a/client/client.go b/client/client.go index d3d3805fc4d..74cb7adf2a5 100644 --- a/client/client.go +++ b/client/client.go @@ -87,7 +87,7 @@ type Client interface { // GetRegion gets a region and its leader Peer from PD by key. // The region may expire after split. Caller is responsible for caching and // taking care of region change. - // Also it may return nil if PD finds no Region for the key temporarily, + // Also, it may return nil if PD finds no Region for the key temporarily, // client should retry later. GetRegion(ctx context.Context, key []byte, opts ...GetRegionOption) (*Region, error) // GetRegionFromMember gets a region from certain members. @@ -96,7 +96,7 @@ type Client interface { GetPrevRegion(ctx context.Context, key []byte, opts ...GetRegionOption) (*Region, error) // GetRegionByID gets a region and its leader Peer from PD by id. GetRegionByID(ctx context.Context, regionID uint64, opts ...GetRegionOption) (*Region, error) - // ScanRegion gets a list of regions, starts from the region that contains key. + // ScanRegions gets a list of regions, starts from the region that contains key. // Limit limits the maximum number of regions returned. // If a region has no leader, corresponding leader will be placed by a peer // with empty value (PeerID is 0). @@ -109,7 +109,7 @@ type Client interface { // The store may expire later. Caller is responsible for caching and taking care // of store change. GetAllStores(ctx context.Context, opts ...GetStoreOption) ([]*metapb.Store, error) - // Update GC safe point. TiKV will check it and do GC themselves if necessary. + // UpdateGCSafePoint TiKV will check it and do GC themselves if necessary. // If the given safePoint is less than the current one, it will not be updated. // Returns the new safePoint after updating. UpdateGCSafePoint(ctx context.Context, safePoint uint64) (uint64, error) diff --git a/pkg/audit/audit_test.go b/pkg/audit/audit_test.go index 3bac22fcd80..3cb43ceead0 100644 --- a/pkg/audit/audit_test.go +++ b/pkg/audit/audit_test.go @@ -103,7 +103,7 @@ func TestLocalLogBackendUsingFile(t *testing.T) { b, _ := os.ReadFile(fname) output := strings.SplitN(string(b), "]", 4) re.Equal( - fmt.Sprintf(" [\"audit log\"] [service-info=\"{ServiceLabel:, Method:HTTP/1.1/GET:/test, Component:anonymous, IP:, "+ + fmt.Sprintf(" [\"audit log\"] [service-info=\"{ServiceLabel:, Method:HTTP/1.1/GET:/test, Component:anonymous, IP:, Port:, "+ "StartTime:%s, URLParam:{\\\"test\\\":[\\\"test\\\"]}, BodyParam:testBody}\"]\n", time.Unix(info.StartTimeStamp, 0).String()), output[3], diff --git a/pkg/mcs/scheduling/server/config/config.go b/pkg/mcs/scheduling/server/config/config.go index 7839ec7f274..ab26e1add29 100644 --- a/pkg/mcs/scheduling/server/config/config.go +++ b/pkg/mcs/scheduling/server/config/config.go @@ -19,6 +19,7 @@ import ( "os" "path/filepath" "strings" + "sync" "sync/atomic" "time" "unsafe" @@ -192,10 +193,13 @@ func (c *Config) validate() error { // PersistConfig wraps all configurations that need to persist to storage and // allows to access them safely. type PersistConfig struct { + // Store the global configuration that is related to the scheduling. clusterVersion unsafe.Pointer schedule atomic.Value replication atomic.Value storeConfig atomic.Value + // Store the respective configurations for different schedulers. + schedulerConfig sync.Map } // NewPersistConfig creates a new PersistConfig instance. @@ -253,6 +257,24 @@ func (o *PersistConfig) GetStoreConfig() *sc.StoreConfig { return o.storeConfig.Load().(*sc.StoreConfig) } +// SetSchedulerConfig sets the scheduler configuration with the given name. +func (o *PersistConfig) SetSchedulerConfig(name, data string) { + o.schedulerConfig.Store(name, data) +} + +// RemoveSchedulerConfig removes the scheduler configuration with the given name. +func (o *PersistConfig) RemoveSchedulerConfig(name string) { + o.schedulerConfig.Delete(name) +} + +// GetSchedulerConfig returns the scheduler configuration with the given name. +func (o *PersistConfig) GetSchedulerConfig(name string) string { + if v, ok := o.schedulerConfig.Load(name); ok { + return v.(string) + } + return "" +} + // GetMaxReplicas returns the max replicas. func (o *PersistConfig) GetMaxReplicas() int { return int(o.GetReplicationConfig().MaxReplicas) diff --git a/pkg/mcs/scheduling/server/config/watcher.go b/pkg/mcs/scheduling/server/config/watcher.go index 81ec4b62f0c..c9010db69a3 100644 --- a/pkg/mcs/scheduling/server/config/watcher.go +++ b/pkg/mcs/scheduling/server/config/watcher.go @@ -17,11 +17,13 @@ package config import ( "context" "encoding/json" + "strings" "sync" "github.com/coreos/go-semver/semver" "github.com/pingcap/log" sc "github.com/tikv/pd/pkg/schedule/config" + "github.com/tikv/pd/pkg/storage/endpoint" "github.com/tikv/pd/pkg/utils/etcdutil" "go.etcd.io/etcd/clientv3" "go.etcd.io/etcd/mvcc/mvccpb" @@ -34,11 +36,20 @@ type Watcher struct { ctx context.Context cancel context.CancelFunc - etcdClient *clientv3.Client - watcher *etcdutil.LoopWatcher + // configPath is the path of the configuration in etcd: + // - Key: /pd/{cluster_id}/config + // - Value: configuration JSON. + configPath string + // schedulerConfigPathPrefix is the path prefix of the scheduler configuration in etcd: + // - Key: /pd/{cluster_id}/scheduler_config/{scheduler_name} + // - Value: configuration JSON. + schedulerConfigPathPrefix string + + etcdClient *clientv3.Client + configWatcher *etcdutil.LoopWatcher + schedulerConfigWatcher *etcdutil.LoopWatcher *PersistConfig - // TODO: watch the scheduler config change. } type persistedConfig struct { @@ -52,19 +63,30 @@ type persistedConfig struct { func NewWatcher( ctx context.Context, etcdClient *clientv3.Client, - // configPath is the path of the configuration in etcd: - // - Key: /pd/{cluster_id}/config - // - Value: configuration JSON. - configPath string, + clusterID uint64, persistConfig *PersistConfig, ) (*Watcher, error) { ctx, cancel := context.WithCancel(ctx) cw := &Watcher{ - ctx: ctx, - cancel: cancel, - etcdClient: etcdClient, - PersistConfig: persistConfig, + ctx: ctx, + cancel: cancel, + configPath: endpoint.ConfigPath(clusterID), + schedulerConfigPathPrefix: endpoint.SchedulerConfigPathPrefix(clusterID), + etcdClient: etcdClient, + PersistConfig: persistConfig, + } + err := cw.initializeConfigWatcher() + if err != nil { + return nil, err } + err = cw.initializeSchedulerConfigWatcher() + if err != nil { + return nil, err + } + return cw, nil +} + +func (cw *Watcher) initializeConfigWatcher() error { putFn := func(kv *mvccpb.KeyValue) error { cfg := &persistedConfig{} if err := json.Unmarshal(kv.Value, cfg); err != nil { @@ -84,21 +106,41 @@ func NewWatcher( postEventFn := func() error { return nil } - cw.watcher = etcdutil.NewLoopWatcher( - ctx, - &cw.wg, - etcdClient, - "scheduling-config-watcher", - configPath, - putFn, - deleteFn, - postEventFn, + cw.configWatcher = etcdutil.NewLoopWatcher( + cw.ctx, &cw.wg, + cw.etcdClient, + "scheduling-config-watcher", cw.configPath, + putFn, deleteFn, postEventFn, ) - cw.watcher.StartWatchLoop() - if err := cw.watcher.WaitLoad(); err != nil { - return nil, err + cw.configWatcher.StartWatchLoop() + return cw.configWatcher.WaitLoad() +} + +func (cw *Watcher) initializeSchedulerConfigWatcher() error { + prefixToTrim := cw.schedulerConfigPathPrefix + "/" + putFn := func(kv *mvccpb.KeyValue) error { + cw.SetSchedulerConfig( + strings.TrimPrefix(string(kv.Key), prefixToTrim), + string(kv.Value), + ) + return nil } - return cw, nil + deleteFn := func(kv *mvccpb.KeyValue) error { + cw.RemoveSchedulerConfig(strings.TrimPrefix(string(kv.Key), prefixToTrim)) + return nil + } + postEventFn := func() error { + return nil + } + cw.schedulerConfigWatcher = etcdutil.NewLoopWatcher( + cw.ctx, &cw.wg, + cw.etcdClient, + "scheduling-scheduler-config-watcher", cw.schedulerConfigPathPrefix, + putFn, deleteFn, postEventFn, + clientv3.WithPrefix(), + ) + cw.schedulerConfigWatcher.StartWatchLoop() + return cw.schedulerConfigWatcher.WaitLoad() } // Close closes the watcher. diff --git a/pkg/mcs/scheduling/server/rule/watcher.go b/pkg/mcs/scheduling/server/rule/watcher.go index c85644ff14f..cf0e1cd8ba1 100644 --- a/pkg/mcs/scheduling/server/rule/watcher.go +++ b/pkg/mcs/scheduling/server/rule/watcher.go @@ -16,6 +16,7 @@ package rule import ( "context" + "strings" "sync" "github.com/tikv/pd/pkg/storage/endpoint" @@ -44,7 +45,7 @@ func (rs *ruleStorage) LoadRules(f func(k, v string)) error { return nil } -// SaveRule stores a rule cfg to the rulesPath. +// SaveRule stores a rule cfg to the rulesPathPrefix. func (rs *ruleStorage) SaveRule(ruleKey string, rule interface{}) error { rs.rules.Store(ruleKey, rule) return nil @@ -104,6 +105,19 @@ type Watcher struct { cancel context.CancelFunc wg sync.WaitGroup + // rulesPathPrefix: + // - Key: /pd/{cluster_id}/rules/{group_id}-{rule_id} + // - Value: placement.Rule + rulesPathPrefix string + // ruleGroupPathPrefix: + // - Key: /pd/{cluster_id}/rule_group/{group_id} + // - Value: placement.RuleGroup + ruleGroupPathPrefix string + // regionLabelPathPrefix: + // - Key: /pd/{cluster_id}/region_label/{rule_id} + // - Value: labeler.LabelRule + regionLabelPathPrefix string + etcdClient *clientv3.Client ruleStore *ruleStorage @@ -117,47 +131,45 @@ type Watcher struct { func NewWatcher( ctx context.Context, etcdClient *clientv3.Client, - // rulePath: - // - Key: /pd/{cluster_id}/rules/{group_id}-{rule_id} - // - Value: placement.Rule - // ruleGroupPath: - // - Key: /pd/{cluster_id}/rule_group/{group_id} - // - Value: placement.RuleGroup - // regionLabelPath: - // - Key: /pd/{cluster_id}/region_label/{rule_id} - // - Value: labeler.LabelRule - rulesPath, ruleGroupPath, regionLabelPath string, + clusterID uint64, ) (*Watcher, error) { ctx, cancel := context.WithCancel(ctx) rw := &Watcher{ - ctx: ctx, - cancel: cancel, - etcdClient: etcdClient, - ruleStore: &ruleStorage{}, + ctx: ctx, + cancel: cancel, + rulesPathPrefix: endpoint.RulesPathPrefix(clusterID), + ruleGroupPathPrefix: endpoint.RuleGroupPathPrefix(clusterID), + regionLabelPathPrefix: endpoint.RegionLabelPathPrefix(clusterID), + etcdClient: etcdClient, + ruleStore: &ruleStorage{}, } - err := rw.initializeRuleWatcher(rulesPath) + err := rw.initializeRuleWatcher() if err != nil { return nil, err } - err = rw.initializeGroupWatcher(ruleGroupPath) + err = rw.initializeGroupWatcher() if err != nil { return nil, err } - err = rw.initializeRegionLabelWatcher(regionLabelPath) + err = rw.initializeRegionLabelWatcher() if err != nil { return nil, err } return rw, nil } -func (rw *Watcher) initializeRuleWatcher(rulePath string) error { +func (rw *Watcher) initializeRuleWatcher() error { + prefixToTrim := rw.rulesPathPrefix + "/" putFn := func(kv *mvccpb.KeyValue) error { // Since the PD API server will validate the rule before saving it to etcd, // so we could directly save the string rule in JSON to the storage here. - return rw.ruleStore.SaveRule(string(kv.Key), string(kv.Value)) + return rw.ruleStore.SaveRule( + strings.TrimPrefix(string(kv.Key), prefixToTrim), + string(kv.Value), + ) } deleteFn := func(kv *mvccpb.KeyValue) error { - return rw.ruleStore.DeleteRule(string(kv.Key)) + return rw.ruleStore.DeleteRule(strings.TrimPrefix(string(kv.Key), prefixToTrim)) } postEventFn := func() error { return nil @@ -165,7 +177,7 @@ func (rw *Watcher) initializeRuleWatcher(rulePath string) error { rw.ruleWatcher = etcdutil.NewLoopWatcher( rw.ctx, &rw.wg, rw.etcdClient, - "scheduling-rule-watcher", rulePath, + "scheduling-rule-watcher", rw.rulesPathPrefix, putFn, deleteFn, postEventFn, clientv3.WithPrefix(), ) @@ -173,12 +185,16 @@ func (rw *Watcher) initializeRuleWatcher(rulePath string) error { return rw.ruleWatcher.WaitLoad() } -func (rw *Watcher) initializeGroupWatcher(ruleGroupPath string) error { +func (rw *Watcher) initializeGroupWatcher() error { + prefixToTrim := rw.ruleGroupPathPrefix + "/" putFn := func(kv *mvccpb.KeyValue) error { - return rw.ruleStore.SaveRuleGroup(string(kv.Key), string(kv.Value)) + return rw.ruleStore.SaveRuleGroup( + strings.TrimPrefix(string(kv.Key), prefixToTrim), + string(kv.Value), + ) } deleteFn := func(kv *mvccpb.KeyValue) error { - return rw.ruleStore.DeleteRuleGroup(string(kv.Key)) + return rw.ruleStore.DeleteRuleGroup(strings.TrimPrefix(string(kv.Key), prefixToTrim)) } postEventFn := func() error { return nil @@ -186,7 +202,7 @@ func (rw *Watcher) initializeGroupWatcher(ruleGroupPath string) error { rw.groupWatcher = etcdutil.NewLoopWatcher( rw.ctx, &rw.wg, rw.etcdClient, - "scheduling-rule-group-watcher", ruleGroupPath, + "scheduling-rule-group-watcher", rw.ruleGroupPathPrefix, putFn, deleteFn, postEventFn, clientv3.WithPrefix(), ) @@ -194,12 +210,16 @@ func (rw *Watcher) initializeGroupWatcher(ruleGroupPath string) error { return rw.groupWatcher.WaitLoad() } -func (rw *Watcher) initializeRegionLabelWatcher(regionLabelPath string) error { +func (rw *Watcher) initializeRegionLabelWatcher() error { + prefixToTrim := rw.regionLabelPathPrefix + "/" putFn := func(kv *mvccpb.KeyValue) error { - return rw.ruleStore.SaveRegionRule(string(kv.Key), string(kv.Value)) + return rw.ruleStore.SaveRegionRule( + strings.TrimPrefix(string(kv.Key), prefixToTrim), + string(kv.Value), + ) } deleteFn := func(kv *mvccpb.KeyValue) error { - return rw.ruleStore.DeleteRegionRule(string(kv.Key)) + return rw.ruleStore.DeleteRegionRule(strings.TrimPrefix(string(kv.Key), prefixToTrim)) } postEventFn := func() error { return nil @@ -207,7 +227,7 @@ func (rw *Watcher) initializeRegionLabelWatcher(regionLabelPath string) error { rw.labelWatcher = etcdutil.NewLoopWatcher( rw.ctx, &rw.wg, rw.etcdClient, - "scheduling-region-label-watcher", regionLabelPath, + "scheduling-region-label-watcher", rw.regionLabelPathPrefix, putFn, deleteFn, postEventFn, clientv3.WithPrefix(), ) diff --git a/pkg/mcs/scheduling/server/server.go b/pkg/mcs/scheduling/server/server.go index 5ad011e2b4c..991d513e9b1 100644 --- a/pkg/mcs/scheduling/server/server.go +++ b/pkg/mcs/scheduling/server/server.go @@ -555,18 +555,13 @@ func (s *Server) startServer() (err error) { func (s *Server) startWatcher() (err error) { s.configWatcher, err = config.NewWatcher( - s.ctx, s.etcdClient, - endpoint.ConfigPath(s.clusterID), - s.persistConfig, + s.ctx, s.etcdClient, s.clusterID, s.persistConfig, ) if err != nil { return err } s.ruleWatcher, err = rule.NewWatcher( - s.ctx, s.etcdClient, - endpoint.RulesPath(s.clusterID), - endpoint.RuleGroupPath(s.clusterID), - endpoint.RegionLabelPath(s.clusterID), + s.ctx, s.etcdClient, s.clusterID, ) return err } diff --git a/pkg/storage/endpoint/key_path.go b/pkg/storage/endpoint/key_path.go index a463fc8acf6..0e99431044a 100644 --- a/pkg/storage/endpoint/key_path.go +++ b/pkg/storage/endpoint/key_path.go @@ -92,18 +92,23 @@ func ConfigPath(clusterID uint64) string { return path.Join(PDRootPath(clusterID), configPath) } -// RulesPath returns the path to save the placement rules. -func RulesPath(clusterID uint64) string { +// SchedulerConfigPathPrefix returns the path prefix to save the scheduler config. +func SchedulerConfigPathPrefix(clusterID uint64) string { + return path.Join(PDRootPath(clusterID), customScheduleConfigPath) +} + +// RulesPathPrefix returns the path prefix to save the placement rules. +func RulesPathPrefix(clusterID uint64) string { return path.Join(PDRootPath(clusterID), rulesPath) } -// RuleGroupPath returns the path to save the placement rule groups. -func RuleGroupPath(clusterID uint64) string { +// RuleGroupPathPrefix returns the path prefix to save the placement rule groups. +func RuleGroupPathPrefix(clusterID uint64) string { return path.Join(PDRootPath(clusterID), ruleGroupPath) } -// RegionLabelPath returns the path to save the region label. -func RegionLabelPath(clusterID uint64) string { +// RegionLabelPathPrefix returns the path prefix to save the region label. +func RegionLabelPathPrefix(clusterID uint64) string { return path.Join(PDRootPath(clusterID), regionLabelPath) } diff --git a/pkg/tso/keyspace_group_manager.go b/pkg/tso/keyspace_group_manager.go index b4f513eaa60..9ba3afb2a1f 100644 --- a/pkg/tso/keyspace_group_manager.go +++ b/pkg/tso/keyspace_group_manager.go @@ -137,6 +137,13 @@ func (s *state) getDeletedGroups() []uint32 { return groups } +// getDeletedGroupNum returns the number of the deleted keyspace groups. +func (s *state) getDeletedGroupNum() int { + s.RLock() + defer s.RUnlock() + return len(s.deletedGroups) +} + func (s *state) checkTSOSplit( targetGroupID uint32, ) (splitTargetAM, splitSourceAM *AllocatorManager, err error) { @@ -1391,6 +1398,11 @@ func (kgm *KeyspaceGroupManager) deletedGroupCleaner() { defer ticker.Stop() log.Info("deleted group cleaner is started", zap.Duration("patrol-interval", patrolInterval)) + var ( + empty = true + lastDeletedGroupID uint32 + lastDeletedGroupNum int + ) for { select { case <-kgm.ctx.Done(): @@ -1403,6 +1415,7 @@ func (kgm *KeyspaceGroupManager) deletedGroupCleaner() { if groupID == mcsutils.DefaultKeyspaceGroupID { continue } + empty = false // Make sure the allocator and group meta are not in use anymore. am, _ := kgm.getKeyspaceGroupMeta(groupID) if am != nil { @@ -1428,6 +1441,18 @@ func (kgm *KeyspaceGroupManager) deletedGroupCleaner() { kgm.Lock() delete(kgm.deletedGroups, groupID) kgm.Unlock() + lastDeletedGroupID = groupID + lastDeletedGroupNum += 1 + } + // This log would be helpful to check if the deleted groups are all gone. + if !empty && kgm.getDeletedGroupNum() == 0 { + log.Info("all the deleted keyspace groups have been cleaned up", + zap.Uint32("last-deleted-group-id", lastDeletedGroupID), + zap.Int("last-deleted-group-num", lastDeletedGroupNum)) + // Reset the state to make sure the log won't be printed again + // until we have new deleted groups. + empty = true + lastDeletedGroupNum = 0 } } } diff --git a/pkg/utils/apiutil/apiutil.go b/pkg/utils/apiutil/apiutil.go index dce063a99f9..269a256cff3 100644 --- a/pkg/utils/apiutil/apiutil.go +++ b/pkg/utils/apiutil/apiutil.go @@ -47,6 +47,17 @@ var ( ) const ( + // PDRedirectorHeader is used to mark which PD redirected this request. + PDRedirectorHeader = "PD-Redirector" + // PDAllowFollowerHandleHeader is used to mark whether this request is allowed to be handled by the follower PD. + PDAllowFollowerHandleHeader = "PD-Allow-follower-handle" + // XForwardedForHeader is used to mark the client IP. + XForwardedForHeader = "X-Forwarded-For" + // XForwardedPortHeader is used to mark the client port. + XForwardedPortHeader = "X-Forwarded-Port" + // XRealIPHeader is used to mark the real client IP. + XRealIPHeader = "X-Real-Ip" + // ErrRedirectFailed is the error message for redirect failed. ErrRedirectFailed = "redirect failed" // ErrRedirectToNotLeader is the error message for redirect to not leader. @@ -101,26 +112,30 @@ func ErrorResp(rd *render.Render, w http.ResponseWriter, err error) { } } -// GetIPAddrFromHTTPRequest returns http client IP from context. +// GetIPPortFromHTTPRequest returns http client host IP and port from context. // Because `X-Forwarded-For ` header has been written into RFC 7239(Forwarded HTTP Extension), // so `X-Forwarded-For` has the higher priority than `X-Real-IP`. // And both of them have the higher priority than `RemoteAddr` -func GetIPAddrFromHTTPRequest(r *http.Request) string { - ips := strings.Split(r.Header.Get("X-Forwarded-For"), ",") - if len(strings.Trim(ips[0], " ")) > 0 { - return ips[0] - } - - ip := r.Header.Get("X-Real-Ip") - if ip != "" { - return ip +func GetIPPortFromHTTPRequest(r *http.Request) (ip, port string) { + forwardedIPs := strings.Split(r.Header.Get(XForwardedForHeader), ",") + if forwardedIP := strings.Trim(forwardedIPs[0], " "); len(forwardedIP) > 0 { + ip = forwardedIP + // Try to get the port from "X-Forwarded-Port" header. + forwardedPorts := strings.Split(r.Header.Get(XForwardedPortHeader), ",") + if forwardedPort := strings.Trim(forwardedPorts[0], " "); len(forwardedPort) > 0 { + port = forwardedPort + } + } else if realIP := r.Header.Get(XRealIPHeader); len(realIP) > 0 { + ip = realIP + } else { + ip = r.RemoteAddr } - - ip, _, err := net.SplitHostPort(r.RemoteAddr) + splitIP, splitPort, err := net.SplitHostPort(ip) if err != nil { - return "" + // Ensure we could get an IP address at least. + return ip, port } - return ip + return splitIP, splitPort } // GetComponentNameOnHTTP returns component name from Request Header diff --git a/pkg/utils/apiutil/apiutil_test.go b/pkg/utils/apiutil/apiutil_test.go index bbbb3b860fb..a4e7b97aa4d 100644 --- a/pkg/utils/apiutil/apiutil_test.go +++ b/pkg/utils/apiutil/apiutil_test.go @@ -17,6 +17,7 @@ package apiutil import ( "bytes" "io" + "net/http" "net/http/httptest" "testing" @@ -68,3 +69,141 @@ func TestJsonRespondErrorBadInput(t *testing.T) { re.Equal(400, result.StatusCode) } } + +func TestGetIPPortFromHTTPRequest(t *testing.T) { + t.Parallel() + re := require.New(t) + + testCases := []struct { + r *http.Request + ip string + port string + err error + }{ + // IPv4 "X-Forwarded-For" with port + { + r: &http.Request{ + Header: map[string][]string{ + XForwardedForHeader: {"127.0.0.1:5299"}, + }, + }, + ip: "127.0.0.1", + port: "5299", + }, + // IPv4 "X-Forwarded-For" without port + { + r: &http.Request{ + Header: map[string][]string{ + XForwardedForHeader: {"127.0.0.1"}, + XForwardedPortHeader: {"5299"}, + }, + }, + ip: "127.0.0.1", + port: "5299", + }, + // IPv4 "X-Real-IP" with port + { + r: &http.Request{ + Header: map[string][]string{ + XRealIPHeader: {"127.0.0.1:5299"}, + }, + }, + ip: "127.0.0.1", + port: "5299", + }, + // IPv4 "X-Real-IP" without port + { + r: &http.Request{ + Header: map[string][]string{ + XForwardedForHeader: {"127.0.0.1"}, + XForwardedPortHeader: {"5299"}, + }, + }, + ip: "127.0.0.1", + port: "5299", + }, + // IPv4 RemoteAddr with port + { + r: &http.Request{ + RemoteAddr: "127.0.0.1:5299", + }, + ip: "127.0.0.1", + port: "5299", + }, + // IPv4 RemoteAddr without port + { + r: &http.Request{ + RemoteAddr: "127.0.0.1", + }, + ip: "127.0.0.1", + port: "", + }, + // IPv6 "X-Forwarded-For" with port + { + r: &http.Request{ + Header: map[string][]string{ + XForwardedForHeader: {"[::1]:5299"}, + }, + }, + ip: "::1", + port: "5299", + }, + // IPv6 "X-Forwarded-For" without port + { + r: &http.Request{ + Header: map[string][]string{ + XForwardedForHeader: {"::1"}, + }, + }, + ip: "::1", + port: "", + }, + // IPv6 "X-Real-IP" with port + { + r: &http.Request{ + Header: map[string][]string{ + XRealIPHeader: {"[::1]:5299"}, + }, + }, + ip: "::1", + port: "5299", + }, + // IPv6 "X-Real-IP" without port + { + r: &http.Request{ + Header: map[string][]string{ + XForwardedForHeader: {"::1"}, + }, + }, + ip: "::1", + port: "", + }, + // IPv6 RemoteAddr with port + { + r: &http.Request{ + RemoteAddr: "[::1]:5299", + }, + ip: "::1", + port: "5299", + }, + // IPv6 RemoteAddr without port + { + r: &http.Request{ + RemoteAddr: "::1", + }, + ip: "::1", + port: "", + }, + // Abnormal case + { + r: &http.Request{}, + ip: "", + port: "", + }, + } + for idx, testCase := range testCases { + ip, port := GetIPPortFromHTTPRequest(testCase.r) + re.Equal(testCase.ip, ip, "case %d", idx) + re.Equal(testCase.port, port, "case %d", idx) + } +} diff --git a/pkg/utils/apiutil/serverapi/middleware.go b/pkg/utils/apiutil/serverapi/middleware.go index 653ede75e7a..7d403ecef13 100644 --- a/pkg/utils/apiutil/serverapi/middleware.go +++ b/pkg/utils/apiutil/serverapi/middleware.go @@ -26,13 +26,6 @@ import ( "go.uber.org/zap" ) -// HTTP headers. -const ( - PDRedirectorHeader = "PD-Redirector" - PDAllowFollowerHandle = "PD-Allow-follower-handle" - ForwardedForHeader = "X-Forwarded-For" -) - type runtimeServiceValidator struct { s *server.Server group apiutil.APIServiceGroup @@ -130,7 +123,7 @@ func (h *redirector) matchMicroServiceRedirectRules(r *http.Request) (bool, stri func (h *redirector) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { matchedFlag, targetAddr := h.matchMicroServiceRedirectRules(r) - allowFollowerHandle := len(r.Header.Get(PDAllowFollowerHandle)) > 0 + allowFollowerHandle := len(r.Header.Get(apiutil.PDAllowFollowerHandleHeader)) > 0 isLeader := h.s.GetMember().IsLeader() if !h.s.IsClosed() && (allowFollowerHandle || isLeader) && !matchedFlag { next(w, r) @@ -138,14 +131,23 @@ func (h *redirector) ServeHTTP(w http.ResponseWriter, r *http.Request, next http } // Prevent more than one redirection. - if name := r.Header.Get(PDRedirectorHeader); len(name) != 0 { + if name := r.Header.Get(apiutil.PDRedirectorHeader); len(name) != 0 { log.Error("redirect but server is not leader", zap.String("from", name), zap.String("server", h.s.Name()), errs.ZapError(errs.ErrRedirect)) http.Error(w, apiutil.ErrRedirectToNotLeader, http.StatusInternalServerError) return } - r.Header.Set(PDRedirectorHeader, h.s.Name()) - r.Header.Add(ForwardedForHeader, r.RemoteAddr) + r.Header.Set(apiutil.PDRedirectorHeader, h.s.Name()) + forwardedIP, forwardedPort := apiutil.GetIPPortFromHTTPRequest(r) + if len(forwardedIP) > 0 { + r.Header.Add(apiutil.XForwardedForHeader, forwardedIP) + } else { + // Fallback if GetIPPortFromHTTPRequest failed to get the IP. + r.Header.Add(apiutil.XForwardedForHeader, r.RemoteAddr) + } + if len(forwardedPort) > 0 { + r.Header.Add(apiutil.XForwardedPortHeader, forwardedPort) + } var clientUrls []string if matchedFlag { diff --git a/pkg/utils/etcdutil/etcdutil.go b/pkg/utils/etcdutil/etcdutil.go index b59a9581996..9d7fe7bfeca 100644 --- a/pkg/utils/etcdutil/etcdutil.go +++ b/pkg/utils/etcdutil/etcdutil.go @@ -106,6 +106,10 @@ func AddEtcdMember(client *clientv3.Client, urls []string) (*clientv3.MemberAddR // ListEtcdMembers returns a list of internal etcd members. func ListEtcdMembers(client *clientv3.Client) (*clientv3.MemberListResponse, error) { + failpoint.Inject("SlowEtcdMemberList", func(val failpoint.Value) { + d := val.(int) + time.Sleep(time.Duration(d) * time.Second) + }) ctx, cancel := context.WithTimeout(client.Ctx(), DefaultRequestTimeout) listResp, err := client.MemberList(ctx) cancel() @@ -132,6 +136,10 @@ func EtcdKVGet(c *clientv3.Client, key string, opts ...clientv3.OpOption) (*clie defer cancel() start := time.Now() + failpoint.Inject("SlowEtcdKVGet", func(val failpoint.Value) { + d := val.(int) + time.Sleep(time.Duration(d) * time.Second) + }) resp, err := clientv3.NewKV(c).Get(ctx, key, opts...) if cost := time.Since(start); cost > DefaultSlowRequestTime { log.Warn("kv gets too slow", zap.String("request-key", key), zap.Duration("cost", cost), errs.ZapError(err)) diff --git a/pkg/utils/requestutil/request_info.go b/pkg/utils/requestutil/request_info.go index 73a7e299e16..40724bb790f 100644 --- a/pkg/utils/requestutil/request_info.go +++ b/pkg/utils/requestutil/request_info.go @@ -31,24 +31,27 @@ type RequestInfo struct { Method string Component string IP string + Port string URLParam string BodyParam string StartTimeStamp int64 } func (info *RequestInfo) String() string { - s := fmt.Sprintf("{ServiceLabel:%s, Method:%s, Component:%s, IP:%s, StartTime:%s, URLParam:%s, BodyParam:%s}", - info.ServiceLabel, info.Method, info.Component, info.IP, time.Unix(info.StartTimeStamp, 0), info.URLParam, info.BodyParam) + s := fmt.Sprintf("{ServiceLabel:%s, Method:%s, Component:%s, IP:%s, Port:%s, StartTime:%s, URLParam:%s, BodyParam:%s}", + info.ServiceLabel, info.Method, info.Component, info.IP, info.Port, time.Unix(info.StartTimeStamp, 0), info.URLParam, info.BodyParam) return s } // GetRequestInfo returns request info needed from http.Request func GetRequestInfo(r *http.Request) RequestInfo { + ip, port := apiutil.GetIPPortFromHTTPRequest(r) return RequestInfo{ ServiceLabel: apiutil.GetRouteName(r), Method: fmt.Sprintf("%s/%s:%s", r.Proto, r.Method, r.URL.Path), Component: apiutil.GetComponentNameOnHTTP(r), - IP: apiutil.GetIPAddrFromHTTPRequest(r), + IP: ip, + Port: port, URLParam: getURLParam(r), BodyParam: getBodyParam(r), StartTimeStamp: time.Now().Unix(), diff --git a/server/api/failpoint.go b/server/api/failpoint.go new file mode 100644 index 00000000000..05c3850fcd8 --- /dev/null +++ b/server/api/failpoint.go @@ -0,0 +1,22 @@ +// Copyright 2023 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build with_fail +// +build with_fail + +package api + +func init() { + enableFailPointAPI = true +} diff --git a/server/api/min_resolved_ts.go b/server/api/min_resolved_ts.go index 0d30ea3395e..ef05e91b9f7 100644 --- a/server/api/min_resolved_ts.go +++ b/server/api/min_resolved_ts.go @@ -17,6 +17,7 @@ package api import ( "net/http" "strconv" + "strings" "github.com/gorilla/mux" "github.com/tikv/pd/pkg/utils/typeutil" @@ -38,17 +39,18 @@ func newMinResolvedTSHandler(svr *server.Server, rd *render.Render) *minResolved // NOTE: This type is exported by HTTP API. Please pay more attention when modifying it. type minResolvedTS struct { - IsRealTime bool `json:"is_real_time,omitempty"` - MinResolvedTS uint64 `json:"min_resolved_ts"` - PersistInterval typeutil.Duration `json:"persist_interval,omitempty"` + IsRealTime bool `json:"is_real_time,omitempty"` + MinResolvedTS uint64 `json:"min_resolved_ts"` + PersistInterval typeutil.Duration `json:"persist_interval,omitempty"` + StoresMinResolvedTS map[uint64]uint64 `json:"stores_min_resolved_ts"` } // @Tags min_store_resolved_ts // @Summary Get store-level min resolved ts. -// @Produce json -// @Success 200 {array} minResolvedTS +// @Produce json +// @Success 200 {array} minResolvedTS // @Failure 400 {string} string "The input is invalid." -// @Failure 500 {string} string "PD server failed to proceed the request." +// @Failure 500 {string} string "PD server failed to proceed the request." // @Router /min-resolved-ts/{store_id} [get] func (h *minResolvedTSHandler) GetStoreMinResolvedTS(w http.ResponseWriter, r *http.Request) { c := h.svr.GetRaftCluster() @@ -67,19 +69,59 @@ func (h *minResolvedTSHandler) GetStoreMinResolvedTS(w http.ResponseWriter, r *h }) } -// @Tags min_resolved_ts -// @Summary Get cluster-level min resolved ts. +// @Tags min_resolved_ts +// @Summary Get cluster-level min resolved ts and optionally store-level min resolved ts. +// @Description Optionally, we support a query parameter `scope` +// to get store-level min resolved ts by specifying a list of store IDs. +// - When no scope is given, cluster-level's min_resolved_ts will be returned and storesMinResolvedTS will be nil. +// - When scope is `cluster`, cluster-level's min_resolved_ts will be returned and storesMinResolvedTS will be filled. +// - When scope given a list of stores, min_resolved_ts will be provided for each store +// and the scope-specific min_resolved_ts will be returned. +// // @Produce json +// @Param scope query string false "Scope of the min resolved ts: comma-separated list of store IDs (e.g., '1,2,3')." default(cluster) // @Success 200 {array} minResolvedTS // @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /min-resolved-ts [get] +// @Router /min-resolved-ts [get] func (h *minResolvedTSHandler) GetMinResolvedTS(w http.ResponseWriter, r *http.Request) { c := h.svr.GetRaftCluster() - value := c.GetMinResolvedTS() + scopeMinResolvedTS := c.GetMinResolvedTS() persistInterval := c.GetPDServerConfig().MinResolvedTSPersistenceInterval + + var storesMinResolvedTS map[uint64]uint64 + if scopeStr := r.URL.Query().Get("scope"); len(scopeStr) > 0 { + // scope is an optional parameter, it can be `cluster` or specified store IDs. + // - When no scope is given, cluster-level's min_resolved_ts will be returned and storesMinResolvedTS will be nil. + // - When scope is `cluster`, cluster-level's min_resolved_ts will be returned and storesMinResolvedTS will be filled. + // - When scope given a list of stores, min_resolved_ts will be provided for each store + // and the scope-specific min_resolved_ts will be returned. + if scopeStr == "cluster" { + stores := c.GetMetaStores() + ids := make([]uint64, len(stores)) + for i, store := range stores { + ids[i] = store.GetId() + } + // use cluster-level min_resolved_ts as the scope-specific min_resolved_ts. + _, storesMinResolvedTS = c.GetMinResolvedTSByStoreIDs(ids) + } else { + scopeIDs := strings.Split(scopeStr, ",") + ids := make([]uint64, len(scopeIDs)) + for i, idStr := range scopeIDs { + id, err := strconv.ParseUint(idStr, 10, 64) + if err != nil { + h.rd.JSON(w, http.StatusBadRequest, err.Error()) + return + } + ids[i] = id + } + scopeMinResolvedTS, storesMinResolvedTS = c.GetMinResolvedTSByStoreIDs(ids) + } + } + h.rd.JSON(w, http.StatusOK, minResolvedTS{ - MinResolvedTS: value, - PersistInterval: persistInterval, - IsRealTime: persistInterval.Duration != 0, + MinResolvedTS: scopeMinResolvedTS, + PersistInterval: persistInterval, + IsRealTime: persistInterval.Duration != 0, + StoresMinResolvedTS: storesMinResolvedTS, }) } diff --git a/server/api/min_resolved_ts_test.go b/server/api/min_resolved_ts_test.go index 79ab71e2be1..3abc7555919 100644 --- a/server/api/min_resolved_ts_test.go +++ b/server/api/min_resolved_ts_test.go @@ -17,6 +17,8 @@ package api import ( "fmt" "reflect" + "strconv" + "strings" "testing" "time" @@ -36,6 +38,7 @@ type minResolvedTSTestSuite struct { cleanup testutil.CleanupFunc url string defaultInterval time.Duration + storesNum int } func TestMinResolvedTSTestSuite(t *testing.T) { @@ -53,11 +56,13 @@ func (suite *minResolvedTSTestSuite) SetupSuite() { suite.url = fmt.Sprintf("%s%s/api/v1/min-resolved-ts", addr, apiPrefix) mustBootstrapCluster(re, suite.svr) - mustPutStore(re, suite.svr, 1, metapb.StoreState_Up, metapb.NodeState_Serving, nil) - r1 := core.NewTestRegionInfo(7, 1, []byte("a"), []byte("b")) - mustRegionHeartbeat(re, suite.svr, r1) - r2 := core.NewTestRegionInfo(8, 1, []byte("b"), []byte("c")) - mustRegionHeartbeat(re, suite.svr, r2) + suite.storesNum = 3 + for i := 1; i <= suite.storesNum; i++ { + id := uint64(i) + mustPutStore(re, suite.svr, id, metapb.StoreState_Up, metapb.NodeState_Serving, nil) + r := core.NewTestRegionInfo(id, id, []byte(fmt.Sprintf("%da", id)), []byte(fmt.Sprintf("%db", id))) + mustRegionHeartbeat(re, suite.svr, r) + } } func (suite *minResolvedTSTestSuite) TearDownSuite() { @@ -92,9 +97,8 @@ func (suite *minResolvedTSTestSuite) TestMinResolvedTS() { PersistInterval: interval, }) // case4: set min resolved ts - rc := suite.svr.GetRaftCluster() ts := uint64(233) - rc.SetMinResolvedTS(1, ts) + suite.setAllStoresMinResolvedTS(ts) suite.checkMinResolvedTS(&minResolvedTS{ MinResolvedTS: ts, IsRealTime: true, @@ -108,7 +112,7 @@ func (suite *minResolvedTSTestSuite) TestMinResolvedTS() { IsRealTime: false, PersistInterval: interval, }) - rc.SetMinResolvedTS(1, ts+1) + suite.setAllStoresMinResolvedTS(ts) suite.checkMinResolvedTS(&minResolvedTS{ MinResolvedTS: ts, // last persist value IsRealTime: false, @@ -116,12 +120,69 @@ func (suite *minResolvedTSTestSuite) TestMinResolvedTS() { }) } +func (suite *minResolvedTSTestSuite) TestMinResolvedTSByStores() { + // run job. + interval := typeutil.Duration{Duration: suite.defaultInterval} + suite.setMinResolvedTSPersistenceInterval(interval) + suite.Eventually(func() bool { + return interval == suite.svr.GetRaftCluster().GetPDServerConfig().MinResolvedTSPersistenceInterval + }, time.Second*10, time.Millisecond*20) + // set min resolved ts. + rc := suite.svr.GetRaftCluster() + ts := uint64(233) + + // scope is `cluster` + testStoresID := make([]string, 0) + testMap := make(map[uint64]uint64) + for i := 1; i <= suite.storesNum; i++ { + storeID := uint64(i) + testTS := ts + storeID + testMap[storeID] = testTS + rc.SetMinResolvedTS(storeID, testTS) + + testStoresID = append(testStoresID, strconv.Itoa(i)) + } + suite.checkMinResolvedTSByStores(&minResolvedTS{ + MinResolvedTS: 234, + IsRealTime: true, + PersistInterval: interval, + StoresMinResolvedTS: testMap, + }, "cluster") + + // set all stores min resolved ts. + testStoresIDStr := strings.Join(testStoresID, ",") + suite.checkMinResolvedTSByStores(&minResolvedTS{ + MinResolvedTS: 234, + IsRealTime: true, + PersistInterval: interval, + StoresMinResolvedTS: testMap, + }, testStoresIDStr) + + // remove last store for test. + testStoresID = testStoresID[:len(testStoresID)-1] + testStoresIDStr = strings.Join(testStoresID, ",") + delete(testMap, uint64(suite.storesNum)) + suite.checkMinResolvedTSByStores(&minResolvedTS{ + MinResolvedTS: 234, + IsRealTime: true, + PersistInterval: interval, + StoresMinResolvedTS: testMap, + }, testStoresIDStr) +} + func (suite *minResolvedTSTestSuite) setMinResolvedTSPersistenceInterval(duration typeutil.Duration) { cfg := suite.svr.GetRaftCluster().GetPDServerConfig().Clone() cfg.MinResolvedTSPersistenceInterval = duration suite.svr.GetRaftCluster().SetPDServerConfig(cfg) } +func (suite *minResolvedTSTestSuite) setAllStoresMinResolvedTS(ts uint64) { + rc := suite.svr.GetRaftCluster() + for i := 1; i <= suite.storesNum; i++ { + rc.SetMinResolvedTS(uint64(i), ts) + } +} + func (suite *minResolvedTSTestSuite) checkMinResolvedTS(expect *minResolvedTS) { suite.Eventually(func() bool { res, err := testDialClient.Get(suite.url) @@ -130,6 +191,20 @@ func (suite *minResolvedTSTestSuite) checkMinResolvedTS(expect *minResolvedTS) { listResp := &minResolvedTS{} err = apiutil.ReadJSON(res.Body, listResp) suite.NoError(err) + suite.Nil(listResp.StoresMinResolvedTS) + return reflect.DeepEqual(expect, listResp) + }, time.Second*10, time.Millisecond*20) +} + +func (suite *minResolvedTSTestSuite) checkMinResolvedTSByStores(expect *minResolvedTS, scope string) { + suite.Eventually(func() bool { + url := fmt.Sprintf("%s?scope=%s", suite.url, scope) + res, err := testDialClient.Get(url) + suite.NoError(err) + defer res.Body.Close() + listResp := &minResolvedTS{} + err = apiutil.ReadJSON(res.Body, listResp) + suite.NoError(err) return reflect.DeepEqual(expect, listResp) }, time.Second*10, time.Millisecond*20) } diff --git a/server/api/router.go b/server/api/router.go index ea99e82cdd8..ce649ba9aef 100644 --- a/server/api/router.go +++ b/server/api/router.go @@ -31,6 +31,9 @@ import ( "github.com/unrolled/render" ) +// enableFailPointAPI enable fail point API handler. +var enableFailPointAPI bool + // createRouteOption is used to register service for mux.Route type createRouteOption func(route *mux.Route) @@ -365,15 +368,13 @@ func createRouter(prefix string, svr *server.Server) *mux.Router { registerFunc(apiRouter, "/admin/reset-ts", tsoAdminHandler.ResetTS, setMethods(http.MethodPost), setAuditBackend(localLog, prometheus)) // API to set or unset failpoints - failpoint.Inject("enableFailpointAPI", func() { - // this function will be named to "func2". It may be used in test + if enableFailPointAPI { registerPrefix(apiRouter, "/fail", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // The HTTP handler of failpoint requires the full path to be the failpoint path. r.URL.Path = strings.TrimPrefix(r.URL.Path, prefix+apiPrefix+"/fail") new(failpoint.HttpHandler).ServeHTTP(w, r) - }), setAuditBackend("test")) - }) - + }), setAuditBackend(localLog)) + } // Deprecated: use /pd/api/v1/health instead. rootRouter.HandleFunc("/health", healthHandler.GetHealthStatus).Methods(http.MethodGet) // Deprecated: use /pd/api/v1/ping instead. diff --git a/server/apiv2/middlewares/redirector.go b/server/apiv2/middlewares/redirector.go index 5539dd089dc..285f096e823 100644 --- a/server/apiv2/middlewares/redirector.go +++ b/server/apiv2/middlewares/redirector.go @@ -22,7 +22,6 @@ import ( "github.com/pingcap/log" "github.com/tikv/pd/pkg/errs" "github.com/tikv/pd/pkg/utils/apiutil" - "github.com/tikv/pd/pkg/utils/apiutil/serverapi" "github.com/tikv/pd/server" "go.uber.org/zap" ) @@ -31,7 +30,7 @@ import ( func Redirector() gin.HandlerFunc { return func(c *gin.Context) { svr := c.MustGet(ServerContextKey).(*server.Server) - allowFollowerHandle := len(c.Request.Header.Get(serverapi.PDAllowFollowerHandle)) > 0 + allowFollowerHandle := len(c.Request.Header.Get(apiutil.PDAllowFollowerHandleHeader)) > 0 isLeader := svr.GetMember().IsLeader() if !svr.IsClosed() && (allowFollowerHandle || isLeader) { c.Next() @@ -39,13 +38,13 @@ func Redirector() gin.HandlerFunc { } // Prevent more than one redirection. - if name := c.Request.Header.Get(serverapi.PDRedirectorHeader); len(name) != 0 { + if name := c.Request.Header.Get(apiutil.PDRedirectorHeader); len(name) != 0 { log.Error("redirect but server is not leader", zap.String("from", name), zap.String("server", svr.Name()), errs.ZapError(errs.ErrRedirect)) c.AbortWithStatusJSON(http.StatusInternalServerError, errs.ErrRedirect.FastGenByArgs().Error()) return } - c.Request.Header.Set(serverapi.PDRedirectorHeader, svr.Name()) + c.Request.Header.Set(apiutil.PDRedirectorHeader, svr.Name()) leader := svr.GetMember().GetLeader() if leader == nil { diff --git a/server/cluster/cluster.go b/server/cluster/cluster.go index 06de6f9a56e..35bf14b617a 100644 --- a/server/cluster/cluster.go +++ b/server/cluster/cluster.go @@ -2548,10 +2548,29 @@ func (c *RaftCluster) GetMinResolvedTS() uint64 { func (c *RaftCluster) GetStoreMinResolvedTS(storeID uint64) uint64 { c.RLock() defer c.RUnlock() - if !c.isInitialized() || !core.IsAvailableForMinResolvedTS(c.GetStore(storeID)) { + store := c.GetStore(storeID) + if store == nil { + return math.MaxUint64 + } + if !c.isInitialized() || !core.IsAvailableForMinResolvedTS(store) { return math.MaxUint64 } - return c.GetStore(storeID).GetMinResolvedTS() + return store.GetMinResolvedTS() +} + +// GetMinResolvedTSByStoreIDs returns the min_resolved_ts for each store +// and returns the min_resolved_ts for all given store lists. +func (c *RaftCluster) GetMinResolvedTSByStoreIDs(ids []uint64) (uint64, map[uint64]uint64) { + minResolvedTS := uint64(math.MaxUint64) + storesMinResolvedTS := make(map[uint64]uint64) + for _, storeID := range ids { + storeTS := c.GetStoreMinResolvedTS(storeID) + storesMinResolvedTS[storeID] = storeTS + if minResolvedTS > storeTS { + minResolvedTS = storeTS + } + } + return minResolvedTS, storesMinResolvedTS } // GetExternalTS returns the external timestamp. diff --git a/server/server.go b/server/server.go index 1fdcd8497f0..c03ebbc17b4 100644 --- a/server/server.go +++ b/server/server.go @@ -1744,7 +1744,7 @@ func (s *Server) ReplicateFileToMember(ctx context.Context, member *pdpb.Member, } url := clientUrls[0] + filepath.Join("/pd/api/v1/admin/persist-file", name) req, _ := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(data)) - req.Header.Set("PD-Allow-follower-handle", "true") + req.Header.Set(apiutil.PDAllowFollowerHandleHeader, "true") res, err := s.httpClient.Do(req) if err != nil { log.Warn("failed to replicate file", zap.String("name", name), zap.String("member", member.GetName()), errs.ZapError(err)) diff --git a/server/server_test.go b/server/server_test.go index 47ec2dd735c..2d0e23c7682 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -24,6 +24,7 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "github.com/tikv/pd/pkg/mcs/utils" + "github.com/tikv/pd/pkg/utils/apiutil" "github.com/tikv/pd/pkg/utils/assertutil" "github.com/tikv/pd/pkg/utils/etcdutil" "github.com/tikv/pd/pkg/utils/testutil" @@ -218,7 +219,7 @@ func (suite *leaderServerTestSuite) TestSourceIpForHeaderForwarded() { req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%s/pd/apis/mock/v1/hello", svr.GetAddr()), nil) suite.NoError(err) - req.Header.Add("X-Forwarded-For", "127.0.0.2") + req.Header.Add(apiutil.XForwardedForHeader, "127.0.0.2") resp, err := http.DefaultClient.Do(req) suite.NoError(err) suite.Equal(http.StatusOK, resp.StatusCode) @@ -248,7 +249,7 @@ func (suite *leaderServerTestSuite) TestSourceIpForHeaderXReal() { req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%s/pd/apis/mock/v1/hello", svr.GetAddr()), nil) suite.NoError(err) - req.Header.Add("X-Real-Ip", "127.0.0.2") + req.Header.Add(apiutil.XRealIPHeader, "127.0.0.2") resp, err := http.DefaultClient.Do(req) suite.NoError(err) suite.Equal(http.StatusOK, resp.StatusCode) @@ -278,8 +279,8 @@ func (suite *leaderServerTestSuite) TestSourceIpForHeaderBoth() { req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%s/pd/apis/mock/v1/hello", svr.GetAddr()), nil) suite.NoError(err) - req.Header.Add("X-Forwarded-For", "127.0.0.2") - req.Header.Add("X-Real-Ip", "127.0.0.3") + req.Header.Add(apiutil.XForwardedForHeader, "127.0.0.2") + req.Header.Add(apiutil.XRealIPHeader, "127.0.0.3") resp, err := http.DefaultClient.Do(req) suite.NoError(err) suite.Equal(http.StatusOK, resp.StatusCode) diff --git a/server/testutil.go b/server/testutil.go index 506139e20f1..cc1a380bfb8 100644 --- a/server/testutil.go +++ b/server/testutil.go @@ -143,7 +143,7 @@ func CreateMockHandler(re *require.Assertions, ip string) HandlerBuilder { mux.HandleFunc("/pd/apis/mock/v1/hello", func(w http.ResponseWriter, r *http.Request) { fmt.Fprintln(w, "Hello World") // test getting ip - clientIP := apiutil.GetIPAddrFromHTTPRequest(r) + clientIP, _ := apiutil.GetIPPortFromHTTPRequest(r) re.Equal(ip, clientIP) }) info := apiutil.APIServiceGroup{ diff --git a/tests/integrations/mcs/scheduling/config_test.go b/tests/integrations/mcs/scheduling/config_test.go index e2f2eeacfd1..e1c124b965f 100644 --- a/tests/integrations/mcs/scheduling/config_test.go +++ b/tests/integrations/mcs/scheduling/config_test.go @@ -22,7 +22,7 @@ import ( "github.com/stretchr/testify/suite" "github.com/tikv/pd/pkg/mcs/scheduling/server/config" sc "github.com/tikv/pd/pkg/schedule/config" - "github.com/tikv/pd/pkg/storage/endpoint" + "github.com/tikv/pd/pkg/schedule/schedulers" "github.com/tikv/pd/pkg/utils/testutil" "github.com/tikv/pd/pkg/versioninfo" "github.com/tikv/pd/tests" @@ -47,6 +47,7 @@ func TestConfig(t *testing.T) { func (suite *configTestSuite) SetupSuite() { re := suite.Require() + schedulers.Register() var err error suite.ctx, suite.cancel = context.WithCancel(context.Background()) suite.cluster, err = tests.NewTestAPICluster(suite.ctx, 1) @@ -56,6 +57,8 @@ func (suite *configTestSuite) SetupSuite() { leaderName := suite.cluster.WaitLeader() suite.pdLeaderServer = suite.cluster.GetServer(leaderName) re.NoError(suite.pdLeaderServer.BootstrapCluster()) + // Force the coordinator to be prepared to initialize the schedulers. + suite.pdLeaderServer.GetRaftCluster().GetCoordinator().GetPrepareChecker().SetPrepared() } func (suite *configTestSuite) TearDownSuite() { @@ -72,7 +75,7 @@ func (suite *configTestSuite) TestConfigWatch() { watcher, err := config.NewWatcher( suite.ctx, suite.pdLeaderServer.GetEtcdClient(), - endpoint.ConfigPath(suite.cluster.GetCluster().GetId()), + suite.cluster.GetCluster().GetId(), config.NewPersistConfig(config.NewConfig()), ) re.NoError(err) @@ -118,3 +121,48 @@ func persistConfig(re *require.Assertions, pdLeaderServer *tests.TestServer) { err := pdLeaderServer.GetPersistOptions().Persist(pdLeaderServer.GetServer().GetStorage()) re.NoError(err) } + +func (suite *configTestSuite) TestSchedulerConfigWatch() { + re := suite.Require() + + // Make sure the config is persisted before the watcher is created. + persistConfig(re, suite.pdLeaderServer) + // Create a config watcher. + watcher, err := config.NewWatcher( + suite.ctx, + suite.pdLeaderServer.GetEtcdClient(), + suite.cluster.GetCluster().GetId(), + config.NewPersistConfig(config.NewConfig()), + ) + re.NoError(err) + // Get all default scheduler names. + var ( + schedulerNames []string + schedulerController = suite.pdLeaderServer.GetRaftCluster().GetCoordinator().GetSchedulersController() + ) + testutil.Eventually(re, func() bool { + schedulerNames = schedulerController.GetSchedulerNames() + return len(schedulerNames) == len(sc.DefaultSchedulers) + }) + // Check all default schedulers' configs. + for _, schedulerName := range schedulerNames { + testutil.Eventually(re, func() bool { + return len(watcher.GetSchedulerConfig(schedulerName)) > 0 + }) + } + // Add a new scheduler. + err = suite.pdLeaderServer.GetServer().GetHandler().AddEvictLeaderScheduler(1) + re.NoError(err) + // Check the new scheduler's config. + testutil.Eventually(re, func() bool { + return len(watcher.GetSchedulerConfig(schedulers.EvictLeaderName)) > 0 + }) + // Remove the scheduler. + err = suite.pdLeaderServer.GetServer().GetHandler().RemoveScheduler(schedulers.EvictLeaderName) + re.NoError(err) + // Check the removed scheduler's config. + testutil.Eventually(re, func() bool { + return len(watcher.GetSchedulerConfig(schedulers.EvictLeaderName)) == 0 + }) + watcher.Close() +} diff --git a/tests/integrations/mcs/scheduling/rule_test.go b/tests/integrations/mcs/scheduling/rule_test.go index bdffb6b2bb9..68347366378 100644 --- a/tests/integrations/mcs/scheduling/rule_test.go +++ b/tests/integrations/mcs/scheduling/rule_test.go @@ -100,13 +100,10 @@ func (suite *ruleTestSuite) TestRuleWatch() { re := suite.Require() // Create a rule watcher. - clusterID := suite.cluster.GetCluster().GetId() watcher, err := rule.NewWatcher( suite.ctx, suite.pdLeaderServer.GetEtcdClient(), - endpoint.RulesPath(clusterID), - endpoint.RuleGroupPath(clusterID), - endpoint.RegionLabelPath(clusterID), + suite.cluster.GetCluster().GetId(), ) re.NoError(err) ruleStorage := watcher.GetRuleStorage() diff --git a/tests/server/api/api_test.go b/tests/server/api/api_test.go index 080eeb44b1b..375a0cf7c80 100644 --- a/tests/server/api/api_test.go +++ b/tests/server/api/api_test.go @@ -34,7 +34,7 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "github.com/tikv/pd/pkg/core" - "github.com/tikv/pd/pkg/utils/apiutil/serverapi" + "github.com/tikv/pd/pkg/utils/apiutil" "github.com/tikv/pd/pkg/utils/testutil" "github.com/tikv/pd/pkg/utils/typeutil" "github.com/tikv/pd/server" @@ -621,10 +621,10 @@ func (suite *redirectorTestSuite) TestAllowFollowerHandle() { addr := follower.GetAddr() + "/pd/api/v1/version" request, err := http.NewRequest(http.MethodGet, addr, nil) suite.NoError(err) - request.Header.Add(serverapi.PDAllowFollowerHandle, "true") + request.Header.Add(apiutil.PDAllowFollowerHandleHeader, "true") resp, err := dialClient.Do(request) suite.NoError(err) - suite.Equal("", resp.Header.Get(serverapi.PDRedirectorHeader)) + suite.Equal("", resp.Header.Get(apiutil.PDRedirectorHeader)) defer resp.Body.Close() suite.Equal(http.StatusOK, resp.StatusCode) _, err = io.ReadAll(resp.Body) @@ -655,7 +655,7 @@ func (suite *redirectorTestSuite) TestNotLeader() { // Request to follower with redirectorHeader will fail. request.RequestURI = "" - request.Header.Set(serverapi.PDRedirectorHeader, "pd") + request.Header.Set(apiutil.PDRedirectorHeader, "pd") resp1, err := dialClient.Do(request) suite.NoError(err) defer resp1.Body.Close() diff --git a/tests/server/cluster/cluster_test.go b/tests/server/cluster/cluster_test.go index c15520aca3c..87acdf897fd 100644 --- a/tests/server/cluster/cluster_test.go +++ b/tests/server/cluster/cluster_test.go @@ -1301,6 +1301,13 @@ func checkMinResolvedTS(re *require.Assertions, rc *cluster.RaftCluster, expect }, time.Second*10, time.Millisecond*50) } +func checkStoreMinResolvedTS(re *require.Assertions, rc *cluster.RaftCluster, expectTS, storeID uint64) { + re.Eventually(func() bool { + ts := rc.GetStoreMinResolvedTS(storeID) + return expectTS == ts + }, time.Second*10, time.Millisecond*50) +} + func checkMinResolvedTSFromStorage(re *require.Assertions, rc *cluster.RaftCluster, expect uint64) { re.Eventually(func() bool { ts2, err := rc.GetStorage().LoadMinResolvedTS() @@ -1400,6 +1407,9 @@ func TestMinResolvedTS(t *testing.T) { resetStoreState(re, rc, store1, metapb.StoreState_Tombstone) checkMinResolvedTS(re, rc, store3TS) checkMinResolvedTSFromStorage(re, rc, store3TS) + checkStoreMinResolvedTS(re, rc, store3TS, store3) + // check no-exist store + checkStoreMinResolvedTS(re, rc, math.MaxUint64, 100) // case7: add a store with leader peer but no report min resolved ts // min resolved ts should be no change @@ -1419,6 +1429,7 @@ func TestMinResolvedTS(t *testing.T) { checkMinResolvedTS(re, rc, store3TS) setMinResolvedTSPersistenceInterval(re, rc, svr, time.Millisecond) checkMinResolvedTS(re, rc, store5TS) + checkStoreMinResolvedTS(re, rc, store5TS, store5) } // See https://github.com/tikv/pd/issues/4941 diff --git a/tools/pd-api-bench/README.md b/tools/pd-api-bench/README.md index 13b7feb6b25..0ab4ea6463b 100644 --- a/tools/pd-api-bench/README.md +++ b/tools/pd-api-bench/README.md @@ -59,11 +59,22 @@ The api bench cases we support are as follows: -debug > print the output of api response for debug +### Run Shell + You can run shell as follows. ```shell go run main.go -http-cases GetRegionStatus-1+1,GetMinResolvedTS-1+1 -client 1 -debug ``` +### HTTP params + +You can use the following command to set the params of HTTP request: +```shell +go run main.go -http-cases GetMinResolvedTS-1+1 -params 'scope=cluster' -client 1 -debug +``` +for more params, can use like `-params 'A=1&B=2&C=3'` + + ### TLS You can use the following command to generate a certificate for testing TLS: @@ -74,4 +85,4 @@ mkdir cert go run main.go -http-cases GetRegionStatus-1+1,GetMinResolvedTS-1+1 -client 1 -debug -cacert ./cert/ca.pem -cert ./cert/pd-server.pem -key ./cert/pd-server-key.pem ./cert_opt.sh cleanup cert rm -rf cert -``` \ No newline at end of file +``` diff --git a/tools/pd-api-bench/cases/cases.go b/tools/pd-api-bench/cases/cases.go index a3154f1462b..2b770805cd8 100644 --- a/tools/pd-api-bench/cases/cases.go +++ b/tools/pd-api-bench/cases/cases.go @@ -116,6 +116,7 @@ var GRPCCaseMap = map[string]GRPCCase{ type HTTPCase interface { Case Do(context.Context, *http.Client) error + Params(string) } var HTTPCaseMap = map[string]HTTPCase{ @@ -125,7 +126,8 @@ var HTTPCaseMap = map[string]HTTPCase{ type minResolvedTS struct { *baseCase - path string + path string + params string } func newMinResolvedTS() *minResolvedTS { @@ -140,14 +142,15 @@ func newMinResolvedTS() *minResolvedTS { } type minResolvedTSStruct struct { - IsRealTime bool `json:"is_real_time,omitempty"` - MinResolvedTS uint64 `json:"min_resolved_ts"` - PersistInterval typeutil.Duration `json:"persist_interval,omitempty"` + IsRealTime bool `json:"is_real_time,omitempty"` + MinResolvedTS uint64 `json:"min_resolved_ts"` + PersistInterval typeutil.Duration `json:"persist_interval,omitempty"` + StoresMinResolvedTS map[uint64]uint64 `json:"stores_min_resolved_ts"` } func (c *minResolvedTS) Do(ctx context.Context, cli *http.Client) error { - storeIdx := rand.Intn(int(totalStore)) - req, _ := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("%s%s/%d", PDAddress, c.path, storesID[storeIdx]), nil) + url := fmt.Sprintf("%s%s", PDAddress, c.path) + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) res, err := cli.Do(req) if err != nil { return err @@ -155,7 +158,7 @@ func (c *minResolvedTS) Do(ctx context.Context, cli *http.Client) error { listResp := &minResolvedTSStruct{} err = apiutil.ReadJSON(res.Body, listResp) if Debug { - log.Printf("Do %s: %v %v", c.name, listResp, err) + log.Printf("Do %s: url: %s resp: %v err: %v", c.name, url, listResp, err) } if err != nil { return err @@ -164,6 +167,11 @@ func (c *minResolvedTS) Do(ctx context.Context, cli *http.Client) error { return nil } +func (c *minResolvedTS) Params(param string) { + c.params = param + c.path = fmt.Sprintf("%s?%s", c.path, c.params) +} + type regionsStats struct { *baseCase regionSample int @@ -183,20 +191,20 @@ func newRegionStats() *regionsStats { } func (c *regionsStats) Do(ctx context.Context, cli *http.Client) error { - upperBound := int(totalRegion) / c.regionSample + upperBound := totalRegion / c.regionSample if upperBound < 1 { upperBound = 1 } random := rand.Intn(upperBound) startID := c.regionSample*random*4 + 1 endID := c.regionSample*(random+1)*4 + 1 - req, _ := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("%s%s?start_key=%s&end_key=%s&%s", + url := fmt.Sprintf("%s%s?start_key=%s&end_key=%s&%s", PDAddress, c.path, url.QueryEscape(string(generateKeyForSimulator(startID, 56))), url.QueryEscape(string(generateKeyForSimulator(endID, 56))), - "", - ), nil) + "") + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) res, err := cli.Do(req) if err != nil { return err @@ -204,7 +212,7 @@ func (c *regionsStats) Do(ctx context.Context, cli *http.Client) error { statsResp := &statistics.RegionStats{} err = apiutil.ReadJSON(res.Body, statsResp) if Debug { - log.Printf("Do %s: %v %v", c.name, statsResp, err) + log.Printf("Do %s: url: %s resp: %v err: %v", c.name, url, statsResp, err) } if err != nil { return err @@ -213,6 +221,8 @@ func (c *regionsStats) Do(ctx context.Context, cli *http.Client) error { return nil } +func (c *regionsStats) Params(_ string) {} + type getRegion struct { *baseCase } @@ -228,7 +238,7 @@ func newGetRegion() *getRegion { } func (c *getRegion) Unary(ctx context.Context, cli pd.Client) error { - id := rand.Intn(int(totalRegion))*4 + 1 + id := rand.Intn(totalRegion)*4 + 1 _, err := cli.GetRegion(ctx, generateKeyForSimulator(id, 56)) if err != nil { return err @@ -253,7 +263,7 @@ func newScanRegions() *scanRegions { } func (c *scanRegions) Unary(ctx context.Context, cli pd.Client) error { - upperBound := int(totalRegion) / c.regionSample + upperBound := totalRegion / c.regionSample random := rand.Intn(upperBound) startID := c.regionSample*random*4 + 1 endID := c.regionSample*(random+1)*4 + 1 @@ -279,7 +289,7 @@ func newGetStore() *getStore { } func (c *getStore) Unary(ctx context.Context, cli pd.Client) error { - storeIdx := rand.Intn(int(totalStore)) + storeIdx := rand.Intn(totalStore) _, err := cli.GetStore(ctx, storesID[storeIdx]) if err != nil { return err diff --git a/tools/pd-api-bench/main.go b/tools/pd-api-bench/main.go index 7032ef1df00..a891f7d2318 100644 --- a/tools/pd-api-bench/main.go +++ b/tools/pd-api-bench/main.go @@ -48,13 +48,16 @@ var ( qps = flag.Int64("qps", 1000, "qps") burst = flag.Int64("burst", 1, "burst") + // http params + httpParams = flag.String("params", "", "http params") + // tls caPath = flag.String("cacert", "", "path of file that contains list of trusted SSL CAs") certPath = flag.String("cert", "", "path of file that contains X509 certificate in PEM format") keyPath = flag.String("key", "", "path of file that contains X509 key in PEM format") ) -var base int64 = int64(time.Second) / int64(time.Microsecond) +var base = int64(time.Second) / int64(time.Microsecond) func main() { flag.Parse() @@ -216,6 +219,9 @@ func handleHTTPCase(ctx context.Context, hcase cases.HTTPCase, httpClis []*http. burst := hcase.GetBurst() tt := time.Duration(base/qps*burst*int64(*client)) * time.Microsecond log.Printf("begin to run http case %s, with qps = %d and burst = %d, interval is %v", hcase.Name(), qps, burst, tt) + if *httpParams != "" { + hcase.Params(*httpParams) + } for _, hCli := range httpClis { go func(hCli *http.Client) { var ticker = time.NewTicker(tt) diff --git a/tools/pd-ctl/pdctl/command/log_command.go b/tools/pd-ctl/pdctl/command/log_command.go index ec22884ecad..56c4438a6c3 100644 --- a/tools/pd-ctl/pdctl/command/log_command.go +++ b/tools/pd-ctl/pdctl/command/log_command.go @@ -20,6 +20,7 @@ import ( "net/http" "github.com/spf13/cobra" + "github.com/tikv/pd/pkg/utils/apiutil" ) var ( @@ -55,7 +56,7 @@ func logCommandFunc(cmd *cobra.Command, args []string) { cmd.Printf("Failed to parse address %v: %s\n", args[1], err) return } - _, err = doRequestSingleEndpoint(cmd, url, logPrefix, http.MethodPost, http.Header{"Content-Type": {"application/json"}, "PD-Allow-follower-handle": {"true"}}, + _, err = doRequestSingleEndpoint(cmd, url, logPrefix, http.MethodPost, http.Header{"Content-Type": {"application/json"}, apiutil.PDAllowFollowerHandleHeader: {"true"}}, WithBody(bytes.NewBuffer(data))) if err != nil { cmd.Printf("Failed to set %v log level: %s\n", args[1], err)