diff --git a/planner/core/integration_test.go b/planner/core/integration_test.go index 33cc9e13358af..707a8b23a0c05 100644 --- a/planner/core/integration_test.go +++ b/planner/core/integration_test.go @@ -3696,3 +3696,31 @@ func (s *testIntegrationSuite) TestIssue26559(c *C) { tk.MustExec("insert into t values('2020-07-29 09:07:01', '2020-07-27 16:57:36');") tk.MustQuery("select greatest(a, b) from t union select null;").Sort().Check(testkit.Rows("2020-07-29 09:07:01", "")) } + +func (s *testIntegrationSuite) TestGroupBySetVar(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t1") + tk.MustExec("create table t1(c1 int);") + tk.MustExec("insert into t1 values(1), (2), (3), (4), (5), (6);") + rows := tk.MustQuery("select floor(dt.rn/2) rownum, count(c1) from (select @rownum := @rownum + 1 rn, c1 from (select @rownum := -1) drn, t1) dt group by floor(dt.rn/2) order by rownum;") + rows.Check(testkit.Rows("0 2", "1 2", "2 2")) + + tk.MustExec("create table ta(a int, b int);") + tk.MustExec("set sql_mode='';") + + var input []string + var output []struct { + SQL string + Plan []string + } + s.testData.GetTestCases(c, &input, &output) + for i, tt := range input { + res := tk.MustQuery("explain format = 'brief' " + tt) + s.testData.OnRecord(func() { + output[i].SQL = tt + output[i].Plan = s.testData.ConvertRowsToStrings(res.Rows()) + }) + res.Check(testkit.Rows(output[i].Plan...)) + } +} diff --git a/planner/core/rule_aggregation_push_down.go b/planner/core/rule_aggregation_push_down.go index 800915e345448..83e9c3ade6cab 100644 --- a/planner/core/rule_aggregation_push_down.go +++ b/planner/core/rule_aggregation_push_down.go @@ -427,22 +427,41 @@ func (a *aggregationPushDownSolver) aggPushDown(p LogicalPlan) (_ LogicalPlan, e } else if proj, ok1 := child.(*LogicalProjection); ok1 && p.SCtx().GetSessionVars().AllowAggPushDown { // TODO: This optimization is not always reasonable. We have not supported pushing projection to kv layer yet, // so we must do this optimization. - for i, gbyItem := range agg.GroupByItems { - agg.GroupByItems[i] = expression.ColumnSubstitute(gbyItem, proj.schema, proj.Exprs) + noSideEffects := true + newGbyItems := make([]expression.Expression, 0, len(agg.GroupByItems)) + for _, gbyItem := range agg.GroupByItems { + newGbyItems = append(newGbyItems, expression.ColumnSubstitute(gbyItem, proj.schema, proj.Exprs)) + if ExprsHasSideEffects(newGbyItems) { + noSideEffects = false + break + } + } + newAggFuncsArgs := make([][]expression.Expression, 0, len(agg.AggFuncs)) + if noSideEffects { + for _, aggFunc := range agg.AggFuncs { + newArgs := make([]expression.Expression, 0, len(aggFunc.Args)) + for _, arg := range aggFunc.Args { + newArgs = append(newArgs, expression.ColumnSubstitute(arg, proj.schema, proj.Exprs)) + } + if ExprsHasSideEffects(newArgs) { + noSideEffects = false + break + } + newAggFuncsArgs = append(newAggFuncsArgs, newArgs) + } } - for _, aggFunc := range agg.AggFuncs { - newArgs := make([]expression.Expression, 0, len(aggFunc.Args)) - for _, arg := range aggFunc.Args { - newArgs = append(newArgs, expression.ColumnSubstitute(arg, proj.schema, proj.Exprs)) + if noSideEffects { + agg.GroupByItems = newGbyItems + for i, aggFunc := range agg.AggFuncs { + aggFunc.Args = newAggFuncsArgs[i] } - aggFunc.Args = newArgs + projChild := proj.children[0] + agg.SetChildren(projChild) + // When the origin plan tree is `Aggregation->Projection->Union All->X`, we need to merge 'Aggregation' and 'Projection' first. + // And then push the new 'Aggregation' below the 'Union All' . + // The final plan tree should be 'Aggregation->Union All->Aggregation->X'. + child = projChild } - projChild := proj.children[0] - agg.SetChildren(projChild) - // When the origin plan tree is `Aggregation->Projection->Union All->X`, we need to merge 'Aggregation' and 'Projection' first. - // And then push the new 'Aggregation' below the 'Union All' . - // The final plan tree should be 'Aggregation->Union All->Aggregation->X'. - child = projChild } if union, ok1 := child.(*LogicalUnionAll); ok1 && p.SCtx().GetSessionVars().AllowAggPushDown { err := a.tryAggPushDownForUnion(union, agg) diff --git a/planner/core/rule_column_pruning.go b/planner/core/rule_column_pruning.go index 9eeb6feb2d086..2b980f86b6259 100644 --- a/planner/core/rule_column_pruning.go +++ b/planner/core/rule_column_pruning.go @@ -94,7 +94,7 @@ func (la *LogicalAggregation) PruneColumns(parentUsedCols []*expression.Column) if la.AggFuncs[i].Name != ast.AggFuncFirstRow { allFirstRow = false } - if !used[i] { + if !used[i] && !ExprsHasSideEffects(la.AggFuncs[i].Args) { la.schema.Columns = append(la.schema.Columns[:i], la.schema.Columns[i+1:]...) la.AggFuncs = append(la.AggFuncs[:i], la.AggFuncs[i+1:]...) } else if la.AggFuncs[i].Name != ast.AggFuncFirstRow { @@ -135,7 +135,7 @@ func (la *LogicalAggregation) PruneColumns(parentUsedCols []*expression.Column) if len(la.GroupByItems) > 0 { for i := len(la.GroupByItems) - 1; i >= 0; i-- { cols := expression.ExtractColumns(la.GroupByItems[i]) - if len(cols) == 0 { + if len(cols) == 0 && !exprHasSetVarOrSleep(la.GroupByItems[i]) { la.GroupByItems = append(la.GroupByItems[:i], la.GroupByItems[i+1:]...) } else { selfUsedCols = append(selfUsedCols, cols...) diff --git a/planner/core/testdata/integration_suite_in.json b/planner/core/testdata/integration_suite_in.json index e89c1ae3a00ce..1be0dd432914a 100644 --- a/planner/core/testdata/integration_suite_in.json +++ b/planner/core/testdata/integration_suite_in.json @@ -287,5 +287,18 @@ "explain format = 'brief' select * from t where exists (select 1 from t t1 join t t2 where t1.a = t2.a and t1.a = t.a)", "explain format = 'brief' select * from t where exists (select 1 from t t1 join t t2 on t1.a = t2.a and t1.a = t.a)" ] + }, + { + "name": "TestGroupBySetVar", + "cases": [ + "select floor(dt.rn/2) rownum, count(c1) from (select @rownum := @rownum + 1 rn, c1 from (select @rownum := -1) drn, t1) dt group by floor(dt.rn/2) order by rownum;", + // TODO: fix these two cases + "select @n:=@n+1 as e from ta group by e", + "select @n:=@n+a as e from ta group by e", + "select * from (select @n:=@n+1 as e from ta) tt group by e", + "select * from (select @n:=@n+a as e from ta) tt group by e", + "select a from ta group by @n:=@n+1", + "select a from ta group by @n:=@n+a" + ] } ] diff --git a/planner/core/testdata/integration_suite_out.json b/planner/core/testdata/integration_suite_out.json index a828e413b5d66..15cf14d68e871 100644 --- a/planner/core/testdata/integration_suite_out.json +++ b/planner/core/testdata/integration_suite_out.json @@ -1538,5 +1538,81 @@ ] } ] + }, + { + "Name": "TestGroupBySetVar", + "Cases": [ + { + "SQL": "select floor(dt.rn/2) rownum, count(c1) from (select @rownum := @rownum + 1 rn, c1 from (select @rownum := -1) drn, t1) dt group by floor(dt.rn/2) order by rownum;", + "Plan": [ + "Sort 1.00 root Column#6", + "└─Projection 1.00 root floor(div(cast(Column#4, decimal(20,0) BINARY), 2))->Column#6, Column#5", + " └─HashAgg 1.00 root group by:Column#13, funcs:count(Column#11)->Column#5, funcs:firstrow(Column#12)->Column#4", + " └─Projection 10000.00 root test.t1.c1, Column#4, floor(div(cast(Column#4, decimal(20,0) BINARY), 2))->Column#13", + " └─Projection 10000.00 root setvar(rownum, plus(getvar(rownum), 1))->Column#4, test.t1.c1", + " └─HashJoin 10000.00 root CARTESIAN inner join", + " ├─Projection(Build) 1.00 root setvar(rownum, -1)->Column#1", + " │ └─TableDual 1.00 root rows:1", + " └─TableReader(Probe) 10000.00 root data:TableFullScan", + " └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo" + ] + }, + { + "SQL": "select @n:=@n+1 as e from ta group by e", + "Plan": [ + "Projection 1.00 root setvar(n, plus(getvar(n), 1))->Column#4", + "└─HashAgg 1.00 root group by:Column#8, funcs:firstrow(1)->Column#7", + " └─Projection 10000.00 root setvar(n, plus(cast(getvar(n), double BINARY), 1))->Column#8", + " └─TableReader 10000.00 root data:TableFullScan", + " └─TableFullScan 10000.00 cop[tikv] table:ta keep order:false, stats:pseudo" + ] + }, + { + "SQL": "select @n:=@n+a as e from ta group by e", + "Plan": [ + "Projection 8000.00 root setvar(n, plus(getvar(n), cast(test.ta.a, double BINARY)))->Column#4", + "└─HashAgg 8000.00 root group by:Column#7, funcs:firstrow(Column#6)->test.ta.a", + " └─Projection 10000.00 root test.ta.a, setvar(n, plus(getvar(n), cast(test.ta.a, double BINARY)))->Column#7", + " └─TableReader 10000.00 root data:TableFullScan", + " └─TableFullScan 10000.00 cop[tikv] table:ta keep order:false, stats:pseudo" + ] + }, + { + "SQL": "select * from (select @n:=@n+1 as e from ta) tt group by e", + "Plan": [ + "HashAgg 1.00 root group by:Column#4, funcs:firstrow(Column#4)->Column#4", + "└─Projection 10000.00 root setvar(n, plus(getvar(n), 1))->Column#4", + " └─TableReader 10000.00 root data:TableFullScan", + " └─TableFullScan 10000.00 cop[tikv] table:ta keep order:false, stats:pseudo" + ] + }, + { + "SQL": "select * from (select @n:=@n+a as e from ta) tt group by e", + "Plan": [ + "HashAgg 8000.00 root group by:Column#4, funcs:firstrow(Column#4)->Column#4", + "└─Projection 10000.00 root setvar(n, plus(getvar(n), cast(test.ta.a, double BINARY)))->Column#4", + " └─TableReader 10000.00 root data:TableFullScan", + " └─TableFullScan 10000.00 cop[tikv] table:ta keep order:false, stats:pseudo" + ] + }, + { + "SQL": "select a from ta group by @n:=@n+1", + "Plan": [ + "HashAgg 1.00 root group by:Column#5, funcs:firstrow(Column#4)->test.ta.a", + "└─Projection 10000.00 root test.ta.a, setvar(n, plus(getvar(n), 1))->Column#5", + " └─TableReader 10000.00 root data:TableFullScan", + " └─TableFullScan 10000.00 cop[tikv] table:ta keep order:false, stats:pseudo" + ] + }, + { + "SQL": "select a from ta group by @n:=@n+a", + "Plan": [ + "HashAgg 8000.00 root group by:Column#5, funcs:firstrow(Column#4)->test.ta.a", + "└─Projection 10000.00 root test.ta.a, setvar(n, plus(getvar(n), cast(test.ta.a, double BINARY)))->Column#5", + " └─TableReader 10000.00 root data:TableFullScan", + " └─TableFullScan 10000.00 cop[tikv] table:ta keep order:false, stats:pseudo" + ] + } + ] } ]