diff --git a/pkg/mcs/scheduling/server/rule/watcher.go b/pkg/mcs/scheduling/server/rule/watcher.go index 725df690ab7..2abc84b9c47 100644 --- a/pkg/mcs/scheduling/server/rule/watcher.go +++ b/pkg/mcs/scheduling/server/rule/watcher.go @@ -18,7 +18,6 @@ import ( "context" "strings" "sync" - "sync/atomic" "github.com/pingcap/log" "github.com/tikv/pd/pkg/schedule/checker" @@ -26,6 +25,7 @@ import ( "github.com/tikv/pd/pkg/schedule/placement" "github.com/tikv/pd/pkg/storage/endpoint" "github.com/tikv/pd/pkg/utils/etcdutil" + "github.com/tikv/pd/pkg/utils/syncutil" "go.etcd.io/etcd/clientv3" "go.etcd.io/etcd/mvcc/mvccpb" "go.uber.org/zap" @@ -53,12 +53,16 @@ type Watcher struct { etcdClient *clientv3.Client ruleStorage endpoint.RuleStorage - // checkerController is used to add the suspect key ranges to the checker when the rule changed. - checkerController atomic.Value - // ruleManager is used to manage the placement rules. - ruleManager atomic.Value - // regionLabeler is used to manage the region label rules. - regionLabeler atomic.Value + // components is used to store the cluster components and protect them with a RWMutex lock. + components struct { + syncutil.RWMutex + // checkerController is used to add the suspect key ranges to the checker when the rule changed. + checkerController *checker.Controller + // ruleManager is used to manage the placement rules. + ruleManager *placement.RuleManager + // regionLabeler is used to manage the region label rules. + regionLabeler *labeler.RegionLabeler + } ruleWatcher *etcdutil.LoopWatcher groupWatcher *etcdutil.LoopWatcher @@ -260,32 +264,33 @@ func (rw *Watcher) SetClusterComponents( sc *checker.Controller, rm *placement.RuleManager, rl *labeler.RegionLabeler, -) { - rw.checkerController.Store(sc) - rw.ruleManager.Store(rm) - rw.regionLabeler.Store(rl) +) error { + rw.components.Lock() + defer rw.components.Unlock() + rw.components.checkerController = sc + rw.components.ruleManager = rm + rw.components.regionLabeler = rl + // Reload the rules to make sure that the rules are consistent with the storage. + if err := rm.Reload(); err != nil { + return err + } + return rl.Reload() } func (rw *Watcher) getCheckerController() *checker.Controller { - cc := rw.checkerController.Load() - if cc == nil { - return nil - } - return cc.(*checker.Controller) + rw.components.RLock() + defer rw.components.RUnlock() + return rw.components.checkerController } func (rw *Watcher) getRuleManager() *placement.RuleManager { - rm := rw.ruleManager.Load() - if rm == nil { - return nil - } - return rm.(*placement.RuleManager) + rw.components.RLock() + defer rw.components.RUnlock() + return rw.components.ruleManager } func (rw *Watcher) getRegionLabeler() *labeler.RegionLabeler { - rl := rw.regionLabeler.Load() - if rl == nil { - return nil - } - return rl.(*labeler.RegionLabeler) + rw.components.RLock() + defer rw.components.RUnlock() + return rw.components.regionLabeler } diff --git a/pkg/schedule/labeler/labeler.go b/pkg/schedule/labeler/labeler.go index 39722b1a038..892ef097405 100644 --- a/pkg/schedule/labeler/labeler.go +++ b/pkg/schedule/labeler/labeler.go @@ -149,6 +149,14 @@ func (l *RegionLabeler) buildRangeList() { l.rangeList = builder.Build() } +// Reload loads rules from storage. +func (l *RegionLabeler) Reload() error { + l.Lock() + defer l.Unlock() + l.labelRules = make(map[string]*LabelRule) + return l.loadRules() +} + // GetSplitKeys returns all split keys in the range (start, end). func (l *RegionLabeler) GetSplitKeys(start, end []byte) [][]byte { l.RLock() diff --git a/pkg/schedule/placement/rule_manager.go b/pkg/schedule/placement/rule_manager.go index 909c0fa1078..6532a4083ed 100644 --- a/pkg/schedule/placement/rule_manager.go +++ b/pkg/schedule/placement/rule_manager.go @@ -135,8 +135,10 @@ func (m *RuleManager) Initialize(maxReplica int, locationLabels []string, isolat } func (m *RuleManager) loadRules() error { - var toSave []*Rule - var toDelete []string + var ( + toSave []*Rule + toDelete []string + ) err := m.storage.LoadRules(func(k, v string) { r, err := NewRuleFromJSON([]byte(v)) if err != nil { @@ -261,6 +263,31 @@ func (m *RuleManager) adjustRule(r *Rule, groupID string) (err error) { return nil } +// Reload reloads rules from storage. +func (m *RuleManager) Reload() error { + m.Lock() + defer m.Unlock() + // Only allow to reload when it is initialized. + if !m.initialized { + return nil + } + // Force the rule manager to reload rules from storage. + m.ruleConfig = newRuleConfig() + if err := m.loadRules(); err != nil { + return err + } + if err := m.loadGroups(); err != nil { + return err + } + m.ruleConfig.adjust() + ruleList, err := buildRuleList(m.ruleConfig) + if err != nil { + return err + } + m.ruleList = ruleList + return nil +} + // GetRule returns the Rule with the same (group, id). func (m *RuleManager) GetRule(group, id string) *Rule { m.RLock()