Skip to content

Commit

Permalink
planner: Add trace for agg pushdown rule (#30262)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yisaer authored Dec 2, 2021
1 parent cbe5240 commit dae711c
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 8 deletions.
31 changes: 31 additions & 0 deletions planner/core/logical_plan_trace_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,36 @@ func (s *testPlanSuite) TestSingleRuleTraceStep(c *C) {
},
},
},
{
sql: "select count(*) from t a , t b, t c",
flags: []uint64{flagBuildKeyInfo, flagPrunColumns, flagPushDownAgg},
assertRuleName: "aggregation_push_down",
assertRuleSteps: []assertTraceStep{
{
assertAction: "agg[6] pushed down across join[5], and join right path becomes agg[8]",
assertReason: "agg[6]'s functions[count(Column#38)] are decomposable with join",
},
},
},
{
sql: "select sum(c1) from (select c c1, d c2 from t a union all select a c1, b c2 from t b) x group by c2",
flags: []uint64{flagBuildKeyInfo, flagPrunColumns, flagPushDownAgg},
assertRuleName: "aggregation_push_down",
assertRuleSteps: []assertTraceStep{
{
assertAction: "agg[8] pushed down, and union[5]'s children changed into[[id:11,tp:Aggregation],[id:12,tp:Aggregation]]",
assertReason: "agg[8] functions[sum(Column#28)] are decomposable with union",
},
{
assertAction: "proj[6] is eliminated, and agg[11]'s functions changed into[sum(test.t.c),firstrow(test.t.d)]",
assertReason: "Proj[6] is directly below an agg[11] and has no side effects",
},
{
assertAction: "proj[7] is eliminated, and agg[12]'s functions changed into[sum(test.t.a),firstrow(test.t.b)]",
assertReason: "Proj[7] is directly below an agg[12] and has no side effects",
},
},
},
}

