diff --git a/pkg/mcs/scheduling/server/rule/watcher.go b/pkg/mcs/scheduling/server/rule/watcher.go index 4cad6fdcbae..912fb9c01e5 100644 --- a/pkg/mcs/scheduling/server/rule/watcher.go +++ b/pkg/mcs/scheduling/server/rule/watcher.go @@ -20,6 +20,9 @@ import ( "sync" "github.com/pingcap/log" + "github.com/tikv/pd/pkg/schedule/checker" + "github.com/tikv/pd/pkg/schedule/labeler" + "github.com/tikv/pd/pkg/schedule/placement" "github.com/tikv/pd/pkg/storage/endpoint" "github.com/tikv/pd/pkg/utils/etcdutil" "go.etcd.io/etcd/clientv3" @@ -49,18 +52,27 @@ type Watcher struct { etcdClient *clientv3.Client ruleStorage endpoint.RuleStorage + // checkerController is used to add the suspect key ranges to the checker when the rule changed. + checkerController *checker.Controller + // ruleManager is used to manage the placement rules. + ruleManager *placement.RuleManager + // regionLabeler is used to manage the region label rules. + regionLabeler *labeler.RegionLabeler + ruleWatcher *etcdutil.LoopWatcher groupWatcher *etcdutil.LoopWatcher labelWatcher *etcdutil.LoopWatcher } // NewWatcher creates a new watcher to watch the Placement Rule change from PD API server. -// Please use `GetRuleStorage` to get the underlying storage to access the Placement Rules. func NewWatcher( ctx context.Context, etcdClient *clientv3.Client, clusterID uint64, ruleStorage endpoint.RuleStorage, + checkerController *checker.Controller, + ruleManager *placement.RuleManager, + regionLabeler *labeler.RegionLabeler, ) (*Watcher, error) { ctx, cancel := context.WithCancel(ctx) rw := &Watcher{ @@ -71,6 +83,9 @@ func NewWatcher( regionLabelPathPrefix: endpoint.RegionLabelPathPrefix(clusterID), etcdClient: etcdClient, ruleStorage: ruleStorage, + checkerController: checkerController, + ruleManager: ruleManager, + regionLabeler: regionLabeler, } err := rw.initializeRuleWatcher() if err != nil { @@ -90,17 +105,31 @@ func NewWatcher( func (rw *Watcher) initializeRuleWatcher() error { prefixToTrim := rw.rulesPathPrefix + "/" putFn := func(kv *mvccpb.KeyValue) error { - // Since the PD API server will validate the rule before saving it to etcd, - // so we could directly save the string rule in JSON to the storage here. log.Info("update placement rule", zap.String("key", string(kv.Key)), zap.String("value", string(kv.Value))) - return rw.ruleStorage.SaveRuleJSON( - strings.TrimPrefix(string(kv.Key), prefixToTrim), - string(kv.Value), - ) + rule, err := placement.NewRuleFromJSON(kv.Value) + if err != nil { + return err + } + // Update the suspect key ranges in the checker. + rw.checkerController.AddSuspectKeyRange(rule.StartKey, rule.EndKey) + if oldRule := rw.ruleManager.GetRule(rule.GroupID, rule.ID); oldRule != nil { + rw.checkerController.AddSuspectKeyRange(oldRule.StartKey, oldRule.EndKey) + } + return rw.ruleManager.SetRule(rule) } deleteFn := func(kv *mvccpb.KeyValue) error { - log.Info("delete placement rule", zap.String("key", string(kv.Key))) - return rw.ruleStorage.DeleteRule(strings.TrimPrefix(string(kv.Key), prefixToTrim)) + key := string(kv.Key) + log.Info("delete placement rule", zap.String("key", key)) + ruleJSON, err := rw.ruleStorage.LoadRule(strings.TrimPrefix(key, prefixToTrim)) + if err != nil { + return err + } + rule, err := placement.NewRuleFromJSON([]byte(ruleJSON)) + if err != nil { + return err + } + rw.checkerController.AddSuspectKeyRange(rule.StartKey, rule.EndKey) + return rw.ruleManager.DeleteRule(rule.GroupID, rule.ID) } postEventFn := func() error { return nil @@ -120,14 +149,24 @@ func (rw *Watcher) initializeGroupWatcher() error { prefixToTrim := rw.ruleGroupPathPrefix + "/" putFn := func(kv *mvccpb.KeyValue) error { log.Info("update placement rule group", zap.String("key", string(kv.Key)), zap.String("value", string(kv.Value))) - return rw.ruleStorage.SaveRuleGroupJSON( - strings.TrimPrefix(string(kv.Key), prefixToTrim), - string(kv.Value), - ) + ruleGroup, err := placement.NewRuleGroupFromJSON(kv.Value) + if err != nil { + return err + } + // Add all rule key ranges within the group to the suspect key ranges. + for _, rule := range rw.ruleManager.GetRulesByGroup(ruleGroup.ID) { + rw.checkerController.AddSuspectKeyRange(rule.StartKey, rule.EndKey) + } + return rw.ruleManager.SetRuleGroup(ruleGroup) } deleteFn := func(kv *mvccpb.KeyValue) error { - log.Info("delete placement rule group", zap.String("key", string(kv.Key))) - return rw.ruleStorage.DeleteRuleGroup(strings.TrimPrefix(string(kv.Key), prefixToTrim)) + key := string(kv.Key) + log.Info("delete placement rule group", zap.String("key", key)) + trimmedKey := strings.TrimPrefix(key, prefixToTrim) + for _, rule := range rw.ruleManager.GetRulesByGroup(trimmedKey) { + rw.checkerController.AddSuspectKeyRange(rule.StartKey, rule.EndKey) + } + return rw.ruleManager.DeleteRuleGroup(trimmedKey) } postEventFn := func() error { return nil @@ -147,14 +186,16 @@ func (rw *Watcher) initializeRegionLabelWatcher() error { prefixToTrim := rw.regionLabelPathPrefix + "/" putFn := func(kv *mvccpb.KeyValue) error { log.Info("update region label rule", zap.String("key", string(kv.Key)), zap.String("value", string(kv.Value))) - return rw.ruleStorage.SaveRegionRuleJSON( - strings.TrimPrefix(string(kv.Key), prefixToTrim), - string(kv.Value), - ) + rule, err := labeler.NewLabelRuleFromJSON(kv.Value) + if err != nil { + return err + } + return rw.regionLabeler.SetLabelRule(rule) } deleteFn := func(kv *mvccpb.KeyValue) error { - log.Info("delete region label rule", zap.String("key", string(kv.Key))) - return rw.ruleStorage.DeleteRegionRule(strings.TrimPrefix(string(kv.Key), prefixToTrim)) + key := string(kv.Key) + log.Info("delete region label rule", zap.String("key", key)) + return rw.regionLabeler.DeleteLabelRule(strings.TrimPrefix(key, prefixToTrim)) } postEventFn := func() error { return nil diff --git a/pkg/mcs/scheduling/server/server.go b/pkg/mcs/scheduling/server/server.go index 02cb1ba3c70..9caae932037 100644 --- a/pkg/mcs/scheduling/server/server.go +++ b/pkg/mcs/scheduling/server/server.go @@ -455,7 +455,7 @@ func (s *Server) startServer() (err error) { func (s *Server) startCluster(context.Context) error { s.basicCluster = core.NewBasicCluster() s.storage = endpoint.NewStorageEndpoint(kv.NewMemoryKV(), nil) - err := s.startWatcher() + err := s.startMetaConfWatcher() if err != nil { return err } @@ -464,7 +464,13 @@ func (s *Server) startCluster(context.Context) error { if err != nil { return err } + // Inject the cluster components into the config watcher after the scheduler controller is created. s.configWatcher.SetSchedulersController(s.cluster.GetCoordinator().GetSchedulersController()) + // Start the rule watcher after the cluster is created. + err = s.startRuleWatcher() + if err != nil { + return err + } s.cluster.StartBackgroundJobs() return nil } @@ -474,7 +480,7 @@ func (s *Server) stopCluster() { s.stopWatcher() } -func (s *Server) startWatcher() (err error) { +func (s *Server) startMetaConfWatcher() (err error) { s.metaWatcher, err = meta.NewWatcher(s.Context(), s.GetClient(), s.clusterID, s.basicCluster) if err != nil { return err @@ -483,7 +489,12 @@ func (s *Server) startWatcher() (err error) { if err != nil { return err } - s.ruleWatcher, err = rule.NewWatcher(s.Context(), s.GetClient(), s.clusterID, s.storage) + return err +} + +func (s *Server) startRuleWatcher() (err error) { + s.ruleWatcher, err = rule.NewWatcher(s.Context(), s.GetClient(), s.clusterID, s.storage, + s.cluster.GetCoordinator().GetCheckerController(), s.cluster.GetRuleManager(), s.cluster.GetRegionLabeler()) return err } diff --git a/pkg/schedule/labeler/labeler.go b/pkg/schedule/labeler/labeler.go index c525ac5c44f..39722b1a038 100644 --- a/pkg/schedule/labeler/labeler.go +++ b/pkg/schedule/labeler/labeler.go @@ -254,7 +254,7 @@ func (l *RegionLabeler) Patch(patch LabelRulePatch) error { } } - // update inmemory states. + // update in-memory states. l.Lock() defer l.Unlock() diff --git a/pkg/schedule/placement/rule_manager.go b/pkg/schedule/placement/rule_manager.go index 909c0fa1078..bdca4cc1b19 100644 --- a/pkg/schedule/placement/rule_manager.go +++ b/pkg/schedule/placement/rule_manager.go @@ -135,8 +135,10 @@ func (m *RuleManager) Initialize(maxReplica int, locationLabels []string, isolat } func (m *RuleManager) loadRules() error { - var toSave []*Rule - var toDelete []string + var ( + toSave []*Rule + toDelete []string + ) err := m.storage.LoadRules(func(k, v string) { r, err := NewRuleFromJSON([]byte(v)) if err != nil { diff --git a/pkg/storage/endpoint/rule.go b/pkg/storage/endpoint/rule.go index 7e2813c23bd..125c5bc31eb 100644 --- a/pkg/storage/endpoint/rule.go +++ b/pkg/storage/endpoint/rule.go @@ -22,6 +22,7 @@ import ( // RuleStorage defines the storage operations on the rule. type RuleStorage interface { + LoadRule(ruleKey string) (string, error) LoadRules(f func(k, v string)) error SaveRule(ruleKey string, rule interface{}) error SaveRuleJSON(ruleKey, rule string) error @@ -93,6 +94,11 @@ func (se *StorageEndpoint) DeleteRegionRule(ruleKey string) error { return se.Remove(regionLabelKeyPath(ruleKey)) } +// LoadRule load a placement rule from storage. +func (se *StorageEndpoint) LoadRule(ruleKey string) (string, error) { + return se.Load(ruleKeyPath(ruleKey)) +} + // LoadRules loads placement rules from storage. func (se *StorageEndpoint) LoadRules(f func(k, v string)) error { return se.loadRangeByPrefix(rulesPath+"/", f) diff --git a/tests/integrations/mcs/scheduling/rule_test.go b/tests/integrations/mcs/scheduling/rule_test.go index 104204dd625..bffa58d0fe6 100644 --- a/tests/integrations/mcs/scheduling/rule_test.go +++ b/tests/integrations/mcs/scheduling/rule_test.go @@ -19,15 +19,9 @@ import ( "sort" "testing" - "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" - "github.com/tikv/pd/pkg/keyspace" - "github.com/tikv/pd/pkg/mcs/scheduling/server/rule" - "github.com/tikv/pd/pkg/mcs/utils" "github.com/tikv/pd/pkg/schedule/labeler" "github.com/tikv/pd/pkg/schedule/placement" - "github.com/tikv/pd/pkg/storage/endpoint" - "github.com/tikv/pd/pkg/storage/kv" "github.com/tikv/pd/pkg/utils/testutil" "github.com/tikv/pd/tests" ) @@ -41,7 +35,8 @@ type ruleTestSuite struct { // The PD cluster. cluster *tests.TestCluster // pdLeaderServer is the leader server of the PD cluster. - pdLeaderServer *tests.TestServer + pdLeaderServer *tests.TestServer + backendEndpoint string } func TestRule(t *testing.T) { @@ -59,6 +54,7 @@ func (suite *ruleTestSuite) SetupSuite() { re.NoError(err) leaderName := suite.cluster.WaitLeader() suite.pdLeaderServer = suite.cluster.GetServer(leaderName) + suite.backendEndpoint = suite.pdLeaderServer.GetAddr() re.NoError(suite.pdLeaderServer.BootstrapCluster()) } @@ -67,50 +63,18 @@ func (suite *ruleTestSuite) TearDownSuite() { suite.cluster.Destroy() } -func loadRules(re *require.Assertions, ruleStorage endpoint.RuleStorage) (rules []*placement.Rule) { - err := ruleStorage.LoadRules(func(_, v string) { - r, err := placement.NewRuleFromJSON([]byte(v)) - re.NoError(err) - rules = append(rules, r) - }) - re.NoError(err) - return -} - -func loadRuleGroups(re *require.Assertions, ruleStorage endpoint.RuleStorage) (groups []*placement.RuleGroup) { - err := ruleStorage.LoadRuleGroups(func(_, v string) { - rg, err := placement.NewRuleGroupFromJSON([]byte(v)) - re.NoError(err) - groups = append(groups, rg) - }) - re.NoError(err) - return -} - -func loadRegionRules(re *require.Assertions, ruleStorage endpoint.RuleStorage) (rules []*labeler.LabelRule) { - err := ruleStorage.LoadRegionRules(func(_, v string) { - lr, err := labeler.NewLabelRuleFromJSON([]byte(v)) - re.NoError(err) - rules = append(rules, lr) - }) - re.NoError(err) - return -} - func (suite *ruleTestSuite) TestRuleWatch() { re := suite.Require() - ruleStorage := endpoint.NewStorageEndpoint(kv.NewMemoryKV(), nil) - // Create a rule watcher. - _, err := rule.NewWatcher( - suite.ctx, - suite.pdLeaderServer.GetEtcdClient(), - suite.cluster.GetCluster().GetId(), - ruleStorage, - ) + tc, err := tests.NewTestSchedulingCluster(suite.ctx, 1, suite.backendEndpoint) re.NoError(err) - // Check the default rule. - rules := loadRules(re, ruleStorage) + defer tc.Destroy() + + tc.WaitForPrimaryServing(re) + cluster := tc.GetPrimaryServer().GetCluster() + ruleManager := cluster.GetRuleManager() + // Check the default rule and rule group. + rules := ruleManager.GetAllRules() re.Len(rules, 1) re.Equal("pd", rules[0].GroupID) re.Equal("default", rules[0].ID) @@ -119,12 +83,13 @@ func (suite *ruleTestSuite) TestRuleWatch() { re.Empty(rules[0].EndKey) re.Equal(placement.Voter, rules[0].Role) re.Empty(rules[0].LocationLabels) - // Check the empty rule group. - ruleGroups := loadRuleGroups(re, ruleStorage) - re.NoError(err) - re.Empty(ruleGroups) + ruleGroups := ruleManager.GetRuleGroups() + re.Len(ruleGroups, 1) + re.Equal("pd", ruleGroups[0].ID) + re.Equal(0, ruleGroups[0].Index) + re.False(ruleGroups[0].Override) // Set a new rule via the PD API server. - ruleManager := suite.pdLeaderServer.GetRaftCluster().GetRuleManager() + apiRuleManager := suite.pdLeaderServer.GetRaftCluster().GetRuleManager() rule := &placement.Rule{ GroupID: "2", ID: "3", @@ -133,10 +98,10 @@ func (suite *ruleTestSuite) TestRuleWatch() { StartKeyHex: "22", EndKeyHex: "dd", } - err = ruleManager.SetRule(rule) + err = apiRuleManager.SetRule(rule) re.NoError(err) testutil.Eventually(re, func() bool { - rules = loadRules(re, ruleStorage) + rules = ruleManager.GetAllRules() return len(rules) == 2 }) sort.Slice(rules, func(i, j int) bool { @@ -150,10 +115,10 @@ func (suite *ruleTestSuite) TestRuleWatch() { re.Equal(rule.StartKeyHex, rules[1].StartKeyHex) re.Equal(rule.EndKeyHex, rules[1].EndKeyHex) // Delete the rule. - err = ruleManager.DeleteRule(rule.GroupID, rule.ID) + err = apiRuleManager.DeleteRule(rule.GroupID, rule.ID) re.NoError(err) testutil.Eventually(re, func() bool { - rules = loadRules(re, ruleStorage) + rules = ruleManager.GetAllRules() return len(rules) == 1 }) re.Len(rules, 1) @@ -164,30 +129,35 @@ func (suite *ruleTestSuite) TestRuleWatch() { Index: 100, Override: true, } - err = ruleManager.SetRuleGroup(ruleGroup) + err = apiRuleManager.SetRuleGroup(ruleGroup) re.NoError(err) testutil.Eventually(re, func() bool { - ruleGroups = loadRuleGroups(re, ruleStorage) - return len(ruleGroups) == 1 + ruleGroups = ruleManager.GetRuleGroups() + return len(ruleGroups) == 2 }) - re.Len(ruleGroups, 1) - re.Equal(ruleGroup.ID, ruleGroups[0].ID) - re.Equal(ruleGroup.Index, ruleGroups[0].Index) - re.Equal(ruleGroup.Override, ruleGroups[0].Override) + re.Len(ruleGroups, 2) + re.Equal(ruleGroup.ID, ruleGroups[1].ID) + re.Equal(ruleGroup.Index, ruleGroups[1].Index) + re.Equal(ruleGroup.Override, ruleGroups[1].Override) // Delete the rule group. - err = ruleManager.DeleteRuleGroup(ruleGroup.ID) + err = apiRuleManager.DeleteRuleGroup(ruleGroup.ID) re.NoError(err) testutil.Eventually(re, func() bool { - ruleGroups = loadRuleGroups(re, ruleStorage) - return len(ruleGroups) == 0 + ruleGroups = ruleManager.GetRuleGroups() + return len(ruleGroups) == 1 }) - re.Empty(ruleGroups) + re.Len(ruleGroups, 1) // Test the region label rule watch. - labelRules := loadRegionRules(re, ruleStorage) - re.Len(labelRules, 1) - defaultKeyspaceRule := keyspace.MakeLabelRule(utils.DefaultKeyspaceID) - re.Equal(defaultKeyspaceRule, labelRules[0]) + regionLabeler := cluster.GetRegionLabeler() + labelRules := regionLabeler.GetAllLabelRules() + apiRegionLabeler := suite.pdLeaderServer.GetRaftCluster().GetRegionLabeler() + apiLabelRules := apiRegionLabeler.GetAllLabelRules() + re.Len(labelRules, len(apiLabelRules)) + re.Equal(apiLabelRules[0].ID, labelRules[0].ID) + re.Equal(apiLabelRules[0].Index, labelRules[0].Index) + re.Equal(apiLabelRules[0].Labels, labelRules[0].Labels) + re.Equal(apiLabelRules[0].RuleType, labelRules[0].RuleType) // Set a new region label rule. labelRule := &labeler.LabelRule{ ID: "rule1", @@ -195,11 +165,10 @@ func (suite *ruleTestSuite) TestRuleWatch() { RuleType: "key-range", Data: labeler.MakeKeyRanges("1234", "5678"), } - regionLabeler := suite.pdLeaderServer.GetRaftCluster().GetRegionLabeler() - err = regionLabeler.SetLabelRule(labelRule) + err = apiRegionLabeler.SetLabelRule(labelRule) re.NoError(err) testutil.Eventually(re, func() bool { - labelRules = loadRegionRules(re, ruleStorage) + labelRules = regionLabeler.GetAllLabelRules() return len(labelRules) == 2 }) sort.Slice(labelRules, func(i, j int) bool { @@ -220,17 +189,16 @@ func (suite *ruleTestSuite) TestRuleWatch() { SetRules: []*labeler.LabelRule{labelRule}, DeleteRules: []string{"rule1"}, } - err = regionLabeler.Patch(patch) + err = apiRegionLabeler.Patch(patch) re.NoError(err) testutil.Eventually(re, func() bool { - labelRules = loadRegionRules(re, ruleStorage) + labelRules = regionLabeler.GetAllLabelRules() return len(labelRules) == 2 }) sort.Slice(labelRules, func(i, j int) bool { return labelRules[i].ID < labelRules[j].ID }) re.Len(labelRules, 2) - re.Equal(defaultKeyspaceRule, labelRules[0]) re.Equal(labelRule.ID, labelRules[1].ID) re.Equal(labelRule.Labels, labelRules[1].Labels) re.Equal(labelRule.RuleType, labelRules[1].RuleType)