Skip to content

Commit 69bb072

Browse files
committed
Introduce NullSafeHashPartitioning and NullUnsafePartitioning.
1 parent d5b84c3 commit 69bb072

File tree

13 files changed

+130
-47
lines changed

13 files changed

+130
-47
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import java.security.{MessageDigest, NoSuchAlgorithmException}
2121
import java.util.zip.CRC32
2222

2323
import org.apache.commons.codec.digest.DigestUtils
24+
import org.apache.spark.sql.catalyst.InternalRow
2425

2526
import org.apache.spark.sql.catalyst.expressions.codegen._
2627
import org.apache.spark.sql.types._
@@ -160,3 +161,22 @@ case class Crc32(child: Expression) extends UnaryExpression with ImplicitCastInp
160161
})
161162
}
162163
}
164+
165+
/** An expression that returns the hashCode of the input row. */
166+
case object RowHashCode extends LeafExpression {
167+
override def dataType: DataType = IntegerType
168+
169+
/** hashCode will never be null. */
170+
override def nullable: Boolean = false
171+
172+
override def eval(input: InternalRow): Any = {
173+
input.hashCode
174+
}
175+
176+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
177+
s"""
178+
boolean ${ev.isNull} = false;
179+
${ctx.javaType(dataType)} ${ev.primitive} = i.hashCode();
180+
"""
181+
}
182+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala

