From 5b3d0172b0db653a97257373fac9ce9d9699c6f5 Mon Sep 17 00:00:00 2001 From: Ryan Leung Date: Thu, 21 Sep 2023 14:24:13 +0800 Subject: [PATCH 1/8] *: fix sync isolation level to default placement rule (#7122) close tikv/pd#7121 Signed-off-by: Ryan Leung Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- pkg/mcs/scheduling/server/cluster.go | 2 +- pkg/mock/mockcluster/mockcluster.go | 2 +- pkg/schedule/checker/rule_checker_test.go | 35 ++++++++++++++++++ pkg/schedule/placement/rule_manager.go | 5 ++- pkg/schedule/placement/rule_manager_test.go | 6 ++-- pkg/statistics/region_collection_test.go | 4 +-- server/api/operator_test.go | 4 ++- server/cluster/cluster.go | 2 +- server/cluster/cluster_test.go | 10 +++--- server/config/persist_options.go | 7 ++++ server/server.go | 11 +++--- tests/pdctl/config/config_test.go | 40 ++++++++++++++++++--- 12 files changed, 104 insertions(+), 24 deletions(-) diff --git a/pkg/mcs/scheduling/server/cluster.go b/pkg/mcs/scheduling/server/cluster.go index 81c82d73d33b..b2986f722df7 100644 --- a/pkg/mcs/scheduling/server/cluster.go +++ b/pkg/mcs/scheduling/server/cluster.go @@ -75,7 +75,7 @@ func NewCluster(parentCtx context.Context, persistConfig *config.PersistConfig, checkMembershipCh: checkMembershipCh, } c.coordinator = schedule.NewCoordinator(ctx, c, hbStreams) - err = c.ruleManager.Initialize(persistConfig.GetMaxReplicas(), persistConfig.GetLocationLabels()) + err = c.ruleManager.Initialize(persistConfig.GetMaxReplicas(), persistConfig.GetLocationLabels(), persistConfig.GetIsolationLevel()) if err != nil { cancel() return nil, err diff --git a/pkg/mock/mockcluster/mockcluster.go b/pkg/mock/mockcluster/mockcluster.go index 1ed7ab4eb9ff..01282b405346 100644 --- a/pkg/mock/mockcluster/mockcluster.go +++ b/pkg/mock/mockcluster/mockcluster.go @@ -213,7 +213,7 @@ func (mc *Cluster) AllocPeer(storeID uint64) (*metapb.Peer, error) { func (mc *Cluster) initRuleManager() { if mc.RuleManager == nil { mc.RuleManager = placement.NewRuleManager(mc.GetStorage(), mc, mc.GetSharedConfig()) - mc.RuleManager.Initialize(int(mc.GetReplicationConfig().MaxReplicas), mc.GetReplicationConfig().LocationLabels) + mc.RuleManager.Initialize(int(mc.GetReplicationConfig().MaxReplicas), mc.GetReplicationConfig().LocationLabels, mc.GetReplicationConfig().IsolationLevel) } } diff --git a/pkg/schedule/checker/rule_checker_test.go b/pkg/schedule/checker/rule_checker_test.go index cbd7624f3b16..ad140e91606b 100644 --- a/pkg/schedule/checker/rule_checker_test.go +++ b/pkg/schedule/checker/rule_checker_test.go @@ -112,6 +112,41 @@ func (suite *ruleCheckerTestSuite) TestAddRulePeerWithIsolationLevel() { suite.Equal(uint64(4), op.Step(0).(operator.AddLearner).ToStore) } +func (suite *ruleCheckerTestSuite) TestReplaceDownPeerWithIsolationLevel() { + suite.cluster.SetMaxStoreDownTime(100 * time.Millisecond) + suite.cluster.AddLabelsStore(1, 1, map[string]string{"zone": "z1", "host": "h1"}) + suite.cluster.AddLabelsStore(2, 1, map[string]string{"zone": "z1", "host": "h2"}) + suite.cluster.AddLabelsStore(3, 1, map[string]string{"zone": "z2", "host": "h3"}) + suite.cluster.AddLabelsStore(4, 1, map[string]string{"zone": "z2", "host": "h4"}) + suite.cluster.AddLabelsStore(5, 1, map[string]string{"zone": "z3", "host": "h5"}) + suite.cluster.AddLabelsStore(6, 1, map[string]string{"zone": "z3", "host": "h6"}) + suite.cluster.AddLeaderRegionWithRange(1, "", "", 1, 3, 5) + suite.ruleManager.DeleteRule("pd", "default") + suite.ruleManager.SetRule(&placement.Rule{ + GroupID: "pd", + ID: "test", + Index: 100, + Override: true, + Role: placement.Voter, + Count: 3, + LocationLabels: []string{"zone", "host"}, + IsolationLevel: "zone", + }) + op := suite.rc.Check(suite.cluster.GetRegion(1)) + suite.Nil(op) + region := suite.cluster.GetRegion(1) + downPeer := []*pdpb.PeerStats{ + {Peer: region.GetStorePeer(5), DownSeconds: 6000}, + } + region = region.Clone(core.WithDownPeers(downPeer)) + suite.cluster.PutRegion(region) + suite.cluster.SetStoreDown(5) + suite.cluster.SetStoreDown(6) + time.Sleep(200 * time.Millisecond) + op = suite.rc.Check(suite.cluster.GetRegion(1)) + suite.Nil(op) +} + func (suite *ruleCheckerTestSuite) TestFixPeer() { suite.cluster.AddLeaderStore(1, 1) suite.cluster.AddLeaderStore(2, 1) diff --git a/pkg/schedule/placement/rule_manager.go b/pkg/schedule/placement/rule_manager.go index 3bd272a00ac4..909c0fa10785 100644 --- a/pkg/schedule/placement/rule_manager.go +++ b/pkg/schedule/placement/rule_manager.go @@ -66,7 +66,7 @@ func NewRuleManager(storage endpoint.RuleStorage, storeSetInformer core.StoreSet // Initialize loads rules from storage. If Placement Rules feature is never enabled, it creates default rule that is // compatible with previous configuration. -func (m *RuleManager) Initialize(maxReplica int, locationLabels []string) error { +func (m *RuleManager) Initialize(maxReplica int, locationLabels []string, isolationLevel string) error { m.Lock() defer m.Unlock() if m.initialized { @@ -93,6 +93,7 @@ func (m *RuleManager) Initialize(maxReplica int, locationLabels []string) error Role: Voter, Count: maxReplica - witnessCount, LocationLabels: locationLabels, + IsolationLevel: isolationLevel, }, { GroupID: "pd", @@ -101,6 +102,7 @@ func (m *RuleManager) Initialize(maxReplica int, locationLabels []string) error Count: witnessCount, IsWitness: true, LocationLabels: locationLabels, + IsolationLevel: isolationLevel, }, }..., ) @@ -111,6 +113,7 @@ func (m *RuleManager) Initialize(maxReplica int, locationLabels []string) error Role: Voter, Count: maxReplica, LocationLabels: locationLabels, + IsolationLevel: isolationLevel, }) } for _, defaultRule := range defaultRules { diff --git a/pkg/schedule/placement/rule_manager_test.go b/pkg/schedule/placement/rule_manager_test.go index e5be8d74cd21..a6454337aa84 100644 --- a/pkg/schedule/placement/rule_manager_test.go +++ b/pkg/schedule/placement/rule_manager_test.go @@ -34,7 +34,7 @@ func newTestManager(t *testing.T, enableWitness bool) (endpoint.RuleStorage, *Ru var err error manager := NewRuleManager(store, nil, mockconfig.NewTestOptions()) manager.conf.SetEnableWitness(enableWitness) - err = manager.Initialize(3, []string{"zone", "rack", "host"}) + err = manager.Initialize(3, []string{"zone", "rack", "host"}, "") re.NoError(err) return store, manager } @@ -157,7 +157,7 @@ func TestSaveLoad(t *testing.T) { } m2 := NewRuleManager(store, nil, nil) - err := m2.Initialize(3, []string{"no", "labels"}) + err := m2.Initialize(3, []string{"no", "labels"}, "") re.NoError(err) re.Len(m2.GetAllRules(), 3) re.Equal(rules[0].String(), m2.GetRule("pd", "default").String()) @@ -173,7 +173,7 @@ func TestSetAfterGet(t *testing.T) { manager.SetRule(rule) m2 := NewRuleManager(store, nil, nil) - err := m2.Initialize(100, []string{}) + err := m2.Initialize(100, []string{}, "") re.NoError(err) rule = m2.GetRule("pd", "default") re.Equal(1, rule.Count) diff --git a/pkg/statistics/region_collection_test.go b/pkg/statistics/region_collection_test.go index 232fb8b73d8b..2706ffeb0436 100644 --- a/pkg/statistics/region_collection_test.go +++ b/pkg/statistics/region_collection_test.go @@ -30,7 +30,7 @@ func TestRegionStatistics(t *testing.T) { re := require.New(t) store := storage.NewStorageWithMemoryBackend() manager := placement.NewRuleManager(store, nil, nil) - err := manager.Initialize(3, []string{"zone", "rack", "host"}) + err := manager.Initialize(3, []string{"zone", "rack", "host"}, "") re.NoError(err) opt := mockconfig.NewTestOptions() opt.SetPlacementRuleEnabled(false) @@ -120,7 +120,7 @@ func TestRegionStatisticsWithPlacementRule(t *testing.T) { re := require.New(t) store := storage.NewStorageWithMemoryBackend() manager := placement.NewRuleManager(store, nil, nil) - err := manager.Initialize(3, []string{"zone", "rack", "host"}) + err := manager.Initialize(3, []string{"zone", "rack", "host"}, "") re.NoError(err) opt := mockconfig.NewTestOptions() opt.SetPlacementRuleEnabled(true) diff --git a/server/api/operator_test.go b/server/api/operator_test.go index ddb605c7d877..ee849552f09b 100644 --- a/server/api/operator_test.go +++ b/server/api/operator_test.go @@ -383,7 +383,9 @@ func (suite *transferRegionOperatorTestSuite) TestTransferRegionWithPlacementRul if testCase.placementRuleEnable { err := suite.svr.GetRaftCluster().GetRuleManager().Initialize( suite.svr.GetRaftCluster().GetOpts().GetMaxReplicas(), - suite.svr.GetRaftCluster().GetOpts().GetLocationLabels()) + suite.svr.GetRaftCluster().GetOpts().GetLocationLabels(), + suite.svr.GetRaftCluster().GetOpts().GetIsolationLevel(), + ) suite.NoError(err) } if len(testCase.rules) > 0 { diff --git a/server/cluster/cluster.go b/server/cluster/cluster.go index 771fb03ac202..d42dbb21ed13 100644 --- a/server/cluster/cluster.go +++ b/server/cluster/cluster.go @@ -301,7 +301,7 @@ func (c *RaftCluster) Start(s Server) error { c.ruleManager = placement.NewRuleManager(c.storage, c, c.GetOpts()) if c.opt.IsPlacementRulesEnabled() { - err = c.ruleManager.Initialize(c.opt.GetMaxReplicas(), c.opt.GetLocationLabels()) + err = c.ruleManager.Initialize(c.opt.GetMaxReplicas(), c.opt.GetLocationLabels(), c.opt.GetIsolationLevel()) if err != nil { return err } diff --git a/server/cluster/cluster_test.go b/server/cluster/cluster_test.go index c9d4d0f8f610..aa826e344069 100644 --- a/server/cluster/cluster_test.go +++ b/server/cluster/cluster_test.go @@ -243,7 +243,7 @@ func TestSetOfflineStore(t *testing.T) { cluster.coordinator = schedule.NewCoordinator(ctx, cluster, nil) cluster.ruleManager = placement.NewRuleManager(storage.NewStorageWithMemoryBackend(), cluster, cluster.GetOpts()) if opt.IsPlacementRulesEnabled() { - err := cluster.ruleManager.Initialize(opt.GetMaxReplicas(), opt.GetLocationLabels()) + err := cluster.ruleManager.Initialize(opt.GetMaxReplicas(), opt.GetLocationLabels(), opt.GetIsolationLevel()) if err != nil { panic(err) } @@ -440,7 +440,7 @@ func TestUpStore(t *testing.T) { cluster.coordinator = schedule.NewCoordinator(ctx, cluster, nil) cluster.ruleManager = placement.NewRuleManager(storage.NewStorageWithMemoryBackend(), cluster, cluster.GetOpts()) if opt.IsPlacementRulesEnabled() { - err := cluster.ruleManager.Initialize(opt.GetMaxReplicas(), opt.GetLocationLabels()) + err := cluster.ruleManager.Initialize(opt.GetMaxReplicas(), opt.GetLocationLabels(), opt.GetIsolationLevel()) if err != nil { panic(err) } @@ -543,7 +543,7 @@ func TestDeleteStoreUpdatesClusterVersion(t *testing.T) { cluster.coordinator = schedule.NewCoordinator(ctx, cluster, nil) cluster.ruleManager = placement.NewRuleManager(storage.NewStorageWithMemoryBackend(), cluster, cluster.GetOpts()) if opt.IsPlacementRulesEnabled() { - err := cluster.ruleManager.Initialize(opt.GetMaxReplicas(), opt.GetLocationLabels()) + err := cluster.ruleManager.Initialize(opt.GetMaxReplicas(), opt.GetLocationLabels(), opt.GetIsolationLevel()) if err != nil { panic(err) } @@ -1270,7 +1270,7 @@ func TestOfflineAndMerge(t *testing.T) { cluster.coordinator = schedule.NewCoordinator(ctx, cluster, nil) cluster.ruleManager = placement.NewRuleManager(storage.NewStorageWithMemoryBackend(), cluster, cluster.GetOpts()) if opt.IsPlacementRulesEnabled() { - err := cluster.ruleManager.Initialize(opt.GetMaxReplicas(), opt.GetLocationLabels()) + err := cluster.ruleManager.Initialize(opt.GetMaxReplicas(), opt.GetLocationLabels(), opt.GetIsolationLevel()) if err != nil { panic(err) } @@ -2129,7 +2129,7 @@ func newTestRaftCluster( rc.InitCluster(id, opt, s, basicCluster, nil) rc.ruleManager = placement.NewRuleManager(storage.NewStorageWithMemoryBackend(), rc, opt) if opt.IsPlacementRulesEnabled() { - err := rc.ruleManager.Initialize(opt.GetMaxReplicas(), opt.GetLocationLabels()) + err := rc.ruleManager.Initialize(opt.GetMaxReplicas(), opt.GetLocationLabels(), opt.GetIsolationLevel()) if err != nil { panic(err) } diff --git a/server/config/persist_options.go b/server/config/persist_options.go index 3f1c4d4a24e8..14fdbf653aa1 100644 --- a/server/config/persist_options.go +++ b/server/config/persist_options.go @@ -330,6 +330,13 @@ func (o *PersistOptions) SetEnableWitness(enable bool) { o.SetScheduleConfig(v) } +// SetMaxStoreDownTime to set the max store down time. It's only used to test. +func (o *PersistOptions) SetMaxStoreDownTime(time time.Duration) { + v := o.GetScheduleConfig().Clone() + v.MaxStoreDownTime = typeutil.NewDuration(time) + o.SetScheduleConfig(v) +} + // SetMaxMergeRegionSize sets the max merge region size. func (o *PersistOptions) SetMaxMergeRegionSize(maxMergeRegionSize uint64) { v := o.GetScheduleConfig().Clone() diff --git a/server/server.go b/server/server.go index 03e036a968ea..2fb66387d7a6 100644 --- a/server/server.go +++ b/server/server.go @@ -1030,7 +1030,7 @@ func (s *Server) SetReplicationConfig(cfg sc.ReplicationConfig) error { } if cfg.EnablePlacementRules { // initialize rule manager. - if err := rc.GetRuleManager().Initialize(int(cfg.MaxReplicas), cfg.LocationLabels); err != nil { + if err := rc.GetRuleManager().Initialize(int(cfg.MaxReplicas), cfg.LocationLabels, cfg.IsolationLevel); err != nil { return err } } else { @@ -1053,19 +1053,19 @@ func (s *Server) SetReplicationConfig(cfg sc.ReplicationConfig) error { defaultRule := rc.GetRuleManager().GetRule("pd", "default") CheckInDefaultRule := func() error { - // replication config won't work when placement rule is enabled and exceeds one default rule + // replication config won't work when placement rule is enabled and exceeds one default rule if !(defaultRule != nil && len(defaultRule.StartKey) == 0 && len(defaultRule.EndKey) == 0) { - return errors.New("cannot update MaxReplicas or LocationLabels when placement rules feature is enabled and not only default rule exists, please update rule instead") + return errors.New("cannot update MaxReplicas, LocationLabels or IsolationLevel when placement rules feature is enabled and not only default rule exists, please update rule instead") } - if !(defaultRule.Count == int(old.MaxReplicas) && typeutil.AreStringSlicesEqual(defaultRule.LocationLabels, []string(old.LocationLabels))) { + if !(defaultRule.Count == int(old.MaxReplicas) && typeutil.AreStringSlicesEqual(defaultRule.LocationLabels, []string(old.LocationLabels)) && defaultRule.IsolationLevel == old.IsolationLevel) { return errors.New("cannot to update replication config, the default rules do not consistent with replication config, please update rule instead") } return nil } - if !(cfg.MaxReplicas == old.MaxReplicas && typeutil.AreStringSlicesEqual(cfg.LocationLabels, old.LocationLabels)) { + if !(cfg.MaxReplicas == old.MaxReplicas && typeutil.AreStringSlicesEqual(cfg.LocationLabels, old.LocationLabels) && cfg.IsolationLevel == old.IsolationLevel) { if err := CheckInDefaultRule(); err != nil { return err } @@ -1076,6 +1076,7 @@ func (s *Server) SetReplicationConfig(cfg sc.ReplicationConfig) error { if rule != nil { rule.Count = int(cfg.MaxReplicas) rule.LocationLabels = cfg.LocationLabels + rule.IsolationLevel = cfg.IsolationLevel rc := s.GetRaftCluster() if rc == nil { return errs.ErrNotBootstrapped.GenWithStackByArgs() diff --git a/tests/pdctl/config/config_test.go b/tests/pdctl/config/config_test.go index 3d0146589d56..f43a964b50cf 100644 --- a/tests/pdctl/config/config_test.go +++ b/tests/pdctl/config/config_test.go @@ -683,7 +683,7 @@ func TestUpdateDefaultReplicaConfig(t *testing.T) { re.Equal(expect, replicationCfg.MaxReplicas) } - checkLocaltionLabels := func(expect int) { + checkLocationLabels := func(expect int) { args := []string{"-u", pdAddr, "config", "show", "replication"} output, err := pdctl.ExecuteCommand(cmd, args...) re.NoError(err) @@ -692,6 +692,15 @@ func TestUpdateDefaultReplicaConfig(t *testing.T) { re.Len(replicationCfg.LocationLabels, expect) } + checkIsolationLevel := func(expect string) { + args := []string{"-u", pdAddr, "config", "show", "replication"} + output, err := pdctl.ExecuteCommand(cmd, args...) + re.NoError(err) + replicationCfg := sc.ReplicationConfig{} + re.NoError(json.Unmarshal(output, &replicationCfg)) + re.Equal(replicationCfg.IsolationLevel, expect) + } + checkRuleCount := func(expect int) { args := []string{"-u", pdAddr, "config", "placement-rules", "show", "--group", "pd", "--id", "default"} output, err := pdctl.ExecuteCommand(cmd, args...) @@ -710,6 +719,15 @@ func TestUpdateDefaultReplicaConfig(t *testing.T) { re.Len(rule.LocationLabels, expect) } + checkRuleIsolationLevel := func(expect string) { + args := []string{"-u", pdAddr, "config", "placement-rules", "show", "--group", "pd", "--id", "default"} + output, err := pdctl.ExecuteCommand(cmd, args...) + re.NoError(err) + rule := placement.Rule{} + re.NoError(json.Unmarshal(output, &rule)) + re.Equal(rule.IsolationLevel, expect) + } + // update successfully when placement rules is not enabled. output, err := pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "set", "max-replicas", "2") re.NoError(err) @@ -718,8 +736,13 @@ func TestUpdateDefaultReplicaConfig(t *testing.T) { output, err = pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "set", "location-labels", "zone,host") re.NoError(err) re.Contains(string(output), "Success!") - checkLocaltionLabels(2) + output, err = pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "set", "isolation-level", "zone") + re.NoError(err) + re.Contains(string(output), "Success!") + checkLocationLabels(2) checkRuleLocationLabels(2) + checkIsolationLevel("zone") + checkRuleIsolationLevel("zone") // update successfully when only one default rule exists. output, err = pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "placement-rules", "enable") @@ -732,11 +755,18 @@ func TestUpdateDefaultReplicaConfig(t *testing.T) { checkMaxReplicas(3) checkRuleCount(3) + // We need to change isolation first because we will validate + // if the location label contains the isolation level when setting location labels. + output, err = pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "set", "isolation-level", "host") + re.NoError(err) + re.Contains(string(output), "Success!") output, err = pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "set", "location-labels", "host") re.NoError(err) re.Contains(string(output), "Success!") - checkLocaltionLabels(1) + checkLocationLabels(1) checkRuleLocationLabels(1) + checkIsolationLevel("host") + checkRuleIsolationLevel("host") // update unsuccessfully when many rule exists. fname := t.TempDir() @@ -760,8 +790,10 @@ func TestUpdateDefaultReplicaConfig(t *testing.T) { re.NoError(err) checkMaxReplicas(4) checkRuleCount(4) - checkLocaltionLabels(1) + checkLocationLabels(1) checkRuleLocationLabels(1) + checkIsolationLevel("host") + checkRuleIsolationLevel("host") } func TestPDServerConfig(t *testing.T) { From e2f12696c76adc96d43297bcc2f5df097ed21b70 Mon Sep 17 00:00:00 2001 From: lhy1024 Date: Thu, 21 Sep 2023 16:55:46 +0800 Subject: [PATCH 2/8] util: add check delete json function (#7113) ref tikv/pd#4399 Signed-off-by: lhy1024 Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- errors.toml | 30 ++++++++++++------------- pkg/autoscaling/prometheus_test.go | 4 ++-- pkg/errs/errno.go | 6 ++--- pkg/tso/keyspace_group_manager.go | 14 +++++++----- pkg/utils/apiutil/apiutil.go | 11 +++------ pkg/utils/testutil/api_check.go | 11 ++++++++- server/api/admin_test.go | 7 ++---- server/api/diagnostic_test.go | 3 +-- server/api/operator_test.go | 9 ++++---- server/api/region_label_test.go | 3 +-- server/api/rule_test.go | 12 +++++----- server/api/scheduler.go | 7 +++--- server/api/scheduler_test.go | 27 +++++++++------------- server/api/service_gc_safepoint_test.go | 4 +--- 14 files changed, 71 insertions(+), 77 deletions(-) diff --git a/errors.toml b/errors.toml index 6766da79572f..1b96de8a2098 100644 --- a/errors.toml +++ b/errors.toml @@ -531,21 +531,6 @@ error = ''' plugin is not found: %s ''' -["PD:operator:ErrRegionAbnormalPeer"] -error = ''' -region %v has abnormal peer -''' - -["PD:operator:ErrRegionNotAdjacent"] -error = ''' -two regions are not adjacent -''' - -["PD:operator:ErrRegionNotFound"] -error = ''' -region %v not found -''' - ["PD:os:ErrOSOpen"] error = ''' open error @@ -616,6 +601,21 @@ error = ''' failed to unmarshal proto ''' +["PD:region:ErrRegionAbnormalPeer"] +error = ''' +region %v has abnormal peer +''' + +["PD:region:ErrRegionNotAdjacent"] +error = ''' +two regions are not adjacent +''' + +["PD:region:ErrRegionNotFound"] +error = ''' +region %v not found +''' + ["PD:region:ErrRegionRuleContent"] error = ''' invalid region rule content, %s diff --git a/pkg/autoscaling/prometheus_test.go b/pkg/autoscaling/prometheus_test.go index 6d4a27b04119..6c30e3ead4cb 100644 --- a/pkg/autoscaling/prometheus_test.go +++ b/pkg/autoscaling/prometheus_test.go @@ -155,7 +155,7 @@ func makeJSONResponse(promResp *response) (*http.Response, []byte, error) { response := &http.Response{ Status: "200 OK", - StatusCode: 200, + StatusCode: http.StatusOK, Proto: "HTTP/1.1", ProtoMajor: 1, ProtoMinor: 1, @@ -246,7 +246,7 @@ func (c *errorHTTPStatusClient) Do(_ context.Context, req *http.Request) (r *htt r, body, err = makeJSONResponse(promResp) - r.StatusCode = 500 + r.StatusCode = http.StatusInternalServerError r.Status = "500 Internal Server Error" return diff --git a/pkg/errs/errno.go b/pkg/errs/errno.go index 9eedb144f95f..181dfc9b3938 100644 --- a/pkg/errs/errno.go +++ b/pkg/errs/errno.go @@ -103,11 +103,11 @@ var ( // region errors var ( // ErrRegionNotAdjacent is error info for region not adjacent. - ErrRegionNotAdjacent = errors.Normalize("two regions are not adjacent", errors.RFCCodeText("PD:operator:ErrRegionNotAdjacent")) + ErrRegionNotAdjacent = errors.Normalize("two regions are not adjacent", errors.RFCCodeText("PD:region:ErrRegionNotAdjacent")) // ErrRegionNotFound is error info for region not found. - ErrRegionNotFound = errors.Normalize("region %v not found", errors.RFCCodeText("PD:operator:ErrRegionNotFound")) + ErrRegionNotFound = errors.Normalize("region %v not found", errors.RFCCodeText("PD:region:ErrRegionNotFound")) // ErrRegionAbnormalPeer is error info for region has abnormal peer. - ErrRegionAbnormalPeer = errors.Normalize("region %v has abnormal peer", errors.RFCCodeText("PD:operator:ErrRegionAbnormalPeer")) + ErrRegionAbnormalPeer = errors.Normalize("region %v has abnormal peer", errors.RFCCodeText("PD:region:ErrRegionAbnormalPeer")) ) // plugin errors diff --git a/pkg/tso/keyspace_group_manager.go b/pkg/tso/keyspace_group_manager.go index c6d2323aa4bc..3b352884eab3 100644 --- a/pkg/tso/keyspace_group_manager.go +++ b/pkg/tso/keyspace_group_manager.go @@ -1226,16 +1226,17 @@ func (kgm *KeyspaceGroupManager) finishSplitKeyspaceGroup(id uint32) error { return nil } startRequest := time.Now() - statusCode, err := apiutil.DoDelete( + resp, err := apiutil.DoDelete( kgm.httpClient, kgm.cfg.GeBackendEndpoints()+keyspaceGroupsAPIPrefix+fmt.Sprintf("/%d/split", id)) if err != nil { return err } - if statusCode != http.StatusOK { + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { log.Warn("failed to finish split keyspace group", zap.Uint32("keyspace-group-id", id), - zap.Int("status-code", statusCode)) + zap.Int("status-code", resp.StatusCode)) return errs.ErrSendRequest.FastGenByArgs() } kgm.metrics.finishSplitSendDuration.Observe(time.Since(startRequest).Seconds()) @@ -1264,16 +1265,17 @@ func (kgm *KeyspaceGroupManager) finishMergeKeyspaceGroup(id uint32) error { return nil } startRequest := time.Now() - statusCode, err := apiutil.DoDelete( + resp, err := apiutil.DoDelete( kgm.httpClient, kgm.cfg.GeBackendEndpoints()+keyspaceGroupsAPIPrefix+fmt.Sprintf("/%d/merge", id)) if err != nil { return err } - if statusCode != http.StatusOK { + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { log.Warn("failed to finish merging keyspace group", zap.Uint32("keyspace-group-id", id), - zap.Int("status-code", statusCode)) + zap.Int("status-code", resp.StatusCode)) return errs.ErrSendRequest.FastGenByArgs() } kgm.metrics.finishMergeSendDuration.Observe(time.Since(startRequest).Seconds()) diff --git a/pkg/utils/apiutil/apiutil.go b/pkg/utils/apiutil/apiutil.go index 0b72b9af10fe..2c476042da0f 100644 --- a/pkg/utils/apiutil/apiutil.go +++ b/pkg/utils/apiutil/apiutil.go @@ -226,17 +226,12 @@ func PostJSONIgnoreResp(client *http.Client, url string, data []byte) error { } // DoDelete is used to send delete request and return http response code. -func DoDelete(client *http.Client, url string) (int, error) { +func DoDelete(client *http.Client, url string) (*http.Response, error) { req, err := http.NewRequest(http.MethodDelete, url, nil) if err != nil { - return http.StatusBadRequest, err - } - res, err := client.Do(req) - if err != nil { - return 0, err + return nil, err } - defer res.Body.Close() - return res.StatusCode, nil + return client.Do(req) } func checkResponse(resp *http.Response, err error) error { diff --git a/pkg/utils/testutil/api_check.go b/pkg/utils/testutil/api_check.go index d11d575967d1..84af97f828de 100644 --- a/pkg/utils/testutil/api_check.go +++ b/pkg/utils/testutil/api_check.go @@ -123,9 +123,18 @@ func CheckPatchJSON(client *http.Client, url string, data []byte, checkOpts ...f return checkResp(resp, checkOpts...) } +// CheckDelete is used to do delete request and do check options. +func CheckDelete(client *http.Client, url string, checkOpts ...func([]byte, int, http.Header)) error { + resp, err := apiutil.DoDelete(client, url) + if err != nil { + return err + } + return checkResp(resp, checkOpts...) +} + func checkResp(resp *http.Response, checkOpts ...func([]byte, int, http.Header)) error { res, err := io.ReadAll(resp.Body) - resp.Body.Close() + defer resp.Body.Close() if err != nil { return err } diff --git a/server/api/admin_test.go b/server/api/admin_test.go index 1f2b386eb987..6a972171e1fe 100644 --- a/server/api/admin_test.go +++ b/server/api/admin_test.go @@ -26,7 +26,6 @@ import ( "github.com/pingcap/kvproto/pkg/pdpb" "github.com/stretchr/testify/suite" "github.com/tikv/pd/pkg/core" - "github.com/tikv/pd/pkg/utils/apiutil" tu "github.com/tikv/pd/pkg/utils/testutil" "github.com/tikv/pd/server" ) @@ -271,9 +270,8 @@ func (suite *adminTestSuite) TestMarkSnapshotRecovering() { suite.NoError(err2) suite.True(resp.Marked) // unmark - code, err := apiutil.DoDelete(testDialClient, url) + err := tu.CheckDelete(testDialClient, url, tu.StatusOK(re)) suite.NoError(err) - suite.Equal(200, code) suite.NoError(tu.CheckGetJSON(testDialClient, url, nil, tu.StatusOK(re), tu.StringContain(re, "false"))) } @@ -310,9 +308,8 @@ func (suite *adminTestSuite) TestRecoverAllocID() { suite.NoError(err2) suite.Equal(id, uint64(99000001)) // unmark - code, err := apiutil.DoDelete(testDialClient, markRecoveringURL) + err := tu.CheckDelete(testDialClient, markRecoveringURL, tu.StatusOK(re)) suite.NoError(err) - suite.Equal(200, code) suite.NoError(tu.CheckGetJSON(testDialClient, markRecoveringURL, nil, tu.StatusOK(re), tu.StringContain(re, "false"))) suite.NoError(tu.CheckPostJSON(testDialClient, url, []byte(`{"id": "100000"}`), diff --git a/server/api/diagnostic_test.go b/server/api/diagnostic_test.go index 8a39b2e00077..1774c2215396 100644 --- a/server/api/diagnostic_test.go +++ b/server/api/diagnostic_test.go @@ -24,7 +24,6 @@ import ( "github.com/stretchr/testify/suite" "github.com/tikv/pd/pkg/core" "github.com/tikv/pd/pkg/schedule/schedulers" - "github.com/tikv/pd/pkg/utils/apiutil" tu "github.com/tikv/pd/pkg/utils/testutil" "github.com/tikv/pd/server" "github.com/tikv/pd/server/config" @@ -129,7 +128,7 @@ func (suite *diagnosticTestSuite) TestSchedulerDiagnosticAPI() { suite.checkStatus("normal", balanceRegionURL) deleteURL := fmt.Sprintf("%s/%s", suite.schedulerPrifex, schedulers.BalanceRegionName) - _, err = apiutil.DoDelete(testDialClient, deleteURL) + err = tu.CheckDelete(testDialClient, deleteURL, tu.StatusOK(re)) suite.NoError(err) suite.checkStatus("disabled", balanceRegionURL) } diff --git a/server/api/operator_test.go b/server/api/operator_test.go index ee849552f09b..1675fdd40c74 100644 --- a/server/api/operator_test.go +++ b/server/api/operator_test.go @@ -33,7 +33,6 @@ import ( "github.com/tikv/pd/pkg/mock/mockhbstream" pdoperator "github.com/tikv/pd/pkg/schedule/operator" "github.com/tikv/pd/pkg/schedule/placement" - "github.com/tikv/pd/pkg/utils/apiutil" tu "github.com/tikv/pd/pkg/utils/testutil" "github.com/tikv/pd/pkg/versioninfo" "github.com/tikv/pd/server" @@ -99,7 +98,7 @@ func (suite *operatorTestSuite) TestAddRemovePeer() { suite.Contains(operator, "add learner peer 1 on store 3") suite.Contains(operator, "RUNNING") - _, err = apiutil.DoDelete(testDialClient, regionURL) + err = tu.CheckDelete(testDialClient, regionURL, tu.StatusOK(re)) suite.NoError(err) records = mustReadURL(re, recordURL) suite.Contains(records, "admin-add-peer {add peer: store [3]}") @@ -110,7 +109,7 @@ func (suite *operatorTestSuite) TestAddRemovePeer() { suite.Contains(operator, "RUNNING") suite.Contains(operator, "remove peer on store 2") - _, err = apiutil.DoDelete(testDialClient, regionURL) + err = tu.CheckDelete(testDialClient, regionURL, tu.StatusOK(re)) suite.NoError(err) records = mustReadURL(re, recordURL) suite.Contains(records, "admin-remove-peer {rm peer: store [2]}") @@ -406,8 +405,10 @@ func (suite *transferRegionOperatorTestSuite) TestTransferRegionWithPlacementRul if len(testCase.expectSteps) > 0 { operator = mustReadURL(re, regionURL) suite.Contains(operator, testCase.expectSteps) + err = tu.CheckDelete(testDialClient, regionURL, tu.StatusOK(re)) + } else { + err = tu.CheckDelete(testDialClient, regionURL, tu.StatusNotOK(re)) } - _, err = apiutil.DoDelete(testDialClient, regionURL) suite.NoError(err) } } diff --git a/server/api/region_label_test.go b/server/api/region_label_test.go index 021ec7f1359e..fd7401b83e0a 100644 --- a/server/api/region_label_test.go +++ b/server/api/region_label_test.go @@ -24,7 +24,6 @@ import ( "github.com/pingcap/failpoint" "github.com/stretchr/testify/suite" "github.com/tikv/pd/pkg/schedule/labeler" - "github.com/tikv/pd/pkg/utils/apiutil" tu "github.com/tikv/pd/pkg/utils/testutil" "github.com/tikv/pd/server" ) @@ -86,7 +85,7 @@ func (suite *regionLabelTestSuite) TestGetSet() { expects := []*labeler.LabelRule{rules[0], rules[2]} suite.Equal(expects, resp) - _, err = apiutil.DoDelete(testDialClient, suite.urlPrefix+"rule/"+url.QueryEscape("rule2/a/b")) + err = tu.CheckDelete(testDialClient, suite.urlPrefix+"rule/"+url.QueryEscape("rule2/a/b"), tu.StatusOK(re)) suite.NoError(err) err = tu.ReadGetJSON(re, testDialClient, suite.urlPrefix+"rules", &resp) suite.NoError(err) diff --git a/server/api/rule_test.go b/server/api/rule_test.go index 4cea15234015..d2dc50f11192 100644 --- a/server/api/rule_test.go +++ b/server/api/rule_test.go @@ -26,7 +26,6 @@ import ( "github.com/stretchr/testify/suite" "github.com/tikv/pd/pkg/core" "github.com/tikv/pd/pkg/schedule/placement" - "github.com/tikv/pd/pkg/utils/apiutil" tu "github.com/tikv/pd/pkg/utils/testutil" "github.com/tikv/pd/server" "github.com/tikv/pd/server/config" @@ -202,13 +201,13 @@ func (suite *ruleTestSuite) TestGet() { name: "found", rule: rule, found: true, - code: 200, + code: http.StatusOK, }, { name: "not found", rule: placement.Rule{GroupID: "a", ID: "30", StartKeyHex: "1111", EndKeyHex: "3333", Role: "voter", Count: 1}, found: false, - code: 404, + code: http.StatusNotFound, }, } for _, testCase := range testCases { @@ -533,9 +532,8 @@ func (suite *ruleTestSuite) TestDelete() { url := fmt.Sprintf("%s/rule/%s/%s", suite.urlPrefix, testCase.groupID, testCase.id) // clear suspect keyRanges to prevent test case from others suite.svr.GetRaftCluster().ClearSuspectKeyRanges() - statusCode, err := apiutil.DoDelete(testDialClient, url) + err = tu.CheckDelete(testDialClient, url, tu.StatusOK(suite.Require())) suite.NoError(err) - suite.Equal(http.StatusOK, statusCode) if len(testCase.popKeyRange) > 0 { popKeyRangeMap := map[string]struct{}{} for i := 0; i < len(testCase.popKeyRange)/2; i++ { @@ -726,7 +724,7 @@ func (suite *ruleTestSuite) TestBundle() { suite.compareBundle(bundles[1], b2) // Delete - _, err = apiutil.DoDelete(testDialClient, suite.urlPrefix+"/placement-rule/pd") + err = tu.CheckDelete(testDialClient, suite.urlPrefix+"/placement-rule/pd", tu.StatusOK(suite.Require())) suite.NoError(err) // GetAll again @@ -753,7 +751,7 @@ func (suite *ruleTestSuite) TestBundle() { suite.compareBundle(bundles[2], b3) // Delete using regexp - _, err = apiutil.DoDelete(testDialClient, suite.urlPrefix+"/placement-rule/"+url.PathEscape("foo.*")+"?regexp") + err = tu.CheckDelete(testDialClient, suite.urlPrefix+"/placement-rule/"+url.PathEscape("foo.*")+"?regexp", tu.StatusOK(suite.Require())) suite.NoError(err) // GetAll again diff --git a/server/api/scheduler.go b/server/api/scheduler.go index dc7f2507141a..c2691ea98269 100644 --- a/server/api/scheduler.go +++ b/server/api/scheduler.go @@ -324,12 +324,13 @@ func (h *schedulerHandler) redirectSchedulerDelete(w http.ResponseWriter, name, h.r.JSON(w, http.StatusInternalServerError, err.Error()) return } - statusCode, err := apiutil.DoDelete(h.svr.GetHTTPClient(), deleteURL) + resp, err := apiutil.DoDelete(h.svr.GetHTTPClient(), deleteURL) if err != nil { - h.r.JSON(w, statusCode, err.Error()) + h.r.JSON(w, resp.StatusCode, err.Error()) return } - h.r.JSON(w, statusCode, nil) + defer resp.Body.Close() + h.r.JSON(w, resp.StatusCode, nil) } // FIXME: details of input json body params diff --git a/server/api/scheduler_test.go b/server/api/scheduler_test.go index 613de8e441c3..b015bbe8f524 100644 --- a/server/api/scheduler_test.go +++ b/server/api/scheduler_test.go @@ -25,7 +25,6 @@ import ( "github.com/pingcap/kvproto/pkg/metapb" "github.com/stretchr/testify/suite" sc "github.com/tikv/pd/pkg/schedule/config" - "github.com/tikv/pd/pkg/utils/apiutil" tu "github.com/tikv/pd/pkg/utils/testutil" "github.com/tikv/pd/server" ) @@ -93,7 +92,7 @@ func (suite *scheduleTestSuite) TestOriginAPI() { suite.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) suite.Len(resp["store-id-ranges"], 2) deleteURL := fmt.Sprintf("%s/%s", suite.urlPrefix, "evict-leader-scheduler-1") - _, err = apiutil.DoDelete(testDialClient, deleteURL) + err = tu.CheckDelete(testDialClient, deleteURL, tu.StatusOK(re)) suite.NoError(err) suite.Len(rc.GetSchedulers(), 1) resp1 := make(map[string]interface{}) @@ -101,18 +100,16 @@ func (suite *scheduleTestSuite) TestOriginAPI() { suite.Len(resp1["store-id-ranges"], 1) deleteURL = fmt.Sprintf("%s/%s", suite.urlPrefix, "evict-leader-scheduler-2") suite.NoError(failpoint.Enable("github.com/tikv/pd/server/config/persistFail", "return(true)")) - statusCode, err := apiutil.DoDelete(testDialClient, deleteURL) + err = tu.CheckDelete(testDialClient, deleteURL, tu.Status(re, http.StatusInternalServerError)) suite.NoError(err) - suite.Equal(500, statusCode) suite.Len(rc.GetSchedulers(), 1) suite.NoError(failpoint.Disable("github.com/tikv/pd/server/config/persistFail")) - statusCode, err = apiutil.DoDelete(testDialClient, deleteURL) + err = tu.CheckDelete(testDialClient, deleteURL, tu.StatusOK(re)) suite.NoError(err) - suite.Equal(200, statusCode) suite.Empty(rc.GetSchedulers()) - suite.NoError(tu.CheckGetJSON(testDialClient, listURL, nil, tu.Status(re, 404))) - statusCode, _ = apiutil.DoDelete(testDialClient, deleteURL) - suite.Equal(404, statusCode) + suite.NoError(tu.CheckGetJSON(testDialClient, listURL, nil, tu.Status(re, http.StatusNotFound))) + err = tu.CheckDelete(testDialClient, deleteURL, tu.Status(re, http.StatusNotFound)) + suite.NoError(err) } func (suite *scheduleTestSuite) TestAPI() { @@ -370,15 +367,14 @@ func (suite *scheduleTestSuite) TestAPI() { // using /pd/v1/schedule-config/grant-leader-scheduler/config to delete exists store from grant-leader-scheduler deleteURL := fmt.Sprintf("%s%s%s/%s/delete/%s", suite.svr.GetAddr(), apiPrefix, server.SchedulerConfigHandlerPath, name, "2") - _, err = apiutil.DoDelete(testDialClient, deleteURL) + err = tu.CheckDelete(testDialClient, deleteURL, tu.StatusOK(re)) suite.NoError(err) resp = make(map[string]interface{}) suite.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) delete(exceptMap, "2") suite.Equal(exceptMap, resp["store-id-ranges"]) - statusCode, err := apiutil.DoDelete(testDialClient, deleteURL) + err = tu.CheckDelete(testDialClient, deleteURL, tu.Status(re, http.StatusNotFound)) suite.NoError(err) - suite.Equal(404, statusCode) }, }, { @@ -434,15 +430,14 @@ func (suite *scheduleTestSuite) TestAPI() { // using /pd/v1/schedule-config/evict-leader-scheduler/config to delete exist store from evict-leader-scheduler deleteURL := fmt.Sprintf("%s%s%s/%s/delete/%s", suite.svr.GetAddr(), apiPrefix, server.SchedulerConfigHandlerPath, name, "4") - _, err = apiutil.DoDelete(testDialClient, deleteURL) + err = tu.CheckDelete(testDialClient, deleteURL, tu.StatusOK(re)) suite.NoError(err) resp = make(map[string]interface{}) suite.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) delete(exceptMap, "4") suite.Equal(exceptMap, resp["store-id-ranges"]) - statusCode, err := apiutil.DoDelete(testDialClient, deleteURL) + err = tu.CheckDelete(testDialClient, deleteURL, tu.Status(re, http.StatusNotFound)) suite.NoError(err) - suite.Equal(404, statusCode) }, }, } @@ -591,7 +586,7 @@ func (suite *scheduleTestSuite) addScheduler(body []byte) { func (suite *scheduleTestSuite) deleteScheduler(createdName string) { deleteURL := fmt.Sprintf("%s/%s", suite.urlPrefix, createdName) - _, err := apiutil.DoDelete(testDialClient, deleteURL) + err := tu.CheckDelete(testDialClient, deleteURL, tu.StatusOK(suite.Require())) suite.NoError(err) } diff --git a/server/api/service_gc_safepoint_test.go b/server/api/service_gc_safepoint_test.go index fe52204dfb21..517a94c2e23a 100644 --- a/server/api/service_gc_safepoint_test.go +++ b/server/api/service_gc_safepoint_test.go @@ -16,7 +16,6 @@ package api import ( "fmt" - "net/http" "testing" "time" @@ -93,9 +92,8 @@ func (suite *serviceGCSafepointTestSuite) TestServiceGCSafepoint() { suite.NoError(err) suite.Equal(list, listResp) - statusCode, err := apiutil.DoDelete(testDialClient, sspURL+"/a") + err = testutil.CheckDelete(testDialClient, sspURL+"/a", testutil.StatusOK(suite.Require())) suite.NoError(err) - suite.Equal(http.StatusOK, statusCode) left, err := storage.LoadAllServiceGCSafePoints() suite.NoError(err) From 96ace89decdc0b5e0a050aa17ba4356057ec3b88 Mon Sep 17 00:00:00 2001 From: lhy1024 Date: Thu, 21 Sep 2023 17:12:15 +0800 Subject: [PATCH 3/8] tests: refactor and make pd-ctl helper support mcs (#7120) ref tikv/pd#5839 Signed-off-by: lhy1024 Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- tests/autoscaling/autoscaling_test.go | 2 +- tests/cluster.go | 20 ++++++ tests/compatibility/version_upgrade_test.go | 6 +- tests/dashboard/service_test.go | 2 +- tests/integrations/client/client_test.go | 12 ++-- .../mcs/keyspace/tso_keyspace_group_test.go | 2 +- .../resourcemanager/resource_manager_test.go | 4 +- tests/integrations/mcs/scheduling/api_test.go | 2 +- .../mcs/tso/keyspace_group_manager_test.go | 4 +- tests/integrations/tso/client_test.go | 2 +- tests/pdctl/cluster/cluster_test.go | 2 +- tests/pdctl/config/config_test.go | 36 ++++------ tests/pdctl/health/health_test.go | 2 +- tests/pdctl/helper.go | 64 ----------------- tests/pdctl/hot/hot_test.go | 36 +++++----- tests/pdctl/keyspace/keyspace_group_test.go | 16 ++--- tests/pdctl/keyspace/keyspace_test.go | 2 +- tests/pdctl/label/label_test.go | 4 +- tests/pdctl/log/log_test.go | 9 +-- tests/pdctl/member/member_test.go | 2 +- tests/pdctl/operator/operator_test.go | 10 +-- tests/pdctl/region/region_test.go | 18 ++--- tests/pdctl/scheduler/scheduler_test.go | 16 ++--- tests/pdctl/store/store_test.go | 14 ++-- tests/pdctl/unsafe/unsafe_operation_test.go | 2 +- tests/registry/registry_test.go | 4 +- tests/server/api/api_test.go | 61 ++++++++-------- tests/server/apiv2/handlers/keyspace_test.go | 2 +- .../apiv2/handlers/tso_keyspace_group_test.go | 2 +- tests/server/cluster/cluster_test.go | 42 +++++------ tests/server/cluster/cluster_work_test.go | 6 +- tests/server/config/config_test.go | 4 +- tests/server/id/id_test.go | 12 ++-- tests/server/keyspace/keyspace_test.go | 2 +- tests/server/member/member_test.go | 2 +- .../region_syncer/region_syncer_test.go | 10 +-- .../server/storage/hot_region_storage_test.go | 29 ++++---- tests/server/tso/consistency_test.go | 10 +-- tests/server/tso/global_tso_test.go | 4 +- tests/server/tso/tso_test.go | 4 +- tests/server/watch/leader_watch_test.go | 4 +- tests/testutil.go | 72 +++++++++++++++++++ 42 files changed, 289 insertions(+), 270 deletions(-) diff --git a/tests/autoscaling/autoscaling_test.go b/tests/autoscaling/autoscaling_test.go index 55e29297dbdc..663bc92f5627 100644 --- a/tests/autoscaling/autoscaling_test.go +++ b/tests/autoscaling/autoscaling_test.go @@ -42,7 +42,7 @@ func TestAPI(t *testing.T) { re.NoError(err) cluster.WaitLeader() - leaderServer := cluster.GetServer(cluster.GetLeader()) + leaderServer := cluster.GetLeaderServer() re.NoError(leaderServer.BootstrapCluster()) var jsonStr = []byte(` diff --git a/tests/cluster.go b/tests/cluster.go index ce8293531cd0..c49f3cd982d7 100644 --- a/tests/cluster.go +++ b/tests/cluster.go @@ -33,6 +33,7 @@ import ( "github.com/tikv/pd/pkg/errs" "github.com/tikv/pd/pkg/id" "github.com/tikv/pd/pkg/keyspace" + scheduling "github.com/tikv/pd/pkg/mcs/scheduling/server" "github.com/tikv/pd/pkg/mcs/utils" "github.com/tikv/pd/pkg/schedule/schedulers" "github.com/tikv/pd/pkg/swaggerserver" @@ -447,6 +448,7 @@ type TestCluster struct { sync.Mutex pool map[uint64]struct{} } + schedulingCluster *TestSchedulingCluster } // ConfigOption is used to define customize settings in test. @@ -629,6 +631,11 @@ func (c *TestCluster) GetFollower() string { return "" } +// GetLeaderServer returns the leader server of all servers +func (c *TestCluster) GetLeaderServer() *TestServer { + return c.GetServer(c.GetLeader()) +} + // WaitLeader is used to get leader. // If it exceeds the maximum number of loops, it will return an empty string. func (c *TestCluster) WaitLeader(ops ...WaitOption) string { @@ -853,6 +860,19 @@ func (c *TestCluster) CheckTSOUnique(ts uint64) bool { return true } +// GetSchedulingPrimaryServer returns the scheduling primary server. +func (c *TestCluster) GetSchedulingPrimaryServer() *scheduling.Server { + if c.schedulingCluster == nil { + return nil + } + return c.schedulingCluster.GetPrimaryServer() +} + +// SetSchedulingCluster sets the scheduling cluster. +func (c *TestCluster) SetSchedulingCluster(cluster *TestSchedulingCluster) { + c.schedulingCluster = cluster +} + // WaitOp represent the wait configuration type WaitOp struct { retryTimes int diff --git a/tests/compatibility/version_upgrade_test.go b/tests/compatibility/version_upgrade_test.go index 11573e6da2f8..8979d85c9bbb 100644 --- a/tests/compatibility/version_upgrade_test.go +++ b/tests/compatibility/version_upgrade_test.go @@ -38,7 +38,7 @@ func TestStoreRegister(t *testing.T) { err = cluster.RunInitialServers() re.NoError(err) cluster.WaitLeader() - leaderServer := cluster.GetServer(cluster.GetLeader()) + leaderServer := cluster.GetLeaderServer() re.NoError(leaderServer.BootstrapCluster()) putStoreRequest := &pdpb.PutStoreRequest{ @@ -63,7 +63,7 @@ func TestStoreRegister(t *testing.T) { re.NoError(err) cluster.WaitLeader() - leaderServer = cluster.GetServer(cluster.GetLeader()) + leaderServer = cluster.GetLeaderServer() re.NotNil(leaderServer) newVersion := leaderServer.GetClusterVersion() re.Equal(version, newVersion) @@ -92,7 +92,7 @@ func TestRollingUpgrade(t *testing.T) { err = cluster.RunInitialServers() re.NoError(err) cluster.WaitLeader() - leaderServer := cluster.GetServer(cluster.GetLeader()) + leaderServer := cluster.GetLeaderServer() re.NoError(leaderServer.BootstrapCluster()) stores := []*pdpb.PutStoreRequest{ diff --git a/tests/dashboard/service_test.go b/tests/dashboard/service_test.go index f75e047d8f13..ab3a2c431cbc 100644 --- a/tests/dashboard/service_test.go +++ b/tests/dashboard/service_test.go @@ -134,7 +134,7 @@ func (suite *dashboardTestSuite) testDashboard(internalProxy bool) { cluster.WaitLeader() servers := cluster.GetServers() - leader := cluster.GetServer(cluster.GetLeader()) + leader := cluster.GetLeaderServer() leaderAddr := leader.GetAddr() // auto select node diff --git a/tests/integrations/client/client_test.go b/tests/integrations/client/client_test.go index d669e17af903..9cabbb030902 100644 --- a/tests/integrations/client/client_test.go +++ b/tests/integrations/client/client_test.go @@ -347,7 +347,7 @@ func TestUnavailableTimeAfterLeaderIsReady(t *testing.T) { go getTsoFunc() go func() { defer wg.Done() - leader := cluster.GetServer(cluster.GetLeader()) + leader := cluster.GetLeaderServer() leader.Stop() re.NotEmpty(cluster.WaitLeader()) leaderReadyTime = time.Now() @@ -362,7 +362,7 @@ func TestUnavailableTimeAfterLeaderIsReady(t *testing.T) { go getTsoFunc() go func() { defer wg.Done() - leader := cluster.GetServer(cluster.GetLeader()) + leader := cluster.GetLeaderServer() re.NoError(failpoint.Enable("github.com/tikv/pd/client/unreachableNetwork", "return(true)")) leader.Stop() re.NotEmpty(cluster.WaitLeader()) @@ -596,7 +596,7 @@ func TestGetTsoFromFollowerClient2(t *testing.T) { }) lastTS = checkTS(re, cli, lastTS) - re.NoError(cluster.GetServer(cluster.GetLeader()).ResignLeader()) + re.NoError(cluster.GetLeaderServer().ResignLeader()) re.NotEmpty(cluster.WaitLeader()) lastTS = checkTS(re, cli, lastTS) @@ -622,7 +622,7 @@ func runServer(re *require.Assertions, cluster *tests.TestCluster) []string { err := cluster.RunInitialServers() re.NoError(err) re.NotEmpty(cluster.WaitLeader()) - leaderServer := cluster.GetServer(cluster.GetLeader()) + leaderServer := cluster.GetLeaderServer() re.NoError(leaderServer.BootstrapCluster()) testServers := cluster.GetServers() @@ -1439,7 +1439,7 @@ func TestPutGet(t *testing.T) { getResp, err = client.Get(context.Background(), key) re.NoError(err) re.Equal([]byte("2"), getResp.GetKvs()[0].Value) - s := cluster.GetServer(cluster.GetLeader()) + s := cluster.GetLeaderServer() // use etcd client delete the key _, err = s.GetEtcdClient().Delete(context.Background(), string(key)) re.NoError(err) @@ -1459,7 +1459,7 @@ func TestClientWatchWithRevision(t *testing.T) { endpoints := runServer(re, cluster) client := setupCli(re, ctx, endpoints) defer client.Close() - s := cluster.GetServer(cluster.GetLeader()) + s := cluster.GetLeaderServer() watchPrefix := "watch_test" defer func() { _, err := s.GetEtcdClient().Delete(context.Background(), watchPrefix+"test") diff --git a/tests/integrations/mcs/keyspace/tso_keyspace_group_test.go b/tests/integrations/mcs/keyspace/tso_keyspace_group_test.go index 59aabb260aee..af7b31553b38 100644 --- a/tests/integrations/mcs/keyspace/tso_keyspace_group_test.go +++ b/tests/integrations/mcs/keyspace/tso_keyspace_group_test.go @@ -62,7 +62,7 @@ func (suite *keyspaceGroupTestSuite) SetupTest() { suite.NoError(err) suite.NoError(cluster.RunInitialServers()) suite.NotEmpty(cluster.WaitLeader()) - suite.server = cluster.GetServer(cluster.GetLeader()) + suite.server = cluster.GetLeaderServer() suite.NoError(suite.server.BootstrapCluster()) suite.backendEndpoints = suite.server.GetAddr() suite.dialClient = &http.Client{ diff --git a/tests/integrations/mcs/resourcemanager/resource_manager_test.go b/tests/integrations/mcs/resourcemanager/resource_manager_test.go index 546339bee0f9..0be18d1bbd38 100644 --- a/tests/integrations/mcs/resourcemanager/resource_manager_test.go +++ b/tests/integrations/mcs/resourcemanager/resource_manager_test.go @@ -903,7 +903,7 @@ func (suite *resourceManagerClientTestSuite) TestBasicResourceGroupCURD() { // Test Resource Group CURD via HTTP finalNum = 1 getAddr := func(i int) string { - server := suite.cluster.GetServer(suite.cluster.GetLeader()) + server := suite.cluster.GetLeaderServer() if i%2 == 1 { server = suite.cluster.GetServer(suite.cluster.GetFollower()) } @@ -1298,7 +1298,7 @@ func (suite *resourceManagerClientTestSuite) TestResourceGroupControllerConfigCh } getAddr := func() string { - server := suite.cluster.GetServer(suite.cluster.GetLeader()) + server := suite.cluster.GetLeaderServer() if rand.Intn(100)%2 == 1 { server = suite.cluster.GetServer(suite.cluster.GetFollower()) } diff --git a/tests/integrations/mcs/scheduling/api_test.go b/tests/integrations/mcs/scheduling/api_test.go index 04671d847982..311c8a3fbed5 100644 --- a/tests/integrations/mcs/scheduling/api_test.go +++ b/tests/integrations/mcs/scheduling/api_test.go @@ -45,7 +45,7 @@ func (suite *apiTestSuite) SetupSuite() { suite.NoError(err) suite.NoError(cluster.RunInitialServers()) suite.NotEmpty(cluster.WaitLeader()) - suite.server = cluster.GetServer(cluster.GetLeader()) + suite.server = cluster.GetLeaderServer() suite.NoError(suite.server.BootstrapCluster()) suite.backendEndpoints = suite.server.GetAddr() suite.dialClient = &http.Client{ diff --git a/tests/integrations/mcs/tso/keyspace_group_manager_test.go b/tests/integrations/mcs/tso/keyspace_group_manager_test.go index 3d3fe25b3729..d1a4cf35db49 100644 --- a/tests/integrations/mcs/tso/keyspace_group_manager_test.go +++ b/tests/integrations/mcs/tso/keyspace_group_manager_test.go @@ -517,7 +517,7 @@ func TestTwiceSplitKeyspaceGroup(t *testing.T) { re.NoError(err) defer tc.Destroy() tc.WaitLeader() - leaderServer := tc.GetServer(tc.GetLeader()) + leaderServer := tc.GetLeaderServer() re.NoError(leaderServer.BootstrapCluster()) tsoCluster, err := tests.NewTestTSOCluster(ctx, 2, pdAddr) @@ -711,7 +711,7 @@ func TestGetTSOImmediately(t *testing.T) { re.NoError(err) defer tc.Destroy() tc.WaitLeader() - leaderServer := tc.GetServer(tc.GetLeader()) + leaderServer := tc.GetLeaderServer() re.NoError(leaderServer.BootstrapCluster()) tsoCluster, err := tests.NewTestTSOCluster(ctx, 2, pdAddr) diff --git a/tests/integrations/tso/client_test.go b/tests/integrations/tso/client_test.go index 1d2f437e8498..63243214e816 100644 --- a/tests/integrations/tso/client_test.go +++ b/tests/integrations/tso/client_test.go @@ -389,7 +389,7 @@ func (suite *tsoClientTestSuite) TestRandomShutdown() { if !suite.legacy { suite.tsoCluster.WaitForDefaultPrimaryServing(re).Close() } else { - suite.cluster.GetServer(suite.cluster.GetLeader()).GetServer().Close() + suite.cluster.GetLeaderServer().GetServer().Close() } time.Sleep(time.Duration(n) * time.Second) } diff --git a/tests/pdctl/cluster/cluster_test.go b/tests/pdctl/cluster/cluster_test.go index 2b8b8bc8f590..cd4ec6e13914 100644 --- a/tests/pdctl/cluster/cluster_test.go +++ b/tests/pdctl/cluster/cluster_test.go @@ -39,7 +39,7 @@ func TestClusterAndPing(t *testing.T) { err = cluster.RunInitialServers() re.NoError(err) cluster.WaitLeader() - err = cluster.GetServer(cluster.GetLeader()).BootstrapCluster() + err = cluster.GetLeaderServer().BootstrapCluster() re.NoError(err) pdAddr := cluster.GetConfig().GetClientURL() i := strings.Index(pdAddr, "//") diff --git a/tests/pdctl/config/config_test.go b/tests/pdctl/config/config_test.go index f43a964b50cf..6ed0841bf74a 100644 --- a/tests/pdctl/config/config_test.go +++ b/tests/pdctl/config/config_test.go @@ -64,10 +64,10 @@ func TestConfig(t *testing.T) { Id: 1, State: metapb.StoreState_Up, } - leaderServer := cluster.GetServer(cluster.GetLeader()) + leaderServer := cluster.GetLeaderServer() re.NoError(leaderServer.BootstrapCluster()) svr := leaderServer.GetServer() - pdctl.MustPutStore(re, svr, store) + tests.MustPutStore(re, cluster, store) defer cluster.Destroy() // config show @@ -300,10 +300,9 @@ func TestPlacementRules(t *testing.T) { State: metapb.StoreState_Up, LastHeartbeat: time.Now().UnixNano(), } - leaderServer := cluster.GetServer(cluster.GetLeader()) + leaderServer := cluster.GetLeaderServer() re.NoError(leaderServer.BootstrapCluster()) - svr := leaderServer.GetServer() - pdctl.MustPutStore(re, svr, store) + tests.MustPutStore(re, cluster, store) defer cluster.Destroy() output, err := pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "placement-rules", "enable") @@ -358,7 +357,7 @@ func TestPlacementRules(t *testing.T) { re.Equal([2]string{"pd", "test1"}, rules2[1].Key()) // test rule region detail - pdctl.MustPutRegion(re, cluster, 1, 1, []byte("a"), []byte("b")) + tests.MustPutRegion(re, cluster, 1, 1, []byte("a"), []byte("b")) fit := &placement.RegionFit{} // need clear up args, so create new a cobra.Command. Otherwise gourp still exists. cmd2 := pdctlCmd.GetRootCmd() @@ -398,10 +397,9 @@ func TestPlacementRuleGroups(t *testing.T) { State: metapb.StoreState_Up, LastHeartbeat: time.Now().UnixNano(), } - leaderServer := cluster.GetServer(cluster.GetLeader()) + leaderServer := cluster.GetLeaderServer() re.NoError(leaderServer.BootstrapCluster()) - svr := leaderServer.GetServer() - pdctl.MustPutStore(re, svr, store) + tests.MustPutStore(re, cluster, store) defer cluster.Destroy() output, err := pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "placement-rules", "enable") @@ -473,10 +471,9 @@ func TestPlacementRuleBundle(t *testing.T) { State: metapb.StoreState_Up, LastHeartbeat: time.Now().UnixNano(), } - leaderServer := cluster.GetServer(cluster.GetLeader()) + leaderServer := cluster.GetLeaderServer() re.NoError(leaderServer.BootstrapCluster()) - svr := leaderServer.GetServer() - pdctl.MustPutStore(re, svr, store) + tests.MustPutStore(re, cluster, store) defer cluster.Destroy() output, err := pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "placement-rules", "enable") @@ -609,10 +606,9 @@ func TestReplicationMode(t *testing.T) { State: metapb.StoreState_Up, LastHeartbeat: time.Now().UnixNano(), } - leaderServer := cluster.GetServer(cluster.GetLeader()) + leaderServer := cluster.GetLeaderServer() re.NoError(leaderServer.BootstrapCluster()) - svr := leaderServer.GetServer() - pdctl.MustPutStore(re, svr, store) + tests.MustPutStore(re, cluster, store) defer cluster.Destroy() conf := config.ReplicationModeConfig{ @@ -668,10 +664,9 @@ func TestUpdateDefaultReplicaConfig(t *testing.T) { Id: 1, State: metapb.StoreState_Up, } - leaderServer := cluster.GetServer(cluster.GetLeader()) + leaderServer := cluster.GetLeaderServer() re.NoError(leaderServer.BootstrapCluster()) - svr := leaderServer.GetServer() - pdctl.MustPutStore(re, svr, store) + tests.MustPutStore(re, cluster, store) defer cluster.Destroy() checkMaxReplicas := func(expect uint64) { @@ -813,10 +808,9 @@ func TestPDServerConfig(t *testing.T) { State: metapb.StoreState_Up, LastHeartbeat: time.Now().UnixNano(), } - leaderServer := cluster.GetServer(cluster.GetLeader()) + leaderServer := cluster.GetLeaderServer() re.NoError(leaderServer.BootstrapCluster()) - svr := leaderServer.GetServer() - pdctl.MustPutStore(re, svr, store) + tests.MustPutStore(re, cluster, store) defer cluster.Destroy() output, err := pdctl.ExecuteCommand(cmd, "-u", pdAddr, "config", "show", "server") diff --git a/tests/pdctl/health/health_test.go b/tests/pdctl/health/health_test.go index bc808a367501..748250babe4e 100644 --- a/tests/pdctl/health/health_test.go +++ b/tests/pdctl/health/health_test.go @@ -36,7 +36,7 @@ func TestHealth(t *testing.T) { err = tc.RunInitialServers() re.NoError(err) tc.WaitLeader() - leaderServer := tc.GetServer(tc.GetLeader()) + leaderServer := tc.GetLeaderServer() re.NoError(leaderServer.BootstrapCluster()) pdAddr := tc.GetConfig().GetClientURL() cmd := pdctlCmd.GetRootCmd() diff --git a/tests/pdctl/helper.go b/tests/pdctl/helper.go index d7d6a8584978..3912cdfef7c8 100644 --- a/tests/pdctl/helper.go +++ b/tests/pdctl/helper.go @@ -16,21 +16,13 @@ package pdctl import ( "bytes" - "context" - "fmt" "sort" - "github.com/docker/go-units" - "github.com/pingcap/kvproto/pkg/metapb" - "github.com/pingcap/kvproto/pkg/pdpb" "github.com/spf13/cobra" "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/core" "github.com/tikv/pd/pkg/utils/typeutil" - "github.com/tikv/pd/pkg/versioninfo" - "github.com/tikv/pd/server" "github.com/tikv/pd/server/api" - "github.com/tikv/pd/tests" ) // ExecuteCommand is used for test purpose. @@ -89,59 +81,3 @@ func CheckRegionsInfo(re *require.Assertions, output *api.RegionsInfo, expected CheckRegionInfo(re, &got[i], region) } } - -// MustPutStore is used for test purpose. -func MustPutStore(re *require.Assertions, svr *server.Server, store *metapb.Store) { - store.Address = fmt.Sprintf("tikv%d", store.GetId()) - if len(store.Version) == 0 { - store.Version = versioninfo.MinSupportedVersion(versioninfo.Version2_0).String() - } - grpcServer := &server.GrpcServer{Server: svr} - _, err := grpcServer.PutStore(context.Background(), &pdpb.PutStoreRequest{ - Header: &pdpb.RequestHeader{ClusterId: svr.ClusterID()}, - Store: store, - }) - re.NoError(err) - - storeInfo := grpcServer.GetRaftCluster().GetStore(store.GetId()) - newStore := storeInfo.Clone(core.SetStoreStats(&pdpb.StoreStats{ - Capacity: uint64(10 * units.GiB), - UsedSize: uint64(9 * units.GiB), - Available: uint64(1 * units.GiB), - })) - grpcServer.GetRaftCluster().GetBasicCluster().PutStore(newStore) -} - -// MustPutRegion is used for test purpose. -func MustPutRegion(re *require.Assertions, cluster *tests.TestCluster, regionID, storeID uint64, start, end []byte, opts ...core.RegionCreateOption) *core.RegionInfo { - leader := &metapb.Peer{ - Id: regionID, - StoreId: storeID, - } - metaRegion := &metapb.Region{ - Id: regionID, - StartKey: start, - EndKey: end, - Peers: []*metapb.Peer{leader}, - RegionEpoch: &metapb.RegionEpoch{ConfVer: 1, Version: 1}, - } - r := core.NewRegionInfo(metaRegion, leader, opts...) - err := cluster.HandleRegionHeartbeat(r) - re.NoError(err) - return r -} - -// MustReportBuckets is used for test purpose. -func MustReportBuckets(re *require.Assertions, cluster *tests.TestCluster, regionID uint64, start, end []byte, stats *metapb.BucketStats) *metapb.Buckets { - buckets := &metapb.Buckets{ - RegionId: regionID, - Version: 1, - Keys: [][]byte{start, end}, - Stats: stats, - // report buckets interval is 10s - PeriodInMs: 10000, - } - err := cluster.HandleReportBuckets(buckets) - re.NoError(err) - return buckets -} diff --git a/tests/pdctl/hot/hot_test.go b/tests/pdctl/hot/hot_test.go index 352b891c0925..359d89199c9b 100644 --- a/tests/pdctl/hot/hot_test.go +++ b/tests/pdctl/hot/hot_test.go @@ -63,10 +63,10 @@ func TestHot(t *testing.T) { Labels: []*metapb.StoreLabel{{Key: "engine", Value: "tiflash"}}, } - leaderServer := cluster.GetServer(cluster.GetLeader()) + leaderServer := cluster.GetLeaderServer() re.NoError(leaderServer.BootstrapCluster()) - pdctl.MustPutStore(re, leaderServer.GetServer(), store1) - pdctl.MustPutStore(re, leaderServer.GetServer(), store2) + tests.MustPutStore(re, cluster, store1) + tests.MustPutStore(re, cluster, store2) defer cluster.Destroy() // test hot store @@ -159,7 +159,7 @@ func TestHot(t *testing.T) { } testHot(hotRegionID, hotStoreID, "read") case "write": - pdctl.MustPutRegion( + tests.MustPutRegion( re, cluster, hotRegionID, hotStoreID, []byte("c"), []byte("d"), @@ -222,16 +222,16 @@ func TestHotWithStoreID(t *testing.T) { }, } - leaderServer := cluster.GetServer(cluster.GetLeader()) + leaderServer := cluster.GetLeaderServer() re.NoError(leaderServer.BootstrapCluster()) for _, store := range stores { - pdctl.MustPutStore(re, leaderServer.GetServer(), store) + tests.MustPutStore(re, cluster, store) } defer cluster.Destroy() - pdctl.MustPutRegion(re, cluster, 1, 1, []byte("a"), []byte("b"), core.SetWrittenBytes(3000000000), core.SetReportInterval(0, utils.RegionHeartBeatReportInterval)) - pdctl.MustPutRegion(re, cluster, 2, 2, []byte("c"), []byte("d"), core.SetWrittenBytes(6000000000), core.SetReportInterval(0, utils.RegionHeartBeatReportInterval)) - pdctl.MustPutRegion(re, cluster, 3, 1, []byte("e"), []byte("f"), core.SetWrittenBytes(9000000000), core.SetReportInterval(0, utils.RegionHeartBeatReportInterval)) + tests.MustPutRegion(re, cluster, 1, 1, []byte("a"), []byte("b"), core.SetWrittenBytes(3000000000), core.SetReportInterval(0, utils.RegionHeartBeatReportInterval)) + tests.MustPutRegion(re, cluster, 2, 2, []byte("c"), []byte("d"), core.SetWrittenBytes(6000000000), core.SetReportInterval(0, utils.RegionHeartBeatReportInterval)) + tests.MustPutRegion(re, cluster, 3, 1, []byte("e"), []byte("f"), core.SetWrittenBytes(9000000000), core.SetReportInterval(0, utils.RegionHeartBeatReportInterval)) // wait hot scheduler starts rc := leaderServer.GetRaftCluster() testutil.Eventually(re, func() bool { @@ -267,7 +267,7 @@ func TestHotWithStoreID(t *testing.T) { WriteBytes: []uint64{13 * units.MiB}, WriteQps: []uint64{0}, } - buckets := pdctl.MustReportBuckets(re, cluster, 1, []byte("a"), []byte("b"), stats) + buckets := tests.MustReportBuckets(re, cluster, 1, []byte("a"), []byte("b"), stats) args = []string{"-u", pdAddr, "hot", "buckets", "1"} output, err = pdctl.ExecuteCommand(cmd, args...) re.NoError(err) @@ -330,20 +330,20 @@ func TestHistoryHotRegions(t *testing.T) { }, } - leaderServer := cluster.GetServer(cluster.GetLeader()) + leaderServer := cluster.GetLeaderServer() re.NoError(leaderServer.BootstrapCluster()) for _, store := range stores { - pdctl.MustPutStore(re, leaderServer.GetServer(), store) + tests.MustPutStore(re, cluster, store) } defer cluster.Destroy() startTime := time.Now().Unix() - pdctl.MustPutRegion(re, cluster, 1, 1, []byte("a"), []byte("b"), core.SetWrittenBytes(3000000000), + tests.MustPutRegion(re, cluster, 1, 1, []byte("a"), []byte("b"), core.SetWrittenBytes(3000000000), core.SetReportInterval(uint64(startTime-utils.RegionHeartBeatReportInterval), uint64(startTime))) - pdctl.MustPutRegion(re, cluster, 2, 2, []byte("c"), []byte("d"), core.SetWrittenBytes(6000000000), + tests.MustPutRegion(re, cluster, 2, 2, []byte("c"), []byte("d"), core.SetWrittenBytes(6000000000), core.SetReportInterval(uint64(startTime-utils.RegionHeartBeatReportInterval), uint64(startTime))) - pdctl.MustPutRegion(re, cluster, 3, 1, []byte("e"), []byte("f"), core.SetWrittenBytes(9000000000), + tests.MustPutRegion(re, cluster, 3, 1, []byte("e"), []byte("f"), core.SetWrittenBytes(9000000000), core.SetReportInterval(uint64(startTime-utils.RegionHeartBeatReportInterval), uint64(startTime))) - pdctl.MustPutRegion(re, cluster, 4, 3, []byte("g"), []byte("h"), core.SetWrittenBytes(9000000000), + tests.MustPutRegion(re, cluster, 4, 3, []byte("g"), []byte("h"), core.SetWrittenBytes(9000000000), core.SetReportInterval(uint64(startTime-utils.RegionHeartBeatReportInterval), uint64(startTime))) // wait hot scheduler starts testutil.Eventually(re, func() bool { @@ -440,10 +440,10 @@ func TestHotWithoutHotPeer(t *testing.T) { }, } - leaderServer := cluster.GetServer(cluster.GetLeader()) + leaderServer := cluster.GetLeaderServer() re.NoError(leaderServer.BootstrapCluster()) for _, store := range stores { - pdctl.MustPutStore(re, leaderServer.GetServer(), store) + tests.MustPutStore(re, cluster, store) } timestamp := uint64(time.Now().UnixNano()) load := 1024.0 diff --git a/tests/pdctl/keyspace/keyspace_group_test.go b/tests/pdctl/keyspace/keyspace_group_test.go index 105e860ad173..0b09550d9676 100644 --- a/tests/pdctl/keyspace/keyspace_group_test.go +++ b/tests/pdctl/keyspace/keyspace_group_test.go @@ -44,7 +44,7 @@ func TestKeyspaceGroup(t *testing.T) { err = tc.RunInitialServers() re.NoError(err) tc.WaitLeader() - leaderServer := tc.GetServer(tc.GetLeader()) + leaderServer := tc.GetLeaderServer() re.NoError(leaderServer.BootstrapCluster()) pdAddr := tc.GetConfig().GetClientURL() cmd := pdctlCmd.GetRootCmd() @@ -113,7 +113,7 @@ func TestSplitKeyspaceGroup(t *testing.T) { cmd := pdctlCmd.GetRootCmd() tc.WaitLeader() - leaderServer := tc.GetServer(tc.GetLeader()) + leaderServer := tc.GetLeaderServer() re.NoError(leaderServer.BootstrapCluster()) // split keyspace group. @@ -164,7 +164,7 @@ func TestExternalAllocNodeWhenStart(t *testing.T) { cmd := pdctlCmd.GetRootCmd() tc.WaitLeader() - leaderServer := tc.GetServer(tc.GetLeader()) + leaderServer := tc.GetLeaderServer() re.NoError(leaderServer.BootstrapCluster()) // check keyspace group information. @@ -207,7 +207,7 @@ func TestSetNodeAndPriorityKeyspaceGroup(t *testing.T) { cmd := pdctlCmd.GetRootCmd() tc.WaitLeader() - leaderServer := tc.GetServer(tc.GetLeader()) + leaderServer := tc.GetLeaderServer() re.NoError(leaderServer.BootstrapCluster()) // set-node keyspace group. @@ -309,7 +309,7 @@ func TestMergeKeyspaceGroup(t *testing.T) { cmd := pdctlCmd.GetRootCmd() tc.WaitLeader() - leaderServer := tc.GetServer(tc.GetLeader()) + leaderServer := tc.GetLeaderServer() re.NoError(leaderServer.BootstrapCluster()) // split keyspace group. @@ -427,7 +427,7 @@ func TestKeyspaceGroupState(t *testing.T) { cmd := pdctlCmd.GetRootCmd() tc.WaitLeader() - leaderServer := tc.GetServer(tc.GetLeader()) + leaderServer := tc.GetLeaderServer() re.NoError(leaderServer.BootstrapCluster()) // split keyspace group. @@ -518,7 +518,7 @@ func TestShowKeyspaceGroupPrimary(t *testing.T) { cmd := pdctlCmd.GetRootCmd() tc.WaitLeader() - leaderServer := tc.GetServer(tc.GetLeader()) + leaderServer := tc.GetLeaderServer() re.NoError(leaderServer.BootstrapCluster()) defaultKeyspaceGroupID := fmt.Sprintf("%d", utils.DefaultKeyspaceGroupID) @@ -600,7 +600,7 @@ func TestInPDMode(t *testing.T) { pdAddr := tc.GetConfig().GetClientURL() cmd := pdctlCmd.GetRootCmd() tc.WaitLeader() - leaderServer := tc.GetServer(tc.GetLeader()) + leaderServer := tc.GetLeaderServer() re.NoError(leaderServer.BootstrapCluster()) argsList := [][]string{ diff --git a/tests/pdctl/keyspace/keyspace_test.go b/tests/pdctl/keyspace/keyspace_test.go index a0bab4114df8..57acdc86c707 100644 --- a/tests/pdctl/keyspace/keyspace_test.go +++ b/tests/pdctl/keyspace/keyspace_test.go @@ -58,7 +58,7 @@ func TestKeyspace(t *testing.T) { cmd := pdctlCmd.GetRootCmd() tc.WaitLeader() - leaderServer := tc.GetServer(tc.GetLeader()) + leaderServer := tc.GetLeaderServer() re.NoError(leaderServer.BootstrapCluster()) defaultKeyspaceGroupID := fmt.Sprintf("%d", utils.DefaultKeyspaceGroupID) diff --git a/tests/pdctl/label/label_test.go b/tests/pdctl/label/label_test.go index ba31b1fb1d1e..9c64933a1279 100644 --- a/tests/pdctl/label/label_test.go +++ b/tests/pdctl/label/label_test.go @@ -92,11 +92,11 @@ func TestLabel(t *testing.T) { }, }, } - leaderServer := cluster.GetServer(cluster.GetLeader()) + leaderServer := cluster.GetLeaderServer() re.NoError(leaderServer.BootstrapCluster()) for _, store := range stores { - pdctl.MustPutStore(re, leaderServer.GetServer(), store.Store.Store) + tests.MustPutStore(re, cluster, store.Store.Store) } defer cluster.Destroy() diff --git a/tests/pdctl/log/log_test.go b/tests/pdctl/log/log_test.go index 7f2e4f205842..e69952313298 100644 --- a/tests/pdctl/log/log_test.go +++ b/tests/pdctl/log/log_test.go @@ -21,7 +21,6 @@ import ( "github.com/pingcap/kvproto/pkg/metapb" "github.com/stretchr/testify/suite" - "github.com/tikv/pd/server" "github.com/tikv/pd/tests" "github.com/tikv/pd/tests/pdctl" pdctlCmd "github.com/tikv/pd/tools/pd-ctl/pdctl" @@ -32,7 +31,6 @@ type logTestSuite struct { ctx context.Context cancel context.CancelFunc cluster *tests.TestCluster - svr *server.Server pdAddrs []string } @@ -54,10 +52,9 @@ func (suite *logTestSuite) SetupSuite() { State: metapb.StoreState_Up, LastHeartbeat: time.Now().UnixNano(), } - leaderServer := suite.cluster.GetServer(suite.cluster.GetLeader()) + leaderServer := suite.cluster.GetLeaderServer() suite.NoError(leaderServer.BootstrapCluster()) - suite.svr = leaderServer.GetServer() - pdctl.MustPutStore(suite.Require(), suite.svr, store) + tests.MustPutStore(suite.Require(), suite.cluster, store) } func (suite *logTestSuite) TearDownSuite() { @@ -97,7 +94,7 @@ func (suite *logTestSuite) TestLog() { for _, testCase := range testCases { _, err := pdctl.ExecuteCommand(cmd, testCase.cmd...) suite.NoError(err) - suite.Equal(testCase.expect, suite.svr.GetConfig().Log.Level) + suite.Equal(testCase.expect, suite.cluster.GetLeaderServer().GetConfig().Log.Level) } } diff --git a/tests/pdctl/member/member_test.go b/tests/pdctl/member/member_test.go index 9c7874992538..af3ee771e82b 100644 --- a/tests/pdctl/member/member_test.go +++ b/tests/pdctl/member/member_test.go @@ -38,7 +38,7 @@ func TestMember(t *testing.T) { err = cluster.RunInitialServers() re.NoError(err) cluster.WaitLeader() - leaderServer := cluster.GetServer(cluster.GetLeader()) + leaderServer := cluster.GetLeaderServer() re.NoError(leaderServer.BootstrapCluster()) pdAddr := cluster.GetConfig().GetClientURL() re.NoError(err) diff --git a/tests/pdctl/operator/operator_test.go b/tests/pdctl/operator/operator_test.go index 148cbc9e0815..a95c620adcfe 100644 --- a/tests/pdctl/operator/operator_test.go +++ b/tests/pdctl/operator/operator_test.go @@ -79,17 +79,17 @@ func TestOperator(t *testing.T) { }, } - leaderServer := cluster.GetServer(cluster.GetLeader()) + leaderServer := cluster.GetLeaderServer() re.NoError(leaderServer.BootstrapCluster()) for _, store := range stores { - pdctl.MustPutStore(re, leaderServer.GetServer(), store) + tests.MustPutStore(re, cluster, store) } - pdctl.MustPutRegion(re, cluster, 1, 1, []byte("a"), []byte("b"), core.SetPeers([]*metapb.Peer{ + tests.MustPutRegion(re, cluster, 1, 1, []byte("a"), []byte("b"), core.SetPeers([]*metapb.Peer{ {Id: 1, StoreId: 1}, {Id: 2, StoreId: 2}, })) - pdctl.MustPutRegion(re, cluster, 3, 2, []byte("b"), []byte("d"), core.SetPeers([]*metapb.Peer{ + tests.MustPutRegion(re, cluster, 3, 2, []byte("b"), []byte("d"), core.SetPeers([]*metapb.Peer{ {Id: 3, StoreId: 1}, {Id: 4, StoreId: 2}, })) @@ -261,7 +261,7 @@ func TestForwardOperatorRequest(t *testing.T) { re.NoError(err) re.NoError(cluster.RunInitialServers()) re.NotEmpty(cluster.WaitLeader()) - server := cluster.GetServer(cluster.GetLeader()) + server := cluster.GetLeaderServer() re.NoError(server.BootstrapCluster()) backendEndpoints := server.GetAddr() tc, err := tests.NewTestSchedulingCluster(ctx, 2, backendEndpoints) diff --git a/tests/pdctl/region/region_test.go b/tests/pdctl/region/region_test.go index d56463d728d7..b913f1b09239 100644 --- a/tests/pdctl/region/region_test.go +++ b/tests/pdctl/region/region_test.go @@ -45,9 +45,9 @@ func TestRegionKeyFormat(t *testing.T) { State: metapb.StoreState_Up, LastHeartbeat: time.Now().UnixNano(), } - leaderServer := cluster.GetServer(cluster.GetLeader()) + leaderServer := cluster.GetLeaderServer() re.NoError(leaderServer.BootstrapCluster()) - pdctl.MustPutStore(re, leaderServer.GetServer(), store) + tests.MustPutStore(re, cluster, store) cmd := pdctlCmd.GetRootCmd() output, err := pdctl.ExecuteCommand(cmd, "-u", url, "region", "key", "--format=raw", " ") @@ -72,12 +72,12 @@ func TestRegion(t *testing.T) { State: metapb.StoreState_Up, LastHeartbeat: time.Now().UnixNano(), } - leaderServer := cluster.GetServer(cluster.GetLeader()) + leaderServer := cluster.GetLeaderServer() re.NoError(leaderServer.BootstrapCluster()) - pdctl.MustPutStore(re, leaderServer.GetServer(), store) + tests.MustPutStore(re, cluster, store) downPeer := &metapb.Peer{Id: 8, StoreId: 3} - r1 := pdctl.MustPutRegion(re, cluster, 1, 1, []byte("a"), []byte("b"), + r1 := tests.MustPutRegion(re, cluster, 1, 1, []byte("a"), []byte("b"), core.SetWrittenBytes(1000), core.SetReadBytes(1000), core.SetRegionConfVer(1), core.SetRegionVersion(1), core.SetApproximateSize(1), core.SetApproximateKeys(100), core.SetPeers([]*metapb.Peer{ @@ -86,16 +86,16 @@ func TestRegion(t *testing.T) { {Id: 6, StoreId: 3}, {Id: 7, StoreId: 4}, })) - r2 := pdctl.MustPutRegion(re, cluster, 2, 1, []byte("b"), []byte("c"), + r2 := tests.MustPutRegion(re, cluster, 2, 1, []byte("b"), []byte("c"), core.SetWrittenBytes(2000), core.SetReadBytes(0), core.SetRegionConfVer(2), core.SetRegionVersion(3), core.SetApproximateSize(144), core.SetApproximateKeys(14400), ) - r3 := pdctl.MustPutRegion(re, cluster, 3, 1, []byte("c"), []byte("d"), + r3 := tests.MustPutRegion(re, cluster, 3, 1, []byte("c"), []byte("d"), core.SetWrittenBytes(500), core.SetReadBytes(800), core.SetRegionConfVer(3), core.SetRegionVersion(2), core.SetApproximateSize(30), core.SetApproximateKeys(3000), core.WithDownPeers([]*pdpb.PeerStats{{Peer: downPeer, DownSeconds: 3600}}), core.WithPendingPeers([]*metapb.Peer{downPeer}), core.WithLearners([]*metapb.Peer{{Id: 3, StoreId: 1}})) - r4 := pdctl.MustPutRegion(re, cluster, 4, 1, []byte("d"), []byte("e"), + r4 := tests.MustPutRegion(re, cluster, 4, 1, []byte("d"), []byte("e"), core.SetWrittenBytes(100), core.SetReadBytes(100), core.SetRegionConfVer(1), core.SetRegionVersion(1), core.SetApproximateSize(10), core.SetApproximateKeys(1000), ) @@ -197,7 +197,7 @@ func TestRegion(t *testing.T) { } // Test region range-holes. - r5 := pdctl.MustPutRegion(re, cluster, 5, 1, []byte("x"), []byte("z")) + r5 := tests.MustPutRegion(re, cluster, 5, 1, []byte("x"), []byte("z")) output, err := pdctl.ExecuteCommand(cmd, []string{"-u", pdAddr, "region", "range-holes"}...) re.NoError(err) rangeHoles := new([][]string) diff --git a/tests/pdctl/scheduler/scheduler_test.go b/tests/pdctl/scheduler/scheduler_test.go index 31e6270aa3bc..f2d44a589a4d 100644 --- a/tests/pdctl/scheduler/scheduler_test.go +++ b/tests/pdctl/scheduler/scheduler_test.go @@ -94,14 +94,14 @@ func TestScheduler(t *testing.T) { re.Equal(expectedConfig, configInfo) } - leaderServer := cluster.GetServer(cluster.GetLeader()) + leaderServer := cluster.GetLeaderServer() re.NoError(leaderServer.BootstrapCluster()) for _, store := range stores { - pdctl.MustPutStore(re, leaderServer.GetServer(), store) + tests.MustPutStore(re, cluster, store) } // note: because pdqsort is a unstable sort algorithm, set ApproximateSize for this region. - pdctl.MustPutRegion(re, cluster, 1, 1, []byte("a"), []byte("b"), core.SetApproximateSize(10)) + tests.MustPutRegion(re, cluster, 1, 1, []byte("a"), []byte("b"), core.SetApproximateSize(10)) time.Sleep(3 * time.Second) // scheduler show command @@ -363,7 +363,7 @@ func TestScheduler(t *testing.T) { for _, store := range stores { version := versioninfo.HotScheduleWithQuery store.Version = versioninfo.MinSupportedVersion(version).String() - pdctl.MustPutStore(re, leaderServer.GetServer(), store) + tests.MustPutStore(re, cluster, store) } re.Equal("5.2.0", leaderServer.GetClusterVersion().String()) // After upgrading, we should not use query. @@ -488,14 +488,14 @@ func TestSchedulerDiagnostic(t *testing.T) { LastHeartbeat: time.Now().UnixNano(), }, } - leaderServer := cluster.GetServer(cluster.GetLeader()) + leaderServer := cluster.GetLeaderServer() re.NoError(leaderServer.BootstrapCluster()) for _, store := range stores { - pdctl.MustPutStore(re, leaderServer.GetServer(), store) + tests.MustPutStore(re, cluster, store) } // note: because pdqsort is a unstable sort algorithm, set ApproximateSize for this region. - pdctl.MustPutRegion(re, cluster, 1, 1, []byte("a"), []byte("b"), core.SetApproximateSize(10)) + tests.MustPutRegion(re, cluster, 1, 1, []byte("a"), []byte("b"), core.SetApproximateSize(10)) time.Sleep(3 * time.Second) echo := mustExec(re, cmd, []string{"-u", pdAddr, "config", "set", "enable-diagnostic", "true"}, nil) @@ -539,7 +539,7 @@ func TestForwardSchedulerRequest(t *testing.T) { re.NoError(err) re.NoError(cluster.RunInitialServers()) re.NotEmpty(cluster.WaitLeader()) - server := cluster.GetServer(cluster.GetLeader()) + server := cluster.GetLeaderServer() re.NoError(server.BootstrapCluster()) backendEndpoints := server.GetAddr() tc, err := tests.NewTestSchedulingCluster(ctx, 2, backendEndpoints) diff --git a/tests/pdctl/store/store_test.go b/tests/pdctl/store/store_test.go index 0ac68e35d985..13c7350bb6f1 100644 --- a/tests/pdctl/store/store_test.go +++ b/tests/pdctl/store/store_test.go @@ -79,11 +79,11 @@ func TestStore(t *testing.T) { }, } - leaderServer := cluster.GetServer(cluster.GetLeader()) + leaderServer := cluster.GetLeaderServer() re.NoError(leaderServer.BootstrapCluster()) for _, store := range stores { - pdctl.MustPutStore(re, leaderServer.GetServer(), store.Store.Store) + tests.MustPutStore(re, cluster, store.Store.Store) } defer cluster.Destroy() @@ -293,7 +293,7 @@ func TestStore(t *testing.T) { NodeState: metapb.NodeState_Serving, LastHeartbeat: time.Now().UnixNano(), } - pdctl.MustPutStore(re, leaderServer.GetServer(), store2) + tests.MustPutStore(re, cluster, store2) } // store delete command @@ -506,15 +506,15 @@ func TestTombstoneStore(t *testing.T) { }, } - leaderServer := cluster.GetServer(cluster.GetLeader()) + leaderServer := cluster.GetLeaderServer() re.NoError(leaderServer.BootstrapCluster()) for _, store := range stores { - pdctl.MustPutStore(re, leaderServer.GetServer(), store.Store.Store) + tests.MustPutStore(re, cluster, store.Store.Store) } defer cluster.Destroy() - pdctl.MustPutRegion(re, cluster, 1, 2, []byte("a"), []byte("b"), core.SetWrittenBytes(3000000000), core.SetReportInterval(0, utils.RegionHeartBeatReportInterval)) - pdctl.MustPutRegion(re, cluster, 2, 3, []byte("b"), []byte("c"), core.SetWrittenBytes(3000000000), core.SetReportInterval(0, utils.RegionHeartBeatReportInterval)) + tests.MustPutRegion(re, cluster, 1, 2, []byte("a"), []byte("b"), core.SetWrittenBytes(3000000000), core.SetReportInterval(0, utils.RegionHeartBeatReportInterval)) + tests.MustPutRegion(re, cluster, 2, 3, []byte("b"), []byte("c"), core.SetWrittenBytes(3000000000), core.SetReportInterval(0, utils.RegionHeartBeatReportInterval)) // store remove-tombstone args := []string{"-u", pdAddr, "store", "remove-tombstone"} output, err := pdctl.ExecuteCommand(cmd, args...) diff --git a/tests/pdctl/unsafe/unsafe_operation_test.go b/tests/pdctl/unsafe/unsafe_operation_test.go index 1e4e34682253..e0fdb9835911 100644 --- a/tests/pdctl/unsafe/unsafe_operation_test.go +++ b/tests/pdctl/unsafe/unsafe_operation_test.go @@ -33,7 +33,7 @@ func TestRemoveFailedStores(t *testing.T) { err = cluster.RunInitialServers() re.NoError(err) cluster.WaitLeader() - err = cluster.GetServer(cluster.GetLeader()).BootstrapCluster() + err = cluster.GetLeaderServer().BootstrapCluster() re.NoError(err) pdAddr := cluster.GetConfig().GetClientURL() cmd := pdctlCmd.GetRootCmd() diff --git a/tests/registry/registry_test.go b/tests/registry/registry_test.go index da68bddd354c..a3aff76a1cff 100644 --- a/tests/registry/registry_test.go +++ b/tests/registry/registry_test.go @@ -76,8 +76,8 @@ func TestRegistryService(t *testing.T) { err = cluster.RunInitialServers() re.NoError(err) - leaderName := cluster.WaitLeader() - leader := cluster.GetServer(leaderName) + re.NotEmpty(cluster.WaitLeader()) + leader := cluster.GetLeaderServer() // Test registered GRPC Service cc, err := grpc.DialContext(ctx, strings.TrimPrefix(leader.GetAddr(), "http://"), grpc.WithInsecure()) diff --git a/tests/server/api/api_test.go b/tests/server/api/api_test.go index cc35d9eaab3f..ff430f1b848b 100644 --- a/tests/server/api/api_test.go +++ b/tests/server/api/api_test.go @@ -40,7 +40,6 @@ import ( "github.com/tikv/pd/server/api" "github.com/tikv/pd/server/config" "github.com/tikv/pd/tests" - "github.com/tikv/pd/tests/pdctl" "go.uber.org/goleak" ) @@ -64,6 +63,7 @@ func TestReconnect(t *testing.T) { // Make connections to followers. // Make sure they proxy requests to the leader. leader := cluster.WaitLeader() + re.NotEmpty(leader) for name, s := range cluster.GetServers() { if name != leader { res, err := http.Get(s.GetConfig().AdvertiseClientUrls + "/pd/api/v1/version") @@ -136,7 +136,7 @@ func (suite *middlewareTestSuite) TearDownSuite() { func (suite *middlewareTestSuite) TestRequestInfoMiddleware() { suite.NoError(failpoint.Enable("github.com/tikv/pd/server/api/addRequestInfoMiddleware", "return(true)")) - leader := suite.cluster.GetServer(suite.cluster.GetLeader()) + leader := suite.cluster.GetLeaderServer() suite.NotNil(leader) input := map[string]interface{}{ @@ -190,7 +190,7 @@ func BenchmarkDoRequestWithServiceMiddleware(b *testing.B) { cluster, _ := tests.NewTestCluster(ctx, 1) cluster.RunInitialServers() cluster.WaitLeader() - leader := cluster.GetServer(cluster.GetLeader()) + leader := cluster.GetLeaderServer() input := map[string]interface{}{ "enable-audit": "true", } @@ -207,7 +207,7 @@ func BenchmarkDoRequestWithServiceMiddleware(b *testing.B) { } func (suite *middlewareTestSuite) TestRateLimitMiddleware() { - leader := suite.cluster.GetServer(suite.cluster.GetLeader()) + leader := suite.cluster.GetLeaderServer() suite.NotNil(leader) input := map[string]interface{}{ "enable-rate-limit": "true", @@ -296,7 +296,7 @@ func (suite *middlewareTestSuite) TestRateLimitMiddleware() { servers = append(servers, s.GetServer()) } server.MustWaitLeader(suite.Require(), servers) - leader = suite.cluster.GetServer(suite.cluster.GetLeader()) + leader = suite.cluster.GetLeaderServer() suite.Equal(leader.GetServer().GetServiceMiddlewarePersistOptions().IsRateLimitEnabled(), true) cfg, ok := leader.GetServer().GetRateLimitConfig().LimiterConfig["SetLogLevel"] suite.Equal(ok, true) @@ -372,7 +372,7 @@ func (suite *middlewareTestSuite) TestRateLimitMiddleware() { } func (suite *middlewareTestSuite) TestSwaggerUrl() { - leader := suite.cluster.GetServer(suite.cluster.GetLeader()) + leader := suite.cluster.GetLeaderServer() suite.NotNil(leader) req, _ := http.NewRequest(http.MethodGet, leader.GetAddr()+"/swagger/ui/index", nil) resp, err := dialClient.Do(req) @@ -382,7 +382,7 @@ func (suite *middlewareTestSuite) TestSwaggerUrl() { } func (suite *middlewareTestSuite) TestAuditPrometheusBackend() { - leader := suite.cluster.GetServer(suite.cluster.GetLeader()) + leader := suite.cluster.GetLeaderServer() suite.NotNil(leader) input := map[string]interface{}{ "enable-audit": "true", @@ -418,7 +418,7 @@ func (suite *middlewareTestSuite) TestAuditPrometheusBackend() { servers = append(servers, s.GetServer()) } server.MustWaitLeader(suite.Require(), servers) - leader = suite.cluster.GetServer(suite.cluster.GetLeader()) + leader = suite.cluster.GetLeaderServer() timeUnix = time.Now().Unix() - 20 req, _ = http.NewRequest(http.MethodGet, fmt.Sprintf("%s/pd/api/v1/trend?from=%d", leader.GetAddr(), timeUnix), nil) @@ -451,7 +451,7 @@ func (suite *middlewareTestSuite) TestAuditPrometheusBackend() { func (suite *middlewareTestSuite) TestAuditLocalLogBackend() { fname := testutil.InitTempFileLogger("info") defer os.RemoveAll(fname) - leader := suite.cluster.GetServer(suite.cluster.GetLeader()) + leader := suite.cluster.GetLeaderServer() suite.NotNil(leader) input := map[string]interface{}{ "enable-audit": "true", @@ -481,7 +481,7 @@ func BenchmarkDoRequestWithLocalLogAudit(b *testing.B) { cluster, _ := tests.NewTestCluster(ctx, 1) cluster.RunInitialServers() cluster.WaitLeader() - leader := cluster.GetServer(cluster.GetLeader()) + leader := cluster.GetLeaderServer() input := map[string]interface{}{ "enable-audit": "true", } @@ -503,7 +503,7 @@ func BenchmarkDoRequestWithPrometheusAudit(b *testing.B) { cluster, _ := tests.NewTestCluster(ctx, 1) cluster.RunInitialServers() cluster.WaitLeader() - leader := cluster.GetServer(cluster.GetLeader()) + leader := cluster.GetLeaderServer() input := map[string]interface{}{ "enable-audit": "true", } @@ -525,7 +525,7 @@ func BenchmarkDoRequestWithoutServiceMiddleware(b *testing.B) { cluster, _ := tests.NewTestCluster(ctx, 1) cluster.RunInitialServers() cluster.WaitLeader() - leader := cluster.GetServer(cluster.GetLeader()) + leader := cluster.GetLeaderServer() input := map[string]interface{}{ "enable-audit": "false", } @@ -586,7 +586,7 @@ func (suite *redirectorTestSuite) TearDownSuite() { func (suite *redirectorTestSuite) TestRedirect() { re := suite.Require() - leader := suite.cluster.GetServer(suite.cluster.GetLeader()) + leader := suite.cluster.GetLeaderServer() suite.NotNil(leader) header := mustRequestSuccess(re, leader.GetServer()) header.Del("Date") @@ -602,7 +602,7 @@ func (suite *redirectorTestSuite) TestRedirect() { func (suite *redirectorTestSuite) TestAllowFollowerHandle() { // Find a follower. var follower *server.Server - leader := suite.cluster.GetServer(suite.cluster.GetLeader()) + leader := suite.cluster.GetLeaderServer() for _, svr := range suite.cluster.GetServers() { if svr != leader { follower = svr.GetServer() @@ -626,7 +626,7 @@ func (suite *redirectorTestSuite) TestAllowFollowerHandle() { func (suite *redirectorTestSuite) TestNotLeader() { // Find a follower. var follower *server.Server - leader := suite.cluster.GetServer(suite.cluster.GetLeader()) + leader := suite.cluster.GetLeaderServer() for _, svr := range suite.cluster.GetServers() { if svr != leader { follower = svr.GetServer() @@ -657,7 +657,7 @@ func (suite *redirectorTestSuite) TestNotLeader() { } func (suite *redirectorTestSuite) TestXForwardedFor() { - leader := suite.cluster.GetServer(suite.cluster.GetLeader()) + leader := suite.cluster.GetLeaderServer() suite.NoError(leader.BootstrapCluster()) fname := testutil.InitTempFileLogger("info") defer os.RemoveAll(fname) @@ -702,7 +702,7 @@ func TestRemovingProgress(t *testing.T) { re.NoError(err) cluster.WaitLeader() - leader := cluster.GetServer(cluster.GetLeader()) + leader := cluster.GetLeaderServer() grpcPDClient := testutil.MustNewGrpcClient(re, leader.GetAddr()) clusterID := leader.GetClusterID() req := &pdpb.BootstrapRequest{ @@ -735,12 +735,12 @@ func TestRemovingProgress(t *testing.T) { } for _, store := range stores { - pdctl.MustPutStore(re, leader.GetServer(), store) + tests.MustPutStore(re, cluster, store) } - pdctl.MustPutRegion(re, cluster, 1000, 1, []byte("a"), []byte("b"), core.SetApproximateSize(60)) - pdctl.MustPutRegion(re, cluster, 1001, 2, []byte("c"), []byte("d"), core.SetApproximateSize(30)) - pdctl.MustPutRegion(re, cluster, 1002, 1, []byte("e"), []byte("f"), core.SetApproximateSize(50)) - pdctl.MustPutRegion(re, cluster, 1003, 2, []byte("g"), []byte("h"), core.SetApproximateSize(40)) + tests.MustPutRegion(re, cluster, 1000, 1, []byte("a"), []byte("b"), core.SetApproximateSize(60)) + tests.MustPutRegion(re, cluster, 1001, 2, []byte("c"), []byte("d"), core.SetApproximateSize(30)) + tests.MustPutRegion(re, cluster, 1002, 1, []byte("e"), []byte("f"), core.SetApproximateSize(50)) + tests.MustPutRegion(re, cluster, 1003, 2, []byte("g"), []byte("h"), core.SetApproximateSize(40)) // no store removing output := sendRequest(re, leader.GetAddr()+"/pd/api/v1/stores/progress?action=removing", http.MethodGet, http.StatusNotFound) @@ -762,8 +762,8 @@ func TestRemovingProgress(t *testing.T) { re.Equal(math.MaxFloat64, p.LeftSeconds) // update size - pdctl.MustPutRegion(re, cluster, 1000, 1, []byte("a"), []byte("b"), core.SetApproximateSize(20)) - pdctl.MustPutRegion(re, cluster, 1001, 2, []byte("c"), []byte("d"), core.SetApproximateSize(10)) + tests.MustPutRegion(re, cluster, 1000, 1, []byte("a"), []byte("b"), core.SetApproximateSize(20)) + tests.MustPutRegion(re, cluster, 1001, 2, []byte("c"), []byte("d"), core.SetApproximateSize(10)) // is not prepared time.Sleep(2 * time.Second) @@ -817,7 +817,8 @@ func TestSendApiWhenRestartRaftCluster(t *testing.T) { err = cluster.RunInitialServers() re.NoError(err) - leader := cluster.GetServer(cluster.WaitLeader()) + re.NotEmpty(cluster.WaitLeader()) + leader := cluster.GetLeaderServer() grpcPDClient := testutil.MustNewGrpcClient(re, leader.GetAddr()) clusterID := leader.GetClusterID() @@ -860,7 +861,7 @@ func TestPreparingProgress(t *testing.T) { re.NoError(err) cluster.WaitLeader() - leader := cluster.GetServer(cluster.GetLeader()) + leader := cluster.GetLeaderServer() grpcPDClient := testutil.MustNewGrpcClient(re, leader.GetAddr()) clusterID := leader.GetClusterID() req := &pdpb.BootstrapRequest{ @@ -910,10 +911,10 @@ func TestPreparingProgress(t *testing.T) { } for _, store := range stores { - pdctl.MustPutStore(re, leader.GetServer(), store) + tests.MustPutStore(re, cluster, store) } for i := 0; i < 100; i++ { - pdctl.MustPutRegion(re, cluster, uint64(i+1), uint64(i)%3+1, []byte(fmt.Sprintf("p%d", i)), []byte(fmt.Sprintf("%d", i+1)), core.SetApproximateSize(10)) + tests.MustPutRegion(re, cluster, uint64(i+1), uint64(i)%3+1, []byte(fmt.Sprintf("p%d", i)), []byte(fmt.Sprintf("%d", i+1)), core.SetApproximateSize(10)) } // no store preparing output := sendRequest(re, leader.GetAddr()+"/pd/api/v1/stores/progress?action=preparing", http.MethodGet, http.StatusNotFound) @@ -940,8 +941,8 @@ func TestPreparingProgress(t *testing.T) { re.Equal(math.MaxFloat64, p.LeftSeconds) // update size - pdctl.MustPutRegion(re, cluster, 1000, 4, []byte(fmt.Sprintf("%d", 1000)), []byte(fmt.Sprintf("%d", 1001)), core.SetApproximateSize(10)) - pdctl.MustPutRegion(re, cluster, 1001, 5, []byte(fmt.Sprintf("%d", 1001)), []byte(fmt.Sprintf("%d", 1002)), core.SetApproximateSize(40)) + tests.MustPutRegion(re, cluster, 1000, 4, []byte(fmt.Sprintf("%d", 1000)), []byte(fmt.Sprintf("%d", 1001)), core.SetApproximateSize(10)) + tests.MustPutRegion(re, cluster, 1001, 5, []byte(fmt.Sprintf("%d", 1001)), []byte(fmt.Sprintf("%d", 1002)), core.SetApproximateSize(40)) time.Sleep(2 * time.Second) output = sendRequest(re, leader.GetAddr()+"/pd/api/v1/stores/progress?action=preparing", http.MethodGet, http.StatusOK) re.NoError(json.Unmarshal(output, &p)) diff --git a/tests/server/apiv2/handlers/keyspace_test.go b/tests/server/apiv2/handlers/keyspace_test.go index 7fd8de013f7e..f7b43ab194d1 100644 --- a/tests/server/apiv2/handlers/keyspace_test.go +++ b/tests/server/apiv2/handlers/keyspace_test.go @@ -53,7 +53,7 @@ func (suite *keyspaceTestSuite) SetupTest() { suite.NoError(err) suite.NoError(cluster.RunInitialServers()) suite.NotEmpty(cluster.WaitLeader()) - suite.server = cluster.GetServer(cluster.GetLeader()) + suite.server = cluster.GetLeaderServer() suite.NoError(suite.server.BootstrapCluster()) suite.NoError(failpoint.Enable("github.com/tikv/pd/pkg/keyspace/skipSplitRegion", "return(true)")) } diff --git a/tests/server/apiv2/handlers/tso_keyspace_group_test.go b/tests/server/apiv2/handlers/tso_keyspace_group_test.go index 1f0189c532fc..214de6e95ef5 100644 --- a/tests/server/apiv2/handlers/tso_keyspace_group_test.go +++ b/tests/server/apiv2/handlers/tso_keyspace_group_test.go @@ -45,7 +45,7 @@ func (suite *keyspaceGroupTestSuite) SetupTest() { suite.NoError(err) suite.NoError(cluster.RunInitialServers()) suite.NotEmpty(cluster.WaitLeader()) - suite.server = cluster.GetServer(cluster.GetLeader()) + suite.server = cluster.GetLeaderServer() suite.NoError(suite.server.BootstrapCluster()) } diff --git a/tests/server/cluster/cluster_test.go b/tests/server/cluster/cluster_test.go index f22a754b8bf4..e1b04c4ebc10 100644 --- a/tests/server/cluster/cluster_test.go +++ b/tests/server/cluster/cluster_test.go @@ -71,7 +71,7 @@ func TestBootstrap(t *testing.T) { re.NoError(err) tc.WaitLeader() - leaderServer := tc.GetServer(tc.GetLeader()) + leaderServer := tc.GetLeaderServer() grpcPDClient := testutil.MustNewGrpcClient(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() @@ -111,7 +111,7 @@ func TestDamagedRegion(t *testing.T) { re.NoError(err) tc.WaitLeader() - leaderServer := tc.GetServer(tc.GetLeader()) + leaderServer := tc.GetLeaderServer() grpcPDClient := testutil.MustNewGrpcClient(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() bootstrapCluster(re, clusterID, grpcPDClient) @@ -191,7 +191,7 @@ func TestStaleRegion(t *testing.T) { re.NoError(err) tc.WaitLeader() - leaderServer := tc.GetServer(tc.GetLeader()) + leaderServer := tc.GetLeaderServer() grpcPDClient := testutil.MustNewGrpcClient(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() bootstrapCluster(re, clusterID, grpcPDClient) @@ -236,7 +236,7 @@ func TestGetPutConfig(t *testing.T) { re.NoError(err) tc.WaitLeader() - leaderServer := tc.GetServer(tc.GetLeader()) + leaderServer := tc.GetLeaderServer() grpcPDClient := testutil.MustNewGrpcClient(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() bootstrapCluster(re, clusterID, grpcPDClient) @@ -465,7 +465,7 @@ func TestRaftClusterRestart(t *testing.T) { re.NoError(err) tc.WaitLeader() - leaderServer := tc.GetServer(tc.GetLeader()) + leaderServer := tc.GetLeaderServer() grpcPDClient := testutil.MustNewGrpcClient(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() bootstrapCluster(re, clusterID, grpcPDClient) @@ -495,7 +495,7 @@ func TestRaftClusterMultipleRestart(t *testing.T) { re.NoError(err) tc.WaitLeader() - leaderServer := tc.GetServer(tc.GetLeader()) + leaderServer := tc.GetLeaderServer() grpcPDClient := testutil.MustNewGrpcClient(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() bootstrapCluster(re, clusterID, grpcPDClient) @@ -538,7 +538,7 @@ func TestGetPDMembers(t *testing.T) { re.NoError(err) tc.WaitLeader() - leaderServer := tc.GetServer(tc.GetLeader()) + leaderServer := tc.GetLeaderServer() grpcPDClient := testutil.MustNewGrpcClient(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() req := &pdpb.GetMembersRequest{Header: testutil.NewRequestHeader(clusterID)} @@ -582,7 +582,7 @@ func TestStoreVersionChange(t *testing.T) { re.NoError(err) tc.WaitLeader() - leaderServer := tc.GetServer(tc.GetLeader()) + leaderServer := tc.GetLeaderServer() grpcPDClient := testutil.MustNewGrpcClient(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() bootstrapCluster(re, clusterID, grpcPDClient) @@ -620,7 +620,7 @@ func TestConcurrentHandleRegion(t *testing.T) { err = tc.RunInitialServers() re.NoError(err) tc.WaitLeader() - leaderServer := tc.GetServer(tc.GetLeader()) + leaderServer := tc.GetLeaderServer() grpcPDClient := testutil.MustNewGrpcClient(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() bootstrapCluster(re, clusterID, grpcPDClient) @@ -737,7 +737,7 @@ func TestSetScheduleOpt(t *testing.T) { re.NoError(err) tc.WaitLeader() - leaderServer := tc.GetServer(tc.GetLeader()) + leaderServer := tc.GetLeaderServer() grpcPDClient := testutil.MustNewGrpcClient(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() bootstrapCluster(re, clusterID, grpcPDClient) @@ -808,7 +808,7 @@ func TestLoadClusterInfo(t *testing.T) { re.NoError(err) tc.WaitLeader() - leaderServer := tc.GetServer(tc.GetLeader()) + leaderServer := tc.GetLeaderServer() svr := leaderServer.GetServer() rc := cluster.NewRaftCluster(ctx, svr.ClusterID(), syncer.NewRegionSyncer(svr), svr.GetClient(), svr.GetHTTPClient()) @@ -896,7 +896,7 @@ func TestTiFlashWithPlacementRules(t *testing.T) { err = tc.RunInitialServers() re.NoError(err) tc.WaitLeader() - leaderServer := tc.GetServer(tc.GetLeader()) + leaderServer := tc.GetLeaderServer() grpcPDClient := testutil.MustNewGrpcClient(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() bootstrapCluster(re, clusterID, grpcPDClient) @@ -949,7 +949,7 @@ func TestReplicationModeStatus(t *testing.T) { err = tc.RunInitialServers() re.NoError(err) tc.WaitLeader() - leaderServer := tc.GetServer(tc.GetLeader()) + leaderServer := tc.GetLeaderServer() grpcPDClient := testutil.MustNewGrpcClient(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() req := newBootstrapRequest(clusterID) @@ -1049,7 +1049,7 @@ func TestOfflineStoreLimit(t *testing.T) { err = tc.RunInitialServers() re.NoError(err) tc.WaitLeader() - leaderServer := tc.GetServer(tc.GetLeader()) + leaderServer := tc.GetLeaderServer() grpcPDClient := testutil.MustNewGrpcClient(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() bootstrapCluster(re, clusterID, grpcPDClient) @@ -1141,7 +1141,7 @@ func TestUpgradeStoreLimit(t *testing.T) { err = tc.RunInitialServers() re.NoError(err) tc.WaitLeader() - leaderServer := tc.GetServer(tc.GetLeader()) + leaderServer := tc.GetLeaderServer() grpcPDClient := testutil.MustNewGrpcClient(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() bootstrapCluster(re, clusterID, grpcPDClient) @@ -1199,7 +1199,7 @@ func TestStaleTermHeartbeat(t *testing.T) { err = tc.RunInitialServers() re.NoError(err) tc.WaitLeader() - leaderServer := tc.GetServer(tc.GetLeader()) + leaderServer := tc.GetLeaderServer() grpcPDClient := testutil.MustNewGrpcClient(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() bootstrapCluster(re, clusterID, grpcPDClient) @@ -1334,7 +1334,7 @@ func TestMinResolvedTS(t *testing.T) { err = tc.RunInitialServers() re.NoError(err) tc.WaitLeader() - leaderServer := tc.GetServer(tc.GetLeader()) + leaderServer := tc.GetLeaderServer() id := leaderServer.GetAllocator() grpcPDClient := testutil.MustNewGrpcClient(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() @@ -1443,7 +1443,7 @@ func TestTransferLeaderBack(t *testing.T) { err = tc.RunInitialServers() re.NoError(err) tc.WaitLeader() - leaderServer := tc.GetServer(tc.GetLeader()) + leaderServer := tc.GetLeaderServer() svr := leaderServer.GetServer() rc := cluster.NewRaftCluster(ctx, svr.ClusterID(), syncer.NewRegionSyncer(svr), svr.GetClient(), svr.GetHTTPClient()) rc.InitCluster(svr.GetAllocator(), svr.GetPersistOptions(), svr.GetStorage(), svr.GetBasicCluster(), svr.GetKeyspaceGroupManager()) @@ -1470,7 +1470,7 @@ func TestTransferLeaderBack(t *testing.T) { // transfer PD leader to another PD tc.ResignLeader() tc.WaitLeader() - leaderServer = tc.GetServer(tc.GetLeader()) + leaderServer = tc.GetLeaderServer() svr1 := leaderServer.GetServer() rc1 := svr1.GetRaftCluster() re.NoError(err) @@ -1483,7 +1483,7 @@ func TestTransferLeaderBack(t *testing.T) { // transfer PD leader back to the previous PD tc.ResignLeader() tc.WaitLeader() - leaderServer = tc.GetServer(tc.GetLeader()) + leaderServer = tc.GetLeaderServer() svr = leaderServer.GetServer() rc = svr.GetRaftCluster() re.NotNil(rc) @@ -1503,7 +1503,7 @@ func TestExternalTimestamp(t *testing.T) { err = tc.RunInitialServers() re.NoError(err) tc.WaitLeader() - leaderServer := tc.GetServer(tc.GetLeader()) + leaderServer := tc.GetLeaderServer() grpcPDClient := testutil.MustNewGrpcClient(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() bootstrapCluster(re, clusterID, grpcPDClient) diff --git a/tests/server/cluster/cluster_work_test.go b/tests/server/cluster/cluster_work_test.go index f0f24ca67770..ef09e5223059 100644 --- a/tests/server/cluster/cluster_work_test.go +++ b/tests/server/cluster/cluster_work_test.go @@ -42,7 +42,7 @@ func TestValidRequestRegion(t *testing.T) { re.NoError(err) cluster.WaitLeader() - leaderServer := cluster.GetServer(cluster.GetLeader()) + leaderServer := cluster.GetLeaderServer() grpcPDClient := testutil.MustNewGrpcClient(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() bootstrapCluster(re, clusterID, grpcPDClient) @@ -86,7 +86,7 @@ func TestAskSplit(t *testing.T) { re.NoError(err) cluster.WaitLeader() - leaderServer := cluster.GetServer(cluster.GetLeader()) + leaderServer := cluster.GetLeaderServer() grpcPDClient := testutil.MustNewGrpcClient(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() bootstrapCluster(re, clusterID, grpcPDClient) @@ -143,7 +143,7 @@ func TestSuspectRegions(t *testing.T) { re.NoError(err) cluster.WaitLeader() - leaderServer := cluster.GetServer(cluster.GetLeader()) + leaderServer := cluster.GetLeaderServer() grpcPDClient := testutil.MustNewGrpcClient(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() bootstrapCluster(re, clusterID, grpcPDClient) diff --git a/tests/server/config/config_test.go b/tests/server/config/config_test.go index b9a746b8bedb..1b2178bde33d 100644 --- a/tests/server/config/config_test.go +++ b/tests/server/config/config_test.go @@ -43,7 +43,7 @@ func TestRateLimitConfigReload(t *testing.T) { defer cluster.Destroy() re.NoError(cluster.RunInitialServers()) re.NotEmpty(cluster.WaitLeader()) - leader := cluster.GetServer(cluster.GetLeader()) + leader := cluster.GetLeaderServer() re.NotNil(leader) re.Empty(leader.GetServer().GetServiceMiddlewareConfig().RateLimitConfig.LimiterConfig) limitCfg := make(map[string]ratelimit.DimensionConfig) @@ -69,7 +69,7 @@ func TestRateLimitConfigReload(t *testing.T) { servers = append(servers, s.GetServer()) } server.MustWaitLeader(re, servers) - leader = cluster.GetServer(cluster.GetLeader()) + leader = cluster.GetLeaderServer() re.NotNil(leader) re.True(leader.GetServer().GetServiceMiddlewarePersistOptions().IsRateLimitEnabled()) re.Len(leader.GetServer().GetServiceMiddlewarePersistOptions().GetRateLimitConfig().LimiterConfig, 1) diff --git a/tests/server/id/id_test.go b/tests/server/id/id_test.go index c4e1c8bb5de8..737aa4deac25 100644 --- a/tests/server/id/id_test.go +++ b/tests/server/id/id_test.go @@ -44,7 +44,7 @@ func TestID(t *testing.T) { re.NoError(err) cluster.WaitLeader() - leaderServer := cluster.GetServer(cluster.GetLeader()) + leaderServer := cluster.GetLeaderServer() var last uint64 for i := uint64(0); i < allocStep; i++ { id, err := leaderServer.GetAllocator().Alloc() @@ -90,7 +90,7 @@ func TestCommand(t *testing.T) { re.NoError(err) cluster.WaitLeader() - leaderServer := cluster.GetServer(cluster.GetLeader()) + leaderServer := cluster.GetLeaderServer() req := &pdpb.AllocIDRequest{Header: testutil.NewRequestHeader(leaderServer.GetClusterID())} grpcPDClient := testutil.MustNewGrpcClient(re, leaderServer.GetAddr()) @@ -116,7 +116,7 @@ func TestMonotonicID(t *testing.T) { re.NoError(err) cluster.WaitLeader() - leaderServer := cluster.GetServer(cluster.GetLeader()) + leaderServer := cluster.GetLeaderServer() var last1 uint64 for i := uint64(0); i < 10; i++ { id, err := leaderServer.GetAllocator().Alloc() @@ -127,7 +127,7 @@ func TestMonotonicID(t *testing.T) { err = cluster.ResignLeader() re.NoError(err) cluster.WaitLeader() - leaderServer = cluster.GetServer(cluster.GetLeader()) + leaderServer = cluster.GetLeaderServer() var last2 uint64 for i := uint64(0); i < 10; i++ { id, err := leaderServer.GetAllocator().Alloc() @@ -138,7 +138,7 @@ func TestMonotonicID(t *testing.T) { err = cluster.ResignLeader() re.NoError(err) cluster.WaitLeader() - leaderServer = cluster.GetServer(cluster.GetLeader()) + leaderServer = cluster.GetLeaderServer() id, err := leaderServer.GetAllocator().Alloc() re.NoError(err) re.Greater(id, last2) @@ -162,7 +162,7 @@ func TestPDRestart(t *testing.T) { err = cluster.RunInitialServers() re.NoError(err) cluster.WaitLeader() - leaderServer := cluster.GetServer(cluster.GetLeader()) + leaderServer := cluster.GetLeaderServer() var last uint64 for i := uint64(0); i < 10; i++ { diff --git a/tests/server/keyspace/keyspace_test.go b/tests/server/keyspace/keyspace_test.go index a36a73795502..86b8f6fd37c5 100644 --- a/tests/server/keyspace/keyspace_test.go +++ b/tests/server/keyspace/keyspace_test.go @@ -59,7 +59,7 @@ func (suite *keyspaceTestSuite) SetupTest() { suite.NoError(err) suite.NoError(cluster.RunInitialServers()) suite.NotEmpty(cluster.WaitLeader()) - suite.server = cluster.GetServer(cluster.GetLeader()) + suite.server = cluster.GetLeaderServer() suite.manager = suite.server.GetKeyspaceManager() suite.NoError(suite.server.BootstrapCluster()) } diff --git a/tests/server/member/member_test.go b/tests/server/member/member_test.go index ca89e66a0415..26d4fa2a9044 100644 --- a/tests/server/member/member_test.go +++ b/tests/server/member/member_test.go @@ -63,7 +63,7 @@ func TestMemberDelete(t *testing.T) { re.NoError(err) leaderName := cluster.WaitLeader() re.NotEmpty(leaderName) - leader := cluster.GetServer(leaderName) + leader := cluster.GetLeaderServer() var members []*tests.TestServer for _, s := range cluster.GetConfig().InitialServers { if s.Name != leaderName { diff --git a/tests/server/region_syncer/region_syncer_test.go b/tests/server/region_syncer/region_syncer_test.go index afa5c87cdcc9..f672f82f1f64 100644 --- a/tests/server/region_syncer/region_syncer_test.go +++ b/tests/server/region_syncer/region_syncer_test.go @@ -57,7 +57,7 @@ func TestRegionSyncer(t *testing.T) { re.NoError(cluster.RunInitialServers()) cluster.WaitLeader() - leaderServer := cluster.GetServer(cluster.GetLeader()) + leaderServer := cluster.GetLeaderServer() re.NoError(leaderServer.BootstrapCluster()) rc := leaderServer.GetServer().GetRaftCluster() re.NotNil(rc) @@ -140,7 +140,7 @@ func TestRegionSyncer(t *testing.T) { err = leaderServer.Stop() re.NoError(err) cluster.WaitLeader() - leaderServer = cluster.GetServer(cluster.GetLeader()) + leaderServer = cluster.GetLeaderServer() re.NotNil(leaderServer) loadRegions := leaderServer.GetServer().GetRaftCluster().GetRegions() re.Len(loadRegions, regionLen) @@ -166,7 +166,7 @@ func TestFullSyncWithAddMember(t *testing.T) { err = cluster.RunInitialServers() re.NoError(err) cluster.WaitLeader() - leaderServer := cluster.GetServer(cluster.GetLeader()) + leaderServer := cluster.GetLeaderServer() re.NoError(leaderServer.BootstrapCluster()) rc := leaderServer.GetServer().GetRaftCluster() re.NotNil(rc) @@ -210,7 +210,7 @@ func TestPrepareChecker(t *testing.T) { err = cluster.RunInitialServers() re.NoError(err) cluster.WaitLeader() - leaderServer := cluster.GetServer(cluster.GetLeader()) + leaderServer := cluster.GetLeaderServer() re.NoError(leaderServer.BootstrapCluster()) rc := leaderServer.GetServer().GetRaftCluster() re.NotNil(rc) @@ -235,7 +235,7 @@ func TestPrepareChecker(t *testing.T) { err = cluster.ResignLeader() re.NoError(err) re.Equal("pd2", cluster.WaitLeader()) - leaderServer = cluster.GetServer(cluster.GetLeader()) + leaderServer = cluster.GetLeaderServer() rc = leaderServer.GetServer().GetRaftCluster() for _, region := range regions { err = rc.HandleRegionHeartbeat(region) diff --git a/tests/server/storage/hot_region_storage_test.go b/tests/server/storage/hot_region_storage_test.go index 21881802d7d0..00d0244a7905 100644 --- a/tests/server/storage/hot_region_storage_test.go +++ b/tests/server/storage/hot_region_storage_test.go @@ -29,7 +29,6 @@ import ( "github.com/tikv/pd/pkg/utils/testutil" "github.com/tikv/pd/server/config" "github.com/tikv/pd/tests" - "github.com/tikv/pd/tests/pdctl" ) func TestHotRegionStorage(t *testing.T) { @@ -61,20 +60,20 @@ func TestHotRegionStorage(t *testing.T) { }, } - leaderServer := cluster.GetServer(cluster.GetLeader()) + leaderServer := cluster.GetLeaderServer() re.NoError(leaderServer.BootstrapCluster()) for _, store := range stores { - pdctl.MustPutStore(re, leaderServer.GetServer(), store) + tests.MustPutStore(re, cluster, store) } defer cluster.Destroy() startTime := time.Now().Unix() - pdctl.MustPutRegion(re, cluster, 1, 1, []byte("a"), []byte("b"), core.SetWrittenBytes(3000000000), + tests.MustPutRegion(re, cluster, 1, 1, []byte("a"), []byte("b"), core.SetWrittenBytes(3000000000), core.SetReportInterval(uint64(startTime-utils.RegionHeartBeatReportInterval), uint64(startTime))) - pdctl.MustPutRegion(re, cluster, 2, 2, []byte("c"), []byte("d"), core.SetWrittenBytes(6000000000), + tests.MustPutRegion(re, cluster, 2, 2, []byte("c"), []byte("d"), core.SetWrittenBytes(6000000000), core.SetReportInterval(uint64(startTime-utils.RegionHeartBeatReportInterval), uint64(startTime))) - pdctl.MustPutRegion(re, cluster, 3, 1, []byte("e"), []byte("f"), + tests.MustPutRegion(re, cluster, 3, 1, []byte("e"), []byte("f"), core.SetReportInterval(uint64(startTime-utils.RegionHeartBeatReportInterval), uint64(startTime))) - pdctl.MustPutRegion(re, cluster, 4, 2, []byte("g"), []byte("h"), + tests.MustPutRegion(re, cluster, 4, 2, []byte("g"), []byte("h"), core.SetReportInterval(uint64(startTime-utils.RegionHeartBeatReportInterval), uint64(startTime))) storeStats := []*pdpb.StoreStats{ { @@ -169,14 +168,14 @@ func TestHotRegionStorageReservedDayConfigChange(t *testing.T) { }, } - leaderServer := cluster.GetServer(cluster.GetLeader()) + leaderServer := cluster.GetLeaderServer() re.NoError(leaderServer.BootstrapCluster()) for _, store := range stores { - pdctl.MustPutStore(re, leaderServer.GetServer(), store) + tests.MustPutStore(re, cluster, store) } defer cluster.Destroy() startTime := time.Now().Unix() - pdctl.MustPutRegion(re, cluster, 1, 1, []byte("a"), []byte("b"), core.SetWrittenBytes(3000000000), + tests.MustPutRegion(re, cluster, 1, 1, []byte("a"), []byte("b"), core.SetWrittenBytes(3000000000), core.SetReportInterval(uint64(startTime-utils.RegionHeartBeatReportInterval), uint64(startTime))) var iter storage.HotRegionStorageIterator var next *storage.HistoryHotRegion @@ -197,7 +196,7 @@ func TestHotRegionStorageReservedDayConfigChange(t *testing.T) { schedule.HotRegionsReservedDays = 0 leaderServer.GetServer().SetScheduleConfig(schedule) time.Sleep(3 * interval) - pdctl.MustPutRegion(re, cluster, 2, 2, []byte("c"), []byte("d"), core.SetWrittenBytes(6000000000), + tests.MustPutRegion(re, cluster, 2, 2, []byte("c"), []byte("d"), core.SetWrittenBytes(6000000000), core.SetReportInterval(uint64(time.Now().Unix()-utils.RegionHeartBeatReportInterval), uint64(time.Now().Unix()))) time.Sleep(10 * interval) hotRegionStorage := leaderServer.GetServer().GetHistoryHotRegionStorage() @@ -261,14 +260,14 @@ func TestHotRegionStorageWriteIntervalConfigChange(t *testing.T) { }, } - leaderServer := cluster.GetServer(cluster.GetLeader()) + leaderServer := cluster.GetLeaderServer() re.NoError(leaderServer.BootstrapCluster()) for _, store := range stores { - pdctl.MustPutStore(re, leaderServer.GetServer(), store) + tests.MustPutStore(re, cluster, store) } defer cluster.Destroy() startTime := time.Now().Unix() - pdctl.MustPutRegion(re, cluster, 1, 1, []byte("a"), []byte("b"), + tests.MustPutRegion(re, cluster, 1, 1, []byte("a"), []byte("b"), core.SetWrittenBytes(3000000000), core.SetReportInterval(uint64(startTime-utils.RegionHeartBeatReportInterval), uint64(startTime))) var iter storage.HotRegionStorageIterator @@ -290,7 +289,7 @@ func TestHotRegionStorageWriteIntervalConfigChange(t *testing.T) { schedule.HotRegionsWriteInterval.Duration = 20 * interval leaderServer.GetServer().SetScheduleConfig(schedule) time.Sleep(3 * interval) - pdctl.MustPutRegion(re, cluster, 2, 2, []byte("c"), []byte("d"), core.SetWrittenBytes(6000000000), + tests.MustPutRegion(re, cluster, 2, 2, []byte("c"), []byte("d"), core.SetWrittenBytes(6000000000), core.SetReportInterval(uint64(time.Now().Unix()-utils.RegionHeartBeatReportInterval), uint64(time.Now().Unix()))) time.Sleep(10 * interval) // it cant get new hot region because wait time smaller than hot region write interval diff --git a/tests/server/tso/consistency_test.go b/tests/server/tso/consistency_test.go index db6e2135d2b7..9cfadbf5ba30 100644 --- a/tests/server/tso/consistency_test.go +++ b/tests/server/tso/consistency_test.go @@ -79,7 +79,7 @@ func (suite *tsoConsistencyTestSuite) TestSynchronizedGlobalTSO() { re := suite.Require() cluster.WaitAllLeaders(re, dcLocationConfig) - suite.leaderServer = cluster.GetServer(cluster.GetLeader()) + suite.leaderServer = cluster.GetLeaderServer() suite.NotNil(suite.leaderServer) suite.dcClientMap[tso.GlobalDCLocation] = testutil.MustNewGrpcClient(re, suite.leaderServer.GetAddr()) for _, dcLocation := range dcLocationConfig { @@ -154,7 +154,7 @@ func (suite *tsoConsistencyTestSuite) TestSynchronizedGlobalTSOOverflow() { re := suite.Require() cluster.WaitAllLeaders(re, dcLocationConfig) - suite.leaderServer = cluster.GetServer(cluster.GetLeader()) + suite.leaderServer = cluster.GetLeaderServer() suite.NotNil(suite.leaderServer) suite.dcClientMap[tso.GlobalDCLocation] = testutil.MustNewGrpcClient(re, suite.leaderServer.GetAddr()) for _, dcLocation := range dcLocationConfig { @@ -186,7 +186,7 @@ func (suite *tsoConsistencyTestSuite) TestLocalAllocatorLeaderChange() { re := suite.Require() cluster.WaitAllLeaders(re, dcLocationConfig) - suite.leaderServer = cluster.GetServer(cluster.GetLeader()) + suite.leaderServer = cluster.GetLeaderServer() suite.NotNil(suite.leaderServer) suite.dcClientMap[tso.GlobalDCLocation] = testutil.MustNewGrpcClient(re, suite.leaderServer.GetAddr()) for _, dcLocation := range dcLocationConfig { @@ -248,7 +248,7 @@ func (suite *tsoConsistencyTestSuite) TestLocalTSOAfterMemberChanged() { re := suite.Require() cluster.WaitAllLeaders(re, dcLocationConfig) - leaderServer := cluster.GetServer(cluster.GetLeader()) + leaderServer := cluster.GetLeaderServer() leaderCli := testutil.MustNewGrpcClient(re, leaderServer.GetAddr()) req := &pdpb.TsoRequest{ Header: testutil.NewRequestHeader(cluster.GetCluster().GetId()), @@ -286,7 +286,7 @@ func (suite *tsoConsistencyTestSuite) TestLocalTSOAfterMemberChanged() { func (suite *tsoConsistencyTestSuite) testTSO(cluster *tests.TestCluster, dcLocationConfig map[string]string, previousTS *pdpb.Timestamp) { re := suite.Require() - leaderServer := cluster.GetServer(cluster.GetLeader()) + leaderServer := cluster.GetLeaderServer() dcClientMap := make(map[string]pdpb.PDClient) for _, dcLocation := range dcLocationConfig { pdName := leaderServer.GetAllocatorLeader(dcLocation).GetName() diff --git a/tests/server/tso/global_tso_test.go b/tests/server/tso/global_tso_test.go index a6340e2671c6..5ae2e6e0f675 100644 --- a/tests/server/tso/global_tso_test.go +++ b/tests/server/tso/global_tso_test.go @@ -97,7 +97,7 @@ func TestDelaySyncTimestamp(t *testing.T) { cluster.WaitLeader() var leaderServer, nextLeaderServer *tests.TestServer - leaderServer = cluster.GetServer(cluster.GetLeader()) + leaderServer = cluster.GetLeaderServer() re.NotNil(leaderServer) for _, s := range cluster.GetServers() { if s.GetConfig().Name != cluster.GetLeader() { @@ -145,7 +145,7 @@ func TestLogicalOverflow(t *testing.T) { re.NoError(cluster.RunInitialServers()) cluster.WaitLeader() - leaderServer := cluster.GetServer(cluster.GetLeader()) + leaderServer := cluster.GetLeaderServer() grpcPDClient := testutil.MustNewGrpcClient(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() diff --git a/tests/server/tso/tso_test.go b/tests/server/tso/tso_test.go index 48df02a6c275..9eff1192e570 100644 --- a/tests/server/tso/tso_test.go +++ b/tests/server/tso/tso_test.go @@ -76,7 +76,7 @@ func TestLoadTimestamp(t *testing.T) { func requestLocalTSOs(re *require.Assertions, cluster *tests.TestCluster, dcLocationConfig map[string]string) map[string]*pdpb.Timestamp { dcClientMap := make(map[string]pdpb.PDClient) tsMap := make(map[string]*pdpb.Timestamp) - leaderServer := cluster.GetServer(cluster.GetLeader()) + leaderServer := cluster.GetLeaderServer() for _, dcLocation := range dcLocationConfig { pdName := leaderServer.GetAllocatorLeader(dcLocation).GetName() dcClientMap[dcLocation] = testutil.MustNewGrpcClient(re, cluster.GetServer(pdName).GetAddr()) @@ -125,7 +125,7 @@ func TestDisableLocalTSOAfterEnabling(t *testing.T) { cluster.WaitLeader() // Re-request the global TSOs. - leaderServer := cluster.GetServer(cluster.GetLeader()) + leaderServer := cluster.GetLeaderServer() grpcPDClient := testutil.MustNewGrpcClient(re, leaderServer.GetAddr()) clusterID := leaderServer.GetClusterID() req := &pdpb.TsoRequest{ diff --git a/tests/server/watch/leader_watch_test.go b/tests/server/watch/leader_watch_test.go index 049486ba068f..f77652970236 100644 --- a/tests/server/watch/leader_watch_test.go +++ b/tests/server/watch/leader_watch_test.go @@ -42,7 +42,7 @@ func TestWatcher(t *testing.T) { err = cluster.RunInitialServers() re.NoError(err) cluster.WaitLeader() - pd1 := cluster.GetServer(cluster.GetLeader()) + pd1 := cluster.GetLeaderServer() re.NotNil(pd1) pd2, err := cluster.Join(ctx) @@ -80,7 +80,7 @@ func TestWatcherCompacted(t *testing.T) { err = cluster.RunInitialServers() re.NoError(err) cluster.WaitLeader() - pd1 := cluster.GetServer(cluster.GetLeader()) + pd1 := cluster.GetLeaderServer() re.NotNil(pd1) client := pd1.GetEtcdClient() _, err = client.Put(context.Background(), "test", "v") diff --git a/tests/testutil.go b/tests/testutil.go index 53efcff76580..3fd8e9dca351 100644 --- a/tests/testutil.go +++ b/tests/testutil.go @@ -16,19 +16,26 @@ package tests import ( "context" + "fmt" "os" "sync" "time" + "github.com/docker/go-units" + "github.com/pingcap/kvproto/pkg/metapb" + "github.com/pingcap/kvproto/pkg/pdpb" "github.com/pingcap/log" "github.com/stretchr/testify/require" bs "github.com/tikv/pd/pkg/basicserver" + "github.com/tikv/pd/pkg/core" rm "github.com/tikv/pd/pkg/mcs/resourcemanager/server" scheduling "github.com/tikv/pd/pkg/mcs/scheduling/server" sc "github.com/tikv/pd/pkg/mcs/scheduling/server/config" tso "github.com/tikv/pd/pkg/mcs/tso/server" "github.com/tikv/pd/pkg/utils/logutil" "github.com/tikv/pd/pkg/utils/testutil" + "github.com/tikv/pd/pkg/versioninfo" + "github.com/tikv/pd/server" "go.uber.org/zap" ) @@ -148,3 +155,68 @@ func WaitForPrimaryServing(re *require.Assertions, serverMap map[string]bs.Serve return primary } + +// MustPutStore is used for test purpose. +func MustPutStore(re *require.Assertions, cluster *TestCluster, store *metapb.Store) { + store.Address = fmt.Sprintf("tikv%d", store.GetId()) + if len(store.Version) == 0 { + store.Version = versioninfo.MinSupportedVersion(versioninfo.Version2_0).String() + } + svr := cluster.GetLeaderServer().GetServer() + grpcServer := &server.GrpcServer{Server: svr} + _, err := grpcServer.PutStore(context.Background(), &pdpb.PutStoreRequest{ + Header: &pdpb.RequestHeader{ClusterId: svr.ClusterID()}, + Store: store, + }) + re.NoError(err) + + storeInfo := grpcServer.GetRaftCluster().GetStore(store.GetId()) + newStore := storeInfo.Clone(core.SetStoreStats(&pdpb.StoreStats{ + Capacity: uint64(10 * units.GiB), + UsedSize: uint64(9 * units.GiB), + Available: uint64(1 * units.GiB), + })) + grpcServer.GetRaftCluster().GetBasicCluster().PutStore(newStore) + if cluster.GetSchedulingPrimaryServer() != nil { + cluster.GetSchedulingPrimaryServer().GetCluster().PutStore(newStore) + } +} + +// MustPutRegion is used for test purpose. +func MustPutRegion(re *require.Assertions, cluster *TestCluster, regionID, storeID uint64, start, end []byte, opts ...core.RegionCreateOption) *core.RegionInfo { + leader := &metapb.Peer{ + Id: regionID, + StoreId: storeID, + } + metaRegion := &metapb.Region{ + Id: regionID, + StartKey: start, + EndKey: end, + Peers: []*metapb.Peer{leader}, + RegionEpoch: &metapb.RegionEpoch{ConfVer: 1, Version: 1}, + } + r := core.NewRegionInfo(metaRegion, leader, opts...) + err := cluster.HandleRegionHeartbeat(r) + re.NoError(err) + if cluster.GetSchedulingPrimaryServer() != nil { + err = cluster.GetSchedulingPrimaryServer().GetCluster().HandleRegionHeartbeat(r) + re.NoError(err) + } + return r +} + +// MustReportBuckets is used for test purpose. +func MustReportBuckets(re *require.Assertions, cluster *TestCluster, regionID uint64, start, end []byte, stats *metapb.BucketStats) *metapb.Buckets { + buckets := &metapb.Buckets{ + RegionId: regionID, + Version: 1, + Keys: [][]byte{start, end}, + Stats: stats, + // report buckets interval is 10s + PeriodInMs: 10000, + } + err := cluster.HandleReportBuckets(buckets) + re.NoError(err) + // TODO: forwards to scheduling server after it supports buckets + return buckets +} From e94b4e44ff078b69ffb30a11085f19f17862f511 Mon Sep 17 00:00:00 2001 From: JmPotato Date: Fri, 22 Sep 2023 00:24:16 +0800 Subject: [PATCH 4/8] schedulers: fix the grant-leader-scheuler store pause/resume (#7128) ref tikv/pd#5839 The grant-leader-scheduler should also check the store pause/resume after reloading the config. Signed-off-by: JmPotato Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- pkg/schedule/schedulers/balance_witness.go | 4 ++-- pkg/schedule/schedulers/evict_leader.go | 22 +++++++++++++--------- pkg/schedule/schedulers/grant_leader.go | 1 + 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/pkg/schedule/schedulers/balance_witness.go b/pkg/schedule/schedulers/balance_witness.go index 9bd8a592ba13..e9bab6c1bc7c 100644 --- a/pkg/schedule/schedulers/balance_witness.go +++ b/pkg/schedule/schedulers/balance_witness.go @@ -118,7 +118,7 @@ type balanceWitnessHandler struct { config *balanceWitnessSchedulerConfig } -func newbalanceWitnessHandler(conf *balanceWitnessSchedulerConfig) http.Handler { +func newBalanceWitnessHandler(conf *balanceWitnessSchedulerConfig) http.Handler { handler := &balanceWitnessHandler{ config: conf, rd: render.New(render.Options{IndentJSON: true}), @@ -161,7 +161,7 @@ func newBalanceWitnessScheduler(opController *operator.Controller, conf *balance retryQuota: newRetryQuota(), name: BalanceWitnessName, conf: conf, - handler: newbalanceWitnessHandler(conf), + handler: newBalanceWitnessHandler(conf), counter: balanceWitnessCounter, filterCounter: filter.NewCounter(filter.BalanceWitness.String()), } diff --git a/pkg/schedule/schedulers/evict_leader.go b/pkg/schedule/schedulers/evict_leader.go index 2551b9ac9cbb..3c3f0603408a 100644 --- a/pkg/schedule/schedulers/evict_leader.go +++ b/pkg/schedule/schedulers/evict_leader.go @@ -218,21 +218,25 @@ func (s *evictLeaderScheduler) ReloadConfig() error { if err = DecodeConfig([]byte(cfgData), newCfg); err != nil { return err } - // Resume and pause the leader transfer for each store. - for id := range s.conf.StoreIDWithRanges { - if _, ok := newCfg.StoreIDWithRanges[id]; ok { + pauseAndResumeLeaderTransfer(s.conf.cluster, s.conf.StoreIDWithRanges, newCfg.StoreIDWithRanges) + s.conf.StoreIDWithRanges = newCfg.StoreIDWithRanges + return nil +} + +// pauseAndResumeLeaderTransfer checks the old and new store IDs, and pause or resume the leader transfer. +func pauseAndResumeLeaderTransfer(cluster *core.BasicCluster, old, new map[uint64][]core.KeyRange) { + for id := range old { + if _, ok := new[id]; ok { continue } - s.conf.cluster.ResumeLeaderTransfer(id) + cluster.ResumeLeaderTransfer(id) } - for id := range newCfg.StoreIDWithRanges { - if _, ok := s.conf.StoreIDWithRanges[id]; ok { + for id := range new { + if _, ok := old[id]; ok { continue } - s.conf.cluster.PauseLeaderTransfer(id) + cluster.PauseLeaderTransfer(id) } - s.conf.StoreIDWithRanges = newCfg.StoreIDWithRanges - return nil } func (s *evictLeaderScheduler) Prepare(cluster sche.SchedulerCluster) error { diff --git a/pkg/schedule/schedulers/grant_leader.go b/pkg/schedule/schedulers/grant_leader.go index 7d1ff2f616c9..f244228a10f0 100644 --- a/pkg/schedule/schedulers/grant_leader.go +++ b/pkg/schedule/schedulers/grant_leader.go @@ -192,6 +192,7 @@ func (s *grantLeaderScheduler) ReloadConfig() error { if err = DecodeConfig([]byte(cfgData), newCfg); err != nil { return err } + pauseAndResumeLeaderTransfer(s.conf.cluster, s.conf.StoreIDWithRanges, newCfg.StoreIDWithRanges) s.conf.StoreIDWithRanges = newCfg.StoreIDWithRanges return nil } From e6c884139fc90679c6121b0aed75dc9ad3cc0dca Mon Sep 17 00:00:00 2001 From: Ryan Leung Date: Fri, 22 Sep 2023 11:19:45 +0800 Subject: [PATCH 5/8] Revert "cluster: handle region after report split (#6867)" (#7134) close tikv/pd#7133 Signed-off-by: Ryan Leung --- pkg/core/region.go | 40 ++++----- pkg/core/region_test.go | 4 +- pkg/core/store.go | 39 ++++---- pkg/core/store_option.go | 7 -- pkg/mcs/scheduling/server/cluster.go | 8 +- pkg/schedule/filter/counter.go | 2 - pkg/schedule/filter/counter_test.go | 2 +- pkg/schedule/filter/filters.go | 13 +-- pkg/schedule/filter/region_filters.go | 45 ++++------ pkg/schedule/filter/status.go | 5 +- pkg/schedule/plan/status.go | 5 +- pkg/schedule/schedulers/balance_leader.go | 13 ++- pkg/schedule/schedulers/balance_test.go | 8 -- pkg/syncer/client.go | 4 +- server/cluster/cluster.go | 24 +++-- server/cluster/cluster_worker.go | 105 ++++++++-------------- server/cluster/cluster_worker_test.go | 101 ++------------------- server/grpc_service.go | 32 +------ 18 files changed, 126 insertions(+), 331 deletions(-) diff --git a/pkg/core/region.go b/pkg/core/region.go index 2fec30de1326..4540f7aafb30 100644 --- a/pkg/core/region.go +++ b/pkg/core/region.go @@ -682,14 +682,9 @@ func (r *RegionInfo) isRegionRecreated() bool { return r.GetRegionEpoch().GetVersion() == 1 && r.GetRegionEpoch().GetConfVer() == 1 && (len(r.GetStartKey()) != 0 || len(r.GetEndKey()) != 0) } -// RegionChanged is a struct that records the changes of the region. -type RegionChanged struct { - IsNew, SaveKV, SaveCache, NeedSync bool -} - // RegionGuideFunc is a function that determines which follow-up operations need to be performed based on the origin // and new region information. -type RegionGuideFunc func(region, origin *RegionInfo) *RegionChanged +type RegionGuideFunc func(region, origin *RegionInfo) (isNew, saveKV, saveCache, needSync bool) // GenerateRegionGuideFunc is used to generate a RegionGuideFunc. Control the log output by specifying the log function. // nil means do not print the log. @@ -702,19 +697,18 @@ func GenerateRegionGuideFunc(enableLog bool) RegionGuideFunc { } // Save to storage if meta is updated. // Save to cache if meta or leader is updated, or contains any down/pending peer. - // Mark IsNew if the region in cache does not have leader. - return func(region, origin *RegionInfo) (changed *RegionChanged) { - changed = &RegionChanged{} + // Mark isNew if the region in cache does not have leader. + return func(region, origin *RegionInfo) (isNew, saveKV, saveCache, needSync bool) { if origin == nil { if log.GetLevel() <= zap.DebugLevel { debug("insert new region", zap.Uint64("region-id", region.GetID()), logutil.ZapRedactStringer("meta-region", RegionToHexMeta(region.GetMeta()))) } - changed.SaveKV, changed.SaveCache, changed.IsNew = true, true, true + saveKV, saveCache, isNew = true, true, true } else { if !origin.IsFromHeartbeat() { - changed.IsNew = true + isNew = true } r := region.GetRegionEpoch() o := origin.GetRegionEpoch() @@ -727,7 +721,7 @@ func GenerateRegionGuideFunc(enableLog bool) RegionGuideFunc { zap.Uint64("new-version", r.GetVersion()), ) } - changed.SaveKV, changed.SaveCache = true, true + saveKV, saveCache = true, true } if r.GetConfVer() > o.GetConfVer() { if log.GetLevel() <= zap.InfoLevel { @@ -738,11 +732,11 @@ func GenerateRegionGuideFunc(enableLog bool) RegionGuideFunc { zap.Uint64("new-confver", r.GetConfVer()), ) } - changed.SaveCache, changed.SaveKV = true, true + saveKV, saveCache = true, true } if region.GetLeader().GetId() != origin.GetLeader().GetId() { if origin.GetLeader().GetId() == 0 { - changed.IsNew = true + isNew = true } else if log.GetLevel() <= zap.InfoLevel { info("leader changed", zap.Uint64("region-id", region.GetID()), @@ -751,17 +745,17 @@ func GenerateRegionGuideFunc(enableLog bool) RegionGuideFunc { ) } // We check it first and do not return because the log is important for us to investigate, - changed.SaveCache, changed.NeedSync = true, true + saveCache, needSync = true, true } if len(region.GetPeers()) != len(origin.GetPeers()) { - changed.SaveCache, changed.SaveKV = true, true + saveKV, saveCache = true, true return } if len(region.GetBuckets().GetKeys()) != len(origin.GetBuckets().GetKeys()) { if log.GetLevel() <= zap.DebugLevel { debug("bucket key changed", zap.Uint64("region-id", region.GetID())) } - changed.SaveCache, changed.SaveKV = true, true + saveKV, saveCache = true, true return } // Once flow has changed, will update the cache. @@ -769,39 +763,39 @@ func GenerateRegionGuideFunc(enableLog bool) RegionGuideFunc { if region.GetRoundBytesWritten() != origin.GetRoundBytesWritten() || region.GetRoundBytesRead() != origin.GetRoundBytesRead() || region.flowRoundDivisor < origin.flowRoundDivisor { - changed.SaveCache, changed.NeedSync = true, true + saveCache, needSync = true, true return } if !SortedPeersStatsEqual(region.GetDownPeers(), origin.GetDownPeers()) { if log.GetLevel() <= zap.DebugLevel { debug("down-peers changed", zap.Uint64("region-id", region.GetID())) } - changed.SaveCache, changed.NeedSync = true, true + saveCache, needSync = true, true return } if !SortedPeersEqual(region.GetPendingPeers(), origin.GetPendingPeers()) { if log.GetLevel() <= zap.DebugLevel { debug("pending-peers changed", zap.Uint64("region-id", region.GetID())) } - changed.SaveCache, changed.NeedSync = true, true + saveCache, needSync = true, true return } if region.GetApproximateSize() != origin.GetApproximateSize() || region.GetApproximateKeys() != origin.GetApproximateKeys() { - changed.SaveCache = true + saveCache = true return } if region.GetReplicationStatus().GetState() != replication_modepb.RegionReplicationState_UNKNOWN && (region.GetReplicationStatus().GetState() != origin.GetReplicationStatus().GetState() || region.GetReplicationStatus().GetStateId() != origin.GetReplicationStatus().GetStateId()) { - changed.SaveCache = true + saveCache = true return } // Do not save to kv, because 1) flashback will be eventually set to // false, 2) flashback changes almost all regions in a cluster. // Saving kv may downgrade PD performance when there are many regions. if region.IsFlashbackChanged(origin) { - changed.SaveCache = true + saveCache = true return } } diff --git a/pkg/core/region_test.go b/pkg/core/region_test.go index 3b58f5ee15a3..1e6b43fbf964 100644 --- a/pkg/core/region_test.go +++ b/pkg/core/region_test.go @@ -333,8 +333,8 @@ func TestNeedSync(t *testing.T) { for _, testCase := range testCases { regionA := region.Clone(testCase.optionsA...) regionB := region.Clone(testCase.optionsB...) - changed := RegionGuide(regionA, regionB) - re.Equal(testCase.needSync, changed.NeedSync) + _, _, _, needSync := RegionGuide(regionA, regionB) + re.Equal(testCase.needSync, needSync) } } diff --git a/pkg/core/store.go b/pkg/core/store.go index cafb443bb7dd..1d3362cac0e4 100644 --- a/pkg/core/store.go +++ b/pkg/core/store.go @@ -36,7 +36,6 @@ const ( initialMinSpace = 8 * units.GiB // 2^33=8GB slowStoreThreshold = 80 awakenStoreInterval = 10 * time.Minute // 2 * slowScoreRecoveryTime - splitStoreWait = time.Minute // EngineKey is the label key used to indicate engine. EngineKey = "engine" @@ -51,23 +50,22 @@ const ( type StoreInfo struct { meta *metapb.Store *storeStats - pauseLeaderTransfer bool // not allow to be used as source or target of transfer leader - slowStoreEvicted bool // this store has been evicted as a slow store, should not transfer leader to it - slowTrendEvicted bool // this store has been evicted as a slow store by trend, should not transfer leader to it - leaderCount int - regionCount int - learnerCount int - witnessCount int - leaderSize int64 - regionSize int64 - pendingPeerCount int - lastPersistTime time.Time - leaderWeight float64 - regionWeight float64 - limiter storelimit.StoreLimit - minResolvedTS uint64 - lastAwakenTime time.Time - recentlySplitRegionsTime time.Time + pauseLeaderTransfer bool // not allow to be used as source or target of transfer leader + slowStoreEvicted bool // this store has been evicted as a slow store, should not transfer leader to it + slowTrendEvicted bool // this store has been evicted as a slow store by trend, should not transfer leader to it + leaderCount int + regionCount int + learnerCount int + witnessCount int + leaderSize int64 + regionSize int64 + pendingPeerCount int + lastPersistTime time.Time + leaderWeight float64 + regionWeight float64 + limiter storelimit.StoreLimit + minResolvedTS uint64 + lastAwakenTime time.Time } // NewStoreInfo creates StoreInfo with meta data. @@ -541,11 +539,6 @@ func (s *StoreInfo) NeedAwakenStore() bool { return s.GetLastHeartbeatTS().Sub(s.lastAwakenTime) > awakenStoreInterval } -// HasRecentlySplitRegions checks if there are some region are splitted in this store. -func (s *StoreInfo) HasRecentlySplitRegions() bool { - return time.Since(s.recentlySplitRegionsTime) < splitStoreWait -} - var ( // If a store's last heartbeat is storeDisconnectDuration ago, the store will // be marked as disconnected state. The value should be greater than tikv's diff --git a/pkg/core/store_option.go b/pkg/core/store_option.go index 4d8864ea4788..8a2aa1ef089f 100644 --- a/pkg/core/store_option.go +++ b/pkg/core/store_option.go @@ -274,10 +274,3 @@ func SetLastAwakenTime(lastAwaken time.Time) StoreCreateOption { store.lastAwakenTime = lastAwaken } } - -// SetRecentlySplitRegionsTime sets last split time for the store. -func SetRecentlySplitRegionsTime(recentlySplitRegionsTime time.Time) StoreCreateOption { - return func(store *StoreInfo) { - store.recentlySplitRegionsTime = recentlySplitRegionsTime - } -} diff --git a/pkg/mcs/scheduling/server/cluster.go b/pkg/mcs/scheduling/server/cluster.go index b2986f722df7..20af077d241f 100644 --- a/pkg/mcs/scheduling/server/cluster.go +++ b/pkg/mcs/scheduling/server/cluster.go @@ -433,8 +433,8 @@ func (c *Cluster) processRegionHeartbeat(region *core.RegionInfo) error { // Save to storage if meta is updated, except for flashback. // Save to cache if meta or leader is updated, or contains any down/pending peer. // Mark isNew if the region in cache does not have leader. - changed := core.GenerateRegionGuideFunc(true)(region, origin) - if !changed.SaveCache && !changed.IsNew { + isNew, _, saveCache, _ := core.GenerateRegionGuideFunc(true)(region, origin) + if !saveCache && !isNew { // Due to some config changes need to update the region stats as well, // so we do some extra checks here. if hasRegionStats && c.regionStats.RegionStatsNeedUpdate(region) { @@ -444,7 +444,7 @@ func (c *Cluster) processRegionHeartbeat(region *core.RegionInfo) error { } var overlaps []*core.RegionInfo - if changed.SaveCache { + if saveCache { // To prevent a concurrent heartbeat of another region from overriding the up-to-date region info by a stale one, // check its validation again here. // @@ -456,7 +456,7 @@ func (c *Cluster) processRegionHeartbeat(region *core.RegionInfo) error { cluster.HandleOverlaps(c, overlaps) } - cluster.Collect(c, region, c.GetRegionStores(region), hasRegionStats, changed.IsNew, c.IsPrepared()) + cluster.Collect(c, region, c.GetRegionStores(region), hasRegionStats, isNew, c.IsPrepared()) return nil } diff --git a/pkg/schedule/filter/counter.go b/pkg/schedule/filter/counter.go index 0619bbdde29c..0120ef5b6663 100644 --- a/pkg/schedule/filter/counter.go +++ b/pkg/schedule/filter/counter.go @@ -127,7 +127,6 @@ const ( storeStateTooManyPendingPeer storeStateRejectLeader storeStateSlowTrend - storeStateRecentlySplitRegions filtersLen ) @@ -157,7 +156,6 @@ var filters = [filtersLen]string{ "store-state-too-many-pending-peers-filter", "store-state-reject-leader-filter", "store-state-slow-trend-filter", - "store-state-recently-split-regions-filter", } // String implements fmt.Stringer interface. diff --git a/pkg/schedule/filter/counter_test.go b/pkg/schedule/filter/counter_test.go index f8b6c0bcb8dd..067a07f138b4 100644 --- a/pkg/schedule/filter/counter_test.go +++ b/pkg/schedule/filter/counter_test.go @@ -27,7 +27,7 @@ func TestString(t *testing.T) { expected string }{ {int(storeStateTombstone), "store-state-tombstone-filter"}, - {int(filtersLen - 1), "store-state-recently-split-regions-filter"}, + {int(filtersLen - 1), "store-state-slow-trend-filter"}, {int(filtersLen), "unknown"}, } diff --git a/pkg/schedule/filter/filters.go b/pkg/schedule/filter/filters.go index e76969127d1c..0d188e69180a 100644 --- a/pkg/schedule/filter/filters.go +++ b/pkg/schedule/filter/filters.go @@ -332,8 +332,6 @@ type StoreStateFilter struct { // If it checks failed, the operator will be put back to the waiting queue util the limit is available. // But the scheduler should keep the same with the operator level. OperatorLevel constant.PriorityLevel - // check the store not split recently in it if set true. - ForbidRecentlySplitRegions bool // Reason is used to distinguish the reason of store state filter Reason filterType } @@ -473,15 +471,6 @@ func (f *StoreStateFilter) hasRejectLeaderProperty(conf config.SharedConfigProvi return statusOK } -func (f *StoreStateFilter) hasRecentlySplitRegions(_ config.SharedConfigProvider, store *core.StoreInfo) *plan.Status { - if f.ForbidRecentlySplitRegions && store.HasRecentlySplitRegions() { - f.Reason = storeStateRecentlySplitRegions - return statusStoreRecentlySplitRegions - } - f.Reason = storeStateOK - return statusOK -} - // The condition table. // Y: the condition is temporary (expected to become false soon). // N: the condition is expected to be true for a long time. @@ -510,7 +499,7 @@ func (f *StoreStateFilter) anyConditionMatch(typ int, conf config.SharedConfigPr var funcs []conditionFunc switch typ { case leaderSource: - funcs = []conditionFunc{f.isRemoved, f.isDown, f.pauseLeaderTransfer, f.isDisconnected, f.hasRecentlySplitRegions} + funcs = []conditionFunc{f.isRemoved, f.isDown, f.pauseLeaderTransfer, f.isDisconnected} case regionSource: funcs = []conditionFunc{f.isBusy, f.exceedRemoveLimit, f.tooManySnapshots} case witnessSource: diff --git a/pkg/schedule/filter/region_filters.go b/pkg/schedule/filter/region_filters.go index 70cdb8500b0f..799cee7d90c8 100644 --- a/pkg/schedule/filter/region_filters.go +++ b/pkg/schedule/filter/region_filters.go @@ -24,6 +24,24 @@ import ( "github.com/tikv/pd/pkg/slice" ) +// SelectRegions selects regions that be selected from the list. +func SelectRegions(regions []*core.RegionInfo, filters ...RegionFilter) []*core.RegionInfo { + return filterRegionsBy(regions, func(r *core.RegionInfo) bool { + return slice.AllOf(filters, func(i int) bool { + return filters[i].Select(r).IsOK() + }) + }) +} + +func filterRegionsBy(regions []*core.RegionInfo, keepPred func(*core.RegionInfo) bool) (selected []*core.RegionInfo) { + for _, s := range regions { + if keepPred(s) { + selected = append(selected, s) + } + } + return +} + // SelectOneRegion selects one region that be selected from the list. func SelectOneRegion(regions []*core.RegionInfo, collector *plan.Collector, filters ...RegionFilter) *core.RegionInfo { for _, r := range regions { @@ -155,7 +173,7 @@ type SnapshotSenderFilter struct { senders map[uint64]struct{} } -// NewSnapshotSendFilter returns creates a RegionFilter that filters regions whose leader has sender limit on the specific store. +// NewSnapshotSendFilter returns creates a RegionFilter that filters regions with witness peer on the specific store. // level should be set as same with the operator priority level. func NewSnapshotSendFilter(stores []*core.StoreInfo, level constant.PriorityLevel) RegionFilter { senders := make(map[uint64]struct{}) @@ -175,28 +193,3 @@ func (f *SnapshotSenderFilter) Select(region *core.RegionInfo) *plan.Status { } return statusRegionLeaderSendSnapshotThrottled } - -// StoreRecentlySplitFilter filer the region whose leader store not recently split regions. -type StoreRecentlySplitFilter struct { - recentlySplitStores map[uint64]struct{} -} - -// NewStoreRecentlySplitFilter returns creates a StoreRecentlySplitFilter. -func NewStoreRecentlySplitFilter(stores []*core.StoreInfo) RegionFilter { - recentlySplitStores := make(map[uint64]struct{}) - for _, store := range stores { - if store.HasRecentlySplitRegions() { - recentlySplitStores[store.GetID()] = struct{}{} - } - } - return &StoreRecentlySplitFilter{recentlySplitStores: recentlySplitStores} -} - -// Select returns ok if the region leader not in the recentlySplitStores. -func (f *StoreRecentlySplitFilter) Select(region *core.RegionInfo) *plan.Status { - leaderStoreID := region.GetLeader().GetStoreId() - if _, ok := f.recentlySplitStores[leaderStoreID]; ok { - return statusStoreRecentlySplitRegions - } - return statusOK -} diff --git a/pkg/schedule/filter/status.go b/pkg/schedule/filter/status.go index 9b6665a2fa72..930c59e3ba87 100644 --- a/pkg/schedule/filter/status.go +++ b/pkg/schedule/filter/status.go @@ -39,9 +39,8 @@ var ( // store config limitation statusStoreRejectLeader = plan.NewStatus(plan.StatusStoreRejectLeader) - statusStoreNotMatchRule = plan.NewStatus(plan.StatusStoreNotMatchRule) - statusStoreNotMatchIsolation = plan.NewStatus(plan.StatusStoreNotMatchIsolation) - statusStoreRecentlySplitRegions = plan.NewStatus(plan.StatusStoreRecentlySplitRegions) + statusStoreNotMatchRule = plan.NewStatus(plan.StatusStoreNotMatchRule) + statusStoreNotMatchIsolation = plan.NewStatus(plan.StatusStoreNotMatchIsolation) // region filter status statusRegionPendingPeer = plan.NewStatus(plan.StatusRegionUnhealthy) diff --git a/pkg/schedule/plan/status.go b/pkg/schedule/plan/status.go index 847d03a17ff3..4242b6314939 100644 --- a/pkg/schedule/plan/status.go +++ b/pkg/schedule/plan/status.go @@ -72,8 +72,6 @@ const ( StatusStoreLowSpace = iota + 500 // StatusStoreNotExisted represents the store cannot be found in PD. StatusStoreNotExisted - // StatusStoreRecentlySplitRegions represents the store cannot be selected due to the region is splitting. - StatusStoreRecentlySplitRegions ) // TODO: define region status priority @@ -129,8 +127,7 @@ var statusText = map[StatusCode]string{ StatusStoreDown: "StoreDown", StatusStoreBusy: "StoreBusy", - StatusStoreNotExisted: "StoreNotExisted", - StatusStoreRecentlySplitRegions: "StoreRecentlySplitRegions", + StatusStoreNotExisted: "StoreNotExisted", // region StatusRegionHot: "RegionHot", diff --git a/pkg/schedule/schedulers/balance_leader.go b/pkg/schedule/schedulers/balance_leader.go index 46f7fdc29cdd..e5516317f461 100644 --- a/pkg/schedule/schedulers/balance_leader.go +++ b/pkg/schedule/schedulers/balance_leader.go @@ -48,6 +48,8 @@ const ( // Default value is 4 which is subjected by scheduler-max-waiting-operator and leader-schedule-limit // If you want to increase balance speed more, please increase above-mentioned param. BalanceLeaderBatchSize = 4 + // MaxBalanceLeaderBatchSize is maximum of balance leader batch size + MaxBalanceLeaderBatchSize = 10 transferIn = "transfer-in" transferOut = "transfer-out" @@ -148,7 +150,7 @@ func (handler *balanceLeaderHandler) UpdateConfig(w http.ResponseWriter, r *http handler.rd.JSON(w, httpCode, v) } -func (handler *balanceLeaderHandler) ListConfig(w http.ResponseWriter, _ *http.Request) { +func (handler *balanceLeaderHandler) ListConfig(w http.ResponseWriter, r *http.Request) { conf := handler.config.Clone() handler.rd.JSON(w, http.StatusOK, conf) } @@ -160,7 +162,6 @@ type balanceLeaderScheduler struct { conf *balanceLeaderSchedulerConfig handler http.Handler filters []filter.Filter - regionFilters filter.RegionFilter filterCounter *filter.Counter } @@ -180,7 +181,7 @@ func newBalanceLeaderScheduler(opController *operator.Controller, conf *balanceL option(s) } s.filters = []filter.Filter{ - &filter.StoreStateFilter{ActionScope: s.GetName(), TransferLeader: true, ForbidRecentlySplitRegions: true, OperatorLevel: constant.High}, + &filter.StoreStateFilter{ActionScope: s.GetName(), TransferLeader: true, OperatorLevel: constant.High}, filter.NewSpecialUseFilter(s.GetName()), } return s @@ -276,7 +277,7 @@ func (cs *candidateStores) less(iID uint64, scorei float64, jID uint64, scorej f return scorei > scorej } -// hasStore returns true when there are leftover stores. +// hasStore returns returns true when there are leftover stores. func (cs *candidateStores) hasStore() bool { return cs.index < len(cs.stores) } @@ -348,7 +349,6 @@ func (l *balanceLeaderScheduler) Schedule(cluster sche.SchedulerCluster, dryRun opInfluence := l.OpController.GetOpInfluence(cluster.GetBasicCluster()) kind := constant.NewScheduleKind(constant.LeaderKind, leaderSchedulePolicy) solver := newSolver(basePlan, kind, cluster, opInfluence) - l.regionFilters = filter.NewStoreRecentlySplitFilter(cluster.GetStores()) stores := cluster.GetStores() scoreFunc := func(store *core.StoreInfo) float64 { @@ -486,7 +486,7 @@ func (l *balanceLeaderScheduler) transferLeaderOut(solver *solver, collector *pl // the worst follower peer and transfers the leader. func (l *balanceLeaderScheduler) transferLeaderIn(solver *solver, collector *plan.Collector) *operator.Operator { solver.Region = filter.SelectOneRegion(solver.RandFollowerRegions(solver.TargetStoreID(), l.conf.Ranges), - nil, filter.NewRegionPendingFilter(), filter.NewRegionDownFilter(), l.regionFilters) + nil, filter.NewRegionPendingFilter(), filter.NewRegionDownFilter()) if solver.Region == nil { log.Debug("store has no follower", zap.String("scheduler", l.GetName()), zap.Uint64("store-id", solver.TargetStoreID())) balanceLeaderNoFollowerRegionCounter.Inc() @@ -508,7 +508,6 @@ func (l *balanceLeaderScheduler) transferLeaderIn(solver *solver, collector *pla balanceLeaderNoLeaderRegionCounter.Inc() return nil } - finalFilters := l.filters conf := solver.GetSchedulerConfig() if leaderFilter := filter.NewPlacementLeaderSafeguard(l.GetName(), conf, solver.GetBasicCluster(), solver.GetRuleManager(), solver.Region, solver.Source, false /*allowMoveLeader*/); leaderFilter != nil { diff --git a/pkg/schedule/schedulers/balance_test.go b/pkg/schedule/schedulers/balance_test.go index 3231716c6810..54fe8ff489bc 100644 --- a/pkg/schedule/schedulers/balance_test.go +++ b/pkg/schedule/schedulers/balance_test.go @@ -20,7 +20,6 @@ import ( "math/rand" "sort" "testing" - "time" "github.com/docker/go-units" "github.com/pingcap/kvproto/pkg/metapb" @@ -295,13 +294,6 @@ func (suite *balanceLeaderSchedulerTestSuite) TestBalanceLimit() { // Region1: F F F L suite.tc.UpdateLeaderCount(4, 16) suite.NotEmpty(suite.schedule()) - - // can't balance leader from 4 to 1 when store 1 has split in it. - store := suite.tc.GetStore(4) - store = store.Clone(core.SetRecentlySplitRegionsTime(time.Now())) - suite.tc.PutStore(store) - op := suite.schedule() - suite.Empty(op) } func (suite *balanceLeaderSchedulerTestSuite) TestBalanceLeaderSchedulePolicy() { diff --git a/pkg/syncer/client.go b/pkg/syncer/client.go index b0892a6736aa..ac409f901157 100644 --- a/pkg/syncer/client.go +++ b/pkg/syncer/client.go @@ -194,7 +194,7 @@ func (s *RegionSyncer) StartSyncWithLeader(addr string) { log.Debug("region is stale", zap.Stringer("origin", origin.GetMeta()), errs.ZapError(err)) continue } - changed := regionGuide(region, origin) + _, saveKV, _, _ := regionGuide(region, origin) overlaps := bc.PutRegion(region) if hasBuckets { @@ -202,7 +202,7 @@ func (s *RegionSyncer) StartSyncWithLeader(addr string) { region.UpdateBuckets(buckets[i], old) } } - if changed.SaveKV { + if saveKV { err = regionStorage.SaveRegion(r) } if err == nil { diff --git a/server/cluster/cluster.go b/server/cluster/cluster.go index d42dbb21ed13..22e1b16d822b 100644 --- a/server/cluster/cluster.go +++ b/server/cluster/cluster.go @@ -1113,16 +1113,12 @@ func (c *RaftCluster) processRegionHeartbeat(region *core.RegionInfo) error { cluster.HandleStatsAsync(c, region) } + hasRegionStats := c.regionStats != nil + // Save to storage if meta is updated, except for flashback. // Save to cache if meta or leader is updated, or contains any down/pending peer. // Mark isNew if the region in cache does not have leader. - changed := regionGuide(region, origin) - return c.SaveRegion(region, changed) -} - -// SaveRegion saves region info into cache and PD storage. -func (c *RaftCluster) SaveRegion(region *core.RegionInfo, changed *core.RegionChanged) (err error) { - hasRegionStats := c.regionStats != nil - if !c.isAPIServiceMode && !changed.SaveKV && !changed.SaveCache && !changed.IsNew { + isNew, saveKV, saveCache, needSync := regionGuide(region, origin) + if !c.isAPIServiceMode && !saveKV && !saveCache && !isNew { // Due to some config changes need to update the region stats as well, // so we do some extra checks here. if hasRegionStats && c.regionStats.RegionStatsNeedUpdate(region) { @@ -1136,15 +1132,14 @@ func (c *RaftCluster) SaveRegion(region *core.RegionInfo, changed *core.RegionCh }) var overlaps []*core.RegionInfo - - if changed.SaveCache { + if saveCache { failpoint.Inject("decEpoch", func() { region = region.Clone(core.SetRegionConfVer(2), core.SetRegionVersion(2)) }) // To prevent a concurrent heartbeat of another region from overriding the up-to-date region info by a stale one, // check its validation again here. // - // However, it can't solve the race condition of concurrent heartbeats from the same region. + // However it can't solve the race condition of concurrent heartbeats from the same region. if overlaps, err = c.core.AtomicCheckAndPutRegion(region); err != nil { return err } @@ -1155,7 +1150,7 @@ func (c *RaftCluster) SaveRegion(region *core.RegionInfo, changed *core.RegionCh } if !c.isAPIServiceMode { - cluster.Collect(c, region, c.GetRegionStores(region), hasRegionStats, changed.IsNew, c.IsPrepared()) + cluster.Collect(c, region, c.GetRegionStores(region), hasRegionStats, isNew, c.IsPrepared()) } if c.storage != nil { @@ -1171,7 +1166,7 @@ func (c *RaftCluster) SaveRegion(region *core.RegionInfo, changed *core.RegionCh errs.ZapError(err)) } } - if changed.SaveKV { + if saveKV { if err := c.storage.SaveRegion(region.GetMeta()); err != nil { log.Error("failed to save region to storage", zap.Uint64("region-id", region.GetID()), @@ -1182,12 +1177,13 @@ func (c *RaftCluster) SaveRegion(region *core.RegionInfo, changed *core.RegionCh } } - if changed.SaveKV || changed.NeedSync { + if saveKV || needSync { select { case c.changedRegions <- region: default: } } + return nil } diff --git a/server/cluster/cluster_worker.go b/server/cluster/cluster_worker.go index 3036fe95b3ea..c1da97363b53 100644 --- a/server/cluster/cluster_worker.go +++ b/server/cluster/cluster_worker.go @@ -16,8 +16,6 @@ package cluster import ( "bytes" - "fmt" - "time" "github.com/pingcap/errors" "github.com/pingcap/kvproto/pkg/metapb" @@ -28,13 +26,11 @@ import ( "github.com/tikv/pd/pkg/schedule/operator" "github.com/tikv/pd/pkg/statistics/buckets" "github.com/tikv/pd/pkg/utils/logutil" + "github.com/tikv/pd/pkg/utils/typeutil" "github.com/tikv/pd/pkg/versioninfo" "go.uber.org/zap" ) -// store doesn't pick balance leader source if the split region is bigger than maxSplitThreshold. -const maxSplitThreshold = 10 - // HandleRegionHeartbeat processes RegionInfo reports from client. func (c *RaftCluster) HandleRegionHeartbeat(region *core.RegionInfo) error { if err := c.processRegionHeartbeat(region); err != nil { @@ -45,58 +41,6 @@ func (c *RaftCluster) HandleRegionHeartbeat(region *core.RegionInfo) error { return nil } -// ProcessRegionSplit to process split region into region cache. -// it's different with the region heartbeat, it's only fill some new region into the region cache. -// so it doesn't consider the leader and hot statistics. -func (c *RaftCluster) ProcessRegionSplit(regions []*metapb.Region) []error { - if err := c.checkSplitRegions(regions); err != nil { - return []error{err} - } - total := len(regions) - 1 - regions[0], regions[total] = regions[total], regions[0] - leaderStoreID := uint64(0) - if r := c.core.GetRegion(regions[0].GetId()); r != nil { - leaderStoreID = r.GetLeader().GetStoreId() - } - if leaderStoreID == 0 { - return []error{errors.New("origin region no leader")} - } - leaderStore := c.GetStore(leaderStoreID) - if leaderStore == nil { - return []error{errors.New("leader store not found")} - } - errList := make([]error, 0, total) - for _, region := range regions { - if len(region.GetPeers()) == 0 { - errList = append(errList, errors.New(fmt.Sprintf("region:%d has no peer", region.GetId()))) - continue - } - // region split initiator store will be leader with a high probability - leader := region.Peers[0] - if leaderStoreID > 0 { - for _, peer := range region.GetPeers() { - if peer.GetStoreId() == leaderStoreID { - leader = peer - break - } - } - } - region := core.NewRegionInfo(region, leader) - changed := &core.RegionChanged{ - IsNew: true, SaveKV: true, SaveCache: true, NeedSync: true, - } - if err := c.SaveRegion(region, changed); err != nil { - errList = append(errList, err) - } - } - // If the number of regions exceeds the threshold, update the last split time. - if len(regions) >= maxSplitThreshold { - newStore := leaderStore.Clone(core.SetRecentlySplitRegionsTime(time.Now())) - c.core.PutStore(newStore) - } - return errList -} - // HandleAskSplit handles the split request. func (c *RaftCluster) HandleAskSplit(request *pdpb.AskSplitRequest) (*pdpb.AskSplitResponse, error) { if c.isSchedulingHalted() { @@ -221,6 +165,22 @@ func (c *RaftCluster) HandleAskBatchSplit(request *pdpb.AskBatchSplitRequest) (* return resp, nil } +func (c *RaftCluster) checkSplitRegion(left *metapb.Region, right *metapb.Region) error { + if left == nil || right == nil { + return errors.New("invalid split region") + } + + if !bytes.Equal(left.GetEndKey(), right.GetStartKey()) { + return errors.New("invalid split region") + } + + if len(right.GetEndKey()) == 0 || bytes.Compare(left.GetStartKey(), right.GetEndKey()) < 0 { + return nil + } + + return errors.New("invalid split region") +} + func (c *RaftCluster) checkSplitRegions(regions []*metapb.Region) error { if len(regions) <= 1 { return errors.New("invalid split region") @@ -244,18 +204,21 @@ func (c *RaftCluster) HandleReportSplit(request *pdpb.ReportSplitRequest) (*pdpb left := request.GetLeft() right := request.GetRight() - if errs := c.ProcessRegionSplit([]*metapb.Region{left, right}); len(errs) > 0 { + err := c.checkSplitRegion(left, right) + if err != nil { log.Warn("report split region is invalid", logutil.ZapRedactStringer("left-region", core.RegionToHexMeta(left)), logutil.ZapRedactStringer("right-region", core.RegionToHexMeta(right)), - zap.Errors("errs", errs), - ) - // error[0] may be checker error, others are ignored. - return nil, errs[0] + errs.ZapError(err)) + return nil, err } + // Build origin region by using left and right. + originRegion := typeutil.DeepClone(right, core.RegionFactory) + originRegion.RegionEpoch = nil + originRegion.StartKey = left.GetStartKey() log.Info("region split, generate new region", - zap.Uint64("region-id", right.GetId()), + zap.Uint64("region-id", originRegion.GetId()), logutil.ZapRedactStringer("region-meta", core.RegionToHexMeta(left))) return &pdpb.ReportSplitResponse{}, nil } @@ -263,19 +226,21 @@ func (c *RaftCluster) HandleReportSplit(request *pdpb.ReportSplitRequest) (*pdpb // HandleBatchReportSplit handles the batch report split request. func (c *RaftCluster) HandleBatchReportSplit(request *pdpb.ReportBatchSplitRequest) (*pdpb.ReportBatchSplitResponse, error) { regions := request.GetRegions() + hrm := core.RegionsToHexMeta(regions) - if errs := c.ProcessRegionSplit(regions); len(errs) > 0 { + err := c.checkSplitRegions(regions) + if err != nil { log.Warn("report batch split region is invalid", zap.Stringer("region-meta", hrm), - zap.Errors("errs", errs)) - // error[0] may be checker error, others are ignored. - return nil, errs[0] + errs.ZapError(err)) + return nil, err } last := len(regions) - 1 - originRegionID := regions[last].GetId() + originRegion := typeutil.DeepClone(regions[last], core.RegionFactory) + hrm = core.RegionsToHexMeta(regions[:last]) log.Info("region batch split, generate new regions", - zap.Uint64("region-id", originRegionID), - zap.Stringer("new-peer", hrm[:last]), + zap.Uint64("region-id", originRegion.GetId()), + zap.Stringer("origin", hrm), zap.Int("total", last)) return &pdpb.ReportBatchSplitResponse{}, nil } diff --git a/server/cluster/cluster_worker_test.go b/server/cluster/cluster_worker_test.go index 98b9b8380f12..b376b38edc3a 100644 --- a/server/cluster/cluster_worker_test.go +++ b/server/cluster/cluster_worker_test.go @@ -23,23 +23,9 @@ import ( "github.com/stretchr/testify/require" "github.com/tikv/pd/pkg/core" "github.com/tikv/pd/pkg/mock/mockid" - "github.com/tikv/pd/pkg/schedule" "github.com/tikv/pd/pkg/storage" ) -func mockRegionPeer(cluster *RaftCluster, voters []uint64) []*metapb.Peer { - rst := make([]*metapb.Peer, len(voters)) - for i, v := range voters { - id, _ := cluster.AllocID() - rst[i] = &metapb.Peer{ - Id: id, - StoreId: v, - Role: metapb.PeerRole_Voter, - } - } - return rst -} - func TestReportSplit(t *testing.T) { re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) @@ -48,56 +34,12 @@ func TestReportSplit(t *testing.T) { _, opt, err := newTestScheduleConfig() re.NoError(err) cluster := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) - cluster.coordinator = schedule.NewCoordinator(cluster.ctx, cluster, nil) - right := &metapb.Region{Id: 1, StartKey: []byte("a"), EndKey: []byte("c"), Peers: mockRegionPeer(cluster, []uint64{1, 2, 3}), - RegionEpoch: &metapb.RegionEpoch{ConfVer: 1, Version: 1}} - region := core.NewRegionInfo(right, right.Peers[0]) - cluster.putRegion(region) - store := newTestStores(1, "2.0.0") - cluster.core.PutStore(store[0]) - - // split failed, split region keys must be continuous. - left := &metapb.Region{Id: 2, StartKey: []byte("a"), EndKey: []byte("b"), Peers: mockRegionPeer(cluster, []uint64{1, 2, 3}), - RegionEpoch: &metapb.RegionEpoch{ConfVer: 1, Version: 2}} - _, err = cluster.HandleReportSplit(&pdpb.ReportSplitRequest{Left: right, Right: left}) - re.Error(err) - - // split success with continuous region keys. - right = &metapb.Region{Id: 1, StartKey: []byte("b"), EndKey: []byte("c"), Peers: mockRegionPeer(cluster, []uint64{1, 2, 3}), - RegionEpoch: &metapb.RegionEpoch{ConfVer: 1, Version: 2}} + left := &metapb.Region{Id: 1, StartKey: []byte("a"), EndKey: []byte("b")} + right := &metapb.Region{Id: 2, StartKey: []byte("b"), EndKey: []byte("c")} _, err = cluster.HandleReportSplit(&pdpb.ReportSplitRequest{Left: left, Right: right}) re.NoError(err) - // no range hole - storeID := region.GetLeader().GetStoreId() - re.Equal(storeID, cluster.GetRegionByKey([]byte("b")).GetLeader().GetStoreId()) - re.Equal(storeID, cluster.GetRegionByKey([]byte("a")).GetLeader().GetStoreId()) - re.Equal(uint64(1), cluster.GetRegionByKey([]byte("b")).GetID()) - re.Equal(uint64(2), cluster.GetRegionByKey([]byte("a")).GetID()) - - testdata := []struct { - regionID uint64 - startKey []byte - endKey []byte - }{ - { - regionID: 1, - startKey: []byte("b"), - endKey: []byte("c"), - }, { - regionID: 2, - startKey: []byte("a"), - endKey: []byte("b"), - }, - } - - for _, data := range testdata { - r := metapb.Region{} - ok, err := cluster.storage.LoadRegion(data.regionID, &r) - re.NoError(err) - re.True(ok) - re.Equal(data.startKey, r.GetStartKey()) - re.Equal(data.endKey, r.GetEndKey()) - } + _, err = cluster.HandleReportSplit(&pdpb.ReportSplitRequest{Left: right, Right: left}) + re.Error(err) } func TestReportBatchSplit(t *testing.T) { @@ -108,39 +50,12 @@ func TestReportBatchSplit(t *testing.T) { _, opt, err := newTestScheduleConfig() re.NoError(err) cluster := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend(), core.NewBasicCluster()) - cluster.coordinator = schedule.NewCoordinator(ctx, cluster, nil) - store := newTestStores(1, "2.0.0") - cluster.core.PutStore(store[0]) - re.False(cluster.GetStore(1).HasRecentlySplitRegions()) regions := []*metapb.Region{ - {Id: 1, StartKey: []byte(""), EndKey: []byte("a"), Peers: mockRegionPeer(cluster, []uint64{1, 2, 3})}, - {Id: 2, StartKey: []byte("a"), EndKey: []byte("b"), Peers: mockRegionPeer(cluster, []uint64{1, 2, 3})}, - {Id: 3, StartKey: []byte("b"), EndKey: []byte("c"), Peers: mockRegionPeer(cluster, []uint64{1, 2, 3})}, - {Id: 4, StartKey: []byte("c"), EndKey: []byte(""), Peers: mockRegionPeer(cluster, []uint64{1, 2, 3})}, - } - _, err = cluster.HandleBatchReportSplit(&pdpb.ReportBatchSplitRequest{Regions: regions}) - re.Error(err) - - meta := &metapb.Region{Id: 1, StartKey: []byte(""), EndKey: []byte(""), Peers: mockRegionPeer(cluster, []uint64{1, 2, 3}), - RegionEpoch: &metapb.RegionEpoch{ConfVer: 1, Version: 1}} - region := core.NewRegionInfo(meta, meta.Peers[0]) - cluster.putRegion(region) - - regions = []*metapb.Region{ - {Id: 2, StartKey: []byte(""), EndKey: []byte("a"), Peers: mockRegionPeer(cluster, []uint64{1, 2, 3}), RegionEpoch: &metapb.RegionEpoch{ConfVer: 1, Version: 2}}, - {Id: 3, StartKey: []byte("a"), EndKey: []byte("b"), Peers: mockRegionPeer(cluster, []uint64{1, 2, 3}), RegionEpoch: &metapb.RegionEpoch{ConfVer: 1, Version: 2}}, - {Id: 4, StartKey: []byte("b"), EndKey: []byte("c"), Peers: mockRegionPeer(cluster, []uint64{1, 2, 3}), RegionEpoch: &metapb.RegionEpoch{ConfVer: 1, Version: 2}}, - {Id: 5, StartKey: []byte("c"), EndKey: []byte("d"), Peers: mockRegionPeer(cluster, []uint64{1, 2, 3}), RegionEpoch: &metapb.RegionEpoch{ConfVer: 1, Version: 2}}, - {Id: 6, StartKey: []byte("d"), EndKey: []byte("e"), Peers: mockRegionPeer(cluster, []uint64{1, 2, 3}), RegionEpoch: &metapb.RegionEpoch{ConfVer: 1, Version: 2}}, - {Id: 7, StartKey: []byte("e"), EndKey: []byte("f"), Peers: mockRegionPeer(cluster, []uint64{1, 2, 3}), RegionEpoch: &metapb.RegionEpoch{ConfVer: 1, Version: 2}}, - {Id: 8, StartKey: []byte("f"), EndKey: []byte("g"), Peers: mockRegionPeer(cluster, []uint64{1, 2, 3}), RegionEpoch: &metapb.RegionEpoch{ConfVer: 1, Version: 2}}, - {Id: 9, StartKey: []byte("g"), EndKey: []byte("h"), Peers: mockRegionPeer(cluster, []uint64{1, 2, 3}), RegionEpoch: &metapb.RegionEpoch{ConfVer: 1, Version: 2}}, - {Id: 10, StartKey: []byte("h"), EndKey: []byte("i"), Peers: mockRegionPeer(cluster, []uint64{1, 2, 3}), RegionEpoch: &metapb.RegionEpoch{ConfVer: 1, Version: 2}}, - - {Id: 1, StartKey: []byte("i"), EndKey: []byte(""), Peers: mockRegionPeer(cluster, []uint64{1, 2, 3}), RegionEpoch: &metapb.RegionEpoch{ConfVer: 1, Version: 2}}, + {Id: 1, StartKey: []byte(""), EndKey: []byte("a")}, + {Id: 2, StartKey: []byte("a"), EndKey: []byte("b")}, + {Id: 3, StartKey: []byte("b"), EndKey: []byte("c")}, + {Id: 3, StartKey: []byte("c"), EndKey: []byte("")}, } _, err = cluster.HandleBatchReportSplit(&pdpb.ReportBatchSplitRequest{Regions: regions}) re.NoError(err) - - re.True(cluster.GetStore(1).HasRecentlySplitRegions()) } diff --git a/server/grpc_service.go b/server/grpc_service.go index d218c2bb0b60..5e40bc1c732c 100644 --- a/server/grpc_service.go +++ b/server/grpc_service.go @@ -1428,24 +1428,10 @@ func (s *GrpcServer) GetRegion(ctx context.Context, request *pdpb.GetRegionReque if rc == nil { return &pdpb.GetRegionResponse{Header: s.notBootstrappedHeader()}, nil } - var region *core.RegionInfo - // allow region miss temporarily if this key can't be found in the region tree. -retryLoop: - for retry := 0; retry <= 10; retry++ { - region = rc.GetRegionByKey(request.GetRegionKey()) - if region != nil { - break retryLoop - } - select { - case <-ctx.Done(): - break retryLoop - case <-time.After(10 * time.Millisecond): - } - } + region := rc.GetRegionByKey(request.GetRegionKey()) if region == nil { return &pdpb.GetRegionResponse{Header: s.header()}, nil } - var buckets *metapb.Buckets if rc.GetStoreConfig().IsEnableRegionBucket() && request.GetNeedBuckets() { buckets = region.GetBuckets() @@ -1487,21 +1473,7 @@ func (s *GrpcServer) GetPrevRegion(ctx context.Context, request *pdpb.GetRegionR return &pdpb.GetRegionResponse{Header: s.notBootstrappedHeader()}, nil } - var region *core.RegionInfo - // allow region miss temporarily if this key can't be found in the region tree. -retryLoop: - for retry := 0; retry <= 10; retry++ { - region = rc.GetPrevRegionByKey(request.GetRegionKey()) - if region != nil { - break retryLoop - } - select { - case <-ctx.Done(): - break retryLoop - case <-time.After(10 * time.Millisecond): - } - } - + region := rc.GetPrevRegionByKey(request.GetRegionKey()) if region == nil { return &pdpb.GetRegionResponse{Header: s.header()}, nil } From 301b917365bb071652b297cf5cbc8e2c0e0e6ddd Mon Sep 17 00:00:00 2001 From: ShuNing Date: Fri, 22 Sep 2023 12:16:44 +0800 Subject: [PATCH 6/8] resourcemanager: change the ru label name (#7135) close tikv/pd#4399 resourcemanager: change the ru label name Signed-off-by: nolouch --- pkg/mcs/resourcemanager/server/manager.go | 2 +- pkg/mcs/resourcemanager/server/metrics.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pkg/mcs/resourcemanager/server/manager.go b/pkg/mcs/resourcemanager/server/manager.go index df237bd0feb0..1731faf8af1c 100644 --- a/pkg/mcs/resourcemanager/server/manager.go +++ b/pkg/mcs/resourcemanager/server/manager.go @@ -363,7 +363,7 @@ func (m *Manager) backgroundMetricsFlush(ctx context.Context) { if consumption == nil { continue } - ruLabelType := tidbTypeLabel + ruLabelType := defaultTypeLabel if consumptionInfo.isBackground { ruLabelType = backgroundTypeLabel } diff --git a/pkg/mcs/resourcemanager/server/metrics.go b/pkg/mcs/resourcemanager/server/metrics.go index 184eddc8ef95..25d0516d2690 100644 --- a/pkg/mcs/resourcemanager/server/metrics.go +++ b/pkg/mcs/resourcemanager/server/metrics.go @@ -26,8 +26,8 @@ const ( readTypeLabel = "read" writeTypeLabel = "write" backgroundTypeLabel = "background" - tiflashTypeLabel = "tiflash" - tidbTypeLabel = "tidb" + tiflashTypeLabel = "ap" + defaultTypeLabel = "tp" ) var ( From f1107b2dc96e1572df9a6dc1582193a8eda9a25b Mon Sep 17 00:00:00 2001 From: Hu# Date: Fri, 22 Sep 2023 18:03:14 +0800 Subject: [PATCH 7/8] resource_control: watch delete with prev and refine test (#7092) close tikv/pd#7095 Signed-off-by: husharp Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- .../resource_group/controller/controller.go | 25 ++- client/resource_manager_client.go | 47 ----- server/grpc_service.go | 2 +- .../resourcemanager/resource_manager_test.go | 166 +++++++++--------- 4 files changed, 105 insertions(+), 135 deletions(-) diff --git a/client/resource_group/controller/controller.go b/client/resource_group/controller/controller.go index 528369df229a..e3495a21ff1d 100755 --- a/client/resource_group/controller/controller.go +++ b/client/resource_group/controller/controller.go @@ -234,7 +234,8 @@ func (c *ResourceGroupsController) Start(ctx context.Context) { cfgRevision := resp.GetHeader().GetRevision() var watchMetaChannel, watchConfigChannel chan []*meta_storagepb.Event if !c.ruConfig.isSingleGroupByKeyspace { - watchMetaChannel, err = c.provider.Watch(ctx, pd.GroupSettingsPathPrefixBytes, pd.WithRev(metaRevision), pd.WithPrefix()) + // Use WithPrevKV() to get the previous key-value pair when get Delete Event. + watchMetaChannel, err = c.provider.Watch(ctx, pd.GroupSettingsPathPrefixBytes, pd.WithRev(metaRevision), pd.WithPrefix(), pd.WithPrevKV()) if err != nil { log.Warn("watch resource group meta failed", zap.Error(err)) } @@ -260,7 +261,8 @@ func (c *ResourceGroupsController) Start(ctx context.Context) { } case <-watchRetryTimer.C: if !c.ruConfig.isSingleGroupByKeyspace && watchMetaChannel == nil { - watchMetaChannel, err = c.provider.Watch(ctx, pd.GroupSettingsPathPrefixBytes, pd.WithRev(metaRevision), pd.WithPrefix()) + // Use WithPrevKV() to get the previous key-value pair when get Delete Event. + watchMetaChannel, err = c.provider.Watch(ctx, pd.GroupSettingsPathPrefixBytes, pd.WithRev(metaRevision), pd.WithPrefix(), pd.WithPrevKV()) if err != nil { log.Warn("watch resource group meta failed", zap.Error(err)) watchRetryTimer.Reset(watchRetryInterval) @@ -319,18 +321,27 @@ func (c *ResourceGroupsController) Start(ctx context.Context) { for _, item := range resp { metaRevision = item.Kv.ModRevision group := &rmpb.ResourceGroup{} - if err := proto.Unmarshal(item.Kv.Value, group); err != nil { - continue - } switch item.Type { case meta_storagepb.Event_PUT: + if err = proto.Unmarshal(item.Kv.Value, group); err != nil { + continue + } if item, ok := c.groupsController.Load(group.Name); ok { gc := item.(*groupCostController) gc.modifyMeta(group) } case meta_storagepb.Event_DELETE: - if _, ok := c.groupsController.LoadAndDelete(group.Name); ok { - resourceGroupStatusGauge.DeleteLabelValues(group.Name) + if item.PrevKv != nil { + if err = proto.Unmarshal(item.PrevKv.Value, group); err != nil { + continue + } + if _, ok := c.groupsController.LoadAndDelete(group.Name); ok { + resourceGroupStatusGauge.DeleteLabelValues(group.Name) + } + } else { + // Prev-kv is compacted means there must have been a delete event before this event, + // which means that this is just a duplicated event, so we can just ignore it. + log.Info("previous key-value pair has been compacted", zap.String("required-key", string(item.Kv.Key)), zap.String("value", string(item.Kv.Value))) } } } diff --git a/client/resource_manager_client.go b/client/resource_manager_client.go index 68b2de66ae23..309443085848 100644 --- a/client/resource_manager_client.go +++ b/client/resource_manager_client.go @@ -55,7 +55,6 @@ type ResourceManagerClient interface { ModifyResourceGroup(ctx context.Context, metaGroup *rmpb.ResourceGroup) (string, error) DeleteResourceGroup(ctx context.Context, resourceGroupName string) (string, error) LoadResourceGroups(ctx context.Context) ([]*rmpb.ResourceGroup, int64, error) - WatchResourceGroup(ctx context.Context, revision int64) (chan []*rmpb.ResourceGroup, error) AcquireTokenBuckets(ctx context.Context, request *rmpb.TokenBucketsRequest) ([]*rmpb.TokenBucketResponse, error) Watch(ctx context.Context, key []byte, opts ...OpOption) (chan []*meta_storagepb.Event, error) } @@ -188,52 +187,6 @@ func (c *client) LoadResourceGroups(ctx context.Context) ([]*rmpb.ResourceGroup, return groups, resp.Header.Revision, nil } -// WatchResourceGroup [just for TEST] watches resource groups changes. -// It returns a stream of slices of resource groups. -// The first message in stream contains all current resource groups, -// all subsequent messages contains new events[PUT/DELETE] for all resource groups. -func (c *client) WatchResourceGroup(ctx context.Context, revision int64) (chan []*rmpb.ResourceGroup, error) { - configChan, err := c.Watch(ctx, GroupSettingsPathPrefixBytes, WithRev(revision), WithPrefix()) - if err != nil { - return nil, err - } - resourceGroupWatcherChan := make(chan []*rmpb.ResourceGroup) - go func() { - defer func() { - close(resourceGroupWatcherChan) - if r := recover(); r != nil { - log.Error("[pd] panic in ResourceManagerClient `WatchResourceGroups`", zap.Any("error", r)) - return - } - }() - for { - select { - case <-ctx.Done(): - return - case res, ok := <-configChan: - if !ok { - return - } - groups := make([]*rmpb.ResourceGroup, 0, len(res)) - for _, item := range res { - switch item.Type { - case meta_storagepb.Event_PUT: - group := &rmpb.ResourceGroup{} - if err := proto.Unmarshal(item.Kv.Value, group); err != nil { - return - } - groups = append(groups, group) - case meta_storagepb.Event_DELETE: - continue - } - } - resourceGroupWatcherChan <- groups - } - } - }() - return resourceGroupWatcherChan, err -} - func (c *client) AcquireTokenBuckets(ctx context.Context, request *rmpb.TokenBucketsRequest) ([]*rmpb.TokenBucketResponse, error) { req := &tokenRequest{ done: make(chan error, 1), diff --git a/server/grpc_service.go b/server/grpc_service.go index 5e40bc1c732c..dd53416d30d3 100644 --- a/server/grpc_service.go +++ b/server/grpc_service.go @@ -2638,7 +2638,7 @@ func (s *GrpcServer) WatchGlobalConfig(req *pdpb.WatchGlobalConfigRequest, serve } else { // Prev-kv is compacted means there must have been a delete event before this event, // which means that this is just a duplicated event, so we can just ignore it. - log.Info("previous key-value pair has been compacted", zap.String("previous key", string(e.Kv.Key))) + log.Info("previous key-value pair has been compacted", zap.String("required-key", string(e.Kv.Key))) } } } diff --git a/tests/integrations/mcs/resourcemanager/resource_manager_test.go b/tests/integrations/mcs/resourcemanager/resource_manager_test.go index 0be18d1bbd38..6da7bd3aac1f 100644 --- a/tests/integrations/mcs/resourcemanager/resource_manager_test.go +++ b/tests/integrations/mcs/resourcemanager/resource_manager_test.go @@ -205,51 +205,25 @@ func (suite *resourceManagerClientTestSuite) TestWatchResourceGroup() { }, }, } - // Mock get revision by listing - for i := 0; i < 3; i++ { - group.Name += strconv.Itoa(i) - resp, err := cli.AddResourceGroup(suite.ctx, group) - group.Name = "test" - re.NoError(err) - re.Contains(resp, "Success!") - } - lresp, revision, err := cli.LoadResourceGroups(suite.ctx) - re.NoError(err) - re.Equal(len(lresp), 4) - re.Greater(revision, int64(0)) - tcs := tokenConsumptionPerSecond{rruTokensAtATime: 100} - re.NoError(failpoint.Enable("github.com/tikv/pd/client/resource_group/controller/disableWatch", "return(true)")) - defer func() { - re.NoError(failpoint.Disable("github.com/tikv/pd/client/resource_group/controller/disableWatch")) - }() - controllerKeySpace, _ := controller.NewResourceGroupController(suite.ctx, 1, cli, nil, controller.EnableSingleGroupByKeyspace()) controller, _ := controller.NewResourceGroupController(suite.ctx, 1, cli, nil) controller.Start(suite.ctx) defer controller.Stop() - controller.OnRequestWait(suite.ctx, "test0", tcs.makeReadRequest()) - meta := controller.GetActiveResourceGroup("test0") - metaShadow, err := controller.GetResourceGroup("test0") - re.NoError(err) - re.Equal(meta.RUSettings.RU, group.RUSettings.RU) - re.Equal(metaShadow.RUSettings.RU, group.RUSettings.RU) - - controllerKeySpace.OnRequestWait(suite.ctx, "test0", tcs.makeReadRequest()) - metaKeySpace := controllerKeySpace.GetActiveResourceGroup("test0") - re.Equal(metaKeySpace.RUSettings.RU, group.RUSettings.RU) - controller.OnRequestWait(suite.ctx, "test1", tcs.makeReadRequest()) - meta = controller.GetActiveResourceGroup("test1") - metaShadow, err = controller.GetResourceGroup("test1") - re.NoError(err) - re.Equal(meta.RUSettings.RU, group.RUSettings.RU) - re.Equal(metaShadow.RUSettings.RU, group.RUSettings.RU) - suite.NoError(err) // Mock add resource groups - for i := 3; i < 9; i++ { + var meta *rmpb.ResourceGroup + groupsNum := 10 + for i := 0; i < groupsNum; i++ { group.Name = "test" + strconv.Itoa(i) resp, err := cli.AddResourceGroup(suite.ctx, group) re.NoError(err) re.Contains(resp, "Success!") + + // Make sure the resource group active + meta, err = controller.GetResourceGroup(group.Name) + re.NotNil(meta) + re.NoError(err) + meta = controller.GetActiveResourceGroup(group.Name) + re.NotNil(meta) } // Mock modify resource groups modifySettings := func(gs *rmpb.ResourceGroup) { @@ -261,65 +235,97 @@ func (suite *resourceManagerClientTestSuite) TestWatchResourceGroup() { }, } } - re.NoError(failpoint.Enable("github.com/tikv/pd/client/resource_group/controller/watchStreamError", "return(true)")) - for i := 0; i < 2; i++ { - if i == 1 { - testutil.Eventually(re, func() bool { - meta = controller.GetActiveResourceGroup("test0") - return meta.RUSettings.RU.Settings.FillRate == uint64(20000) - }, testutil.WithTickInterval(50*time.Millisecond)) - metaKeySpace = controllerKeySpace.GetActiveResourceGroup("test0") - re.Equal(metaKeySpace.RUSettings.RU.Settings.FillRate, uint64(10000)) - re.NoError(failpoint.Enable("github.com/tikv/pd/client/watchStreamError", "return(true)")) - } + for i := 0; i < groupsNum; i++ { group.Name = "test" + strconv.Itoa(i) modifySettings(group) resp, err := cli.ModifyResourceGroup(suite.ctx, group) re.NoError(err) re.Contains(resp, "Success!") } - time.Sleep(time.Millisecond * 50) - meta = controller.GetActiveResourceGroup("test1") - re.Equal(meta.RUSettings.RU.Settings.FillRate, uint64(10000)) - re.NoError(failpoint.Disable("github.com/tikv/pd/client/watchStreamError")) + for i := 0; i < groupsNum; i++ { + testutil.Eventually(re, func() bool { + name := "test" + strconv.Itoa(i) + meta = controller.GetActiveResourceGroup(name) + if meta != nil { + return meta.RUSettings.RU.Settings.FillRate == uint64(20000) + } + return false + }, testutil.WithTickInterval(50*time.Millisecond)) + } + + // Mock reset watch stream + re.NoError(failpoint.Enable("github.com/tikv/pd/client/resource_group/controller/watchStreamError", "return(true)")) + group.Name = "test" + strconv.Itoa(groupsNum) + resp, err := cli.AddResourceGroup(suite.ctx, group) + re.NoError(err) + re.Contains(resp, "Success!") + // Make sure the resource group active + meta, err = controller.GetResourceGroup(group.Name) + re.NotNil(meta) + re.NoError(err) + modifySettings(group) + resp, err = cli.ModifyResourceGroup(suite.ctx, group) + re.NoError(err) + re.Contains(resp, "Success!") testutil.Eventually(re, func() bool { - meta = controller.GetActiveResourceGroup("test1") + meta = controller.GetActiveResourceGroup(group.Name) return meta.RUSettings.RU.Settings.FillRate == uint64(20000) }, testutil.WithTickInterval(100*time.Millisecond)) re.NoError(failpoint.Disable("github.com/tikv/pd/client/resource_group/controller/watchStreamError")) - for i := 2; i < 9; i++ { - group.Name = "test" + strconv.Itoa(i) - modifySettings(group) - resp, err := cli.ModifyResourceGroup(suite.ctx, group) - re.NoError(err) - re.Contains(resp, "Success!") - } // Mock delete resource groups suite.cleanupResourceGroups() - time.Sleep(time.Second) - meta = controller.GetActiveResourceGroup(group.Name) - re.Nil(meta) + for i := 0; i < groupsNum; i++ { + testutil.Eventually(re, func() bool { + name := "test" + strconv.Itoa(i) + meta = controller.GetActiveResourceGroup(name) + return meta == nil + }, testutil.WithTickInterval(50*time.Millisecond)) + } +} - // Check watch result - watchChan, err := suite.client.WatchResourceGroup(suite.ctx, revision) - re.NoError(err) - i := 0 - for { - select { - case <-time.After(time.Second): - return - case res := <-watchChan: - for _, r := range res { - if i < 6 { - suite.Equal(uint64(10000), r.RUSettings.RU.Settings.FillRate) - } else { - suite.Equal(uint64(20000), r.RUSettings.RU.Settings.FillRate) - } - i++ - } - } +func (suite *resourceManagerClientTestSuite) TestWatchWithSingleGroupByKeyspace() { + re := suite.Require() + cli := suite.client + + // We need to disable watch stream for `isSingleGroupByKeyspace`. + re.NoError(failpoint.Enable("github.com/tikv/pd/client/resource_group/controller/disableWatch", "return(true)")) + defer func() { + re.NoError(failpoint.Disable("github.com/tikv/pd/client/resource_group/controller/disableWatch")) + }() + // Distinguish the controller with and without enabling `isSingleGroupByKeyspace`. + controllerKeySpace, _ := controller.NewResourceGroupController(suite.ctx, 1, cli, nil, controller.EnableSingleGroupByKeyspace()) + controller, _ := controller.NewResourceGroupController(suite.ctx, 2, cli, nil) + controller.Start(suite.ctx) + controllerKeySpace.Start(suite.ctx) + defer controllerKeySpace.Stop() + defer controller.Stop() + + // Mock add resource group. + group := &rmpb.ResourceGroup{ + Name: "test", + Mode: rmpb.GroupMode_RUMode, + RUSettings: &rmpb.GroupRequestUnitSettings{ + RU: &rmpb.TokenBucket{ + Settings: &rmpb.TokenLimitSettings{ + FillRate: 10000, + }, + Tokens: 100000, + }, + }, } + resp, err := cli.AddResourceGroup(suite.ctx, group) + re.NoError(err) + re.Contains(resp, "Success!") + + tcs := tokenConsumptionPerSecond{rruTokensAtATime: 100} + controller.OnRequestWait(suite.ctx, group.Name, tcs.makeReadRequest()) + meta := controller.GetActiveResourceGroup(group.Name) + re.Equal(meta.RUSettings.RU, group.RUSettings.RU) + + controllerKeySpace.OnRequestWait(suite.ctx, group.Name, tcs.makeReadRequest()) + metaKeySpace := controllerKeySpace.GetActiveResourceGroup(group.Name) + re.Equal(metaKeySpace.RUSettings.RU, group.RUSettings.RU) } const buffDuration = time.Millisecond * 300 From a21fd58d9953eeb8a4ff2657c04de2f44ee29e0e Mon Sep 17 00:00:00 2001 From: Hu# Date: Fri, 22 Sep 2023 18:28:43 +0800 Subject: [PATCH 8/8] security: disable plugin in default and persist file in specified dir (#7087) close tikv/pd#7094 Signed-off-by: husharp Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- Makefile | 6 ++++- pkg/replication/replication_mode.go | 5 ++-- server/api/admin.go | 5 +++- server/api/admin_test.go | 5 ++-- server/api/plugin.go | 3 +++ server/api/plugin_disable.go | 41 +++++++++++++++++++++++++++++ server/api/server_test.go | 23 ++++++++++++++++ server/handler.go | 8 ++++++ server/server.go | 10 ++++++- server/server_test.go | 13 +++++++++ server/util.go | 15 +++++++++++ 11 files changed, 127 insertions(+), 7 deletions(-) create mode 100644 server/api/plugin_disable.go diff --git a/Makefile b/Makefile index 77ad51a1a870..54ad331aea49 100644 --- a/Makefile +++ b/Makefile @@ -42,6 +42,10 @@ ifeq ("$(WITH_RACE)", "1") BUILD_CGO_ENABLED := 1 endif +ifeq ($(PLUGIN), 1) + BUILD_TAGS += with_plugin +endif + LDFLAGS += -X "$(PD_PKG)/pkg/versioninfo.PDReleaseVersion=$(shell git describe --tags --dirty --always)" LDFLAGS += -X "$(PD_PKG)/pkg/versioninfo.PDBuildTS=$(shell date -u '+%Y-%m-%d %I:%M:%S')" LDFLAGS += -X "$(PD_PKG)/pkg/versioninfo.PDGitHash=$(shell git rev-parse HEAD)" @@ -286,4 +290,4 @@ clean-build: rm -rf $(BUILD_BIN_PATH) rm -rf $(GO_TOOLS_BIN_PATH) -.PHONY: clean clean-test clean-build \ No newline at end of file +.PHONY: clean clean-test clean-build diff --git a/pkg/replication/replication_mode.go b/pkg/replication/replication_mode.go index 5a52f562e600..609d7f646c8a 100644 --- a/pkg/replication/replication_mode.go +++ b/pkg/replication/replication_mode.go @@ -60,7 +60,8 @@ type FileReplicater interface { ReplicateFileToMember(ctx context.Context, member *pdpb.Member, name string, data []byte) error } -const drStatusFile = "DR_STATE" +// DrStatusFile is the file name that stores the dr status. +const DrStatusFile = "DR_STATE" const persistFileTimeout = time.Second * 3 // ModeManager is used to control how raft logs are synchronized between @@ -489,7 +490,7 @@ func (m *ModeManager) tickReplicateStatus() { stateID, ok := m.replicateState.Load(member.GetMemberId()) if !ok || stateID.(uint64) != state.StateID { ctx, cancel := context.WithTimeout(context.Background(), persistFileTimeout) - err := m.fileReplicater.ReplicateFileToMember(ctx, member, drStatusFile, data) + err := m.fileReplicater.ReplicateFileToMember(ctx, member, DrStatusFile, data) if err != nil { log.Warn("failed to switch state", zap.String("replicate-mode", modeDRAutoSync), zap.String("new-state", state.State), errs.ZapError(err)) } else { diff --git a/server/api/admin.go b/server/api/admin.go index c81193f1468d..7a1dfb0f1e82 100644 --- a/server/api/admin.go +++ b/server/api/admin.go @@ -111,7 +111,10 @@ func (h *adminHandler) DeleteAllRegionCache(w http.ResponseWriter, r *http.Reque } // Intentionally no swagger mark as it is supposed to be only used in -// server-to-server. For security reason, it only accepts JSON formatted data. +// server-to-server. +// For security reason, +// - it only accepts JSON formatted data. +// - it only accepts file name which is `DrStatusFile`. func (h *adminHandler) SavePersistFile(w http.ResponseWriter, r *http.Request) { data, err := io.ReadAll(r.Body) if err != nil { diff --git a/server/api/admin_test.go b/server/api/admin_test.go index 6a972171e1fe..09130fd83851 100644 --- a/server/api/admin_test.go +++ b/server/api/admin_test.go @@ -26,6 +26,7 @@ import ( "github.com/pingcap/kvproto/pkg/pdpb" "github.com/stretchr/testify/suite" "github.com/tikv/pd/pkg/core" + "github.com/tikv/pd/pkg/replication" tu "github.com/tikv/pd/pkg/utils/testutil" "github.com/tikv/pd/server" ) @@ -167,10 +168,10 @@ func (suite *adminTestSuite) TestDropRegions() { func (suite *adminTestSuite) TestPersistFile() { data := []byte("#!/bin/sh\nrm -rf /") re := suite.Require() - err := tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/admin/persist-file/fun.sh", data, tu.StatusNotOK(re)) + err := tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/admin/persist-file/"+replication.DrStatusFile, data, tu.StatusNotOK(re)) suite.NoError(err) data = []byte(`{"foo":"bar"}`) - err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/admin/persist-file/good.json", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/admin/persist-file/"+replication.DrStatusFile, data, tu.StatusOK(re)) suite.NoError(err) } diff --git a/server/api/plugin.go b/server/api/plugin.go index fd75cc6bb2b8..922fde531f8c 100644 --- a/server/api/plugin.go +++ b/server/api/plugin.go @@ -12,6 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +//go:build with_plugin +// +build with_plugin + package api import ( diff --git a/server/api/plugin_disable.go b/server/api/plugin_disable.go new file mode 100644 index 000000000000..2676dbb91e20 --- /dev/null +++ b/server/api/plugin_disable.go @@ -0,0 +1,41 @@ +// Copyright 2023 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !with_plugin +// +build !with_plugin + +package api + +import ( + "net/http" + + "github.com/tikv/pd/server" + "github.com/unrolled/render" +) + +type pluginHandler struct{} + +func newPluginHandler(_ *server.Handler, _ *render.Render) *pluginHandler { + return &pluginHandler{} +} + +func (h *pluginHandler) LoadPlugin(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotImplemented) + w.Write([]byte("load plugin is disabled, please `PLUGIN=1 $(MAKE) pd-server` first")) +} + +func (h *pluginHandler) UnloadPlugin(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotImplemented) + w.Write([]byte("unload plugin is disabled, please `PLUGIN=1 $(MAKE) pd-server` first")) +} diff --git a/server/api/server_test.go b/server/api/server_test.go index 88253b3a6242..2e89ad797c34 100644 --- a/server/api/server_test.go +++ b/server/api/server_test.go @@ -16,7 +16,9 @@ package api import ( "context" + "fmt" "net/http" + "net/http/httptest" "sort" "sync" "testing" @@ -210,3 +212,24 @@ func (suite *serviceTestSuite) TestServiceLabels() { apiutil.NewAccessPath("/pd/api/v1/metric/query", http.MethodGet)) suite.Equal("QueryMetric", serviceLabel) } + +func (suite *adminTestSuite) TestCleanPath() { + re := suite.Require() + // transfer path to /config + url := fmt.Sprintf("%s/admin/persist-file/../../config", suite.urlPrefix) + cfg := &config.Config{} + err := testutil.ReadGetJSON(re, testDialClient, url, cfg) + suite.NoError(err) + + // handled by router + response := httptest.NewRecorder() + r, _, _ := NewHandler(context.Background(), suite.svr) + request, err := http.NewRequest(http.MethodGet, url, nil) + re.NoError(err) + r.ServeHTTP(response, request) + // handled by `cleanPath` which is in `mux.ServeHTTP` + result := response.Result() + defer result.Body.Close() + re.NotNil(result.Header["Location"]) + re.Contains(result.Header["Location"][0], "/pd/api/v1/config") +} diff --git a/server/handler.go b/server/handler.go index ecc337b71932..ace7592cd7c5 100644 --- a/server/handler.go +++ b/server/handler.go @@ -19,6 +19,7 @@ import ( "net/http" "net/url" "path" + "path/filepath" "strconv" "time" @@ -530,6 +531,13 @@ func (h *Handler) PluginLoad(pluginPath string) error { c := cluster.GetCoordinator() ch := make(chan string) h.pluginChMap[pluginPath] = ch + + // make sure path is in data dir + filePath, err := filepath.Abs(pluginPath) + if err != nil || !isPathInDirectory(filePath, h.s.GetConfig().DataDir) { + return errs.ErrFilePathAbs.Wrap(err).FastGenWithCause() + } + c.LoadPlugin(pluginPath, ch) return nil } diff --git a/server/server.go b/server/server.go index 2fb66387d7a6..ca131debb29a 100644 --- a/server/server.go +++ b/server/server.go @@ -57,6 +57,7 @@ import ( mcs "github.com/tikv/pd/pkg/mcs/utils" "github.com/tikv/pd/pkg/member" "github.com/tikv/pd/pkg/ratelimit" + "github.com/tikv/pd/pkg/replication" sc "github.com/tikv/pd/pkg/schedule/config" "github.com/tikv/pd/pkg/schedule/hbstream" "github.com/tikv/pd/pkg/schedule/placement" @@ -1872,8 +1873,15 @@ func (s *Server) ReplicateFileToMember(ctx context.Context, member *pdpb.Member, // PersistFile saves a file in DataDir. func (s *Server) PersistFile(name string, data []byte) error { + if name != replication.DrStatusFile { + return errors.New("Invalid file name") + } log.Info("persist file", zap.String("name", name), zap.Binary("data", data)) - return os.WriteFile(filepath.Join(s.GetConfig().DataDir, name), data, 0644) // #nosec + path := filepath.Join(s.GetConfig().DataDir, name) + if !isPathInDirectory(path, s.GetConfig().DataDir) { + return errors.New("Invalid file path") + } + return os.WriteFile(path, data, 0644) // #nosec } // SaveTTLConfig save ttl config diff --git a/server/server_test.go b/server/server_test.go index 2d0e23c7682c..62cf5b168fc7 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -19,6 +19,7 @@ import ( "fmt" "io" "net/http" + "path/filepath" "testing" "github.com/stretchr/testify/require" @@ -307,3 +308,15 @@ func TestAPIService(t *testing.T) { MustWaitLeader(re, []*Server{svr}) re.True(svr.IsAPIServiceMode()) } + +func TestIsPathInDirectory(t *testing.T) { + re := require.New(t) + fileName := "test" + directory := "/root/project" + path := filepath.Join(directory, fileName) + re.True(isPathInDirectory(path, directory)) + + fileName = "../../test" + path = filepath.Join(directory, fileName) + re.False(isPathInDirectory(path, directory)) +} diff --git a/server/util.go b/server/util.go index 9c7a97a98066..654b424465e3 100644 --- a/server/util.go +++ b/server/util.go @@ -17,6 +17,7 @@ package server import ( "context" "net/http" + "path/filepath" "strings" "github.com/gorilla/mux" @@ -124,3 +125,17 @@ func combineBuilderServerHTTPService(ctx context.Context, svr *Server, serviceBu userHandlers[pdAPIPrefix] = apiService return userHandlers, nil } + +func isPathInDirectory(path, directory string) bool { + absPath, err := filepath.Abs(path) + if err != nil { + return false + } + + absDir, err := filepath.Abs(directory) + if err != nil { + return false + } + + return strings.HasPrefix(absPath, absDir) +}