diff --git a/planner/core/integration_test.go b/planner/core/integration_test.go index 4957c29223fbb..5eee8f144b216 100644 --- a/planner/core/integration_test.go +++ b/planner/core/integration_test.go @@ -2095,3 +2095,31 @@ func (s *testIntegrationSuite) TestIssue26559(c *C) { tk.MustExec("insert into t values('2020-07-29 09:07:01', '2020-07-27 16:57:36');") tk.MustQuery("select greatest(a, b) from t union select null;").Sort().Check(testkit.Rows("2020-07-29 09:07:01", "")) } + +func (s *testIntegrationSuite) TestGroupBySetVar(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t1") + tk.MustExec("create table t1(c1 int);") + tk.MustExec("insert into t1 values(1), (2), (3), (4), (5), (6);") + rows := tk.MustQuery("select floor(dt.rn/2) rownum, count(c1) from (select @rownum := @rownum + 1 rn, c1 from (select @rownum := -1) drn, t1) dt group by floor(dt.rn/2) order by rownum;") + rows.Check(testkit.Rows("0 2", "1 2", "2 2")) + + tk.MustExec("create table ta(a int, b int);") + tk.MustExec("set sql_mode='';") + + var input []string + var output []struct { + SQL string + Plan []string + } + s.testData.GetTestCases(c, &input, &output) + for i, tt := range input { + res := tk.MustQuery("explain format = 'brief' " + tt) + s.testData.OnRecord(func() { + output[i].SQL = tt + output[i].Plan = s.testData.ConvertRowsToStrings(res.Rows()) + }) + res.Check(testkit.Rows(output[i].Plan...)) + } +} diff --git a/planner/core/rule_aggregation_push_down.go b/planner/core/rule_aggregation_push_down.go index 0c796e4b0c3f4..0d5a916bd5f17 100644 --- a/planner/core/rule_aggregation_push_down.go +++ b/planner/core/rule_aggregation_push_down.go @@ -431,20 +431,57 @@ func (a *aggregationPushDownSolver) aggPushDown(p LogicalPlan) (_ LogicalPlan, e } else if proj, ok1 := child.(*LogicalProjection); ok1 && p.SCtx().GetSessionVars().AllowAggPushDown { // TODO: This optimization is not always reasonable. We have not supported pushing projection to kv layer yet, // so we must do this optimization. - for i, gbyItem := range agg.GroupByItems { - agg.GroupByItems[i] = expression.ColumnSubstitute(gbyItem, proj.schema, proj.Exprs) + noSideEffects := true + newGbyItems := make([]expression.Expression, 0, len(agg.GroupByItems)) + for _, gbyItem := range agg.GroupByItems { + newGbyItems = append(newGbyItems, expression.ColumnSubstitute(gbyItem, proj.schema, proj.Exprs)) + if ExprsHasSideEffects(newGbyItems) { + noSideEffects = false + break + } + } + newAggFuncsArgs := make([][]expression.Expression, 0, len(agg.AggFuncs)) + if noSideEffects { + for _, aggFunc := range agg.AggFuncs { + newArgs := make([]expression.Expression, 0, len(aggFunc.Args)) + for _, arg := range aggFunc.Args { + newArgs = append(newArgs, expression.ColumnSubstitute(arg, proj.schema, proj.Exprs)) + } + if ExprsHasSideEffects(newArgs) { + noSideEffects = false + break + } + newAggFuncsArgs = append(newAggFuncsArgs, newArgs) + } } +<<<<<<< HEAD agg.collectGroupByColumns() for _, aggFunc := range agg.AggFuncs { newArgs := make([]expression.Expression, 0, len(aggFunc.Args)) for _, arg := range aggFunc.Args { newArgs = append(newArgs, expression.ColumnSubstitute(arg, proj.schema, proj.Exprs)) +======= + if noSideEffects { + agg.GroupByItems = newGbyItems + for i, aggFunc := range agg.AggFuncs { + aggFunc.Args = newAggFuncsArgs[i] +>>>>>>> 8dcebd123... planner, expression: avoid exprs with side effects in column pruning and agg pushdown (#27370) } - aggFunc.Args = newArgs + projChild := proj.children[0] + agg.SetChildren(projChild) + // When the origin plan tree is `Aggregation->Projection->Union All->X`, we need to merge 'Aggregation' and 'Projection' first. + // And then push the new 'Aggregation' below the 'Union All' . + // The final plan tree should be 'Aggregation->Union All->Aggregation->X'. + child = projChild } +<<<<<<< HEAD projChild := proj.children[0] agg.SetChildren(projChild) } else if union, ok1 := child.(*LogicalUnionAll); ok1 && p.SCtx().GetSessionVars().AllowAggPushDown { +======= + } + if union, ok1 := child.(*LogicalUnionAll); ok1 && p.SCtx().GetSessionVars().AllowAggPushDown { +>>>>>>> 8dcebd123... planner, expression: avoid exprs with side effects in column pruning and agg pushdown (#27370) err := a.tryAggPushDownForUnion(union, agg) if err != nil { return nil, err diff --git a/planner/core/rule_column_pruning.go b/planner/core/rule_column_pruning.go index 76a1e68806885..fe23bfd1bde92 100644 --- a/planner/core/rule_column_pruning.go +++ b/planner/core/rule_column_pruning.go @@ -94,7 +94,7 @@ func (la *LogicalAggregation) PruneColumns(parentUsedCols []*expression.Column) if la.AggFuncs[i].Name != ast.AggFuncFirstRow { allFirstRow = false } - if !used[i] { + if !used[i] && !ExprsHasSideEffects(la.AggFuncs[i].Args) { la.schema.Columns = append(la.schema.Columns[:i], la.schema.Columns[i+1:]...) la.AggFuncs = append(la.AggFuncs[:i], la.AggFuncs[i+1:]...) } else if la.AggFuncs[i].Name != ast.AggFuncFirstRow { @@ -135,7 +135,7 @@ func (la *LogicalAggregation) PruneColumns(parentUsedCols []*expression.Column) if len(la.GroupByItems) > 0 { for i := len(la.GroupByItems) - 1; i >= 0; i-- { cols := expression.ExtractColumns(la.GroupByItems[i]) - if len(cols) == 0 { + if len(cols) == 0 && !exprHasSetVarOrSleep(la.GroupByItems[i]) { la.GroupByItems = append(la.GroupByItems[:i], la.GroupByItems[i+1:]...) } else { selfUsedCols = append(selfUsedCols, cols...) diff --git a/planner/core/testdata/integration_suite_in.json b/planner/core/testdata/integration_suite_in.json index d3fb427b4cd64..495feca2285b5 100644 --- a/planner/core/testdata/integration_suite_in.json +++ b/planner/core/testdata/integration_suite_in.json @@ -193,5 +193,18 @@ "EXPLAIN SELECT t1.pk FROM t1 INNER JOIN t2 ON t1.col1 = t2.pk INNER JOIN t3 ON t1.col3 = t3.pk WHERE t2.col1 IN ('a' , 'b') AND t3.keycol = 'c' AND t1.col2 = 'a' AND t1.col1 != 'abcdef' AND t1.col1 != 'aaaaaa'", "EXPLAIN SELECT t1.pk FROM t1 LEFT JOIN t2 ON t1.col1 = t2.pk LEFT JOIN t3 ON t1.col3 = t3.pk WHERE t2.col1 IN ('a' , 'b') AND t3.keycol = 'c' AND t1.col2 = 'a' AND t1.col1 != 'abcdef' AND t1.col1 != 'aaaaaa'" ] + }, + { + "name": "TestGroupBySetVar", + "cases": [ + "select floor(dt.rn/2) rownum, count(c1) from (select @rownum := @rownum + 1 rn, c1 from (select @rownum := -1) drn, t1) dt group by floor(dt.rn/2) order by rownum;", + // TODO: fix these two cases + "select @n:=@n+1 as e from ta group by e", + "select @n:=@n+a as e from ta group by e", + "select * from (select @n:=@n+1 as e from ta) tt group by e", + "select * from (select @n:=@n+a as e from ta) tt group by e", + "select a from ta group by @n:=@n+1", + "select a from ta group by @n:=@n+a" + ] } ] diff --git a/planner/core/testdata/integration_suite_out.json b/planner/core/testdata/integration_suite_out.json index f673fcf8dc1f9..34ce027a9064d 100644 --- a/planner/core/testdata/integration_suite_out.json +++ b/planner/core/testdata/integration_suite_out.json @@ -1022,5 +1022,81 @@ ] } ] + }, + { + "Name": "TestGroupBySetVar", + "Cases": [ + { + "SQL": "select floor(dt.rn/2) rownum, count(c1) from (select @rownum := @rownum + 1 rn, c1 from (select @rownum := -1) drn, t1) dt group by floor(dt.rn/2) order by rownum;", + "Plan": [ + "Sort 1.00 root Column#6", + "└─Projection 1.00 root floor(div(cast(Column#4, decimal(20,0) BINARY), 2))->Column#6, Column#5", + " └─HashAgg 1.00 root group by:Column#13, funcs:count(Column#11)->Column#5, funcs:firstrow(Column#12)->Column#4", + " └─Projection 10000.00 root test.t1.c1, Column#4, floor(div(cast(Column#4, decimal(20,0) BINARY), 2))->Column#13", + " └─Projection 10000.00 root setvar(rownum, plus(getvar(rownum), 1))->Column#4, test.t1.c1", + " └─HashJoin 10000.00 root CARTESIAN inner join", + " ├─Projection(Build) 1.00 root setvar(rownum, -1)->Column#1", + " │ └─TableDual 1.00 root rows:1", + " └─TableReader(Probe) 10000.00 root data:TableFullScan", + " └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo" + ] + }, + { + "SQL": "select @n:=@n+1 as e from ta group by e", + "Plan": [ + "Projection 1.00 root setvar(n, plus(getvar(n), 1))->Column#4", + "└─HashAgg 1.00 root group by:Column#8, funcs:firstrow(1)->Column#7", + " └─Projection 10000.00 root setvar(n, plus(cast(getvar(n), double BINARY), 1))->Column#8", + " └─TableReader 10000.00 root data:TableFullScan", + " └─TableFullScan 10000.00 cop[tikv] table:ta keep order:false, stats:pseudo" + ] + }, + { + "SQL": "select @n:=@n+a as e from ta group by e", + "Plan": [ + "Projection 8000.00 root setvar(n, plus(getvar(n), cast(test.ta.a, double BINARY)))->Column#4", + "└─HashAgg 8000.00 root group by:Column#7, funcs:firstrow(Column#6)->test.ta.a", + " └─Projection 10000.00 root test.ta.a, setvar(n, plus(getvar(n), cast(test.ta.a, double BINARY)))->Column#7", + " └─TableReader 10000.00 root data:TableFullScan", + " └─TableFullScan 10000.00 cop[tikv] table:ta keep order:false, stats:pseudo" + ] + }, + { + "SQL": "select * from (select @n:=@n+1 as e from ta) tt group by e", + "Plan": [ + "HashAgg 1.00 root group by:Column#4, funcs:firstrow(Column#4)->Column#4", + "└─Projection 10000.00 root setvar(n, plus(getvar(n), 1))->Column#4", + " └─TableReader 10000.00 root data:TableFullScan", + " └─TableFullScan 10000.00 cop[tikv] table:ta keep order:false, stats:pseudo" + ] + }, + { + "SQL": "select * from (select @n:=@n+a as e from ta) tt group by e", + "Plan": [ + "HashAgg 8000.00 root group by:Column#4, funcs:firstrow(Column#4)->Column#4", + "└─Projection 10000.00 root setvar(n, plus(getvar(n), cast(test.ta.a, double BINARY)))->Column#4", + " └─TableReader 10000.00 root data:TableFullScan", + " └─TableFullScan 10000.00 cop[tikv] table:ta keep order:false, stats:pseudo" + ] + }, + { + "SQL": "select a from ta group by @n:=@n+1", + "Plan": [ + "HashAgg 1.00 root group by:Column#5, funcs:firstrow(Column#4)->test.ta.a", + "└─Projection 10000.00 root test.ta.a, setvar(n, plus(getvar(n), 1))->Column#5", + " └─TableReader 10000.00 root data:TableFullScan", + " └─TableFullScan 10000.00 cop[tikv] table:ta keep order:false, stats:pseudo" + ] + }, + { + "SQL": "select a from ta group by @n:=@n+a", + "Plan": [ + "HashAgg 8000.00 root group by:Column#5, funcs:firstrow(Column#4)->test.ta.a", + "└─Projection 10000.00 root test.ta.a, setvar(n, plus(getvar(n), cast(test.ta.a, double BINARY)))->Column#5", + " └─TableReader 10000.00 root data:TableFullScan", + " └─TableFullScan 10000.00 cop[tikv] table:ta keep order:false, stats:pseudo" + ] + } + ] } ]