for i, tc := range tt {
Expand All @@ -123,6 +153,7 @@ func (s *testPlanSuite) TestSingleRuleTraceStep(c *C) {
c.Assert(err, IsNil, comment)
sctx := MockContext()
sctx.GetSessionVars().StmtCtx.EnableOptimizeTrace = true
sctx.GetSessionVars().AllowAggPushDown = true
builder, _ := NewPlanBuilder().Init(sctx, s.is, &hint.BlockHintProcessor{})
domain.GetDomain(sctx).MockInfoCacheAndLoadInfoSchema(s.is)
ctx := context.TODO()
Expand Down
93 changes: 87 additions & 6 deletions planner/core/rule_aggregation_push_down.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
package core

import (
"bytes"
"context"
"fmt"

"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/expression/aggregation"
Expand All @@ -32,6 +34,7 @@ type aggregationPushDownSolver struct {
// isDecomposable checks if an aggregate function is decomposable. An aggregation function $F$ is decomposable
// if there exist aggregation functions F_1 and F_2 such that F(S_1 union all S_2) = F_2(F_1(S_1),F_1(S_2)),
// where S_1 and S_2 are two sets of values. We call S_1 and S_2 partial groups.
// For example, Max(S_1 union S_2) = Max(Max(S_1) union Max(S_2)), thus we think Max is decomposable.
// It's easy to see that max, min, first row is decomposable, no matter whether it's distinct, but sum(distinct) and
// count(distinct) is not.
// Currently we don't support avg and concat.
Expand Down Expand Up @@ -207,7 +210,8 @@ func (a *aggregationPushDownSolver) decompose(ctx sessionctx.Context, aggFunc *a
// tryToPushDownAgg tries to push down an aggregate function into a join path. If all aggFuncs are first row, we won't
// process it temporarily. If not, We will add additional group by columns and first row functions. We make a new aggregation operator.
// If the pushed aggregation is grouped by unique key, it's no need to push it down.
func (a *aggregationPushDownSolver) tryToPushDownAgg(aggFuncs []*aggregation.AggFuncDesc, gbyCols []*expression.Column, join *LogicalJoin, childIdx int, aggHints aggHintInfo, blockOffset int) (_ LogicalPlan, err error) {
func (a *aggregationPushDownSolver) tryToPushDownAgg(oldAgg *LogicalAggregation, aggFuncs []*aggregation.AggFuncDesc, gbyCols []*expression.Column,
join *LogicalJoin, childIdx int, aggHints aggHintInfo, blockOffset int, opt *logicalOptimizeOp) (_ LogicalPlan, err error) {
child := join.children[childIdx]
if aggregation.IsAllFirstRow(aggFuncs) {
return child, nil
Expand Down Expand Up @@ -241,6 +245,7 @@ func (a *aggregationPushDownSolver) tryToPushDownAgg(aggFuncs []*aggregation.Agg
return child, nil
}
}
appendAggPushDownAcrossJoinTraceStep(oldAgg, agg, aggFuncs, join, childIdx, opt)
return agg, nil
}

Expand Down Expand Up @@ -371,7 +376,7 @@ func (a *aggregationPushDownSolver) optimize(ctx context.Context, p LogicalPlan,
return a.aggPushDown(p, opt)
}

func (a *aggregationPushDownSolver) tryAggPushDownForUnion(union *LogicalUnionAll, agg *LogicalAggregation) error {
func (a *aggregationPushDownSolver) tryAggPushDownForUnion(union *LogicalUnionAll, agg *LogicalAggregation, opt *logicalOptimizeOp) error {
for _, aggFunc := range agg.AggFuncs {
if !a.isDecomposableWithUnion(aggFunc) {
return nil
Expand All @@ -391,6 +396,7 @@ func (a *aggregationPushDownSolver) tryAggPushDownForUnion(union *LogicalUnionAl
}
union.SetSchema(expression.NewSchema(newChildren[0].Schema().Clone().Columns...))
union.SetChildren(newChildren...)
appendAggPushDownAcrossUnionTraceStep(union, agg, opt)
return nil
}

Expand All @@ -402,6 +408,9 @@ func (a *aggregationPushDownSolver) aggPushDown(p LogicalPlan, opt *logicalOptim
p = proj
} else {
child := agg.children[0]
// For example, we can optimize 'select sum(a.id) from t as a,t as b where a.id = b.id;' as
// 'select sum(agg) from (select sum(id) as agg,id from t group by id) as a, t as b where a.id = b.id;'
// by pushing down sum aggregation functions.
if join, ok1 := child.(*LogicalJoin); ok1 && a.checkValidJoin(join) && p.SCtx().GetSessionVars().AllowAggPushDown {
if valid, leftAggFuncs, rightAggFuncs, leftGbyCols, rightGbyCols := a.splitAggFuncsAndGbyCols(agg, join); valid {
var lChild, rChild LogicalPlan
Expand All @@ -412,15 +421,15 @@ func (a *aggregationPushDownSolver) aggPushDown(p LogicalPlan, opt *logicalOptim
if rightInvalid {
rChild = join.children[1]
} else {
rChild, err = a.tryToPushDownAgg(rightAggFuncs, rightGbyCols, join, 1, agg.aggHints, agg.blockOffset)
rChild, err = a.tryToPushDownAgg(agg, rightAggFuncs, rightGbyCols, join, 1, agg.aggHints, agg.blockOffset, opt)
if err != nil {
return nil, err
}
}
if leftInvalid {
lChild = join.children[0]
} else {
lChild, err = a.tryToPushDownAgg(leftAggFuncs, leftGbyCols, join, 0, agg.aggHints, agg.blockOffset)
lChild, err = a.tryToPushDownAgg(agg, leftAggFuncs, leftGbyCols, join, 0, agg.aggHints, agg.blockOffset, opt)
if err != nil {
return nil, err
}
Expand All @@ -433,6 +442,7 @@ func (a *aggregationPushDownSolver) aggPushDown(p LogicalPlan, opt *logicalOptim
p = proj
}
}
// push aggregation across projection
} else if proj, ok1 := child.(*LogicalProjection); ok1 {
// TODO: This optimization is not always reasonable. We have not supported pushing projection to kv layer yet,
// so we must do this optimization.
Expand All @@ -445,9 +455,11 @@ func (a *aggregationPushDownSolver) aggPushDown(p LogicalPlan, opt *logicalOptim
break
}
}
oldAggFuncsArgs := make([][]expression.Expression, 0, len(agg.AggFuncs))
newAggFuncsArgs := make([][]expression.Expression, 0, len(agg.AggFuncs))
if noSideEffects {
for _, aggFunc := range agg.AggFuncs {
oldAggFuncsArgs = append(oldAggFuncsArgs, aggFunc.Args)
newArgs := make([]expression.Expression, 0, len(aggFunc.Args))
for _, arg := range aggFunc.Args {
newArgs = append(newArgs, expression.ColumnSubstitute(arg, proj.schema, proj.Exprs))
Expand All @@ -470,15 +482,16 @@ func (a *aggregationPushDownSolver) aggPushDown(p LogicalPlan, opt *logicalOptim
// And then push the new 'Aggregation' below the 'Union All' .
// The final plan tree should be 'Aggregation->Union All->Aggregation->X'.
child = projChild
appendAggPushDownAcrossProjTraceStep(agg, proj, opt)
}
}
if union, ok1 := child.(*LogicalUnionAll); ok1 && p.SCtx().GetSessionVars().AllowAggPushDown {
err := a.tryAggPushDownForUnion(union, agg)
err := a.tryAggPushDownForUnion(union, agg, opt)
if err != nil {
return nil, err
}
} else if union, ok1 := child.(*LogicalPartitionUnionAll); ok1 {
err := a.tryAggPushDownForUnion(&union.LogicalUnionAll, agg)
err := a.tryAggPushDownForUnion(&union.LogicalUnionAll, agg, opt)
if err != nil {
return nil, err
}
Expand All @@ -500,3 +513,71 @@ func (a *aggregationPushDownSolver) aggPushDown(p LogicalPlan, opt *logicalOptim
func (*aggregationPushDownSolver) name() string {
return "aggregation_push_down"
}

func appendAggPushDownAcrossJoinTraceStep(oldAgg, newAgg *LogicalAggregation, aggFuncs []*aggregation.AggFuncDesc, join *LogicalJoin,
childIdx int, opt *logicalOptimizeOp) {
reason := func() string {
buffer := bytes.NewBufferString(fmt.Sprintf("agg[%v]'s functions[", oldAgg.ID()))
for i, aggFunc := range aggFuncs {
if i > 0 {
buffer.WriteString(",")
}
buffer.WriteString(aggFunc.String())
}
buffer.WriteString("] are decomposable with join")
return buffer.String()
}()
action := func() string {
buffer := bytes.NewBufferString(fmt.Sprintf("agg[%v] pushed down across join[%v], ", oldAgg.ID(), join.ID()))
buffer.WriteString(fmt.Sprintf("and join %v path becomes agg[%v]", func() string {
if childIdx == 0 {
return "left"
}
return "right"
}(), newAgg.ID()))
return buffer.String()
}()
opt.appendStepToCurrent(join.ID(), join.TP(), reason, action)
}

func appendAggPushDownAcrossProjTraceStep(agg *LogicalAggregation, proj *LogicalProjection, opt *logicalOptimizeOp) {
action := func() string {
buffer := bytes.NewBufferString(fmt.Sprintf("proj[%v] is eliminated, and agg[%v]'s functions changed into[", proj.ID(), agg.ID()))
for i, aggFunc := range agg.AggFuncs {
if i > 0 {
buffer.WriteString(",")
}
buffer.WriteString(aggFunc.String())
}
buffer.WriteString("]")
return buffer.String()
}()
reason := fmt.Sprintf("Proj[%v] is directly below an agg[%v] and has no side effects", proj.ID(), agg.ID())
opt.appendStepToCurrent(agg.ID(), agg.TP(), reason, action)
}

func appendAggPushDownAcrossUnionTraceStep(union *LogicalUnionAll, agg *LogicalAggregation, opt *logicalOptimizeOp) {
reason := func() string {
buffer := bytes.NewBufferString(fmt.Sprintf("agg[%v] functions[", agg.ID()))
for i, aggFunc := range agg.AggFuncs {
if i > 0 {
buffer.WriteString(",")
}
buffer.WriteString(aggFunc.String())
}
buffer.WriteString("] are decomposable with union")
return buffer.String()
}()
action := func() string {
buffer := bytes.NewBufferString(fmt.Sprintf("agg[%v] pushed down, and union[%v]'s children changed into[", agg.ID(), union.ID()))
for i, child := range union.Children() {
if i > 0 {
buffer.WriteString(",")
}
buffer.WriteString(fmt.Sprintf("[id:%v,tp:%s]", child.ID(), child.TP()))
}
buffer.WriteString("]")
return buffer.String()
}()
opt.appendStepToCurrent(union.ID(), union.TP(), reason, action)
}
11 changes: 9 additions & 2 deletions planner/core/rule_eliminate_projection.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ func (pe *projectionEliminator) eliminate(p LogicalPlan, replace map[string]*exp
proj.Exprs[i] = foldedExpr
}
p.Children()[0] = child.Children()[0]
appendProjEliminateTraceStep(proj, child, opt)
appendDupProjEliminateTraceStep(proj, child, opt)
}
}

Expand All @@ -199,6 +199,7 @@ func (pe *projectionEliminator) eliminate(p LogicalPlan, replace map[string]*exp
for i, col := range proj.Schema().Columns {
replace[string(col.HashCode(nil))] = exprs[i].(*expression.Column)
}
appendProjEliminateTraceStep(proj, opt)
return p.Children()[0]
}

Expand Down Expand Up @@ -296,7 +297,7 @@ func (*projectionEliminator) name() string {
return "projection_eliminate"
}

func appendProjEliminateTraceStep(parent, child *LogicalProjection, opt *logicalOptimizeOp) {
func appendDupProjEliminateTraceStep(parent, child *LogicalProjection, opt *logicalOptimizeOp) {
action := func() string {
buffer := bytes.NewBufferString(
fmt.Sprintf("Proj[%v] is eliminated, Proj[%v]'s expressions changed into[", child.ID(), parent.ID()))
Expand All @@ -312,3 +313,9 @@ func appendProjEliminateTraceStep(parent, child *LogicalProjection, opt *logical
reason := fmt.Sprintf("Proj[%v]'s child proj[%v] is redundant", parent.ID(), child.ID())
opt.appendStepToCurrent(child.ID(), child.TP(), reason, action)
}

func appendProjEliminateTraceStep(proj *LogicalProjection, opt *logicalOptimizeOp) {
reason := fmt.Sprintf("Proj[%v]'s Exprs are all Columns", proj.ID())
action := fmt.Sprintf("Proj[%v] is eliminated", proj.ID())
opt.appendStepToCurrent(proj.ID(), proj.TP(), reason, action)
}

0 comments on commit dae711c

Please sign in to comment.