diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index b56fbd4284d2f..4accf54a18232 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.command.{DescribeTableCommand, ExecutedCommandExec, ShowTablesCommand} import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange} +import org.apache.spark.sql.execution.joins.ReorderJoinPredicates import org.apache.spark.sql.types.{BinaryType, DateType, DecimalType, TimestampType, _} import org.apache.spark.util.Utils @@ -103,6 +104,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { protected def preparations: Seq[Rule[SparkPlan]] = Seq( python.ExtractPythonUDFs, PlanSubqueries(sparkSession), + new ReorderJoinPredicates, EnsureRequirements(sparkSession.sessionState.conf), CollapseCodegenStages(sparkSession.sessionState.conf), ReuseExchange(sparkSession.sessionState.conf), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ReorderJoinPredicates.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ReorderJoinPredicates.scala new file mode 100644 index 0000000000000..534d8c5689c27 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ReorderJoinPredicates.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.SparkPlan + +/** + * When the physical operators are created for JOIN, the ordering of join keys is based on order + * in which the join keys appear in the user query. That might not match with the output + * partitioning of the join node's children (thus leading to extra sort / shuffle being + * introduced). This rule will change the ordering of the join keys to match with the + * partitioning of the join nodes' children. + */ +class ReorderJoinPredicates extends Rule[SparkPlan] { + private def reorderJoinKeys( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + leftPartitioning: Partitioning, + rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = { + + def reorder( + expectedOrderOfKeys: Seq[Expression], + currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { + val leftKeysBuffer = ArrayBuffer[Expression]() + val rightKeysBuffer = ArrayBuffer[Expression]() + + expectedOrderOfKeys.foreach(expression => { + val index = currentOrderOfKeys.indexWhere(e => e.semanticEquals(expression)) + leftKeysBuffer.append(leftKeys(index)) + rightKeysBuffer.append(rightKeys(index)) + }) + (leftKeysBuffer, rightKeysBuffer) + } + + if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) { + leftPartitioning match { + case HashPartitioning(leftExpressions, _) + if leftExpressions.length == leftKeys.length && + leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) => + reorder(leftExpressions, leftKeys) + + case _ => rightPartitioning match { + case HashPartitioning(rightExpressions, _) + if rightExpressions.length == rightKeys.length && + rightKeys.forall(x => rightExpressions.exists(_.semanticEquals(x))) => + reorder(rightExpressions, rightKeys) + + case _ => (leftKeys, rightKeys) + } + } + } else { + (leftKeys, rightKeys) + } + } + + def apply(plan: SparkPlan): SparkPlan = plan.transformUp { + case BroadcastHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) => + val (reorderedLeftKeys, reorderedRightKeys) = + reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) + BroadcastHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition, + left, right) + + case ShuffledHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) => + val (reorderedLeftKeys, reorderedRightKeys) = + reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) + ShuffledHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition, + left, right) + + case SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, left, right) => + val (reorderedLeftKeys, reorderedRightKeys) = + reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) + SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, condition, left, right) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index ba0ca666b5c14..eb9e6458fc61c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -543,6 +543,65 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { ) } + test("SPARK-19122 Re-order join predicates if they match with the child's output partitioning") { + val bucketedTableTestSpec = BucketedTableTestSpec( + Some(BucketSpec(8, Seq("i", "j", "k"), Seq("i", "j", "k"))), + numPartitions = 1, + expectedShuffle = false, + expectedSort = false) + + // If the set of join columns is equal to the set of bucketed + sort columns, then + // the order of join keys in the query should not matter and there should not be any shuffle + // and sort added in the query plan + Seq( + Seq("i", "j", "k"), + Seq("i", "k", "j"), + Seq("j", "k", "i"), + Seq("j", "i", "k"), + Seq("k", "j", "i"), + Seq("k", "i", "j") + ).foreach(joinKeys => { + testBucketing( + bucketedTableTestSpecLeft = bucketedTableTestSpec, + bucketedTableTestSpecRight = bucketedTableTestSpec, + joinCondition = joinCondition(joinKeys) + ) + }) + } + + test("SPARK-19122 No re-ordering should happen if set of join columns != set of child's " + + "partitioning columns") { + + // join predicates is a super set of child's partitioning columns + val bucketedTableTestSpec1 = + BucketedTableTestSpec(Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j"))), numPartitions = 1) + testBucketing( + bucketedTableTestSpecLeft = bucketedTableTestSpec1, + bucketedTableTestSpecRight = bucketedTableTestSpec1, + joinCondition = joinCondition(Seq("i", "j", "k")) + ) + + // child's partitioning columns is a super set of join predicates + val bucketedTableTestSpec2 = + BucketedTableTestSpec(Some(BucketSpec(8, Seq("i", "j", "k"), Seq("i", "j", "k"))), + numPartitions = 1) + testBucketing( + bucketedTableTestSpecLeft = bucketedTableTestSpec2, + bucketedTableTestSpecRight = bucketedTableTestSpec2, + joinCondition = joinCondition(Seq("i", "j")) + ) + + // set of child's partitioning columns != set join predicates (despite the lengths of the + // sets are same) + val bucketedTableTestSpec3 = + BucketedTableTestSpec(Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j"))), numPartitions = 1) + testBucketing( + bucketedTableTestSpecLeft = bucketedTableTestSpec3, + bucketedTableTestSpecRight = bucketedTableTestSpec3, + joinCondition = joinCondition(Seq("j", "k")) + ) + } + test("error if there exists any malformed bucket files") { withTable("bucketed_table") { df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table")