Skip to content

Commit

Permalink
planner, expression: avoid exprs with side effects in column pruning …
Browse files Browse the repository at this point in the history
…and agg pushdown (#27370)
  • Loading branch information
time-and-fate authored Aug 27, 2021
1 parent 5fcfd89 commit 8dcebd1
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 15 deletions.
28 changes: 28 additions & 0 deletions planner/core/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4307,3 +4307,31 @@ func (s *testIntegrationSerialSuite) TestTemporaryTableForCte(c *C) {
rows = tk.MustQuery("WITH RECURSIVE cte(a) AS (SELECT 1 UNION SELECT a+1 FROM tmp1 WHERE a < 5) SELECT * FROM cte order by a;")
rows.Check(testkit.Rows("1", "2", "3", "4", "5"))
}

func (s *testIntegrationSuite) TestGroupBySetVar(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
tk.MustExec("drop table if exists t1")
tk.MustExec("create table t1(c1 int);")
tk.MustExec("insert into t1 values(1), (2), (3), (4), (5), (6);")
rows := tk.MustQuery("select floor(dt.rn/2) rownum, count(c1) from (select @rownum := @rownum + 1 rn, c1 from (select @rownum := -1) drn, t1) dt group by floor(dt.rn/2) order by rownum;")
rows.Check(testkit.Rows("0 2", "1 2", "2 2"))

tk.MustExec("create table ta(a int, b int);")
tk.MustExec("set sql_mode='';")

var input []string
var output []struct {
SQL string
Plan []string
}
s.testData.GetTestCases(c, &input, &output)
for i, tt := range input {
res := tk.MustQuery("explain format = 'brief' " + tt)
s.testData.OnRecord(func() {
output[i].SQL = tt
output[i].Plan = s.testData.ConvertRowsToStrings(res.Rows())
})
res.Check(testkit.Rows(output[i].Plan...))
}
}
45 changes: 32 additions & 13 deletions planner/core/rule_aggregation_push_down.go
Original file line number Diff line number Diff line change
Expand Up @@ -428,22 +428,41 @@ func (a *aggregationPushDownSolver) aggPushDown(p LogicalPlan) (_ LogicalPlan, e
} else if proj, ok1 := child.(*LogicalProjection); ok1 {
// TODO: This optimization is not always reasonable. We have not supported pushing projection to kv layer yet,
// so we must do this optimization.
for i, gbyItem := range agg.GroupByItems {
agg.GroupByItems[i] = expression.ColumnSubstitute(gbyItem, proj.schema, proj.Exprs)
noSideEffects := true
newGbyItems := make([]expression.Expression, 0, len(agg.GroupByItems))
for _, gbyItem := range agg.GroupByItems {
newGbyItems = append(newGbyItems, expression.ColumnSubstitute(gbyItem, proj.schema, proj.Exprs))
if ExprsHasSideEffects(newGbyItems) {
noSideEffects = false
break
}
}
newAggFuncsArgs := make([][]expression.Expression, 0, len(agg.AggFuncs))
if noSideEffects {
for _, aggFunc := range agg.AggFuncs {
newArgs := make([]expression.Expression, 0, len(aggFunc.Args))
for _, arg := range aggFunc.Args {
newArgs = append(newArgs, expression.ColumnSubstitute(arg, proj.schema, proj.Exprs))
}
if ExprsHasSideEffects(newArgs) {
noSideEffects = false
break
}
newAggFuncsArgs = append(newAggFuncsArgs, newArgs)
}
}
for _, aggFunc := range agg.AggFuncs {
newArgs := make([]expression.Expression, 0, len(aggFunc.Args))
for _, arg := range aggFunc.Args {
newArgs = append(newArgs, expression.ColumnSubstitute(arg, proj.schema, proj.Exprs))
if noSideEffects {
agg.GroupByItems = newGbyItems
for i, aggFunc := range agg.AggFuncs {
aggFunc.Args = newAggFuncsArgs[i]
}
aggFunc.Args = newArgs
projChild := proj.children[0]
agg.SetChildren(projChild)
// When the origin plan tree is `Aggregation->Projection->Union All->X`, we need to merge 'Aggregation' and 'Projection' first.
// And then push the new 'Aggregation' below the 'Union All' .
// The final plan tree should be 'Aggregation->Union All->Aggregation->X'.
child = projChild
}
projChild := proj.children[0]
agg.SetChildren(projChild)
// When the origin plan tree is `Aggregation->Projection->Union All->X`, we need to merge 'Aggregation' and 'Projection' first.
// And then push the new 'Aggregation' below the 'Union All' .
// The final plan tree should be 'Aggregation->Union All->Aggregation->X'.
child = projChild
}
if union, ok1 := child.(*LogicalUnionAll); ok1 && p.SCtx().GetSessionVars().AllowAggPushDown {
err := a.tryAggPushDownForUnion(union, agg)
Expand Down
4 changes: 2 additions & 2 deletions planner/core/rule_column_pruning.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func (la *LogicalAggregation) PruneColumns(parentUsedCols []*expression.Column)
if la.AggFuncs[i].Name != ast.AggFuncFirstRow {
allFirstRow = false
}
if !used[i] {
if !used[i] && !ExprsHasSideEffects(la.AggFuncs[i].Args) {
la.schema.Columns = append(la.schema.Columns[:i], la.schema.Columns[i+1:]...)
la.AggFuncs = append(la.AggFuncs[:i], la.AggFuncs[i+1:]...)
} else if la.AggFuncs[i].Name != ast.AggFuncFirstRow {
Expand Down Expand Up @@ -137,7 +137,7 @@ func (la *LogicalAggregation) PruneColumns(parentUsedCols []*expression.Column)
if len(la.GroupByItems) > 0 {
for i := len(la.GroupByItems) - 1; i >= 0; i-- {
cols := expression.ExtractColumns(la.GroupByItems[i])
if len(cols) == 0 {
if len(cols) == 0 && !exprHasSetVarOrSleep(la.GroupByItems[i]) {
la.GroupByItems = append(la.GroupByItems[:i], la.GroupByItems[i+1:]...)
} else {
selfUsedCols = append(selfUsedCols, cols...)
Expand Down
13 changes: 13 additions & 0 deletions planner/core/testdata/integration_suite_in.json
Original file line number Diff line number Diff line change
Expand Up @@ -355,5 +355,18 @@
"cases": [
"select * from t use index (idx_b) where b = 2 limit 1"
]
},
{
"name": "TestGroupBySetVar",
"cases": [
"select floor(dt.rn/2) rownum, count(c1) from (select @rownum := @rownum + 1 rn, c1 from (select @rownum := -1) drn, t1) dt group by floor(dt.rn/2) order by rownum;",
// TODO: fix these two cases
"select @n:=@n+1 as e from ta group by e",
"select @n:=@n+a as e from ta group by e",
"select * from (select @n:=@n+1 as e from ta) tt group by e",
"select * from (select @n:=@n+a as e from ta) tt group by e",
"select a from ta group by @n:=@n+1",
"select a from ta group by @n:=@n+a"
]
}
]
76 changes: 76 additions & 0 deletions planner/core/testdata/integration_suite_out.json
Original file line number Diff line number Diff line change
Expand Up @@ -1897,5 +1897,81 @@
]
}
]
},
{
"Name": "TestGroupBySetVar",
"Cases": [
{
"SQL": "select floor(dt.rn/2) rownum, count(c1) from (select @rownum := @rownum + 1 rn, c1 from (select @rownum := -1) drn, t1) dt group by floor(dt.rn/2) order by rownum;",
"Plan": [
"Sort 1.00 root Column#6",
"└─Projection 1.00 root floor(div(cast(Column#4, decimal(20,0) BINARY), 2))->Column#6, Column#5",
" └─HashAgg 1.00 root group by:Column#13, funcs:count(Column#11)->Column#5, funcs:firstrow(Column#12)->Column#4",
" └─Projection 10000.00 root test.t1.c1, Column#4, floor(div(cast(Column#4, decimal(20,0) BINARY), 2))->Column#13",
" └─Projection 10000.00 root setvar(rownum, plus(getvar(rownum), 1))->Column#4, test.t1.c1",
" └─HashJoin 10000.00 root CARTESIAN inner join",
" ├─Projection(Build) 1.00 root setvar(rownum, -1)->Column#1",
" │ └─TableDual 1.00 root rows:1",
" └─TableReader(Probe) 10000.00 root data:TableFullScan",
" └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo"
]
},
{
"SQL": "select @n:=@n+1 as e from ta group by e",
"Plan": [
"Projection 1.00 root setvar(n, plus(getvar(n), 1))->Column#4",
"└─HashAgg 1.00 root group by:Column#8, funcs:firstrow(1)->Column#7",
" └─Projection 10000.00 root setvar(n, plus(cast(getvar(n), double BINARY), 1))->Column#8",
" └─TableReader 10000.00 root data:TableFullScan",
" └─TableFullScan 10000.00 cop[tikv] table:ta keep order:false, stats:pseudo"
]
},
{
"SQL": "select @n:=@n+a as e from ta group by e",
"Plan": [
"Projection 8000.00 root setvar(n, plus(getvar(n), cast(test.ta.a, double BINARY)))->Column#4",
"└─HashAgg 8000.00 root group by:Column#7, funcs:firstrow(Column#6)->test.ta.a",
" └─Projection 10000.00 root test.ta.a, setvar(n, plus(getvar(n), cast(test.ta.a, double BINARY)))->Column#7",
" └─TableReader 10000.00 root data:TableFullScan",
" └─TableFullScan 10000.00 cop[tikv] table:ta keep order:false, stats:pseudo"
]
},
{
"SQL": "select * from (select @n:=@n+1 as e from ta) tt group by e",
"Plan": [
"HashAgg 1.00 root group by:Column#4, funcs:firstrow(Column#4)->Column#4",
"└─Projection 10000.00 root setvar(n, plus(getvar(n), 1))->Column#4",
" └─TableReader 10000.00 root data:TableFullScan",
" └─TableFullScan 10000.00 cop[tikv] table:ta keep order:false, stats:pseudo"
]
},
{
"SQL": "select * from (select @n:=@n+a as e from ta) tt group by e",
"Plan": [
"HashAgg 8000.00 root group by:Column#4, funcs:firstrow(Column#4)->Column#4",
"└─Projection 10000.00 root setvar(n, plus(getvar(n), cast(test.ta.a, double BINARY)))->Column#4",
" └─TableReader 10000.00 root data:TableFullScan",
" └─TableFullScan 10000.00 cop[tikv] table:ta keep order:false, stats:pseudo"
]
},
{
"SQL": "select a from ta group by @n:=@n+1",
"Plan": [
"HashAgg 1.00 root group by:Column#5, funcs:firstrow(Column#4)->test.ta.a",
"└─Projection 10000.00 root test.ta.a, setvar(n, plus(getvar(n), 1))->Column#5",
" └─TableReader 10000.00 root data:TableFullScan",
" └─TableFullScan 10000.00 cop[tikv] table:ta keep order:false, stats:pseudo"
]
},
{
"SQL": "select a from ta group by @n:=@n+a",
"Plan": [
"HashAgg 8000.00 root group by:Column#5, funcs:firstrow(Column#4)->test.ta.a",
"└─Projection 10000.00 root test.ta.a, setvar(n, plus(getvar(n), cast(test.ta.a, double BINARY)))->Column#5",
" └─TableReader 10000.00 root data:TableFullScan",
" └─TableFullScan 10000.00 cop[tikv] table:ta keep order:false, stats:pseudo"
]
}
]
}
]

0 comments on commit 8dcebd1

Please sign in to comment.