diff --git a/statistics/cmsketch_bench_test.go b/statistics/cmsketch_bench_test.go index 08666c4c2c3db..e68d102ca57da 100644 --- a/statistics/cmsketch_bench_test.go +++ b/statistics/cmsketch_bench_test.go @@ -123,10 +123,7 @@ func benchmarkMergeGlobalStatsTopNByConcurrencyWithHists(partitions int, b *test h.Buckets = append(h.Buckets, statistics.Bucket{Repeat: 10, Count: 40}) hists = append(hists, h) } - wrapper := &statistics.StatsWrapper{ - AllTopN: topNs, - AllHg: hists, - } + wrapper := statistics.NewStatsWrapper(hists, topNs) const mergeConcurrency = 4 batchSize := len(wrapper.AllTopN) / mergeConcurrency if batchSize < 1 { diff --git a/statistics/handle/handle.go b/statistics/handle/handle.go index a086c3315509a..8bd8ea1d7152b 100644 --- a/statistics/handle/handle.go +++ b/statistics/handle/handle.go @@ -15,7 +15,6 @@ package handle import ( - "bytes" "context" "encoding/json" "fmt" @@ -894,19 +893,15 @@ func MergeGlobalStatsTopNByConcurrency(mergeConcurrency, mergeBatchSize int, wra // handle Error hasErr := false + errMsg := make([]string, 0) for resp := range respCh { if resp.Err != nil { hasErr = true + errMsg = append(errMsg, resp.Err.Error()) } resps = append(resps, resp) } if hasErr { - errMsg := make([]string, 0) - for _, resp := range resps { - if resp.Err != nil { - errMsg = append(errMsg, resp.Err.Error()) - } - } return nil, nil, nil, errors.New(strings.Join(errMsg, ",")) } @@ -918,6 +913,7 @@ func MergeGlobalStatsTopNByConcurrency(mergeConcurrency, mergeBatchSize int, wra sorted = append(sorted, resp.TopN.TopN...) } leftTopn = append(leftTopn, resp.PopedTopn...) +<<<<<<< HEAD for i, removeTopn := range resp.RemoveVals { // Remove the value from the Hists. if len(removeTopn) > 0 { @@ -929,6 +925,8 @@ func MergeGlobalStatsTopNByConcurrency(mergeConcurrency, mergeBatchSize int, wra wrapper.AllHg[i].RemoveVals(tmp) } } +======= +>>>>>>> e9f4e31b41e (statistics: improve memory for mergeGlobalStatsTopNByConcurrency (#45993)) } globalTopN, popedTopn := statistics.GetMergedTopNFromSortedSlice(sorted, n) diff --git a/statistics/merge_worker.go b/statistics/merge_worker.go index e554da5370453..7b1e2dbbcf767 100644 --- a/statistics/merge_worker.go +++ b/statistics/merge_worker.go @@ -15,6 +15,7 @@ package statistics import ( + "sync" "sync/atomic" "time" @@ -44,6 +45,8 @@ type topnStatsMergeWorker struct { respCh chan<- *TopnStatsMergeResponse // the stats in the wrapper should only be read during the worker statsWrapper *StatsWrapper + // shardMutex is used to protect `statsWrapper.AllHg` + shardMutex []sync.Mutex } // NewTopnStatsMergeWorker returns topn merge worker @@ -57,6 +60,7 @@ func NewTopnStatsMergeWorker( respCh: respCh, } worker.statsWrapper = wrapper + worker.shardMutex = make([]sync.Mutex, len(wrapper.AllHg)) worker.killed = killed return worker } @@ -77,10 +81,16 @@ func NewTopnStatsMergeTask(start, end int) *TopnStatsMergeTask { // TopnStatsMergeResponse indicates topn merge worker response type TopnStatsMergeResponse struct { +<<<<<<< HEAD TopN *TopN PopedTopn []TopNMeta RemoveVals [][]TopNMeta Err error +======= + Err error + TopN *TopN + PopedTopn []TopNMeta +>>>>>>> e9f4e31b41e (statistics: improve memory for mergeGlobalStatsTopNByConcurrency (#45993)) } // Run runs topn merge like statistics.MergePartTopN2GlobalTopN @@ -99,7 +109,6 @@ func (worker *topnStatsMergeWorker) Run(timeZone *time.Location, isIndex bool, return } partNum := len(allTopNs) - removeVals := make([][]TopNMeta, partNum) // Different TopN structures may hold the same value, we have to merge them. counter := make(map[hack.MutableString]float64) // datumMap is used to store the mapping from the string type to datum type. @@ -168,13 +177,13 @@ func (worker *topnStatsMergeWorker) Run(timeZone *time.Location, isIndex bool, if count != 0 { counter[encodedVal] += count // Remove the value corresponding to encodedVal from the histogram. - removeVals[j] = append(removeVals[j], TopNMeta{Encoded: datum.GetBytes(), Count: uint64(count)}) + worker.shardMutex[j].Lock() + worker.statsWrapper.AllHg[j].BinarySearchRemoveVal(TopNMeta{Encoded: datum.GetBytes(), Count: uint64(count)}) + worker.shardMutex[j].Unlock() } } } } - // record remove values - resp.RemoveVals = removeVals numTop := len(counter) if numTop == 0 {