diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 79e4ddb8c4f5d..4e123ea626d60 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -91,9 +91,19 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { joins.BuildLeft } val hashJoin = joins.ShuffledHashJoin( - leftKeys, rightKeys, buildSide, planLater(left), planLater(right)) + leftKeys, rightKeys, buildSide, Inner, condition, planLater(left), planLater(right)) condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil + case ExtractEquiJoinKeys(LeftOuter, leftKeys, rightKeys, condition, left, right) => + joins.ShuffledHashJoin( + leftKeys, rightKeys, joins.BuildRight, LeftOuter, + condition, planLater(left), planLater(right)) :: Nil + + case ExtractEquiJoinKeys(RightOuter, leftKeys, rightKeys, condition, left, right) => + joins.ShuffledHashJoin( + leftKeys, rightKeys, joins.BuildLeft, RightOuter, + condition, planLater(left), planLater(right)) :: Nil + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) => joins.HashOuterJoin( leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index 418c1c23e5546..3e33f0dbc4a1c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -18,10 +18,11 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning} +import org.apache.spark.sql.catalyst.plans.{Inner, FullOuter, JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} - +import org.apache.spark.util.collection.CompactBuffer /** * :: DeveloperApi :: * Performs an inner hash join of two child relations by first shuffling the data using the join @@ -32,19 +33,115 @@ case class ShuffledHashJoin( leftKeys: Seq[Expression], rightKeys: Seq[Expression], buildSide: BuildSide, + joinType: JoinType, + condition: Option[Expression], left: SparkPlan, right: SparkPlan) extends BinaryNode with HashJoin { - override def outputPartitioning: Partitioning = left.outputPartitioning + override def outputPartitioning: Partitioning = joinType match { + case Inner => left.outputPartitioning + case LeftOuter => left.outputPartitioning + case RightOuter => right.outputPartitioning + case x => throw new Exception(s"ShuffledHashJoin should not take $x as the JoinType") + } override def requiredChildDistribution = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + override def output = { + joinType match { + case Inner => + left.output ++ right.output + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case x => + throw new Exception(s"ShuffledHashJoin should not take $x as the JoinType") + } + } + + private[this] lazy val nullRow = joinType match { + case LeftOuter => new GenericRow(right.output.length) + case RightOuter => new GenericRow(left.output.length) + case _ => null + } + + private[this] lazy val boundCondition = + condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) + + private def outerJoin(streamIter: Iterator[Row], hashedRelation: HashedRelation):Iterator[Row] = { + new Iterator[Row] { + private[this] var currentStreamedRow: Row = _ + private[this] var currentHashMatches: CompactBuffer[Row] = _ + private[this] var currentMatchPosition: Int = -1 + + // Mutable per row objects. + private[this] val joinRow = new JoinedRow2 + + private[this] val joinKeys = streamSideKeyGenerator() + + override final def hasNext: Boolean = + (currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) || + (streamIter.hasNext && fetchNext()) + + override final def next() = { + val ret = joinType match { + case LeftOuter => + if (currentMatchPosition == -1) { + joinRow(currentStreamedRow, nullRow) + } else { + val rightRow = currentHashMatches(currentMatchPosition) + val joinedRow = joinRow(currentStreamedRow, rightRow) + currentMatchPosition += 1 + if (!boundCondition(joinedRow)) { + joinRow(currentStreamedRow, nullRow) + } else { + joinedRow + } + } + case RightOuter => + if (currentMatchPosition == -1) { + joinRow(nullRow, currentStreamedRow) + } else { + val leftRow = currentHashMatches(currentMatchPosition) + val joinedRow = joinRow(leftRow, currentStreamedRow) + currentMatchPosition += 1 + if (!boundCondition(joinedRow)) { + joinRow(nullRow, currentStreamedRow) + } else { + joinedRow + } + } + } + ret + } + + private final def fetchNext(): Boolean = { + currentMatchPosition = -1 + currentHashMatches = null + currentStreamedRow = streamIter.next() + if (!joinKeys(currentStreamedRow).anyNull) { + currentHashMatches = hashedRelation.get(joinKeys.currentValue) + } + if (currentHashMatches != null) { + currentMatchPosition = 0 + } + true + } + } + } + override def execute() = { buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => val hashed = HashedRelation(buildIter, buildSideKeyGenerator) - hashJoin(streamIter, hashed) + joinType match { + case Inner => hashJoin(streamIter, hashed) + case LeftOuter => outerJoin(streamIter, hashed) + case RightOuter => outerJoin(streamIter, hashed) + case x => throw new Exception(s"ShuffledHashJoin should not take $x as the JoinType") + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 8b4cf5bac0187..a8a8454c45f75 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -75,11 +75,11 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[ShuffledHashJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[ShuffledHashJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[ShuffledHashJoin]), - ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[HashOuterJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", - classOf[HashOuterJoin]), + classOf[ShuffledHashJoin]), ("SELECT * FROM testData right join testData2 ON key = a and key = 2", - classOf[HashOuterJoin]), + classOf[ShuffledHashJoin]), ("SELECT * FROM testData full outer join testData2 ON key = a", classOf[HashOuterJoin]) // TODO add BroadcastNestedLoopJoin ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }