From 26db5909628fee0605d9e53d949b7b4d9e2b64e3 Mon Sep 17 00:00:00 2001 From: Weizhen Wang Date: Fri, 24 Nov 2023 11:53:42 +0800 Subject: [PATCH] *: fix wrong result when to concurrency merge global stats (#48852) close pingcap/tidb#48713 --- .../handle/globalstats/merge_worker.go | 59 ++++++++----------- pkg/statistics/handle/globalstats/topn.go | 34 ++++------- 2 files changed, 39 insertions(+), 54 deletions(-) diff --git a/pkg/statistics/handle/globalstats/merge_worker.go b/pkg/statistics/handle/globalstats/merge_worker.go index 7a63adee65fb2..b702acaa797f7 100644 --- a/pkg/statistics/handle/globalstats/merge_worker.go +++ b/pkg/statistics/handle/globalstats/merge_worker.go @@ -43,8 +43,11 @@ type topnStatsMergeWorker struct { respCh chan<- *TopnStatsMergeResponse // the stats in the wrapper should only be read during the worker statsWrapper *StatsWrapper + // Different TopN structures may hold the same value, we have to merge them. + counter map[hack.MutableString]float64 // shardMutex is used to protect `statsWrapper.AllHg` shardMutex []sync.Mutex + mu sync.Mutex } // NewTopnStatsMergeWorker returns topn merge worker @@ -54,8 +57,9 @@ func NewTopnStatsMergeWorker( wrapper *StatsWrapper, killer *sqlkiller.SQLKiller) *topnStatsMergeWorker { worker := &topnStatsMergeWorker{ - taskCh: taskCh, - respCh: respCh, + taskCh: taskCh, + respCh: respCh, + counter: make(map[hack.MutableString]float64), } worker.statsWrapper = wrapper worker.shardMutex = make([]sync.Mutex, len(wrapper.AllHg)) @@ -79,15 +83,11 @@ func NewTopnStatsMergeTask(start, end int) *TopnStatsMergeTask { // TopnStatsMergeResponse indicates topn merge worker response type TopnStatsMergeResponse struct { - Err error - TopN *statistics.TopN - PopedTopn []statistics.TopNMeta + Err error } // Run runs topn merge like statistics.MergePartTopN2GlobalTopN -func (worker *topnStatsMergeWorker) Run(timeZone *time.Location, isIndex bool, - n uint32, - version int) { +func (worker *topnStatsMergeWorker) Run(timeZone *time.Location, isIndex bool, version int) { for task := range worker.taskCh { start := task.start end := task.end @@ -95,17 +95,12 @@ func (worker *topnStatsMergeWorker) Run(timeZone *time.Location, isIndex bool, allTopNs := worker.statsWrapper.AllTopN allHists := worker.statsWrapper.AllHg resp := &TopnStatsMergeResponse{} - if statistics.CheckEmptyTopNs(checkTopNs) { - worker.respCh <- resp - return - } + partNum := len(allTopNs) - // Different TopN structures may hold the same value, we have to merge them. - counter := make(map[hack.MutableString]float64) + // datumMap is used to store the mapping from the string type to datum type. // The datum is used to find the value in the histogram. datumMap := statistics.NewDatumMapCache() - for i, topN := range checkTopNs { i = i + start if err := worker.killer.HandleSignal(); err != nil { @@ -118,12 +113,15 @@ func (worker *topnStatsMergeWorker) Run(timeZone *time.Location, isIndex bool, } for _, val := range topN.TopN { encodedVal := hack.String(val.Encoded) - _, exists := counter[encodedVal] - counter[encodedVal] += float64(val.Count) + worker.mu.Lock() + _, exists := worker.counter[encodedVal] + worker.counter[encodedVal] += float64(val.Count) if exists { + worker.mu.Unlock() // We have already calculated the encodedVal from the histogram, so just continue to next topN value. continue } + worker.mu.Unlock() // We need to check whether the value corresponding to encodedVal is contained in other partition-level stats. // 1. Check the topN first. // 2. If the topN doesn't contain the value corresponding to encodedVal. We should check the histogram. @@ -147,31 +145,26 @@ func (worker *topnStatsMergeWorker) Run(timeZone *time.Location, isIndex bool, } datum = d } + worker.shardMutex[j].Lock() // Get the row count which the value is equal to the encodedVal from histogram. count, _ := allHists[j].EqualRowCount(nil, datum, isIndex) if count != 0 { - counter[encodedVal] += count // Remove the value corresponding to encodedVal from the histogram. - worker.shardMutex[j].Lock() worker.statsWrapper.AllHg[j].BinarySearchRemoveVal(statistics.TopNMeta{Encoded: datum.GetBytes(), Count: uint64(count)}) - worker.shardMutex[j].Unlock() + } + worker.shardMutex[j].Unlock() + if count != 0 { + worker.mu.Lock() + worker.counter[encodedVal] += count + worker.mu.Unlock() } } } } - numTop := len(counter) - if numTop == 0 { - worker.respCh <- resp - continue - } - sorted := make([]statistics.TopNMeta, 0, numTop) - for value, cnt := range counter { - data := hack.Slice(string(value)) - sorted = append(sorted, statistics.TopNMeta{Encoded: data, Count: uint64(cnt)}) - } - globalTopN, leftTopN := statistics.GetMergedTopNFromSortedSlice(sorted, n) - resp.TopN = globalTopN - resp.PopedTopn = leftTopN worker.respCh <- resp } } + +func (worker *topnStatsMergeWorker) Result() map[hack.MutableString]float64 { + return worker.counter +} diff --git a/pkg/statistics/handle/globalstats/topn.go b/pkg/statistics/handle/globalstats/topn.go index 9e9f14a068a54..171756b82357b 100644 --- a/pkg/statistics/handle/globalstats/topn.go +++ b/pkg/statistics/handle/globalstats/topn.go @@ -30,8 +30,12 @@ import ( func mergeGlobalStatsTopN(gp *gp.Pool, sc sessionctx.Context, wrapper *StatsWrapper, timeZone *time.Location, version int, n uint32, isIndex bool) (*statistics.TopN, []statistics.TopNMeta, []*statistics.Histogram, error) { + if statistics.CheckEmptyTopNs(wrapper.AllTopN) { + return nil, nil, wrapper.AllHg, nil + } mergeConcurrency := sc.GetSessionVars().AnalyzePartitionMergeConcurrency killer := &sc.GetSessionVars().SQLKiller + // use original method if concurrency equals 1 or for version1 if mergeConcurrency < 2 { return MergePartTopN2GlobalTopN(timeZone, version, wrapper.AllTopN, n, wrapper.AllHg, isIndex, killer) @@ -78,12 +82,12 @@ func MergeGlobalStatsTopNByConcurrency( taskNum := len(tasks) taskCh := make(chan *TopnStatsMergeTask, taskNum) respCh := make(chan *TopnStatsMergeResponse, taskNum) + worker := NewTopnStatsMergeWorker(taskCh, respCh, wrapper, killer) for i := 0; i < mergeConcurrency; i++ { - worker := NewTopnStatsMergeWorker(taskCh, respCh, wrapper, killer) wg.Add(1) gp.Go(func() { defer wg.Done() - worker.Run(timeZone, isIndex, n, version) + worker.Run(timeZone, isIndex, version) }) } for _, task := range tasks { @@ -92,8 +96,6 @@ func MergeGlobalStatsTopNByConcurrency( close(taskCh) wg.Wait() close(respCh) - resps := make([]*TopnStatsMergeResponse, 0) - // handle Error hasErr := false errMsg := make([]string, 0) @@ -102,27 +104,21 @@ func MergeGlobalStatsTopNByConcurrency( hasErr = true errMsg = append(errMsg, resp.Err.Error()) } - resps = append(resps, resp) } if hasErr { return nil, nil, nil, errors.New(strings.Join(errMsg, ",")) } // fetch the response from each worker and merge them into global topn stats - sorted := make([]statistics.TopNMeta, 0, mergeConcurrency) - leftTopn := make([]statistics.TopNMeta, 0) - for _, resp := range resps { - if resp.TopN != nil { - sorted = append(sorted, resp.TopN.TopN...) - } - leftTopn = append(leftTopn, resp.PopedTopn...) + counter := worker.Result() + numTop := len(counter) + sorted := make([]statistics.TopNMeta, 0, numTop) + for value, cnt := range counter { + data := hack.Slice(string(value)) + sorted = append(sorted, statistics.TopNMeta{Encoded: data, Count: uint64(cnt)}) } - globalTopN, popedTopn := statistics.GetMergedTopNFromSortedSlice(sorted, n) - - result := append(leftTopn, popedTopn...) - statistics.SortTopnMeta(result) - return globalTopN, result, wrapper.AllHg, nil + return globalTopN, popedTopn, wrapper.AllHg, nil } // MergePartTopN2GlobalTopN is used to merge the partition-level topN to global-level topN. @@ -149,10 +145,6 @@ func MergePartTopN2GlobalTopN( isIndex bool, killer *sqlkiller.SQLKiller, ) (*statistics.TopN, []statistics.TopNMeta, []*statistics.Histogram, error) { - if statistics.CheckEmptyTopNs(topNs) { - return nil, nil, hists, nil - } - partNum := len(topNs) // Different TopN structures may hold the same value, we have to merge them. counter := make(map[hack.MutableString]float64)