diff --git a/executor/aggregate.go b/executor/aggregate.go index d89a04073ee37..bcba95859299d 100644 --- a/executor/aggregate.go +++ b/executor/aggregate.go @@ -51,15 +51,33 @@ type baseHashAggWorker struct { aggFuncs []aggfuncs.AggFunc maxChunkSize int stats *AggWorkerStat + + memTracker *memory.Tracker + BInMap int // indicate there are 2^BInMap buckets in Golang Map. } -func newBaseHashAggWorker(ctx sessionctx.Context, finishCh <-chan struct{}, aggFuncs []aggfuncs.AggFunc, maxChunkSize int) baseHashAggWorker { - return baseHashAggWorker{ +const ( + // ref https://github.com/golang/go/blob/go1.15.6/src/reflect/type.go#L2162. + // defBucketMemoryUsage = bucketSize*(1+unsafe.Sizeof(string) + unsafe.Sizeof(slice))+2*ptrSize + // The bucket size may be changed by golang implement in the future. + defBucketMemoryUsage = 8*(1+16+24) + 16 + // Maximum average load of a bucket that triggers growth is 6.5. + // Represent as loadFactorNum/loadFactDen, to allow integer math. + loadFactorNum = 13 + loadFactorDen = 2 +) + +func newBaseHashAggWorker(ctx sessionctx.Context, finishCh <-chan struct{}, aggFuncs []aggfuncs.AggFunc, + maxChunkSize int, memTrack *memory.Tracker) baseHashAggWorker { + baseWorker := baseHashAggWorker{ ctx: ctx, finishCh: finishCh, aggFuncs: aggFuncs, maxChunkSize: maxChunkSize, + memTracker: memTrack, + BInMap: 0, } + return baseWorker } // HashAggPartialWorker indicates the partial workers of parallel hash agg execution, @@ -76,8 +94,7 @@ type HashAggPartialWorker struct { groupKey [][]byte // chk stores the input data from child, // and is reused by childExec and partial worker. - chk *chunk.Chunk - memTracker *memory.Tracker + chk *chunk.Chunk } // HashAggFinalWorker indicates the final workers of parallel hash agg execution, @@ -296,7 +313,7 @@ func (e *HashAggExec) initForParallelExec(ctx sessionctx.Context) { // Init partial workers. for i := 0; i < partialConcurrency; i++ { w := HashAggPartialWorker{ - baseHashAggWorker: newBaseHashAggWorker(e.ctx, e.finishCh, e.PartialAggFuncs, e.maxChunkSize), + baseHashAggWorker: newBaseHashAggWorker(e.ctx, e.finishCh, e.PartialAggFuncs, e.maxChunkSize, e.memTracker), inputCh: e.partialInputChs[i], outputChs: e.partialOutputChs, giveBackCh: e.inputCh, @@ -305,8 +322,9 @@ func (e *HashAggExec) initForParallelExec(ctx sessionctx.Context) { groupByItems: e.GroupByItems, chk: newFirstChunk(e.children[0]), groupKey: make([][]byte, 0, 8), - memTracker: e.memTracker, } + // There is a bucket in the empty partialResultsMap. + e.memTracker.Consume(defBucketMemoryUsage * (1 << w.BInMap)) if e.stats != nil { w.stats = &AggWorkerStat{} e.stats.PartialStats = append(e.stats.PartialStats, w.stats) @@ -324,7 +342,7 @@ func (e *HashAggExec) initForParallelExec(ctx sessionctx.Context) { // Init final workers. for i := 0; i < finalConcurrency; i++ { w := HashAggFinalWorker{ - baseHashAggWorker: newBaseHashAggWorker(e.ctx, e.finishCh, e.FinalAggFuncs, e.maxChunkSize), + baseHashAggWorker: newBaseHashAggWorker(e.ctx, e.finishCh, e.FinalAggFuncs, e.maxChunkSize, e.memTracker), partialResultMap: make(aggPartialResultMapper), groupSet: set.NewStringSet(), inputCh: e.partialOutputChs[i], @@ -334,6 +352,8 @@ func (e *HashAggExec) initForParallelExec(ctx sessionctx.Context) { mutableRow: chunk.MutRowFromTypes(retTypes(e)), groupKeys: make([][]byte, 0, 8), } + // There is a bucket in the empty partialResultsMap. + e.memTracker.Consume(defBucketMemoryUsage * (1 << w.BInMap)) if e.stats != nil { w.stats = &AggWorkerStat{} e.stats.FinalStats = append(e.stats.FinalStats, w.stats) @@ -406,8 +426,19 @@ func (w *HashAggPartialWorker) run(ctx sessionctx.Context, waitGroup *sync.WaitG } } +func getGroupKeyMemUsage(groupKey [][]byte) int64 { + mem := int64(0) + for _, key := range groupKey { + mem += int64(cap(key)) + } + mem += 12 * int64(cap(groupKey)) + return mem +} + func (w *HashAggPartialWorker) updatePartialResult(ctx sessionctx.Context, sc *stmtctx.StatementContext, chk *chunk.Chunk, finalConcurrency int) (err error) { + memSize := getGroupKeyMemUsage(w.groupKey) w.groupKey, err = getGroupKey(w.ctx, chk, w.groupKey, w.groupByItems) + w.memTracker.Consume(getGroupKeyMemUsage(w.groupKey) - memSize) if err != nil { return err } @@ -418,9 +449,11 @@ func (w *HashAggPartialWorker) updatePartialResult(ctx sessionctx.Context, sc *s for i := 0; i < numRows; i++ { for j, af := range w.aggFuncs { rows[0] = chk.GetRow(i) - if _, err := af.UpdatePartialResult(ctx, rows, partialResults[i][j]); err != nil { + memDelta, err := af.UpdatePartialResult(ctx, rows, partialResults[i][j]) + if err != nil { return err } + w.memTracker.Consume(memDelta) } } return nil @@ -487,7 +520,7 @@ func getGroupKey(ctx sessionctx.Context, input *chunk.Chunk, groupKey [][]byte, return groupKey, nil } -func (w baseHashAggWorker) getPartialResult(sc *stmtctx.StatementContext, groupKey [][]byte, mapper aggPartialResultMapper) [][]aggfuncs.PartialResult { +func (w *baseHashAggWorker) getPartialResult(sc *stmtctx.StatementContext, groupKey [][]byte, mapper aggPartialResultMapper) [][]aggfuncs.PartialResult { n := len(groupKey) partialResults := make([][]aggfuncs.PartialResult, n) for i := 0; i < n; i++ { @@ -496,10 +529,17 @@ func (w baseHashAggWorker) getPartialResult(sc *stmtctx.StatementContext, groupK continue } for _, af := range w.aggFuncs { - partialResult, _ := af.AllocPartialResult() + partialResult, memDelta := af.AllocPartialResult() partialResults[i] = append(partialResults[i], partialResult) + w.memTracker.Consume(memDelta) } mapper[string(groupKey[i])] = partialResults[i] + w.memTracker.Consume(int64(len(groupKey[i]))) + // Map will expand when count > bucketNum * loadFactor. The memory usage will doubled. + if len(mapper) > (1<