diff --git a/cmd/explaintest/r/cte.result b/cmd/explaintest/r/cte.result index 4f3d979c45001..077e62cf33cd4 100644 --- a/cmd/explaintest/r/cte.result +++ b/cmd/explaintest/r/cte.result @@ -607,3 +607,118 @@ c1 c1 c1 1 1 1 2 2 2 3 3 3 +// Test CTE as inner side of Apply +drop table if exists t1, t2; +create table t1(c1 int, c2 int); +insert into t1 values(2, 1); +insert into t1 values(2, 2); +create table t2(c1 int, c2 int); +insert into t2 values(1, 1); +insert into t2 values(3, 2); +explain select * from t1 where c1 > all(with cte1 as (select c1 from t2 where t2.c2 = t1.c2) select c1 from cte1); +id estRows task access object operator info +Projection_17 10000.00 root test.t1.c1, test.t1.c2 +└─Apply_19 10000.00 root CARTESIAN inner join, other cond:or(and(gt(test.t1.c1, Column#8), if(ne(Column#9, 0), NULL, 1)), or(eq(Column#10, 0), if(isnull(test.t1.c1), NULL, 0))) + ├─TableReader_21(Build) 10000.00 root data:TableFullScan_20 + │ └─TableFullScan_20 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo + └─HashAgg_22(Probe) 1.00 root funcs:max(Column#13)->Column#8, funcs:sum(Column#14)->Column#9, funcs:count(1)->Column#10 + └─Projection_26 10.00 root test.t2.c1, cast(isnull(test.t2.c1), decimal(20,0) BINARY)->Column#14 + └─CTEFullScan_24 10.00 root CTE:cte1 data:CTE_0 +CTE_0 10.00 root Non-Recursive CTE +└─Projection_12(Seed Part) 10.00 root test.t2.c1 + └─TableReader_15 10.00 root data:Selection_14 + └─Selection_14 10.00 cop[tikv] eq(test.t2.c2, test.t1.c2) + └─TableFullScan_13 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo +select * from t1 where c1 > all(with cte1 as (select c1 from t2 where t2.c2 = t1.c2) select c1 from cte1); +c1 c2 +2 1 +// Test semi apply. +insert into t1 values(2, 3); +explain select * from t1 where exists(with cte1 as (select c1 from t2 where t2.c2 = t1.c2) select c1 from cte1); +id estRows task access object operator info +Apply_16 10000.00 root CARTESIAN semi join +├─TableReader_18(Build) 10000.00 root data:TableFullScan_17 +│ └─TableFullScan_17 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo +└─CTEFullScan_19(Probe) 10.00 root CTE:cte1 data:CTE_0 +CTE_0 10.00 root Non-Recursive CTE +└─Projection_10(Seed Part) 10.00 root test.t2.c1 + └─TableReader_13 10.00 root data:Selection_12 + └─Selection_12 10.00 cop[tikv] eq(test.t2.c2, test.t1.c2) + └─TableFullScan_11 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo +select * from t1 where exists(with cte1 as (select c1 from t2 where t2.c2 = t1.c2) select c1 from cte1); +c1 c2 +2 1 +2 2 +// Same as above, but test recursive cte. +explain select * from t1 where c1 > all(with recursive cte1 as (select c1 from t2 where t2.c2 = t1.c2 union all select c1+1 as c1 from cte1 limit 1) select c1 from cte1); +id estRows task access object operator info +Projection_26 10000.00 root test.t1.c1, test.t1.c2 +└─Apply_28 10000.00 root CARTESIAN inner join, other cond:or(and(gt(test.t1.c1, Column#14), if(ne(Column#15, 0), NULL, 1)), or(eq(Column#16, 0), if(isnull(test.t1.c1), NULL, 0))) + ├─TableReader_30(Build) 10000.00 root data:TableFullScan_29 + │ └─TableFullScan_29 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo + └─HashAgg_31(Probe) 1.00 root funcs:max(Column#19)->Column#14, funcs:sum(Column#20)->Column#15, funcs:count(1)->Column#16 + └─Projection_35 20.00 root test.t2.c1, cast(isnull(test.t2.c1), decimal(20,0) BINARY)->Column#20 + └─CTEFullScan_33 20.00 root CTE:cte1 data:CTE_0 +CTE_0 20.00 root Recursive CTE, limit(offset:0, count:1) +├─Projection_19(Seed Part) 10.00 root test.t2.c1 +│ └─TableReader_22 10.00 root data:Selection_21 +│ └─Selection_21 10.00 cop[tikv] eq(test.t2.c2, test.t1.c2) +│ └─TableFullScan_20 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo +└─Projection_23(Recursive Part) 10.00 root cast(plus(test.t2.c1, 1), int(11))->test.t2.c1 + └─CTETable_24 10.00 root Scan on CTE_0 +select * from t1 where c1 > all(with recursive cte1 as (select c1 from t2 where t2.c2 = t1.c2 union all select c1+1 as c1 from cte1 limit 1) select c1 from cte1); +c1 c2 +2 1 +2 3 +explain select * from t1 where exists(with recursive cte1 as (select c1 from t2 where t2.c2 = t1.c2 union all select c1+1 as c1 from cte1 limit 10) select c1 from cte1); +id estRows task access object operator info +Apply_25 10000.00 root CARTESIAN semi join +├─TableReader_27(Build) 10000.00 root data:TableFullScan_26 +│ └─TableFullScan_26 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo +└─CTEFullScan_28(Probe) 20.00 root CTE:cte1 data:CTE_0 +CTE_0 20.00 root Recursive CTE, limit(offset:0, count:10) +├─Projection_17(Seed Part) 10.00 root test.t2.c1 +│ └─TableReader_20 10.00 root data:Selection_19 +│ └─Selection_19 10.00 cop[tikv] eq(test.t2.c2, test.t1.c2) +│ └─TableFullScan_18 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo +└─Projection_21(Recursive Part) 10.00 root cast(plus(test.t2.c1, 1), int(11))->test.t2.c1 + └─CTETable_22 10.00 root Scan on CTE_0 +select * from t1 where exists(with recursive cte1 as (select c1 from t2 where t2.c2 = t1.c2 union all select c1+1 as c1 from cte1 limit 10) select c1 from cte1); +c1 c2 +2 1 +2 2 +// Test correlated col is in recursive part. +explain select * from t1 where c1 > all(with recursive cte1 as (select c1, c2 from t2 union all select c1+1 as c1, c2+1 as c2 from cte1 where cte1.c2=t1.c2) select c1 from cte1); +id estRows task access object operator info +Projection_24 10000.00 root test.t1.c1, test.t1.c2 +└─Apply_26 10000.00 root CARTESIAN inner join, other cond:or(and(gt(test.t1.c1, Column#18), if(ne(Column#19, 0), NULL, 1)), or(eq(Column#20, 0), if(isnull(test.t1.c1), NULL, 0))) + ├─TableReader_28(Build) 10000.00 root data:TableFullScan_27 + │ └─TableFullScan_27 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo + └─HashAgg_29(Probe) 1.00 root funcs:max(Column#23)->Column#18, funcs:sum(Column#24)->Column#19, funcs:count(1)->Column#20 + └─Projection_33 18000.00 root test.t2.c1, cast(isnull(test.t2.c1), decimal(20,0) BINARY)->Column#24 + └─CTEFullScan_31 18000.00 root CTE:cte1 data:CTE_0 +CTE_0 18000.00 root Recursive CTE +├─TableReader_19(Seed Part) 10000.00 root data:TableFullScan_18 +│ └─TableFullScan_18 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo +└─Projection_20(Recursive Part) 8000.00 root cast(plus(test.t2.c1, 1), int(11))->test.t2.c1, cast(plus(test.t2.c2, 1), int(11))->test.t2.c2 + └─Selection_21 8000.00 root eq(test.t2.c2, test.t1.c2) + └─CTETable_22 10000.00 root Scan on CTE_0 +select * from t1 where c1 > all(with recursive cte1 as (select c1, c2 from t2 union all select c1+1 as c1, c2+1 as c2 from cte1 where cte1.c2=t1.c2) select c1 from cte1); +c1 c2 +explain select * from t1 where exists(with recursive cte1 as (select c1, c2 from t2 union all select c1+1 as c1, c2+1 as c2 from cte1 where cte1.c2=t1.c2) select c1 from cte1); +id estRows task access object operator info +Apply_23 10000.00 root CARTESIAN semi join +├─TableReader_25(Build) 10000.00 root data:TableFullScan_24 +│ └─TableFullScan_24 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo +└─CTEFullScan_26(Probe) 18000.00 root CTE:cte1 data:CTE_0 +CTE_0 18000.00 root Recursive CTE +├─TableReader_17(Seed Part) 10000.00 root data:TableFullScan_16 +│ └─TableFullScan_16 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo +└─Projection_18(Recursive Part) 8000.00 root cast(plus(test.t2.c1, 1), int(11))->test.t2.c1, cast(plus(test.t2.c2, 1), int(11))->test.t2.c2 + └─Selection_19 8000.00 root eq(test.t2.c2, test.t1.c2) + └─CTETable_20 10000.00 root Scan on CTE_0 +select * from t1 where exists(with recursive cte1 as (select c1, c2 from t2 union all select c1+1 as c1, c2+1 as c2 from cte1 where cte1.c2=t1.c2) select c1 from cte1); +c1 c2 +2 1 +2 2 +2 3 diff --git a/cmd/explaintest/t/cte.test b/cmd/explaintest/t/cte.test index b5fda97071cc8..df9ee6c3b8d06 100644 --- a/cmd/explaintest/t/cte.test +++ b/cmd/explaintest/t/cte.test @@ -226,3 +226,33 @@ create table tpk1(c1 int primary key); insert into tpk1 values(1), (2), (3); explain with cte1 as (select c1 from tpk) select /*+ merge_join(dt1, dt2) */ * from tpk1 dt1 inner join cte1 dt2 inner join cte1 dt3 on dt1.c1 = dt2.c1 and dt2.c1 = dt3.c1; with cte1 as (select c1 from tpk) select /*+ merge_join(dt1, dt2) */ * from tpk1 dt1 inner join cte1 dt2 inner join cte1 dt3 on dt1.c1 = dt2.c1 and dt2.c1 = dt3.c1; +#case 34 +--echo // Test CTE as inner side of Apply +drop table if exists t1, t2; +create table t1(c1 int, c2 int); +insert into t1 values(2, 1); +insert into t1 values(2, 2); +create table t2(c1 int, c2 int); +insert into t2 values(1, 1); +insert into t2 values(3, 2); +explain select * from t1 where c1 > all(with cte1 as (select c1 from t2 where t2.c2 = t1.c2) select c1 from cte1); +select * from t1 where c1 > all(with cte1 as (select c1 from t2 where t2.c2 = t1.c2) select c1 from cte1); + +--echo // Test semi apply. +insert into t1 values(2, 3); +explain select * from t1 where exists(with cte1 as (select c1 from t2 where t2.c2 = t1.c2) select c1 from cte1); +select * from t1 where exists(with cte1 as (select c1 from t2 where t2.c2 = t1.c2) select c1 from cte1); + +--echo // Same as above, but test recursive cte. +explain select * from t1 where c1 > all(with recursive cte1 as (select c1 from t2 where t2.c2 = t1.c2 union all select c1+1 as c1 from cte1 limit 1) select c1 from cte1); +select * from t1 where c1 > all(with recursive cte1 as (select c1 from t2 where t2.c2 = t1.c2 union all select c1+1 as c1 from cte1 limit 1) select c1 from cte1); + +explain select * from t1 where exists(with recursive cte1 as (select c1 from t2 where t2.c2 = t1.c2 union all select c1+1 as c1 from cte1 limit 10) select c1 from cte1); +select * from t1 where exists(with recursive cte1 as (select c1 from t2 where t2.c2 = t1.c2 union all select c1+1 as c1 from cte1 limit 10) select c1 from cte1); + +--echo // Test correlated col is in recursive part. +explain select * from t1 where c1 > all(with recursive cte1 as (select c1, c2 from t2 union all select c1+1 as c1, c2+1 as c2 from cte1 where cte1.c2=t1.c2) select c1 from cte1); +select * from t1 where c1 > all(with recursive cte1 as (select c1, c2 from t2 union all select c1+1 as c1, c2+1 as c2 from cte1 where cte1.c2=t1.c2) select c1 from cte1); + +explain select * from t1 where exists(with recursive cte1 as (select c1, c2 from t2 union all select c1+1 as c1, c2+1 as c2 from cte1 where cte1.c2=t1.c2) select c1 from cte1); +select * from t1 where exists(with recursive cte1 as (select c1, c2 from t2 union all select c1+1 as c1, c2+1 as c2 from cte1 where cte1.c2=t1.c2) select c1 from cte1); diff --git a/executor/builder.go b/executor/builder.go index db5c9dfb9302d..2343e18b73f7e 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -4540,19 +4540,9 @@ func (b *executorBuilder) buildCTE(v *plannercore.PhysicalCTE) Executor { // 2. Build tables to store intermediate results. chkSize := b.ctx.GetSessionVars().MaxChunkSize tps := seedExec.base().retFieldTypes + // iterOutTbl will be constructed in CTEExec.Open(). var resTbl cteutil.Storage var iterInTbl cteutil.Storage - var iterOutTbl cteutil.Storage - - if v.RecurPlan != nil { - // For non-recursive CTE, the result will be put into resTbl directly. - // So no need to build iterOutTbl. - iterOutTbl := cteutil.NewStorageRowContainer(tps, chkSize) - if err := iterOutTbl.OpenAndRef(); err != nil { - b.err = err - return nil - } - } storageMap, ok := b.ctx.GetSessionVars().StmtCtx.CTEStorageMap.(map[int]*CTEStorages) if !ok { @@ -4601,13 +4591,13 @@ func (b *executorBuilder) buildCTE(v *plannercore.PhysicalCTE) Executor { recursiveExec: recursiveExec, resTbl: resTbl, iterInTbl: iterInTbl, - iterOutTbl: iterOutTbl, chkIdx: 0, isDistinct: v.CTE.IsDistinct, sel: sel, hasLimit: v.CTE.HasLimit, limitBeg: v.CTE.LimitBeg, limitEnd: v.CTE.LimitEnd, + isInApply: v.CTE.IsInApply, } } diff --git a/executor/cte.go b/executor/cte.go index 3ce82c3920559..b6824b1a6e3ee 100644 --- a/executor/cte.go +++ b/executor/cte.go @@ -89,6 +89,11 @@ type CTEExec struct { memTracker *memory.Tracker diskTracker *disk.Tracker + + // isInApply indicates whether CTE is in inner side of Apply + // and should resTbl/iterInTbl be reset for each outer row of Apply. + // Because we reset them when SQL is finished instead of when CTEExec.Close() is called. + isInApply bool } // Open implements the Executor interface. @@ -114,6 +119,9 @@ func (e *CTEExec) Open(ctx context.Context) (err error) { if err = e.recursiveExec.Open(ctx); err != nil { return err } + // For non-recursive CTE, the result will be put into resTbl directly. + // So no need to build iterOutTbl. + // Construct iterOutTbl in Open() instead of buildCTE(), because its destruct is in Close(). recursiveTypes := e.recursiveExec.base().retFieldTypes e.iterOutTbl = cteutil.NewStorageRowContainer(recursiveTypes, e.maxChunkSize) if err = e.iterOutTbl.OpenAndRef(); err != nil { @@ -208,6 +216,11 @@ func (e *CTEExec) Close() (err error) { return err } } + if e.isInApply { + if err = e.reopenTbls(); err != nil { + return err + } + } return e.baseExecutor.Close() } @@ -396,7 +409,9 @@ func (e *CTEExec) reset() { } func (e *CTEExec) reopenTbls() (err error) { - e.hashTbl = newConcurrentMapHashTable() + if e.isDistinct { + e.hashTbl = newConcurrentMapHashTable() + } if err := e.resTbl.Reopen(); err != nil { return err } diff --git a/executor/explainfor_test.go b/executor/explainfor_test.go index e3dbfd42d3a52..c2a68e2ce3b56 100644 --- a/executor/explainfor_test.go +++ b/executor/explainfor_test.go @@ -1388,7 +1388,7 @@ func (s *testPrepareSerialSuite) TestCTE4PlanCache(c *C) { tk.MustExec("set @a=1, @b=2, @c=3, @d=4, @e=5, @f=0;") tk.MustQuery("execute stmt using @f, @a, @f").Check(testkit.Rows("1")) - tk.MustQuery("execute stmt using @a, @b, @a").Check(testkit.Rows("1")) + tk.MustQuery("execute stmt using @a, @b, @a").Sort().Check(testkit.Rows("1", "2")) tk.MustQuery("select @@last_plan_from_cache;").Check(testkit.Rows("0")) tk.MustExec("prepare stmt from 'with recursive c(p) as (select ?), cte(a, b) as (select 1, 1 union select a+?, 1 from cte, c where a < ?) select * from cte order by 1, 2;';") diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index dd1ab40650ac6..f712cb3043b91 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -4401,6 +4401,7 @@ func (b *PlanBuilder) buildProjUponView(ctx context.Context, dbName model.CIStr, // every row from outerPlan and the whole innerPlan. func (b *PlanBuilder) buildApplyWithJoinType(outerPlan, innerPlan LogicalPlan, tp JoinType) LogicalPlan { b.optFlag = b.optFlag | flagPredicatePushDown | flagBuildKeyInfo | flagDecorrelate + setIsInApplyForCTE(innerPlan) ap := LogicalApply{LogicalJoin: LogicalJoin{JoinType: tp}}.Init(b.ctx, b.getSelectOffset()) ap.SetChildren(outerPlan, innerPlan) ap.names = make([]*types.FieldName, outerPlan.Schema().Len()+innerPlan.Schema().Len()) @@ -4426,12 +4427,31 @@ func (b *PlanBuilder) buildSemiApply(outerPlan, innerPlan LogicalPlan, condition return nil, err } + setIsInApplyForCTE(innerPlan) ap := &LogicalApply{LogicalJoin: *join} ap.tp = plancodec.TypeApply ap.self = ap return ap, nil } +// setIsInApplyForCTE indicates CTE is the in inner side of Apply, +// the storage of cte needs to be reset for each outer row. +// It's better to handle this in CTEExec.Close(), but cte storage is closed when SQL is finished. +func setIsInApplyForCTE(p LogicalPlan) { + switch x := p.(type) { + case *LogicalCTE: + x.cte.IsInApply = true + setIsInApplyForCTE(x.cte.seedPartLogicalPlan) + if x.cte.recursivePartLogicalPlan != nil { + setIsInApplyForCTE(x.cte.recursivePartLogicalPlan) + } + default: + for _, child := range p.Children() { + setIsInApplyForCTE(child) + } + } +} + func (b *PlanBuilder) buildMaxOneRow(p LogicalPlan) LogicalPlan { maxOneRow := LogicalMaxOneRow{}.Init(b.ctx, b.getSelectOffset()) maxOneRow.SetChildren(p) diff --git a/planner/core/logical_plans.go b/planner/core/logical_plans.go index 173de775558bb..8716f89ec22ae 100644 --- a/planner/core/logical_plans.go +++ b/planner/core/logical_plans.go @@ -1271,10 +1271,11 @@ type CTEClass struct { // storageID for this CTE. IDForStorage int // optFlag is the optFlag for the whole CTE. - optFlag uint64 - HasLimit bool - LimitBeg uint64 - LimitEnd uint64 + optFlag uint64 + HasLimit bool + LimitBeg uint64 + LimitEnd uint64 + IsInApply bool } // LogicalCTE is for CTE.