diff --git a/executor/adapter.go b/executor/adapter.go index d8610a50d2782..fb0f76ffc872d 100644 --- a/executor/adapter.go +++ b/executor/adapter.go @@ -179,8 +179,11 @@ func (a *recordSet) NewChunk(alloc chunk.Allocator) *chunk.Chunk { func (a *recordSet) Close() error { err := a.executor.Close() - a.stmt.CloseRecordSet(a.txnStartTS, a.lastErr) - return err + err1 := a.stmt.CloseRecordSet(a.txnStartTS, a.lastErr) + if err != nil { + return err + } + return err1 } // OnFetchReturned implements commandLifeCycle#OnFetchReturned @@ -496,6 +499,13 @@ func (a *ExecStmt) handleNoDelay(ctx context.Context, e Executor, isPessimistic if sc.DiskTracker != nil { sc.DiskTracker.Detach() } + if handled { + cteErr := resetCTEStorageMap(a.Ctx) + if err == nil { + // Only overwrite err when it's nil. + err = cteErr + } + } } }() @@ -589,8 +599,7 @@ func (c *chunkRowRecordSet) NewChunk(alloc chunk.Allocator) *chunk.Chunk { } func (c *chunkRowRecordSet) Close() error { - c.execStmt.CloseRecordSet(c.execStmt.Ctx.GetSessionVars().TxnCtx.StartTS, nil) - return nil + return c.execStmt.CloseRecordSet(c.execStmt.Ctx.GetSessionVars().TxnCtx.StartTS, nil) } func (a *ExecStmt) handlePessimisticSelectForUpdate(ctx context.Context, e Executor) (sqlexec.RecordSet, error) { @@ -990,7 +999,11 @@ func (a *ExecStmt) FinishExecuteStmt(txnTS uint64, err error, hasMoreResults boo } // CloseRecordSet will finish the execution of current statement and do some record work -func (a *ExecStmt) CloseRecordSet(txnStartTS uint64, lastErr error) { +func (a *ExecStmt) CloseRecordSet(txnStartTS uint64, lastErr error) error { + cteErr := resetCTEStorageMap(a.Ctx) + if cteErr != nil { + logutil.BgLogger().Error("got error when reset cte storage, should check if the spill disk file deleted or not", zap.Error(cteErr)) + } a.FinishExecuteStmt(txnStartTS, lastErr, false) a.logAudit() // Detach the Memory and disk tracker for the previous stmtCtx from GlobalMemoryUsageTracker and GlobalDiskUsageTracker @@ -1002,6 +1015,39 @@ func (a *ExecStmt) CloseRecordSet(txnStartTS uint64, lastErr error) { stmtCtx.MemTracker.Detach() } } + return cteErr +} + +// Clean CTE storage shared by different CTEFullScan executor within a SQL stmt. +// Will return err in two situations: +// 1. Got err when remove disk spill file. +// 2. Some logical error like ref count of CTEStorage is less than 0. +func resetCTEStorageMap(se sessionctx.Context) error { + tmp := se.GetSessionVars().StmtCtx.CTEStorageMap + if tmp == nil { + // Close() is already called, so no need to reset. Such as TraceExec. + return nil + } + storageMap, ok := tmp.(map[int]*CTEStorages) + if !ok { + return errors.New("type assertion for CTEStorageMap failed") + } + for _, v := range storageMap { + v.ResTbl.Lock() + err1 := v.ResTbl.DerefAndClose() + // Make sure we do not hold the lock for longer than necessary. + v.ResTbl.Unlock() + // No need to lock IterInTbl. + err2 := v.IterInTbl.DerefAndClose() + if err1 != nil { + return err1 + } + if err2 != nil { + return err2 + } + } + se.GetSessionVars().StmtCtx.CTEStorageMap = nil + return nil } // LogSlowQuery is used to print the slow query in the log files. diff --git a/executor/cte_test.go b/executor/cte_test.go index d7dd2c6f9ae3a..4efd07e912566 100644 --- a/executor/cte_test.go +++ b/executor/cte_test.go @@ -482,3 +482,19 @@ func TestCTEPanic(t *testing.T) { require.Contains(t, err.Error(), fp) require.NoError(t, failpoint.Disable(fpPathPrefix+fp)) } + +func TestCTEDelSpillFile(t *testing.T) { + store, clean := testkit.CreateMockStore(t) + defer clean() + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test;") + tk.MustExec("drop table if exists t1, t2;") + tk.MustExec("create table t1(c1 int, c2 int);") + tk.MustExec("create table t2(c1 int);") + tk.MustExec("set @@cte_max_recursion_depth = 1000000;") + tk.MustExec("set global tidb_mem_oom_action = 'log';") + tk.MustExec("set @@tidb_mem_quota_query = 100;") + tk.MustExec("insert into t2 values(1);") + tk.MustExec("insert into t1 (c1, c2) with recursive cte1 as (select c1 from t2 union select cte1.c1 + 1 from cte1 where cte1.c1 < 100000) select cte1.c1, cte1.c1+1 from cte1;") + require.Nil(t, tk.Session().GetSessionVars().StmtCtx.CTEStorageMap) +} diff --git a/executor/executor_test.go b/executor/executor_test.go index 97e398f62c249..19e64e53f83e3 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -3514,6 +3514,7 @@ func TestUnreasonablyClose(t *testing.T) { require.NotNil(t, p) // This for loop level traverses the plan tree to get which operators are covered. + var hasCTE bool for child := []plannercore.PhysicalPlan{p.(plannercore.PhysicalPlan)}; len(child) != 0; { newChild := make([]plannercore.PhysicalPlan, 0, len(child)) for _, ch := range child { @@ -3530,6 +3531,7 @@ func TestUnreasonablyClose(t *testing.T) { case *plannercore.PhysicalCTE: newChild = append(newChild, x.RecurPlan) newChild = append(newChild, x.SeedPlan) + hasCTE = true continue case *plannercore.PhysicalShuffle: newChild = append(newChild, x.DataSources...) @@ -3541,6 +3543,12 @@ func TestUnreasonablyClose(t *testing.T) { child = newChild } + if hasCTE { + // Normally CTEStorages will be setup in ResetContextOfStmt. + // But the following case call e.Close() directly, instead of calling session.ExecStmt(), which calls ResetContextOfStmt. + // So need to setup CTEStorages manually. + tk.Session().GetSessionVars().StmtCtx.CTEStorageMap = map[int]*executor.CTEStorages{} + } e := executorBuilder.Build(p) func() { diff --git a/executor/stale_txn_test.go b/executor/stale_txn_test.go index 674c48e1d9d30..26c827fbdc6c7 100644 --- a/executor/stale_txn_test.go +++ b/executor/stale_txn_test.go @@ -945,6 +945,7 @@ func TestSetTransactionInfoSchema(t *testing.T) { defer tk.MustExec("drop table if exists t") tk.MustExec("create table t (id int primary key);") + time.Sleep(100 * time.Millisecond) schemaVer1 := tk.Session().GetInfoSchema().SchemaMetaVersion() time1 := time.Now() time.Sleep(100 * time.Millisecond) diff --git a/session/session.go b/session/session.go index f08e39278929c..4fd489fcb1060 100644 --- a/session/session.go +++ b/session/session.go @@ -2170,40 +2170,9 @@ func (rs *execStmtResult) Close() error { if err := rs.RecordSet.Close(); err != nil { return finishStmt(context.Background(), se, err, rs.sql) } - if err := resetCTEStorageMap(se); err != nil { - return finishStmt(context.Background(), se, err, rs.sql) - } return finishStmt(context.Background(), se, nil, rs.sql) } -func resetCTEStorageMap(se *session) error { - tmp := se.GetSessionVars().StmtCtx.CTEStorageMap - if tmp == nil { - // Close() is already called, so no need to reset. Such as TraceExec. - return nil - } - storageMap, ok := tmp.(map[int]*executor.CTEStorages) - if !ok { - return errors.New("type assertion for CTEStorageMap failed") - } - for _, v := range storageMap { - v.ResTbl.Lock() - err1 := v.ResTbl.DerefAndClose() - // Make sure we do not hold the lock for longer than necessary. - v.ResTbl.Unlock() - // No need to lock IterInTbl. - err2 := v.IterInTbl.DerefAndClose() - if err1 != nil { - return err1 - } - if err2 != nil { - return err2 - } - } - se.GetSessionVars().StmtCtx.CTEStorageMap = nil - return nil -} - // rollbackOnError makes sure the next statement starts a new transaction with the latest InfoSchema. func (s *session) rollbackOnError(ctx context.Context) { if !s.sessionVars.InTxn() {