diff --git a/pkg/executor/cte.go b/pkg/executor/cte.go index 9928bc652b639..9b58c6687f242 100644 --- a/pkg/executor/cte.go +++ b/pkg/executor/cte.go @@ -112,15 +112,24 @@ func (e *CTEExec) Next(ctx context.Context, req *chunk.Chunk) (err error) { // Close implements the Executor interface. func (e *CTEExec) Close() (err error) { - e.producer.resTbl.Lock() - if !e.producer.closed { - // closeProducer() only close seedExec and recursiveExec, will not touch resTbl. - // It means you can still read resTbl after call closeProducer(). - // You can even call all three functions(openProducer/produce/closeProducer) in CTEExec.Next(). - // Separating these three function calls is only to follow the abstraction of the volcano model. - err = e.producer.closeProducer() - } - e.producer.resTbl.Unlock() + func() { + e.producer.resTbl.Lock() + defer e.producer.resTbl.Unlock() + if !e.producer.closed { + failpoint.Inject("mock_cte_exec_panic_avoid_deadlock", func(v failpoint.Value) { + ok := v.(bool) + if ok { + // mock an oom panic, returning ErrMemoryExceedForQuery for error identification in recovery work. + panic(memory.PanicMemoryExceedWarnMsg) + } + }) + // closeProducer() only close seedExec and recursiveExec, will not touch resTbl. + // It means you can still read resTbl after call closeProducer(). + // You can even call all three functions(openProducer/produce/closeProducer) in CTEExec.Next(). + // Separating these three function calls is only to follow the abstraction of the volcano model. + err = e.producer.closeProducer() + } + }() if err != nil { return err } diff --git a/pkg/executor/cte_test.go b/pkg/executor/cte_test.go index 4133e4401ed97..d81f88c06647c 100644 --- a/pkg/executor/cte_test.go +++ b/pkg/executor/cte_test.go @@ -348,6 +348,37 @@ func TestCTEWithLimit(t *testing.T) { rows.Check(testkit.Rows("3", "4", "3", "4")) } +func TestCTEIssue49096(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + + tk.MustExec("use test;") + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/executor/mock_cte_exec_panic_avoid_deadlock", "return(true)")) + defer func() { + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/executor/mock_cte_exec_panic_avoid_deadlock")) + }() + insertStr := "insert into t1 values(0)" + rowNum := 10 + vals := make([]int, rowNum) + vals[0] = 0 + for i := 1; i < rowNum; i++ { + v := rand.Intn(100) + vals[i] = v + insertStr += fmt.Sprintf(", (%d)", v) + } + tk.MustExec("drop table if exists t1, t2;") + tk.MustExec("create table t1(c1 int);") + tk.MustExec("create table t2(c1 int);") + tk.MustExec(insertStr) + // should be insert statement, otherwise it couldn't step int resetCTEStorageMap in handleNoDelay func. + sql := "insert into t2 with cte1 as ( " + + "select c1 from t1) " + + "select c1 from cte1 natural join (select * from cte1 where c1 > 0) cte2 order by c1;" + err := tk.ExecToErr(sql) + require.NotNil(t, err) + require.Equal(t, "Your query has been cancelled due to exceeding the allowed memory limit", err.Error()) +} + func TestSpillToDisk(t *testing.T) { store := testkit.CreateMockStore(t) diff --git a/pkg/util/cteutil/BUILD.bazel b/pkg/util/cteutil/BUILD.bazel index eaeb81968a4ba..cd8010f53d966 100644 --- a/pkg/util/cteutil/BUILD.bazel +++ b/pkg/util/cteutil/BUILD.bazel @@ -10,6 +10,7 @@ go_library( "//pkg/util/chunk", "//pkg/util/disk", "//pkg/util/memory", + "//pkg/util/syncutil", "@com_github_pingcap_errors//:errors", ], ) diff --git a/pkg/util/cteutil/storage.go b/pkg/util/cteutil/storage.go index a3f334794b2fc..dd172e02b13e7 100644 --- a/pkg/util/cteutil/storage.go +++ b/pkg/util/cteutil/storage.go @@ -15,13 +15,12 @@ package cteutil import ( - "sync" - "github.com/pingcap/errors" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" "github.com/pingcap/tidb/pkg/util/disk" "github.com/pingcap/tidb/pkg/util/memory" + "github.com/pingcap/tidb/pkg/util/syncutil" ) var _ Storage = &StorageRC{} @@ -99,7 +98,7 @@ type StorageRC struct { refCnt int chkSize int iter int - mu sync.Mutex + mu syncutil.Mutex done bool }