Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

executor: make memory tracker for aggregate more accurate. #22463

Merged
merged 18 commits into from
Feb 18, 2021
62 changes: 52 additions & 10 deletions executor/aggregate.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,32 @@ type baseHashAggWorker struct {
aggFuncs []aggfuncs.AggFunc
maxChunkSize int
stats *AggWorkerStat

memTracker *memory.Tracker
BInMap *int // incident B in Go map
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really hard to understand for the reader who does not know the implementation of Map.
We need a more detailed comment

}

func newBaseHashAggWorker(ctx sessionctx.Context, finishCh <-chan struct{}, aggFuncs []aggfuncs.AggFunc, maxChunkSize int) baseHashAggWorker {
return baseHashAggWorker{
const (
// ref https://github.com/golang/go/blob/go1.15.6/src/reflect/type.go#L2162.
// defBucketMemoryUsage = bucketSize*(1+unsafe.Sizeof(string) + unsafe.Sizeof(slice))+2*ptrSize
defBucketMemoryUsage = 8*(1+16+24) + 16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will the bucket size changed by golang in the future?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for curious, is there any way to set it dynamically according to different golang version?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can check runtime.Version() to distinguish different Golang version, and using different estimation way. But we need to implement the estimation way for every different Golang version. (Of course, major version is enough).
Now Golang 1.13,1.14.1.15,1.16 uses the same map implement, so I think now we don't need to distinguish them.

// Maximum average load of a bucket that triggers growth is 6.5.
// Represent as loadFactorNum/loadFactorDen, to allow integer math.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need allow integer math?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we use 6.5 directly, golang complier will throw an error constant 6.5 truncated to integer in the condition len(mapper) > (1<<*w.BInMap)*loadFactorNum/loadFactorDen.
If we want it run, maybe we need add some type conversions, eg float64(len(mapper)) > float64(int(1<<*w.BInMap))*6.5. I think it is inefficient.

loadFactorNum = 13
loadFactorDen = 2
)

func newBaseHashAggWorker(ctx sessionctx.Context, finishCh <-chan struct{}, aggFuncs []aggfuncs.AggFunc,
maxChunkSize int, memTrack *memory.Tracker) baseHashAggWorker {
baseWorker := baseHashAggWorker{
ctx: ctx,
finishCh: finishCh,
aggFuncs: aggFuncs,
maxChunkSize: maxChunkSize,
memTracker: memTrack,
BInMap: new(int),
}
return baseWorker
}

// HashAggPartialWorker indicates the partial workers of parallel hash agg execution,
Expand All @@ -76,8 +93,7 @@ type HashAggPartialWorker struct {
groupKey [][]byte
// chk stores the input data from child,
// and is reused by childExec and partial worker.
chk *chunk.Chunk
memTracker *memory.Tracker
chk *chunk.Chunk
}

// HashAggFinalWorker indicates the final workers of parallel hash agg execution,
Expand Down Expand Up @@ -296,7 +312,7 @@ func (e *HashAggExec) initForParallelExec(ctx sessionctx.Context) {
// Init partial workers.
for i := 0; i < partialConcurrency; i++ {
w := HashAggPartialWorker{
baseHashAggWorker: newBaseHashAggWorker(e.ctx, e.finishCh, e.PartialAggFuncs, e.maxChunkSize),
baseHashAggWorker: newBaseHashAggWorker(e.ctx, e.finishCh, e.PartialAggFuncs, e.maxChunkSize, e.memTracker),
inputCh: e.partialInputChs[i],
outputChs: e.partialOutputChs,
giveBackCh: e.inputCh,
Expand All @@ -305,8 +321,8 @@ func (e *HashAggExec) initForParallelExec(ctx sessionctx.Context) {
groupByItems: e.GroupByItems,
chk: newFirstChunk(e.children[0]),
groupKey: make([][]byte, 0, 8),
memTracker: e.memTracker,
}
e.memTracker.Consume(defBucketMemoryUsage * (1 << *w.BInMap))
wshwsh12 marked this conversation as resolved.
Show resolved Hide resolved
if e.stats != nil {
w.stats = &AggWorkerStat{}
e.stats.PartialStats = append(e.stats.PartialStats, w.stats)
Expand All @@ -324,7 +340,7 @@ func (e *HashAggExec) initForParallelExec(ctx sessionctx.Context) {
// Init final workers.
for i := 0; i < finalConcurrency; i++ {
w := HashAggFinalWorker{
baseHashAggWorker: newBaseHashAggWorker(e.ctx, e.finishCh, e.FinalAggFuncs, e.maxChunkSize),
baseHashAggWorker: newBaseHashAggWorker(e.ctx, e.finishCh, e.FinalAggFuncs, e.maxChunkSize, e.memTracker),
partialResultMap: make(aggPartialResultMapper),
groupSet: set.NewStringSet(),
inputCh: e.partialOutputChs[i],
Expand All @@ -334,6 +350,7 @@ func (e *HashAggExec) initForParallelExec(ctx sessionctx.Context) {
mutableRow: chunk.MutRowFromTypes(retTypes(e)),
groupKeys: make([][]byte, 0, 8),
}
e.memTracker.Consume(defBucketMemoryUsage * (1 << *w.BInMap))
wshwsh12 marked this conversation as resolved.
Show resolved Hide resolved
if e.stats != nil {
w.stats = &AggWorkerStat{}
e.stats.FinalStats = append(e.stats.FinalStats, w.stats)
Expand Down Expand Up @@ -406,8 +423,19 @@ func (w *HashAggPartialWorker) run(ctx sessionctx.Context, waitGroup *sync.WaitG
}
}

func getGroupKeyMemUsage(groupKey [][]byte) int64 {
lzmhhh123 marked this conversation as resolved.
Show resolved Hide resolved
mem := int64(0)
for _, key := range groupKey {
mem += int64(cap(key))
}
mem += 12 * int64(cap(groupKey))
return mem
}

func (w *HashAggPartialWorker) updatePartialResult(ctx sessionctx.Context, sc *stmtctx.StatementContext, chk *chunk.Chunk, finalConcurrency int) (err error) {
memSize := getGroupKeyMemUsage(w.groupKey)
w.groupKey, err = getGroupKey(w.ctx, chk, w.groupKey, w.groupByItems)
w.memTracker.Consume(getGroupKeyMemUsage(w.groupKey) - memSize)
if err != nil {
return err
}
Expand All @@ -418,9 +446,11 @@ func (w *HashAggPartialWorker) updatePartialResult(ctx sessionctx.Context, sc *s
for i := 0; i < numRows; i++ {
for j, af := range w.aggFuncs {
rows[0] = chk.GetRow(i)
if _, err := af.UpdatePartialResult(ctx, rows, partialResults[i][j]); err != nil {
memDelta, err := af.UpdatePartialResult(ctx, rows, partialResults[i][j])
if err != nil {
return err
}
w.memTracker.Consume(memDelta)
}
}
return nil
Expand Down Expand Up @@ -496,10 +526,18 @@ func (w baseHashAggWorker) getPartialResult(sc *stmtctx.StatementContext, groupK
continue
}
for _, af := range w.aggFuncs {
partialResult, _ := af.AllocPartialResult()
partialResult, memDelta := af.AllocPartialResult()
partialResults[i] = append(partialResults[i], partialResult)
w.memTracker.Consume(memDelta)
}
str := string(groupKey[i])
mapper[str] = partialResults[i]
wshwsh12 marked this conversation as resolved.
Show resolved Hide resolved
w.memTracker.Consume(int64(len(groupKey[i])))
// map will expand when count > bucketNum * loadFactor.
if len(mapper) > (1<<*w.BInMap)*loadFactorNum/loadFactorDen {
w.memTracker.Consume(defBucketMemoryUsage * (1 << *w.BInMap))
*w.BInMap++
}
mapper[string(groupKey[i])] = partialResults[i]
}
return partialResults
}
Expand Down Expand Up @@ -541,10 +579,12 @@ func (w *HashAggFinalWorker) consumeIntermData(sctx sessionctx.Context) (err err
for reachEnd := false; !reachEnd; {
intermDataBuffer, groupKeys, reachEnd = input.getPartialResultBatch(sc, intermDataBuffer[:0], w.aggFuncs, w.maxChunkSize)
groupKeysLen := len(groupKeys)
memSize := getGroupKeyMemUsage(w.groupKeys)
w.groupKeys = w.groupKeys[:0]
for i := 0; i < groupKeysLen; i++ {
w.groupKeys = append(w.groupKeys, []byte(groupKeys[i]))
}
w.memTracker.Consume(getGroupKeyMemUsage(w.groupKeys) - memSize)
finalPartialResults := w.getPartialResult(sc, w.groupKeys, w.partialResultMap)
for i, groupKey := range groupKeys {
if !w.groupSet.Exist(groupKey) {
Expand Down Expand Up @@ -575,10 +615,12 @@ func (w *HashAggFinalWorker) getFinalResult(sctx sessionctx.Context) {
return
}
execStart := time.Now()
memSize := getGroupKeyMemUsage(w.groupKeys)
w.groupKeys = w.groupKeys[:0]
for groupKey := range w.groupSet {
w.groupKeys = append(w.groupKeys, []byte(groupKey))
}
w.memTracker.Consume(getGroupKeyMemUsage(w.groupKeys) - memSize)
partialResults := w.getPartialResult(sctx.GetSessionVars().StmtCtx, w.groupKeys, w.partialResultMap)
for i := 0; i < len(w.groupSet); i++ {
for j, af := range w.aggFuncs {
Expand Down