diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index ec659ce789c2..c9dfb10e8840 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -49,7 +49,9 @@ case object AllTuples extends Distribution * can mean such tuples are either co-located in the same partition or they will be contiguous * within a single partition. */ -case class ClusteredDistribution(clustering: Seq[Expression]) extends Distribution { +case class ClusteredDistribution( + clustering: Seq[Expression], + nullSafe: Boolean) extends Distribution { require( clustering != Nil, "The clustering expressions of a ClusteredDistribution should not be Nil. " + @@ -57,6 +59,11 @@ case class ClusteredDistribution(clustering: Seq[Expression]) extends Distributi "a single partition.") } +object ClusteredDistribution { + def apply(clustering: Seq[Expression]): ClusteredDistribution = + ClusteredDistribution(clustering, nullSafe = true) +} + /** * Represents data where tuples have been ordered according to the `ordering` * [[Expression Expressions]]. This is a strictly stronger guarantee than @@ -90,9 +97,22 @@ sealed trait Partitioning { /** * Returns true iff we can say that the partitioning scheme of this [[Partitioning]] * guarantees the same partitioning scheme described by `other`. + * + * If a [[Partitioning]] supports `nullSafe` setting, the nullSafe version of this + * [[Partitioning]] should always `guarantees` its nullUnsafe version. + * For example, HashPartitioning(expressions = 'a, numPartitions = 10, nullSafe = true) + * guarantees HashPartitioning(expressions = 'a, numPartitions = 10, nullSafe = false). + * However, HashPartitioning(expressions = 'a, numPartitions = 10, nullSafe = false) does not + * guarantees HashPartitioning(expressions = 'a, numPartitions = 10, nullSafe = true). */ - // TODO: Add an example once we have the `nullSafe` concept. def guarantees(other: Partitioning): Boolean + + /** + * If a [[Partitioning]] supports `nullSafe` setting, returns a new instance of this + * [[Partitioning]] with the given nullSafe setting. Otherwise, returns this + * [[Partitioning]]. + */ + def withNullSafeSetting(newNullSafe: Boolean): Partitioning } case class UnknownPartitioning(numPartitions: Int) extends Partitioning { @@ -102,6 +122,8 @@ case class UnknownPartitioning(numPartitions: Int) extends Partitioning { } override def guarantees(other: Partitioning): Boolean = false + + override def withNullSafeSetting(newNullSafe: Boolean): Partitioning = this } case object SinglePartition extends Partitioning { @@ -113,6 +135,8 @@ case object SinglePartition extends Partitioning { case SinglePartition => true case _ => false } + + override def withNullSafeSetting(newNullSafe: Boolean): Partitioning = this } case object BroadcastPartitioning extends Partitioning { @@ -124,6 +148,8 @@ case object BroadcastPartitioning extends Partitioning { case BroadcastPartitioning => true case _ => false } + + override def withNullSafeSetting(newNullSafe: Boolean): Partitioning = this } /** @@ -131,7 +157,10 @@ case object BroadcastPartitioning extends Partitioning { * of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be * in the same partition. */ -case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) +case class HashPartitioning( + expressions: Seq[Expression], + numPartitions: Int, + nullSafe: Boolean) extends Expression with Partitioning with Unevaluable { override def children: Seq[Expression] = expressions @@ -142,16 +171,30 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) override def satisfies(required: Distribution): Boolean = required match { case UnspecifiedDistribution => true - case ClusteredDistribution(requiredClustering) => + case ClusteredDistribution(requiredClustering, _) if nullSafe => + clusteringSet.subsetOf(requiredClustering.toSet) + case ClusteredDistribution(requiredClustering, false) if !nullSafe => clusteringSet.subsetOf(requiredClustering.toSet) case _ => false } override def guarantees(other: Partitioning): Boolean = other match { - case o: HashPartitioning => + case o: HashPartitioning if (nullSafe || (!nullSafe && !o.nullSafe)) => this.clusteringSet == o.clusteringSet && this.numPartitions == o.numPartitions case _ => false } + + override def withNullSafeSetting(newNullSafe: Boolean): Partitioning = { + HashPartitioning(expressions, numPartitions, nullSafe = newNullSafe) + } + + override def toString: String = + s"${super.toString} numPartitions=$numPartitions nullSafe=$nullSafe" +} + +object HashPartitioning { + def apply(expressions: Seq[Expression], numPartitions: Int): HashPartitioning = + HashPartitioning(expressions, numPartitions, nullSafe = true) } /** @@ -180,7 +223,7 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) case OrderedDistribution(requiredOrdering) => val minSize = Seq(requiredOrdering.size, ordering.size).min requiredOrdering.take(minSize) == ordering.take(minSize) - case ClusteredDistribution(requiredClustering) => + case ClusteredDistribution(requiredClustering, _) => clusteringSet.subsetOf(requiredClustering.toSet) case _ => false } @@ -189,6 +232,10 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) case o: RangePartitioning => this == o case _ => false } + + override def withNullSafeSetting(newNullSafe: Boolean): Partitioning = this + + override def toString: String = s"${super.toString} numPartitions=$numPartitions" } /** @@ -235,6 +282,10 @@ case class PartitioningCollection(partitionings: Seq[Partitioning]) override def guarantees(other: Partitioning): Boolean = partitionings.exists(_.guarantees(other)) + override def withNullSafeSetting(newNullSafe: Boolean): Partitioning = { + PartitioningCollection(partitionings.map(_.withNullSafeSetting(newNullSafe))) + } + override def toString: String = { partitionings.map(_.toString).mkString("(", " or ", ")") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala index 827f7ce69271..10d71a8261d4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala @@ -104,6 +104,80 @@ class DistributionSuite extends SparkFunSuite { */ } + test("HashPartitioning (with nullSafe = false) is the output partitioning") { + // Cases which do not need an exchange between two data properties. + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10, false), + UnspecifiedDistribution, + true) + + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10, false), + ClusteredDistribution(Seq('a, 'b, 'c), false), + true) + + checkSatisfied( + HashPartitioning(Seq('b, 'c), 10, false), + ClusteredDistribution(Seq('a, 'b, 'c), false), + true) + + checkSatisfied( + SinglePartition, + ClusteredDistribution(Seq('a, 'b, 'c), false), + true) + + checkSatisfied( + SinglePartition, + OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), + true) + + // Cases which need an exchange between two data properties. + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10, false), + ClusteredDistribution(Seq('a, 'b, 'c)), + false) + + checkSatisfied( + HashPartitioning(Seq('b, 'c), 10, false), + ClusteredDistribution(Seq('a, 'b, 'c)), + false) + + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10, false), + ClusteredDistribution(Seq('b, 'c)), + false) + + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10, false), + ClusteredDistribution(Seq('d, 'e)), + false) + + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10, false), + ClusteredDistribution(Seq('b, 'c), false), + false) + + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10, false), + ClusteredDistribution(Seq('d, 'e), false), + false) + + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10, false), + AllTuples, + false) + + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10, false), + OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), + false) + + checkSatisfied( + HashPartitioning(Seq('b, 'c), 10, false), + OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), + false) + } + test("RangePartitioning is the output partitioning") { // Cases which do not need an exchange between two data properties. checkSatisfied( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 05b009d1935b..2e8a3fa2218c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -148,7 +148,8 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una protected override def doExecute(): RDD[InternalRow] = attachTree(this , "execute") { val rdd = child.execute() val part: Partitioner = newPartitioning match { - case HashPartitioning(expressions, numPartitions) => new HashPartitioner(numPartitions) + case HashPartitioning(expressions, numPartitions, nullSafe) => + new HashPartitioner(numPartitions) case RangePartitioning(sortingExpressions, numPartitions) => // Internally, RangePartitioner runs a job on the RDD that samples keys to compute // partition bounds. To get accurate samples, we need to copy the mutable keys. @@ -167,7 +168,9 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una // TODO: Handle BroadcastPartitioning. } def getPartitionKeyExtractor(): InternalRow => InternalRow = newPartitioning match { - case HashPartitioning(expressions, _) => newMutableProjection(expressions, child.output)() + // TODO: If nullSafe is false, we can randomly distribute rows having any null in + // clustering. + case HashPartitioning(expressions, _, _) => newMutableProjection(expressions, child.output)() case RangePartitioning(_, _) | SinglePartition => identity case _ => sys.error(s"Exchange not implemented for $newPartitioning") } @@ -210,7 +213,12 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ def addShuffleIfNecessary(child: SparkPlan): SparkPlan = { if (!child.outputPartitioning.guarantees(partitioning)) { - Exchange(partitioning, child) + // If the child's outputPartitioning does not guarantees partitioning, + // we need to add an Exchange operator. At here, we always use + // the nullSafe version of the given partitioning because the nullSafe + // version always guarantees the nullUnsafe version of the partitioning and + // we do not have any special handling for nullUnsafe partitioning for now. + Exchange(partitioning.withNullSafeSetting(newNullSafe = true), child) } else { child } @@ -240,8 +248,9 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ val fixedChildren = requirements.zipped.map { case (AllTuples, rowOrdering, child) => addOperatorsIfNecessary(SinglePartition, rowOrdering, child) - case (ClusteredDistribution(clustering), rowOrdering, child) => - addOperatorsIfNecessary(HashPartitioning(clustering, numPartitions), rowOrdering, child) + case (ClusteredDistribution(clustering, nullSafe), rowOrdering, child) => + val hashPartitioning = HashPartitioning(clustering, numPartitions, nullSafe) + addOperatorsIfNecessary(hashPartitioning, rowOrdering, child) case (OrderedDistribution(ordering), rowOrdering, child) => addOperatorsIfNecessary(RangePartitioning(ordering, numPartitions), rowOrdering, child) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index fc6efe87bceb..46dbdeffd9df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -42,7 +42,8 @@ case class ShuffledHashJoin( PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + ClusteredDistribution(leftKeys, nullSafe = false) :: + ClusteredDistribution(rightKeys, nullSafe = false) :: Nil protected override def doExecute(): RDD[InternalRow] = { buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala index eee8ad800f98..951bee1ca92e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala @@ -42,12 +42,23 @@ case class ShuffledHashOuterJoin( right: SparkPlan) extends BinaryNode with HashOuterJoin { override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + ClusteredDistribution(leftKeys, nullSafe = false) :: + ClusteredDistribution(rightKeys, nullSafe = false) :: Nil override def outputPartitioning: Partitioning = joinType match { - case LeftOuter => left.outputPartitioning - case RightOuter => right.outputPartitioning - case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) + case LeftOuter => + val partitions = + left.outputPartitioning :: right.outputPartitioning.withNullSafeSetting(false) :: Nil + PartitioningCollection(partitions) + case RightOuter => + val partitions = + Seq(right.outputPartitioning, left.outputPartitioning.withNullSafeSetting(false)) + PartitioningCollection(partitions) + case FullOuter => + val partitions = + left.outputPartitioning.withNullSafeSetting(false) :: + right.outputPartitioning.withNullSafeSetting(false) :: Nil + PartitioningCollection(partitions) case x => throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 41be78afd37e..df7b8eec15fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -44,7 +44,8 @@ case class SortMergeJoin( PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + ClusteredDistribution(leftKeys, nullSafe = false) :: + ClusteredDistribution(rightKeys, nullSafe = false) :: Nil // this is to manually construct an ordering that can be used to compare keys from both sides private val keyOrdering: RowOrdering = RowOrdering.forSchema(leftKeys.map(_.dataType)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 18b0e54dc7c5..2d5668da5b32 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -170,14 +170,57 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils { // Disable broadcast join withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val joins = Array("JOIN", "LEFT OUTER JOIN", "RIGHT OUTER JOIN", "FULL OUTER JOIN") + var i = 0 + while (i < joins.length) { + var j = 0 + while (j < joins.length) { + val firstJoin: String = joins(i) + val secondJoin: String = joins(j) + + { + val numExchanges: Int = sql( + s""" + |SELECT * + |FROM + | normal $firstJoin small ON (normal.key = small.key) + | $secondJoin tiny ON (small.key = tiny.key) + """.stripMargin + ).queryExecution.executedPlan.collect { + case exchange: Exchange => exchange + }.length + assert(numExchanges === 3) + } + + { + val numExchanges: Int = sql( + s""" + |SELECT * + |FROM + | normal $firstJoin small ON (normal.key = small.key) + | $secondJoin tiny ON (normal.key = tiny.key) + """.stripMargin + ).queryExecution.executedPlan.collect { + case exchange: Exchange => exchange + }.length + assert(numExchanges === 3) + } + + j += 1 + } + i += 1 + } + { - val numExchanges = sql( - """ - |SELECT * - |FROM - | normal JOIN small ON (normal.key = small.key) - | JOIN tiny ON (small.key = tiny.key) - """.stripMargin + val numExchanges: Int = sql( + s""" + |SELECT small.key, count(*) + |FROM + | normal JOIN small ON (normal.key = small.key) + | JOIN tiny ON (small.key = tiny.key) + |GROUP BY + | small.key + """.stripMargin ).queryExecution.executedPlan.collect { case exchange: Exchange => exchange }.length @@ -185,20 +228,36 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils { } { - // This second query joins on different keys: - val numExchanges = sql( - """ - |SELECT * - |FROM - | normal JOIN small ON (normal.key = small.key) - | JOIN tiny ON (normal.key = tiny.key) - """.stripMargin + val numExchanges: Int = sql( + s""" + |SELECT normal.key, count(*) + |FROM + | normal LEFT OUTER JOIN small ON (normal.key = small.key) + | JOIN tiny ON (small.key = tiny.key) + |GROUP BY + | normal.key + """.stripMargin ).queryExecution.executedPlan.collect { case exchange: Exchange => exchange }.length assert(numExchanges === 3) } + { + val numExchanges: Int = sql( + s""" + |SELECT small.key, count(*) + |FROM + | normal LEFT OUTER JOIN small ON (normal.key = small.key) + | JOIN tiny ON (small.key = tiny.key) + |GROUP BY + | small.key + """.stripMargin + ).queryExecution.executedPlan.collect { + case exchange: Exchange => exchange + }.length + assert(numExchanges === 4) + } } } }