diff --git a/pkg/executor/cte.go b/pkg/executor/cte.go index 1b358c0e18064..346c13bbe6aa0 100644 --- a/pkg/executor/cte.go +++ b/pkg/executor/cte.go @@ -89,16 +89,16 @@ func (e *CTEExec) Open(ctx context.Context) (err error) { defer e.producer.resTbl.Unlock() if e.producer.checkAndUpdateCorColHashCode() { - e.producer.reset() - if err = e.producer.reopenTbls(); err != nil { + err = e.producer.reset() + if err != nil { return err } } if e.producer.openErr != nil { return e.producer.openErr } - if !e.producer.opened { - if err = e.producer.openProducer(ctx, e); err != nil { + if !e.producer.hasCTEResult() && !e.producer.executorOpened { + if err = e.producer.openProducerExecutor(ctx, e); err != nil { return err } } @@ -109,8 +109,14 @@ func (e *CTEExec) Open(ctx context.Context) (err error) { func (e *CTEExec) Next(ctx context.Context, req *chunk.Chunk) (err error) { e.producer.resTbl.Lock() defer e.producer.resTbl.Unlock() - if !e.producer.resTbl.Done() { - if err = e.producer.produce(ctx); err != nil { + if !e.producer.hasCTEResult() { + // in case that another CTEExec call close without generate CTE result. + if !e.producer.executorOpened { + if err = e.producer.openProducerExecutor(ctx, e); err != nil { + return err + } + } + if err = e.producer.genCTEResult(ctx); err != nil { return err } } @@ -132,7 +138,7 @@ func (e *CTEExec) Close() (firstErr error) { func() { e.producer.resTbl.Lock() defer e.producer.resTbl.Unlock() - if !e.producer.closed { + if e.producer.executorOpened { failpoint.Inject("mock_cte_exec_panic_avoid_deadlock", func(v failpoint.Value) { ok := v.(bool) if ok { @@ -140,12 +146,17 @@ func (e *CTEExec) Close() (firstErr error) { panic(exeerrors.ErrMemoryExceedForQuery) } }) - // closeProducer() only close seedExec and recursiveExec, will not touch resTbl. - // It means you can still read resTbl after call closeProducer(). - // You can even call all three functions(openProducer/produce/closeProducer) in CTEExec.Next(). + // closeProducerExecutor() only close seedExec and recursiveExec, will not touch resTbl. + // It means you can still read resTbl after call closeProducerExecutor(). + // You can even call all three functions(openProducerExecutor/genCTEResult/closeProducerExecutor) in CTEExec.Next(). // Separating these three function calls is only to follow the abstraction of the volcano model. - err := e.producer.closeProducer() + err := e.producer.closeProducerExecutor() firstErr = setFirstErr(firstErr, err, "close cte producer error") + if !e.producer.hasCTEResult() { + // CTE result is not generated, in this case, we reset it + err = e.producer.reset() + firstErr = setFirstErr(firstErr, err, "close cte producer error") + } } }() err := e.BaseExecutor.Close() @@ -160,10 +171,10 @@ func (e *CTEExec) reset() { } type cteProducer struct { - // opened should be false when not open or open fail(a.k.a. openErr != nil) - opened bool - produced bool - closed bool + // executorOpened is used to indicate whether the executor(seedExec/recursiveExec) is opened. + // when executorOpened is true, the executor is opened, otherwise it means the executor is + // not opened or is already closed. + executorOpened bool // cteProducer is shared by multiple operators, so if the first operator tries to open // and got error, the second should return open error directly instead of open again. @@ -202,14 +213,10 @@ type cteProducer struct { corColHashCodes [][]byte } -func (p *cteProducer) openProducer(ctx context.Context, cteExec *CTEExec) (err error) { +func (p *cteProducer) openProducerExecutor(ctx context.Context, cteExec *CTEExec) (err error) { defer func() { p.openErr = err - if err == nil { - p.opened = true - } else { - p.opened = false - } + p.executorOpened = true }() if p.seedExec == nil { return errors.New("seedExec for CTEExec is nil") @@ -252,7 +259,7 @@ func (p *cteProducer) openProducer(ctx context.Context, cteExec *CTEExec) (err e return nil } -func (p *cteProducer) closeProducer() (firstErr error) { +func (p *cteProducer) closeProducerExecutor() (firstErr error) { err := exec.Close(p.seedExec) firstErr = setFirstErr(firstErr, err, "close seedExec err") @@ -271,7 +278,7 @@ func (p *cteProducer) closeProducer() (firstErr error) { // because ExplainExec still needs tracker to get mem usage info. p.memTracker = nil p.diskTracker = nil - p.closed = true + p.executorOpened = false return } @@ -338,7 +345,13 @@ func (p *cteProducer) nextChunkLimit(cteExec *CTEExec, req *chunk.Chunk) error { return nil } -func (p *cteProducer) produce(ctx context.Context) (err error) { +func (p *cteProducer) hasCTEResult() bool { + return p.resTbl.Done() +} + +// genCTEResult generates the result of CTE, and stores the result in resTbl. +// This is a synchronous function, which means it will block until the result is generated. +func (p *cteProducer) genCTEResult(ctx context.Context) (err error) { if p.resTbl.Error() != nil { return p.resTbl.Error() } @@ -531,14 +544,18 @@ func (p *cteProducer) setupTblsForNewIteration() (err error) { return nil } -func (p *cteProducer) reset() { +func (p *cteProducer) reset() error { p.curIter = 0 p.hashTbl = nil - - p.opened = false + p.executorOpened = false p.openErr = nil - p.produced = false - p.closed = false + + // Normally we need to setup tracker after calling Reopen(), + // But reopen resTbl means we need to call genCTEResult() again, it will setup tracker. + if err := p.resTbl.Reopen(); err != nil { + return err + } + return p.iterInTbl.Reopen() } func (p *cteProducer) resetTracker() { @@ -552,18 +569,6 @@ func (p *cteProducer) resetTracker() { } } -func (p *cteProducer) reopenTbls() (err error) { - if p.isDistinct { - p.hashTbl = newConcurrentMapHashTable() - } - // Normally we need to setup tracker after calling Reopen(), - // But reopen resTbl means we need to call produce() again, it will setup tracker. - if err := p.resTbl.Reopen(); err != nil { - return err - } - return p.iterInTbl.Reopen() -} - // Check if tbl meets the requirement of limit. func (p *cteProducer) limitDone(tbl cteutil.Storage) bool { return p.hasLimit && uint64(tbl.NumRows()) >= p.limitEnd diff --git a/pkg/executor/test/issuetest/BUILD.bazel b/pkg/executor/test/issuetest/BUILD.bazel index 584acbb753be7..d825d3bd34733 100644 --- a/pkg/executor/test/issuetest/BUILD.bazel +++ b/pkg/executor/test/issuetest/BUILD.bazel @@ -8,7 +8,7 @@ go_test( "main_test.go", ], flaky = True, - shard_count = 21, + shard_count = 22, deps = [ "//pkg/autoid_service", "//pkg/config", diff --git a/pkg/executor/test/issuetest/executor_issue_test.go b/pkg/executor/test/issuetest/executor_issue_test.go index b9e5e04a0c926..5855c1ee17e4a 100644 --- a/pkg/executor/test/issuetest/executor_issue_test.go +++ b/pkg/executor/test/issuetest/executor_issue_test.go @@ -693,3 +693,22 @@ func TestIssue53867(t *testing.T) { // Need no panic tk.MustQuery("select /*+ STREAM_AGG() */ (ref_4.c_k3kss19 / ref_4.c_k3kss19) as c2 from t_bhze93f as ref_4 where (EXISTS (select ref_5.c_wp7o_0sstj as c0 from t_bhze93f as ref_5 where (207007502 < (select distinct ref_6.c_weg as c0 from t_xf1at0 as ref_6 union all (select ref_7.c_xb as c0 from t_b0t as ref_7 where (-16090 != ref_4.c_x393ej_)) limit 1)) limit 1));") } + +func TestIssue55881(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test;") + tk.MustExec("drop table if exists aaa;") + tk.MustExec("drop table if exists bbb;") + tk.MustExec("create table aaa(id int, value int);") + tk.MustExec("create table bbb(id int, value int);") + tk.MustExec("insert into aaa values(1,2),(2,3)") + tk.MustExec("insert into bbb values(1,2),(2,3),(3,4)") + // set tidb_executor_concurrency to 1 to let the issue happens with high probability. + tk.MustExec("set tidb_executor_concurrency=1;") + // this is a random issue, so run it 100 times to increase the probability of the issue. + for i := 0; i < 100; i++ { + tk.MustQuery("with cte as (select * from aaa) select id, (select id from (select * from aaa where aaa.id != bbb.id union all select * from cte union all select * from cte) d limit 1)," + + "(select max(value) from (select * from cte union all select * from cte union all select * from aaa where aaa.id > bbb.id)) from bbb;") + } +}