Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.command.{DescribeTableCommand, ExecutedCommandExec, ShowTablesCommand}
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange}
import org.apache.spark.sql.execution.joins.ReorderJoinPredicates
import org.apache.spark.sql.types.{BinaryType, DateType, DecimalType, TimestampType, _}
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -103,6 +104,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
protected def preparations: Seq[Rule[SparkPlan]] = Seq(
python.ExtractPythonUDFs,
PlanSubqueries(sparkSession),
new ReorderJoinPredicates,
EnsureRequirements(sparkSession.sessionState.conf),
CollapseCodegenStages(sparkSession.sessionState.conf),
ReuseExchange(sparkSession.sessionState.conf),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.joins

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlan

/**
* When the physical operators are created for JOIN, the ordering of join keys is based on order
* in which the join keys appear in the user query. That might not match with the output
* partitioning of the join node's children (thus leading to extra sort / shuffle being
* introduced). This rule will change the ordering of the join keys to match with the
* partitioning of the join nodes' children.
*/
class ReorderJoinPredicates extends Rule[SparkPlan] {
private def reorderJoinKeys(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
leftPartitioning: Partitioning,
rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = {

def reorder(
expectedOrderOfKeys: Seq[Expression],
currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = {
val leftKeysBuffer = ArrayBuffer[Expression]()
val rightKeysBuffer = ArrayBuffer[Expression]()

expectedOrderOfKeys.foreach(expression => {
val index = currentOrderOfKeys.indexWhere(e => e.semanticEquals(expression))
leftKeysBuffer.append(leftKeys(index))
rightKeysBuffer.append(rightKeys(index))
})
(leftKeysBuffer, rightKeysBuffer)
}

if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) {
leftPartitioning match {
case HashPartitioning(leftExpressions, _)
if leftExpressions.length == leftKeys.length &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need the same length? Let's say the child partitioning is a, b, c, d and the join key is b, a, we can reorder the join key to avoid shuffle, right?

Copy link
Contributor Author

@tejasapatil tejasapatil May 12, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that would be right thing to do. If child is partitioned on a, b, c, d, its basically means rows are distributed over hash of a, b, c, d. Lets say we have two rows with values of a, b, c, d as:

  • row1 : 1,1,1,1 ==> hash(1,1,1,1) = x
  • row2 : 1,1,2,2 ==> hash(1,1,2,2) = y

If the join key b,a is reordered as a,b and we want to avoid shuffle, that would mean that we expect the child to have same values of a,b in the same partition. But if you look at row1 and row2 above, even if values of a and b are the same, there is no guarantee that they would belong to the same partition... as the partition is based on hash of all a,b,c,d.

If the join keys are a subset of the partitioning, then there needs to be a shuffle to be done. There is only one exception to this (more of a corner case) : https://issues.apache.org/jira/browse/SPARK-18067

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh sorry I made a mistake.

if the child partitioning is a, b and the join key is b, a, c, d, does it make sense to reorder it as a, b ,c ,d?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EnsureRequirements would still add a shuffle in either case even if we reorder.

JOIN would expect data to be distributed over b, a, c, d (or a,b,c,d if you reorder) which maps to HashPartitioning(a,b,c,d) :

But the child nodes won't have matching partitioning ie. they will have HashPartitioning(a,b).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The contract for reordering is that the set of join keys must be equal to the set of child's partitioning columns (implemented at L58-L59 in this file). Thus there won't be reordering for the case you pointed out. I have added a test case of the same.

leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) =>
reorder(leftExpressions, leftKeys)

case _ => rightPartitioning match {
case HashPartitioning(rightExpressions, _)
if rightExpressions.length == rightKeys.length &&
rightKeys.forall(x => rightExpressions.exists(_.semanticEquals(x))) =>
reorder(rightExpressions, rightKeys)

case _ => (leftKeys, rightKeys)
}
}
} else {
(leftKeys, rightKeys)
}
}

def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
case BroadcastHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) =>
val (reorderedLeftKeys, reorderedRightKeys) =
reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning)
BroadcastHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition,
left, right)

case ShuffledHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) =>
val (reorderedLeftKeys, reorderedRightKeys) =
reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning)
ShuffledHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition,
left, right)

case SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, left, right) =>
val (reorderedLeftKeys, reorderedRightKeys) =
reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning)
SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, condition, left, right)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,65 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils {
)
}

test("SPARK-19122 Re-order join predicates if they match with the child's output partitioning") {
val bucketedTableTestSpec = BucketedTableTestSpec(
Some(BucketSpec(8, Seq("i", "j", "k"), Seq("i", "j", "k"))),
numPartitions = 1,
expectedShuffle = false,
expectedSort = false)

// If the set of join columns is equal to the set of bucketed + sort columns, then
// the order of join keys in the query should not matter and there should not be any shuffle
// and sort added in the query plan
Seq(
Seq("i", "j", "k"),
Seq("i", "k", "j"),
Seq("j", "k", "i"),
Seq("j", "i", "k"),
Seq("k", "j", "i"),
Seq("k", "i", "j")
).foreach(joinKeys => {
testBucketing(
bucketedTableTestSpecLeft = bucketedTableTestSpec,
bucketedTableTestSpecRight = bucketedTableTestSpec,
joinCondition = joinCondition(joinKeys)
)
})
}

test("SPARK-19122 No re-ordering should happen if set of join columns != set of child's " +
"partitioning columns") {

// join predicates is a super set of child's partitioning columns
val bucketedTableTestSpec1 =
BucketedTableTestSpec(Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j"))), numPartitions = 1)
testBucketing(
bucketedTableTestSpecLeft = bucketedTableTestSpec1,
bucketedTableTestSpecRight = bucketedTableTestSpec1,
joinCondition = joinCondition(Seq("i", "j", "k"))
)

// child's partitioning columns is a super set of join predicates
val bucketedTableTestSpec2 =
BucketedTableTestSpec(Some(BucketSpec(8, Seq("i", "j", "k"), Seq("i", "j", "k"))),
numPartitions = 1)
testBucketing(
bucketedTableTestSpecLeft = bucketedTableTestSpec2,
bucketedTableTestSpecRight = bucketedTableTestSpec2,
joinCondition = joinCondition(Seq("i", "j"))
)

// set of child's partitioning columns != set join predicates (despite the lengths of the
// sets are same)
val bucketedTableTestSpec3 =
BucketedTableTestSpec(Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j"))), numPartitions = 1)
testBucketing(
bucketedTableTestSpecLeft = bucketedTableTestSpec3,
bucketedTableTestSpecRight = bucketedTableTestSpec3,
joinCondition = joinCondition(Seq("j", "k"))
)
}

test("error if there exists any malformed bucket files") {
withTable("bucketed_table") {
df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table")
Expand Down