diff --git a/cmd/explaintest/r/explain_easy.result b/cmd/explaintest/r/explain_easy.result index 927e25c8024f2..214e51b366de1 100644 --- a/cmd/explaintest/r/explain_easy.result +++ b/cmd/explaintest/r/explain_easy.result @@ -194,31 +194,32 @@ test t4 1 expr_idx 1 NULL NULL (`a` + `b` + 1) 2 YES NO explain format = 'brief' select count(1) from (select count(1) from (select * from t1 where c3 = 100) k) k2; id estRows task access object operator info StreamAgg 1.00 root funcs:count(1)->Column#5 -└─StreamAgg 1.00 root funcs:firstrow(Column#9)->Column#7 +└─StreamAgg 1.00 root funcs:count(Column#9)->Column#7 └─TableReader 1.00 root data:StreamAgg - └─StreamAgg 1.00 cop[tikv] funcs:firstrow(1)->Column#9 + └─StreamAgg 1.00 cop[tikv] funcs:count(1)->Column#9 └─Selection 10.00 cop[tikv] eq(test.t1.c3, 100) └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo explain format = 'brief' select 1 from (select count(c2), count(c3) from t1) k; id estRows task access object operator info Projection 1.00 root 1->Column#6 -└─StreamAgg 1.00 root funcs:firstrow(Column#14)->Column#9 +└─StreamAgg 1.00 root funcs:count(Column#14)->Column#9 └─TableReader 1.00 root data:StreamAgg - └─StreamAgg 1.00 cop[tikv] funcs:firstrow(1)->Column#14 + └─StreamAgg 1.00 cop[tikv] funcs:count(1)->Column#14 └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo explain format = 'brief' select count(1) from (select max(c2), count(c3) as m from t1) k; id estRows task access object operator info StreamAgg 1.00 root funcs:count(1)->Column#6 -└─StreamAgg 1.00 root funcs:firstrow(Column#13)->Column#8 +└─StreamAgg 1.00 root funcs:count(Column#13)->Column#8 └─TableReader 1.00 root data:StreamAgg - └─StreamAgg 1.00 cop[tikv] funcs:firstrow(1)->Column#13 + └─StreamAgg 1.00 cop[tikv] funcs:count(1)->Column#13 └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo explain format = 'brief' select count(1) from (select count(c2) from t1 group by c3) k; id estRows task access object operator info StreamAgg 1.00 root funcs:count(1)->Column#5 -└─HashAgg 8000.00 root group by:test.t1.c3, funcs:firstrow(1)->Column#7 - └─TableReader 10000.00 root data:TableFullScan - └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo +└─HashAgg 8000.00 root group by:test.t1.c3, funcs:count(Column#9)->Column#7 + └─TableReader 8000.00 root data:HashAgg + └─HashAgg 8000.00 cop[tikv] group by:test.t1.c3, funcs:count(1)->Column#9 + └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo set @@session.tidb_opt_insubq_to_join_and_agg=0; explain format = 'brief' select sum(t1.c1 in (select c1 from t2)) from t1; id estRows task access object operator info @@ -498,7 +499,7 @@ PRIMARY KEY (`id`) explain format = 'brief' SELECT COUNT(1) FROM (SELECT COALESCE(b.region_name, '不详') region_name, SUM(a.registration_num) registration_num FROM (SELECT stat_date, show_date, region_id, 0 registration_num FROM test01 WHERE period = 1 AND stat_date >= 20191202 AND stat_date <= 20191202 UNION ALL SELECT stat_date, show_date, region_id, registration_num registration_num FROM test01 WHERE period = 1 AND stat_date >= 20191202 AND stat_date <= 20191202) a LEFT JOIN test02 b ON a.region_id = b.id WHERE registration_num > 0 AND a.stat_date >= '20191202' AND a.stat_date <= '20191202' GROUP BY a.stat_date , a.show_date , COALESCE(b.region_name, '不详') ) JLS; id estRows task access object operator info StreamAgg 1.00 root funcs:count(1)->Column#22 -└─HashAgg 8000.00 root group by:Column#32, Column#33, Column#34, funcs:firstrow(1)->Column#31 +└─HashAgg 8000.00 root group by:Column#32, Column#33, Column#34, funcs:count(1)->Column#31 └─Projection 10000.01 root Column#14, Column#15, coalesce(test.test02.region_name, 不详)->Column#34 └─HashJoin 10000.01 root left outer join, equal:[eq(Column#16, test.test02.id)] ├─TableReader(Build) 10000.00 root data:TableFullScan diff --git a/planner/core/integration_test.go b/planner/core/integration_test.go index c691e5341b0a5..bbfe52a0ee710 100644 --- a/planner/core/integration_test.go +++ b/planner/core/integration_test.go @@ -180,6 +180,29 @@ func (s *testIntegrationSuite) TestPushLimitDownIndexLookUpReader(c *C) { } } +func (s *testIntegrationSuite) TestAggColumnPrune(c *C) { + tk := testkit.NewTestKit(c, s.store) + + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int)") + tk.MustExec("insert into t values(1),(2)") + + var input []string + var output []struct { + SQL string + Res []string + } + s.testData.GetTestCases(c, &input, &output) + for i, tt := range input { + s.testData.OnRecord(func() { + output[i].SQL = tt + output[i].Res = s.testData.ConvertRowsToStrings(tk.MustQuery(tt).Rows()) + }) + tk.MustQuery(tt).Check(testkit.Rows(output[i].Res...)) + } +} + func (s *testIntegrationSuite) TestIsFromUnixtimeNullRejective(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") diff --git a/planner/core/rule_column_pruning.go b/planner/core/rule_column_pruning.go index 4b31853c138d0..8a627792ecc7f 100644 --- a/planner/core/rule_column_pruning.go +++ b/planner/core/rule_column_pruning.go @@ -88,7 +88,11 @@ func (la *LogicalAggregation) PruneColumns(parentUsedCols []*expression.Column) child := la.children[0] used := expression.GetUsedList(parentUsedCols, la.Schema()) + allFirstRow := true for i := len(used) - 1; i >= 0; i-- { + if la.AggFuncs[i].Name != ast.AggFuncFirstRow { + allFirstRow = false + } if !used[i] { la.schema.Columns = append(la.schema.Columns[:i], la.schema.Columns[i+1:]...) la.AggFuncs = append(la.AggFuncs[:i], la.AggFuncs[i+1:]...) @@ -103,15 +107,24 @@ func (la *LogicalAggregation) PruneColumns(parentUsedCols []*expression.Column) selfUsedCols = append(selfUsedCols, cols...) } if len(la.AggFuncs) == 0 { - // If all the aggregate functions are pruned, we should add an aggregate function to keep the correctness. - one, err := aggregation.NewAggFuncDesc(la.ctx, ast.AggFuncFirstRow, []expression.Expression{expression.NewOne()}, false) + // If all the aggregate functions are pruned, we should add an aggregate function to maintain the info of row numbers. + // For all the aggregate functions except `first_row`, if we have an empty table defined as t(a,b), + // `select agg(a) from t` would always return one row, while `select agg(a) from t group by b` would return empty. + // For `first_row` which is only used internally by tidb, `first_row(a)` would always return empty for empty input now. + var err error + var newAgg *aggregation.AggFuncDesc + if allFirstRow { + newAgg, err = aggregation.NewAggFuncDesc(la.ctx, ast.AggFuncFirstRow, []expression.Expression{expression.NewOne()}, false) + } else { + newAgg, err = aggregation.NewAggFuncDesc(la.ctx, ast.AggFuncCount, []expression.Expression{expression.NewOne()}, false) + } if err != nil { return err } - la.AggFuncs = []*aggregation.AggFuncDesc{one} + la.AggFuncs = []*aggregation.AggFuncDesc{newAgg} col := &expression.Column{ UniqueID: la.ctx.GetSessionVars().AllocPlanColumnID(), - RetType: one.RetTp, + RetType: newAgg.RetTp, } la.schema.Columns = []*expression.Column{col} } diff --git a/planner/core/testdata/integration_suite_in.json b/planner/core/testdata/integration_suite_in.json index 087b32110e18f..63b866ad3badd 100644 --- a/planner/core/testdata/integration_suite_in.json +++ b/planner/core/testdata/integration_suite_in.json @@ -19,6 +19,21 @@ "explain format = 'brief' select * from t t1 left join t t2 on t1.a=t2.a where from_unixtime(t2.b);" ] }, + { + "name": "TestAggColumnPrune", + "cases": [ + "select count(1) from t join (select count(1) from t where false) as tmp", + "select count(1) from t join (select max(a) from t where false) as tmp", + "select count(1) from t join (select min(a) from t where false) as tmp", + "select count(1) from t join (select sum(a) from t where false) as tmp", + "select count(1) from t join (select avg(a) from t where false) as tmp", + "select count(1) from t join (select count(1) from t where false group by a) as tmp", + "select count(1) from t join (select max(a) from t where false group by a) as tmp", + "select count(1) from t join (select min(a) from t where false group by a) as tmp", + "select count(1) from t join (select sum(a) from t where false group by a) as tmp", + "select count(1) from t join (select avg(a) from t where false group by a) as tmp" + ] + }, { "name": "TestIndexJoinInnerIndexNDV", "cases": [ diff --git a/planner/core/testdata/integration_suite_out.json b/planner/core/testdata/integration_suite_out.json index 7c735fcb5657c..77aa5b1494da7 100644 --- a/planner/core/testdata/integration_suite_out.json +++ b/planner/core/testdata/integration_suite_out.json @@ -63,6 +63,71 @@ } ] }, + { + "Name": "TestAggColumnPrune", + "Cases": [ + { + "SQL": "select count(1) from t join (select count(1) from t where false) as tmp", + "Res": [ + "2" + ] + }, + { + "SQL": "select count(1) from t join (select max(a) from t where false) as tmp", + "Res": [ + "2" + ] + }, + { + "SQL": "select count(1) from t join (select min(a) from t where false) as tmp", + "Res": [ + "2" + ] + }, + { + "SQL": "select count(1) from t join (select sum(a) from t where false) as tmp", + "Res": [ + "2" + ] + }, + { + "SQL": "select count(1) from t join (select avg(a) from t where false) as tmp", + "Res": [ + "2" + ] + }, + { + "SQL": "select count(1) from t join (select count(1) from t where false group by a) as tmp", + "Res": [ + "0" + ] + }, + { + "SQL": "select count(1) from t join (select max(a) from t where false group by a) as tmp", + "Res": [ + "0" + ] + }, + { + "SQL": "select count(1) from t join (select min(a) from t where false group by a) as tmp", + "Res": [ + "0" + ] + }, + { + "SQL": "select count(1) from t join (select sum(a) from t where false group by a) as tmp", + "Res": [ + "0" + ] + }, + { + "SQL": "select count(1) from t join (select avg(a) from t where false group by a) as tmp", + "Res": [ + "0" + ] + } + ] + }, { "Name": "TestIndexJoinInnerIndexNDV", "Cases": [