Skip to content

Commit

Permalink
planner: generate correct number of rows when all agg funcs are pruned (
Browse files Browse the repository at this point in the history
  • Loading branch information
eurekaka committed Jun 2, 2021
1 parent 1f79bfe commit ad7102c
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 14 deletions.
21 changes: 11 additions & 10 deletions cmd/explaintest/r/explain_easy.result
Original file line number Diff line number Diff line change
Expand Up @@ -194,31 +194,32 @@ test t4 1 expr_idx 1 NULL NULL (`a` + `b` + 1) 2 YES NO
explain format = 'brief' select count(1) from (select count(1) from (select * from t1 where c3 = 100) k) k2;
id estRows task access object operator info
StreamAgg 1.00 root funcs:count(1)->Column#5
└─StreamAgg 1.00 root funcs:firstrow(Column#9)->Column#7
└─StreamAgg 1.00 root funcs:count(Column#9)->Column#7
└─TableReader 1.00 root data:StreamAgg
└─StreamAgg 1.00 cop[tikv] funcs:firstrow(1)->Column#9
└─StreamAgg 1.00 cop[tikv] funcs:count(1)->Column#9
└─Selection 10.00 cop[tikv] eq(test.t1.c3, 100)
└─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo
explain format = 'brief' select 1 from (select count(c2), count(c3) from t1) k;
id estRows task access object operator info
Projection 1.00 root 1->Column#6
└─StreamAgg 1.00 root funcs:firstrow(Column#14)->Column#9
└─StreamAgg 1.00 root funcs:count(Column#14)->Column#9
└─TableReader 1.00 root data:StreamAgg
└─StreamAgg 1.00 cop[tikv] funcs:firstrow(1)->Column#14
└─StreamAgg 1.00 cop[tikv] funcs:count(1)->Column#14
└─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo
explain format = 'brief' select count(1) from (select max(c2), count(c3) as m from t1) k;
id estRows task access object operator info
StreamAgg 1.00 root funcs:count(1)->Column#6
└─StreamAgg 1.00 root funcs:firstrow(Column#13)->Column#8
└─StreamAgg 1.00 root funcs:count(Column#13)->Column#8
└─TableReader 1.00 root data:StreamAgg
└─StreamAgg 1.00 cop[tikv] funcs:firstrow(1)->Column#13
└─StreamAgg 1.00 cop[tikv] funcs:count(1)->Column#13
└─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo
explain format = 'brief' select count(1) from (select count(c2) from t1 group by c3) k;
id estRows task access object operator info
StreamAgg 1.00 root funcs:count(1)->Column#5
└─HashAgg 8000.00 root group by:test.t1.c3, funcs:firstrow(1)->Column#7
└─TableReader 10000.00 root data:TableFullScan
└─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo
└─HashAgg 8000.00 root group by:test.t1.c3, funcs:count(Column#9)->Column#7
└─TableReader 8000.00 root data:HashAgg
└─HashAgg 8000.00 cop[tikv] group by:test.t1.c3, funcs:count(1)->Column#9
└─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo
set @@session.tidb_opt_insubq_to_join_and_agg=0;
explain format = 'brief' select sum(t1.c1 in (select c1 from t2)) from t1;
id estRows task access object operator info
Expand Down Expand Up @@ -498,7 +499,7 @@ PRIMARY KEY (`id`)
explain format = 'brief' SELECT COUNT(1) FROM (SELECT COALESCE(b.region_name, '不详') region_name, SUM(a.registration_num) registration_num FROM (SELECT stat_date, show_date, region_id, 0 registration_num FROM test01 WHERE period = 1 AND stat_date >= 20191202 AND stat_date <= 20191202 UNION ALL SELECT stat_date, show_date, region_id, registration_num registration_num FROM test01 WHERE period = 1 AND stat_date >= 20191202 AND stat_date <= 20191202) a LEFT JOIN test02 b ON a.region_id = b.id WHERE registration_num > 0 AND a.stat_date >= '20191202' AND a.stat_date <= '20191202' GROUP BY a.stat_date , a.show_date , COALESCE(b.region_name, '不详') ) JLS;
id estRows task access object operator info
StreamAgg 1.00 root funcs:count(1)->Column#22
└─HashAgg 8000.00 root group by:Column#32, Column#33, Column#34, funcs:firstrow(1)->Column#31
└─HashAgg 8000.00 root group by:Column#32, Column#33, Column#34, funcs:count(1)->Column#31
└─Projection 10000.01 root Column#14, Column#15, coalesce(test.test02.region_name, 不详)->Column#34
└─HashJoin 10000.01 root left outer join, equal:[eq(Column#16, test.test02.id)]
├─TableReader(Build) 10000.00 root data:TableFullScan
Expand Down
23 changes: 23 additions & 0 deletions planner/core/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,29 @@ func (s *testIntegrationSuite) TestPushLimitDownIndexLookUpReader(c *C) {
}
}

func (s *testIntegrationSuite) TestAggColumnPrune(c *C) {
tk := testkit.NewTestKit(c, s.store)

tk.MustExec("use test")
tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a int)")
tk.MustExec("insert into t values(1),(2)")

var input []string
var output []struct {
SQL string
Res []string
}
s.testData.GetTestCases(c, &input, &output)
for i, tt := range input {
s.testData.OnRecord(func() {
output[i].SQL = tt
output[i].Res = s.testData.ConvertRowsToStrings(tk.MustQuery(tt).Rows())
})
tk.MustQuery(tt).Check(testkit.Rows(output[i].Res...))
}
}

func (s *testIntegrationSuite) TestIsFromUnixtimeNullRejective(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
Expand Down
21 changes: 17 additions & 4 deletions planner/core/rule_column_pruning.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,11 @@ func (la *LogicalAggregation) PruneColumns(parentUsedCols []*expression.Column)
child := la.children[0]
used := expression.GetUsedList(parentUsedCols, la.Schema())

allFirstRow := true
for i := len(used) - 1; i >= 0; i-- {
if la.AggFuncs[i].Name != ast.AggFuncFirstRow {
allFirstRow = false
}
if !used[i] {
la.schema.Columns = append(la.schema.Columns[:i], la.schema.Columns[i+1:]...)
la.AggFuncs = append(la.AggFuncs[:i], la.AggFuncs[i+1:]...)
Expand All @@ -103,15 +107,24 @@ func (la *LogicalAggregation) PruneColumns(parentUsedCols []*expression.Column)
selfUsedCols = append(selfUsedCols, cols...)
}
if len(la.AggFuncs) == 0 {
// If all the aggregate functions are pruned, we should add an aggregate function to keep the correctness.
one, err := aggregation.NewAggFuncDesc(la.ctx, ast.AggFuncFirstRow, []expression.Expression{expression.NewOne()}, false)
// If all the aggregate functions are pruned, we should add an aggregate function to maintain the info of row numbers.
// For all the aggregate functions except `first_row`, if we have an empty table defined as t(a,b),
// `select agg(a) from t` would always return one row, while `select agg(a) from t group by b` would return empty.
// For `first_row` which is only used internally by tidb, `first_row(a)` would always return empty for empty input now.
var err error
var newAgg *aggregation.AggFuncDesc
if allFirstRow {
newAgg, err = aggregation.NewAggFuncDesc(la.ctx, ast.AggFuncFirstRow, []expression.Expression{expression.NewOne()}, false)
} else {
newAgg, err = aggregation.NewAggFuncDesc(la.ctx, ast.AggFuncCount, []expression.Expression{expression.NewOne()}, false)
}
if err != nil {
return err
}
la.AggFuncs = []*aggregation.AggFuncDesc{one}
la.AggFuncs = []*aggregation.AggFuncDesc{newAgg}
col := &expression.Column{
UniqueID: la.ctx.GetSessionVars().AllocPlanColumnID(),
RetType: one.RetTp,
RetType: newAgg.RetTp,
}
la.schema.Columns = []*expression.Column{col}
}
Expand Down
15 changes: 15 additions & 0 deletions planner/core/testdata/integration_suite_in.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,21 @@
"explain format = 'brief' select * from t t1 left join t t2 on t1.a=t2.a where from_unixtime(t2.b);"
]
},
{
"name": "TestAggColumnPrune",
"cases": [
"select count(1) from t join (select count(1) from t where false) as tmp",
"select count(1) from t join (select max(a) from t where false) as tmp",
"select count(1) from t join (select min(a) from t where false) as tmp",
"select count(1) from t join (select sum(a) from t where false) as tmp",
"select count(1) from t join (select avg(a) from t where false) as tmp",
"select count(1) from t join (select count(1) from t where false group by a) as tmp",
"select count(1) from t join (select max(a) from t where false group by a) as tmp",
"select count(1) from t join (select min(a) from t where false group by a) as tmp",
"select count(1) from t join (select sum(a) from t where false group by a) as tmp",
"select count(1) from t join (select avg(a) from t where false group by a) as tmp"
]
},
{
"name": "TestIndexJoinInnerIndexNDV",
"cases": [
Expand Down
65 changes: 65 additions & 0 deletions planner/core/testdata/integration_suite_out.json
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,71 @@
}
]
},
{
"Name": "TestAggColumnPrune",
"Cases": [
{
"SQL": "select count(1) from t join (select count(1) from t where false) as tmp",
"Res": [
"2"
]
},
{
"SQL": "select count(1) from t join (select max(a) from t where false) as tmp",
"Res": [
"2"
]
},
{
"SQL": "select count(1) from t join (select min(a) from t where false) as tmp",
"Res": [
"2"
]
},
{
"SQL": "select count(1) from t join (select sum(a) from t where false) as tmp",
"Res": [
"2"
]
},
{
"SQL": "select count(1) from t join (select avg(a) from t where false) as tmp",
"Res": [
"2"
]
},
{
"SQL": "select count(1) from t join (select count(1) from t where false group by a) as tmp",
"Res": [
"0"
]
},
{
"SQL": "select count(1) from t join (select max(a) from t where false group by a) as tmp",
"Res": [
"0"
]
},
{
"SQL": "select count(1) from t join (select min(a) from t where false group by a) as tmp",
"Res": [
"0"
]
},
{
"SQL": "select count(1) from t join (select sum(a) from t where false group by a) as tmp",
"Res": [
"0"
]
},
{
"SQL": "select count(1) from t join (select avg(a) from t where false group by a) as tmp",
"Res": [
"0"
]
}
]
},
{
"Name": "TestIndexJoinInnerIndexNDV",
"Cases": [
Expand Down

0 comments on commit ad7102c

Please sign in to comment.