From c9f9ba90c04d4e9e7160dbd4327851ec41ef8ad9 Mon Sep 17 00:00:00 2001 From: ichn-hu Date: Tue, 19 Nov 2019 11:10:35 +0800 Subject: [PATCH] add memtrack for chunks in aggregation --- executor/aggregate.go | 42 ++++++++++++++++++++++++++++++++++++---- executor/explain_test.go | 7 +++++-- go.sum | 1 + 3 files changed, 44 insertions(+), 6 deletions(-) diff --git a/executor/aggregate.go b/executor/aggregate.go index 2b580849e133b..868691763eb41 100644 --- a/executor/aggregate.go +++ b/executor/aggregate.go @@ -32,6 +32,7 @@ import ( "github.com/pingcap/tidb/util/codec" "github.com/pingcap/tidb/util/logutil" "github.com/pingcap/tidb/util/mathutil" + "github.com/pingcap/tidb/util/memory" "github.com/pingcap/tidb/util/set" "github.com/spaolacci/murmur3" "go.uber.org/zap" @@ -70,7 +71,8 @@ type HashAggPartialWorker struct { groupKey [][]byte // chk stores the input data from child, // and is reused by childExec and partial worker. - chk *chunk.Chunk + chk *chunk.Chunk + memTracker *memory.Tracker } // HashAggFinalWorker indicates the final workers of parallel hash agg execution, @@ -166,6 +168,8 @@ type HashAggExec struct { isUnparallelExec bool prepared bool executed bool + + memTracker *memory.Tracker // track memory usage. } // HashAggInput indicates the input of hash agg exec. @@ -199,6 +203,7 @@ func (d *HashAggIntermData) getPartialResultBatch(sc *stmtctx.StatementContext, // Close implements the Executor Close interface. func (e *HashAggExec) Close() error { if e.isUnparallelExec { + e.memTracker.Consume(-e.childResult.MemoryUsage()) e.childResult = nil e.groupSet = nil e.partialResultMap = nil @@ -221,7 +226,8 @@ func (e *HashAggExec) Close() error { } } for _, ch := range e.partialInputChs { - for range ch { + for chk := range ch { + e.memTracker.Consume(-chk.MemoryUsage()) } } for range e.finalOutputCh { @@ -250,6 +256,9 @@ func (e *HashAggExec) Open(ctx context.Context) error { } e.prepared = false + e.memTracker = memory.NewTracker(e.id, -1) + e.memTracker.AttachTo(e.ctx.GetSessionVars().StmtCtx.MemTracker) + if e.isUnparallelExec { e.initForUnparallelExec() return nil @@ -263,6 +272,7 @@ func (e *HashAggExec) initForUnparallelExec() { e.partialResultMap = make(aggPartialResultMapper) e.groupKeyBuffer = make([][]byte, 0, 8) e.childResult = newFirstChunk(e.children[0]) + e.memTracker.Consume(e.childResult.MemoryUsage()) } func (e *HashAggExec) initForParallelExec(ctx sessionctx.Context) { @@ -298,13 +308,17 @@ func (e *HashAggExec) initForParallelExec(ctx sessionctx.Context) { groupByItems: e.GroupByItems, chk: newFirstChunk(e.children[0]), groupKey: make([][]byte, 0, 8), + memTracker: e.memTracker, } - + e.memTracker.Consume(w.chk.MemoryUsage()) e.partialWorkers[i] = w - e.inputCh <- &HashAggInput{ + + input := &HashAggInput{ chk: newFirstChunk(e.children[0]), giveBackCh: w.inputCh, } + e.memTracker.Consume(input.chk.MemoryUsage()) + e.inputCh <- input } // Init final workers. @@ -356,6 +370,7 @@ func (w *HashAggPartialWorker) run(ctx sessionctx.Context, waitGroup *sync.WaitG if needShuffle { w.shuffleIntermData(sc, finalConcurrency) } + w.memTracker.Consume(-w.chk.MemoryUsage()) waitGroup.Done() }() for { @@ -606,20 +621,28 @@ func (e *HashAggExec) fetchChildData(ctx context.Context) { } chk = input.chk } + mSize := chk.MemoryUsage() err = Next(ctx, e.children[0], chk) if err != nil { e.finalOutputCh <- &AfFinalResult{err: err} + e.memTracker.Consume(-mSize) return } if chk.NumRows() == 0 { + e.memTracker.Consume(-mSize) return } + e.memTracker.Consume(chk.MemoryUsage() - mSize) input.giveBackCh <- chk } } func (e *HashAggExec) waitPartialWorkerAndCloseOutputChs(waitGroup *sync.WaitGroup) { waitGroup.Wait() + close(e.inputCh) + for input := range e.inputCh { + e.memTracker.Consume(-input.chk.MemoryUsage()) + } for _, ch := range e.partialOutputChs { close(ch) } @@ -733,7 +756,9 @@ func (e *HashAggExec) unparallelExec(ctx context.Context, chk *chunk.Chunk) erro // execute fetches Chunks from src and update each aggregate function for each row in Chunk. func (e *HashAggExec) execute(ctx context.Context) (err error) { for { + mSize := e.childResult.MemoryUsage() err := Next(ctx, e.children[0], e.childResult) + e.memTracker.Consume(e.childResult.MemoryUsage() - mSize) if err != nil { return err } @@ -800,6 +825,8 @@ type StreamAggExec struct { partialResults []aggfuncs.PartialResult groupRows []chunk.Row childResult *chunk.Chunk + + memTracker *memory.Tracker // track memory usage. } // Open implements the Executor Open interface. @@ -818,11 +845,16 @@ func (e *StreamAggExec) Open(ctx context.Context) error { e.partialResults = append(e.partialResults, aggFunc.AllocPartialResult()) } + // bytesLimit <= 0 means no limit, for now we just track the memory footprint + e.memTracker = memory.NewTracker(e.id, -1) + e.memTracker.AttachTo(e.ctx.GetSessionVars().StmtCtx.MemTracker) + e.memTracker.Consume(e.childResult.MemoryUsage()) return nil } // Close implements the Executor Close interface. func (e *StreamAggExec) Close() error { + e.memTracker.Consume(-e.childResult.MemoryUsage()) e.childResult = nil e.groupChecker.reset() return e.baseExecutor.Close() @@ -910,7 +942,9 @@ func (e *StreamAggExec) consumeCurGroupRowsAndFetchChild(ctx context.Context, ch return err } + mSize := e.childResult.MemoryUsage() err = Next(ctx, e.children[0], e.childResult) + e.memTracker.Consume(e.childResult.MemoryUsage() - mSize) if err != nil { return err } diff --git a/executor/explain_test.go b/executor/explain_test.go index 3af5cc2321687..96c2880beff78 100644 --- a/executor/explain_test.go +++ b/executor/explain_test.go @@ -129,7 +129,7 @@ func (s *testSuite1) TestExplainAnalyzeMemory(c *C) { func (s *testSuite1) checkMemoryInfo(c *C, tk *testkit.TestKit, sql string) { memCol := 5 - ops := []string{"Join", "Reader", "Top", "Sort", "LookUp", "Projection", "Selection"} + ops := []string{"Join", "Reader", "Top", "Sort", "LookUp", "Projection", "Selection", "Agg"} rows := tk.MustQuery(sql).Rows() for _, row := range rows { strs := make([]string, len(row)) @@ -165,7 +165,10 @@ func (s *testSuite1) TestMemoryUsageAfterClose(c *C) { } SQLs := []string{"select v+abs(k) from t", "select v from t where abs(v) > 0", - "select v from t order by v"} + "select v from t order by v", + "select count(v) from t", // StreamAgg + "select count(v) from t group by v", // HashAgg + } for _, sql := range SQLs { tk.MustQuery(sql) c.Assert(tk.Se.GetSessionVars().StmtCtx.MemTracker.BytesConsumed(), Equals, int64(0)) diff --git a/go.sum b/go.sum index 1705f9bf602dc..a9f827474d558 100644 --- a/go.sum +++ b/go.sum @@ -239,6 +239,7 @@ github.com/remyoudompheng/bigfft v0.0.0-20190728182440-6a916e37a237 h1:HQagqIiBm github.com/remyoudompheng/bigfft v0.0.0-20190728182440-6a916e37a237/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/sergi/go-diff v1.0.1-0.20180205163309-da645544ed44 h1:tB9NOR21++IjLyVx3/PCPhWMwqGNCMQEH96A6dMZ/gc= github.com/sergi/go-diff v1.0.1-0.20180205163309-da645544ed44/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/shirou/gopsutil v2.19.10+incompatible h1:lA4Pi29JEVIQIgATSeftHSY0rMGI9CLrl2ZvDLiahto= github.com/shirou/gopsutil v2.19.10+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=