diff --git a/expression/util.go b/expression/util.go index d11f3b45a54b8..881a8baa19a9b 100644 --- a/expression/util.go +++ b/expression/util.go @@ -44,6 +44,19 @@ func Filter(result []Expression, input []Expression, filter func(Expression) boo return result } +// FilterOutInPlace do the filtering out in place. +// The remained are the ones who doesn't match the filter, storing in the original slice. +// The filteredOut are the ones match the filter, storing in a new slice. +func FilterOutInPlace(input []Expression, filter func(Expression) bool) (remained, filteredOut []Expression) { + for i := len(input) - 1; i >= 0; i-- { + if filter(input[i]) { + filteredOut = append(filteredOut, input[i]) + input = append(input[:i], input[i+1:]...) + } + } + return input, filteredOut +} + // ExtractColumns extracts all columns from an expression. func ExtractColumns(expr Expression) (cols []*Column) { // Pre-allocate a slice to reduce allocation, 8 doesn't have special meaning. diff --git a/expression/util_test.go b/expression/util_test.go index 4aaef213bb483..1d9f9b936860f 100644 --- a/expression/util_test.go +++ b/expression/util_test.go @@ -83,6 +83,20 @@ func (s *testUtilSuite) TestFilter(c *check.C) { c.Assert(result, check.HasLen, 1) } +func (s *testUtilSuite) TestFilterOutInPlace(c *check.C) { + conditions := []Expression{ + newFunction(ast.EQ, newColumn(0), newColumn(1)), + newFunction(ast.EQ, newColumn(1), newColumn(2)), + newFunction(ast.LogicOr, newLonglong(1), newColumn(0)), + } + remained, filtered := FilterOutInPlace(conditions, isLogicOrFunction) + c.Assert(len(remained), check.Equals, 2) + c.Assert(remained[0].(*ScalarFunction).FuncName.L, check.Equals, "eq") + c.Assert(remained[1].(*ScalarFunction).FuncName.L, check.Equals, "eq") + c.Assert(len(filtered), check.Equals, 1) + c.Assert(filtered[0].(*ScalarFunction).FuncName.L, check.Equals, "or") +} + func isLogicOrFunction(e Expression) bool { if f, ok := e.(*ScalarFunction); ok { return f.FuncName.L == ast.LogicOr diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index d5e163d1e931b..0f36b2bc32eb2 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -710,7 +710,7 @@ func (er *expressionRewriter) handleInSubquery(v *ast.PatternInExpr) (ast.Node, // We need to try to eliminate the agg and the projection produced by this operation. er.b.optFlag |= flagEliminateAgg er.b.optFlag |= flagEliminateProjection - er.b.optFlag |= flagJoinReOrderGreedy + er.b.optFlag |= flagJoinReOrder // Build distinct for the inner query. agg := er.b.buildDistinct(np, np.Schema().Len()) for _, col := range agg.schema.Columns { diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index d00f28d561647..b9dcec36df240 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -390,7 +390,7 @@ func (b *PlanBuilder) buildJoin(joinNode *ast.Join) (LogicalPlan, error) { joinPlan.JoinType = RightOuterJoin resetNotNullFlag(joinPlan.schema, 0, leftPlan.Schema().Len()) default: - b.optFlag = b.optFlag | flagJoinReOrderGreedy + b.optFlag = b.optFlag | flagJoinReOrder joinPlan.JoinType = InnerJoin } diff --git a/planner/core/logical_plan_test.go b/planner/core/logical_plan_test.go index a1e795bf59ed9..578155d266563 100644 --- a/planner/core/logical_plan_test.go +++ b/planner/core/logical_plan_test.go @@ -916,7 +916,7 @@ func (s *testPlanSuite) TestJoinReOrder(c *C) { p, err := BuildLogicalPlan(s.ctx, stmt, s.is) c.Assert(err, IsNil) - p, err = logicalOptimize(flagPredicatePushDown|flagJoinReOrderGreedy, p.(LogicalPlan)) + p, err = logicalOptimize(flagPredicatePushDown|flagJoinReOrder, p.(LogicalPlan)) c.Assert(err, IsNil) c.Assert(ToString(p), Equals, tt.best, Commentf("for %s", tt.sql)) } diff --git a/planner/core/optimizer.go b/planner/core/optimizer.go index 1baf596854a66..cdee37b93fef2 100644 --- a/planner/core/optimizer.go +++ b/planner/core/optimizer.go @@ -44,7 +44,7 @@ const ( flagPartitionProcessor flagPushDownAgg flagPushDownTopN - flagJoinReOrderGreedy + flagJoinReOrder ) var optRuleList = []logicalOptRule{ diff --git a/planner/core/rule_join_reorder.go b/planner/core/rule_join_reorder.go index 8e5a1c4f4d7d2..fac63d725cbb5 100644 --- a/planner/core/rule_join_reorder.go +++ b/planner/core/rule_join_reorder.go @@ -71,11 +71,19 @@ func (s *joinReOrderSolver) optimizeRecursive(ctx sessionctx.Context, p LogicalP ctx: ctx, otherConds: otherConds, } - groupSolver := &joinReorderGreedySingleGroupSolver{ - baseSingleGroupJoinOrderSolver: baseGroupSolver, - eqEdges: eqEdges, + if len(curJoinGroup) > ctx.GetSessionVars().TiDBOptJoinReorderThreshold { + groupSolver := &joinReorderGreedySolver{ + baseSingleGroupJoinOrderSolver: baseGroupSolver, + eqEdges: eqEdges, + } + p, err = groupSolver.solve(curJoinGroup) + } else { + dpSolver := &joinReorderDPSolver{ + baseSingleGroupJoinOrderSolver: baseGroupSolver, + } + dpSolver.newJoin = dpSolver.newJoinWithEdges + p, err = dpSolver.solve(curJoinGroup, expression.ScalarFuncs2Exprs(eqEdges)) } - p, err = groupSolver.solve(curJoinGroup) if err != nil { return nil, err } @@ -143,22 +151,15 @@ func (s *baseSingleGroupJoinOrderSolver) newCartesianJoin(lChild, rChild Logical return join } -func (s *baseSingleGroupJoinOrderSolver) newJoinWithEdges(eqEdges []*expression.ScalarFunction, remainedOtherConds []expression.Expression, - lChild, rChild LogicalPlan) (*LogicalJoin, []expression.Expression) { +func (s *baseSingleGroupJoinOrderSolver) newJoinWithEdges(lChild, rChild LogicalPlan, eqEdges []*expression.ScalarFunction, otherConds []expression.Expression) LogicalPlan { newJoin := s.newCartesianJoin(lChild, rChild) newJoin.EqualConditions = eqEdges + newJoin.OtherConditions = otherConds for _, eqCond := range newJoin.EqualConditions { newJoin.LeftJoinKeys = append(newJoin.LeftJoinKeys, eqCond.GetArgs()[0].(*expression.Column)) newJoin.RightJoinKeys = append(newJoin.RightJoinKeys, eqCond.GetArgs()[1].(*expression.Column)) } - for i := len(remainedOtherConds) - 1; i >= 0; i-- { - cols := expression.ExtractColumns(remainedOtherConds[i]) - if newJoin.schema.ColumnsIndices(cols) != nil { - newJoin.OtherConditions = append(newJoin.OtherConditions, remainedOtherConds[i]) - remainedOtherConds = append(remainedOtherConds[:i], remainedOtherConds[i+1:]...) - } - } - return newJoin, remainedOtherConds + return newJoin } // calcJoinCumCost calculates the cumulative cost of the join node. diff --git a/planner/core/rule_join_reorder_dp.go b/planner/core/rule_join_reorder_dp.go index f4820fe7380fd..18b549d7813c8 100644 --- a/planner/core/rule_join_reorder_dp.go +++ b/planner/core/rule_join_reorder_dp.go @@ -18,24 +18,39 @@ import ( "github.com/pingcap/parser/ast" "github.com/pingcap/tidb/expression" - "github.com/pingcap/tidb/sessionctx" ) type joinReorderDPSolver struct { - ctx sessionctx.Context - newJoin func(lChild, rChild LogicalPlan, eqConds []*expression.ScalarFunction) LogicalPlan + *baseSingleGroupJoinOrderSolver + newJoin func(lChild, rChild LogicalPlan, eqConds []*expression.ScalarFunction, otherConds []expression.Expression) LogicalPlan } -type joinGroupEdge struct { +type joinGroupEqEdge struct { nodeIDs []int edge *expression.ScalarFunction } -func (s *joinReorderDPSolver) solve(joinGroup []LogicalPlan, conds []expression.Expression) (LogicalPlan, error) { - adjacents := make([][]int, len(joinGroup)) - totalEdges := make([]joinGroupEdge, 0, len(conds)) - addEdge := func(node1, node2 int, edgeContent *expression.ScalarFunction) { - totalEdges = append(totalEdges, joinGroupEdge{ +type joinGroupNonEqEdge struct { + nodeIDs []int + nodeIDMask uint + expr expression.Expression +} + +func (s *joinReorderDPSolver) solve(joinGroup []LogicalPlan, eqConds []expression.Expression) (LogicalPlan, error) { + for _, node := range joinGroup { + _, err := node.recursiveDeriveStats() + if err != nil { + return nil, err + } + s.curJoinGroup = append(s.curJoinGroup, &jrNode{ + p: node, + cumCost: s.baseNodeCumCost(node), + }) + } + adjacents := make([][]int, len(s.curJoinGroup)) + totalEqEdges := make([]joinGroupEqEdge, 0, len(eqConds)) + addEqEdge := func(node1, node2 int, edgeContent *expression.ScalarFunction) { + totalEqEdges = append(totalEqEdges, joinGroupEqEdge{ nodeIDs: []int{node1, node2}, edge: edgeContent, }) @@ -43,7 +58,7 @@ func (s *joinReorderDPSolver) solve(joinGroup []LogicalPlan, conds []expression. adjacents[node2] = append(adjacents[node2], node1) } // Build Graph for join group - for _, cond := range conds { + for _, cond := range eqConds { sf := cond.(*expression.ScalarFunction) lCol := sf.GetArgs()[0].(*expression.Column) rCol := sf.GetArgs()[1].(*expression.Column) @@ -55,7 +70,26 @@ func (s *joinReorderDPSolver) solve(joinGroup []LogicalPlan, conds []expression. if err != nil { return nil, err } - addEdge(lIdx, rIdx, sf) + addEqEdge(lIdx, rIdx, sf) + } + totalNonEqEdges := make([]joinGroupNonEqEdge, 0, len(s.otherConds)) + for _, cond := range s.otherConds { + cols := expression.ExtractColumns(cond) + mask := uint(0) + ids := make([]int, 0, len(cols)) + for _, col := range cols { + idx, err := findNodeIndexInGroup(joinGroup, col) + if err != nil { + return nil, err + } + ids = append(ids, idx) + mask |= 1 << uint(idx) + } + totalNonEqEdges = append(totalNonEqEdges, joinGroupNonEqEdge{ + nodeIDs: ids, + nodeIDMask: mask, + expr: cond, + }) } visited := make([]bool, len(joinGroup)) nodeID2VisitID := make([]int, len(joinGroup)) @@ -66,15 +100,37 @@ func (s *joinReorderDPSolver) solve(joinGroup []LogicalPlan, conds []expression. continue } visitID2NodeID := s.bfsGraph(i, visited, adjacents, nodeID2VisitID) + nodeIDMask := uint(0) + for _, nodeID := range visitID2NodeID { + nodeIDMask |= 1 << uint(nodeID) + } + var subNonEqEdges []joinGroupNonEqEdge + for i := len(totalNonEqEdges) - 1; i >= 0; i-- { + // If this edge is not the subset of the current sub graph. + if totalNonEqEdges[i].nodeIDMask&nodeIDMask != totalNonEqEdges[i].nodeIDMask { + continue + } + newMask := uint(0) + for _, nodeID := range totalNonEqEdges[i].nodeIDs { + newMask |= 1 << uint(nodeID2VisitID[nodeID]) + } + totalNonEqEdges[i].nodeIDMask = newMask + subNonEqEdges = append(subNonEqEdges, totalNonEqEdges[i]) + totalNonEqEdges = append(totalNonEqEdges[:i], totalNonEqEdges[i+1:]...) + } // Do DP on each sub graph. - join, err := s.dpGraph(visitID2NodeID, nodeID2VisitID, joinGroup, totalEdges) + join, err := s.dpGraph(visitID2NodeID, nodeID2VisitID, joinGroup, totalEqEdges, subNonEqEdges) if err != nil { return nil, err } joins = append(joins, join) } + remainedOtherConds := make([]expression.Expression, 0, len(totalNonEqEdges)) + for _, edge := range totalNonEqEdges { + remainedOtherConds = append(remainedOtherConds, edge.expr) + } // Build bushy tree for cartesian joins. - return s.makeBushyJoin(joins), nil + return s.makeBushyJoin(joins, remainedOtherConds), nil } // bfsGraph bfs a sub graph starting at startPos. And relabel its label for future use. @@ -98,13 +154,16 @@ func (s *joinReorderDPSolver) bfsGraph(startNode int, visited []bool, adjacents return visitID2NodeID } -func (s *joinReorderDPSolver) dpGraph(newPos2OldPos, oldPos2NewPos []int, joinGroup []LogicalPlan, totalEdges []joinGroupEdge) (LogicalPlan, error) { - nodeCnt := uint(len(newPos2OldPos)) - bestPlan := make([]LogicalPlan, 1< join.statsInfo().Count()+bestCost[remain]+bestCost[sub] { - bestPlan[nodeBitmap] = join - bestCost[nodeBitmap] = join.statsInfo().Count() + bestCost[remain] + bestCost[sub] + curCost := s.calcJoinCumCost(join, bestPlan[sub], bestPlan[remain]) + if bestPlan[nodeBitmap] == nil { + bestPlan[nodeBitmap] = &jrNode{ + p: join, + cumCost: curCost, + } + } else if bestPlan[nodeBitmap].cumCost > curCost { + bestPlan[nodeBitmap].p = join + bestPlan[nodeBitmap].cumCost = curCost } } } - return bestPlan[(1< 0 && (rightMask&(1< 0 { - usedEdges = append(usedEdges, edge) - } else if (leftMask&(1< 0 && (rightMask&(1< 0 { - usedEdges = append(usedEdges, edge) + if ((leftMask&(1< 0 && (rightMask&(1< 0) || ((leftMask&(1< 0 && (rightMask&(1< 0) { + usedEqEdges = append(usedEqEdges, edge) } } - return usedEdges + for _, edge := range totalNonEqEdges { + // If the result is false, means that the current group hasn't covered the columns involved in the expression. + if edge.nodeIDMask&(leftMask|rightMask) != edge.nodeIDMask { + continue + } + // Check whether this expression is only built from one side of the join. + if edge.nodeIDMask&leftMask == 0 || edge.nodeIDMask&rightMask == 0 { + continue + } + otherConds = append(otherConds, edge.expr) + } + return usedEqEdges, otherConds } -func (s *joinReorderDPSolver) newJoinWithEdge(leftPlan, rightPlan LogicalPlan, edges []joinGroupEdge) (LogicalPlan, error) { +func (s *joinReorderDPSolver) newJoinWithEdge(leftPlan, rightPlan LogicalPlan, edges []joinGroupEqEdge, otherConds []expression.Expression) (LogicalPlan, error) { var eqConds []*expression.ScalarFunction for _, edge := range edges { lCol := edge.edge.GetArgs()[0].(*expression.Column) @@ -165,13 +244,13 @@ func (s *joinReorderDPSolver) newJoinWithEdge(leftPlan, rightPlan LogicalPlan, e eqConds = append(eqConds, newSf) } } - join := s.newJoin(leftPlan, rightPlan, eqConds) + join := s.newJoin(leftPlan, rightPlan, eqConds, otherConds) _, err := join.recursiveDeriveStats() return join, err } // Make cartesian join as bushy tree. -func (s *joinReorderDPSolver) makeBushyJoin(cartesianJoinGroup []LogicalPlan) LogicalPlan { +func (s *joinReorderDPSolver) makeBushyJoin(cartesianJoinGroup []LogicalPlan, otherConds []expression.Expression) LogicalPlan { for len(cartesianJoinGroup) > 1 { resultJoinGroup := make([]LogicalPlan, 0, len(cartesianJoinGroup)) for i := 0; i < len(cartesianJoinGroup); i += 2 { @@ -179,7 +258,15 @@ func (s *joinReorderDPSolver) makeBushyJoin(cartesianJoinGroup []LogicalPlan) Lo resultJoinGroup = append(resultJoinGroup, cartesianJoinGroup[i]) break } - resultJoinGroup = append(resultJoinGroup, s.newJoin(cartesianJoinGroup[i], cartesianJoinGroup[i+1], nil)) + // TODO:Since the other condition may involve more than two tables, e.g. t1.a = t2.b+t3.c. + // So We'll need a extra stage to deal with it. + // Currently, we just add it when building cartesianJoinGroup. + mergedSchema := expression.MergeSchema(cartesianJoinGroup[i].Schema(), cartesianJoinGroup[i+1].Schema()) + var usedOtherConds []expression.Expression + otherConds, usedOtherConds = expression.FilterOutInPlace(otherConds, func(expr expression.Expression) bool { + return expression.ExprFromSchema(expr, mergedSchema) + }) + resultJoinGroup = append(resultJoinGroup, s.newJoin(cartesianJoinGroup[i], cartesianJoinGroup[i+1], nil, usedOtherConds)) } cartesianJoinGroup = resultJoinGroup } @@ -194,3 +281,14 @@ func findNodeIndexInGroup(group []LogicalPlan, col *expression.Column) (int, err } return -1, ErrUnknownColumn.GenWithStackByArgs(col, "JOIN REORDER RULE") } + +func (s *joinReorderDPSolver) newJoinWithConds(leftPlan, rightPlan LogicalPlan, eqConds []*expression.ScalarFunction, otherConds []expression.Expression) LogicalPlan { + join := s.newCartesianJoin(leftPlan, rightPlan) + join.EqualConditions = eqConds + join.OtherConditions = otherConds + for _, eqCond := range join.EqualConditions { + join.LeftJoinKeys = append(join.LeftJoinKeys, eqCond.GetArgs()[0].(*expression.Column)) + join.RightJoinKeys = append(join.RightJoinKeys, eqCond.GetArgs()[1].(*expression.Column)) + } + return join +} diff --git a/planner/core/rule_join_reorder_dp_test.go b/planner/core/rule_join_reorder_dp_test.go index c3d2790e18fb7..72e25a6507f4b 100644 --- a/planner/core/rule_join_reorder_dp_test.go +++ b/planner/core/rule_join_reorder_dp_test.go @@ -56,7 +56,7 @@ func (mj *mockLogicalJoin) recursiveDeriveStats() (*property.StatsInfo, error) { return mj.statsMap[mj.involvedNodeSet], nil } -func (s *testJoinReorderDPSuite) newMockJoin(lChild, rChild LogicalPlan, eqConds []*expression.ScalarFunction) LogicalPlan { +func (s *testJoinReorderDPSuite) newMockJoin(lChild, rChild LogicalPlan, eqConds []*expression.ScalarFunction, _ []expression.Expression) LogicalPlan { retJoin := mockLogicalJoin{}.init(s.ctx) retJoin.schema = expression.MergeSchema(lChild.Schema(), rChild.Schema()) retJoin.statsMap = s.statsMap @@ -145,7 +145,7 @@ func (s *testJoinReorderDPSuite) makeStatsMapForTPCHQ5() { } -func (s *testJoinReorderDPSuite) newDataSource(name string) LogicalPlan { +func (s *testJoinReorderDPSuite) newDataSource(name string, count int) LogicalPlan { ds := DataSource{}.Init(s.ctx) tan := model.NewCIStr(name) ds.TableAsName = &tan @@ -158,6 +158,9 @@ func (s *testJoinReorderDPSuite) newDataSource(name string) LogicalPlan { DBName: model.NewCIStr("test"), RetType: types.NewFieldType(mysql.TypeLonglong), }) + ds.stats = &property.StatsInfo{ + RowCount: float64(count), + } return ds } @@ -174,12 +177,12 @@ func (s *testJoinReorderDPSuite) planToString(plan LogicalPlan) string { func (s *testJoinReorderDPSuite) TestDPReorderTPCHQ5(c *C) { s.makeStatsMapForTPCHQ5() joinGroups := make([]LogicalPlan, 0, 6) - joinGroups = append(joinGroups, s.newDataSource("lineitem")) - joinGroups = append(joinGroups, s.newDataSource("orders")) - joinGroups = append(joinGroups, s.newDataSource("customer")) - joinGroups = append(joinGroups, s.newDataSource("supplier")) - joinGroups = append(joinGroups, s.newDataSource("nation")) - joinGroups = append(joinGroups, s.newDataSource("region")) + joinGroups = append(joinGroups, s.newDataSource("lineitem", 59986052)) + joinGroups = append(joinGroups, s.newDataSource("orders", 15000000)) + joinGroups = append(joinGroups, s.newDataSource("customer", 1500000)) + joinGroups = append(joinGroups, s.newDataSource("supplier", 100000)) + joinGroups = append(joinGroups, s.newDataSource("nation", 25)) + joinGroups = append(joinGroups, s.newDataSource("region", 5)) var eqConds []expression.Expression eqConds = append(eqConds, expression.NewFunctionInternal(s.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), joinGroups[0].Schema().Columns[0], joinGroups[1].Schema().Columns[0])) eqConds = append(eqConds, expression.NewFunctionInternal(s.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), joinGroups[1].Schema().Columns[0], joinGroups[2].Schema().Columns[0])) @@ -189,7 +192,9 @@ func (s *testJoinReorderDPSuite) TestDPReorderTPCHQ5(c *C) { eqConds = append(eqConds, expression.NewFunctionInternal(s.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), joinGroups[3].Schema().Columns[0], joinGroups[4].Schema().Columns[0])) eqConds = append(eqConds, expression.NewFunctionInternal(s.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), joinGroups[4].Schema().Columns[0], joinGroups[5].Schema().Columns[0])) solver := &joinReorderDPSolver{ - ctx: s.ctx, + baseSingleGroupJoinOrderSolver: &baseSingleGroupJoinOrderSolver{ + ctx: s.ctx, + }, newJoin: s.newMockJoin, } result, err := solver.solve(joinGroups, eqConds) @@ -199,12 +204,14 @@ func (s *testJoinReorderDPSuite) TestDPReorderTPCHQ5(c *C) { func (s *testJoinReorderDPSuite) TestDPReorderAllCartesian(c *C) { joinGroup := make([]LogicalPlan, 0, 4) - joinGroup = append(joinGroup, s.newDataSource("a")) - joinGroup = append(joinGroup, s.newDataSource("b")) - joinGroup = append(joinGroup, s.newDataSource("c")) - joinGroup = append(joinGroup, s.newDataSource("d")) + joinGroup = append(joinGroup, s.newDataSource("a", 100)) + joinGroup = append(joinGroup, s.newDataSource("b", 100)) + joinGroup = append(joinGroup, s.newDataSource("c", 100)) + joinGroup = append(joinGroup, s.newDataSource("d", 100)) solver := &joinReorderDPSolver{ - ctx: s.ctx, + baseSingleGroupJoinOrderSolver: &baseSingleGroupJoinOrderSolver{ + ctx: s.ctx, + }, newJoin: s.newMockJoin, } result, err := solver.solve(joinGroup, nil) diff --git a/planner/core/rule_join_reorder_greedy.go b/planner/core/rule_join_reorder_greedy.go index 5260fa1bde4b4..6bdf993ec54ba 100644 --- a/planner/core/rule_join_reorder_greedy.go +++ b/planner/core/rule_join_reorder_greedy.go @@ -21,7 +21,7 @@ import ( "github.com/pingcap/tidb/expression" ) -type joinReorderGreedySingleGroupSolver struct { +type joinReorderGreedySolver struct { *baseSingleGroupJoinOrderSolver eqEdges []*expression.ScalarFunction } @@ -40,7 +40,7 @@ type joinReorderGreedySingleGroupSolver struct { // // For the nodes and join trees which don't have a join equal condition to // connect them, we make a bushy join tree to do the cartesian joins finally. -func (s *joinReorderGreedySingleGroupSolver) solve(joinNodePlans []LogicalPlan) (LogicalPlan, error) { +func (s *joinReorderGreedySolver) solve(joinNodePlans []LogicalPlan) (LogicalPlan, error) { for _, node := range joinNodePlans { _, err := node.recursiveDeriveStats() if err != nil { @@ -67,7 +67,7 @@ func (s *joinReorderGreedySingleGroupSolver) solve(joinNodePlans []LogicalPlan) return s.makeBushyJoin(cartesianGroup), nil } -func (s *joinReorderGreedySingleGroupSolver) constructConnectedJoinTree() (*jrNode, error) { +func (s *joinReorderGreedySolver) constructConnectedJoinTree() (*jrNode, error) { curJoinTree := s.curJoinGroup[0] s.curJoinGroup = s.curJoinGroup[1:] for { @@ -106,7 +106,7 @@ func (s *joinReorderGreedySingleGroupSolver) constructConnectedJoinTree() (*jrNo return curJoinTree, nil } -func (s *joinReorderGreedySingleGroupSolver) checkConnectionAndMakeJoin(leftNode, rightNode LogicalPlan) (LogicalPlan, []expression.Expression) { +func (s *joinReorderGreedySolver) checkConnectionAndMakeJoin(leftNode, rightNode LogicalPlan) (LogicalPlan, []expression.Expression) { var usedEdges []*expression.ScalarFunction remainOtherConds := make([]expression.Expression, len(s.otherConds)) copy(remainOtherConds, s.otherConds) @@ -123,5 +123,10 @@ func (s *joinReorderGreedySingleGroupSolver) checkConnectionAndMakeJoin(leftNode if len(usedEdges) == 0 { return nil, nil } - return s.newJoinWithEdges(usedEdges, remainOtherConds, leftNode, rightNode) + var otherConds []expression.Expression + mergedSchema := expression.MergeSchema(leftNode.Schema(), rightNode.Schema()) + remainOtherConds, otherConds = expression.FilterOutInPlace(remainOtherConds, func(expr expression.Expression) bool { + return expression.ExprFromSchema(expr, mergedSchema) + }) + return s.newJoinWithEdges(leftNode, rightNode, usedEdges, otherConds), remainOtherConds } diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index c9e46d7b30ecc..da2a76c05d413 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -347,6 +347,10 @@ type SessionVars struct { // CommandValue indicates which command current session is doing. CommandValue uint32 + // TIDBOptJoinOrderAlgoThreshold defines the minimal number of join nodes + // to use the greedy join reorder algorithm. + TiDBOptJoinReorderThreshold int + // SlowQueryFile indicates which slow query log file for SLOW_QUERY table to parse. SlowQueryFile string @@ -377,30 +381,31 @@ type ConnectionInfo struct { // NewSessionVars creates a session vars object. func NewSessionVars() *SessionVars { vars := &SessionVars{ - Users: make(map[string]string), - systems: make(map[string]string), - PreparedStmts: make(map[uint32]*ast.Prepared), - PreparedStmtNameToID: make(map[string]uint32), - PreparedParams: make([]types.Datum, 0, 10), - TxnCtx: &TransactionContext{}, - KVVars: kv.NewVariables(), - RetryInfo: &RetryInfo{}, - ActiveRoles: make([]*auth.RoleIdentity, 0, 10), - StrictSQLMode: true, - Status: mysql.ServerStatusAutocommit, - StmtCtx: new(stmtctx.StatementContext), - AllowAggPushDown: false, - OptimizerSelectivityLevel: DefTiDBOptimizerSelectivityLevel, - RetryLimit: DefTiDBRetryLimit, - DisableTxnAutoRetry: DefTiDBDisableTxnAutoRetry, - DDLReorgPriority: kv.PriorityLow, - AllowInSubqToJoinAndAgg: DefOptInSubqToJoinAndAgg, - CorrelationThreshold: DefOptCorrelationThreshold, - CorrelationExpFactor: DefOptCorrelationExpFactor, - EnableRadixJoin: false, - L2CacheSize: cpuid.CPU.Cache.L2, - CommandValue: uint32(mysql.ComSleep), - SlowQueryFile: config.GetGlobalConfig().Log.SlowQueryFile, + Users: make(map[string]string), + systems: make(map[string]string), + PreparedStmts: make(map[uint32]*ast.Prepared), + PreparedStmtNameToID: make(map[string]uint32), + PreparedParams: make([]types.Datum, 0, 10), + TxnCtx: &TransactionContext{}, + KVVars: kv.NewVariables(), + RetryInfo: &RetryInfo{}, + ActiveRoles: make([]*auth.RoleIdentity, 0, 10), + StrictSQLMode: true, + Status: mysql.ServerStatusAutocommit, + StmtCtx: new(stmtctx.StatementContext), + AllowAggPushDown: false, + OptimizerSelectivityLevel: DefTiDBOptimizerSelectivityLevel, + RetryLimit: DefTiDBRetryLimit, + DisableTxnAutoRetry: DefTiDBDisableTxnAutoRetry, + DDLReorgPriority: kv.PriorityLow, + AllowInSubqToJoinAndAgg: DefOptInSubqToJoinAndAgg, + CorrelationThreshold: DefOptCorrelationThreshold, + CorrelationExpFactor: DefOptCorrelationExpFactor, + EnableRadixJoin: false, + L2CacheSize: cpuid.CPU.Cache.L2, + CommandValue: uint32(mysql.ComSleep), + TiDBOptJoinReorderThreshold: DefTiDBOptJoinReorderThreshold, + SlowQueryFile: config.GetGlobalConfig().Log.SlowQueryFile, } vars.Concurrency = Concurrency{ IndexLookupConcurrency: DefIndexLookupConcurrency, @@ -754,6 +759,8 @@ func (s *SessionVars) SetSystemVar(name string, val string) error { s.EnableRadixJoin = TiDBOptOn(val) case TiDBEnableWindowFunction: s.EnableWindowFunction = TiDBOptOn(val) + case TiDBOptJoinReorderThreshold: + s.TiDBOptJoinReorderThreshold = tidbOptPositiveInt32(val, DefTiDBOptJoinReorderThreshold) case TiDBCheckMb4ValueInUTF8: config.GetGlobalConfig().CheckMb4ValueInUTF8 = TiDBOptOn(val) case TiDBSlowQueryFile: diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index 38b0f6754cab6..2a88b23393b71 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -689,6 +689,7 @@ var defaultSysVars = []*SysVar{ {ScopeSession, TiDBDDLReorgPriority, "PRIORITY_LOW"}, {ScopeSession, TiDBForcePriority, mysql.Priority2Str[DefTiDBForcePriority]}, {ScopeSession, TiDBEnableRadixJoin, BoolToIntStr(DefTiDBUseRadixJoin)}, + {ScopeGlobal | ScopeSession, TiDBOptJoinReorderThreshold, strconv.Itoa(DefTiDBOptJoinReorderThreshold)}, {ScopeSession, TiDBCheckMb4ValueInUTF8, BoolToIntStr(config.GetGlobalConfig().CheckMb4ValueInUTF8)}, {ScopeSession, TiDBSlowQueryFile, ""}, {ScopeSession, TiDBWaitTableSplitFinish, BoolToIntStr(DefTiDBWaitTableSplitFinish)}, diff --git a/sessionctx/variable/tidb_vars.go b/sessionctx/variable/tidb_vars.go index d40dacb2318f3..f16aec3d5cc4e 100644 --- a/sessionctx/variable/tidb_vars.go +++ b/sessionctx/variable/tidb_vars.go @@ -250,6 +250,10 @@ const ( // tidb_enable_window_function is used to control whether to enable the window function. TiDBEnableWindowFunction = "tidb_enable_window_function" + // TIDBOptJoinReorderThreshold defines the threshold less than which + // we'll choose a rather time consuming algorithm to calculate the join order. + TiDBOptJoinReorderThreshold = "tidb_opt_join_reorder_threshold" + // SlowQueryFile indicates which slow query log file for SLOW_QUERY table to parse. TiDBSlowQueryFile = "tidb_slow_query_file" @@ -309,6 +313,7 @@ const ( DefTiDBForcePriority = mysql.NoPriority DefTiDBUseRadixJoin = false DefEnableWindowFunction = false + DefTiDBOptJoinReorderThreshold = 0 DefTiDBDDLSlowOprThreshold = 300 DefTiDBUseFastAnalyze = false DefTiDBSkipIsolationLevelCheck = false diff --git a/sessionctx/variable/varsutil.go b/sessionctx/variable/varsutil.go index 6783975b1b1ef..a9e49e323daae 100644 --- a/sessionctx/variable/varsutil.go +++ b/sessionctx/variable/varsutil.go @@ -485,6 +485,14 @@ func ValidateSetSystemVar(vars *SessionVars, name string, value string) (string, return value, errors.Errorf("tidb_max_chunk_size(%d) cannot be smaller than %d", v, maxChunkSizeLowerBound) } return value, nil + case TiDBOptJoinReorderThreshold: + v, err := strconv.Atoi(value) + if err != nil { + return value, ErrWrongTypeForVar.GenWithStackByArgs(name) + } + if v < 0 || v >= 64 { + return value, errors.Errorf("tidb_join_order_algo_threshold(%d) cannot be smaller than 0 or larger than 63", v) + } } return value, nil } diff --git a/sessionctx/variable/varsutil_test.go b/sessionctx/variable/varsutil_test.go index 2b4c04992a4a9..4875582fd39e0 100644 --- a/sessionctx/variable/varsutil_test.go +++ b/sessionctx/variable/varsutil_test.go @@ -79,6 +79,7 @@ func (s *testVarsutilSuite) TestNewSessionVars(c *C) { c.Assert(vars.MemQuotaNestedLoopApply, Equals, int64(DefTiDBMemQuotaNestedLoopApply)) c.Assert(vars.EnableRadixJoin, Equals, DefTiDBUseRadixJoin) c.Assert(vars.AllowWriteRowID, Equals, DefOptWriteRowID) + c.Assert(vars.TiDBOptJoinReorderThreshold, Equals, DefTiDBOptJoinReorderThreshold) c.Assert(vars.EnableFastAnalyze, Equals, DefTiDBUseFastAnalyze) assertFieldsGreaterThanZero(c, reflect.ValueOf(vars.Concurrency)) @@ -254,6 +255,14 @@ func (s *testVarsutilSuite) TestVarsutil(c *C) { c.Assert(val, Equals, "on") c.Assert(v.EnableTablePartition, Equals, "on") + c.Assert(v.TiDBOptJoinReorderThreshold, Equals, DefTiDBOptJoinReorderThreshold) + err = SetSessionSystemVar(v, TiDBOptJoinReorderThreshold, types.NewIntDatum(5)) + c.Assert(err, IsNil) + val, err = GetSessionSystemVar(v, TiDBOptJoinReorderThreshold) + c.Assert(err, IsNil) + c.Assert(val, Equals, "5") + c.Assert(v.TiDBOptJoinReorderThreshold, Equals, 5) + err = SetSessionSystemVar(v, TiDBCheckMb4ValueInUTF8, types.NewStringDatum("1")) c.Assert(err, IsNil) val, err = GetSessionSystemVar(v, TiDBCheckMb4ValueInUTF8)