From a39be8c06d3397ff834b1defad37ce1ca1824183 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 8 Oct 2014 15:22:34 -0700 Subject: [PATCH 01/14] [SPARK-3857] Create a join package for various join operators. --- .../spark/sql/execution/SparkStrategies.scala | 42 +- .../execution/join/BroadcastHashJoin.scala | 62 ++ .../join/BroadcastNestedLoopJoin.scala | 144 ++++ .../sql/execution/join/CartesianProduct.scala | 40 ++ .../spark/sql/execution/join/HashJoin.scala | 122 ++++ .../sql/execution/join/HashOuterJoin.scala | 222 +++++++ .../sql/execution/join/LeftSemiJoinBNL.scala | 73 ++ .../sql/execution/join/LeftSemiJoinHash.scala | 67 ++ .../sql/execution/join/ShuffledHashJoin.scala | 49 ++ .../spark/sql/execution/join/package.scala | 37 ++ .../apache/spark/sql/execution/joins.scala | 624 ------------------ .../org/apache/spark/sql/JoinSuite.scala | 1 + .../org/apache/spark/sql/SQLQuerySuite.scala | 2 +- .../spark/sql/execution/PlannerSuite.scala | 3 +- .../spark/sql/hive/StatisticsSuite.scala | 2 +- 15 files changed, 843 insertions(+), 647 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/join/BroadcastHashJoin.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/join/BroadcastNestedLoopJoin.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/join/CartesianProduct.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/join/HashJoin.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/join/HashOuterJoin.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/join/LeftSemiJoinBNL.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/join/LeftSemiJoinHash.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/join/ShuffledHashJoin.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/join/package.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 5c16d0c624128..420bc11f748af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan} import org.apache.spark.sql.parquet._ + private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { self: SQLContext#SparkPlanner => @@ -34,13 +35,11 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { // Find left semi joins where at least some predicates can be evaluated by matching join keys case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) => - val semiJoin = execution.LeftSemiJoinHash( - leftKeys, rightKeys, planLater(left), planLater(right)) + val semiJoin = join.LeftSemiJoinHash(leftKeys, rightKeys, planLater(left), planLater(right)) condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil // no predicate can be evaluated by matching hash keys case logical.Join(left, right, LeftSemi, condition) => - execution.LeftSemiJoinBNL( - planLater(left), planLater(right), condition) :: Nil + join.LeftSemiJoinBNL(planLater(left), planLater(right), condition) :: Nil case _ => Nil } } @@ -50,13 +49,13 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * evaluated by matching hash keys. * * This strategy applies a simple optimization based on the estimates of the physical sizes of - * the two join sides. When planning a [[execution.BroadcastHashJoin]], if one side has an + * the two join sides. When planning a [[join.BroadcastHashJoin]], if one side has an * estimated physical size smaller than the user-settable threshold * [[org.apache.spark.sql.SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]], the planner would mark it as the * ''build'' relation and mark the other relation as the ''stream'' side. The build table will be * ''broadcasted'' to all of the executors involved in the join, as a * [[org.apache.spark.broadcast.Broadcast]] object. If both estimates exceed the threshold, they - * will instead be used to decide the build side in a [[execution.ShuffledHashJoin]]. + * will instead be used to decide the build side in a [[join.ShuffledHashJoin]]. */ object HashJoin extends Strategy with PredicateHelper { @@ -66,8 +65,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { left: LogicalPlan, right: LogicalPlan, condition: Option[Expression], - side: BuildSide) = { - val broadcastHashJoin = execution.BroadcastHashJoin( + side: join.BuildSide) = { + val broadcastHashJoin = execution.join.BroadcastHashJoin( leftKeys, rightKeys, side, planLater(left), planLater(right)) condition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) :: Nil } @@ -76,27 +75,26 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) if sqlContext.autoBroadcastJoinThreshold > 0 && right.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold => - makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildRight) + makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, join.BuildRight) case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) if sqlContext.autoBroadcastJoinThreshold > 0 && left.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold => - makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildLeft) + makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, join.BuildLeft) case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) => val buildSide = if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { - BuildRight + join.BuildRight } else { - BuildLeft + join.BuildLeft } - val hashJoin = - execution.ShuffledHashJoin( - leftKeys, rightKeys, buildSide, planLater(left), planLater(right)) + val hashJoin = join.ShuffledHashJoin( + leftKeys, rightKeys, buildSide, planLater(left), planLater(right)) condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) => - execution.HashOuterJoin( + join.HashOuterJoin( leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil case _ => Nil @@ -164,8 +162,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Join(left, right, joinType, condition) => val buildSide = - if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) BuildRight else BuildLeft - execution.BroadcastNestedLoopJoin( + if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { + join.BuildRight + } else { + join.BuildLeft + } + join.BroadcastNestedLoopJoin( planLater(left), planLater(right), buildSide, joinType, condition) :: Nil case _ => Nil } @@ -174,10 +176,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object CartesianProduct extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Join(left, right, _, None) => - execution.CartesianProduct(planLater(left), planLater(right)) :: Nil + execution.join.CartesianProduct(planLater(left), planLater(right)) :: Nil case logical.Join(left, right, Inner, Some(condition)) => execution.Filter(condition, - execution.CartesianProduct(planLater(left), planLater(right))) :: Nil + execution.join.CartesianProduct(planLater(left), planLater(right))) :: Nil case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/join/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/join/BroadcastHashJoin.scala new file mode 100644 index 0000000000000..1e7626905f3b2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/join/BroadcastHashJoin.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.join + +import scala.concurrent._ +import scala.concurrent.duration._ +import scala.concurrent.ExecutionContext.Implicits.global + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnspecifiedDistribution} +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} + +/** + * :: DeveloperApi :: + * Performs an inner hash join of two child relations. When the output RDD of this operator is + * being constructed, a Spark job is asynchronously started to calculate the values for the + * broadcasted relation. This data is then placed in a Spark broadcast variable. The streamed + * relation is not shuffled. + */ +@DeveloperApi +case class BroadcastHashJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + buildSide: BuildSide, + left: SparkPlan, + right: SparkPlan) + extends BinaryNode with HashJoin { + + override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning + + override def requiredChildDistribution = + UnspecifiedDistribution :: UnspecifiedDistribution :: Nil + + @transient + private val broadcastFuture = future { + sparkContext.broadcast(buildPlan.executeCollect()) + } + + override def execute() = { + val broadcastRelation = Await.result(broadcastFuture, 5.minute) + + streamedPlan.execute().mapPartitions { streamedIter => + joinIterators(broadcastRelation.value.iterator, streamedIter) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/join/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/join/BroadcastNestedLoopJoin.scala new file mode 100644 index 0000000000000..b29eaafd7f2fe --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/join/BroadcastNestedLoopJoin.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.join + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.util.collection.CompactBuffer + +/** + * :: DeveloperApi :: + */ +@DeveloperApi +case class BroadcastNestedLoopJoin( + left: SparkPlan, + right: SparkPlan, + buildSide: BuildSide, + joinType: JoinType, + condition: Option[Expression]) extends BinaryNode { + // TODO: Override requiredChildDistribution. + + /** BuildRight means the right relation <=> the broadcast relation. */ + private val (streamed, broadcast) = buildSide match { + case BuildRight => (left, right) + case BuildLeft => (right, left) + } + + override def outputPartitioning: Partitioning = streamed.outputPartitioning + + override def output = { + joinType match { + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case FullOuter => + left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) + case _ => + left.output ++ right.output + } + } + + @transient private lazy val boundCondition = + InterpretedPredicate( + condition + .map(c => BindReferences.bindReference(c, left.output ++ right.output)) + .getOrElse(Literal(true))) + + override def execute() = { + val broadcastedRelation = + sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) + + /** All rows that either match both-way, or rows from streamed joined with nulls. */ + val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter => + val matchedRows = new CompactBuffer[Row] + // TODO: Use Spark's BitSet. + val includedBroadcastTuples = + new scala.collection.mutable.BitSet(broadcastedRelation.value.size) + val joinedRow = new JoinedRow + val leftNulls = new GenericMutableRow(left.output.size) + val rightNulls = new GenericMutableRow(right.output.size) + + streamedIter.foreach { streamedRow => + var i = 0 + var streamRowMatched = false + + while (i < broadcastedRelation.value.size) { + // TODO: One bitset per partition instead of per row. + val broadcastedRow = broadcastedRelation.value(i) + buildSide match { + case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) => + matchedRows += joinedRow(streamedRow, broadcastedRow).copy() + streamRowMatched = true + includedBroadcastTuples += i + case BuildLeft if boundCondition(joinedRow(broadcastedRow, streamedRow)) => + matchedRows += joinedRow(broadcastedRow, streamedRow).copy() + streamRowMatched = true + includedBroadcastTuples += i + case _ => + } + i += 1 + } + + (streamRowMatched, joinType, buildSide) match { + case (false, LeftOuter | FullOuter, BuildRight) => + matchedRows += joinedRow(streamedRow, rightNulls).copy() + case (false, RightOuter | FullOuter, BuildLeft) => + matchedRows += joinedRow(leftNulls, streamedRow).copy() + case _ => + } + } + Iterator((matchedRows, includedBroadcastTuples)) + } + + val includedBroadcastTuples = matchesOrStreamedRowsWithNulls.map(_._2) + val allIncludedBroadcastTuples = + if (includedBroadcastTuples.count == 0) { + new scala.collection.mutable.BitSet(broadcastedRelation.value.size) + } else { + includedBroadcastTuples.reduce(_ ++ _) + } + + val leftNulls = new GenericMutableRow(left.output.size) + val rightNulls = new GenericMutableRow(right.output.size) + /** Rows from broadcasted joined with nulls. */ + val broadcastRowsWithNulls: Seq[Row] = { + val buf: CompactBuffer[Row] = new CompactBuffer() + var i = 0 + val rel = broadcastedRelation.value + while (i < rel.length) { + if (!allIncludedBroadcastTuples.contains(i)) { + (joinType, buildSide) match { + case (RightOuter | FullOuter, BuildRight) => buf += new JoinedRow(leftNulls, rel(i)) + case (LeftOuter | FullOuter, BuildLeft) => buf += new JoinedRow(rel(i), rightNulls) + case _ => + } + } + i += 1 + } + buf.toSeq + } + + // TODO: Breaks lineage. + sparkContext.union( + matchesOrStreamedRowsWithNulls.flatMap(_._1), sparkContext.makeRDD(broadcastRowsWithNulls)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/join/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/join/CartesianProduct.scala new file mode 100644 index 0000000000000..ddc459edca6fb --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/join/CartesianProduct.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.join + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions.JoinedRow +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} + +/** + * :: DeveloperApi :: + */ +@DeveloperApi +case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode { + override def output = left.output ++ right.output + + override def execute() = { + val leftResults = left.execute().map(_.copy()) + val rightResults = right.execute().map(_.copy()) + + leftResults.cartesian(rightResults).mapPartitions { iter => + val joinedRow = new JoinedRow + iter.map(r => joinedRow(r._1, r._2)) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/join/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/join/HashJoin.scala new file mode 100644 index 0000000000000..211076c098db3 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/join/HashJoin.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.join + +import org.apache.spark.sql.catalyst.expressions.{Expression, JoinedRow2, Row} +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.util.collection.CompactBuffer + + +trait HashJoin { + self: SparkPlan => + + val leftKeys: Seq[Expression] + val rightKeys: Seq[Expression] + val buildSide: BuildSide + val left: SparkPlan + val right: SparkPlan + + protected lazy val (buildPlan, streamedPlan) = buildSide match { + case BuildLeft => (left, right) + case BuildRight => (right, left) + } + + protected lazy val (buildKeys, streamedKeys) = buildSide match { + case BuildLeft => (leftKeys, rightKeys) + case BuildRight => (rightKeys, leftKeys) + } + + override def output = left.output ++ right.output + + @transient protected lazy val buildSideKeyGenerator = newProjection(buildKeys, buildPlan.output) + @transient protected lazy val streamSideKeyGenerator = + newMutableProjection(streamedKeys, streamedPlan.output) + + protected def joinIterators(buildIter: Iterator[Row], streamIter: Iterator[Row]): Iterator[Row] = { + // TODO: Use Spark's HashMap implementation. + + val hashTable = new java.util.HashMap[Row, CompactBuffer[Row]]() + var currentRow: Row = null + + // Create a mapping of buildKeys -> rows + while (buildIter.hasNext) { + currentRow = buildIter.next() + val rowKey = buildSideKeyGenerator(currentRow) + if (!rowKey.anyNull) { + val existingMatchList = hashTable.get(rowKey) + val matchList = if (existingMatchList == null) { + val newMatchList = new CompactBuffer[Row]() + hashTable.put(rowKey, newMatchList) + newMatchList + } else { + existingMatchList + } + matchList += currentRow.copy() + } + } + + new Iterator[Row] { + private[this] var currentStreamedRow: Row = _ + private[this] var currentHashMatches: CompactBuffer[Row] = _ + private[this] var currentMatchPosition: Int = -1 + + // Mutable per row objects. + private[this] val joinRow = new JoinedRow2 + + private[this] val joinKeys = streamSideKeyGenerator() + + override final def hasNext: Boolean = + (currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) || + (streamIter.hasNext && fetchNext()) + + override final def next() = { + val ret = buildSide match { + case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition)) + case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow) + } + currentMatchPosition += 1 + ret + } + + /** + * Searches the streamed iterator for the next row that has at least one match in hashtable. + * + * @return true if the search is successful, and false if the streamed iterator runs out of + * tuples. + */ + private final def fetchNext(): Boolean = { + currentHashMatches = null + currentMatchPosition = -1 + + while (currentHashMatches == null && streamIter.hasNext) { + currentStreamedRow = streamIter.next() + if (!joinKeys(currentStreamedRow).anyNull) { + currentHashMatches = hashTable.get(joinKeys.currentValue) + } + } + + if (currentHashMatches == null) { + false + } else { + currentMatchPosition = 0 + true + } + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/join/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/join/HashOuterJoin.scala new file mode 100644 index 0000000000000..caacb0ffca462 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/join/HashOuterJoin.scala @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.join + +import java.util.{HashMap => JavaHashMap} + +import scala.collection.JavaConversions._ + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning, UnknownPartitioning} +import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.util.collection.CompactBuffer + +/** + * :: DeveloperApi :: + * Performs a hash based outer join for two child relations by shuffling the data using + * the join keys. This operator requires loading the associated partition in both side into memory. + */ +@DeveloperApi +case class HashOuterJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode { + + override def outputPartitioning: Partitioning = joinType match { + case LeftOuter => left.outputPartitioning + case RightOuter => right.outputPartitioning + case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) + case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") + } + + override def requiredChildDistribution = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + override def output = { + joinType match { + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case FullOuter => + left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) + case x => + throw new Exception(s"HashOuterJoin should not take $x as the JoinType") + } + } + + @transient private[this] lazy val DUMMY_LIST = Seq[Row](null) + @transient private[this] lazy val EMPTY_LIST = Seq.empty[Row] + + // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala + // iterator for performance purpose. + + private[this] def leftOuterIterator( + key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { + val joinedRow = new JoinedRow() + val rightNullRow = new GenericRow(right.output.length) + val boundCondition = + condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) + + leftIter.iterator.flatMap { l => + joinedRow.withLeft(l) + var matched = false + (if (!key.anyNull) rightIter.collect { case r if (boundCondition(joinedRow.withRight(r))) => + matched = true + joinedRow.copy + } else { + Nil + }) ++ DUMMY_LIST.filter(_ => !matched).map( _ => { + // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, + // as we don't know whether we need to append it until finish iterating all of the + // records in right side. + // If we didn't get any proper row, then append a single row with empty right + joinedRow.withRight(rightNullRow).copy + }) + } + } + + private[this] def rightOuterIterator( + key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { + val joinedRow = new JoinedRow() + val leftNullRow = new GenericRow(left.output.length) + val boundCondition = + condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) + + rightIter.iterator.flatMap { r => + joinedRow.withRight(r) + var matched = false + (if (!key.anyNull) leftIter.collect { case l if (boundCondition(joinedRow.withLeft(l))) => + matched = true + joinedRow.copy + } else { + Nil + }) ++ DUMMY_LIST.filter(_ => !matched).map( _ => { + // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, + // as we don't know whether we need to append it until finish iterating all of the + // records in left side. + // If we didn't get any proper row, then append a single row with empty left. + joinedRow.withLeft(leftNullRow).copy + }) + } + } + + private[this] def fullOuterIterator( + key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { + val joinedRow = new JoinedRow() + val leftNullRow = new GenericRow(left.output.length) + val rightNullRow = new GenericRow(right.output.length) + val boundCondition = + condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) + + if (!key.anyNull) { + // Store the positions of records in right, if one of its associated row satisfy + // the join condition. + val rightMatchedSet = scala.collection.mutable.Set[Int]() + leftIter.iterator.flatMap[Row] { l => + joinedRow.withLeft(l) + var matched = false + rightIter.zipWithIndex.collect { + // 1. For those matched (satisfy the join condition) records with both sides filled, + // append them directly + + case (r, idx) if (boundCondition(joinedRow.withRight(r)))=> { + matched = true + // if the row satisfy the join condition, add its index into the matched set + rightMatchedSet.add(idx) + joinedRow.copy + } + } ++ DUMMY_LIST.filter(_ => !matched).map( _ => { + // 2. For those unmatched records in left, append additional records with empty right. + + // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, + // as we don't know whether we need to append it until finish iterating all + // of the records in right side. + // If we didn't get any proper row, then append a single row with empty right. + joinedRow.withRight(rightNullRow).copy + }) + } ++ rightIter.zipWithIndex.collect { + // 3. For those unmatched records in right, append additional records with empty left. + + // Re-visiting the records in right, and append additional row with empty left, if its not + // in the matched set. + case (r, idx) if (!rightMatchedSet.contains(idx)) => { + joinedRow(leftNullRow, r).copy + } + } + } else { + leftIter.iterator.map[Row] { l => + joinedRow(l, rightNullRow).copy + } ++ rightIter.iterator.map[Row] { r => + joinedRow(leftNullRow, r).copy + } + } + } + + private[this] def buildHashTable( + iter: Iterator[Row], keyGenerator: Projection): JavaHashMap[Row, CompactBuffer[Row]] = { + val hashTable = new JavaHashMap[Row, CompactBuffer[Row]]() + while (iter.hasNext) { + val currentRow = iter.next() + val rowKey = keyGenerator(currentRow) + + var existingMatchList = hashTable.get(rowKey) + if (existingMatchList == null) { + existingMatchList = new CompactBuffer[Row]() + hashTable.put(rowKey, existingMatchList) + } + + existingMatchList += currentRow.copy() + } + + hashTable + } + + override def execute() = { + left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => + // TODO this probably can be replaced by external sort (sort merged join?) + // Build HashMap for current partition in left relation + val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) + // Build HashMap for current partition in right relation + val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) + val boundCondition = + condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) + joinType match { + case LeftOuter => leftHashTable.keysIterator.flatMap { key => + leftOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), + rightHashTable.getOrElse(key, EMPTY_LIST)) + } + case RightOuter => rightHashTable.keysIterator.flatMap { key => + rightOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), + rightHashTable.getOrElse(key, EMPTY_LIST)) + } + case FullOuter => (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => + fullOuterIterator(key, + leftHashTable.getOrElse(key, EMPTY_LIST), + rightHashTable.getOrElse(key, EMPTY_LIST)) + } + case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/join/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/join/LeftSemiJoinBNL.scala new file mode 100644 index 0000000000000..d4401c8ce3ff4 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/join/LeftSemiJoinBNL.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.join + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} + +/** + * :: DeveloperApi :: + * Using BroadcastNestedLoopJoin to calculate left semi join result when there's no join keys + * for hash join. + */ +@DeveloperApi +case class LeftSemiJoinBNL( + streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression]) + extends BinaryNode { + // TODO: Override requiredChildDistribution. + + override def outputPartitioning: Partitioning = streamed.outputPartitioning + + override def output = left.output + + /** The Streamed Relation */ + override def left = streamed + /** The Broadcast relation */ + override def right = broadcast + + @transient private lazy val boundCondition = + InterpretedPredicate( + condition + .map(c => BindReferences.bindReference(c, left.output ++ right.output)) + .getOrElse(Literal(true))) + + override def execute() = { + val broadcastedRelation = + sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) + + streamed.execute().mapPartitions { streamedIter => + val joinedRow = new JoinedRow + + streamedIter.filter(streamedRow => { + var i = 0 + var matched = false + + while (i < broadcastedRelation.value.size && !matched) { + val broadcastedRow = broadcastedRelation.value(i) + if (boundCondition(joinedRow(streamedRow, broadcastedRow))) { + matched = true + } + i += 1 + } + matched + }) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/join/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/join/LeftSemiJoinHash.scala new file mode 100644 index 0000000000000..e383e00725a7e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/join/LeftSemiJoinHash.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.join + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions.{Expression, Row} +import org.apache.spark.sql.catalyst.plans.physical.ClusteredDistribution +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} + +/** + * :: DeveloperApi :: + * Build the right table's join keys into a HashSet, and iteratively go through the left + * table, to find the if join keys are in the Hash set. + */ +@DeveloperApi +case class LeftSemiJoinHash( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode with HashJoin { + + override val buildSide = BuildRight + + override def requiredChildDistribution = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + override def output = left.output + + override def execute() = { + buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => + val hashSet = new java.util.HashSet[Row]() + var currentRow: Row = null + + // Create a Hash set of buildKeys + while (buildIter.hasNext) { + currentRow = buildIter.next() + val rowKey = buildSideKeyGenerator(currentRow) + if (!rowKey.anyNull) { + val keyExists = hashSet.contains(rowKey) + if (!keyExists) { + hashSet.add(rowKey) + } + } + } + + val joinKeys = streamSideKeyGenerator() + streamIter.filter(current => { + !joinKeys(current).anyNull && hashSet.contains(joinKeys.currentValue) + }) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/join/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/join/ShuffledHashJoin.scala new file mode 100644 index 0000000000000..ad248401f2e54 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/join/ShuffledHashJoin.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.join + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning} +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} + +/** + * :: DeveloperApi :: + * Performs an inner hash join of two child relations by first shuffling the data using the join + * keys. + */ +@DeveloperApi +case class ShuffledHashJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + buildSide: BuildSide, + left: SparkPlan, + right: SparkPlan) + extends BinaryNode with HashJoin { + + override def outputPartitioning: Partitioning = left.outputPartitioning + + override def requiredChildDistribution = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + override def execute() = { + buildPlan.execute().zipPartitions(streamedPlan.execute()) { + (buildIter, streamIter) => joinIterators(buildIter, streamIter) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/join/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/join/package.scala new file mode 100644 index 0000000000000..3ee9709137728 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/join/package.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.annotation.DeveloperApi + +/** + * :: DeveloperApi :: + * Physical execution operators for join operations. + */ +package object join { + + @DeveloperApi + sealed abstract class BuildSide + + @DeveloperApi + case object BuildRight extends BuildSide + + @DeveloperApi + case object BuildLeft extends BuildSide + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala deleted file mode 100644 index 2890a563bed48..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ /dev/null @@ -1,624 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import java.util.{HashMap => JavaHashMap} - -import scala.concurrent.ExecutionContext.Implicits.global -import scala.concurrent._ -import scala.concurrent.duration._ - -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.util.collection.CompactBuffer - -@DeveloperApi -sealed abstract class BuildSide - -@DeveloperApi -case object BuildLeft extends BuildSide - -@DeveloperApi -case object BuildRight extends BuildSide - -trait HashJoin { - self: SparkPlan => - - val leftKeys: Seq[Expression] - val rightKeys: Seq[Expression] - val buildSide: BuildSide - val left: SparkPlan - val right: SparkPlan - - lazy val (buildPlan, streamedPlan) = buildSide match { - case BuildLeft => (left, right) - case BuildRight => (right, left) - } - - lazy val (buildKeys, streamedKeys) = buildSide match { - case BuildLeft => (leftKeys, rightKeys) - case BuildRight => (rightKeys, leftKeys) - } - - def output = left.output ++ right.output - - @transient lazy val buildSideKeyGenerator = newProjection(buildKeys, buildPlan.output) - @transient lazy val streamSideKeyGenerator = - newMutableProjection(streamedKeys, streamedPlan.output) - - def joinIterators(buildIter: Iterator[Row], streamIter: Iterator[Row]): Iterator[Row] = { - // TODO: Use Spark's HashMap implementation. - - val hashTable = new java.util.HashMap[Row, CompactBuffer[Row]]() - var currentRow: Row = null - - // Create a mapping of buildKeys -> rows - while (buildIter.hasNext) { - currentRow = buildIter.next() - val rowKey = buildSideKeyGenerator(currentRow) - if (!rowKey.anyNull) { - val existingMatchList = hashTable.get(rowKey) - val matchList = if (existingMatchList == null) { - val newMatchList = new CompactBuffer[Row]() - hashTable.put(rowKey, newMatchList) - newMatchList - } else { - existingMatchList - } - matchList += currentRow.copy() - } - } - - new Iterator[Row] { - private[this] var currentStreamedRow: Row = _ - private[this] var currentHashMatches: CompactBuffer[Row] = _ - private[this] var currentMatchPosition: Int = -1 - - // Mutable per row objects. - private[this] val joinRow = new JoinedRow2 - - private[this] val joinKeys = streamSideKeyGenerator() - - override final def hasNext: Boolean = - (currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) || - (streamIter.hasNext && fetchNext()) - - override final def next() = { - val ret = buildSide match { - case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition)) - case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow) - } - currentMatchPosition += 1 - ret - } - - /** - * Searches the streamed iterator for the next row that has at least one match in hashtable. - * - * @return true if the search is successful, and false if the streamed iterator runs out of - * tuples. - */ - private final def fetchNext(): Boolean = { - currentHashMatches = null - currentMatchPosition = -1 - - while (currentHashMatches == null && streamIter.hasNext) { - currentStreamedRow = streamIter.next() - if (!joinKeys(currentStreamedRow).anyNull) { - currentHashMatches = hashTable.get(joinKeys.currentValue) - } - } - - if (currentHashMatches == null) { - false - } else { - currentMatchPosition = 0 - true - } - } - } - } -} - -/** - * :: DeveloperApi :: - * Performs a hash based outer join for two child relations by shuffling the data using - * the join keys. This operator requires loading the associated partition in both side into memory. - */ -@DeveloperApi -case class HashOuterJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - joinType: JoinType, - condition: Option[Expression], - left: SparkPlan, - right: SparkPlan) extends BinaryNode { - - override def outputPartitioning: Partitioning = joinType match { - case LeftOuter => left.outputPartitioning - case RightOuter => right.outputPartitioning - case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) - case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") - } - - override def requiredChildDistribution = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - - override def output = { - joinType match { - case LeftOuter => - left.output ++ right.output.map(_.withNullability(true)) - case RightOuter => - left.output.map(_.withNullability(true)) ++ right.output - case FullOuter => - left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) - case x => - throw new Exception(s"HashOuterJoin should not take $x as the JoinType") - } - } - - @transient private[this] lazy val DUMMY_LIST = Seq[Row](null) - @transient private[this] lazy val EMPTY_LIST = Seq.empty[Row] - - // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala - // iterator for performance purpose. - - private[this] def leftOuterIterator( - key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { - val joinedRow = new JoinedRow() - val rightNullRow = new GenericRow(right.output.length) - val boundCondition = - condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) - - leftIter.iterator.flatMap { l => - joinedRow.withLeft(l) - var matched = false - (if (!key.anyNull) rightIter.collect { case r if (boundCondition(joinedRow.withRight(r))) => - matched = true - joinedRow.copy - } else { - Nil - }) ++ DUMMY_LIST.filter(_ => !matched).map( _ => { - // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, - // as we don't know whether we need to append it until finish iterating all of the - // records in right side. - // If we didn't get any proper row, then append a single row with empty right - joinedRow.withRight(rightNullRow).copy - }) - } - } - - private[this] def rightOuterIterator( - key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { - val joinedRow = new JoinedRow() - val leftNullRow = new GenericRow(left.output.length) - val boundCondition = - condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) - - rightIter.iterator.flatMap { r => - joinedRow.withRight(r) - var matched = false - (if (!key.anyNull) leftIter.collect { case l if (boundCondition(joinedRow.withLeft(l))) => - matched = true - joinedRow.copy - } else { - Nil - }) ++ DUMMY_LIST.filter(_ => !matched).map( _ => { - // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, - // as we don't know whether we need to append it until finish iterating all of the - // records in left side. - // If we didn't get any proper row, then append a single row with empty left. - joinedRow.withLeft(leftNullRow).copy - }) - } - } - - private[this] def fullOuterIterator( - key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { - val joinedRow = new JoinedRow() - val leftNullRow = new GenericRow(left.output.length) - val rightNullRow = new GenericRow(right.output.length) - val boundCondition = - condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) - - if (!key.anyNull) { - // Store the positions of records in right, if one of its associated row satisfy - // the join condition. - val rightMatchedSet = scala.collection.mutable.Set[Int]() - leftIter.iterator.flatMap[Row] { l => - joinedRow.withLeft(l) - var matched = false - rightIter.zipWithIndex.collect { - // 1. For those matched (satisfy the join condition) records with both sides filled, - // append them directly - - case (r, idx) if (boundCondition(joinedRow.withRight(r)))=> { - matched = true - // if the row satisfy the join condition, add its index into the matched set - rightMatchedSet.add(idx) - joinedRow.copy - } - } ++ DUMMY_LIST.filter(_ => !matched).map( _ => { - // 2. For those unmatched records in left, append additional records with empty right. - - // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, - // as we don't know whether we need to append it until finish iterating all - // of the records in right side. - // If we didn't get any proper row, then append a single row with empty right. - joinedRow.withRight(rightNullRow).copy - }) - } ++ rightIter.zipWithIndex.collect { - // 3. For those unmatched records in right, append additional records with empty left. - - // Re-visiting the records in right, and append additional row with empty left, if its not - // in the matched set. - case (r, idx) if (!rightMatchedSet.contains(idx)) => { - joinedRow(leftNullRow, r).copy - } - } - } else { - leftIter.iterator.map[Row] { l => - joinedRow(l, rightNullRow).copy - } ++ rightIter.iterator.map[Row] { r => - joinedRow(leftNullRow, r).copy - } - } - } - - private[this] def buildHashTable( - iter: Iterator[Row], keyGenerator: Projection): JavaHashMap[Row, CompactBuffer[Row]] = { - val hashTable = new JavaHashMap[Row, CompactBuffer[Row]]() - while (iter.hasNext) { - val currentRow = iter.next() - val rowKey = keyGenerator(currentRow) - - var existingMatchList = hashTable.get(rowKey) - if (existingMatchList == null) { - existingMatchList = new CompactBuffer[Row]() - hashTable.put(rowKey, existingMatchList) - } - - existingMatchList += currentRow.copy() - } - - hashTable - } - - def execute() = { - left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => - // TODO this probably can be replaced by external sort (sort merged join?) - // Build HashMap for current partition in left relation - val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) - // Build HashMap for current partition in right relation - val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) - - import scala.collection.JavaConversions._ - val boundCondition = - condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) - joinType match { - case LeftOuter => leftHashTable.keysIterator.flatMap { key => - leftOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), - rightHashTable.getOrElse(key, EMPTY_LIST)) - } - case RightOuter => rightHashTable.keysIterator.flatMap { key => - rightOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), - rightHashTable.getOrElse(key, EMPTY_LIST)) - } - case FullOuter => (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => - fullOuterIterator(key, - leftHashTable.getOrElse(key, EMPTY_LIST), - rightHashTable.getOrElse(key, EMPTY_LIST)) - } - case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") - } - } - } -} - -/** - * :: DeveloperApi :: - * Performs an inner hash join of two child relations by first shuffling the data using the join - * keys. - */ -@DeveloperApi -case class ShuffledHashJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - buildSide: BuildSide, - left: SparkPlan, - right: SparkPlan) extends BinaryNode with HashJoin { - - override def outputPartitioning: Partitioning = left.outputPartitioning - - override def requiredChildDistribution = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - - def execute() = { - buildPlan.execute().zipPartitions(streamedPlan.execute()) { - (buildIter, streamIter) => joinIterators(buildIter, streamIter) - } - } -} - -/** - * :: DeveloperApi :: - * Build the right table's join keys into a HashSet, and iteratively go through the left - * table, to find the if join keys are in the Hash set. - */ -@DeveloperApi -case class LeftSemiJoinHash( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - left: SparkPlan, - right: SparkPlan) extends BinaryNode with HashJoin { - - val buildSide = BuildRight - - override def requiredChildDistribution = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - - override def output = left.output - - def execute() = { - buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => - val hashSet = new java.util.HashSet[Row]() - var currentRow: Row = null - - // Create a Hash set of buildKeys - while (buildIter.hasNext) { - currentRow = buildIter.next() - val rowKey = buildSideKeyGenerator(currentRow) - if (!rowKey.anyNull) { - val keyExists = hashSet.contains(rowKey) - if (!keyExists) { - hashSet.add(rowKey) - } - } - } - - val joinKeys = streamSideKeyGenerator() - streamIter.filter(current => { - !joinKeys(current).anyNull && hashSet.contains(joinKeys.currentValue) - }) - } - } -} - - -/** - * :: DeveloperApi :: - * Performs an inner hash join of two child relations. When the output RDD of this operator is - * being constructed, a Spark job is asynchronously started to calculate the values for the - * broadcasted relation. This data is then placed in a Spark broadcast variable. The streamed - * relation is not shuffled. - */ -@DeveloperApi -case class BroadcastHashJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - buildSide: BuildSide, - left: SparkPlan, - right: SparkPlan) extends BinaryNode with HashJoin { - - override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning - - override def requiredChildDistribution = - UnspecifiedDistribution :: UnspecifiedDistribution :: Nil - - @transient - val broadcastFuture = future { - sparkContext.broadcast(buildPlan.executeCollect()) - } - - def execute() = { - val broadcastRelation = Await.result(broadcastFuture, 5.minute) - - streamedPlan.execute().mapPartitions { streamedIter => - joinIterators(broadcastRelation.value.iterator, streamedIter) - } - } -} - -/** - * :: DeveloperApi :: - * Using BroadcastNestedLoopJoin to calculate left semi join result when there's no join keys - * for hash join. - */ -@DeveloperApi -case class LeftSemiJoinBNL( - streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression]) - extends BinaryNode { - // TODO: Override requiredChildDistribution. - - override def outputPartitioning: Partitioning = streamed.outputPartitioning - - def output = left.output - - /** The Streamed Relation */ - def left = streamed - /** The Broadcast relation */ - def right = broadcast - - @transient lazy val boundCondition = - InterpretedPredicate( - condition - .map(c => BindReferences.bindReference(c, left.output ++ right.output)) - .getOrElse(Literal(true))) - - def execute() = { - val broadcastedRelation = - sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) - - streamed.execute().mapPartitions { streamedIter => - val joinedRow = new JoinedRow - - streamedIter.filter(streamedRow => { - var i = 0 - var matched = false - - while (i < broadcastedRelation.value.size && !matched) { - val broadcastedRow = broadcastedRelation.value(i) - if (boundCondition(joinedRow(streamedRow, broadcastedRow))) { - matched = true - } - i += 1 - } - matched - }) - } - } -} - -/** - * :: DeveloperApi :: - */ -@DeveloperApi -case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode { - def output = left.output ++ right.output - - def execute() = { - val leftResults = left.execute().map(_.copy()) - val rightResults = right.execute().map(_.copy()) - - leftResults.cartesian(rightResults).mapPartitions { iter => - val joinedRow = new JoinedRow - iter.map(r => joinedRow(r._1, r._2)) - } - } -} - -/** - * :: DeveloperApi :: - */ -@DeveloperApi -case class BroadcastNestedLoopJoin( - left: SparkPlan, - right: SparkPlan, - buildSide: BuildSide, - joinType: JoinType, - condition: Option[Expression]) extends BinaryNode { - // TODO: Override requiredChildDistribution. - - /** BuildRight means the right relation <=> the broadcast relation. */ - val (streamed, broadcast) = buildSide match { - case BuildRight => (left, right) - case BuildLeft => (right, left) - } - - override def outputPartitioning: Partitioning = streamed.outputPartitioning - - override def output = { - joinType match { - case LeftOuter => - left.output ++ right.output.map(_.withNullability(true)) - case RightOuter => - left.output.map(_.withNullability(true)) ++ right.output - case FullOuter => - left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) - case _ => - left.output ++ right.output - } - } - - @transient lazy val boundCondition = - InterpretedPredicate( - condition - .map(c => BindReferences.bindReference(c, left.output ++ right.output)) - .getOrElse(Literal(true))) - - def execute() = { - val broadcastedRelation = - sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) - - /** All rows that either match both-way, or rows from streamed joined with nulls. */ - val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter => - val matchedRows = new CompactBuffer[Row] - // TODO: Use Spark's BitSet. - val includedBroadcastTuples = - new scala.collection.mutable.BitSet(broadcastedRelation.value.size) - val joinedRow = new JoinedRow - val leftNulls = new GenericMutableRow(left.output.size) - val rightNulls = new GenericMutableRow(right.output.size) - - streamedIter.foreach { streamedRow => - var i = 0 - var streamRowMatched = false - - while (i < broadcastedRelation.value.size) { - // TODO: One bitset per partition instead of per row. - val broadcastedRow = broadcastedRelation.value(i) - buildSide match { - case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) => - matchedRows += joinedRow(streamedRow, broadcastedRow).copy() - streamRowMatched = true - includedBroadcastTuples += i - case BuildLeft if boundCondition(joinedRow(broadcastedRow, streamedRow)) => - matchedRows += joinedRow(broadcastedRow, streamedRow).copy() - streamRowMatched = true - includedBroadcastTuples += i - case _ => - } - i += 1 - } - - (streamRowMatched, joinType, buildSide) match { - case (false, LeftOuter | FullOuter, BuildRight) => - matchedRows += joinedRow(streamedRow, rightNulls).copy() - case (false, RightOuter | FullOuter, BuildLeft) => - matchedRows += joinedRow(leftNulls, streamedRow).copy() - case _ => - } - } - Iterator((matchedRows, includedBroadcastTuples)) - } - - val includedBroadcastTuples = matchesOrStreamedRowsWithNulls.map(_._2) - val allIncludedBroadcastTuples = - if (includedBroadcastTuples.count == 0) { - new scala.collection.mutable.BitSet(broadcastedRelation.value.size) - } else { - includedBroadcastTuples.reduce(_ ++ _) - } - - val leftNulls = new GenericMutableRow(left.output.size) - val rightNulls = new GenericMutableRow(right.output.size) - /** Rows from broadcasted joined with nulls. */ - val broadcastRowsWithNulls: Seq[Row] = { - val buf: CompactBuffer[Row] = new CompactBuffer() - var i = 0 - val rel = broadcastedRelation.value - while (i < rel.length) { - if (!allIncludedBroadcastTuples.contains(i)) { - (joinType, buildSide) match { - case (RightOuter | FullOuter, BuildRight) => buf += new JoinedRow(leftNulls, rel(i)) - case (LeftOuter | FullOuter, BuildLeft) => buf += new JoinedRow(rel(i), rightNulls) - case _ => - } - } - i += 1 - } - buf.toSeq - } - - // TODO: Breaks lineage. - sparkContext.union( - matchesOrStreamedRowsWithNulls.flatMap(_._1), sparkContext.makeRDD(broadcastRowsWithNulls)) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 6c7697ece8c56..d62ceba863002 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.{LeftOuter, RightOuter, FullOuter, Inner, LeftSemi} import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.join._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 6fb6cb8db0c8f..45a9933d77efb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.{ShuffledHashJoin, BroadcastHashJoin} +import org.apache.spark.sql.execution.join.BroadcastHashJoin import org.apache.spark.sql.test._ import org.scalatest.BeforeAndAfterAll import java.util.TimeZone diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index bfbf431a11913..85cf5a844edf2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -19,10 +19,11 @@ package org.apache.spark.sql.execution import org.scalatest.FunSuite +import org.apache.spark.sql.{SQLConf, execution} import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.{SQLConf, execution} +import org.apache.spark.sql.execution.join.{BroadcastHashJoin, ShuffledHashJoin} import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.test.TestSQLContext.planner._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index a35c40efdc207..06d69ca60be6f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -24,7 +24,7 @@ import scala.reflect.ClassTag import org.apache.spark.sql.{SQLConf, QueryTest} import org.apache.spark.sql.catalyst.plans.logical.NativeCommand -import org.apache.spark.sql.execution.{BroadcastHashJoin, ShuffledHashJoin} +import org.apache.spark.sql.execution.join.{BroadcastHashJoin, ShuffledHashJoin} import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ From a070d44aa31a6af4cd8d45fc2c02adef61bb03b9 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 8 Oct 2014 15:26:52 -0700 Subject: [PATCH 02/14] Fix line length in HashJoin --- .../scala/org/apache/spark/sql/execution/join/HashJoin.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/join/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/join/HashJoin.scala index 211076c098db3..9e379f2fb1fdc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/join/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/join/HashJoin.scala @@ -47,7 +47,8 @@ trait HashJoin { @transient protected lazy val streamSideKeyGenerator = newMutableProjection(streamedKeys, streamedPlan.output) - protected def joinIterators(buildIter: Iterator[Row], streamIter: Iterator[Row]): Iterator[Row] = { + protected def joinIterators(buildIter: Iterator[Row], streamIter: Iterator[Row]): Iterator[Row] = + { // TODO: Use Spark's HashMap implementation. val hashTable = new java.util.HashMap[Row, CompactBuffer[Row]]() From cbc664c87c2b0e6437990ca09c8771e34d9816e3 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 8 Oct 2014 16:52:11 -0700 Subject: [PATCH 03/14] Rename join -> joins package. --- .../spark/sql/execution/SparkStrategies.scala | 34 +++++++++---------- .../{join => joins}/BroadcastHashJoin.scala | 2 +- .../BroadcastNestedLoopJoin.scala | 2 +- .../{join => joins}/CartesianProduct.scala | 2 +- .../execution/{join => joins}/HashJoin.scala | 2 +- .../{join => joins}/HashOuterJoin.scala | 2 +- .../{join => joins}/LeftSemiJoinBNL.scala | 2 +- .../{join => joins}/LeftSemiJoinHash.scala | 2 +- .../{join => joins}/ShuffledHashJoin.scala | 2 +- .../execution/{join => joins}/package.scala | 2 +- .../org/apache/spark/sql/JoinSuite.scala | 2 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 2 +- .../spark/sql/execution/PlannerSuite.scala | 2 +- .../spark/sql/hive/StatisticsSuite.scala | 2 +- 14 files changed, 30 insertions(+), 30 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/execution/{join => joins}/BroadcastHashJoin.scala (98%) rename sql/core/src/main/scala/org/apache/spark/sql/execution/{join => joins}/BroadcastNestedLoopJoin.scala (99%) rename sql/core/src/main/scala/org/apache/spark/sql/execution/{join => joins}/CartesianProduct.scala (96%) rename sql/core/src/main/scala/org/apache/spark/sql/execution/{join => joins}/HashJoin.scala (98%) rename sql/core/src/main/scala/org/apache/spark/sql/execution/{join => joins}/HashOuterJoin.scala (99%) rename sql/core/src/main/scala/org/apache/spark/sql/execution/{join => joins}/LeftSemiJoinBNL.scala (98%) rename sql/core/src/main/scala/org/apache/spark/sql/execution/{join => joins}/LeftSemiJoinHash.scala (98%) rename sql/core/src/main/scala/org/apache/spark/sql/execution/{join => joins}/ShuffledHashJoin.scala (97%) rename sql/core/src/main/scala/org/apache/spark/sql/execution/{join => joins}/package.scala (98%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 420bc11f748af..5585c1c1f65e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -35,11 +35,11 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { // Find left semi joins where at least some predicates can be evaluated by matching join keys case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) => - val semiJoin = join.LeftSemiJoinHash(leftKeys, rightKeys, planLater(left), planLater(right)) + val semiJoin = joins.LeftSemiJoinHash(leftKeys, rightKeys, planLater(left), planLater(right)) condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil // no predicate can be evaluated by matching hash keys case logical.Join(left, right, LeftSemi, condition) => - join.LeftSemiJoinBNL(planLater(left), planLater(right), condition) :: Nil + joins.LeftSemiJoinBNL(planLater(left), planLater(right), condition) :: Nil case _ => Nil } } @@ -49,13 +49,13 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * evaluated by matching hash keys. * * This strategy applies a simple optimization based on the estimates of the physical sizes of - * the two join sides. When planning a [[join.BroadcastHashJoin]], if one side has an + * the two join sides. When planning a [[joins.BroadcastHashJoin]], if one side has an * estimated physical size smaller than the user-settable threshold * [[org.apache.spark.sql.SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]], the planner would mark it as the * ''build'' relation and mark the other relation as the ''stream'' side. The build table will be * ''broadcasted'' to all of the executors involved in the join, as a * [[org.apache.spark.broadcast.Broadcast]] object. If both estimates exceed the threshold, they - * will instead be used to decide the build side in a [[join.ShuffledHashJoin]]. + * will instead be used to decide the build side in a [[joins.ShuffledHashJoin]]. */ object HashJoin extends Strategy with PredicateHelper { @@ -65,8 +65,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { left: LogicalPlan, right: LogicalPlan, condition: Option[Expression], - side: join.BuildSide) = { - val broadcastHashJoin = execution.join.BroadcastHashJoin( + side: joins.BuildSide) = { + val broadcastHashJoin = execution.joins.BroadcastHashJoin( leftKeys, rightKeys, side, planLater(left), planLater(right)) condition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) :: Nil } @@ -75,26 +75,26 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) if sqlContext.autoBroadcastJoinThreshold > 0 && right.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold => - makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, join.BuildRight) + makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildRight) case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) if sqlContext.autoBroadcastJoinThreshold > 0 && left.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold => - makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, join.BuildLeft) + makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft) case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) => val buildSide = if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { - join.BuildRight + joins.BuildRight } else { - join.BuildLeft + joins.BuildLeft } - val hashJoin = join.ShuffledHashJoin( + val hashJoin = joins.ShuffledHashJoin( leftKeys, rightKeys, buildSide, planLater(left), planLater(right)) condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) => - join.HashOuterJoin( + joins.HashOuterJoin( leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil case _ => Nil @@ -163,11 +163,11 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Join(left, right, joinType, condition) => val buildSide = if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { - join.BuildRight + joins.BuildRight } else { - join.BuildLeft + joins.BuildLeft } - join.BroadcastNestedLoopJoin( + joins.BroadcastNestedLoopJoin( planLater(left), planLater(right), buildSide, joinType, condition) :: Nil case _ => Nil } @@ -176,10 +176,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object CartesianProduct extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Join(left, right, _, None) => - execution.join.CartesianProduct(planLater(left), planLater(right)) :: Nil + execution.joins.CartesianProduct(planLater(left), planLater(right)) :: Nil case logical.Join(left, right, Inner, Some(condition)) => execution.Filter(condition, - execution.join.CartesianProduct(planLater(left), planLater(right))) :: Nil + execution.joins.CartesianProduct(planLater(left), planLater(right))) :: Nil case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/join/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/join/BroadcastHashJoin.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 1e7626905f3b2..d88ab6367a1b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/join/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.join +package org.apache.spark.sql.execution.joins import scala.concurrent._ import scala.concurrent.duration._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/join/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/join/BroadcastNestedLoopJoin.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index b29eaafd7f2fe..36aad13778bd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/join/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.join +package org.apache.spark.sql.execution.joins import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.expressions._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/join/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala similarity index 96% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/join/CartesianProduct.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala index ddc459edca6fb..76c14c02aab34 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/join/CartesianProduct.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.join +package org.apache.spark.sql.execution.joins import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.expressions.JoinedRow diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/join/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/join/HashJoin.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 9e379f2fb1fdc..472b2e6ca6b4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/join/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.join +package org.apache.spark.sql.execution.joins import org.apache.spark.sql.catalyst.expressions.{Expression, JoinedRow2, Row} import org.apache.spark.sql.execution.SparkPlan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/join/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/join/HashOuterJoin.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index caacb0ffca462..b73041d306b36 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/join/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.join +package org.apache.spark.sql.execution.joins import java.util.{HashMap => JavaHashMap} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/join/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/join/LeftSemiJoinBNL.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala index d4401c8ce3ff4..60003d1900d85 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/join/LeftSemiJoinBNL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.join +package org.apache.spark.sql.execution.joins import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.expressions._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/join/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/join/LeftSemiJoinHash.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala index e383e00725a7e..ea7babf3be948 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/join/LeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.join +package org.apache.spark.sql.execution.joins import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.expressions.{Expression, Row} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/join/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala similarity index 97% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/join/ShuffledHashJoin.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index ad248401f2e54..8247304c1dc2c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/join/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.join +package org.apache.spark.sql.execution.joins import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.expressions.Expression diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/join/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/package.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/join/package.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/joins/package.scala index 3ee9709137728..7f2ab1765b28f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/join/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/package.scala @@ -23,7 +23,7 @@ import org.apache.spark.annotation.DeveloperApi * :: DeveloperApi :: * Physical execution operators for join operations. */ -package object join { +package object joins { @DeveloperApi sealed abstract class BuildSide diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index d62ceba863002..07f4d2946c1b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.{LeftOuter, RightOuter, FullOuter, Inner, LeftSemi} import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.join._ +import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 45a9933d77efb..b9b196ea5a46a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.join.BroadcastHashJoin +import org.apache.spark.sql.execution.joins.BroadcastHashJoin import org.apache.spark.sql.test._ import org.scalatest.BeforeAndAfterAll import java.util.TimeZone diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 85cf5a844edf2..f14ffca0e4d35 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.{SQLConf, execution} import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.execution.join.{BroadcastHashJoin, ShuffledHashJoin} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin} import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.test.TestSQLContext.planner._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 06d69ca60be6f..14e791fe0f0ee 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -24,7 +24,7 @@ import scala.reflect.ClassTag import org.apache.spark.sql.{SQLConf, QueryTest} import org.apache.spark.sql.catalyst.plans.logical.NativeCommand -import org.apache.spark.sql.execution.join.{BroadcastHashJoin, ShuffledHashJoin} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin} import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ From 0c0082b5d656a57dee41d97f69a212d36a3c3533 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 8 Oct 2014 16:55:39 -0700 Subject: [PATCH 04/14] Fix line length. --- .../scala/org/apache/spark/sql/execution/SparkStrategies.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 5585c1c1f65e6..c71dffde32671 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -35,7 +35,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { // Find left semi joins where at least some predicates can be evaluated by matching join keys case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) => - val semiJoin = joins.LeftSemiJoinHash(leftKeys, rightKeys, planLater(left), planLater(right)) + val semiJoin = joins.LeftSemiJoinHash( + leftKeys, rightKeys, planLater(left), planLater(right)) condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil // no predicate can be evaluated by matching hash keys case logical.Join(left, right, LeftSemi, condition) => From 90b58c0aed328a329d96bed32b4293f3ac3a208b Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 8 Oct 2014 16:56:54 -0700 Subject: [PATCH 05/14] [SPARK-3861] Avoid rebuilding hash tables on each partition BroadcastHashJoin builds a new hash table for each partition. We can build it once per node and reuse the hash table. --- .../execution/joins/BroadcastHashJoin.scala | 8 +- .../spark/sql/execution/joins/HashJoin.scala | 34 ++---- .../sql/execution/joins/HashedRelation.scala | 106 ++++++++++++++++++ .../execution/joins/ShuffledHashJoin.scala | 5 +- 4 files changed, 121 insertions(+), 32 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index d88ab6367a1b3..8fd35880eedfe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -22,7 +22,7 @@ import scala.concurrent.duration._ import scala.concurrent.ExecutionContext.Implicits.global import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{Row, Expression} import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} @@ -49,14 +49,16 @@ case class BroadcastHashJoin( @transient private val broadcastFuture = future { - sparkContext.broadcast(buildPlan.executeCollect()) + val input: Array[Row] = buildPlan.executeCollect() + val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.length) + sparkContext.broadcast(hashed) } override def execute() = { val broadcastRelation = Await.result(broadcastFuture, 5.minute) streamedPlan.execute().mapPartitions { streamedIter => - joinIterators(broadcastRelation.value.iterator, streamedIter) + hashJoin(streamedIter, broadcastRelation.value) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 472b2e6ca6b4a..4012d757d5f9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.sql.catalyst.expressions.{Expression, JoinedRow2, Row} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.util.collection.CompactBuffer @@ -43,34 +43,14 @@ trait HashJoin { override def output = left.output ++ right.output - @transient protected lazy val buildSideKeyGenerator = newProjection(buildKeys, buildPlan.output) - @transient protected lazy val streamSideKeyGenerator = + @transient protected lazy val buildSideKeyGenerator: Projection = + newProjection(buildKeys, buildPlan.output) + + @transient protected lazy val streamSideKeyGenerator: () => MutableProjection = newMutableProjection(streamedKeys, streamedPlan.output) - protected def joinIterators(buildIter: Iterator[Row], streamIter: Iterator[Row]): Iterator[Row] = + protected def hashJoin(streamIter: Iterator[Row], hashedRelation: HashedRelation): Iterator[Row] = { - // TODO: Use Spark's HashMap implementation. - - val hashTable = new java.util.HashMap[Row, CompactBuffer[Row]]() - var currentRow: Row = null - - // Create a mapping of buildKeys -> rows - while (buildIter.hasNext) { - currentRow = buildIter.next() - val rowKey = buildSideKeyGenerator(currentRow) - if (!rowKey.anyNull) { - val existingMatchList = hashTable.get(rowKey) - val matchList = if (existingMatchList == null) { - val newMatchList = new CompactBuffer[Row]() - hashTable.put(rowKey, newMatchList) - newMatchList - } else { - existingMatchList - } - matchList += currentRow.copy() - } - } - new Iterator[Row] { private[this] var currentStreamedRow: Row = _ private[this] var currentHashMatches: CompactBuffer[Row] = _ @@ -107,7 +87,7 @@ trait HashJoin { while (currentHashMatches == null && streamIter.hasNext) { currentStreamedRow = streamIter.next() if (!joinKeys(currentStreamedRow).anyNull) { - currentHashMatches = hashTable.get(joinKeys.currentValue) + currentHashMatches = hashedRelation.get(joinKeys.currentValue) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala new file mode 100644 index 0000000000000..cd43357f846cf --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import java.util.{HashMap => JavaHashMap} + +import org.apache.spark.sql.catalyst.expressions.{Projection, Row} +import org.apache.spark.util.collection.CompactBuffer + + +/** + * Interface for a hashed relation by some key. Use [[HashedRelation.apply]] to create a concrete + * object. + */ +private[joins] sealed trait HashedRelation { + def get(key: Row): CompactBuffer[Row] +} + + +/** + * A general [[HashedRelation]] backed by a hash map that maps the key into a sequence of values. + */ +private[joins] final class GeneralHashedRelation(hashTable: JavaHashMap[Row, CompactBuffer[Row]]) + extends HashedRelation with Serializable { + + override def get(key: Row) = hashTable.get(key) +} + + +/** + * A specialized [[HashedRelation]] that maps key into a single value. This implementation + * assumes the key is unique. + */ +final class UniqueKeyHashedRelation(hashTable: JavaHashMap[Row, Row]) + extends HashedRelation with Serializable { + + override def get(key: Row) = CompactBuffer(hashTable.get(key)) + + def getValue(key: Row): Row = hashTable.get(key) +} + + +// TODO(rxin): a version of [[HashedRelation]] backed by arrays for consecutive integer keys. + + +private[joins] object HashedRelation { + + def apply( + input: Iterator[Row], + keyGenerator: Projection, + sizeEstimate: Int = 64): HashedRelation = { + + // TODO: Use Spark's HashMap implementation. + val hashTable = new JavaHashMap[Row, CompactBuffer[Row]](sizeEstimate) + var currentRow: Row = null + + // Whether the join key is unique. If the key is unique, we can convert the underlying + // hash map into one specialized for this. + var keyIsUnique = true + + // Create a mapping of buildKeys -> rows + while (input.hasNext) { + currentRow = input.next() + val rowKey = keyGenerator(currentRow) + if (!rowKey.anyNull) { + val existingMatchList = hashTable.get(rowKey) + val matchList = if (existingMatchList == null) { + val newMatchList = new CompactBuffer[Row]() + hashTable.put(rowKey, newMatchList) + newMatchList + } else { + keyIsUnique = false + existingMatchList + } + matchList += currentRow.copy() + } + } + + if (keyIsUnique) { + val uniqHashTable = new JavaHashMap[Row, Row](hashTable.size) + val iter = hashTable.entrySet().iterator() + while (iter.hasNext) { + val entry = iter.next() + uniqHashTable.put(entry.getKey, entry.getValue()(0)) + } + new UniqueKeyHashedRelation(uniqHashTable) + } else { + new GeneralHashedRelation(hashTable) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index 8247304c1dc2c..418c1c23e5546 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -42,8 +42,9 @@ case class ShuffledHashJoin( ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil override def execute() = { - buildPlan.execute().zipPartitions(streamedPlan.execute()) { - (buildIter, streamIter) => joinIterators(buildIter, streamIter) + buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => + val hashed = HashedRelation(buildIter, buildSideKeyGenerator) + hashJoin(streamIter, hashed) } } } From a85f24accd3266e0f97ee04d03c22b593d99c062 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Wed, 8 Oct 2014 17:03:47 -0700 Subject: [PATCH 06/14] [SPARK-3831] [SQL] Filter rule Improvement and bool expression optimization. If we write the filter which is always FALSE like SELECT * from person WHERE FALSE; 200 tasks will run. I think, 1 task is enough. And current optimizer cannot optimize the case NOT is duplicated like SELECT * from person WHERE NOT ( NOT (age > 30)); The filter rule above should be simplified Author: Kousuke Saruta Closes #2692 from sarutak/SPARK-3831 and squashes the following commits: 25f3e20 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-3831 23c750c [Kousuke Saruta] Improved unsupported predicate test case a11b9f3 [Kousuke Saruta] Modified NOT predicate test case in PartitionBatchPruningSuite 8ea872b [Kousuke Saruta] Fixed the number of tasks when the data of LocalRelation is empty. --- .../spark/sql/catalyst/optimizer/Optimizer.scala | 12 ++++++++++++ .../apache/spark/sql/execution/SparkStrategies.scala | 3 ++- .../sql/columnar/PartitionBatchPruningSuite.scala | 3 ++- 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a4133feae8166..636d0b95583e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -299,6 +299,18 @@ object BooleanSimplification extends Rule[LogicalPlan] { case (_, _) => or } + case not @ Not(exp) => + exp match { + case Literal(true, BooleanType) => Literal(false) + case Literal(false, BooleanType) => Literal(true) + case GreaterThan(l, r) => LessThanOrEqual(l, r) + case GreaterThanOrEqual(l, r) => LessThan(l, r) + case LessThan(l, r) => GreaterThanOrEqual(l, r) + case LessThanOrEqual(l, r) => GreaterThan(l, r) + case Not(e) => e + case _ => not + } + // Turn "if (true) a else b" into "a", and if (false) a else b" into "b". case e @ If(Literal(v, _), trueValue, falseValue) => if (v == true) trueValue else falseValue } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 5c16d0c624128..883f2ff521e20 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -274,9 +274,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil case SparkLogicalPlan(alreadyPlanned) => alreadyPlanned :: Nil case logical.LocalRelation(output, data) => + val nPartitions = if (data.isEmpty) 1 else numPartitions PhysicalRDD( output, - RDDConversions.productToRowRdd(sparkContext.parallelize(data, numPartitions))) :: Nil + RDDConversions.productToRowRdd(sparkContext.parallelize(data, nPartitions))) :: Nil case logical.Limit(IntegerLiteral(limit), child) => execution.Limit(limit, planLater(child)) :: Nil case Unions(unionChildren) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index 69e0adbd3ee0d..f53acc8c9f718 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -67,10 +67,11 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be checkBatchPruning("i > 8 AND i <= 21", 9 to 21, 2, 3) checkBatchPruning("i < 2 OR i > 99", Seq(1, 100), 2, 2) checkBatchPruning("i < 2 OR (i > 78 AND i < 92)", Seq(1) ++ (79 to 91), 3, 4) + checkBatchPruning("NOT (i < 88)", 88 to 100, 1, 2) // With unsupported predicate checkBatchPruning("i < 12 AND i IS NOT NULL", 1 to 11, 1, 2) - checkBatchPruning("NOT (i < 88)", 88 to 100, 5, 10) + checkBatchPruning(s"NOT (i in (${(1 to 30).mkString(",")}))", 31 to 100, 5, 10) def checkBatchPruning( filter: String, From a42cc08d219c579019f613faa8d310e6069c06fe Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 8 Oct 2014 17:04:49 -0700 Subject: [PATCH 07/14] [SPARK-3713][SQL] Uses JSON to serialize DataType objects This PR uses JSON instead of `toString` to serialize `DataType`s. The latter is not only hard to parse but also flaky in many cases. Since we already write schema information to Parquet metadata in the old style, we have to reserve the old `DataType` parser and ensure downward compatibility. The old parser is now renamed to `CaseClassStringParser` and moved into `object DataType`. JoshRosen davies Please help review PySpark related changes, thanks! Author: Cheng Lian Closes #2563 from liancheng/datatype-to-json and squashes the following commits: fc92eb3 [Cheng Lian] Reverts debugging code, simplifies primitive type JSON representation 438c75f [Cheng Lian] Refactors PySpark DataType JSON SerDe per comments 6b6387b [Cheng Lian] Removes debugging code 6a3ee3a [Cheng Lian] Addresses per review comments dc158b5 [Cheng Lian] Addresses PEP8 issues 99ab4ee [Cheng Lian] Adds compatibility est case for Parquet type conversion a983a6c [Cheng Lian] Adds PySpark support f608c6e [Cheng Lian] De/serializes DataType objects from/to JSON --- python/pyspark/sql.py | 153 ++++++------ .../catalyst/expressions/WrapDynamic.scala | 4 +- .../spark/sql/catalyst/types/dataTypes.scala | 229 ++++++++++++------ .../org/apache/spark/sql/SQLContext.scala | 9 +- .../spark/sql/parquet/ParquetTypes.scala | 6 +- .../org/apache/spark/sql/DataTypeSuite.scala | 28 +++ .../spark/sql/parquet/ParquetQuerySuite.scala | 16 +- 7 files changed, 277 insertions(+), 168 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 3d5a281239d66..d3d36eb995ab6 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -34,6 +34,7 @@ import datetime import keyword import warnings +import json from array import array from operator import itemgetter from itertools import imap @@ -71,6 +72,18 @@ def __eq__(self, other): def __ne__(self, other): return not self.__eq__(other) + @classmethod + def typeName(cls): + return cls.__name__[:-4].lower() + + def jsonValue(self): + return self.typeName() + + def json(self): + return json.dumps(self.jsonValue(), + separators=(',', ':'), + sort_keys=True) + class PrimitiveTypeSingleton(type): @@ -214,6 +227,16 @@ def __repr__(self): return "ArrayType(%s,%s)" % (self.elementType, str(self.containsNull).lower()) + def jsonValue(self): + return {"type": self.typeName(), + "elementType": self.elementType.jsonValue(), + "containsNull": self.containsNull} + + @classmethod + def fromJson(cls, json): + return ArrayType(_parse_datatype_json_value(json["elementType"]), + json["containsNull"]) + class MapType(DataType): @@ -254,6 +277,18 @@ def __repr__(self): return "MapType(%s,%s,%s)" % (self.keyType, self.valueType, str(self.valueContainsNull).lower()) + def jsonValue(self): + return {"type": self.typeName(), + "keyType": self.keyType.jsonValue(), + "valueType": self.valueType.jsonValue(), + "valueContainsNull": self.valueContainsNull} + + @classmethod + def fromJson(cls, json): + return MapType(_parse_datatype_json_value(json["keyType"]), + _parse_datatype_json_value(json["valueType"]), + json["valueContainsNull"]) + class StructField(DataType): @@ -292,6 +327,17 @@ def __repr__(self): return "StructField(%s,%s,%s)" % (self.name, self.dataType, str(self.nullable).lower()) + def jsonValue(self): + return {"name": self.name, + "type": self.dataType.jsonValue(), + "nullable": self.nullable} + + @classmethod + def fromJson(cls, json): + return StructField(json["name"], + _parse_datatype_json_value(json["type"]), + json["nullable"]) + class StructType(DataType): @@ -321,42 +367,30 @@ def __repr__(self): return ("StructType(List(%s))" % ",".join(str(field) for field in self.fields)) + def jsonValue(self): + return {"type": self.typeName(), + "fields": [f.jsonValue() for f in self.fields]} -def _parse_datatype_list(datatype_list_string): - """Parses a list of comma separated data types.""" - index = 0 - datatype_list = [] - start = 0 - depth = 0 - while index < len(datatype_list_string): - if depth == 0 and datatype_list_string[index] == ",": - datatype_string = datatype_list_string[start:index].strip() - datatype_list.append(_parse_datatype_string(datatype_string)) - start = index + 1 - elif datatype_list_string[index] == "(": - depth += 1 - elif datatype_list_string[index] == ")": - depth -= 1 + @classmethod + def fromJson(cls, json): + return StructType([StructField.fromJson(f) for f in json["fields"]]) - index += 1 - # Handle the last data type - datatype_string = datatype_list_string[start:index].strip() - datatype_list.append(_parse_datatype_string(datatype_string)) - return datatype_list +_all_primitive_types = dict((v.typeName(), v) + for v in globals().itervalues() + if type(v) is PrimitiveTypeSingleton and + v.__base__ == PrimitiveType) -_all_primitive_types = dict((k, v) for k, v in globals().iteritems() - if type(v) is PrimitiveTypeSingleton and v.__base__ == PrimitiveType) +_all_complex_types = dict((v.typeName(), v) + for v in [ArrayType, MapType, StructType]) -def _parse_datatype_string(datatype_string): - """Parses the given data type string. - +def _parse_datatype_json_string(json_string): + """Parses the given data type JSON string. >>> def check_datatype(datatype): - ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(str(datatype)) - ... python_datatype = _parse_datatype_string( - ... scala_datatype.toString()) + ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.json()) + ... python_datatype = _parse_datatype_json_string(scala_datatype.json()) ... return datatype == python_datatype >>> all(check_datatype(cls()) for cls in _all_primitive_types.values()) True @@ -394,51 +428,14 @@ def _parse_datatype_string(datatype_string): >>> check_datatype(complex_maptype) True """ - index = datatype_string.find("(") - if index == -1: - # It is a primitive type. - index = len(datatype_string) - type_or_field = datatype_string[:index] - rest_part = datatype_string[index + 1:len(datatype_string) - 1].strip() - - if type_or_field in _all_primitive_types: - return _all_primitive_types[type_or_field]() - - elif type_or_field == "ArrayType": - last_comma_index = rest_part.rfind(",") - containsNull = True - if rest_part[last_comma_index + 1:].strip().lower() == "false": - containsNull = False - elementType = _parse_datatype_string( - rest_part[:last_comma_index].strip()) - return ArrayType(elementType, containsNull) - - elif type_or_field == "MapType": - last_comma_index = rest_part.rfind(",") - valueContainsNull = True - if rest_part[last_comma_index + 1:].strip().lower() == "false": - valueContainsNull = False - keyType, valueType = _parse_datatype_list( - rest_part[:last_comma_index].strip()) - return MapType(keyType, valueType, valueContainsNull) - - elif type_or_field == "StructField": - first_comma_index = rest_part.find(",") - name = rest_part[:first_comma_index].strip() - last_comma_index = rest_part.rfind(",") - nullable = True - if rest_part[last_comma_index + 1:].strip().lower() == "false": - nullable = False - dataType = _parse_datatype_string( - rest_part[first_comma_index + 1:last_comma_index].strip()) - return StructField(name, dataType, nullable) - - elif type_or_field == "StructType": - # rest_part should be in the format like - # List(StructField(field1,IntegerType,false)). - field_list_string = rest_part[rest_part.find("(") + 1:-1] - fields = _parse_datatype_list(field_list_string) - return StructType(fields) + return _parse_datatype_json_value(json.loads(json_string)) + + +def _parse_datatype_json_value(json_value): + if type(json_value) is unicode and json_value in _all_primitive_types.keys(): + return _all_primitive_types[json_value]() + else: + return _all_complex_types[json_value["type"]].fromJson(json_value) # Mapping Python types to Spark SQL DateType @@ -992,7 +989,7 @@ def registerFunction(self, name, f, returnType=StringType()): self._sc.pythonExec, broadcast_vars, self._sc._javaAccumulator, - str(returnType)) + returnType.json()) def inferSchema(self, rdd): """Infer and apply a schema to an RDD of L{Row}. @@ -1128,7 +1125,7 @@ def applySchema(self, rdd, schema): batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer) jrdd = self._pythonToJava(rdd._jrdd, batched) - srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), str(schema)) + srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) return SchemaRDD(srdd.toJavaSchemaRDD(), self) def registerRDDAsTable(self, rdd, tableName): @@ -1218,7 +1215,7 @@ def jsonFile(self, path, schema=None): if schema is None: srdd = self._ssql_ctx.jsonFile(path) else: - scala_datatype = self._ssql_ctx.parseDataType(str(schema)) + scala_datatype = self._ssql_ctx.parseDataType(schema.json()) srdd = self._ssql_ctx.jsonFile(path, scala_datatype) return SchemaRDD(srdd.toJavaSchemaRDD(), self) @@ -1288,7 +1285,7 @@ def func(iterator): if schema is None: srdd = self._ssql_ctx.jsonRDD(jrdd.rdd()) else: - scala_datatype = self._ssql_ctx.parseDataType(str(schema)) + scala_datatype = self._ssql_ctx.parseDataType(schema.json()) srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) return SchemaRDD(srdd.toJavaSchemaRDD(), self) @@ -1623,7 +1620,7 @@ def saveAsTable(self, tableName): def schema(self): """Returns the schema of this SchemaRDD (represented by a L{StructType}).""" - return _parse_datatype_string(self._jschema_rdd.baseSchemaRDD().schema().toString()) + return _parse_datatype_json_string(self._jschema_rdd.baseSchemaRDD().schema().json()) def schemaString(self): """Returns the output schema in the tree format.""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala index 1eb55715794a7..1a4ac06c7a79d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala @@ -24,9 +24,7 @@ import org.apache.spark.sql.catalyst.types.DataType /** * The data type representing [[DynamicRow]] values. */ -case object DynamicType extends DataType { - def simpleString: String = "dynamic" -} +case object DynamicType extends DataType /** * Wrap a [[Row]] as a [[DynamicRow]]. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index ac043d4dd8eb9..1d375b8754182 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -19,71 +19,125 @@ package org.apache.spark.sql.catalyst.types import java.sql.Timestamp -import scala.math.Numeric.{FloatAsIfIntegral, BigDecimalAsIfIntegral, DoubleAsIfIntegral} +import scala.math.Numeric.{BigDecimalAsIfIntegral, DoubleAsIfIntegral, FloatAsIfIntegral} import scala.reflect.ClassTag -import scala.reflect.runtime.universe.{typeTag, TypeTag, runtimeMirror} +import scala.reflect.runtime.universe.{TypeTag, runtimeMirror, typeTag} import scala.util.parsing.combinator.RegexParsers +import org.json4s.JsonAST.JValue +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + import org.apache.spark.sql.catalyst.ScalaReflectionLock import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} import org.apache.spark.util.Utils -/** - * Utility functions for working with DataTypes. - */ -object DataType extends RegexParsers { - protected lazy val primitiveType: Parser[DataType] = - "StringType" ^^^ StringType | - "FloatType" ^^^ FloatType | - "IntegerType" ^^^ IntegerType | - "ByteType" ^^^ ByteType | - "ShortType" ^^^ ShortType | - "DoubleType" ^^^ DoubleType | - "LongType" ^^^ LongType | - "BinaryType" ^^^ BinaryType | - "BooleanType" ^^^ BooleanType | - "DecimalType" ^^^ DecimalType | - "TimestampType" ^^^ TimestampType - - protected lazy val arrayType: Parser[DataType] = - "ArrayType" ~> "(" ~> dataType ~ "," ~ boolVal <~ ")" ^^ { - case tpe ~ _ ~ containsNull => ArrayType(tpe, containsNull) - } - protected lazy val mapType: Parser[DataType] = - "MapType" ~> "(" ~> dataType ~ "," ~ dataType ~ "," ~ boolVal <~ ")" ^^ { - case t1 ~ _ ~ t2 ~ _ ~ valueContainsNull => MapType(t1, t2, valueContainsNull) +object DataType { + def fromJson(json: String): DataType = parseDataType(parse(json)) + + private object JSortedObject { + def unapplySeq(value: JValue): Option[List[(String, JValue)]] = value match { + case JObject(seq) => Some(seq.toList.sortBy(_._1)) + case _ => None } + } + + // NOTE: Map fields must be sorted in alphabetical order to keep consistent with the Python side. + private def parseDataType(json: JValue): DataType = json match { + case JString(name) => + PrimitiveType.nameToType(name) + + case JSortedObject( + ("containsNull", JBool(n)), + ("elementType", t: JValue), + ("type", JString("array"))) => + ArrayType(parseDataType(t), n) + + case JSortedObject( + ("keyType", k: JValue), + ("type", JString("map")), + ("valueContainsNull", JBool(n)), + ("valueType", v: JValue)) => + MapType(parseDataType(k), parseDataType(v), n) + + case JSortedObject( + ("fields", JArray(fields)), + ("type", JString("struct"))) => + StructType(fields.map(parseStructField)) + } - protected lazy val structField: Parser[StructField] = - ("StructField(" ~> "[a-zA-Z0-9_]*".r) ~ ("," ~> dataType) ~ ("," ~> boolVal <~ ")") ^^ { - case name ~ tpe ~ nullable => + private def parseStructField(json: JValue): StructField = json match { + case JSortedObject( + ("name", JString(name)), + ("nullable", JBool(nullable)), + ("type", dataType: JValue)) => + StructField(name, parseDataType(dataType), nullable) + } + + @deprecated("Use DataType.fromJson instead") + def fromCaseClassString(string: String): DataType = CaseClassStringParser(string) + + private object CaseClassStringParser extends RegexParsers { + protected lazy val primitiveType: Parser[DataType] = + ( "StringType" ^^^ StringType + | "FloatType" ^^^ FloatType + | "IntegerType" ^^^ IntegerType + | "ByteType" ^^^ ByteType + | "ShortType" ^^^ ShortType + | "DoubleType" ^^^ DoubleType + | "LongType" ^^^ LongType + | "BinaryType" ^^^ BinaryType + | "BooleanType" ^^^ BooleanType + | "DecimalType" ^^^ DecimalType + | "TimestampType" ^^^ TimestampType + ) + + protected lazy val arrayType: Parser[DataType] = + "ArrayType" ~> "(" ~> dataType ~ "," ~ boolVal <~ ")" ^^ { + case tpe ~ _ ~ containsNull => ArrayType(tpe, containsNull) + } + + protected lazy val mapType: Parser[DataType] = + "MapType" ~> "(" ~> dataType ~ "," ~ dataType ~ "," ~ boolVal <~ ")" ^^ { + case t1 ~ _ ~ t2 ~ _ ~ valueContainsNull => MapType(t1, t2, valueContainsNull) + } + + protected lazy val structField: Parser[StructField] = + ("StructField(" ~> "[a-zA-Z0-9_]*".r) ~ ("," ~> dataType) ~ ("," ~> boolVal <~ ")") ^^ { + case name ~ tpe ~ nullable => StructField(name, tpe, nullable = nullable) - } + } - protected lazy val boolVal: Parser[Boolean] = - "true" ^^^ true | - "false" ^^^ false + protected lazy val boolVal: Parser[Boolean] = + ( "true" ^^^ true + | "false" ^^^ false + ) - protected lazy val structType: Parser[DataType] = - "StructType\\([A-zA-z]*\\(".r ~> repsep(structField, ",") <~ "))" ^^ { - case fields => new StructType(fields) - } + protected lazy val structType: Parser[DataType] = + "StructType\\([A-zA-z]*\\(".r ~> repsep(structField, ",") <~ "))" ^^ { + case fields => new StructType(fields) + } - protected lazy val dataType: Parser[DataType] = - arrayType | - mapType | - structType | - primitiveType + protected lazy val dataType: Parser[DataType] = + ( arrayType + | mapType + | structType + | primitiveType + ) + + /** + * Parses a string representation of a DataType. + * + * TODO: Generate parser as pickler... + */ + def apply(asString: String): DataType = parseAll(dataType, asString) match { + case Success(result, _) => result + case failure: NoSuccess => + throw new IllegalArgumentException(s"Unsupported dataType: $asString, $failure") + } - /** - * Parses a string representation of a DataType. - * - * TODO: Generate parser as pickler... - */ - def apply(asString: String): DataType = parseAll(dataType, asString) match { - case Success(result, _) => result - case failure: NoSuccess => sys.error(s"Unsupported dataType: $asString, $failure") } protected[types] def buildFormattedString( @@ -111,15 +165,19 @@ abstract class DataType { def isPrimitive: Boolean = false - def simpleString: String -} + def typeName: String = this.getClass.getSimpleName.stripSuffix("$").dropRight(4).toLowerCase + + private[sql] def jsonValue: JValue = typeName -case object NullType extends DataType { - def simpleString: String = "null" + def json: String = compact(render(jsonValue)) + + def prettyJson: String = pretty(render(jsonValue)) } +case object NullType extends DataType + object NativeType { - def all = Seq( + val all = Seq( IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) def unapply(dt: DataType): Boolean = all.contains(dt) @@ -139,6 +197,12 @@ trait PrimitiveType extends DataType { override def isPrimitive = true } +object PrimitiveType { + private[sql] val all = Seq(DecimalType, TimestampType, BinaryType) ++ NativeType.all + + private[sql] val nameToType = all.map(t => t.typeName -> t).toMap +} + abstract class NativeType extends DataType { private[sql] type JvmType @transient private[sql] val tag: TypeTag[JvmType] @@ -154,7 +218,6 @@ case object StringType extends NativeType with PrimitiveType { private[sql] type JvmType = String @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val ordering = implicitly[Ordering[JvmType]] - def simpleString: String = "string" } case object BinaryType extends NativeType with PrimitiveType { @@ -166,17 +229,15 @@ case object BinaryType extends NativeType with PrimitiveType { val res = x(i).compareTo(y(i)) if (res != 0) return res } - return x.length - y.length + x.length - y.length } } - def simpleString: String = "binary" } case object BooleanType extends NativeType with PrimitiveType { private[sql] type JvmType = Boolean @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val ordering = implicitly[Ordering[JvmType]] - def simpleString: String = "boolean" } case object TimestampType extends NativeType { @@ -187,8 +248,6 @@ case object TimestampType extends NativeType { private[sql] val ordering = new Ordering[JvmType] { def compare(x: Timestamp, y: Timestamp) = x.compareTo(y) } - - def simpleString: String = "timestamp" } abstract class NumericType extends NativeType with PrimitiveType { @@ -222,7 +281,6 @@ case object LongType extends IntegralType { private[sql] val numeric = implicitly[Numeric[Long]] private[sql] val integral = implicitly[Integral[Long]] private[sql] val ordering = implicitly[Ordering[JvmType]] - def simpleString: String = "long" } case object IntegerType extends IntegralType { @@ -231,7 +289,6 @@ case object IntegerType extends IntegralType { private[sql] val numeric = implicitly[Numeric[Int]] private[sql] val integral = implicitly[Integral[Int]] private[sql] val ordering = implicitly[Ordering[JvmType]] - def simpleString: String = "integer" } case object ShortType extends IntegralType { @@ -240,7 +297,6 @@ case object ShortType extends IntegralType { private[sql] val numeric = implicitly[Numeric[Short]] private[sql] val integral = implicitly[Integral[Short]] private[sql] val ordering = implicitly[Ordering[JvmType]] - def simpleString: String = "short" } case object ByteType extends IntegralType { @@ -249,7 +305,6 @@ case object ByteType extends IntegralType { private[sql] val numeric = implicitly[Numeric[Byte]] private[sql] val integral = implicitly[Integral[Byte]] private[sql] val ordering = implicitly[Ordering[JvmType]] - def simpleString: String = "byte" } /** Matcher for any expressions that evaluate to [[FractionalType]]s */ @@ -271,7 +326,6 @@ case object DecimalType extends FractionalType { private[sql] val fractional = implicitly[Fractional[BigDecimal]] private[sql] val ordering = implicitly[Ordering[JvmType]] private[sql] val asIntegral = BigDecimalAsIfIntegral - def simpleString: String = "decimal" } case object DoubleType extends FractionalType { @@ -281,7 +335,6 @@ case object DoubleType extends FractionalType { private[sql] val fractional = implicitly[Fractional[Double]] private[sql] val ordering = implicitly[Ordering[JvmType]] private[sql] val asIntegral = DoubleAsIfIntegral - def simpleString: String = "double" } case object FloatType extends FractionalType { @@ -291,12 +344,12 @@ case object FloatType extends FractionalType { private[sql] val fractional = implicitly[Fractional[Float]] private[sql] val ordering = implicitly[Ordering[JvmType]] private[sql] val asIntegral = FloatAsIfIntegral - def simpleString: String = "float" } object ArrayType { /** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */ def apply(elementType: DataType): ArrayType = ArrayType(elementType, true) + def typeName: String = "array" } /** @@ -309,11 +362,14 @@ object ArrayType { case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataType { private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { builder.append( - s"${prefix}-- element: ${elementType.simpleString} (containsNull = ${containsNull})\n") + s"$prefix-- element: ${elementType.typeName} (containsNull = $containsNull)\n") DataType.buildFormattedString(elementType, s"$prefix |", builder) } - def simpleString: String = "array" + override private[sql] def jsonValue = + ("type" -> typeName) ~ + ("elementType" -> elementType.jsonValue) ~ + ("containsNull" -> containsNull) } /** @@ -325,14 +381,22 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT case class StructField(name: String, dataType: DataType, nullable: Boolean) { private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { - builder.append(s"${prefix}-- ${name}: ${dataType.simpleString} (nullable = ${nullable})\n") + builder.append(s"$prefix-- $name: ${dataType.typeName} (nullable = $nullable)\n") DataType.buildFormattedString(dataType, s"$prefix |", builder) } + + private[sql] def jsonValue: JValue = { + ("name" -> name) ~ + ("type" -> dataType.jsonValue) ~ + ("nullable" -> nullable) + } } object StructType { protected[sql] def fromAttributes(attributes: Seq[Attribute]): StructType = StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable))) + + def typeName = "struct" } case class StructType(fields: Seq[StructField]) extends DataType { @@ -348,8 +412,7 @@ case class StructType(fields: Seq[StructField]) extends DataType { * have a name matching the given name, `null` will be returned. */ def apply(name: String): StructField = { - nameToField.get(name).getOrElse( - throw new IllegalArgumentException(s"Field ${name} does not exist.")) + nameToField.getOrElse(name, throw new IllegalArgumentException(s"Field $name does not exist.")) } /** @@ -358,7 +421,7 @@ case class StructType(fields: Seq[StructField]) extends DataType { */ def apply(names: Set[String]): StructType = { val nonExistFields = names -- fieldNamesSet - if (!nonExistFields.isEmpty) { + if (nonExistFields.nonEmpty) { throw new IllegalArgumentException( s"Field ${nonExistFields.mkString(",")} does not exist.") } @@ -384,7 +447,9 @@ case class StructType(fields: Seq[StructField]) extends DataType { fields.foreach(field => field.buildFormattedString(prefix, builder)) } - def simpleString: String = "struct" + override private[sql] def jsonValue = + ("type" -> typeName) ~ + ("fields" -> fields.map(_.jsonValue)) } object MapType { @@ -394,6 +459,8 @@ object MapType { */ def apply(keyType: DataType, valueType: DataType): MapType = MapType(keyType: DataType, valueType: DataType, true) + + def simpleName = "map" } /** @@ -407,12 +474,16 @@ case class MapType( valueType: DataType, valueContainsNull: Boolean) extends DataType { private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { - builder.append(s"${prefix}-- key: ${keyType.simpleString}\n") - builder.append(s"${prefix}-- value: ${valueType.simpleString} " + - s"(valueContainsNull = ${valueContainsNull})\n") + builder.append(s"$prefix-- key: ${keyType.typeName}\n") + builder.append(s"$prefix-- value: ${valueType.typeName} " + + s"(valueContainsNull = $valueContainsNull)\n") DataType.buildFormattedString(keyType, s"$prefix |", builder) DataType.buildFormattedString(valueType, s"$prefix |", builder) } - def simpleString: String = "map" + override private[sql] def jsonValue: JValue = + ("type" -> typeName) ~ + ("keyType" -> keyType.jsonValue) ~ + ("valueType" -> valueType.jsonValue) ~ + ("valueContainsNull" -> valueContainsNull) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 7a55c5bf97a71..35561cac3e5e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -22,6 +22,7 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.hadoop.conf.Configuration +import org.apache.spark.SparkContext import org.apache.spark.annotation.{AlphaComponent, DeveloperApi, Experimental} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.ScalaReflection @@ -31,12 +32,11 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.types.DataType import org.apache.spark.sql.columnar.InMemoryRelation -import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.SparkStrategies +import org.apache.spark.sql.execution.{SparkStrategies, _} import org.apache.spark.sql.json._ import org.apache.spark.sql.parquet.ParquetRelation -import org.apache.spark.{Logging, SparkContext} /** * :: AlphaComponent :: @@ -409,8 +409,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * It is only used by PySpark. */ private[sql] def parseDataType(dataTypeString: String): DataType = { - val parser = org.apache.spark.sql.catalyst.types.DataType - parser(dataTypeString) + DataType.fromJson(dataTypeString) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index 2941b9793597f..e6389cf77a4c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.parquet import java.io.IOException +import scala.util.Try + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.mapreduce.Job @@ -323,14 +325,14 @@ private[parquet] object ParquetTypesConverter extends Logging { } def convertFromString(string: String): Seq[Attribute] = { - DataType(string) match { + Try(DataType.fromJson(string)).getOrElse(DataType.fromCaseClassString(string)) match { case s: StructType => s.toAttributes case other => sys.error(s"Can convert $string to row") } } def convertToString(schema: Seq[Attribute]): String = { - StructType.fromAttributes(schema).toString + StructType.fromAttributes(schema).json } def writeMetaData(attributes: Seq[Attribute], origPath: Path, conf: Configuration): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala index 8fb59c5830f6d..100ecb45e9e88 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql import org.scalatest.FunSuite +import org.apache.spark.sql.catalyst.types.DataType + class DataTypeSuite extends FunSuite { test("construct an ArrayType") { @@ -55,4 +57,30 @@ class DataTypeSuite extends FunSuite { struct(Set("b", "d", "e", "f")) } } + + def checkDataTypeJsonRepr(dataType: DataType): Unit = { + test(s"JSON - $dataType") { + assert(DataType.fromJson(dataType.json) === dataType) + } + } + + checkDataTypeJsonRepr(BooleanType) + checkDataTypeJsonRepr(ByteType) + checkDataTypeJsonRepr(ShortType) + checkDataTypeJsonRepr(IntegerType) + checkDataTypeJsonRepr(LongType) + checkDataTypeJsonRepr(FloatType) + checkDataTypeJsonRepr(DoubleType) + checkDataTypeJsonRepr(DecimalType) + checkDataTypeJsonRepr(TimestampType) + checkDataTypeJsonRepr(StringType) + checkDataTypeJsonRepr(BinaryType) + checkDataTypeJsonRepr(ArrayType(DoubleType, true)) + checkDataTypeJsonRepr(ArrayType(StringType, false)) + checkDataTypeJsonRepr(MapType(IntegerType, StringType, true)) + checkDataTypeJsonRepr(MapType(IntegerType, ArrayType(DoubleType), false)) + checkDataTypeJsonRepr( + StructType(Seq( + StructField("a", IntegerType, nullable = true), + StructField("b", ArrayType(DoubleType), nullable = false)))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 07adf731405af..25e41ecf28e2e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -789,7 +789,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA assert(result3(0)(1) === "the answer") Utils.deleteRecursively(tmpdir) } - + test("Querying on empty parquet throws exception (SPARK-3536)") { val tmpdir = Utils.createTempDir() Utils.deleteRecursively(tmpdir) @@ -798,4 +798,18 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA assert(result1.size === 0) Utils.deleteRecursively(tmpdir) } + + test("DataType string parser compatibility") { + val schema = StructType(List( + StructField("c1", IntegerType, false), + StructField("c2", BinaryType, false))) + + val fromCaseClassString = ParquetTypesConverter.convertFromString(schema.toString) + val fromJson = ParquetTypesConverter.convertFromString(schema.json) + + (fromCaseClassString, fromJson).zipped.foreach { (a, b) => + assert(a.name == b.name) + assert(a.dataType === b.dataType) + } + } } From e0ebdd12be6a3f97d2a1ef17ac17f9220e1d99c2 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 8 Oct 2014 17:13:39 -0700 Subject: [PATCH 08/14] Added a test case. --- .../execution/joins/HashedRelationSuite.scala | 60 +++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala new file mode 100644 index 0000000000000..871b7dcc6353e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.scalatest.FunSuite + +import org.apache.spark.sql.catalyst.expressions.{Projection, Row} +import org.apache.spark.util.collection.CompactBuffer + + +class HashedRelationSuite extends FunSuite { + + // Key is simply the record itself + private val keyProjection = new Projection { + override def apply(row: Row): Row = row + } + + test("GeneralHashedRelation") { + val data = Array(Row(0), Row(1), Row(2), Row(2)) + val hashed = HashedRelation(data.iterator, keyProjection) + assert(hashed.isInstanceOf[GeneralHashedRelation]) + + assert(hashed.get(data(0)) == CompactBuffer[Row](data(0))) + assert(hashed.get(data(1)) == CompactBuffer[Row](data(1))) + + val data2 = CompactBuffer[Row](data(2)) + data2 += data(2) + assert(hashed.get(data(2)) == data2) + } + + test("UniqueKeyHashedRelation") { + val data = Array(Row(0), Row(1), Row(2)) + val hashed = HashedRelation(data.iterator, keyProjection) + assert(hashed.isInstanceOf[UniqueKeyHashedRelation]) + + assert(hashed.get(data(0)) == CompactBuffer[Row](data(0))) + assert(hashed.get(data(1)) == CompactBuffer[Row](data(1))) + assert(hashed.get(data(2)) == CompactBuffer[Row](data(2))) + + val uniqHashed = hashed.asInstanceOf[UniqueKeyHashedRelation] + assert(uniqHashed.getValue(data(0)) == data(0)) + assert(uniqHashed.getValue(data(1)) == data(1)) + assert(uniqHashed.getValue(data(2)) == data(2)) + } +} From 00b7791720e50119a98084b2e8755e1b593ca55f Mon Sep 17 00:00:00 2001 From: Liquan Pei Date: Wed, 8 Oct 2014 17:16:54 -0700 Subject: [PATCH 09/14] [SQL][Doc] Keep Spark SQL README.md up to date marmbrus Update README.md to be consistent with Spark 1.1 Author: Liquan Pei Closes #2706 from Ishiihara/SparkSQL-readme and squashes the following commits: 33b9d4b [Liquan Pei] keep README.md up to date --- sql/README.md | 31 +++++++++++++++---------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/sql/README.md b/sql/README.md index 31f9152344086..c84534da9a3d3 100644 --- a/sql/README.md +++ b/sql/README.md @@ -44,38 +44,37 @@ Type in expressions to have them evaluated. Type :help for more information. scala> val query = sql("SELECT * FROM (SELECT * FROM src) a") -query: org.apache.spark.sql.ExecutedQuery = -SELECT * FROM (SELECT * FROM src) a -=== Query Plan === -Project [key#6:0.0,value#7:0.1] - HiveTableScan [key#6,value#7], (MetastoreRelation default, src, None), None +query: org.apache.spark.sql.SchemaRDD = +== Query Plan == +== Physical Plan == +HiveTableScan [key#10,value#11], (MetastoreRelation default, src, None), None ``` Query results are RDDs and can be operated as such. ``` scala> query.collect() -res8: Array[org.apache.spark.sql.execution.Row] = Array([238,val_238], [86,val_86], [311,val_311]... +res2: Array[org.apache.spark.sql.Row] = Array([238,val_238], [86,val_86], [311,val_311], [27,val_27]... ``` You can also build further queries on top of these RDDs using the query DSL. ``` -scala> query.where('key === 100).toRdd.collect() -res11: Array[org.apache.spark.sql.execution.Row] = Array([100,val_100], [100,val_100]) +scala> query.where('key === 100).collect() +res3: Array[org.apache.spark.sql.Row] = Array([100,val_100], [100,val_100]) ``` -From the console you can even write rules that transform query plans. For example, the above query has redundant project operators that aren't doing anything. This redundancy can be eliminated using the `transform` function that is available on all [`TreeNode`](http://databricks.github.io/catalyst/latest/api/#catalyst.trees.TreeNode) objects. +From the console you can even write rules that transform query plans. For example, the above query has redundant project operators that aren't doing anything. This redundancy can be eliminated using the `transform` function that is available on all [`TreeNode`](https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala) objects. ```scala -scala> query.logicalPlan -res1: catalyst.plans.logical.LogicalPlan = -Project {key#0,value#1} - Project {key#0,value#1} +scala> query.queryExecution.analyzed +res4: org.apache.spark.sql.catalyst.plans.logical.LogicalPlan = +Project [key#10,value#11] + Project [key#10,value#11] MetastoreRelation default, src, None -scala> query.logicalPlan transform { +scala> query.queryExecution.analyzed transform { | case Project(projectList, child) if projectList == child.output => child | } -res2: catalyst.plans.logical.LogicalPlan = -Project {key#0,value#1} +res5: res17: org.apache.spark.sql.catalyst.plans.logical.LogicalPlan = +Project [key#10,value#11] MetastoreRelation default, src, None ``` From 4b9d0c9a6674b195feeda753a80920fcb5a5727c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 8 Oct 2014 17:46:33 -0700 Subject: [PATCH 10/14] UniqueKeyHashedRelation.get should return null if the value is null. --- .../apache/spark/sql/execution/joins/HashedRelation.scala | 5 ++++- .../spark/sql/execution/joins/HashedRelationSuite.scala | 3 +++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index cd43357f846cf..34d284cbefca0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -49,7 +49,10 @@ private[joins] final class GeneralHashedRelation(hashTable: JavaHashMap[Row, Com final class UniqueKeyHashedRelation(hashTable: JavaHashMap[Row, Row]) extends HashedRelation with Serializable { - override def get(key: Row) = CompactBuffer(hashTable.get(key)) + override def get(key: Row) = { + val v = hashTable.get(key) + if (v eq null) null else CompactBuffer(v) + } def getValue(key: Row): Row = hashTable.get(key) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 871b7dcc6353e..2aad01ded1acf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -37,6 +37,7 @@ class HashedRelationSuite extends FunSuite { assert(hashed.get(data(0)) == CompactBuffer[Row](data(0))) assert(hashed.get(data(1)) == CompactBuffer[Row](data(1))) + assert(hashed.get(Row(10)) === null) val data2 = CompactBuffer[Row](data(2)) data2 += data(2) @@ -51,10 +52,12 @@ class HashedRelationSuite extends FunSuite { assert(hashed.get(data(0)) == CompactBuffer[Row](data(0))) assert(hashed.get(data(1)) == CompactBuffer[Row](data(1))) assert(hashed.get(data(2)) == CompactBuffer[Row](data(2))) + assert(hashed.get(Row(10)) === null) val uniqHashed = hashed.asInstanceOf[UniqueKeyHashedRelation] assert(uniqHashed.getValue(data(0)) == data(0)) assert(uniqHashed.getValue(data(1)) == data(1)) assert(uniqHashed.getValue(data(2)) == data(2)) + assert(uniqHashed.getValue(Row(10)) == null) } } From 4ec931951fea4efbfe5db39cf581704df7d2775b Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Wed, 8 Oct 2014 17:52:27 -0700 Subject: [PATCH 11/14] [SPARK-3707] [SQL] Fix bug of type coercion in DIV Calling `BinaryArithmetic.dataType` will throws exception until it's resolved, but in type coercion rule `Division`, seems doesn't follow this. Author: Cheng Hao Closes #2559 from chenghao-intel/type_coercion and squashes the following commits: 199a85d [Cheng Hao] Simplify the divide rule dc55218 [Cheng Hao] fix bug of type coercion in div --- .../catalyst/analysis/HiveTypeCoercion.scala | 7 +++- .../sql/catalyst/analysis/AnalysisSuite.scala | 40 +++++++++++++++++-- 2 files changed, 42 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 79e5283e86a37..64881854df7a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -348,8 +348,11 @@ trait HiveTypeCoercion { case e if !e.childrenResolved => e // Decimal and Double remain the same - case d: Divide if d.dataType == DoubleType => d - case d: Divide if d.dataType == DecimalType => d + case d: Divide if d.resolved && d.dataType == DoubleType => d + case d: Divide if d.resolved && d.dataType == DecimalType => d + + case Divide(l, r) if l.dataType == DecimalType => Divide(l, Cast(r, DecimalType)) + case Divide(l, r) if r.dataType == DecimalType => Divide(Cast(l, DecimalType), r) case Divide(l, r) => Divide(Cast(l, DoubleType), Cast(r, DoubleType)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 5809a108ff62e..7b45738c4fc95 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -19,10 +19,11 @@ package org.apache.spark.sql.catalyst.analysis import org.scalatest.{BeforeAndAfter, FunSuite} -import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference} import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.types.IntegerType +import org.apache.spark.sql.catalyst.types._ class AnalysisSuite extends FunSuite with BeforeAndAfter { val caseSensitiveCatalog = new SimpleCatalog(true) @@ -33,6 +34,12 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { new Analyzer(caseInsensitiveCatalog, EmptyFunctionRegistry, caseSensitive = false) val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)()) + val testRelation2 = LocalRelation( + AttributeReference("a", StringType)(), + AttributeReference("b", StringType)(), + AttributeReference("c", DoubleType)(), + AttributeReference("d", DecimalType)(), + AttributeReference("e", ShortType)()) before { caseSensitiveCatalog.registerTable(None, "TaBlE", testRelation) @@ -74,7 +81,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { val e = intercept[RuntimeException] { caseSensitiveAnalyze(UnresolvedRelation(None, "tAbLe", None)) } - assert(e.getMessage === "Table Not Found: tAbLe") + assert(e.getMessage == "Table Not Found: tAbLe") assert( caseSensitiveAnalyze(UnresolvedRelation(None, "TaBlE", None)) === @@ -106,4 +113,31 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { } assert(e.getMessage().toLowerCase.contains("unresolved plan")) } + + test("divide should be casted into fractional types") { + val testRelation2 = LocalRelation( + AttributeReference("a", StringType)(), + AttributeReference("b", StringType)(), + AttributeReference("c", DoubleType)(), + AttributeReference("d", DecimalType)(), + AttributeReference("e", ShortType)()) + + val expr0 = 'a / 2 + val expr1 = 'a / 'b + val expr2 = 'a / 'c + val expr3 = 'a / 'd + val expr4 = 'e / 'e + val plan = caseInsensitiveAnalyze(Project( + Alias(expr0, s"Analyzer($expr0)")() :: + Alias(expr1, s"Analyzer($expr1)")() :: + Alias(expr2, s"Analyzer($expr2)")() :: + Alias(expr3, s"Analyzer($expr3)")() :: + Alias(expr4, s"Analyzer($expr4)")() :: Nil, testRelation2)) + val pl = plan.asInstanceOf[Project].projectList + assert(pl(0).dataType == DoubleType) + assert(pl(1).dataType == DoubleType) + assert(pl(2).dataType == DoubleType) + assert(pl(3).dataType == DecimalType) + assert(pl(4).dataType == DoubleType) + } } From e7033572330bd48b2438f218b0d2cd3fccdeb362 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 8 Oct 2014 18:11:18 -0700 Subject: [PATCH 12/14] [SPARK-3810][SQL] Makes PreInsertionCasts handle partitions properly Includes partition keys into account when applying `PreInsertionCasts` rule. Author: Cheng Lian Closes #2672 from liancheng/fix-pre-insert-casts and squashes the following commits: def1a1a [Cheng Lian] Makes PreInsertionCasts handle partitions properly --- .../spark/sql/hive/HiveMetastoreCatalog.scala | 15 +++----- .../sql/hive/execution/HiveQuerySuite.scala | 36 +++++++++++++++++++ 2 files changed, 41 insertions(+), 10 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index cc0605b0adb35..addd5bed8426d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -19,31 +19,28 @@ package org.apache.spark.sql.hive import scala.util.parsing.combinator.RegexParsers -import org.apache.hadoop.hive.metastore.api.{FieldSchema, StorageDescriptor, SerDeInfo} -import org.apache.hadoop.hive.metastore.api.{Table => TTable, Partition => TPartition} +import org.apache.hadoop.hive.metastore.api.{FieldSchema, SerDeInfo, StorageDescriptor, Partition => TPartition, Table => TTable} import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.hadoop.hive.ql.stats.StatsSetupConst import org.apache.hadoop.hive.serde2.Deserializer -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.Logging +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.analysis.{EliminateAnalysisOperators, Catalog} +import org.apache.spark.sql.catalyst.analysis.Catalog import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.types._ -import org.apache.spark.sql.columnar.InMemoryRelation -import org.apache.spark.sql.hive.execution.HiveTableScan import org.apache.spark.util.Utils /* Implicit conversions */ import scala.collection.JavaConversions._ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with Logging { - import HiveMetastoreTypes._ + import org.apache.spark.sql.hive.HiveMetastoreTypes._ /** Connection to hive metastore. Usages should lock on `this`. */ protected[hive] val client = Hive.get(hive.hiveconf) @@ -137,10 +134,8 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with def castChildOutput(p: InsertIntoTable, table: MetastoreRelation, child: LogicalPlan) = { val childOutputDataTypes = child.output.map(_.dataType) - // Only check attributes, not partitionKeys since they are always strings. - // TODO: Fully support inserting into partitioned tables. val tableOutputDataTypes = - table.attributes.map(_.dataType) ++ table.partitionKeys.map(_.dataType) + (table.attributes ++ table.partitionKeys).take(child.output.length).map(_.dataType) if (childOutputDataTypes == tableOutputDataTypes) { p diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 2e282a9ade40c..2829105f43716 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -22,6 +22,7 @@ import scala.util.Try import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ @@ -675,6 +676,41 @@ class HiveQuerySuite extends HiveComparisonTest { sql("SELECT * FROM boom").queryExecution.analyzed } + test("SPARK-3810: PreInsertionCasts static partitioning support") { + val analyzedPlan = { + loadTestTable("srcpart") + sql("DROP TABLE IF EXISTS withparts") + sql("CREATE TABLE withparts LIKE srcpart") + sql("INSERT INTO TABLE withparts PARTITION(ds='1', hr='2') SELECT key, value FROM src") + .queryExecution.analyzed + } + + assertResult(1, "Duplicated project detected\n" + analyzedPlan) { + analyzedPlan.collect { + case _: Project => () + }.size + } + } + + test("SPARK-3810: PreInsertionCasts dynamic partitioning support") { + val analyzedPlan = { + loadTestTable("srcpart") + sql("DROP TABLE IF EXISTS withparts") + sql("CREATE TABLE withparts LIKE srcpart") + sql("SET hive.exec.dynamic.partition.mode=nonstrict") + + sql("CREATE TABLE IF NOT EXISTS withparts LIKE srcpart") + sql("INSERT INTO TABLE withparts PARTITION(ds, hr) SELECT key, value FROM src") + .queryExecution.analyzed + } + + assertResult(1, "Duplicated project detected\n" + analyzedPlan) { + analyzedPlan.collect { + case _: Project => () + }.size + } + } + test("parse HQL set commands") { // Adapted from its SQL counterpart. val testKey = "spark.sql.key.usedfortestonly" From 3e4f09d2fce9dcf45eaaca827f2cf15c9d4a6c75 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 8 Oct 2014 18:13:22 -0700 Subject: [PATCH 13/14] [SQL] Prevents per row dynamic dispatching and pattern matching when inserting Hive values Builds all wrappers at first according to object inspector types to avoid per row costs. Author: Cheng Lian Closes #2592 from liancheng/hive-value-wrapper and squashes the following commits: 9696559 [Cheng Lian] Passes all tests 4998666 [Cheng Lian] Prevents per row dynamic dispatching and pattern matching when inserting Hive values --- .../hive/execution/InsertIntoHiveTable.scala | 64 ++++++++++--------- 1 file changed, 34 insertions(+), 30 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index f8b4e898ec41d..f0785d8882636 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -69,33 +69,36 @@ case class InsertIntoHiveTable( * Wraps with Hive types based on object inspector. * TODO: Consolidate all hive OI/data interface code. */ - protected def wrap(a: (Any, ObjectInspector)): Any = a match { - case (s: String, oi: JavaHiveVarcharObjectInspector) => - new HiveVarchar(s, s.size) - - case (bd: BigDecimal, oi: JavaHiveDecimalObjectInspector) => - new HiveDecimal(bd.underlying()) - - case (row: Row, oi: StandardStructObjectInspector) => - val struct = oi.create() - row.zip(oi.getAllStructFieldRefs: Seq[StructField]).foreach { - case (data, field) => - oi.setStructFieldData(struct, field, wrap(data, field.getFieldObjectInspector)) + protected def wrapperFor(oi: ObjectInspector): Any => Any = oi match { + case _: JavaHiveVarcharObjectInspector => + (o: Any) => new HiveVarchar(o.asInstanceOf[String], o.asInstanceOf[String].size) + + case _: JavaHiveDecimalObjectInspector => + (o: Any) => new HiveDecimal(o.asInstanceOf[BigDecimal].underlying()) + + case soi: StandardStructObjectInspector => + val wrappers = soi.getAllStructFieldRefs.map(ref => wrapperFor(ref.getFieldObjectInspector)) + (o: Any) => { + val struct = soi.create() + (soi.getAllStructFieldRefs, wrappers, o.asInstanceOf[Row]).zipped.foreach { + (field, wrapper, data) => soi.setStructFieldData(struct, field, wrapper(data)) + } + struct } - struct - case (s: Seq[_], oi: ListObjectInspector) => - val wrappedSeq = s.map(wrap(_, oi.getListElementObjectInspector)) - seqAsJavaList(wrappedSeq) + case loi: ListObjectInspector => + val wrapper = wrapperFor(loi.getListElementObjectInspector) + (o: Any) => seqAsJavaList(o.asInstanceOf[Seq[_]].map(wrapper)) - case (m: Map[_, _], oi: MapObjectInspector) => - val keyOi = oi.getMapKeyObjectInspector - val valueOi = oi.getMapValueObjectInspector - val wrappedMap = m.map { case (key, value) => wrap(key, keyOi) -> wrap(value, valueOi) } - mapAsJavaMap(wrappedMap) + case moi: MapObjectInspector => + val keyWrapper = wrapperFor(moi.getMapKeyObjectInspector) + val valueWrapper = wrapperFor(moi.getMapValueObjectInspector) + (o: Any) => mapAsJavaMap(o.asInstanceOf[Map[_, _]].map { case (key, value) => + keyWrapper(key) -> valueWrapper(value) + }) - case (obj, _) => - obj + case _ => + identity[Any] } def saveAsHiveFile( @@ -103,7 +106,7 @@ case class InsertIntoHiveTable( valueClass: Class[_], fileSinkConf: FileSinkDesc, conf: SerializableWritable[JobConf], - writerContainer: SparkHiveWriterContainer) { + writerContainer: SparkHiveWriterContainer): Unit = { assert(valueClass != null, "Output value class not set") conf.value.setOutputValueClass(valueClass) @@ -122,7 +125,7 @@ case class InsertIntoHiveTable( writerContainer.commitJob() // Note that this function is executed on executor side - def writeToFile(context: TaskContext, iterator: Iterator[Row]) { + def writeToFile(context: TaskContext, iterator: Iterator[Row]): Unit = { val serializer = newSerializer(fileSinkConf.getTableInfo) val standardOI = ObjectInspectorUtils .getStandardObjectInspector( @@ -131,6 +134,7 @@ case class InsertIntoHiveTable( .asInstanceOf[StructObjectInspector] val fieldOIs = standardOI.getAllStructFieldRefs.map(_.getFieldObjectInspector).toArray + val wrappers = fieldOIs.map(wrapperFor) val outputData = new Array[Any](fieldOIs.length) // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it @@ -141,13 +145,13 @@ case class InsertIntoHiveTable( iterator.foreach { row => var i = 0 while (i < fieldOIs.length) { - // TODO (lian) avoid per row dynamic dispatching and pattern matching cost in `wrap` - outputData(i) = wrap(row(i), fieldOIs(i)) + outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row(i)) i += 1 } - val writer = writerContainer.getLocalFileWriter(row) - writer.write(serializer.serialize(outputData, standardOI)) + writerContainer + .getLocalFileWriter(row) + .write(serializer.serialize(outputData, standardOI)) } writerContainer.close() @@ -207,7 +211,7 @@ case class InsertIntoHiveTable( // Report error if any static partition appears after a dynamic partition val isDynamic = partitionColumnNames.map(partitionSpec(_).isEmpty) - isDynamic.init.zip(isDynamic.tail).find(_ == (true, false)).foreach { _ => + if (isDynamic.init.zip(isDynamic.tail).contains((true, false))) { throw new SparkException(ErrorMsg.PARTITION_DYN_STA_ORDER.getMsg) } } From bcb1ae049b447c37418747e0a262f54f9fc1664a Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 8 Oct 2014 18:17:01 -0700 Subject: [PATCH 14/14] [SPARK-3857] Create joins package for various join operators. Author: Reynold Xin Closes #2719 from rxin/sql-join-break and squashes the following commits: 0c0082b [Reynold Xin] Fix line length. cbc664c [Reynold Xin] Rename join -> joins package. a070d44 [Reynold Xin] Fix line length in HashJoin a39be8c [Reynold Xin] [SPARK-3857] Create a join package for various join operators. --- .../spark/sql/execution/SparkStrategies.scala | 41 +- .../apache/spark/sql/execution/joins.scala | 624 ------------------ .../execution/joins/BroadcastHashJoin.scala | 62 ++ .../joins/BroadcastNestedLoopJoin.scala | 144 ++++ .../execution/joins/CartesianProduct.scala | 40 ++ .../spark/sql/execution/joins/HashJoin.scala | 123 ++++ .../sql/execution/joins/HashOuterJoin.scala | 222 +++++++ .../sql/execution/joins/LeftSemiJoinBNL.scala | 73 ++ .../execution/joins/LeftSemiJoinHash.scala | 67 ++ .../execution/joins/ShuffledHashJoin.scala | 49 ++ .../spark/sql/execution/joins/package.scala | 37 ++ .../org/apache/spark/sql/JoinSuite.scala | 1 + .../org/apache/spark/sql/SQLQuerySuite.scala | 2 +- .../spark/sql/execution/PlannerSuite.scala | 3 +- .../spark/sql/hive/StatisticsSuite.scala | 2 +- 15 files changed, 844 insertions(+), 646 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/package.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 883f2ff521e20..bbf17b9fadf86 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan} import org.apache.spark.sql.parquet._ + private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { self: SQLContext#SparkPlanner => @@ -34,13 +35,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { // Find left semi joins where at least some predicates can be evaluated by matching join keys case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) => - val semiJoin = execution.LeftSemiJoinHash( + val semiJoin = joins.LeftSemiJoinHash( leftKeys, rightKeys, planLater(left), planLater(right)) condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil // no predicate can be evaluated by matching hash keys case logical.Join(left, right, LeftSemi, condition) => - execution.LeftSemiJoinBNL( - planLater(left), planLater(right), condition) :: Nil + joins.LeftSemiJoinBNL(planLater(left), planLater(right), condition) :: Nil case _ => Nil } } @@ -50,13 +50,13 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * evaluated by matching hash keys. * * This strategy applies a simple optimization based on the estimates of the physical sizes of - * the two join sides. When planning a [[execution.BroadcastHashJoin]], if one side has an + * the two join sides. When planning a [[joins.BroadcastHashJoin]], if one side has an * estimated physical size smaller than the user-settable threshold * [[org.apache.spark.sql.SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]], the planner would mark it as the * ''build'' relation and mark the other relation as the ''stream'' side. The build table will be * ''broadcasted'' to all of the executors involved in the join, as a * [[org.apache.spark.broadcast.Broadcast]] object. If both estimates exceed the threshold, they - * will instead be used to decide the build side in a [[execution.ShuffledHashJoin]]. + * will instead be used to decide the build side in a [[joins.ShuffledHashJoin]]. */ object HashJoin extends Strategy with PredicateHelper { @@ -66,8 +66,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { left: LogicalPlan, right: LogicalPlan, condition: Option[Expression], - side: BuildSide) = { - val broadcastHashJoin = execution.BroadcastHashJoin( + side: joins.BuildSide) = { + val broadcastHashJoin = execution.joins.BroadcastHashJoin( leftKeys, rightKeys, side, planLater(left), planLater(right)) condition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) :: Nil } @@ -76,27 +76,26 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) if sqlContext.autoBroadcastJoinThreshold > 0 && right.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold => - makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildRight) + makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildRight) case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) if sqlContext.autoBroadcastJoinThreshold > 0 && left.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold => - makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildLeft) + makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft) case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) => val buildSide = if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { - BuildRight + joins.BuildRight } else { - BuildLeft + joins.BuildLeft } - val hashJoin = - execution.ShuffledHashJoin( - leftKeys, rightKeys, buildSide, planLater(left), planLater(right)) + val hashJoin = joins.ShuffledHashJoin( + leftKeys, rightKeys, buildSide, planLater(left), planLater(right)) condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) => - execution.HashOuterJoin( + joins.HashOuterJoin( leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil case _ => Nil @@ -164,8 +163,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Join(left, right, joinType, condition) => val buildSide = - if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) BuildRight else BuildLeft - execution.BroadcastNestedLoopJoin( + if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { + joins.BuildRight + } else { + joins.BuildLeft + } + joins.BroadcastNestedLoopJoin( planLater(left), planLater(right), buildSide, joinType, condition) :: Nil case _ => Nil } @@ -174,10 +177,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object CartesianProduct extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Join(left, right, _, None) => - execution.CartesianProduct(planLater(left), planLater(right)) :: Nil + execution.joins.CartesianProduct(planLater(left), planLater(right)) :: Nil case logical.Join(left, right, Inner, Some(condition)) => execution.Filter(condition, - execution.CartesianProduct(planLater(left), planLater(right))) :: Nil + execution.joins.CartesianProduct(planLater(left), planLater(right))) :: Nil case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala deleted file mode 100644 index 2890a563bed48..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ /dev/null @@ -1,624 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import java.util.{HashMap => JavaHashMap} - -import scala.concurrent.ExecutionContext.Implicits.global -import scala.concurrent._ -import scala.concurrent.duration._ - -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.util.collection.CompactBuffer - -@DeveloperApi -sealed abstract class BuildSide - -@DeveloperApi -case object BuildLeft extends BuildSide - -@DeveloperApi -case object BuildRight extends BuildSide - -trait HashJoin { - self: SparkPlan => - - val leftKeys: Seq[Expression] - val rightKeys: Seq[Expression] - val buildSide: BuildSide - val left: SparkPlan - val right: SparkPlan - - lazy val (buildPlan, streamedPlan) = buildSide match { - case BuildLeft => (left, right) - case BuildRight => (right, left) - } - - lazy val (buildKeys, streamedKeys) = buildSide match { - case BuildLeft => (leftKeys, rightKeys) - case BuildRight => (rightKeys, leftKeys) - } - - def output = left.output ++ right.output - - @transient lazy val buildSideKeyGenerator = newProjection(buildKeys, buildPlan.output) - @transient lazy val streamSideKeyGenerator = - newMutableProjection(streamedKeys, streamedPlan.output) - - def joinIterators(buildIter: Iterator[Row], streamIter: Iterator[Row]): Iterator[Row] = { - // TODO: Use Spark's HashMap implementation. - - val hashTable = new java.util.HashMap[Row, CompactBuffer[Row]]() - var currentRow: Row = null - - // Create a mapping of buildKeys -> rows - while (buildIter.hasNext) { - currentRow = buildIter.next() - val rowKey = buildSideKeyGenerator(currentRow) - if (!rowKey.anyNull) { - val existingMatchList = hashTable.get(rowKey) - val matchList = if (existingMatchList == null) { - val newMatchList = new CompactBuffer[Row]() - hashTable.put(rowKey, newMatchList) - newMatchList - } else { - existingMatchList - } - matchList += currentRow.copy() - } - } - - new Iterator[Row] { - private[this] var currentStreamedRow: Row = _ - private[this] var currentHashMatches: CompactBuffer[Row] = _ - private[this] var currentMatchPosition: Int = -1 - - // Mutable per row objects. - private[this] val joinRow = new JoinedRow2 - - private[this] val joinKeys = streamSideKeyGenerator() - - override final def hasNext: Boolean = - (currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) || - (streamIter.hasNext && fetchNext()) - - override final def next() = { - val ret = buildSide match { - case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition)) - case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow) - } - currentMatchPosition += 1 - ret - } - - /** - * Searches the streamed iterator for the next row that has at least one match in hashtable. - * - * @return true if the search is successful, and false if the streamed iterator runs out of - * tuples. - */ - private final def fetchNext(): Boolean = { - currentHashMatches = null - currentMatchPosition = -1 - - while (currentHashMatches == null && streamIter.hasNext) { - currentStreamedRow = streamIter.next() - if (!joinKeys(currentStreamedRow).anyNull) { - currentHashMatches = hashTable.get(joinKeys.currentValue) - } - } - - if (currentHashMatches == null) { - false - } else { - currentMatchPosition = 0 - true - } - } - } - } -} - -/** - * :: DeveloperApi :: - * Performs a hash based outer join for two child relations by shuffling the data using - * the join keys. This operator requires loading the associated partition in both side into memory. - */ -@DeveloperApi -case class HashOuterJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - joinType: JoinType, - condition: Option[Expression], - left: SparkPlan, - right: SparkPlan) extends BinaryNode { - - override def outputPartitioning: Partitioning = joinType match { - case LeftOuter => left.outputPartitioning - case RightOuter => right.outputPartitioning - case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) - case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") - } - - override def requiredChildDistribution = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - - override def output = { - joinType match { - case LeftOuter => - left.output ++ right.output.map(_.withNullability(true)) - case RightOuter => - left.output.map(_.withNullability(true)) ++ right.output - case FullOuter => - left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) - case x => - throw new Exception(s"HashOuterJoin should not take $x as the JoinType") - } - } - - @transient private[this] lazy val DUMMY_LIST = Seq[Row](null) - @transient private[this] lazy val EMPTY_LIST = Seq.empty[Row] - - // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala - // iterator for performance purpose. - - private[this] def leftOuterIterator( - key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { - val joinedRow = new JoinedRow() - val rightNullRow = new GenericRow(right.output.length) - val boundCondition = - condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) - - leftIter.iterator.flatMap { l => - joinedRow.withLeft(l) - var matched = false - (if (!key.anyNull) rightIter.collect { case r if (boundCondition(joinedRow.withRight(r))) => - matched = true - joinedRow.copy - } else { - Nil - }) ++ DUMMY_LIST.filter(_ => !matched).map( _ => { - // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, - // as we don't know whether we need to append it until finish iterating all of the - // records in right side. - // If we didn't get any proper row, then append a single row with empty right - joinedRow.withRight(rightNullRow).copy - }) - } - } - - private[this] def rightOuterIterator( - key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { - val joinedRow = new JoinedRow() - val leftNullRow = new GenericRow(left.output.length) - val boundCondition = - condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) - - rightIter.iterator.flatMap { r => - joinedRow.withRight(r) - var matched = false - (if (!key.anyNull) leftIter.collect { case l if (boundCondition(joinedRow.withLeft(l))) => - matched = true - joinedRow.copy - } else { - Nil - }) ++ DUMMY_LIST.filter(_ => !matched).map( _ => { - // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, - // as we don't know whether we need to append it until finish iterating all of the - // records in left side. - // If we didn't get any proper row, then append a single row with empty left. - joinedRow.withLeft(leftNullRow).copy - }) - } - } - - private[this] def fullOuterIterator( - key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { - val joinedRow = new JoinedRow() - val leftNullRow = new GenericRow(left.output.length) - val rightNullRow = new GenericRow(right.output.length) - val boundCondition = - condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) - - if (!key.anyNull) { - // Store the positions of records in right, if one of its associated row satisfy - // the join condition. - val rightMatchedSet = scala.collection.mutable.Set[Int]() - leftIter.iterator.flatMap[Row] { l => - joinedRow.withLeft(l) - var matched = false - rightIter.zipWithIndex.collect { - // 1. For those matched (satisfy the join condition) records with both sides filled, - // append them directly - - case (r, idx) if (boundCondition(joinedRow.withRight(r)))=> { - matched = true - // if the row satisfy the join condition, add its index into the matched set - rightMatchedSet.add(idx) - joinedRow.copy - } - } ++ DUMMY_LIST.filter(_ => !matched).map( _ => { - // 2. For those unmatched records in left, append additional records with empty right. - - // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, - // as we don't know whether we need to append it until finish iterating all - // of the records in right side. - // If we didn't get any proper row, then append a single row with empty right. - joinedRow.withRight(rightNullRow).copy - }) - } ++ rightIter.zipWithIndex.collect { - // 3. For those unmatched records in right, append additional records with empty left. - - // Re-visiting the records in right, and append additional row with empty left, if its not - // in the matched set. - case (r, idx) if (!rightMatchedSet.contains(idx)) => { - joinedRow(leftNullRow, r).copy - } - } - } else { - leftIter.iterator.map[Row] { l => - joinedRow(l, rightNullRow).copy - } ++ rightIter.iterator.map[Row] { r => - joinedRow(leftNullRow, r).copy - } - } - } - - private[this] def buildHashTable( - iter: Iterator[Row], keyGenerator: Projection): JavaHashMap[Row, CompactBuffer[Row]] = { - val hashTable = new JavaHashMap[Row, CompactBuffer[Row]]() - while (iter.hasNext) { - val currentRow = iter.next() - val rowKey = keyGenerator(currentRow) - - var existingMatchList = hashTable.get(rowKey) - if (existingMatchList == null) { - existingMatchList = new CompactBuffer[Row]() - hashTable.put(rowKey, existingMatchList) - } - - existingMatchList += currentRow.copy() - } - - hashTable - } - - def execute() = { - left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => - // TODO this probably can be replaced by external sort (sort merged join?) - // Build HashMap for current partition in left relation - val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) - // Build HashMap for current partition in right relation - val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) - - import scala.collection.JavaConversions._ - val boundCondition = - condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) - joinType match { - case LeftOuter => leftHashTable.keysIterator.flatMap { key => - leftOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), - rightHashTable.getOrElse(key, EMPTY_LIST)) - } - case RightOuter => rightHashTable.keysIterator.flatMap { key => - rightOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), - rightHashTable.getOrElse(key, EMPTY_LIST)) - } - case FullOuter => (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => - fullOuterIterator(key, - leftHashTable.getOrElse(key, EMPTY_LIST), - rightHashTable.getOrElse(key, EMPTY_LIST)) - } - case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") - } - } - } -} - -/** - * :: DeveloperApi :: - * Performs an inner hash join of two child relations by first shuffling the data using the join - * keys. - */ -@DeveloperApi -case class ShuffledHashJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - buildSide: BuildSide, - left: SparkPlan, - right: SparkPlan) extends BinaryNode with HashJoin { - - override def outputPartitioning: Partitioning = left.outputPartitioning - - override def requiredChildDistribution = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - - def execute() = { - buildPlan.execute().zipPartitions(streamedPlan.execute()) { - (buildIter, streamIter) => joinIterators(buildIter, streamIter) - } - } -} - -/** - * :: DeveloperApi :: - * Build the right table's join keys into a HashSet, and iteratively go through the left - * table, to find the if join keys are in the Hash set. - */ -@DeveloperApi -case class LeftSemiJoinHash( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - left: SparkPlan, - right: SparkPlan) extends BinaryNode with HashJoin { - - val buildSide = BuildRight - - override def requiredChildDistribution = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - - override def output = left.output - - def execute() = { - buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => - val hashSet = new java.util.HashSet[Row]() - var currentRow: Row = null - - // Create a Hash set of buildKeys - while (buildIter.hasNext) { - currentRow = buildIter.next() - val rowKey = buildSideKeyGenerator(currentRow) - if (!rowKey.anyNull) { - val keyExists = hashSet.contains(rowKey) - if (!keyExists) { - hashSet.add(rowKey) - } - } - } - - val joinKeys = streamSideKeyGenerator() - streamIter.filter(current => { - !joinKeys(current).anyNull && hashSet.contains(joinKeys.currentValue) - }) - } - } -} - - -/** - * :: DeveloperApi :: - * Performs an inner hash join of two child relations. When the output RDD of this operator is - * being constructed, a Spark job is asynchronously started to calculate the values for the - * broadcasted relation. This data is then placed in a Spark broadcast variable. The streamed - * relation is not shuffled. - */ -@DeveloperApi -case class BroadcastHashJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - buildSide: BuildSide, - left: SparkPlan, - right: SparkPlan) extends BinaryNode with HashJoin { - - override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning - - override def requiredChildDistribution = - UnspecifiedDistribution :: UnspecifiedDistribution :: Nil - - @transient - val broadcastFuture = future { - sparkContext.broadcast(buildPlan.executeCollect()) - } - - def execute() = { - val broadcastRelation = Await.result(broadcastFuture, 5.minute) - - streamedPlan.execute().mapPartitions { streamedIter => - joinIterators(broadcastRelation.value.iterator, streamedIter) - } - } -} - -/** - * :: DeveloperApi :: - * Using BroadcastNestedLoopJoin to calculate left semi join result when there's no join keys - * for hash join. - */ -@DeveloperApi -case class LeftSemiJoinBNL( - streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression]) - extends BinaryNode { - // TODO: Override requiredChildDistribution. - - override def outputPartitioning: Partitioning = streamed.outputPartitioning - - def output = left.output - - /** The Streamed Relation */ - def left = streamed - /** The Broadcast relation */ - def right = broadcast - - @transient lazy val boundCondition = - InterpretedPredicate( - condition - .map(c => BindReferences.bindReference(c, left.output ++ right.output)) - .getOrElse(Literal(true))) - - def execute() = { - val broadcastedRelation = - sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) - - streamed.execute().mapPartitions { streamedIter => - val joinedRow = new JoinedRow - - streamedIter.filter(streamedRow => { - var i = 0 - var matched = false - - while (i < broadcastedRelation.value.size && !matched) { - val broadcastedRow = broadcastedRelation.value(i) - if (boundCondition(joinedRow(streamedRow, broadcastedRow))) { - matched = true - } - i += 1 - } - matched - }) - } - } -} - -/** - * :: DeveloperApi :: - */ -@DeveloperApi -case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode { - def output = left.output ++ right.output - - def execute() = { - val leftResults = left.execute().map(_.copy()) - val rightResults = right.execute().map(_.copy()) - - leftResults.cartesian(rightResults).mapPartitions { iter => - val joinedRow = new JoinedRow - iter.map(r => joinedRow(r._1, r._2)) - } - } -} - -/** - * :: DeveloperApi :: - */ -@DeveloperApi -case class BroadcastNestedLoopJoin( - left: SparkPlan, - right: SparkPlan, - buildSide: BuildSide, - joinType: JoinType, - condition: Option[Expression]) extends BinaryNode { - // TODO: Override requiredChildDistribution. - - /** BuildRight means the right relation <=> the broadcast relation. */ - val (streamed, broadcast) = buildSide match { - case BuildRight => (left, right) - case BuildLeft => (right, left) - } - - override def outputPartitioning: Partitioning = streamed.outputPartitioning - - override def output = { - joinType match { - case LeftOuter => - left.output ++ right.output.map(_.withNullability(true)) - case RightOuter => - left.output.map(_.withNullability(true)) ++ right.output - case FullOuter => - left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) - case _ => - left.output ++ right.output - } - } - - @transient lazy val boundCondition = - InterpretedPredicate( - condition - .map(c => BindReferences.bindReference(c, left.output ++ right.output)) - .getOrElse(Literal(true))) - - def execute() = { - val broadcastedRelation = - sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) - - /** All rows that either match both-way, or rows from streamed joined with nulls. */ - val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter => - val matchedRows = new CompactBuffer[Row] - // TODO: Use Spark's BitSet. - val includedBroadcastTuples = - new scala.collection.mutable.BitSet(broadcastedRelation.value.size) - val joinedRow = new JoinedRow - val leftNulls = new GenericMutableRow(left.output.size) - val rightNulls = new GenericMutableRow(right.output.size) - - streamedIter.foreach { streamedRow => - var i = 0 - var streamRowMatched = false - - while (i < broadcastedRelation.value.size) { - // TODO: One bitset per partition instead of per row. - val broadcastedRow = broadcastedRelation.value(i) - buildSide match { - case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) => - matchedRows += joinedRow(streamedRow, broadcastedRow).copy() - streamRowMatched = true - includedBroadcastTuples += i - case BuildLeft if boundCondition(joinedRow(broadcastedRow, streamedRow)) => - matchedRows += joinedRow(broadcastedRow, streamedRow).copy() - streamRowMatched = true - includedBroadcastTuples += i - case _ => - } - i += 1 - } - - (streamRowMatched, joinType, buildSide) match { - case (false, LeftOuter | FullOuter, BuildRight) => - matchedRows += joinedRow(streamedRow, rightNulls).copy() - case (false, RightOuter | FullOuter, BuildLeft) => - matchedRows += joinedRow(leftNulls, streamedRow).copy() - case _ => - } - } - Iterator((matchedRows, includedBroadcastTuples)) - } - - val includedBroadcastTuples = matchesOrStreamedRowsWithNulls.map(_._2) - val allIncludedBroadcastTuples = - if (includedBroadcastTuples.count == 0) { - new scala.collection.mutable.BitSet(broadcastedRelation.value.size) - } else { - includedBroadcastTuples.reduce(_ ++ _) - } - - val leftNulls = new GenericMutableRow(left.output.size) - val rightNulls = new GenericMutableRow(right.output.size) - /** Rows from broadcasted joined with nulls. */ - val broadcastRowsWithNulls: Seq[Row] = { - val buf: CompactBuffer[Row] = new CompactBuffer() - var i = 0 - val rel = broadcastedRelation.value - while (i < rel.length) { - if (!allIncludedBroadcastTuples.contains(i)) { - (joinType, buildSide) match { - case (RightOuter | FullOuter, BuildRight) => buf += new JoinedRow(leftNulls, rel(i)) - case (LeftOuter | FullOuter, BuildLeft) => buf += new JoinedRow(rel(i), rightNulls) - case _ => - } - } - i += 1 - } - buf.toSeq - } - - // TODO: Breaks lineage. - sparkContext.union( - matchesOrStreamedRowsWithNulls.flatMap(_._1), sparkContext.makeRDD(broadcastRowsWithNulls)) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala new file mode 100644 index 0000000000000..d88ab6367a1b3 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import scala.concurrent._ +import scala.concurrent.duration._ +import scala.concurrent.ExecutionContext.Implicits.global + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnspecifiedDistribution} +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} + +/** + * :: DeveloperApi :: + * Performs an inner hash join of two child relations. When the output RDD of this operator is + * being constructed, a Spark job is asynchronously started to calculate the values for the + * broadcasted relation. This data is then placed in a Spark broadcast variable. The streamed + * relation is not shuffled. + */ +@DeveloperApi +case class BroadcastHashJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + buildSide: BuildSide, + left: SparkPlan, + right: SparkPlan) + extends BinaryNode with HashJoin { + + override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning + + override def requiredChildDistribution = + UnspecifiedDistribution :: UnspecifiedDistribution :: Nil + + @transient + private val broadcastFuture = future { + sparkContext.broadcast(buildPlan.executeCollect()) + } + + override def execute() = { + val broadcastRelation = Await.result(broadcastFuture, 5.minute) + + streamedPlan.execute().mapPartitions { streamedIter => + joinIterators(broadcastRelation.value.iterator, streamedIter) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala new file mode 100644 index 0000000000000..36aad13778bd2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.util.collection.CompactBuffer + +/** + * :: DeveloperApi :: + */ +@DeveloperApi +case class BroadcastNestedLoopJoin( + left: SparkPlan, + right: SparkPlan, + buildSide: BuildSide, + joinType: JoinType, + condition: Option[Expression]) extends BinaryNode { + // TODO: Override requiredChildDistribution. + + /** BuildRight means the right relation <=> the broadcast relation. */ + private val (streamed, broadcast) = buildSide match { + case BuildRight => (left, right) + case BuildLeft => (right, left) + } + + override def outputPartitioning: Partitioning = streamed.outputPartitioning + + override def output = { + joinType match { + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case FullOuter => + left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) + case _ => + left.output ++ right.output + } + } + + @transient private lazy val boundCondition = + InterpretedPredicate( + condition + .map(c => BindReferences.bindReference(c, left.output ++ right.output)) + .getOrElse(Literal(true))) + + override def execute() = { + val broadcastedRelation = + sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) + + /** All rows that either match both-way, or rows from streamed joined with nulls. */ + val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter => + val matchedRows = new CompactBuffer[Row] + // TODO: Use Spark's BitSet. + val includedBroadcastTuples = + new scala.collection.mutable.BitSet(broadcastedRelation.value.size) + val joinedRow = new JoinedRow + val leftNulls = new GenericMutableRow(left.output.size) + val rightNulls = new GenericMutableRow(right.output.size) + + streamedIter.foreach { streamedRow => + var i = 0 + var streamRowMatched = false + + while (i < broadcastedRelation.value.size) { + // TODO: One bitset per partition instead of per row. + val broadcastedRow = broadcastedRelation.value(i) + buildSide match { + case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) => + matchedRows += joinedRow(streamedRow, broadcastedRow).copy() + streamRowMatched = true + includedBroadcastTuples += i + case BuildLeft if boundCondition(joinedRow(broadcastedRow, streamedRow)) => + matchedRows += joinedRow(broadcastedRow, streamedRow).copy() + streamRowMatched = true + includedBroadcastTuples += i + case _ => + } + i += 1 + } + + (streamRowMatched, joinType, buildSide) match { + case (false, LeftOuter | FullOuter, BuildRight) => + matchedRows += joinedRow(streamedRow, rightNulls).copy() + case (false, RightOuter | FullOuter, BuildLeft) => + matchedRows += joinedRow(leftNulls, streamedRow).copy() + case _ => + } + } + Iterator((matchedRows, includedBroadcastTuples)) + } + + val includedBroadcastTuples = matchesOrStreamedRowsWithNulls.map(_._2) + val allIncludedBroadcastTuples = + if (includedBroadcastTuples.count == 0) { + new scala.collection.mutable.BitSet(broadcastedRelation.value.size) + } else { + includedBroadcastTuples.reduce(_ ++ _) + } + + val leftNulls = new GenericMutableRow(left.output.size) + val rightNulls = new GenericMutableRow(right.output.size) + /** Rows from broadcasted joined with nulls. */ + val broadcastRowsWithNulls: Seq[Row] = { + val buf: CompactBuffer[Row] = new CompactBuffer() + var i = 0 + val rel = broadcastedRelation.value + while (i < rel.length) { + if (!allIncludedBroadcastTuples.contains(i)) { + (joinType, buildSide) match { + case (RightOuter | FullOuter, BuildRight) => buf += new JoinedRow(leftNulls, rel(i)) + case (LeftOuter | FullOuter, BuildLeft) => buf += new JoinedRow(rel(i), rightNulls) + case _ => + } + } + i += 1 + } + buf.toSeq + } + + // TODO: Breaks lineage. + sparkContext.union( + matchesOrStreamedRowsWithNulls.flatMap(_._1), sparkContext.makeRDD(broadcastRowsWithNulls)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala new file mode 100644 index 0000000000000..76c14c02aab34 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions.JoinedRow +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} + +/** + * :: DeveloperApi :: + */ +@DeveloperApi +case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode { + override def output = left.output ++ right.output + + override def execute() = { + val leftResults = left.execute().map(_.copy()) + val rightResults = right.execute().map(_.copy()) + + leftResults.cartesian(rightResults).mapPartitions { iter => + val joinedRow = new JoinedRow + iter.map(r => joinedRow(r._1, r._2)) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala new file mode 100644 index 0000000000000..472b2e6ca6b4a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.sql.catalyst.expressions.{Expression, JoinedRow2, Row} +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.util.collection.CompactBuffer + + +trait HashJoin { + self: SparkPlan => + + val leftKeys: Seq[Expression] + val rightKeys: Seq[Expression] + val buildSide: BuildSide + val left: SparkPlan + val right: SparkPlan + + protected lazy val (buildPlan, streamedPlan) = buildSide match { + case BuildLeft => (left, right) + case BuildRight => (right, left) + } + + protected lazy val (buildKeys, streamedKeys) = buildSide match { + case BuildLeft => (leftKeys, rightKeys) + case BuildRight => (rightKeys, leftKeys) + } + + override def output = left.output ++ right.output + + @transient protected lazy val buildSideKeyGenerator = newProjection(buildKeys, buildPlan.output) + @transient protected lazy val streamSideKeyGenerator = + newMutableProjection(streamedKeys, streamedPlan.output) + + protected def joinIterators(buildIter: Iterator[Row], streamIter: Iterator[Row]): Iterator[Row] = + { + // TODO: Use Spark's HashMap implementation. + + val hashTable = new java.util.HashMap[Row, CompactBuffer[Row]]() + var currentRow: Row = null + + // Create a mapping of buildKeys -> rows + while (buildIter.hasNext) { + currentRow = buildIter.next() + val rowKey = buildSideKeyGenerator(currentRow) + if (!rowKey.anyNull) { + val existingMatchList = hashTable.get(rowKey) + val matchList = if (existingMatchList == null) { + val newMatchList = new CompactBuffer[Row]() + hashTable.put(rowKey, newMatchList) + newMatchList + } else { + existingMatchList + } + matchList += currentRow.copy() + } + } + + new Iterator[Row] { + private[this] var currentStreamedRow: Row = _ + private[this] var currentHashMatches: CompactBuffer[Row] = _ + private[this] var currentMatchPosition: Int = -1 + + // Mutable per row objects. + private[this] val joinRow = new JoinedRow2 + + private[this] val joinKeys = streamSideKeyGenerator() + + override final def hasNext: Boolean = + (currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) || + (streamIter.hasNext && fetchNext()) + + override final def next() = { + val ret = buildSide match { + case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition)) + case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow) + } + currentMatchPosition += 1 + ret + } + + /** + * Searches the streamed iterator for the next row that has at least one match in hashtable. + * + * @return true if the search is successful, and false if the streamed iterator runs out of + * tuples. + */ + private final def fetchNext(): Boolean = { + currentHashMatches = null + currentMatchPosition = -1 + + while (currentHashMatches == null && streamIter.hasNext) { + currentStreamedRow = streamIter.next() + if (!joinKeys(currentStreamedRow).anyNull) { + currentHashMatches = hashTable.get(joinKeys.currentValue) + } + } + + if (currentHashMatches == null) { + false + } else { + currentMatchPosition = 0 + true + } + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala new file mode 100644 index 0000000000000..b73041d306b36 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import java.util.{HashMap => JavaHashMap} + +import scala.collection.JavaConversions._ + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning, UnknownPartitioning} +import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.util.collection.CompactBuffer + +/** + * :: DeveloperApi :: + * Performs a hash based outer join for two child relations by shuffling the data using + * the join keys. This operator requires loading the associated partition in both side into memory. + */ +@DeveloperApi +case class HashOuterJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode { + + override def outputPartitioning: Partitioning = joinType match { + case LeftOuter => left.outputPartitioning + case RightOuter => right.outputPartitioning + case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) + case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") + } + + override def requiredChildDistribution = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + override def output = { + joinType match { + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case FullOuter => + left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) + case x => + throw new Exception(s"HashOuterJoin should not take $x as the JoinType") + } + } + + @transient private[this] lazy val DUMMY_LIST = Seq[Row](null) + @transient private[this] lazy val EMPTY_LIST = Seq.empty[Row] + + // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala + // iterator for performance purpose. + + private[this] def leftOuterIterator( + key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { + val joinedRow = new JoinedRow() + val rightNullRow = new GenericRow(right.output.length) + val boundCondition = + condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) + + leftIter.iterator.flatMap { l => + joinedRow.withLeft(l) + var matched = false + (if (!key.anyNull) rightIter.collect { case r if (boundCondition(joinedRow.withRight(r))) => + matched = true + joinedRow.copy + } else { + Nil + }) ++ DUMMY_LIST.filter(_ => !matched).map( _ => { + // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, + // as we don't know whether we need to append it until finish iterating all of the + // records in right side. + // If we didn't get any proper row, then append a single row with empty right + joinedRow.withRight(rightNullRow).copy + }) + } + } + + private[this] def rightOuterIterator( + key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { + val joinedRow = new JoinedRow() + val leftNullRow = new GenericRow(left.output.length) + val boundCondition = + condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) + + rightIter.iterator.flatMap { r => + joinedRow.withRight(r) + var matched = false + (if (!key.anyNull) leftIter.collect { case l if (boundCondition(joinedRow.withLeft(l))) => + matched = true + joinedRow.copy + } else { + Nil + }) ++ DUMMY_LIST.filter(_ => !matched).map( _ => { + // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, + // as we don't know whether we need to append it until finish iterating all of the + // records in left side. + // If we didn't get any proper row, then append a single row with empty left. + joinedRow.withLeft(leftNullRow).copy + }) + } + } + + private[this] def fullOuterIterator( + key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { + val joinedRow = new JoinedRow() + val leftNullRow = new GenericRow(left.output.length) + val rightNullRow = new GenericRow(right.output.length) + val boundCondition = + condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) + + if (!key.anyNull) { + // Store the positions of records in right, if one of its associated row satisfy + // the join condition. + val rightMatchedSet = scala.collection.mutable.Set[Int]() + leftIter.iterator.flatMap[Row] { l => + joinedRow.withLeft(l) + var matched = false + rightIter.zipWithIndex.collect { + // 1. For those matched (satisfy the join condition) records with both sides filled, + // append them directly + + case (r, idx) if (boundCondition(joinedRow.withRight(r)))=> { + matched = true + // if the row satisfy the join condition, add its index into the matched set + rightMatchedSet.add(idx) + joinedRow.copy + } + } ++ DUMMY_LIST.filter(_ => !matched).map( _ => { + // 2. For those unmatched records in left, append additional records with empty right. + + // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, + // as we don't know whether we need to append it until finish iterating all + // of the records in right side. + // If we didn't get any proper row, then append a single row with empty right. + joinedRow.withRight(rightNullRow).copy + }) + } ++ rightIter.zipWithIndex.collect { + // 3. For those unmatched records in right, append additional records with empty left. + + // Re-visiting the records in right, and append additional row with empty left, if its not + // in the matched set. + case (r, idx) if (!rightMatchedSet.contains(idx)) => { + joinedRow(leftNullRow, r).copy + } + } + } else { + leftIter.iterator.map[Row] { l => + joinedRow(l, rightNullRow).copy + } ++ rightIter.iterator.map[Row] { r => + joinedRow(leftNullRow, r).copy + } + } + } + + private[this] def buildHashTable( + iter: Iterator[Row], keyGenerator: Projection): JavaHashMap[Row, CompactBuffer[Row]] = { + val hashTable = new JavaHashMap[Row, CompactBuffer[Row]]() + while (iter.hasNext) { + val currentRow = iter.next() + val rowKey = keyGenerator(currentRow) + + var existingMatchList = hashTable.get(rowKey) + if (existingMatchList == null) { + existingMatchList = new CompactBuffer[Row]() + hashTable.put(rowKey, existingMatchList) + } + + existingMatchList += currentRow.copy() + } + + hashTable + } + + override def execute() = { + left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => + // TODO this probably can be replaced by external sort (sort merged join?) + // Build HashMap for current partition in left relation + val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) + // Build HashMap for current partition in right relation + val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) + val boundCondition = + condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) + joinType match { + case LeftOuter => leftHashTable.keysIterator.flatMap { key => + leftOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), + rightHashTable.getOrElse(key, EMPTY_LIST)) + } + case RightOuter => rightHashTable.keysIterator.flatMap { key => + rightOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), + rightHashTable.getOrElse(key, EMPTY_LIST)) + } + case FullOuter => (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => + fullOuterIterator(key, + leftHashTable.getOrElse(key, EMPTY_LIST), + rightHashTable.getOrElse(key, EMPTY_LIST)) + } + case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala new file mode 100644 index 0000000000000..60003d1900d85 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} + +/** + * :: DeveloperApi :: + * Using BroadcastNestedLoopJoin to calculate left semi join result when there's no join keys + * for hash join. + */ +@DeveloperApi +case class LeftSemiJoinBNL( + streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression]) + extends BinaryNode { + // TODO: Override requiredChildDistribution. + + override def outputPartitioning: Partitioning = streamed.outputPartitioning + + override def output = left.output + + /** The Streamed Relation */ + override def left = streamed + /** The Broadcast relation */ + override def right = broadcast + + @transient private lazy val boundCondition = + InterpretedPredicate( + condition + .map(c => BindReferences.bindReference(c, left.output ++ right.output)) + .getOrElse(Literal(true))) + + override def execute() = { + val broadcastedRelation = + sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) + + streamed.execute().mapPartitions { streamedIter => + val joinedRow = new JoinedRow + + streamedIter.filter(streamedRow => { + var i = 0 + var matched = false + + while (i < broadcastedRelation.value.size && !matched) { + val broadcastedRow = broadcastedRelation.value(i) + if (boundCondition(joinedRow(streamedRow, broadcastedRow))) { + matched = true + } + i += 1 + } + matched + }) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala new file mode 100644 index 0000000000000..ea7babf3be948 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions.{Expression, Row} +import org.apache.spark.sql.catalyst.plans.physical.ClusteredDistribution +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} + +/** + * :: DeveloperApi :: + * Build the right table's join keys into a HashSet, and iteratively go through the left + * table, to find the if join keys are in the Hash set. + */ +@DeveloperApi +case class LeftSemiJoinHash( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode with HashJoin { + + override val buildSide = BuildRight + + override def requiredChildDistribution = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + override def output = left.output + + override def execute() = { + buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => + val hashSet = new java.util.HashSet[Row]() + var currentRow: Row = null + + // Create a Hash set of buildKeys + while (buildIter.hasNext) { + currentRow = buildIter.next() + val rowKey = buildSideKeyGenerator(currentRow) + if (!rowKey.anyNull) { + val keyExists = hashSet.contains(rowKey) + if (!keyExists) { + hashSet.add(rowKey) + } + } + } + + val joinKeys = streamSideKeyGenerator() + streamIter.filter(current => { + !joinKeys(current).anyNull && hashSet.contains(joinKeys.currentValue) + }) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala new file mode 100644 index 0000000000000..8247304c1dc2c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning} +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} + +/** + * :: DeveloperApi :: + * Performs an inner hash join of two child relations by first shuffling the data using the join + * keys. + */ +@DeveloperApi +case class ShuffledHashJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + buildSide: BuildSide, + left: SparkPlan, + right: SparkPlan) + extends BinaryNode with HashJoin { + + override def outputPartitioning: Partitioning = left.outputPartitioning + + override def requiredChildDistribution = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + override def execute() = { + buildPlan.execute().zipPartitions(streamedPlan.execute()) { + (buildIter, streamIter) => joinIterators(buildIter, streamIter) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/package.scala new file mode 100644 index 0000000000000..7f2ab1765b28f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/package.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.annotation.DeveloperApi + +/** + * :: DeveloperApi :: + * Physical execution operators for join operations. + */ +package object joins { + + @DeveloperApi + sealed abstract class BuildSide + + @DeveloperApi + case object BuildRight extends BuildSide + + @DeveloperApi + case object BuildLeft extends BuildSide + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 6c7697ece8c56..07f4d2946c1b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.{LeftOuter, RightOuter, FullOuter, Inner, LeftSemi} import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 6fb6cb8db0c8f..b9b196ea5a46a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.{ShuffledHashJoin, BroadcastHashJoin} +import org.apache.spark.sql.execution.joins.BroadcastHashJoin import org.apache.spark.sql.test._ import org.scalatest.BeforeAndAfterAll import java.util.TimeZone diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index bfbf431a11913..f14ffca0e4d35 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -19,10 +19,11 @@ package org.apache.spark.sql.execution import org.scalatest.FunSuite +import org.apache.spark.sql.{SQLConf, execution} import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.{SQLConf, execution} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin} import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.test.TestSQLContext.planner._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index a35c40efdc207..14e791fe0f0ee 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -24,7 +24,7 @@ import scala.reflect.ClassTag import org.apache.spark.sql.{SQLConf, QueryTest} import org.apache.spark.sql.catalyst.plans.logical.NativeCommand -import org.apache.spark.sql.execution.{BroadcastHashJoin, ShuffledHashJoin} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin} import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._