diff --git a/executor/cte.go b/executor/cte.go index 06490e218b3ee..e0f3ca8664cc2 100644 --- a/executor/cte.go +++ b/executor/cte.go @@ -107,11 +107,11 @@ func (e *CTEExec) Next(ctx context.Context, req *chunk.Chunk) (err error) { e.producer.resTbl.Lock() defer e.producer.resTbl.Unlock() if !e.producer.resTbl.Done() { - if err = e.producer.produce(ctx, e); err != nil { + if err = e.producer.produce(ctx); err != nil { return err } } - return e.producer.getChunk(ctx, e, req) + return e.producer.getChunk(e, req) } func setFirstErr(firstErr error, newErr error, msg string) error { @@ -271,7 +271,7 @@ func (p *cteProducer) closeProducer() (firstErr error) { return } -func (p *cteProducer) getChunk(ctx context.Context, cteExec *CTEExec, req *chunk.Chunk) (err error) { +func (p *cteProducer) getChunk(cteExec *CTEExec, req *chunk.Chunk) (err error) { req.Reset() if p.hasLimit { return p.nextChunkLimit(cteExec, req) @@ -334,15 +334,15 @@ func (p *cteProducer) nextChunkLimit(cteExec *CTEExec, req *chunk.Chunk) error { return nil } -func (p *cteProducer) produce(ctx context.Context, cteExec *CTEExec) (err error) { +func (p *cteProducer) produce(ctx context.Context) (err error) { if p.resTbl.Error() != nil { return p.resTbl.Error() } - resAction := setupCTEStorageTracker(p.resTbl, cteExec.ctx, p.memTracker, p.diskTracker) - iterInAction := setupCTEStorageTracker(p.iterInTbl, cteExec.ctx, p.memTracker, p.diskTracker) + resAction := setupCTEStorageTracker(p.resTbl, p.ctx, p.memTracker, p.diskTracker) + iterInAction := setupCTEStorageTracker(p.iterInTbl, p.ctx, p.memTracker, p.diskTracker) var iterOutAction *chunk.SpillDiskAction if p.iterOutTbl != nil { - iterOutAction = setupCTEStorageTracker(p.iterOutTbl, cteExec.ctx, p.memTracker, p.diskTracker) + iterOutAction = setupCTEStorageTracker(p.iterOutTbl, p.ctx, p.memTracker, p.diskTracker) } failpoint.Inject("testCTEStorageSpill", func(val failpoint.Value) { @@ -425,12 +425,29 @@ func (p *cteProducer) computeRecursivePart(ctx context.Context) (err error) { return } + var iterNum uint64 for { chk := tryNewCacheChunk(p.recursiveExec) if err = Next(ctx, p.recursiveExec, chk); err != nil { return } if chk.NumRows() == 0 { + if iterNum%1000 == 0 { + // To avoid too many logs. + p.logTbls(ctx, err, iterNum) + } + iterNum++ + failpoint.Inject("assertIterTableSpillToDisk", func(maxIter failpoint.Value) { + if iterNum > 0 && iterNum < uint64(maxIter.(int)) && err == nil { + if p.iterInTbl.GetMemBytes() != 0 || p.iterInTbl.GetDiskBytes() == 0 || + p.iterOutTbl.GetMemBytes() != 0 || p.iterOutTbl.GetDiskBytes() == 0 || + p.resTbl.GetMemBytes() != 0 || p.resTbl.GetDiskBytes() == 0 { + p.logTbls(ctx, err, iterNum) + panic("assert row container spill disk failed") + } + } + }) + if err = p.setupTblsForNewIteration(); err != nil { return } @@ -489,6 +506,8 @@ func (p *cteProducer) setupTblsForNewIteration() (err error) { if err = p.iterInTbl.Reopen(); err != nil { return err } + setupCTEStorageTracker(p.iterInTbl, p.ctx, p.memTracker, p.diskTracker) + if p.isDistinct { // Already deduplicated by resTbl, adding directly is ok. for _, chk := range chks { @@ -503,7 +522,11 @@ func (p *cteProducer) setupTblsForNewIteration() (err error) { } // Clear data in iterOutTbl. - return p.iterOutTbl.Reopen() + if err = p.iterOutTbl.Reopen(); err != nil { + return err + } + setupCTEStorageTracker(p.iterOutTbl, p.ctx, p.memTracker, p.diskTracker) + return nil } func (p *cteProducer) reset() { @@ -531,6 +554,8 @@ func (p *cteProducer) reopenTbls() (err error) { if p.isDistinct { p.hashTbl = newConcurrentMapHashTable() } + // Normally we need to setup tracker after calling Reopen(), + // But reopen resTbl means we need to call produce() again, it will setup tracker. if err := p.resTbl.Reopen(); err != nil { return err } @@ -735,3 +760,11 @@ func (p *cteProducer) checkAndUpdateCorColHashCode() bool { } return changed } + +func (p *cteProducer) logTbls(ctx context.Context, err error, iterNum uint64) { + logutil.Logger(ctx).Debug("cte iteration info", + zap.Any("iterInTbl mem usage", p.iterInTbl.GetMemBytes()), zap.Any("iterInTbl disk usage", p.iterInTbl.GetDiskBytes()), + zap.Any("iterOutTbl mem usage", p.iterOutTbl.GetMemBytes()), zap.Any("iterOutTbl disk usage", p.iterOutTbl.GetDiskBytes()), + zap.Any("resTbl mem usage", p.resTbl.GetMemBytes()), zap.Any("resTbl disk usage", p.resTbl.GetDiskBytes()), + zap.Any("resTbl rows", p.resTbl.NumRows()), zap.Any("iteration num", iterNum), zap.Error(err)) +} diff --git a/executor/cte_test.go b/executor/cte_test.go index 6b49a369be0be..421bb52cfe60c 100644 --- a/executor/cte_test.go +++ b/executor/cte_test.go @@ -572,3 +572,35 @@ func TestIssue46522(t *testing.T) { tk.MustExec("commit;") } + +func TestCTEIterationMemTracker(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + + insertStr := "insert into t1 values(0)" + rowNum := 1000 + vals := make([]int, rowNum) + vals[0] = 0 + for i := 1; i < rowNum; i++ { + v := rand.Intn(100) + vals[i] = v + insertStr += fmt.Sprintf(", (%d)", v) + } + tk.MustExec("use test;") + tk.MustExec("drop table if exists t1;") + tk.MustExec("create table t1(c1 int);") + tk.MustExec(insertStr) + + tk.MustExec("set @@cte_max_recursion_depth=1000000") + tk.MustExec("set global tidb_mem_oom_action = 'log';") + defer func() { + tk.MustExec("set global tidb_mem_oom_action = default;") + }() + tk.MustExec("set @@tidb_mem_quota_query=10;") + maxIter := 5000 + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/executor/assertIterTableSpillToDisk", fmt.Sprintf("return(%d)", maxIter))) + defer func() { + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/executor/assertIterTableSpillToDisk")) + }() + tk.MustQuery(fmt.Sprintf("explain analyze with recursive cte1 as (select c1 from t1 union all select c1 + 1 c1 from cte1 where c1 < %d) select * from cte1", maxIter)) +} diff --git a/util/cteutil/storage.go b/util/cteutil/storage.go index dea6fd632e42b..89f445dc23681 100644 --- a/util/cteutil/storage.go +++ b/util/cteutil/storage.go @@ -89,6 +89,9 @@ type Storage interface { GetMemTracker() *memory.Tracker GetDiskTracker() *disk.Tracker ActionSpill() *chunk.SpillDiskAction + + GetMemBytes() int64 + GetDiskBytes() int64 } // StorageRC implements Storage interface using RowContainer. @@ -269,3 +272,13 @@ func (s *StorageRC) ActionSpillForTest() *chunk.SpillDiskAction { func (s *StorageRC) valid() bool { return s.refCnt > 0 && s.rc != nil } + +// GetMemBytes returns memory bytes used by row container. +func (s *StorageRC) GetMemBytes() int64 { + return s.rc.GetMemTracker().BytesConsumed() +} + +// GetDiskBytes returns disk bytes used by row container. +func (s *StorageRC) GetDiskBytes() int64 { + return s.rc.GetDiskTracker().BytesConsumed() +}