Lines changed: 49 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,23 @@ case object AllTuples extends Distribution
4747
* Represents data where tuples that share the same values for the `clustering`
4848
* [[Expression Expressions]] will be co-located. Based on the context, this
4949
* can mean such tuples are either co-located in the same partition or they will be contiguous
50-
* within a single partition.
50+
* within a single partition. For two null values in two rows evaluated by `clustering`,
51+
* we consider these two nulls are equal.
5152
*/
52-
case class ClusteredDistribution(clustering: Seq[Expression]) extends Distribution {
53+
case class NullSafeClusteredDistribution(clustering: Seq[Expression]) extends Distribution {
54+
require(
55+
clustering != Nil,
56+
"The clustering expressions of a ClusteredDistribution should not be Nil. " +
57+
"An AllTuples should be used to represent a distribution that only has " +
58+
"a single partition.")
59+
}
60+
61+
/**
62+
* It is basically the same as [[NullSafeClusteredDistribution]] except that
63+
* for two null values in two rows evaluated by `clustering`,
64+
* we consider these two nulls are not equal.
65+
*/
66+
case class NullUnsafeClusteredDistribution(clustering: Seq[Expression]) extends Distribution {
5367
require(
5468
clustering != Nil,
5569
"The clustering expressions of a ClusteredDistribution should not be Nil. " +
@@ -60,7 +74,7 @@ case class ClusteredDistribution(clustering: Seq[Expression]) extends Distributi
6074
/**
6175
* Represents data where tuples have been ordered according to the `ordering`
6276
* [[Expression Expressions]]. This is a strictly stronger guarantee than
63-
* [[ClusteredDistribution]] as an ordering will ensure that tuples that share the same value for
77+
* [[NullSafeClusteredDistribution]] as an ordering will ensure that tuples that share the same value for
6478
* the ordering expressions are contiguous and will never be split across partitions.
6579
*/
6680
case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution {
@@ -89,7 +103,7 @@ sealed trait Partitioning {
89103
/**
90104
* Returns true iff all distribution guarantees made by this partitioning can also be made
91105
* for the `other` specified partitioning.
92-
* For example, two [[HashPartitioning HashPartitioning]]s are
106+
* For example, two [[NullSafeHashPartitioning HashPartitioning]]s are
93107
* only compatible if the `numPartitions` of them is the same.
94108
*/
95109
def compatibleWith(other: Partitioning): Boolean
@@ -143,7 +157,34 @@ case object BroadcastPartitioning extends Partitioning {
143157
* of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be
144158
* in the same partition.
145159
*/
146-
case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
160+
case class NullSafeHashPartitioning(expressions: Seq[Expression], numPartitions: Int)
161+
extends Expression with Partitioning with Unevaluable {
162+
163+
override def children: Seq[Expression] = expressions
164+
override def nullable: Boolean = false
165+
override def dataType: DataType = IntegerType
166+
167+
private[this] lazy val clusteringSet = expressions.toSet
168+
169+
override def satisfies(required: Distribution): Boolean = required match {
170+
case UnspecifiedDistribution => true
171+
case NullSafeClusteredDistribution(requiredClustering) =>
172+
clusteringSet.subsetOf(requiredClustering.toSet)
173+
case NullUnsafeClusteredDistribution(requiredClustering) =>
174+
clusteringSet.subsetOf(requiredClustering.toSet)
175+
case _ => false
176+
}
177+
178+
override def compatibleWith(other: Partitioning): Boolean = other match {
179+
case BroadcastPartitioning => true
180+
case h: NullSafeHashPartitioning if h == this => true
181+
case _ => false
182+
}
183+
184+
override def keyExpressions: Seq[Expression] = expressions
185+
}
186+
187+
case class NullUnsafeHashPartitioning(expressions: Seq[Expression], numPartitions: Int)
147188
extends Expression with Partitioning with Unevaluable {
148189

149190
override def children: Seq[Expression] = expressions
@@ -154,14 +195,14 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
154195

155196
override def satisfies(required: Distribution): Boolean = required match {
156197
case UnspecifiedDistribution => true
157-
case ClusteredDistribution(requiredClustering) =>
198+
case NullUnsafeClusteredDistribution(requiredClustering) =>
158199
clusteringSet.subsetOf(requiredClustering.toSet)
159200
case _ => false
160201
}
161202

162203
override def compatibleWith(other: Partitioning): Boolean = other match {
163204
case BroadcastPartitioning => true
164-
case h: HashPartitioning if h == this => true
205+
case h: NullUnsafeHashPartitioning if h == this => true
165206
case _ => false
166207
}
167208

@@ -194,14 +235,13 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
194235
case OrderedDistribution(requiredOrdering) =>
195236
val minSize = Seq(requiredOrdering.size, ordering.size).min
196237
requiredOrdering.take(minSize) == ordering.take(minSize)
197-
case ClusteredDistribution(requiredClustering) =>
238+
case NullSafeClusteredDistribution(requiredClustering) =>
198239
clusteringSet.subsetOf(requiredClustering.toSet)
199240
case _ => false
200241
}
201242

202243
override def compatibleWith(other: Partitioning): Boolean = other match {
203244
case BroadcastPartitioning => true
204-
case r: RangePartitioning if r == this => true
205245
case _ => false
206246
}
207247

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -45,23 +45,23 @@ class DistributionSuite extends SparkFunSuite {
4545
test("HashPartitioning is the output partitioning") {
4646
// Cases which do not need an exchange between two data properties.
4747
checkSatisfied(
48-
HashPartitioning(Seq('a, 'b, 'c), 10),
48+
NullSafeHashPartitioning(Seq('a, 'b, 'c), 10),
4949
UnspecifiedDistribution,
5050
true)
5151

5252
checkSatisfied(
53-
HashPartitioning(Seq('a, 'b, 'c), 10),
54-
ClusteredDistribution(Seq('a, 'b, 'c)),
53+
NullSafeHashPartitioning(Seq('a, 'b, 'c), 10),
54+
NullSafeClusteredDistribution(Seq('a, 'b, 'c)),
5555
true)
5656

5757
checkSatisfied(
58-
HashPartitioning(Seq('b, 'c), 10),
59-
ClusteredDistribution(Seq('a, 'b, 'c)),
58+
NullSafeHashPartitioning(Seq('b, 'c), 10),
59+
NullSafeClusteredDistribution(Seq('a, 'b, 'c)),
6060
true)
6161

6262
checkSatisfied(
6363
SinglePartition,
64-
ClusteredDistribution(Seq('a, 'b, 'c)),
64+
NullSafeClusteredDistribution(Seq('a, 'b, 'c)),
6565
true)
6666

6767
checkSatisfied(
@@ -71,27 +71,27 @@ class DistributionSuite extends SparkFunSuite {
7171

7272
// Cases which need an exchange between two data properties.
7373
checkSatisfied(
74-
HashPartitioning(Seq('a, 'b, 'c), 10),
75-
ClusteredDistribution(Seq('b, 'c)),
74+
NullSafeHashPartitioning(Seq('a, 'b, 'c), 10),
75+
NullSafeClusteredDistribution(Seq('b, 'c)),
7676
false)
7777

7878
checkSatisfied(
79-
HashPartitioning(Seq('a, 'b, 'c), 10),
80-
ClusteredDistribution(Seq('d, 'e)),
79+
NullSafeHashPartitioning(Seq('a, 'b, 'c), 10),
80+
NullSafeClusteredDistribution(Seq('d, 'e)),
8181
false)
8282

8383
checkSatisfied(
84-
HashPartitioning(Seq('a, 'b, 'c), 10),
84+
NullSafeHashPartitioning(Seq('a, 'b, 'c), 10),
8585
AllTuples,
8686
false)
8787

8888
checkSatisfied(
89-
HashPartitioning(Seq('a, 'b, 'c), 10),
89+
NullSafeHashPartitioning(Seq('a, 'b, 'c), 10),
9090
OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)),
9191
false)
9292

9393
checkSatisfied(
94-
HashPartitioning(Seq('b, 'c), 10),
94+
NullSafeHashPartitioning(Seq('b, 'c), 10),
9595
OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)),
9696
false)
9797

@@ -128,17 +128,17 @@ class DistributionSuite extends SparkFunSuite {
128128

129129
checkSatisfied(
130130
RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10),
131-
ClusteredDistribution(Seq('a, 'b, 'c)),
131+
NullSafeClusteredDistribution(Seq('a, 'b, 'c)),
132132
true)
133133

134134
checkSatisfied(
135135
RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10),
136-
ClusteredDistribution(Seq('c, 'b, 'a)),
136+
NullSafeClusteredDistribution(Seq('c, 'b, 'a)),
137137
true)
138138

139139
checkSatisfied(
140140
RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10),
141-
ClusteredDistribution(Seq('b, 'c, 'a, 'd)),
141+
NullSafeClusteredDistribution(Seq('b, 'c, 'a, 'd)),
142142
true)
143143

144144
// Cases which need an exchange between two data properties.
@@ -158,12 +158,12 @@ class DistributionSuite extends SparkFunSuite {
158158

159159
checkSatisfied(
160160
RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10),
161-
ClusteredDistribution(Seq('a, 'b)),
161+
NullSafeClusteredDistribution(Seq('a, 'b)),
162162
false)
163163

164164
checkSatisfied(
165165
RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10),
166-
ClusteredDistribution(Seq('c, 'd)),
166+
NullSafeClusteredDistribution(Seq('c, 'd)),
167167
false)
168168

169169
checkSatisfied(

sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ case class Aggregate(
5252
if (groupingExpressions == Nil) {
5353
AllTuples :: Nil
5454
} else {
55-
ClusteredDistribution(groupingExpressions) :: Nil
55+
NullSafeClusteredDistribution(groupingExpressions) :: Nil
5656
}
5757
}
5858
}

sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.errors.attachTree
2929
import org.apache.spark.sql.catalyst.expressions._
3030
import org.apache.spark.sql.catalyst.plans.physical._
3131
import org.apache.spark.sql.catalyst.rules.Rule
32+
import org.apache.spark.sql.types.IntegerType
3233
import org.apache.spark.util.MutablePair
3334
import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEnv}
3435

@@ -140,10 +141,13 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
140141
}
141142
}
142143

144+
private val advancedSqlOptimizations = child.sqlContext.conf.advancedSqlOptimizations
145+
143146
protected override def doExecute(): RDD[InternalRow] = attachTree(this , "execute") {
144147
val rdd = child.execute()
145148
val part: Partitioner = newPartitioning match {
146-
case HashPartitioning(expressions, numPartitions) => new HashPartitioner(numPartitions)
149+
case NullSafeHashPartitioning(expressions, numPartitions) => new HashPartitioner(numPartitions)
150+
case NullUnsafeHashPartitioning(expressions, numPartitions) => new HashPartitioner(numPartitions)
147151
case RangePartitioning(sortingExpressions, numPartitions) =>
148152
// Internally, RangePartitioner runs a job on the RDD that samples keys to compute
149153
// partition bounds. To get accurate samples, we need to copy the mutable keys.
@@ -162,7 +166,24 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
162166
// TODO: Handle BroadcastPartitioning.
163167
}
164168
def getPartitionKeyExtractor(): InternalRow => InternalRow = newPartitioning match {
165-
case HashPartitioning(expressions, _) => newMutableProjection(expressions, child.output)()
169+
case NullSafeHashPartitioning(expressions, _) => newMutableProjection(expressions, child.output)()
170+
case NullUnsafeHashPartitioning(expressions, numPartition) if advancedSqlOptimizations =>
171+
// For NullUnsafeHashPartitioning, we do not want to send rows having any expression
172+
// in `expressions` evaluated as null to the same node.
173+
val materalizeExpressions = newMutableProjection(expressions, child.output)()
174+
val partitionExpressionSchema = expressions.map { expr =>
175+
Alias(expr, "partitionExpr")().toAttribute
176+
}
177+
val partitionId =
178+
If(
179+
AtLeastNNonNulls(partitionExpressionSchema.length, partitionExpressionSchema),
180+
RowHashCode,
181+
Cast(Multiply(new Rand(numPartition), Literal(numPartition.toDouble)), IntegerType))
182+
val partitionIdExtractor =
183+
newMutableProjection(partitionId :: Nil, partitionExpressionSchema)()
184+
(row: InternalRow) => partitionIdExtractor(materalizeExpressions(row))
185+
case NullUnsafeHashPartitioning(expressions, numPartition) =>
186+
newMutableProjection(expressions, child.output)()
166187
case RangePartitioning(_, _) | SinglePartition => identity
167188
case _ => sys.error(s"Exchange not implemented for $newPartitioning")
168189
}
@@ -276,8 +297,10 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
276297
val fixedChildren = requirements.zipped.map {
277298
case (AllTuples, rowOrdering, child) =>
278299
addOperatorsIfNecessary(SinglePartition, rowOrdering, child)
279-
case (ClusteredDistribution(clustering), rowOrdering, child) =>
280-
addOperatorsIfNecessary(HashPartitioning(clustering, numPartitions), rowOrdering, child)
300+
case (NullSafeClusteredDistribution(clustering), rowOrdering, child) =>
301+
addOperatorsIfNecessary(NullSafeHashPartitioning(clustering, numPartitions), rowOrdering, child)
302+
case (NullUnsafeClusteredDistribution(clustering), rowOrdering, child) =>
303+
addOperatorsIfNecessary(NullUnsafeHashPartitioning(clustering, numPartitions), rowOrdering, child)
281304
case (OrderedDistribution(ordering), rowOrdering, child) =>
282305
addOperatorsIfNecessary(RangePartitioning(ordering, numPartitions), rowOrdering, child)
283306

sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ case class GeneratedAggregate(
6161
if (groupingExpressions == Nil) {
6262
AllTuples :: Nil
6363
} else {
64-
ClusteredDistribution(groupingExpressions) :: Nil
64+
NullSafeClusteredDistribution(groupingExpressions) :: Nil
6565
}
6666
}
6767

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
403403
case logical.OneRowRelation =>
404404
execution.PhysicalRDD(Nil, singleRowRdd) :: Nil
405405
case logical.RepartitionByExpression(expressions, child) =>
406-
execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil
406+
execution.Exchange(NullSafeHashPartitioning(expressions, numPartitions), planLater(child)) :: Nil
407407
case e @ EvaluatePython(udf, child, _) =>
408408
BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil
409409
case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd) :: Nil

sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ case class Window(
9292
logWarning("No Partition Defined for Window operation! Moving all data to a single "
9393
+ "partition, this can cause serious performance degradation.")
9494
AllTuples :: Nil
95-
} else ClusteredDistribution(windowSpec.partitionSpec) :: Nil
95+
} else NullSafeClusteredDistribution(windowSpec.partitionSpec) :: Nil
9696
}
9797

9898
override def requiredChildOrdering: Seq[Seq[SortOrder]] =

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow
2222
import org.apache.spark.sql.catalyst.errors._
2323
import org.apache.spark.sql.catalyst.expressions._
2424
import org.apache.spark.sql.catalyst.expressions.aggregate._
25-
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, UnspecifiedDistribution}
25+
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, NullSafeClusteredDistribution, Distribution, UnspecifiedDistribution}
2626
import org.apache.spark.sql.execution.{SparkPlan, UnaryNode}
2727

2828
case class Aggregate2Sort(
@@ -49,7 +49,7 @@ case class Aggregate2Sort(
4949
override def requiredChildDistribution: List[Distribution] = {
5050
requiredChildDistributionExpressions match {
5151
case Some(exprs) if exprs.length == 0 => AllTuples :: Nil
52-
case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil
52+
case Some(exprs) if exprs.length > 0 => NullSafeClusteredDistribution(exprs) :: Nil
5353
case None => UnspecifiedDistribution :: Nil
5454
}
5555
}
@@ -144,7 +144,7 @@ case class FinalAndCompleteAggregate2Sort(
144144
if (groupingExpressions.isEmpty) {
145145
AllTuples :: Nil
146146
} else {
147-
ClusteredDistribution(groupingExpressions) :: Nil
147+
NullSafeClusteredDistribution(groupingExpressions) :: Nil
148148
}
149149
}
150150

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import org.apache.spark.annotation.DeveloperApi
2121
import org.apache.spark.rdd.RDD
2222
import org.apache.spark.sql.catalyst.InternalRow
2323
import org.apache.spark.sql.catalyst.expressions._
24-
import org.apache.spark.sql.catalyst.plans.physical.ClusteredDistribution
24+
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, NullUnsafeClusteredDistribution, NullSafeClusteredDistribution}
2525
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
2626

2727
/**
@@ -37,8 +37,8 @@ case class LeftSemiJoinHash(
3737
right: SparkPlan,
3838
condition: Option[Expression]) extends BinaryNode with HashSemiJoin {
3939

40-
override def requiredChildDistribution: Seq[ClusteredDistribution] =
41-
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
40+
override def requiredChildDistribution: Seq[Distribution] =
41+
NullUnsafeClusteredDistribution(leftKeys) :: NullUnsafeClusteredDistribution(rightKeys) :: Nil
4242

4343
protected override def doExecute(): RDD[InternalRow] = {
4444
right.execute().zipPartitions(left.execute()) { (buildIter, streamIter) =>

0 commit comments

Comments
 (0)