From 4b2951de79d6e47693fbb7ebae299b734b5e7604 Mon Sep 17 00:00:00 2001 From: Yiding Cui Date: Wed, 12 Sep 2018 17:36:00 +0800 Subject: [PATCH 01/16] plan: split `aggPrune` out of `aggPushDown` --- plan/logical_plan_builder.go | 1 + plan/logical_plan_test.go | 12 ++- plan/optimizer.go | 4 +- plan/rule_aggregation_elimination.go | 129 +++++++++++++++++++++++++++ plan/rule_aggregation_push_down.go | 115 ++++-------------------- 5 files changed, 155 insertions(+), 106 deletions(-) create mode 100644 plan/rule_aggregation_elimination.go diff --git a/plan/logical_plan_builder.go b/plan/logical_plan_builder.go index 1d1711fd331ed..03ccd993dd337 100644 --- a/plan/logical_plan_builder.go +++ b/plan/logical_plan_builder.go @@ -75,6 +75,7 @@ func (b *planBuilder) buildAggregation(p LogicalPlan, aggFuncList []*ast.Aggrega b.optFlag = b.optFlag | flagPushDownTopN // when we eliminate the max and min we may add `is not null` filter. b.optFlag = b.optFlag | flagPredicatePushDown + b.optFlag = b.optFlag | flagEliminateAgg plan4Agg := LogicalAggregation{AggFuncs: make([]*aggregation.AggFuncDesc, 0, len(aggFuncList))}.init(b.ctx) schema4Agg := expression.NewSchema(make([]*expression.Column, 0, len(aggFuncList)+p.Schema().Len())...) diff --git a/plan/logical_plan_test.go b/plan/logical_plan_test.go index 2f70772b0544f..743359a6d42fe 100644 --- a/plan/logical_plan_test.go +++ b/plan/logical_plan_test.go @@ -875,6 +875,10 @@ func (s *testPlanSuite) TestEagerAggregation(c *C) { sql: "select max(a.c) from t a join t b on a.a=b.a and a.b=b.b group by a.b", best: "Join{DataScan(a)->DataScan(b)}(a.a,b.a)(a.b,b.b)->Aggr(max(a.c))->Projection", }, + { + sql: "select t1.a, count(t2.b) from t t1, t t2 where t1.a = t2.a group by t1.a", + best: "Join{DataScan(t1)->DataScan(t2)}(t1.a,t2.a)->Projection->Projection", + }, } s.ctx.GetSessionVars().AllowAggPushDown = true for _, tt := range tests { @@ -1315,10 +1319,6 @@ func (s *testPlanSuite) TestAggPrune(c *C) { sql: "select sum(b) from t group by c, d, e", best: "DataScan(t)->Aggr(sum(test.t.b))->Projection", }, - { - sql: "select t1.a, count(t2.b) from t t1, t t2 where t1.a = t2.a group by t1.a", - best: "Join{DataScan(t1)->DataScan(t2)}(t1.a,t2.a)->Projection->Projection", - }, { sql: "select tt.a, sum(tt.b) from (select a, b from t) tt group by tt.a", best: "DataScan(t)->Projection->Projection->Projection", @@ -1328,7 +1328,6 @@ func (s *testPlanSuite) TestAggPrune(c *C) { best: "DataScan(t)->Projection->Projection->Projection->Projection", }, } - s.ctx.GetSessionVars().AllowAggPushDown = true for _, tt := range tests { comment := Commentf("for %s", tt.sql) stmt, err := s.ParseOneStmt(tt.sql, "", "") @@ -1337,11 +1336,10 @@ func (s *testPlanSuite) TestAggPrune(c *C) { p, err := BuildLogicalPlan(s.ctx, stmt, s.is) c.Assert(err, IsNil) - p, err = logicalOptimize(flagPredicatePushDown|flagPrunColumns|flagBuildKeyInfo|flagAggregationOptimize, p.(LogicalPlan)) + p, err = logicalOptimize(flagPredicatePushDown|flagPrunColumns|flagBuildKeyInfo|flagEliminateAgg, p.(LogicalPlan)) c.Assert(err, IsNil) c.Assert(ToString(p), Equals, tt.best, comment) } - s.ctx.GetSessionVars().AllowAggPushDown = false } func (s *testPlanSuite) TestVisitInfo(c *C) { diff --git a/plan/optimizer.go b/plan/optimizer.go index 4a1c2b5b7b5b3..d45046d451284 100644 --- a/plan/optimizer.go +++ b/plan/optimizer.go @@ -32,6 +32,7 @@ const ( flagPrunColumns uint64 = 1 << iota flagEliminateProjection flagBuildKeyInfo + flagEliminateAgg flagDecorrelate flagMaxMinEliminate flagPredicatePushDown @@ -44,11 +45,12 @@ var optRuleList = []logicalOptRule{ &columnPruner{}, &projectionEliminater{}, &buildKeySolver{}, + &aggregationRecursiveEliminater{}, &decorrelateSolver{}, &maxMinEliminator{}, &ppdSolver{}, &partitionProcessor{}, - &aggregationOptimizer{}, + &aggregationPushDownSolver{}, &pushDownTopNOptimizer{}, } diff --git a/plan/rule_aggregation_elimination.go b/plan/rule_aggregation_elimination.go new file mode 100644 index 0000000000000..9ba17f83903ba --- /dev/null +++ b/plan/rule_aggregation_elimination.go @@ -0,0 +1,129 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package plan + +import ( + "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/expression/aggregation" + "github.com/pingcap/tidb/mysql" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/types" +) + +type aggregationRecursiveEliminater struct { + aggregationEliminateChecker +} + +type aggregationEliminateChecker struct { +} + +// tryToEliminateAggregation will eliminate aggregation grouped by unique key. +// e.g. select min(b) from t group by a. If a is a unique key, then this sql is equal to `select b from t group by a`. +// For count(expr), sum(expr), avg(expr), count(distinct expr, [expr...]) we may need to rewrite the expr. Details are shown below. +// If we can eliminate agg successful, we return a projection. Else we return a nil pointer. +func (a *aggregationEliminateChecker) tryToEliminateAggregation(agg *LogicalAggregation) *LogicalProjection { + schemaByGroupby := expression.NewSchema(agg.groupByCols...) + coveredByUniqueKey := false + for _, key := range agg.children[0].Schema().Keys { + if schemaByGroupby.ColumnsIndices(key) != nil { + coveredByUniqueKey = true + break + } + } + if coveredByUniqueKey { + // GroupByCols has unique key, so this aggregation can be removed. + proj := a.convertAggToProj(agg) + proj.SetChildren(agg.children[0]) + return proj + } + return nil +} + +func (a *aggregationEliminateChecker) convertAggToProj(agg *LogicalAggregation) *LogicalProjection { + proj := LogicalProjection{ + Exprs: make([]expression.Expression, 0, len(agg.AggFuncs)), + }.init(agg.ctx) + for _, fun := range agg.AggFuncs { + expr := a.rewriteExpr(agg.ctx, fun) + proj.Exprs = append(proj.Exprs, expr) + } + proj.SetSchema(agg.schema.Clone()) + return proj +} + +// rewriteExpr will rewrite the aggregate function to expression doesn't contain aggregate function. +func (a *aggregationEliminateChecker) rewriteExpr(ctx sessionctx.Context, aggFunc *aggregation.AggFuncDesc) expression.Expression { + switch aggFunc.Name { + case ast.AggFuncCount: + if aggFunc.Mode == aggregation.FinalMode { + return a.rewriteSumOrAvg(ctx, aggFunc.Args) + } + return a.rewriteCount(ctx, aggFunc.Args) + case ast.AggFuncSum, ast.AggFuncAvg: + return a.rewriteSumOrAvg(ctx, aggFunc.Args) + default: + // Default we do nothing about expr. + return aggFunc.Args[0] + } +} + +func (a *aggregationEliminateChecker) rewriteCount(ctx sessionctx.Context, exprs []expression.Expression) expression.Expression { + // If is count(expr), we will change it to if(isnull(expr), 0, 1). + // If is count(distinct x, y, z) we will change it to if(isnull(x) or isnull(y) or isnull(z), 0, 1). + isNullExprs := make([]expression.Expression, 0, len(exprs)) + for _, expr := range exprs { + isNullExpr := expression.NewFunctionInternal(ctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), expr) + isNullExprs = append(isNullExprs, isNullExpr) + } + innerExpr := expression.ComposeDNFCondition(ctx, isNullExprs...) + newExpr := expression.NewFunctionInternal(ctx, ast.If, types.NewFieldType(mysql.TypeLonglong), innerExpr, expression.Zero, expression.One) + return newExpr +} + +// See https://dev.mysql.com/doc/refman/5.7/en/group-by-functions.html +// The SUM() and AVG() functions return a DECIMAL value for exact-value arguments (integer or DECIMAL), +// and a DOUBLE value for approximate-value arguments (FLOAT or DOUBLE). +func (a *aggregationEliminateChecker) rewriteSumOrAvg(ctx sessionctx.Context, exprs []expression.Expression) expression.Expression { + // FIXME: Consider the case that avg is final mode. + expr := exprs[0] + switch expr.GetType().Tp { + // Integer type should be cast to decimal. + case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong: + return expression.BuildCastFunction(ctx, expr, types.NewFieldType(mysql.TypeNewDecimal)) + // Double and Decimal doesn't need to be cast. + case mysql.TypeDouble, mysql.TypeNewDecimal: + return expr + // Float should be cast to double. And other non-numeric type should be cast to double too. + default: + return expression.BuildCastFunction(ctx, expr, types.NewFieldType(mysql.TypeDouble)) + } +} + +func (a *aggregationRecursiveEliminater) optimize(p LogicalPlan) (LogicalPlan, error) { + newChildren := make([]LogicalPlan, 0, len(p.Children())) + for _, child := range p.Children() { + newChild, _ := a.optimize(child) + newChildren = append(newChildren, newChild) + } + p.SetChildren(newChildren...) + agg, ok := p.(*LogicalAggregation) + if !ok { + return p, nil + } + if proj := a.tryToEliminateAggregation(agg); proj != nil { + return proj, nil + } + return p, nil +} diff --git a/plan/rule_aggregation_push_down.go b/plan/rule_aggregation_push_down.go index 3089f1d5c9cfd..174048f30f693 100644 --- a/plan/rule_aggregation_push_down.go +++ b/plan/rule_aggregation_push_down.go @@ -24,7 +24,8 @@ import ( "github.com/pingcap/tidb/types" ) -type aggregationOptimizer struct { +type aggregationPushDownSolver struct { + aggregationEliminateChecker } // isDecomposable checks if an aggregate function is decomposable. An aggregation function $F$ is decomposable @@ -33,7 +34,7 @@ type aggregationOptimizer struct { // It's easy to see that max, min, first row is decomposable, no matter whether it's distinct, but sum(distinct) and // count(distinct) is not. // Currently we don't support avg and concat. -func (a *aggregationOptimizer) isDecomposable(fun *aggregation.AggFuncDesc) bool { +func (a *aggregationPushDownSolver) isDecomposable(fun *aggregation.AggFuncDesc) bool { switch fun.Name { case ast.AggFuncAvg, ast.AggFuncGroupConcat: // TODO: Support avg push down. @@ -48,7 +49,7 @@ func (a *aggregationOptimizer) isDecomposable(fun *aggregation.AggFuncDesc) bool } // getAggFuncChildIdx gets which children it belongs to, 0 stands for left, 1 stands for right, -1 stands for both. -func (a *aggregationOptimizer) getAggFuncChildIdx(aggFunc *aggregation.AggFuncDesc, schema *expression.Schema) int { +func (a *aggregationPushDownSolver) getAggFuncChildIdx(aggFunc *aggregation.AggFuncDesc, schema *expression.Schema) int { fromLeft, fromRight := false, false var cols []*expression.Column cols = expression.ExtractColumnsFromExpressions(cols, aggFunc.Args, nil) @@ -70,7 +71,7 @@ func (a *aggregationOptimizer) getAggFuncChildIdx(aggFunc *aggregation.AggFuncDe // collectAggFuncs collects all aggregate functions and splits them into two parts: "leftAggFuncs" and "rightAggFuncs" whose // arguments are all from left child or right child separately. If some aggregate functions have the arguments that have // columns both from left and right children, the whole aggregation is forbidden to push down. -func (a *aggregationOptimizer) collectAggFuncs(agg *LogicalAggregation, join *LogicalJoin) (valid bool, leftAggFuncs, rightAggFuncs []*aggregation.AggFuncDesc) { +func (a *aggregationPushDownSolver) collectAggFuncs(agg *LogicalAggregation, join *LogicalJoin) (valid bool, leftAggFuncs, rightAggFuncs []*aggregation.AggFuncDesc) { valid = true leftChild := join.children[0] for _, aggFunc := range agg.AggFuncs { @@ -95,7 +96,7 @@ func (a *aggregationOptimizer) collectAggFuncs(agg *LogicalAggregation, join *Lo // query should be "SELECT SUM(B.agg) FROM A, (SELECT SUM(id) as agg, c1, c2, c3 FROM B GROUP BY id, c1, c2, c3) as B // WHERE A.c1 = B.c1 AND A.c2 != B.c2 GROUP BY B.c3". As you see, all the columns appearing in join-conditions should be // treated as group by columns in join subquery. -func (a *aggregationOptimizer) collectGbyCols(agg *LogicalAggregation, join *LogicalJoin) (leftGbyCols, rightGbyCols []*expression.Column) { +func (a *aggregationPushDownSolver) collectGbyCols(agg *LogicalAggregation, join *LogicalJoin) (leftGbyCols, rightGbyCols []*expression.Column) { leftChild := join.children[0] ctx := agg.ctx for _, gbyExpr := range agg.GroupByItems { @@ -134,7 +135,7 @@ func (a *aggregationOptimizer) collectGbyCols(agg *LogicalAggregation, join *Log return } -func (a *aggregationOptimizer) splitAggFuncsAndGbyCols(agg *LogicalAggregation, join *LogicalJoin) (valid bool, +func (a *aggregationPushDownSolver) splitAggFuncsAndGbyCols(agg *LogicalAggregation, join *LogicalJoin) (valid bool, leftAggFuncs, rightAggFuncs []*aggregation.AggFuncDesc, leftGbyCols, rightGbyCols []*expression.Column) { valid, leftAggFuncs, rightAggFuncs = a.collectAggFuncs(agg, join) @@ -146,7 +147,7 @@ func (a *aggregationOptimizer) splitAggFuncsAndGbyCols(agg *LogicalAggregation, } // addGbyCol adds a column to gbyCols. If a group by column has existed, it will not be added repeatedly. -func (a *aggregationOptimizer) addGbyCol(ctx sessionctx.Context, gbyCols []*expression.Column, cols ...*expression.Column) []*expression.Column { +func (a *aggregationPushDownSolver) addGbyCol(ctx sessionctx.Context, gbyCols []*expression.Column, cols ...*expression.Column) []*expression.Column { for _, c := range cols { duplicate := false for _, gbyCol := range gbyCols { @@ -163,13 +164,13 @@ func (a *aggregationOptimizer) addGbyCol(ctx sessionctx.Context, gbyCols []*expr } // checkValidJoin checks if this join should be pushed across. -func (a *aggregationOptimizer) checkValidJoin(join *LogicalJoin) bool { +func (a *aggregationPushDownSolver) checkValidJoin(join *LogicalJoin) bool { return join.JoinType == InnerJoin || join.JoinType == LeftOuterJoin || join.JoinType == RightOuterJoin } // decompose splits an aggregate function to two parts: a final mode function and a partial mode function. Currently // there are no differences between partial mode and complete mode, so we can confuse them. -func (a *aggregationOptimizer) decompose(ctx sessionctx.Context, aggFunc *aggregation.AggFuncDesc, schema *expression.Schema) ([]*aggregation.AggFuncDesc, *expression.Schema) { +func (a *aggregationPushDownSolver) decompose(ctx sessionctx.Context, aggFunc *aggregation.AggFuncDesc, schema *expression.Schema) ([]*aggregation.AggFuncDesc, *expression.Schema) { // Result is a slice because avg should be decomposed to sum and count. Currently we don't process this case. result := []*aggregation.AggFuncDesc{aggFunc.Clone()} for _, aggFunc := range result { @@ -187,7 +188,7 @@ func (a *aggregationOptimizer) decompose(ctx sessionctx.Context, aggFunc *aggreg // tryToPushDownAgg tries to push down an aggregate function into a join path. If all aggFuncs are first row, we won't // process it temporarily. If not, We will add additional group by columns and first row functions. We make a new aggregation operator. // If the pushed aggregation is grouped by unique key, it's no need to push it down. -func (a *aggregationOptimizer) tryToPushDownAgg(aggFuncs []*aggregation.AggFuncDesc, gbyCols []*expression.Column, join *LogicalJoin, childIdx int) LogicalPlan { +func (a *aggregationPushDownSolver) tryToPushDownAgg(aggFuncs []*aggregation.AggFuncDesc, gbyCols []*expression.Column, join *LogicalJoin, childIdx int) LogicalPlan { child := join.children[childIdx] if aggregation.IsAllFirstRow(aggFuncs) { return child @@ -221,7 +222,7 @@ func (a *aggregationOptimizer) tryToPushDownAgg(aggFuncs []*aggregation.AggFuncD return agg } -func (a *aggregationOptimizer) getDefaultValues(agg *LogicalAggregation) ([]types.Datum, bool) { +func (a *aggregationPushDownSolver) getDefaultValues(agg *LogicalAggregation) ([]types.Datum, bool) { defaultValues := make([]types.Datum, 0, agg.Schema().Len()) for _, aggFunc := range agg.AggFuncs { value, existsDefaultValue := aggFunc.EvalNullValueInOuterJoin(agg.ctx, agg.children[0].Schema()) @@ -233,7 +234,7 @@ func (a *aggregationOptimizer) getDefaultValues(agg *LogicalAggregation) ([]type return defaultValues, true } -func (a *aggregationOptimizer) checkAnyCountAndSum(aggFuncs []*aggregation.AggFuncDesc) bool { +func (a *aggregationPushDownSolver) checkAnyCountAndSum(aggFuncs []*aggregation.AggFuncDesc) bool { for _, fun := range aggFuncs { if fun.Name == ast.AggFuncSum || fun.Name == ast.AggFuncCount { return true @@ -242,7 +243,7 @@ func (a *aggregationOptimizer) checkAnyCountAndSum(aggFuncs []*aggregation.AggFu return false } -func (a *aggregationOptimizer) makeNewAgg(ctx sessionctx.Context, aggFuncs []*aggregation.AggFuncDesc, gbyCols []*expression.Column) *LogicalAggregation { +func (a *aggregationPushDownSolver) makeNewAgg(ctx sessionctx.Context, aggFuncs []*aggregation.AggFuncDesc, gbyCols []*expression.Column) *LogicalAggregation { agg := LogicalAggregation{ GroupByItems: expression.Column2Exprs(gbyCols), groupByCols: gbyCols, @@ -269,7 +270,7 @@ func (a *aggregationOptimizer) makeNewAgg(ctx sessionctx.Context, aggFuncs []*ag // pushAggCrossUnion will try to push the agg down to the union. If the new aggregation's group-by columns doesn't contain unique key. // We will return the new aggregation. Otherwise we will transform the aggregation to projection. -func (a *aggregationOptimizer) pushAggCrossUnion(agg *LogicalAggregation, unionSchema *expression.Schema, unionChild LogicalPlan) LogicalPlan { +func (a *aggregationPushDownSolver) pushAggCrossUnion(agg *LogicalAggregation, unionSchema *expression.Schema, unionChild LogicalPlan) LogicalPlan { ctx := agg.ctx newAgg := LogicalAggregation{ AggFuncs: make([]*aggregation.AggFuncDesc, 0, len(agg.AggFuncs)), @@ -305,7 +306,7 @@ func (a *aggregationOptimizer) pushAggCrossUnion(agg *LogicalAggregation, unionS return newAgg } -func (a *aggregationOptimizer) optimize(p LogicalPlan) (LogicalPlan, error) { +func (a *aggregationPushDownSolver) optimize(p LogicalPlan) (LogicalPlan, error) { if !p.context().GetSessionVars().AllowAggPushDown { return p, nil } @@ -314,7 +315,7 @@ func (a *aggregationOptimizer) optimize(p LogicalPlan) (LogicalPlan, error) { } // aggPushDown tries to push down aggregate functions to join paths. -func (a *aggregationOptimizer) aggPushDown(p LogicalPlan) LogicalPlan { +func (a *aggregationPushDownSolver) aggPushDown(p LogicalPlan) LogicalPlan { if agg, ok := p.(*LogicalAggregation); ok { proj := a.tryToEliminateAggregation(agg) if proj != nil { @@ -383,85 +384,3 @@ func (a *aggregationOptimizer) aggPushDown(p LogicalPlan) LogicalPlan { p.SetChildren(newChildren...) return p } - -// tryToEliminateAggregation will eliminate aggregation grouped by unique key. -// e.g. select min(b) from t group by a. If a is a unique key, then this sql is equal to `select b from t group by a`. -// For count(expr), sum(expr), avg(expr), count(distinct expr, [expr...]) we may need to rewrite the expr. Details are shown below. -// If we can eliminate agg successful, we return a projection. Else we return a nil pointer. -func (a *aggregationOptimizer) tryToEliminateAggregation(agg *LogicalAggregation) *LogicalProjection { - schemaByGroupby := expression.NewSchema(agg.groupByCols...) - coveredByUniqueKey := false - for _, key := range agg.children[0].Schema().Keys { - if schemaByGroupby.ColumnsIndices(key) != nil { - coveredByUniqueKey = true - break - } - } - if coveredByUniqueKey { - // GroupByCols has unique key, so this aggregation can be removed. - proj := a.convertAggToProj(agg) - proj.SetChildren(agg.children[0]) - return proj - } - return nil -} - -func (a *aggregationOptimizer) convertAggToProj(agg *LogicalAggregation) *LogicalProjection { - proj := LogicalProjection{ - Exprs: make([]expression.Expression, 0, len(agg.AggFuncs)), - }.init(agg.ctx) - for _, fun := range agg.AggFuncs { - expr := a.rewriteExpr(agg.ctx, fun) - proj.Exprs = append(proj.Exprs, expr) - } - proj.SetSchema(agg.schema.Clone()) - return proj -} - -func (a *aggregationOptimizer) rewriteCount(ctx sessionctx.Context, exprs []expression.Expression) expression.Expression { - // If is count(expr), we will change it to if(isnull(expr), 0, 1). - // If is count(distinct x, y, z) we will change it to if(isnull(x) or isnull(y) or isnull(z), 0, 1). - isNullExprs := make([]expression.Expression, 0, len(exprs)) - for _, expr := range exprs { - isNullExpr := expression.NewFunctionInternal(ctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), expr) - isNullExprs = append(isNullExprs, isNullExpr) - } - innerExpr := expression.ComposeDNFCondition(ctx, isNullExprs...) - newExpr := expression.NewFunctionInternal(ctx, ast.If, types.NewFieldType(mysql.TypeLonglong), innerExpr, expression.Zero, expression.One) - return newExpr -} - -// See https://dev.mysql.com/doc/refman/5.7/en/group-by-functions.html -// The SUM() and AVG() functions return a DECIMAL value for exact-value arguments (integer or DECIMAL), -// and a DOUBLE value for approximate-value arguments (FLOAT or DOUBLE). -func (a *aggregationOptimizer) rewriteSumOrAvg(ctx sessionctx.Context, exprs []expression.Expression) expression.Expression { - // FIXME: Consider the case that avg is final mode. - expr := exprs[0] - switch expr.GetType().Tp { - // Integer type should be cast to decimal. - case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong: - return expression.BuildCastFunction(ctx, expr, types.NewFieldType(mysql.TypeNewDecimal)) - // Double and Decimal doesn't need to be cast. - case mysql.TypeDouble, mysql.TypeNewDecimal: - return expr - // Float should be cast to double. And other non-numeric type should be cast to double too. - default: - return expression.BuildCastFunction(ctx, expr, types.NewFieldType(mysql.TypeDouble)) - } -} - -// rewriteExpr will rewrite the aggregate function to expression doesn't contain aggregate function. -func (a *aggregationOptimizer) rewriteExpr(ctx sessionctx.Context, aggFunc *aggregation.AggFuncDesc) expression.Expression { - switch aggFunc.Name { - case ast.AggFuncCount: - if aggFunc.Mode == aggregation.FinalMode { - return a.rewriteSumOrAvg(ctx, aggFunc.Args) - } - return a.rewriteCount(ctx, aggFunc.Args) - case ast.AggFuncSum, ast.AggFuncAvg: - return a.rewriteSumOrAvg(ctx, aggFunc.Args) - default: - // Default we do nothing about expr. - return aggFunc.Args[0] - } -} From 3ee0ecf9536d76e52b903ee2fa26174e1dac3f1a Mon Sep 17 00:00:00 2001 From: Yiding Cui Date: Wed, 12 Sep 2018 20:15:21 +0800 Subject: [PATCH 02/16] rename things to address comments --- plan/logical_plan_builder.go | 4 ++-- plan/logical_plan_test.go | 2 +- plan/optimizer.go | 4 ++-- plan/rule_aggregation_elimination.go | 4 ++-- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/plan/logical_plan_builder.go b/plan/logical_plan_builder.go index 03ccd993dd337..f989bfda8ffa7 100644 --- a/plan/logical_plan_builder.go +++ b/plan/logical_plan_builder.go @@ -68,7 +68,7 @@ func (la *LogicalAggregation) collectGroupByColumns() { func (b *planBuilder) buildAggregation(p LogicalPlan, aggFuncList []*ast.AggregateFuncExpr, gbyItems []expression.Expression) (LogicalPlan, map[int]int, error) { b.optFlag = b.optFlag | flagBuildKeyInfo - b.optFlag = b.optFlag | flagAggregationOptimize + b.optFlag = b.optFlag | flagPushDownAgg // We may apply aggregation eliminate optimization. // So we add the flagMaxMinEliminate to try to convert max/min to topn and flagPushDownTopN to handle the newly added topn operator. b.optFlag = b.optFlag | flagMaxMinEliminate @@ -586,7 +586,7 @@ func (b *planBuilder) buildProjection(p LogicalPlan, fields []*ast.SelectField, func (b *planBuilder) buildDistinct(child LogicalPlan, length int) *LogicalAggregation { b.optFlag = b.optFlag | flagBuildKeyInfo - b.optFlag = b.optFlag | flagAggregationOptimize + b.optFlag = b.optFlag | flagPushDownAgg plan4Agg := LogicalAggregation{ AggFuncs: make([]*aggregation.AggFuncDesc, 0, child.Schema().Len()), GroupByItems: expression.Column2Exprs(child.Schema().Clone().Columns[:length]), diff --git a/plan/logical_plan_test.go b/plan/logical_plan_test.go index 743359a6d42fe..b4b7ffa23656c 100644 --- a/plan/logical_plan_test.go +++ b/plan/logical_plan_test.go @@ -888,7 +888,7 @@ func (s *testPlanSuite) TestEagerAggregation(c *C) { p, err := BuildLogicalPlan(s.ctx, stmt, s.is) c.Assert(err, IsNil) - p, err = logicalOptimize(flagBuildKeyInfo|flagPredicatePushDown|flagPrunColumns|flagAggregationOptimize, p.(LogicalPlan)) + p, err = logicalOptimize(flagBuildKeyInfo|flagPredicatePushDown|flagPrunColumns|flagPushDownAgg, p.(LogicalPlan)) c.Assert(err, IsNil) c.Assert(ToString(p), Equals, tt.best, Commentf("for %s", tt.sql)) } diff --git a/plan/optimizer.go b/plan/optimizer.go index d45046d451284..1c8f2c01c17dd 100644 --- a/plan/optimizer.go +++ b/plan/optimizer.go @@ -37,7 +37,7 @@ const ( flagMaxMinEliminate flagPredicatePushDown flagPartitionProcessor - flagAggregationOptimize + flagPushDownAgg flagPushDownTopN ) @@ -45,7 +45,7 @@ var optRuleList = []logicalOptRule{ &columnPruner{}, &projectionEliminater{}, &buildKeySolver{}, - &aggregationRecursiveEliminater{}, + &aggregationEliminater{}, &decorrelateSolver{}, &maxMinEliminator{}, &ppdSolver{}, diff --git a/plan/rule_aggregation_elimination.go b/plan/rule_aggregation_elimination.go index 9ba17f83903ba..3972942ee54b2 100644 --- a/plan/rule_aggregation_elimination.go +++ b/plan/rule_aggregation_elimination.go @@ -22,7 +22,7 @@ import ( "github.com/pingcap/tidb/types" ) -type aggregationRecursiveEliminater struct { +type aggregationEliminater struct { aggregationEliminateChecker } @@ -111,7 +111,7 @@ func (a *aggregationEliminateChecker) rewriteSumOrAvg(ctx sessionctx.Context, ex } } -func (a *aggregationRecursiveEliminater) optimize(p LogicalPlan) (LogicalPlan, error) { +func (a *aggregationEliminater) optimize(p LogicalPlan) (LogicalPlan, error) { newChildren := make([]LogicalPlan, 0, len(p.Children())) for _, child := range p.Children() { newChild, _ := a.optimize(child) From 3e59451afbf7c8614e09057fe382595e874ca77b Mon Sep 17 00:00:00 2001 From: Yiding Cui Date: Thu, 20 Sep 2018 19:47:55 +0800 Subject: [PATCH 03/16] address comment --- plan/optimizer.go | 2 +- plan/rule_aggregation_elimination.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/plan/optimizer.go b/plan/optimizer.go index 1c8f2c01c17dd..a824dbbbd7964 100644 --- a/plan/optimizer.go +++ b/plan/optimizer.go @@ -45,7 +45,7 @@ var optRuleList = []logicalOptRule{ &columnPruner{}, &projectionEliminater{}, &buildKeySolver{}, - &aggregationEliminater{}, + &aggregationEliminator{}, &decorrelateSolver{}, &maxMinEliminator{}, &ppdSolver{}, diff --git a/plan/rule_aggregation_elimination.go b/plan/rule_aggregation_elimination.go index 3972942ee54b2..2f3653c04ff71 100644 --- a/plan/rule_aggregation_elimination.go +++ b/plan/rule_aggregation_elimination.go @@ -22,7 +22,7 @@ import ( "github.com/pingcap/tidb/types" ) -type aggregationEliminater struct { +type aggregationEliminator struct { aggregationEliminateChecker } @@ -111,7 +111,7 @@ func (a *aggregationEliminateChecker) rewriteSumOrAvg(ctx sessionctx.Context, ex } } -func (a *aggregationEliminater) optimize(p LogicalPlan) (LogicalPlan, error) { +func (a *aggregationEliminator) optimize(p LogicalPlan) (LogicalPlan, error) { newChildren := make([]LogicalPlan, 0, len(p.Children())) for _, child := range p.Children() { newChild, _ := a.optimize(child) From ec1d9e9aa76e7bb9caf15afe1e562c3dcaea3f1c Mon Sep 17 00:00:00 2001 From: Yiding Cui Date: Fri, 21 Sep 2018 14:30:53 +0800 Subject: [PATCH 04/16] extract struct. --- expression/aggregation/descriptor.go | 130 +++++++++++++++------------ 1 file changed, 73 insertions(+), 57 deletions(-) diff --git a/expression/aggregation/descriptor.go b/expression/aggregation/descriptor.go index 7cbf71b46b5b0..faaf97eccbdaf 100644 --- a/expression/aggregation/descriptor.go +++ b/expression/aggregation/descriptor.go @@ -33,6 +33,7 @@ import ( // AggFuncDesc describes an aggregation function signature, only used in planner. type AggFuncDesc struct { + typeInferer AggFuncTypeInferer // Name represents the aggregation function name. Name string // Args represents the arguments of the aggregation function. @@ -45,6 +46,10 @@ type AggFuncDesc struct { HasDistinct bool } +// AggFuncTypeInferer infers the type of aggregate functions. +type AggFuncTypeInferer struct { +} + // NewAggFuncDesc creates an aggregation function signature descriptor. func NewAggFuncDesc(ctx sessionctx.Context, name string, args []expression.Expression, hasDistinct bool) *AggFuncDesc { a := &AggFuncDesc{ @@ -145,17 +150,21 @@ func (a *AggFuncDesc) String() string { func (a *AggFuncDesc) typeInfer(ctx sessionctx.Context) { switch a.Name { case ast.AggFuncCount: - a.typeInfer4Count(ctx) + a.RetTp = a.typeInferer.InferCount(ctx) case ast.AggFuncSum: - a.typeInfer4Sum(ctx) + //TODO: a.Args[0] = expression.WrapWithCastAsReal(ctx, a.Args[0]) + a.RetTp = a.typeInferer.InferSum(ctx, a.Args[0]) case ast.AggFuncAvg: - a.typeInfer4Avg(ctx) + a.RetTp = a.typeInferer.InferAvg(ctx, a.Args[0]) + // TODO: a.Args[0] = expression.WrapWithCastAsDecimal(ctx, a.Args[0]) case ast.AggFuncGroupConcat: - a.typeInfer4GroupConcat(ctx) + a.RetTp = a.typeInferer.InferGroupConcat(ctx) + // TODO: a.Args[i] = expression.WrapWithCastAsString(ctx, a.Args[i]) case ast.AggFuncMax, ast.AggFuncMin, ast.AggFuncFirstRow: - a.typeInfer4MaxMin(ctx) + a.RetTp = a.typeInferer.InferMaxMin(ctx, a.Args[0]) case ast.AggFuncBitAnd, ast.AggFuncBitOr, ast.AggFuncBitXor: - a.typeInfer4BitFuncs(ctx) + a.RetTp = a.typeInferer.InferBitFuncs(ctx) + // TODO: a.Args[0] = expression.WrapWithCastAsInt(ctx, a.Args[0]) default: panic("unsupported agg function: " + a.Name) } @@ -281,83 +290,90 @@ func (a *AggFuncDesc) GetAggFunc(ctx sessionctx.Context) Aggregation { } } -func (a *AggFuncDesc) typeInfer4Count(ctx sessionctx.Context) { - a.RetTp = types.NewFieldType(mysql.TypeLonglong) - a.RetTp.Flen = 21 - types.SetBinChsClnFlag(a.RetTp) +// InferCount infers the type of COUNT function. +func (a *AggFuncTypeInferer) InferCount(ctx sessionctx.Context) (retTp *types.FieldType) { + retTp = types.NewFieldType(mysql.TypeLonglong) + retTp.Flen = 21 + types.SetBinChsClnFlag(retTp) + return } -// typeInfer4Sum should returns a "decimal", otherwise it returns a "double". -// Because child returns integer or decimal type. -func (a *AggFuncDesc) typeInfer4Sum(ctx sessionctx.Context) { - switch a.Args[0].GetType().Tp { +// InferSum infers the type of SUM function. It should returns a "decimal" for exact numeric values, otherwise it returns a "double". +func (a *AggFuncTypeInferer) InferSum(ctx sessionctx.Context, arg expression.Expression) (retTp *types.FieldType) { + switch arg.GetType().Tp { case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeNewDecimal: - a.RetTp = types.NewFieldType(mysql.TypeNewDecimal) - a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxDecimalWidth, a.Args[0].GetType().Decimal - if a.RetTp.Decimal < 0 || a.RetTp.Decimal > mysql.MaxDecimalScale { - a.RetTp.Decimal = mysql.MaxDecimalScale + retTp = types.NewFieldType(mysql.TypeNewDecimal) + retTp.Flen, retTp.Decimal = mysql.MaxDecimalWidth, arg.GetType().Decimal + if retTp.Decimal < 0 || retTp.Decimal > mysql.MaxDecimalScale { + retTp.Decimal = mysql.MaxDecimalScale } - // TODO: a.Args[0] = expression.WrapWithCastAsDecimal(ctx, a.Args[0]) + // TODO: cast arg as expression.WrapWithCastAsDecimal(ctx, arg) default: - a.RetTp = types.NewFieldType(mysql.TypeDouble) - a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, a.Args[0].GetType().Decimal - //TODO: a.Args[0] = expression.WrapWithCastAsReal(ctx, a.Args[0]) + retTp = types.NewFieldType(mysql.TypeDouble) + retTp.Flen, retTp.Decimal = mysql.MaxRealWidth, arg.GetType().Decimal + // TODO: cast arg as expression.WrapWithCastAsDecimal(ctx, arg) } - types.SetBinChsClnFlag(a.RetTp) + types.SetBinChsClnFlag(retTp) + return retTp } -// typeInfer4Avg should returns a "decimal", otherwise it returns a "double". -// Because child returns integer or decimal type. -func (a *AggFuncDesc) typeInfer4Avg(ctx sessionctx.Context) { - switch a.Args[0].GetType().Tp { +// InferAvg infers the type of AVG function. It should returns a "decimal" for exact numeric values, otherwise it returns a "double". +func (a *AggFuncTypeInferer) InferAvg(ctx sessionctx.Context, arg expression.Expression) (retTp *types.FieldType) { + switch arg.GetType().Tp { case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeNewDecimal: - a.RetTp = types.NewFieldType(mysql.TypeNewDecimal) - if a.Args[0].GetType().Decimal < 0 { - a.RetTp.Decimal = mysql.MaxDecimalScale + retTp = types.NewFieldType(mysql.TypeNewDecimal) + if arg.GetType().Decimal < 0 { + retTp.Decimal = mysql.MaxDecimalScale } else { - a.RetTp.Decimal = mathutil.Min(a.Args[0].GetType().Decimal+types.DivFracIncr, mysql.MaxDecimalScale) + retTp.Decimal = mathutil.Min(arg.GetType().Decimal+types.DivFracIncr, mysql.MaxDecimalScale) } - a.RetTp.Flen = mysql.MaxDecimalWidth - // TODO: a.Args[0] = expression.WrapWithCastAsDecimal(ctx, a.Args[0]) + retTp.Flen = mysql.MaxDecimalWidth + // TODO: arg = expression.WrapWithCastAsDecimal(ctx, arg) default: - a.RetTp = types.NewFieldType(mysql.TypeDouble) - a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, a.Args[0].GetType().Decimal - // TODO: a.Args[0] = expression.WrapWithCastAsReal(ctx, a.Args[0]) + retTp = types.NewFieldType(mysql.TypeDouble) + retTp.Flen, retTp.Decimal = mysql.MaxRealWidth, arg.GetType().Decimal + // TODO: arg = expression.WrapWithCastAsReal(ctx, arg) } - types.SetBinChsClnFlag(a.RetTp) + types.SetBinChsClnFlag(retTp) + return } -func (a *AggFuncDesc) typeInfer4GroupConcat(ctx sessionctx.Context) { - a.RetTp = types.NewFieldType(mysql.TypeVarString) - a.RetTp.Charset = charset.CharsetUTF8 - a.RetTp.Collate = charset.CollationUTF8 - a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxBlobWidth, 0 - // TODO: a.Args[i] = expression.WrapWithCastAsString(ctx, a.Args[i]) +// InferGroupConcat infers type of GROUP_CONCAT function. +func (a *AggFuncTypeInferer) InferGroupConcat(ctx sessionctx.Context) (retTp *types.FieldType) { + retTp = types.NewFieldType(mysql.TypeVarString) + retTp.Charset = charset.CharsetUTF8 + retTp.Collate = charset.CollationUTF8 + retTp.Flen, retTp.Decimal = mysql.MaxBlobWidth, 0 + return } -func (a *AggFuncDesc) typeInfer4MaxMin(ctx sessionctx.Context) { - _, argIsScalaFunc := a.Args[0].(*expression.ScalarFunction) - if argIsScalaFunc && a.Args[0].GetType().Tp == mysql.TypeFloat { +// InferMaxMin infers type of MAX/MIN/FIRST_ROW function. +func (a *AggFuncTypeInferer) InferMaxMin(ctx sessionctx.Context, arg expression.Expression) (retTp *types.FieldType) { + _, argIsScalaFunc := arg.(*expression.ScalarFunction) + if argIsScalaFunc && arg.GetType().Tp == mysql.TypeFloat { // For scalar function, the result of "float32" is set to the "float64" - // field in the "Datum". If we do not wrap a cast-as-double function on a.Args[0], - // error would happen when extracting the evaluation of a.Args[0] to a ProjectionExec. + // field in the "Datum". If we do not wrap a cast-as-double function on arg, + // error would happen when extracting the evaluation of arg to a ProjectionExec. tp := types.NewFieldType(mysql.TypeDouble) tp.Flen, tp.Decimal = mysql.MaxRealWidth, types.UnspecifiedLength types.SetBinChsClnFlag(tp) - a.Args[0] = expression.BuildCastFunction(ctx, a.Args[0], tp) + arg = expression.BuildCastFunction(ctx, arg, tp) } - a.RetTp = a.Args[0].GetType() - if a.RetTp.Tp == mysql.TypeEnum || a.RetTp.Tp == mysql.TypeSet { - a.RetTp = &types.FieldType{Tp: mysql.TypeString, Flen: mysql.MaxFieldCharLength} + retTp = arg.GetType() + if retTp.Tp == mysql.TypeEnum || retTp.Tp == mysql.TypeSet { + retTp = &types.FieldType{Tp: mysql.TypeString, Flen: mysql.MaxFieldCharLength} } + return retTp } -func (a *AggFuncDesc) typeInfer4BitFuncs(ctx sessionctx.Context) { - a.RetTp = types.NewFieldType(mysql.TypeLonglong) - a.RetTp.Flen = 21 - types.SetBinChsClnFlag(a.RetTp) - a.RetTp.Flag |= mysql.UnsignedFlag | mysql.NotNullFlag +// InferBitFuncs infers type of bit functions, such as BIT_XOR, BIT_OR ... +func (a *AggFuncTypeInferer) InferBitFuncs(ctx sessionctx.Context) (retTp *types.FieldType) { + retTp = types.NewFieldType(mysql.TypeLonglong) + retTp.Flen = 21 + types.SetBinChsClnFlag(retTp) + retTp.Flag |= mysql.UnsignedFlag | mysql.NotNullFlag // TODO: a.Args[0] = expression.WrapWithCastAsInt(ctx, a.Args[0]) + return } func (a *AggFuncDesc) evalNullValueInOuterJoin4Count(ctx sessionctx.Context, schema *expression.Schema) (types.Datum, bool) { From 083896827ec6afa9fd6f18063659cfab55d7c801 Mon Sep 17 00:00:00 2001 From: Yiding Cui Date: Fri, 21 Sep 2018 15:05:30 +0800 Subject: [PATCH 05/16] fix the typeinfer in aggregate elimination. --- plan/rule_aggregation_elimination.go | 40 +++++++++------------------- 1 file changed, 13 insertions(+), 27 deletions(-) diff --git a/plan/rule_aggregation_elimination.go b/plan/rule_aggregation_elimination.go index 2f3653c04ff71..fe41fdb71ad01 100644 --- a/plan/rule_aggregation_elimination.go +++ b/plan/rule_aggregation_elimination.go @@ -27,6 +27,7 @@ type aggregationEliminator struct { } type aggregationEliminateChecker struct { + aggregation.AggFuncTypeInferer } // tryToEliminateAggregation will eliminate aggregation grouped by unique key. @@ -67,15 +68,19 @@ func (a *aggregationEliminateChecker) convertAggToProj(agg *LogicalAggregation) func (a *aggregationEliminateChecker) rewriteExpr(ctx sessionctx.Context, aggFunc *aggregation.AggFuncDesc) expression.Expression { switch aggFunc.Name { case ast.AggFuncCount: - if aggFunc.Mode == aggregation.FinalMode { - return a.rewriteSumOrAvg(ctx, aggFunc.Args) - } return a.rewriteCount(ctx, aggFunc.Args) - case ast.AggFuncSum, ast.AggFuncAvg: - return a.rewriteSumOrAvg(ctx, aggFunc.Args) + case ast.AggFuncSum: + return expression.BuildCastFunction(ctx, aggFunc.Args[0], a.InferSum(ctx, aggFunc.Args[0])) + case ast.AggFuncAvg: + return expression.BuildCastFunction(ctx, aggFunc.Args[0], a.InferAvg(ctx, aggFunc.Args[0])) + case ast.AggFuncFirstRow, ast.AggFuncMax, ast.AggFuncMin: + return expression.BuildCastFunction(ctx, aggFunc.Args[0], a.InferMaxMin(ctx, aggFunc.Args[0])) + case ast.AggFuncBitAnd, ast.AggFuncBitOr, ast.AggFuncBitXor: + return expression.BuildCastFunction(ctx, aggFunc.Args[0], a.InferBitFuncs(ctx)) + case ast.AggFuncGroupConcat: + return expression.BuildCastFunction(ctx, aggFunc.Args[0], a.InferGroupConcat(ctx)) default: - // Default we do nothing about expr. - return aggFunc.Args[0] + panic("Unsupported function") } } @@ -88,29 +93,10 @@ func (a *aggregationEliminateChecker) rewriteCount(ctx sessionctx.Context, exprs isNullExprs = append(isNullExprs, isNullExpr) } innerExpr := expression.ComposeDNFCondition(ctx, isNullExprs...) - newExpr := expression.NewFunctionInternal(ctx, ast.If, types.NewFieldType(mysql.TypeLonglong), innerExpr, expression.Zero, expression.One) + newExpr := expression.NewFunctionInternal(ctx, ast.If, a.InferCount(ctx), innerExpr, expression.Zero, expression.One) return newExpr } -// See https://dev.mysql.com/doc/refman/5.7/en/group-by-functions.html -// The SUM() and AVG() functions return a DECIMAL value for exact-value arguments (integer or DECIMAL), -// and a DOUBLE value for approximate-value arguments (FLOAT or DOUBLE). -func (a *aggregationEliminateChecker) rewriteSumOrAvg(ctx sessionctx.Context, exprs []expression.Expression) expression.Expression { - // FIXME: Consider the case that avg is final mode. - expr := exprs[0] - switch expr.GetType().Tp { - // Integer type should be cast to decimal. - case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong: - return expression.BuildCastFunction(ctx, expr, types.NewFieldType(mysql.TypeNewDecimal)) - // Double and Decimal doesn't need to be cast. - case mysql.TypeDouble, mysql.TypeNewDecimal: - return expr - // Float should be cast to double. And other non-numeric type should be cast to double too. - default: - return expression.BuildCastFunction(ctx, expr, types.NewFieldType(mysql.TypeDouble)) - } -} - func (a *aggregationEliminator) optimize(p LogicalPlan) (LogicalPlan, error) { newChildren := make([]LogicalPlan, 0, len(p.Children())) for _, child := range p.Children() { From 5a3a128c179c6f16e88103be7e85d2a1ef896305 Mon Sep 17 00:00:00 2001 From: Yiding Cui Date: Tue, 25 Sep 2018 20:12:59 +0800 Subject: [PATCH 06/16] fix behavior. --- plan/rule_aggregation_elimination.go | 35 ++++++++++++++++++++++------ 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/plan/rule_aggregation_elimination.go b/plan/rule_aggregation_elimination.go index fe41fdb71ad01..d307e784441a2 100644 --- a/plan/rule_aggregation_elimination.go +++ b/plan/rule_aggregation_elimination.go @@ -14,6 +14,8 @@ package plan import ( + "math" + "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/expression/aggregation" @@ -68,23 +70,23 @@ func (a *aggregationEliminateChecker) convertAggToProj(agg *LogicalAggregation) func (a *aggregationEliminateChecker) rewriteExpr(ctx sessionctx.Context, aggFunc *aggregation.AggFuncDesc) expression.Expression { switch aggFunc.Name { case ast.AggFuncCount: - return a.rewriteCount(ctx, aggFunc.Args) + return a.rewriteCount(ctx, aggFunc.Args, aggFunc.RetTp) case ast.AggFuncSum: - return expression.BuildCastFunction(ctx, aggFunc.Args[0], a.InferSum(ctx, aggFunc.Args[0])) + return a.wrapCastFunction(ctx, aggFunc.Args[0], aggFunc.RetTp) case ast.AggFuncAvg: - return expression.BuildCastFunction(ctx, aggFunc.Args[0], a.InferAvg(ctx, aggFunc.Args[0])) + return a.wrapCastFunction(ctx, aggFunc.Args[0], aggFunc.RetTp) case ast.AggFuncFirstRow, ast.AggFuncMax, ast.AggFuncMin: - return expression.BuildCastFunction(ctx, aggFunc.Args[0], a.InferMaxMin(ctx, aggFunc.Args[0])) + return a.wrapCastFunction(ctx, aggFunc.Args[0], aggFunc.RetTp) case ast.AggFuncBitAnd, ast.AggFuncBitOr, ast.AggFuncBitXor: - return expression.BuildCastFunction(ctx, aggFunc.Args[0], a.InferBitFuncs(ctx)) + return a.rewriteBitFunc(ctx, aggFunc.Name, aggFunc.Args[0], aggFunc.RetTp) case ast.AggFuncGroupConcat: - return expression.BuildCastFunction(ctx, aggFunc.Args[0], a.InferGroupConcat(ctx)) + return a.wrapCastFunction(ctx, aggFunc.Args[0], aggFunc.RetTp) default: panic("Unsupported function") } } -func (a *aggregationEliminateChecker) rewriteCount(ctx sessionctx.Context, exprs []expression.Expression) expression.Expression { +func (a *aggregationEliminateChecker) rewriteCount(ctx sessionctx.Context, exprs []expression.Expression, targetTp *types.FieldType) expression.Expression { // If is count(expr), we will change it to if(isnull(expr), 0, 1). // If is count(distinct x, y, z) we will change it to if(isnull(x) or isnull(y) or isnull(z), 0, 1). isNullExprs := make([]expression.Expression, 0, len(exprs)) @@ -97,6 +99,25 @@ func (a *aggregationEliminateChecker) rewriteCount(ctx sessionctx.Context, exprs return newExpr } +func (a *aggregationEliminateChecker) rewriteBitFunc(ctx sessionctx.Context, funcType string, arg expression.Expression, targetTp *types.FieldType) expression.Expression { + innerCast := expression.WrapWithCastAsInt(ctx, arg) + outerCast := a.wrapCastFunction(ctx, innerCast, targetTp) + var finalExpr expression.Expression + if funcType != ast.AggFuncBitAnd { + finalExpr = expression.NewFunctionInternal(ctx, ast.Ifnull, targetTp, outerCast, expression.Zero.Clone()) + } else { + finalExpr = expression.NewFunctionInternal(ctx, ast.Ifnull, outerCast.GetType(), outerCast, &expression.Constant{Value: types.NewUintDatum(math.MaxUint64), RetType: targetTp}) + } + return finalExpr +} + +func (a *aggregationEliminateChecker) wrapCastFunction(ctx sessionctx.Context, arg expression.Expression, targetTp *types.FieldType) expression.Expression { + if arg.GetType() == targetTp { + return arg + } + return expression.BuildCastFunction(ctx, arg, targetTp) +} + func (a *aggregationEliminator) optimize(p LogicalPlan) (LogicalPlan, error) { newChildren := make([]LogicalPlan, 0, len(p.Children())) for _, child := range p.Children() { From b9701aef14ec35e053ca73e112718fd77741e75f Mon Sep 17 00:00:00 2001 From: Yiding Cui Date: Tue, 25 Sep 2018 20:16:18 +0800 Subject: [PATCH 07/16] undo unnecessary change --- expression/aggregation/descriptor.go | 130 ++++++++++++--------------- 1 file changed, 57 insertions(+), 73 deletions(-) diff --git a/expression/aggregation/descriptor.go b/expression/aggregation/descriptor.go index faaf97eccbdaf..7cbf71b46b5b0 100644 --- a/expression/aggregation/descriptor.go +++ b/expression/aggregation/descriptor.go @@ -33,7 +33,6 @@ import ( // AggFuncDesc describes an aggregation function signature, only used in planner. type AggFuncDesc struct { - typeInferer AggFuncTypeInferer // Name represents the aggregation function name. Name string // Args represents the arguments of the aggregation function. @@ -46,10 +45,6 @@ type AggFuncDesc struct { HasDistinct bool } -// AggFuncTypeInferer infers the type of aggregate functions. -type AggFuncTypeInferer struct { -} - // NewAggFuncDesc creates an aggregation function signature descriptor. func NewAggFuncDesc(ctx sessionctx.Context, name string, args []expression.Expression, hasDistinct bool) *AggFuncDesc { a := &AggFuncDesc{ @@ -150,21 +145,17 @@ func (a *AggFuncDesc) String() string { func (a *AggFuncDesc) typeInfer(ctx sessionctx.Context) { switch a.Name { case ast.AggFuncCount: - a.RetTp = a.typeInferer.InferCount(ctx) + a.typeInfer4Count(ctx) case ast.AggFuncSum: - //TODO: a.Args[0] = expression.WrapWithCastAsReal(ctx, a.Args[0]) - a.RetTp = a.typeInferer.InferSum(ctx, a.Args[0]) + a.typeInfer4Sum(ctx) case ast.AggFuncAvg: - a.RetTp = a.typeInferer.InferAvg(ctx, a.Args[0]) - // TODO: a.Args[0] = expression.WrapWithCastAsDecimal(ctx, a.Args[0]) + a.typeInfer4Avg(ctx) case ast.AggFuncGroupConcat: - a.RetTp = a.typeInferer.InferGroupConcat(ctx) - // TODO: a.Args[i] = expression.WrapWithCastAsString(ctx, a.Args[i]) + a.typeInfer4GroupConcat(ctx) case ast.AggFuncMax, ast.AggFuncMin, ast.AggFuncFirstRow: - a.RetTp = a.typeInferer.InferMaxMin(ctx, a.Args[0]) + a.typeInfer4MaxMin(ctx) case ast.AggFuncBitAnd, ast.AggFuncBitOr, ast.AggFuncBitXor: - a.RetTp = a.typeInferer.InferBitFuncs(ctx) - // TODO: a.Args[0] = expression.WrapWithCastAsInt(ctx, a.Args[0]) + a.typeInfer4BitFuncs(ctx) default: panic("unsupported agg function: " + a.Name) } @@ -290,90 +281,83 @@ func (a *AggFuncDesc) GetAggFunc(ctx sessionctx.Context) Aggregation { } } -// InferCount infers the type of COUNT function. -func (a *AggFuncTypeInferer) InferCount(ctx sessionctx.Context) (retTp *types.FieldType) { - retTp = types.NewFieldType(mysql.TypeLonglong) - retTp.Flen = 21 - types.SetBinChsClnFlag(retTp) - return +func (a *AggFuncDesc) typeInfer4Count(ctx sessionctx.Context) { + a.RetTp = types.NewFieldType(mysql.TypeLonglong) + a.RetTp.Flen = 21 + types.SetBinChsClnFlag(a.RetTp) } -// InferSum infers the type of SUM function. It should returns a "decimal" for exact numeric values, otherwise it returns a "double". -func (a *AggFuncTypeInferer) InferSum(ctx sessionctx.Context, arg expression.Expression) (retTp *types.FieldType) { - switch arg.GetType().Tp { +// typeInfer4Sum should returns a "decimal", otherwise it returns a "double". +// Because child returns integer or decimal type. +func (a *AggFuncDesc) typeInfer4Sum(ctx sessionctx.Context) { + switch a.Args[0].GetType().Tp { case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeNewDecimal: - retTp = types.NewFieldType(mysql.TypeNewDecimal) - retTp.Flen, retTp.Decimal = mysql.MaxDecimalWidth, arg.GetType().Decimal - if retTp.Decimal < 0 || retTp.Decimal > mysql.MaxDecimalScale { - retTp.Decimal = mysql.MaxDecimalScale + a.RetTp = types.NewFieldType(mysql.TypeNewDecimal) + a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxDecimalWidth, a.Args[0].GetType().Decimal + if a.RetTp.Decimal < 0 || a.RetTp.Decimal > mysql.MaxDecimalScale { + a.RetTp.Decimal = mysql.MaxDecimalScale } - // TODO: cast arg as expression.WrapWithCastAsDecimal(ctx, arg) + // TODO: a.Args[0] = expression.WrapWithCastAsDecimal(ctx, a.Args[0]) default: - retTp = types.NewFieldType(mysql.TypeDouble) - retTp.Flen, retTp.Decimal = mysql.MaxRealWidth, arg.GetType().Decimal - // TODO: cast arg as expression.WrapWithCastAsDecimal(ctx, arg) + a.RetTp = types.NewFieldType(mysql.TypeDouble) + a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, a.Args[0].GetType().Decimal + //TODO: a.Args[0] = expression.WrapWithCastAsReal(ctx, a.Args[0]) } - types.SetBinChsClnFlag(retTp) - return retTp + types.SetBinChsClnFlag(a.RetTp) } -// InferAvg infers the type of AVG function. It should returns a "decimal" for exact numeric values, otherwise it returns a "double". -func (a *AggFuncTypeInferer) InferAvg(ctx sessionctx.Context, arg expression.Expression) (retTp *types.FieldType) { - switch arg.GetType().Tp { +// typeInfer4Avg should returns a "decimal", otherwise it returns a "double". +// Because child returns integer or decimal type. +func (a *AggFuncDesc) typeInfer4Avg(ctx sessionctx.Context) { + switch a.Args[0].GetType().Tp { case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeNewDecimal: - retTp = types.NewFieldType(mysql.TypeNewDecimal) - if arg.GetType().Decimal < 0 { - retTp.Decimal = mysql.MaxDecimalScale + a.RetTp = types.NewFieldType(mysql.TypeNewDecimal) + if a.Args[0].GetType().Decimal < 0 { + a.RetTp.Decimal = mysql.MaxDecimalScale } else { - retTp.Decimal = mathutil.Min(arg.GetType().Decimal+types.DivFracIncr, mysql.MaxDecimalScale) + a.RetTp.Decimal = mathutil.Min(a.Args[0].GetType().Decimal+types.DivFracIncr, mysql.MaxDecimalScale) } - retTp.Flen = mysql.MaxDecimalWidth - // TODO: arg = expression.WrapWithCastAsDecimal(ctx, arg) + a.RetTp.Flen = mysql.MaxDecimalWidth + // TODO: a.Args[0] = expression.WrapWithCastAsDecimal(ctx, a.Args[0]) default: - retTp = types.NewFieldType(mysql.TypeDouble) - retTp.Flen, retTp.Decimal = mysql.MaxRealWidth, arg.GetType().Decimal - // TODO: arg = expression.WrapWithCastAsReal(ctx, arg) + a.RetTp = types.NewFieldType(mysql.TypeDouble) + a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, a.Args[0].GetType().Decimal + // TODO: a.Args[0] = expression.WrapWithCastAsReal(ctx, a.Args[0]) } - types.SetBinChsClnFlag(retTp) - return + types.SetBinChsClnFlag(a.RetTp) } -// InferGroupConcat infers type of GROUP_CONCAT function. -func (a *AggFuncTypeInferer) InferGroupConcat(ctx sessionctx.Context) (retTp *types.FieldType) { - retTp = types.NewFieldType(mysql.TypeVarString) - retTp.Charset = charset.CharsetUTF8 - retTp.Collate = charset.CollationUTF8 - retTp.Flen, retTp.Decimal = mysql.MaxBlobWidth, 0 - return +func (a *AggFuncDesc) typeInfer4GroupConcat(ctx sessionctx.Context) { + a.RetTp = types.NewFieldType(mysql.TypeVarString) + a.RetTp.Charset = charset.CharsetUTF8 + a.RetTp.Collate = charset.CollationUTF8 + a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxBlobWidth, 0 + // TODO: a.Args[i] = expression.WrapWithCastAsString(ctx, a.Args[i]) } -// InferMaxMin infers type of MAX/MIN/FIRST_ROW function. -func (a *AggFuncTypeInferer) InferMaxMin(ctx sessionctx.Context, arg expression.Expression) (retTp *types.FieldType) { - _, argIsScalaFunc := arg.(*expression.ScalarFunction) - if argIsScalaFunc && arg.GetType().Tp == mysql.TypeFloat { +func (a *AggFuncDesc) typeInfer4MaxMin(ctx sessionctx.Context) { + _, argIsScalaFunc := a.Args[0].(*expression.ScalarFunction) + if argIsScalaFunc && a.Args[0].GetType().Tp == mysql.TypeFloat { // For scalar function, the result of "float32" is set to the "float64" - // field in the "Datum". If we do not wrap a cast-as-double function on arg, - // error would happen when extracting the evaluation of arg to a ProjectionExec. + // field in the "Datum". If we do not wrap a cast-as-double function on a.Args[0], + // error would happen when extracting the evaluation of a.Args[0] to a ProjectionExec. tp := types.NewFieldType(mysql.TypeDouble) tp.Flen, tp.Decimal = mysql.MaxRealWidth, types.UnspecifiedLength types.SetBinChsClnFlag(tp) - arg = expression.BuildCastFunction(ctx, arg, tp) + a.Args[0] = expression.BuildCastFunction(ctx, a.Args[0], tp) } - retTp = arg.GetType() - if retTp.Tp == mysql.TypeEnum || retTp.Tp == mysql.TypeSet { - retTp = &types.FieldType{Tp: mysql.TypeString, Flen: mysql.MaxFieldCharLength} + a.RetTp = a.Args[0].GetType() + if a.RetTp.Tp == mysql.TypeEnum || a.RetTp.Tp == mysql.TypeSet { + a.RetTp = &types.FieldType{Tp: mysql.TypeString, Flen: mysql.MaxFieldCharLength} } - return retTp } -// InferBitFuncs infers type of bit functions, such as BIT_XOR, BIT_OR ... -func (a *AggFuncTypeInferer) InferBitFuncs(ctx sessionctx.Context) (retTp *types.FieldType) { - retTp = types.NewFieldType(mysql.TypeLonglong) - retTp.Flen = 21 - types.SetBinChsClnFlag(retTp) - retTp.Flag |= mysql.UnsignedFlag | mysql.NotNullFlag +func (a *AggFuncDesc) typeInfer4BitFuncs(ctx sessionctx.Context) { + a.RetTp = types.NewFieldType(mysql.TypeLonglong) + a.RetTp.Flen = 21 + types.SetBinChsClnFlag(a.RetTp) + a.RetTp.Flag |= mysql.UnsignedFlag | mysql.NotNullFlag // TODO: a.Args[0] = expression.WrapWithCastAsInt(ctx, a.Args[0]) - return } func (a *AggFuncDesc) evalNullValueInOuterJoin4Count(ctx sessionctx.Context, schema *expression.Schema) (types.Datum, bool) { From 02eb0248753dc942bd60c13234a6fab40cbd0466 Mon Sep 17 00:00:00 2001 From: Yiding Cui Date: Tue, 25 Sep 2018 20:32:41 +0800 Subject: [PATCH 08/16] fix test. --- expression/aggregation/descriptor.go | 14 ++++++++++++-- plan/rule_aggregation_elimination.go | 3 +-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/expression/aggregation/descriptor.go b/expression/aggregation/descriptor.go index 7cbf71b46b5b0..acefd580ecfe8 100644 --- a/expression/aggregation/descriptor.go +++ b/expression/aggregation/descriptor.go @@ -297,11 +297,16 @@ func (a *AggFuncDesc) typeInfer4Sum(ctx sessionctx.Context) { if a.RetTp.Decimal < 0 || a.RetTp.Decimal > mysql.MaxDecimalScale { a.RetTp.Decimal = mysql.MaxDecimalScale } + a.RetTp.Flag = a.Args[0].GetType().Flag & mysql.UnsignedFlag // TODO: a.Args[0] = expression.WrapWithCastAsDecimal(ctx, a.Args[0]) - default: + case mysql.TypeDouble, mysql.TypeFloat: a.RetTp = types.NewFieldType(mysql.TypeDouble) a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, a.Args[0].GetType().Decimal //TODO: a.Args[0] = expression.WrapWithCastAsReal(ctx, a.Args[0]) + default: + a.RetTp = types.NewFieldType(mysql.TypeDouble) + a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, types.UnspecifiedLength + // TODO: a.Args[0] = expression.WrapWithCastAsReal(ctx, a.Args[0]) } types.SetBinChsClnFlag(a.RetTp) } @@ -317,12 +322,17 @@ func (a *AggFuncDesc) typeInfer4Avg(ctx sessionctx.Context) { } else { a.RetTp.Decimal = mathutil.Min(a.Args[0].GetType().Decimal+types.DivFracIncr, mysql.MaxDecimalScale) } + a.RetTp.Flag = a.Args[0].GetType().Flag & mysql.UnsignedFlag a.RetTp.Flen = mysql.MaxDecimalWidth // TODO: a.Args[0] = expression.WrapWithCastAsDecimal(ctx, a.Args[0]) - default: + case mysql.TypeDouble, mysql.TypeFloat: a.RetTp = types.NewFieldType(mysql.TypeDouble) a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, a.Args[0].GetType().Decimal // TODO: a.Args[0] = expression.WrapWithCastAsReal(ctx, a.Args[0]) + default: + a.RetTp = types.NewFieldType(mysql.TypeDouble) + a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, types.UnspecifiedLength + // TODO: a.Args[0] = expression.WrapWithCastAsReal(ctx, a.Args[0]) } types.SetBinChsClnFlag(a.RetTp) } diff --git a/plan/rule_aggregation_elimination.go b/plan/rule_aggregation_elimination.go index d307e784441a2..d0574d749b73d 100644 --- a/plan/rule_aggregation_elimination.go +++ b/plan/rule_aggregation_elimination.go @@ -29,7 +29,6 @@ type aggregationEliminator struct { } type aggregationEliminateChecker struct { - aggregation.AggFuncTypeInferer } // tryToEliminateAggregation will eliminate aggregation grouped by unique key. @@ -95,7 +94,7 @@ func (a *aggregationEliminateChecker) rewriteCount(ctx sessionctx.Context, exprs isNullExprs = append(isNullExprs, isNullExpr) } innerExpr := expression.ComposeDNFCondition(ctx, isNullExprs...) - newExpr := expression.NewFunctionInternal(ctx, ast.If, a.InferCount(ctx), innerExpr, expression.Zero, expression.One) + newExpr := expression.NewFunctionInternal(ctx, ast.If, targetTp, innerExpr, expression.Zero, expression.One) return newExpr } From cad8b4d7c11759e8b79980ff7257eccb3865f3a0 Mon Sep 17 00:00:00 2001 From: Yiding Cui Date: Tue, 25 Sep 2018 20:35:30 +0800 Subject: [PATCH 09/16] add comment --- plan/rule_aggregation_elimination.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/plan/rule_aggregation_elimination.go b/plan/rule_aggregation_elimination.go index d0574d749b73d..63175e648ebc0 100644 --- a/plan/rule_aggregation_elimination.go +++ b/plan/rule_aggregation_elimination.go @@ -99,6 +99,7 @@ func (a *aggregationEliminateChecker) rewriteCount(ctx sessionctx.Context, exprs } func (a *aggregationEliminateChecker) rewriteBitFunc(ctx sessionctx.Context, funcType string, arg expression.Expression, targetTp *types.FieldType) expression.Expression { + // For not integer type. We need to cast(cast(arg as signed) as unsigned) to make the bit function work. innerCast := expression.WrapWithCastAsInt(ctx, arg) outerCast := a.wrapCastFunction(ctx, innerCast, targetTp) var finalExpr expression.Expression @@ -110,6 +111,7 @@ func (a *aggregationEliminateChecker) rewriteBitFunc(ctx sessionctx.Context, fun return finalExpr } +// wrapCastFunction will wrap a cast if the targetTp is not equal to the arg's. func (a *aggregationEliminateChecker) wrapCastFunction(ctx sessionctx.Context, arg expression.Expression, targetTp *types.FieldType) expression.Expression { if arg.GetType() == targetTp { return arg From b6745baaa54c348f7f0b8d116514232cbf38f8a1 Mon Sep 17 00:00:00 2001 From: Yiding Cui Date: Tue, 25 Sep 2018 20:38:52 +0800 Subject: [PATCH 10/16] fix merge error --- planner/core/rule_aggregation_elimination.go | 137 +++++++++++++++++++ 1 file changed, 137 insertions(+) create mode 100644 planner/core/rule_aggregation_elimination.go diff --git a/planner/core/rule_aggregation_elimination.go b/planner/core/rule_aggregation_elimination.go new file mode 100644 index 0000000000000..8e92e117931d4 --- /dev/null +++ b/planner/core/rule_aggregation_elimination.go @@ -0,0 +1,137 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package core + +import ( + "math" + + "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/expression/aggregation" + "github.com/pingcap/tidb/mysql" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/types" +) + +type aggregationEliminator struct { + aggregationEliminateChecker +} + +type aggregationEliminateChecker struct { +} + +// tryToEliminateAggregation will eliminate aggregation grouped by unique key. +// e.g. select min(b) from t group by a. If a is a unique key, then this sql is equal to `select b from t group by a`. +// For count(expr), sum(expr), avg(expr), count(distinct expr, [expr...]) we may need to rewrite the expr. Details are shown below. +// If we can eliminate agg successful, we return a projection. Else we return a nil pointer. +func (a *aggregationEliminateChecker) tryToEliminateAggregation(agg *LogicalAggregation) *LogicalProjection { + schemaByGroupby := expression.NewSchema(agg.groupByCols...) + coveredByUniqueKey := false + for _, key := range agg.children[0].Schema().Keys { + if schemaByGroupby.ColumnsIndices(key) != nil { + coveredByUniqueKey = true + break + } + } + if coveredByUniqueKey { + // GroupByCols has unique key, so this aggregation can be removed. + proj := a.convertAggToProj(agg) + proj.SetChildren(agg.children[0]) + return proj + } + return nil +} + +func (a *aggregationEliminateChecker) convertAggToProj(agg *LogicalAggregation) *LogicalProjection { + proj := LogicalProjection{ + Exprs: make([]expression.Expression, 0, len(agg.AggFuncs)), + }.init(agg.ctx) + for _, fun := range agg.AggFuncs { + expr := a.rewriteExpr(agg.ctx, fun) + proj.Exprs = append(proj.Exprs, expr) + } + proj.SetSchema(agg.schema.Clone()) + return proj +} + +// rewriteExpr will rewrite the aggregate function to expression doesn't contain aggregate function. +func (a *aggregationEliminateChecker) rewriteExpr(ctx sessionctx.Context, aggFunc *aggregation.AggFuncDesc) expression.Expression { + switch aggFunc.Name { + case ast.AggFuncCount: + return a.rewriteCount(ctx, aggFunc.Args, aggFunc.RetTp) + case ast.AggFuncSum: + return a.wrapCastFunction(ctx, aggFunc.Args[0], aggFunc.RetTp) + case ast.AggFuncAvg: + return a.wrapCastFunction(ctx, aggFunc.Args[0], aggFunc.RetTp) + case ast.AggFuncFirstRow, ast.AggFuncMax, ast.AggFuncMin: + return a.wrapCastFunction(ctx, aggFunc.Args[0], aggFunc.RetTp) + case ast.AggFuncBitAnd, ast.AggFuncBitOr, ast.AggFuncBitXor: + return a.rewriteBitFunc(ctx, aggFunc.Name, aggFunc.Args[0], aggFunc.RetTp) + case ast.AggFuncGroupConcat: + return a.wrapCastFunction(ctx, aggFunc.Args[0], aggFunc.RetTp) + default: + panic("Unsupported function") + } +} + +func (a *aggregationEliminateChecker) rewriteCount(ctx sessionctx.Context, exprs []expression.Expression, targetTp *types.FieldType) expression.Expression { + // If is count(expr), we will change it to if(isnull(expr), 0, 1). + // If is count(distinct x, y, z) we will change it to if(isnull(x) or isnull(y) or isnull(z), 0, 1). + isNullExprs := make([]expression.Expression, 0, len(exprs)) + for _, expr := range exprs { + isNullExpr := expression.NewFunctionInternal(ctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), expr) + isNullExprs = append(isNullExprs, isNullExpr) + } + innerExpr := expression.ComposeDNFCondition(ctx, isNullExprs...) + newExpr := expression.NewFunctionInternal(ctx, ast.If, targetTp, innerExpr, expression.Zero, expression.One) + return newExpr +} + +func (a *aggregationEliminateChecker) rewriteBitFunc(ctx sessionctx.Context, funcType string, arg expression.Expression, targetTp *types.FieldType) expression.Expression { + // For not integer type. We need to cast(cast(arg as signed) as unsigned) to make the bit function work. + innerCast := expression.WrapWithCastAsInt(ctx, arg) + outerCast := a.wrapCastFunction(ctx, innerCast, targetTp) + var finalExpr expression.Expression + if funcType != ast.AggFuncBitAnd { + finalExpr = expression.NewFunctionInternal(ctx, ast.Ifnull, targetTp, outerCast, expression.Zero.Clone()) + } else { + finalExpr = expression.NewFunctionInternal(ctx, ast.Ifnull, outerCast.GetType(), outerCast, &expression.Constant{Value: types.NewUintDatum(math.MaxUint64), RetType: targetTp}) + } + return finalExpr +} + +// wrapCastFunction will wrap a cast if the targetTp is not equal to the arg's. +func (a *aggregationEliminateChecker) wrapCastFunction(ctx sessionctx.Context, arg expression.Expression, targetTp *types.FieldType) expression.Expression { + if arg.GetType() == targetTp { + return arg + } + return expression.BuildCastFunction(ctx, arg, targetTp) +} + +func (a *aggregationEliminator) optimize(p LogicalPlan) (LogicalPlan, error) { + newChildren := make([]LogicalPlan, 0, len(p.Children())) + for _, child := range p.Children() { + newChild, _ := a.optimize(child) + newChildren = append(newChildren, newChild) + } + p.SetChildren(newChildren...) + agg, ok := p.(*LogicalAggregation) + if !ok { + return p, nil + } + if proj := a.tryToEliminateAggregation(agg); proj != nil { + return proj, nil + } + return p, nil +} From 9aceaf4254b74f82f262e9c3daf451757324211c Mon Sep 17 00:00:00 2001 From: Yiding Cui Date: Tue, 25 Sep 2018 20:40:53 +0800 Subject: [PATCH 11/16] delete file. --- plan/rule_aggregation_elimination.go | 137 --------------------------- 1 file changed, 137 deletions(-) delete mode 100644 plan/rule_aggregation_elimination.go diff --git a/plan/rule_aggregation_elimination.go b/plan/rule_aggregation_elimination.go deleted file mode 100644 index 63175e648ebc0..0000000000000 --- a/plan/rule_aggregation_elimination.go +++ /dev/null @@ -1,137 +0,0 @@ -// Copyright 2018 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// See the License for the specific language governing permissions and -// limitations under the License. - -package plan - -import ( - "math" - - "github.com/pingcap/tidb/ast" - "github.com/pingcap/tidb/expression" - "github.com/pingcap/tidb/expression/aggregation" - "github.com/pingcap/tidb/mysql" - "github.com/pingcap/tidb/sessionctx" - "github.com/pingcap/tidb/types" -) - -type aggregationEliminator struct { - aggregationEliminateChecker -} - -type aggregationEliminateChecker struct { -} - -// tryToEliminateAggregation will eliminate aggregation grouped by unique key. -// e.g. select min(b) from t group by a. If a is a unique key, then this sql is equal to `select b from t group by a`. -// For count(expr), sum(expr), avg(expr), count(distinct expr, [expr...]) we may need to rewrite the expr. Details are shown below. -// If we can eliminate agg successful, we return a projection. Else we return a nil pointer. -func (a *aggregationEliminateChecker) tryToEliminateAggregation(agg *LogicalAggregation) *LogicalProjection { - schemaByGroupby := expression.NewSchema(agg.groupByCols...) - coveredByUniqueKey := false - for _, key := range agg.children[0].Schema().Keys { - if schemaByGroupby.ColumnsIndices(key) != nil { - coveredByUniqueKey = true - break - } - } - if coveredByUniqueKey { - // GroupByCols has unique key, so this aggregation can be removed. - proj := a.convertAggToProj(agg) - proj.SetChildren(agg.children[0]) - return proj - } - return nil -} - -func (a *aggregationEliminateChecker) convertAggToProj(agg *LogicalAggregation) *LogicalProjection { - proj := LogicalProjection{ - Exprs: make([]expression.Expression, 0, len(agg.AggFuncs)), - }.init(agg.ctx) - for _, fun := range agg.AggFuncs { - expr := a.rewriteExpr(agg.ctx, fun) - proj.Exprs = append(proj.Exprs, expr) - } - proj.SetSchema(agg.schema.Clone()) - return proj -} - -// rewriteExpr will rewrite the aggregate function to expression doesn't contain aggregate function. -func (a *aggregationEliminateChecker) rewriteExpr(ctx sessionctx.Context, aggFunc *aggregation.AggFuncDesc) expression.Expression { - switch aggFunc.Name { - case ast.AggFuncCount: - return a.rewriteCount(ctx, aggFunc.Args, aggFunc.RetTp) - case ast.AggFuncSum: - return a.wrapCastFunction(ctx, aggFunc.Args[0], aggFunc.RetTp) - case ast.AggFuncAvg: - return a.wrapCastFunction(ctx, aggFunc.Args[0], aggFunc.RetTp) - case ast.AggFuncFirstRow, ast.AggFuncMax, ast.AggFuncMin: - return a.wrapCastFunction(ctx, aggFunc.Args[0], aggFunc.RetTp) - case ast.AggFuncBitAnd, ast.AggFuncBitOr, ast.AggFuncBitXor: - return a.rewriteBitFunc(ctx, aggFunc.Name, aggFunc.Args[0], aggFunc.RetTp) - case ast.AggFuncGroupConcat: - return a.wrapCastFunction(ctx, aggFunc.Args[0], aggFunc.RetTp) - default: - panic("Unsupported function") - } -} - -func (a *aggregationEliminateChecker) rewriteCount(ctx sessionctx.Context, exprs []expression.Expression, targetTp *types.FieldType) expression.Expression { - // If is count(expr), we will change it to if(isnull(expr), 0, 1). - // If is count(distinct x, y, z) we will change it to if(isnull(x) or isnull(y) or isnull(z), 0, 1). - isNullExprs := make([]expression.Expression, 0, len(exprs)) - for _, expr := range exprs { - isNullExpr := expression.NewFunctionInternal(ctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), expr) - isNullExprs = append(isNullExprs, isNullExpr) - } - innerExpr := expression.ComposeDNFCondition(ctx, isNullExprs...) - newExpr := expression.NewFunctionInternal(ctx, ast.If, targetTp, innerExpr, expression.Zero, expression.One) - return newExpr -} - -func (a *aggregationEliminateChecker) rewriteBitFunc(ctx sessionctx.Context, funcType string, arg expression.Expression, targetTp *types.FieldType) expression.Expression { - // For not integer type. We need to cast(cast(arg as signed) as unsigned) to make the bit function work. - innerCast := expression.WrapWithCastAsInt(ctx, arg) - outerCast := a.wrapCastFunction(ctx, innerCast, targetTp) - var finalExpr expression.Expression - if funcType != ast.AggFuncBitAnd { - finalExpr = expression.NewFunctionInternal(ctx, ast.Ifnull, targetTp, outerCast, expression.Zero.Clone()) - } else { - finalExpr = expression.NewFunctionInternal(ctx, ast.Ifnull, outerCast.GetType(), outerCast, &expression.Constant{Value: types.NewUintDatum(math.MaxUint64), RetType: targetTp}) - } - return finalExpr -} - -// wrapCastFunction will wrap a cast if the targetTp is not equal to the arg's. -func (a *aggregationEliminateChecker) wrapCastFunction(ctx sessionctx.Context, arg expression.Expression, targetTp *types.FieldType) expression.Expression { - if arg.GetType() == targetTp { - return arg - } - return expression.BuildCastFunction(ctx, arg, targetTp) -} - -func (a *aggregationEliminator) optimize(p LogicalPlan) (LogicalPlan, error) { - newChildren := make([]LogicalPlan, 0, len(p.Children())) - for _, child := range p.Children() { - newChild, _ := a.optimize(child) - newChildren = append(newChildren, newChild) - } - p.SetChildren(newChildren...) - agg, ok := p.(*LogicalAggregation) - if !ok { - return p, nil - } - if proj := a.tryToEliminateAggregation(agg); proj != nil { - return proj, nil - } - return p, nil -} From 2fa94d652dec8252cb94df09525a811114533561 Mon Sep 17 00:00:00 2001 From: Yiding Cui Date: Tue, 25 Sep 2018 22:22:47 +0800 Subject: [PATCH 12/16] fix unit-test --- expression/typeinfer_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/expression/typeinfer_test.go b/expression/typeinfer_test.go index 142c0c0c9bf41..bb86b140d5caf 100644 --- a/expression/typeinfer_test.go +++ b/expression/typeinfer_test.go @@ -822,14 +822,14 @@ func (s *testInferTypeSuite) createTestCase4Aggregations() []typeInferTestCase { {"sum(c_decimal)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxDecimalWidth, 3}, {"sum(1.0)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxDecimalWidth, 1}, {"sum(1.2e2)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, types.UnspecifiedLength}, - {"sum(c_char)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, 0}, + {"sum(c_char)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, types.UnspecifiedLength}, {"avg(c_int_d)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxDecimalWidth, 4}, {"avg(c_float_d)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, types.UnspecifiedLength}, {"avg(c_double_d)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, types.UnspecifiedLength}, {"avg(c_decimal)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxDecimalWidth, 7}, {"avg(1.0)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxDecimalWidth, 5}, {"avg(1.2e2)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, types.UnspecifiedLength}, - {"avg(c_char)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, 0}, + {"avg(c_char)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, types.UnspecifiedLength}, {"group_concat(c_int_d)", mysql.TypeVarString, charset.CharsetUTF8, 0, mysql.MaxBlobWidth, 0}, } } From 1ffeb5385c29333f2e15ec93f8e3861874ca3f57 Mon Sep 17 00:00:00 2001 From: Yiding Cui Date: Wed, 26 Sep 2018 13:30:17 +0800 Subject: [PATCH 13/16] fix behavior when opening push down. --- planner/core/rule_aggregation_elimination.go | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/planner/core/rule_aggregation_elimination.go b/planner/core/rule_aggregation_elimination.go index 8e92e117931d4..f062e74a2f9cd 100644 --- a/planner/core/rule_aggregation_elimination.go +++ b/planner/core/rule_aggregation_elimination.go @@ -69,17 +69,14 @@ func (a *aggregationEliminateChecker) convertAggToProj(agg *LogicalAggregation) func (a *aggregationEliminateChecker) rewriteExpr(ctx sessionctx.Context, aggFunc *aggregation.AggFuncDesc) expression.Expression { switch aggFunc.Name { case ast.AggFuncCount: + if aggFunc.Mode == aggregation.FinalMode { + return a.wrapCastFunction(ctx, aggFunc.Args[0], aggFunc.RetTp) + } return a.rewriteCount(ctx, aggFunc.Args, aggFunc.RetTp) - case ast.AggFuncSum: - return a.wrapCastFunction(ctx, aggFunc.Args[0], aggFunc.RetTp) - case ast.AggFuncAvg: - return a.wrapCastFunction(ctx, aggFunc.Args[0], aggFunc.RetTp) - case ast.AggFuncFirstRow, ast.AggFuncMax, ast.AggFuncMin: + case ast.AggFuncSum, ast.AggFuncAvg, ast.AggFuncFirstRow, ast.AggFuncMax, ast.AggFuncMin, ast.AggFuncGroupConcat: return a.wrapCastFunction(ctx, aggFunc.Args[0], aggFunc.RetTp) case ast.AggFuncBitAnd, ast.AggFuncBitOr, ast.AggFuncBitXor: return a.rewriteBitFunc(ctx, aggFunc.Name, aggFunc.Args[0], aggFunc.RetTp) - case ast.AggFuncGroupConcat: - return a.wrapCastFunction(ctx, aggFunc.Args[0], aggFunc.RetTp) default: panic("Unsupported function") } From c45d8ee22052551b48ebff2f9f93c303e4eb8314 Mon Sep 17 00:00:00 2001 From: Yiding Cui Date: Thu, 27 Sep 2018 12:19:43 +0800 Subject: [PATCH 14/16] change the order of the rule. --- planner/core/optimizer.go | 4 ++-- planner/core/physical_plan_test.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/planner/core/optimizer.go b/planner/core/optimizer.go index 531193e4b7098..67e93412597aa 100644 --- a/planner/core/optimizer.go +++ b/planner/core/optimizer.go @@ -32,8 +32,8 @@ const ( flagPrunColumns uint64 = 1 << iota flagEliminateProjection flagBuildKeyInfo - flagEliminateAgg flagDecorrelate + flagEliminateAgg flagMaxMinEliminate flagPredicatePushDown flagPartitionProcessor @@ -45,8 +45,8 @@ var optRuleList = []logicalOptRule{ &columnPruner{}, &projectionEliminater{}, &buildKeySolver{}, - &aggregationEliminator{}, &decorrelateSolver{}, + &aggregationEliminator{}, &maxMinEliminator{}, &ppdSolver{}, &partitionProcessor{}, diff --git a/planner/core/physical_plan_test.go b/planner/core/physical_plan_test.go index 2a82e8bac589d..332171a49f6f8 100644 --- a/planner/core/physical_plan_test.go +++ b/planner/core/physical_plan_test.go @@ -851,7 +851,7 @@ func (s *testPlanSuite) TestDAGPlanBuilderAgg(c *C) { }, { sql: "select (select count(1) k from t s where s.a = t.a having k != 0) from t", - best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t)->StreamAgg)->StreamAgg->Sel([ne(k, 0)])}(test.t.a,s.a)->Projection->Projection", + best: "MergeLeftOuterJoin{TableReader(Table(t))->TableReader(Table(t))->Projection}(test.t.a,s.a)->Projection->Projection", }, // Test stream agg with multi group by columns. { From 7d5dd0082d41adf120d223dcb0bac899343ae1ea Mon Sep 17 00:00:00 2001 From: Yiding Cui Date: Thu, 27 Sep 2018 13:30:33 +0800 Subject: [PATCH 15/16] fix explain test. --- cmd/explaintest/r/explain_easy.result | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/cmd/explaintest/r/explain_easy.result b/cmd/explaintest/r/explain_easy.result index 3ff9b16527833..839af28e3f448 100644 --- a/cmd/explaintest/r/explain_easy.result +++ b/cmd/explaintest/r/explain_easy.result @@ -96,15 +96,14 @@ TableReader_11 2.00 root data:TableScan_10 └─TableScan_10 2.00 cop table:t1, range:[0,0], [1,1], keep order:false, stats:pseudo explain select (select count(1) k from t1 s where s.c1 = t1.c1 having k != 0) from t1; id count task operator info -Projection_13 10000.00 root k -└─Projection_14 10000.00 root test.t1.c1, ifnull(5_col_0, 0) - └─MergeJoin_15 10000.00 root left outer join, left key:test.t1.c1, right key:s.c1 - ├─TableReader_18 10000.00 root data:TableScan_17 - │ └─TableScan_17 10000.00 cop table:t1, range:[-inf,+inf], keep order:true, stats:pseudo - └─Selection_20 8000.00 root ne(k, 0) - └─Projection_21 10000.00 root 1, s.c1 - └─TableReader_23 10000.00 root data:TableScan_22 - └─TableScan_22 10000.00 cop table:s, range:[-inf,+inf], keep order:true, stats:pseudo +Projection_12 10000.00 root k +└─Projection_13 10000.00 root test.t1.c1, ifnull(5_col_0, 0) + └─MergeJoin_14 10000.00 root left outer join, left key:test.t1.c1, right key:s.c1 + ├─TableReader_17 10000.00 root data:TableScan_16 + │ └─TableScan_16 10000.00 cop table:t1, range:[-inf,+inf], keep order:true, stats:pseudo + └─Projection_19 8000.00 root 1, s.c1 + └─TableReader_21 10000.00 root data:TableScan_20 + └─TableScan_20 10000.00 cop table:s, range:[-inf,+inf], keep order:true, stats:pseudo explain select * from information_schema.columns; id count task operator info MemTableScan_4 10000.00 root From a50d975f7b174a9b0924f8ebda5b7297714b8f49 Mon Sep 17 00:00:00 2001 From: Yiding Cui Date: Thu, 27 Sep 2018 17:28:48 +0800 Subject: [PATCH 16/16] remove unnecessary change after #7792 --- expression/aggregation/descriptor.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/expression/aggregation/descriptor.go b/expression/aggregation/descriptor.go index 5370a6248c850..2d40426374518 100644 --- a/expression/aggregation/descriptor.go +++ b/expression/aggregation/descriptor.go @@ -297,7 +297,6 @@ func (a *AggFuncDesc) typeInfer4Sum(ctx sessionctx.Context) { if a.RetTp.Decimal < 0 || a.RetTp.Decimal > mysql.MaxDecimalScale { a.RetTp.Decimal = mysql.MaxDecimalScale } - a.RetTp.Flag |= a.Args[0].GetType().Flag & mysql.UnsignedFlag // TODO: a.Args[0] = expression.WrapWithCastAsDecimal(ctx, a.Args[0]) case mysql.TypeDouble, mysql.TypeFloat: a.RetTp = types.NewFieldType(mysql.TypeDouble) @@ -322,7 +321,6 @@ func (a *AggFuncDesc) typeInfer4Avg(ctx sessionctx.Context) { } else { a.RetTp.Decimal = mathutil.Min(a.Args[0].GetType().Decimal+types.DivFracIncr, mysql.MaxDecimalScale) } - a.RetTp.Flag |= a.Args[0].GetType().Flag & mysql.UnsignedFlag a.RetTp.Flen = mysql.MaxDecimalWidth // TODO: a.Args[0] = expression.WrapWithCastAsDecimal(ctx, a.Args[0]) case mysql.TypeDouble, mysql.TypeFloat: