Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

executor: fix CTE bug when used with Apply (#31256) #31382

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions cmd/explaintest/r/cte.result
Original file line number Diff line number Diff line change
Expand Up @@ -607,3 +607,118 @@ c1 c1 c1
1 1 1
2 2 2
3 3 3
// Test CTE as inner side of Apply
drop table if exists t1, t2;
create table t1(c1 int, c2 int);
insert into t1 values(2, 1);
insert into t1 values(2, 2);
create table t2(c1 int, c2 int);
insert into t2 values(1, 1);
insert into t2 values(3, 2);
explain select * from t1 where c1 > all(with cte1 as (select c1 from t2 where t2.c2 = t1.c2) select c1 from cte1);
id estRows task access object operator info
Projection_17 10000.00 root test.t1.c1, test.t1.c2
└─Apply_19 10000.00 root CARTESIAN inner join, other cond:or(and(gt(test.t1.c1, Column#8), if(ne(Column#9, 0), NULL, 1)), or(eq(Column#10, 0), if(isnull(test.t1.c1), NULL, 0)))
├─TableReader_21(Build) 10000.00 root data:TableFullScan_20
│ └─TableFullScan_20 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo
└─HashAgg_22(Probe) 1.00 root funcs:max(Column#13)->Column#8, funcs:sum(Column#14)->Column#9, funcs:count(1)->Column#10
└─Projection_26 10.00 root test.t2.c1, cast(isnull(test.t2.c1), decimal(20,0) BINARY)->Column#14
└─CTEFullScan_24 10.00 root CTE:cte1 data:CTE_0
CTE_0 10.00 root Non-Recursive CTE
└─Projection_12(Seed Part) 10.00 root test.t2.c1
└─TableReader_15 10.00 root data:Selection_14
└─Selection_14 10.00 cop[tikv] eq(test.t2.c2, test.t1.c2)
└─TableFullScan_13 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo
select * from t1 where c1 > all(with cte1 as (select c1 from t2 where t2.c2 = t1.c2) select c1 from cte1);
c1 c2
2 1
// Test semi apply.
insert into t1 values(2, 3);
explain select * from t1 where exists(with cte1 as (select c1 from t2 where t2.c2 = t1.c2) select c1 from cte1);
id estRows task access object operator info
Apply_16 10000.00 root CARTESIAN semi join
├─TableReader_18(Build) 10000.00 root data:TableFullScan_17
│ └─TableFullScan_17 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo
└─CTEFullScan_19(Probe) 10.00 root CTE:cte1 data:CTE_0
CTE_0 10.00 root Non-Recursive CTE
└─Projection_10(Seed Part) 10.00 root test.t2.c1
└─TableReader_13 10.00 root data:Selection_12
└─Selection_12 10.00 cop[tikv] eq(test.t2.c2, test.t1.c2)
└─TableFullScan_11 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo
select * from t1 where exists(with cte1 as (select c1 from t2 where t2.c2 = t1.c2) select c1 from cte1);
c1 c2
2 1
2 2
// Same as above, but test recursive cte.
explain select * from t1 where c1 > all(with recursive cte1 as (select c1 from t2 where t2.c2 = t1.c2 union all select c1+1 as c1 from cte1 limit 1) select c1 from cte1);
id estRows task access object operator info
Projection_26 10000.00 root test.t1.c1, test.t1.c2
└─Apply_28 10000.00 root CARTESIAN inner join, other cond:or(and(gt(test.t1.c1, Column#14), if(ne(Column#15, 0), NULL, 1)), or(eq(Column#16, 0), if(isnull(test.t1.c1), NULL, 0)))
├─TableReader_30(Build) 10000.00 root data:TableFullScan_29
│ └─TableFullScan_29 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo
└─HashAgg_31(Probe) 1.00 root funcs:max(Column#19)->Column#14, funcs:sum(Column#20)->Column#15, funcs:count(1)->Column#16
└─Projection_35 20.00 root test.t2.c1, cast(isnull(test.t2.c1), decimal(20,0) BINARY)->Column#20
└─CTEFullScan_33 20.00 root CTE:cte1 data:CTE_0
CTE_0 20.00 root Recursive CTE, limit(offset:0, count:1)
├─Projection_19(Seed Part) 10.00 root test.t2.c1
│ └─TableReader_22 10.00 root data:Selection_21
│ └─Selection_21 10.00 cop[tikv] eq(test.t2.c2, test.t1.c2)
│ └─TableFullScan_20 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo
└─Projection_23(Recursive Part) 10.00 root cast(plus(test.t2.c1, 1), int(11))->test.t2.c1
└─CTETable_24 10.00 root Scan on CTE_0
select * from t1 where c1 > all(with recursive cte1 as (select c1 from t2 where t2.c2 = t1.c2 union all select c1+1 as c1 from cte1 limit 1) select c1 from cte1);
c1 c2
2 1
2 3
explain select * from t1 where exists(with recursive cte1 as (select c1 from t2 where t2.c2 = t1.c2 union all select c1+1 as c1 from cte1 limit 10) select c1 from cte1);
id estRows task access object operator info
Apply_25 10000.00 root CARTESIAN semi join
├─TableReader_27(Build) 10000.00 root data:TableFullScan_26
│ └─TableFullScan_26 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo
└─CTEFullScan_28(Probe) 20.00 root CTE:cte1 data:CTE_0
CTE_0 20.00 root Recursive CTE, limit(offset:0, count:10)
├─Projection_17(Seed Part) 10.00 root test.t2.c1
│ └─TableReader_20 10.00 root data:Selection_19
│ └─Selection_19 10.00 cop[tikv] eq(test.t2.c2, test.t1.c2)
│ └─TableFullScan_18 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo
└─Projection_21(Recursive Part) 10.00 root cast(plus(test.t2.c1, 1), int(11))->test.t2.c1
└─CTETable_22 10.00 root Scan on CTE_0
select * from t1 where exists(with recursive cte1 as (select c1 from t2 where t2.c2 = t1.c2 union all select c1+1 as c1 from cte1 limit 10) select c1 from cte1);
c1 c2
2 1
2 2
// Test correlated col is in recursive part.
explain select * from t1 where c1 > all(with recursive cte1 as (select c1, c2 from t2 union all select c1+1 as c1, c2+1 as c2 from cte1 where cte1.c2=t1.c2) select c1 from cte1);
id estRows task access object operator info
Projection_24 10000.00 root test.t1.c1, test.t1.c2
└─Apply_26 10000.00 root CARTESIAN inner join, other cond:or(and(gt(test.t1.c1, Column#18), if(ne(Column#19, 0), NULL, 1)), or(eq(Column#20, 0), if(isnull(test.t1.c1), NULL, 0)))
├─TableReader_28(Build) 10000.00 root data:TableFullScan_27
│ └─TableFullScan_27 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo
└─HashAgg_29(Probe) 1.00 root funcs:max(Column#23)->Column#18, funcs:sum(Column#24)->Column#19, funcs:count(1)->Column#20
└─Projection_33 18000.00 root test.t2.c1, cast(isnull(test.t2.c1), decimal(20,0) BINARY)->Column#24
└─CTEFullScan_31 18000.00 root CTE:cte1 data:CTE_0
CTE_0 18000.00 root Recursive CTE
├─TableReader_19(Seed Part) 10000.00 root data:TableFullScan_18
│ └─TableFullScan_18 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo
└─Projection_20(Recursive Part) 8000.00 root cast(plus(test.t2.c1, 1), int(11))->test.t2.c1, cast(plus(test.t2.c2, 1), int(11))->test.t2.c2
└─Selection_21 8000.00 root eq(test.t2.c2, test.t1.c2)
└─CTETable_22 10000.00 root Scan on CTE_0
select * from t1 where c1 > all(with recursive cte1 as (select c1, c2 from t2 union all select c1+1 as c1, c2+1 as c2 from cte1 where cte1.c2=t1.c2) select c1 from cte1);
c1 c2
explain select * from t1 where exists(with recursive cte1 as (select c1, c2 from t2 union all select c1+1 as c1, c2+1 as c2 from cte1 where cte1.c2=t1.c2) select c1 from cte1);
id estRows task access object operator info
Apply_23 10000.00 root CARTESIAN semi join
├─TableReader_25(Build) 10000.00 root data:TableFullScan_24
│ └─TableFullScan_24 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo
└─CTEFullScan_26(Probe) 18000.00 root CTE:cte1 data:CTE_0
CTE_0 18000.00 root Recursive CTE
├─TableReader_17(Seed Part) 10000.00 root data:TableFullScan_16
│ └─TableFullScan_16 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo
└─Projection_18(Recursive Part) 8000.00 root cast(plus(test.t2.c1, 1), int(11))->test.t2.c1, cast(plus(test.t2.c2, 1), int(11))->test.t2.c2
└─Selection_19 8000.00 root eq(test.t2.c2, test.t1.c2)
└─CTETable_20 10000.00 root Scan on CTE_0
select * from t1 where exists(with recursive cte1 as (select c1, c2 from t2 union all select c1+1 as c1, c2+1 as c2 from cte1 where cte1.c2=t1.c2) select c1 from cte1);
c1 c2
2 1
2 2
2 3
30 changes: 30 additions & 0 deletions cmd/explaintest/t/cte.test
Original file line number Diff line number Diff line change
Expand Up @@ -226,3 +226,33 @@ create table tpk1(c1 int primary key);
insert into tpk1 values(1), (2), (3);
explain with cte1 as (select c1 from tpk) select /*+ merge_join(dt1, dt2) */ * from tpk1 dt1 inner join cte1 dt2 inner join cte1 dt3 on dt1.c1 = dt2.c1 and dt2.c1 = dt3.c1;
with cte1 as (select c1 from tpk) select /*+ merge_join(dt1, dt2) */ * from tpk1 dt1 inner join cte1 dt2 inner join cte1 dt3 on dt1.c1 = dt2.c1 and dt2.c1 = dt3.c1;
#case 34
--echo // Test CTE as inner side of Apply
drop table if exists t1, t2;
create table t1(c1 int, c2 int);
insert into t1 values(2, 1);
insert into t1 values(2, 2);
create table t2(c1 int, c2 int);
insert into t2 values(1, 1);
insert into t2 values(3, 2);
explain select * from t1 where c1 > all(with cte1 as (select c1 from t2 where t2.c2 = t1.c2) select c1 from cte1);
select * from t1 where c1 > all(with cte1 as (select c1 from t2 where t2.c2 = t1.c2) select c1 from cte1);

--echo // Test semi apply.
insert into t1 values(2, 3);
explain select * from t1 where exists(with cte1 as (select c1 from t2 where t2.c2 = t1.c2) select c1 from cte1);
select * from t1 where exists(with cte1 as (select c1 from t2 where t2.c2 = t1.c2) select c1 from cte1);

--echo // Same as above, but test recursive cte.
explain select * from t1 where c1 > all(with recursive cte1 as (select c1 from t2 where t2.c2 = t1.c2 union all select c1+1 as c1 from cte1 limit 1) select c1 from cte1);
select * from t1 where c1 > all(with recursive cte1 as (select c1 from t2 where t2.c2 = t1.c2 union all select c1+1 as c1 from cte1 limit 1) select c1 from cte1);

explain select * from t1 where exists(with recursive cte1 as (select c1 from t2 where t2.c2 = t1.c2 union all select c1+1 as c1 from cte1 limit 10) select c1 from cte1);
select * from t1 where exists(with recursive cte1 as (select c1 from t2 where t2.c2 = t1.c2 union all select c1+1 as c1 from cte1 limit 10) select c1 from cte1);

--echo // Test correlated col is in recursive part.
explain select * from t1 where c1 > all(with recursive cte1 as (select c1, c2 from t2 union all select c1+1 as c1, c2+1 as c2 from cte1 where cte1.c2=t1.c2) select c1 from cte1);
select * from t1 where c1 > all(with recursive cte1 as (select c1, c2 from t2 union all select c1+1 as c1, c2+1 as c2 from cte1 where cte1.c2=t1.c2) select c1 from cte1);

explain select * from t1 where exists(with recursive cte1 as (select c1, c2 from t2 union all select c1+1 as c1, c2+1 as c2 from cte1 where cte1.c2=t1.c2) select c1 from cte1);
select * from t1 where exists(with recursive cte1 as (select c1, c2 from t2 union all select c1+1 as c1, c2+1 as c2 from cte1 where cte1.c2=t1.c2) select c1 from cte1);
14 changes: 2 additions & 12 deletions executor/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -4540,19 +4540,9 @@ func (b *executorBuilder) buildCTE(v *plannercore.PhysicalCTE) Executor {
// 2. Build tables to store intermediate results.
chkSize := b.ctx.GetSessionVars().MaxChunkSize
tps := seedExec.base().retFieldTypes
// iterOutTbl will be constructed in CTEExec.Open().
var resTbl cteutil.Storage
var iterInTbl cteutil.Storage
var iterOutTbl cteutil.Storage

if v.RecurPlan != nil {
// For non-recursive CTE, the result will be put into resTbl directly.
// So no need to build iterOutTbl.
iterOutTbl := cteutil.NewStorageRowContainer(tps, chkSize)
if err := iterOutTbl.OpenAndRef(); err != nil {
b.err = err
return nil
}
}

storageMap, ok := b.ctx.GetSessionVars().StmtCtx.CTEStorageMap.(map[int]*CTEStorages)
if !ok {
Expand Down Expand Up @@ -4601,13 +4591,13 @@ func (b *executorBuilder) buildCTE(v *plannercore.PhysicalCTE) Executor {
recursiveExec: recursiveExec,
resTbl: resTbl,
iterInTbl: iterInTbl,
iterOutTbl: iterOutTbl,
chkIdx: 0,
isDistinct: v.CTE.IsDistinct,
sel: sel,
hasLimit: v.CTE.HasLimit,
limitBeg: v.CTE.LimitBeg,
limitEnd: v.CTE.LimitEnd,
isInApply: v.CTE.IsInApply,
}
}

Expand Down
17 changes: 16 additions & 1 deletion executor/cte.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ type CTEExec struct {

memTracker *memory.Tracker
diskTracker *disk.Tracker

// isInApply indicates whether CTE is in inner side of Apply
// and should resTbl/iterInTbl be reset for each outer row of Apply.
// Because we reset them when SQL is finished instead of when CTEExec.Close() is called.
isInApply bool
}

// Open implements the Executor interface.
Expand All @@ -114,6 +119,9 @@ func (e *CTEExec) Open(ctx context.Context) (err error) {
if err = e.recursiveExec.Open(ctx); err != nil {
return err
}
// For non-recursive CTE, the result will be put into resTbl directly.
// So no need to build iterOutTbl.
// Construct iterOutTbl in Open() instead of buildCTE(), because its destruct is in Close().
recursiveTypes := e.recursiveExec.base().retFieldTypes
e.iterOutTbl = cteutil.NewStorageRowContainer(recursiveTypes, e.maxChunkSize)
if err = e.iterOutTbl.OpenAndRef(); err != nil {
Expand Down Expand Up @@ -208,6 +216,11 @@ func (e *CTEExec) Close() (err error) {
return err
}
}
if e.isInApply {
if err = e.reopenTbls(); err != nil {
return err
}
}

return e.baseExecutor.Close()
}
Expand Down Expand Up @@ -396,7 +409,9 @@ func (e *CTEExec) reset() {
}

func (e *CTEExec) reopenTbls() (err error) {
e.hashTbl = newConcurrentMapHashTable()
if e.isDistinct {
e.hashTbl = newConcurrentMapHashTable()
}
if err := e.resTbl.Reopen(); err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion executor/explainfor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1388,7 +1388,7 @@ func (s *testPrepareSerialSuite) TestCTE4PlanCache(c *C) {
tk.MustExec("set @a=1, @b=2, @c=3, @d=4, @e=5, @f=0;")

tk.MustQuery("execute stmt using @f, @a, @f").Check(testkit.Rows("1"))
tk.MustQuery("execute stmt using @a, @b, @a").Check(testkit.Rows("1"))
tk.MustQuery("execute stmt using @a, @b, @a").Sort().Check(testkit.Rows("1", "2"))
tk.MustQuery("select @@last_plan_from_cache;").Check(testkit.Rows("0"))

tk.MustExec("prepare stmt from 'with recursive c(p) as (select ?), cte(a, b) as (select 1, 1 union select a+?, 1 from cte, c where a < ?) select * from cte order by 1, 2;';")
Expand Down
20 changes: 20 additions & 0 deletions planner/core/logical_plan_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -4401,6 +4401,7 @@ func (b *PlanBuilder) buildProjUponView(ctx context.Context, dbName model.CIStr,
// every row from outerPlan and the whole innerPlan.
func (b *PlanBuilder) buildApplyWithJoinType(outerPlan, innerPlan LogicalPlan, tp JoinType) LogicalPlan {
b.optFlag = b.optFlag | flagPredicatePushDown | flagBuildKeyInfo | flagDecorrelate
setIsInApplyForCTE(innerPlan)
ap := LogicalApply{LogicalJoin: LogicalJoin{JoinType: tp}}.Init(b.ctx, b.getSelectOffset())
ap.SetChildren(outerPlan, innerPlan)
ap.names = make([]*types.FieldName, outerPlan.Schema().Len()+innerPlan.Schema().Len())
Expand All @@ -4426,12 +4427,31 @@ func (b *PlanBuilder) buildSemiApply(outerPlan, innerPlan LogicalPlan, condition
return nil, err
}

setIsInApplyForCTE(innerPlan)
ap := &LogicalApply{LogicalJoin: *join}
ap.tp = plancodec.TypeApply
ap.self = ap
return ap, nil
}

// setIsInApplyForCTE indicates CTE is the in inner side of Apply,
// the storage of cte needs to be reset for each outer row.
// It's better to handle this in CTEExec.Close(), but cte storage is closed when SQL is finished.
func setIsInApplyForCTE(p LogicalPlan) {
switch x := p.(type) {
case *LogicalCTE:
x.cte.IsInApply = true
setIsInApplyForCTE(x.cte.seedPartLogicalPlan)
if x.cte.recursivePartLogicalPlan != nil {
setIsInApplyForCTE(x.cte.recursivePartLogicalPlan)
}
default:
for _, child := range p.Children() {
setIsInApplyForCTE(child)
}
}
}

func (b *PlanBuilder) buildMaxOneRow(p LogicalPlan) LogicalPlan {
maxOneRow := LogicalMaxOneRow{}.Init(b.ctx, b.getSelectOffset())
maxOneRow.SetChildren(p)
Expand Down
9 changes: 5 additions & 4 deletions planner/core/logical_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -1271,10 +1271,11 @@ type CTEClass struct {
// storageID for this CTE.
IDForStorage int
// optFlag is the optFlag for the whole CTE.
optFlag uint64
HasLimit bool
LimitBeg uint64
LimitEnd uint64
optFlag uint64
HasLimit bool
LimitBeg uint64
LimitEnd uint64
IsInApply bool
}

// LogicalCTE is for CTE.
Expand Down