Skip to content

Commit

Permalink
planner/core: make join reorder by dp work (#8816)
Browse files Browse the repository at this point in the history
  • Loading branch information
winoros authored and zz-jason committed Apr 28, 2019
1 parent 3f89a6e commit 9630d57
Show file tree
Hide file tree
Showing 15 changed files with 265 additions and 97 deletions.
13 changes: 13 additions & 0 deletions expression/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,19 @@ func Filter(result []Expression, input []Expression, filter func(Expression) boo
return result
}

// FilterOutInPlace do the filtering out in place.
// The remained are the ones who doesn't match the filter, storing in the original slice.
// The filteredOut are the ones match the filter, storing in a new slice.
func FilterOutInPlace(input []Expression, filter func(Expression) bool) (remained, filteredOut []Expression) {
for i := len(input) - 1; i >= 0; i-- {
if filter(input[i]) {
filteredOut = append(filteredOut, input[i])
input = append(input[:i], input[i+1:]...)
}
}
return input, filteredOut
}

// ExtractColumns extracts all columns from an expression.
func ExtractColumns(expr Expression) (cols []*Column) {
// Pre-allocate a slice to reduce allocation, 8 doesn't have special meaning.
Expand Down
14 changes: 14 additions & 0 deletions expression/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,20 @@ func (s *testUtilSuite) TestFilter(c *check.C) {
c.Assert(result, check.HasLen, 1)
}

func (s *testUtilSuite) TestFilterOutInPlace(c *check.C) {
conditions := []Expression{
newFunction(ast.EQ, newColumn(0), newColumn(1)),
newFunction(ast.EQ, newColumn(1), newColumn(2)),
newFunction(ast.LogicOr, newLonglong(1), newColumn(0)),
}
remained, filtered := FilterOutInPlace(conditions, isLogicOrFunction)
c.Assert(len(remained), check.Equals, 2)
c.Assert(remained[0].(*ScalarFunction).FuncName.L, check.Equals, "eq")
c.Assert(remained[1].(*ScalarFunction).FuncName.L, check.Equals, "eq")
c.Assert(len(filtered), check.Equals, 1)
c.Assert(filtered[0].(*ScalarFunction).FuncName.L, check.Equals, "or")
}

func isLogicOrFunction(e Expression) bool {
if f, ok := e.(*ScalarFunction); ok {
return f.FuncName.L == ast.LogicOr
Expand Down
2 changes: 1 addition & 1 deletion planner/core/expression_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@ func (er *expressionRewriter) handleInSubquery(v *ast.PatternInExpr) (ast.Node,
// We need to try to eliminate the agg and the projection produced by this operation.
er.b.optFlag |= flagEliminateAgg
er.b.optFlag |= flagEliminateProjection
er.b.optFlag |= flagJoinReOrderGreedy
er.b.optFlag |= flagJoinReOrder
// Build distinct for the inner query.
agg := er.b.buildDistinct(np, np.Schema().Len())
for _, col := range agg.schema.Columns {
Expand Down
2 changes: 1 addition & 1 deletion planner/core/logical_plan_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ func (b *PlanBuilder) buildJoin(joinNode *ast.Join) (LogicalPlan, error) {
joinPlan.JoinType = RightOuterJoin
resetNotNullFlag(joinPlan.schema, 0, leftPlan.Schema().Len())
default:
b.optFlag = b.optFlag | flagJoinReOrderGreedy
b.optFlag = b.optFlag | flagJoinReOrder
joinPlan.JoinType = InnerJoin
}

Expand Down
2 changes: 1 addition & 1 deletion planner/core/logical_plan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -916,7 +916,7 @@ func (s *testPlanSuite) TestJoinReOrder(c *C) {

p, err := BuildLogicalPlan(s.ctx, stmt, s.is)
c.Assert(err, IsNil)
p, err = logicalOptimize(flagPredicatePushDown|flagJoinReOrderGreedy, p.(LogicalPlan))
p, err = logicalOptimize(flagPredicatePushDown|flagJoinReOrder, p.(LogicalPlan))
c.Assert(err, IsNil)
c.Assert(ToString(p), Equals, tt.best, Commentf("for %s", tt.sql))
}
Expand Down
2 changes: 1 addition & 1 deletion planner/core/optimizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ const (
flagPartitionProcessor
flagPushDownAgg
flagPushDownTopN
flagJoinReOrderGreedy
flagJoinReOrder
)

var optRuleList = []logicalOptRule{
Expand Down
29 changes: 15 additions & 14 deletions planner/core/rule_join_reorder.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,19 @@ func (s *joinReOrderSolver) optimizeRecursive(ctx sessionctx.Context, p LogicalP
ctx: ctx,
otherConds: otherConds,
}
groupSolver := &joinReorderGreedySingleGroupSolver{
baseSingleGroupJoinOrderSolver: baseGroupSolver,
eqEdges: eqEdges,
if len(curJoinGroup) > ctx.GetSessionVars().TiDBOptJoinReorderThreshold {
groupSolver := &joinReorderGreedySolver{
baseSingleGroupJoinOrderSolver: baseGroupSolver,
eqEdges: eqEdges,
}
p, err = groupSolver.solve(curJoinGroup)
} else {
dpSolver := &joinReorderDPSolver{
baseSingleGroupJoinOrderSolver: baseGroupSolver,
}
dpSolver.newJoin = dpSolver.newJoinWithEdges
p, err = dpSolver.solve(curJoinGroup, expression.ScalarFuncs2Exprs(eqEdges))
}
p, err = groupSolver.solve(curJoinGroup)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -143,22 +151,15 @@ func (s *baseSingleGroupJoinOrderSolver) newCartesianJoin(lChild, rChild Logical
return join
}

func (s *baseSingleGroupJoinOrderSolver) newJoinWithEdges(eqEdges []*expression.ScalarFunction, remainedOtherConds []expression.Expression,
lChild, rChild LogicalPlan) (*LogicalJoin, []expression.Expression) {
func (s *baseSingleGroupJoinOrderSolver) newJoinWithEdges(lChild, rChild LogicalPlan, eqEdges []*expression.ScalarFunction, otherConds []expression.Expression) LogicalPlan {
newJoin := s.newCartesianJoin(lChild, rChild)
newJoin.EqualConditions = eqEdges
newJoin.OtherConditions = otherConds
for _, eqCond := range newJoin.EqualConditions {
newJoin.LeftJoinKeys = append(newJoin.LeftJoinKeys, eqCond.GetArgs()[0].(*expression.Column))
newJoin.RightJoinKeys = append(newJoin.RightJoinKeys, eqCond.GetArgs()[1].(*expression.Column))
}
for i := len(remainedOtherConds) - 1; i >= 0; i-- {
cols := expression.ExtractColumns(remainedOtherConds[i])
if newJoin.schema.ColumnsIndices(cols) != nil {
newJoin.OtherConditions = append(newJoin.OtherConditions, remainedOtherConds[i])
remainedOtherConds = append(remainedOtherConds[:i], remainedOtherConds[i+1:]...)
}
}
return newJoin, remainedOtherConds
return newJoin
}

// calcJoinCumCost calculates the cumulative cost of the join node.
Expand Down
170 changes: 134 additions & 36 deletions planner/core/rule_join_reorder_dp.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,32 +18,47 @@ import (

"github.com/pingcap/parser/ast"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/sessionctx"
)

type joinReorderDPSolver struct {
ctx sessionctx.Context
newJoin func(lChild, rChild LogicalPlan, eqConds []*expression.ScalarFunction) LogicalPlan
*baseSingleGroupJoinOrderSolver
newJoin func(lChild, rChild LogicalPlan, eqConds []*expression.ScalarFunction, otherConds []expression.Expression) LogicalPlan
}

type joinGroupEdge struct {
type joinGroupEqEdge struct {
nodeIDs []int
edge *expression.ScalarFunction
}

func (s *joinReorderDPSolver) solve(joinGroup []LogicalPlan, conds []expression.Expression) (LogicalPlan, error) {
adjacents := make([][]int, len(joinGroup))
totalEdges := make([]joinGroupEdge, 0, len(conds))
addEdge := func(node1, node2 int, edgeContent *expression.ScalarFunction) {
totalEdges = append(totalEdges, joinGroupEdge{
type joinGroupNonEqEdge struct {
nodeIDs []int
nodeIDMask uint
expr expression.Expression
}

func (s *joinReorderDPSolver) solve(joinGroup []LogicalPlan, eqConds []expression.Expression) (LogicalPlan, error) {
for _, node := range joinGroup {
_, err := node.recursiveDeriveStats()
if err != nil {
return nil, err
}
s.curJoinGroup = append(s.curJoinGroup, &jrNode{
p: node,
cumCost: s.baseNodeCumCost(node),
})
}
adjacents := make([][]int, len(s.curJoinGroup))
totalEqEdges := make([]joinGroupEqEdge, 0, len(eqConds))
addEqEdge := func(node1, node2 int, edgeContent *expression.ScalarFunction) {
totalEqEdges = append(totalEqEdges, joinGroupEqEdge{
nodeIDs: []int{node1, node2},
edge: edgeContent,
})
adjacents[node1] = append(adjacents[node1], node2)
adjacents[node2] = append(adjacents[node2], node1)
}
// Build Graph for join group
for _, cond := range conds {
for _, cond := range eqConds {
sf := cond.(*expression.ScalarFunction)
lCol := sf.GetArgs()[0].(*expression.Column)
rCol := sf.GetArgs()[1].(*expression.Column)
Expand All @@ -55,7 +70,26 @@ func (s *joinReorderDPSolver) solve(joinGroup []LogicalPlan, conds []expression.
if err != nil {
return nil, err
}
addEdge(lIdx, rIdx, sf)
addEqEdge(lIdx, rIdx, sf)
}
totalNonEqEdges := make([]joinGroupNonEqEdge, 0, len(s.otherConds))
for _, cond := range s.otherConds {
cols := expression.ExtractColumns(cond)
mask := uint(0)
ids := make([]int, 0, len(cols))
for _, col := range cols {
idx, err := findNodeIndexInGroup(joinGroup, col)
if err != nil {
return nil, err
}
ids = append(ids, idx)
mask |= 1 << uint(idx)
}
totalNonEqEdges = append(totalNonEqEdges, joinGroupNonEqEdge{
nodeIDs: ids,
nodeIDMask: mask,
expr: cond,
})
}
visited := make([]bool, len(joinGroup))
nodeID2VisitID := make([]int, len(joinGroup))
Expand All @@ -66,15 +100,37 @@ func (s *joinReorderDPSolver) solve(joinGroup []LogicalPlan, conds []expression.
continue
}
visitID2NodeID := s.bfsGraph(i, visited, adjacents, nodeID2VisitID)
nodeIDMask := uint(0)
for _, nodeID := range visitID2NodeID {
nodeIDMask |= 1 << uint(nodeID)
}
var subNonEqEdges []joinGroupNonEqEdge
for i := len(totalNonEqEdges) - 1; i >= 0; i-- {
// If this edge is not the subset of the current sub graph.
if totalNonEqEdges[i].nodeIDMask&nodeIDMask != totalNonEqEdges[i].nodeIDMask {
continue
}
newMask := uint(0)
for _, nodeID := range totalNonEqEdges[i].nodeIDs {
newMask |= 1 << uint(nodeID2VisitID[nodeID])
}
totalNonEqEdges[i].nodeIDMask = newMask
subNonEqEdges = append(subNonEqEdges, totalNonEqEdges[i])
totalNonEqEdges = append(totalNonEqEdges[:i], totalNonEqEdges[i+1:]...)
}
// Do DP on each sub graph.
join, err := s.dpGraph(visitID2NodeID, nodeID2VisitID, joinGroup, totalEdges)
join, err := s.dpGraph(visitID2NodeID, nodeID2VisitID, joinGroup, totalEqEdges, subNonEqEdges)
if err != nil {
return nil, err
}
joins = append(joins, join)
}
remainedOtherConds := make([]expression.Expression, 0, len(totalNonEqEdges))
for _, edge := range totalNonEqEdges {
remainedOtherConds = append(remainedOtherConds, edge.expr)
}
// Build bushy tree for cartesian joins.
return s.makeBushyJoin(joins), nil
return s.makeBushyJoin(joins, remainedOtherConds), nil
}

// bfsGraph bfs a sub graph starting at startPos. And relabel its label for future use.
Expand All @@ -98,13 +154,16 @@ func (s *joinReorderDPSolver) bfsGraph(startNode int, visited []bool, adjacents
return visitID2NodeID
}

func (s *joinReorderDPSolver) dpGraph(newPos2OldPos, oldPos2NewPos []int, joinGroup []LogicalPlan, totalEdges []joinGroupEdge) (LogicalPlan, error) {
nodeCnt := uint(len(newPos2OldPos))
bestPlan := make([]LogicalPlan, 1<<nodeCnt)
bestCost := make([]int64, 1<<nodeCnt)
// dpGraph is the core part of this algorithm.
// It implements the traditional join reorder algorithm: DP by subset using the following formula:
// bestPlan[S:set of node] = the best one among Join(bestPlan[S1:subset of S], bestPlan[S2: S/S1])
func (s *joinReorderDPSolver) dpGraph(visitID2NodeID, nodeID2VisitID []int, joinGroup []LogicalPlan,
totalEqEdges []joinGroupEqEdge, totalNonEqEdges []joinGroupNonEqEdge) (LogicalPlan, error) {
nodeCnt := uint(len(visitID2NodeID))
bestPlan := make([]*jrNode, 1<<nodeCnt)
// bestPlan[s] is nil can be treated as bestCost[s] = +inf.
for i := uint(0); i < nodeCnt; i++ {
bestPlan[1<<i] = joinGroup[newPos2OldPos[i]]
bestPlan[1<<i] = s.curJoinGroup[visitID2NodeID[i]]
}
// Enumerate the nodeBitmap from small to big, make sure that S1 must be enumerated before S2 if S1 belongs to S2.
for nodeBitmap := uint(1); nodeBitmap < (1 << nodeCnt); nodeBitmap++ {
Expand All @@ -122,38 +181,58 @@ func (s *joinReorderDPSolver) dpGraph(newPos2OldPos, oldPos2NewPos []int, joinGr
continue
}
// Get the edge connecting the two parts.
usedEdges := s.nodesAreConnected(sub, remain, oldPos2NewPos, totalEdges)
usedEdges, otherConds := s.nodesAreConnected(sub, remain, nodeID2VisitID, totalEqEdges, totalNonEqEdges)
// Here we only check equal condition currently.
if len(usedEdges) == 0 {
continue
}
join, err := s.newJoinWithEdge(bestPlan[sub], bestPlan[remain], usedEdges)
join, err := s.newJoinWithEdge(bestPlan[sub].p, bestPlan[remain].p, usedEdges, otherConds)
if err != nil {
return nil, err
}
if bestPlan[nodeBitmap] == nil || bestCost[nodeBitmap] > join.statsInfo().Count()+bestCost[remain]+bestCost[sub] {
bestPlan[nodeBitmap] = join
bestCost[nodeBitmap] = join.statsInfo().Count() + bestCost[remain] + bestCost[sub]
curCost := s.calcJoinCumCost(join, bestPlan[sub], bestPlan[remain])
if bestPlan[nodeBitmap] == nil {
bestPlan[nodeBitmap] = &jrNode{
p: join,
cumCost: curCost,
}
} else if bestPlan[nodeBitmap].cumCost > curCost {
bestPlan[nodeBitmap].p = join
bestPlan[nodeBitmap].cumCost = curCost
}
}
}
return bestPlan[(1<<nodeCnt)-1], nil
return bestPlan[(1<<nodeCnt)-1].p, nil
}

func (s *joinReorderDPSolver) nodesAreConnected(leftMask, rightMask uint, oldPos2NewPos []int, totalEdges []joinGroupEdge) []joinGroupEdge {
var usedEdges []joinGroupEdge
for _, edge := range totalEdges {
func (s *joinReorderDPSolver) nodesAreConnected(leftMask, rightMask uint, oldPos2NewPos []int,
totalEqEdges []joinGroupEqEdge, totalNonEqEdges []joinGroupNonEqEdge) ([]joinGroupEqEdge, []expression.Expression) {
var (
usedEqEdges []joinGroupEqEdge
otherConds []expression.Expression
)
for _, edge := range totalEqEdges {
lIdx := uint(oldPos2NewPos[edge.nodeIDs[0]])
rIdx := uint(oldPos2NewPos[edge.nodeIDs[1]])
if (leftMask&(1<<lIdx)) > 0 && (rightMask&(1<<rIdx)) > 0 {
usedEdges = append(usedEdges, edge)
} else if (leftMask&(1<<rIdx)) > 0 && (rightMask&(1<<lIdx)) > 0 {
usedEdges = append(usedEdges, edge)
if ((leftMask&(1<<lIdx)) > 0 && (rightMask&(1<<rIdx)) > 0) || ((leftMask&(1<<rIdx)) > 0 && (rightMask&(1<<lIdx)) > 0) {
usedEqEdges = append(usedEqEdges, edge)
}
}
return usedEdges
for _, edge := range totalNonEqEdges {
// If the result is false, means that the current group hasn't covered the columns involved in the expression.
if edge.nodeIDMask&(leftMask|rightMask) != edge.nodeIDMask {
continue
}
// Check whether this expression is only built from one side of the join.
if edge.nodeIDMask&leftMask == 0 || edge.nodeIDMask&rightMask == 0 {
continue
}
otherConds = append(otherConds, edge.expr)
}
return usedEqEdges, otherConds
}

func (s *joinReorderDPSolver) newJoinWithEdge(leftPlan, rightPlan LogicalPlan, edges []joinGroupEdge) (LogicalPlan, error) {
func (s *joinReorderDPSolver) newJoinWithEdge(leftPlan, rightPlan LogicalPlan, edges []joinGroupEqEdge, otherConds []expression.Expression) (LogicalPlan, error) {
var eqConds []*expression.ScalarFunction
for _, edge := range edges {
lCol := edge.edge.GetArgs()[0].(*expression.Column)
Expand All @@ -165,21 +244,29 @@ func (s *joinReorderDPSolver) newJoinWithEdge(leftPlan, rightPlan LogicalPlan, e
eqConds = append(eqConds, newSf)
}
}
join := s.newJoin(leftPlan, rightPlan, eqConds)
join := s.newJoin(leftPlan, rightPlan, eqConds, otherConds)
_, err := join.recursiveDeriveStats()
return join, err
}

// Make cartesian join as bushy tree.
func (s *joinReorderDPSolver) makeBushyJoin(cartesianJoinGroup []LogicalPlan) LogicalPlan {
func (s *joinReorderDPSolver) makeBushyJoin(cartesianJoinGroup []LogicalPlan, otherConds []expression.Expression) LogicalPlan {
for len(cartesianJoinGroup) > 1 {
resultJoinGroup := make([]LogicalPlan, 0, len(cartesianJoinGroup))
for i := 0; i < len(cartesianJoinGroup); i += 2 {
if i+1 == len(cartesianJoinGroup) {
resultJoinGroup = append(resultJoinGroup, cartesianJoinGroup[i])
break
}
resultJoinGroup = append(resultJoinGroup, s.newJoin(cartesianJoinGroup[i], cartesianJoinGroup[i+1], nil))
// TODO:Since the other condition may involve more than two tables, e.g. t1.a = t2.b+t3.c.
// So We'll need a extra stage to deal with it.
// Currently, we just add it when building cartesianJoinGroup.
mergedSchema := expression.MergeSchema(cartesianJoinGroup[i].Schema(), cartesianJoinGroup[i+1].Schema())
var usedOtherConds []expression.Expression
otherConds, usedOtherConds = expression.FilterOutInPlace(otherConds, func(expr expression.Expression) bool {
return expression.ExprFromSchema(expr, mergedSchema)
})
resultJoinGroup = append(resultJoinGroup, s.newJoin(cartesianJoinGroup[i], cartesianJoinGroup[i+1], nil, usedOtherConds))
}
cartesianJoinGroup = resultJoinGroup
}
Expand All @@ -194,3 +281,14 @@ func findNodeIndexInGroup(group []LogicalPlan, col *expression.Column) (int, err
}
return -1, ErrUnknownColumn.GenWithStackByArgs(col, "JOIN REORDER RULE")
}

func (s *joinReorderDPSolver) newJoinWithConds(leftPlan, rightPlan LogicalPlan, eqConds []*expression.ScalarFunction, otherConds []expression.Expression) LogicalPlan {
join := s.newCartesianJoin(leftPlan, rightPlan)
join.EqualConditions = eqConds
join.OtherConditions = otherConds
for _, eqCond := range join.EqualConditions {
join.LeftJoinKeys = append(join.LeftJoinKeys, eqCond.GetArgs()[0].(*expression.Column))
join.RightJoinKeys = append(join.RightJoinKeys, eqCond.GetArgs()[1].(*expression.Column))
}
return join
}
Loading

0 comments on commit 9630d57

Please sign in to comment.