Skip to content

Commit eb45b52

Browse files
committed
[SPARK-21865][SQL] simplify the distribution semantic of Spark SQL
## What changes were proposed in this pull request? **The current shuffle planning logic** 1. Each operator specifies the distribution requirements for its children, via the `Distribution` interface. 2. Each operator specifies its output partitioning, via the `Partitioning` interface. 3. `Partitioning.satisfy` determines whether a `Partitioning` can satisfy a `Distribution`. 4. For each operator, check each child of it, add a shuffle node above the child if the child partitioning can not satisfy the required distribution. 5. For each operator, check if its children's output partitionings are compatible with each other, via the `Partitioning.compatibleWith`. 6. If the check in 5 failed, add a shuffle above each child. 7. try to eliminate the shuffles added in 6, via `Partitioning.guarantees`. This design has a major problem with the definition of "compatible". `Partitioning.compatibleWith` is not well defined, ideally a `Partitioning` can't know if it's compatible with other `Partitioning`, without more information from the operator. For example, `t1 join t2 on t1.a = t2.b`, `HashPartitioning(a, 10)` should be compatible with `HashPartitioning(b, 10)` under this case, but the partitioning itself doesn't know it. As a result, currently `Partitioning.compatibleWith` always return false except for literals, which make it almost useless. This also means, if an operator has distribution requirements for multiple children, Spark always add shuffle nodes to all the children(although some of them can be eliminated). However, there is no guarantee that the children's output partitionings are compatible with each other after adding these shuffles, we just assume that the operator will only specify `ClusteredDistribution` for multiple children. I think it's very hard to guarantee children co-partition for all kinds of operators, and we can not even give a clear definition about co-partition between distributions like `ClusteredDistribution(a,b)` and `ClusteredDistribution(c)`. I think we should drop the "compatible" concept in the distribution model, and let the operator achieve the co-partition requirement by special distribution requirements. **Proposed shuffle planning logic after this PR** (The first 4 are same as before) 1. Each operator specifies the distribution requirements for its children, via the `Distribution` interface. 2. Each operator specifies its output partitioning, via the `Partitioning` interface. 3. `Partitioning.satisfy` determines whether a `Partitioning` can satisfy a `Distribution`. 4. For each operator, check each child of it, add a shuffle node above the child if the child partitioning can not satisfy the required distribution. 5. For each operator, check if its children's output partitionings have the same number of partitions. 6. If the check in 5 failed, pick the max number of partitions from children's output partitionings, and add shuffle to child whose number of partitions doesn't equal to the max one. The new distribution model is very simple, we only have one kind of relationship, which is `Partitioning.satisfy`. For multiple children, Spark only guarantees they have the same number of partitions, and it's the operator's responsibility to leverage this guarantee to achieve more complicated requirements. For example, non-broadcast joins can use the newly added `HashPartitionedDistribution` to achieve co-partition. ## How was this patch tested? existing tests. Author: Wenchen Fan <wenchen@databricks.com> Closes #19080 from cloud-fan/exchange.
1 parent 2c73d2a commit eb45b52

File tree

8 files changed

+194
-370
lines changed

8 files changed

+194
-370
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala

Lines changed: 109 additions & 177 deletions
Large diffs are not rendered by default.

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala

Lines changed: 0 additions & 55 deletions
This file was deleted.

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,21 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
9494
/** Specifies how data is partitioned across different nodes in the cluster. */
9595
def outputPartitioning: Partitioning = UnknownPartitioning(0) // TODO: WRONG WIDTH!
9696

97-
/** Specifies any partition requirements on the input data for this operator. */
97+
/**
98+
* Specifies the data distribution requirements of all the children for this operator. By default
99+
* it's [[UnspecifiedDistribution]] for each child, which means each child can have any
100+
* distribution.
101+
*
102+
* If an operator overwrites this method, and specifies distribution requirements(excluding
103+
* [[UnspecifiedDistribution]] and [[BroadcastDistribution]]) for more than one child, Spark
104+
* guarantees that the outputs of these children will have same number of partitions, so that the
105+
* operator can safely zip partitions of these children's result RDDs. Some operators can leverage
106+
* this guarantee to satisfy some interesting requirement, e.g., non-broadcast joins can specify
107+
* HashClusteredDistribution(a,b) for its left child, and specify HashClusteredDistribution(c,d)
108+
* for its right child, then it's guaranteed that left and right child are co-partitioned by
109+
* a,b/c,d, which means tuples of same value are in the partitions of same index, e.g.,
110+
* (a=1,b=2) and (c=1,d=2) are both in the second partition of left and right child.
111+
*/
98112
def requiredChildDistribution: Seq[Distribution] =
99113
Seq.fill(children.size)(UnspecifiedDistribution)
100114

sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala

Lines changed: 44 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -46,23 +46,6 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
4646
if (minNumPostShufflePartitions > 0) Some(minNumPostShufflePartitions) else None
4747
}
4848

49-
/**
50-
* Given a required distribution, returns a partitioning that satisfies that distribution.
51-
* @param requiredDistribution The distribution that is required by the operator
52-
* @param numPartitions Used when the distribution doesn't require a specific number of partitions
53-
*/
54-
private def createPartitioning(
55-
requiredDistribution: Distribution,
56-
numPartitions: Int): Partitioning = {
57-
requiredDistribution match {
58-
case AllTuples => SinglePartition
59-
case ClusteredDistribution(clustering, desiredPartitions) =>
60-
HashPartitioning(clustering, desiredPartitions.getOrElse(numPartitions))
61-
case OrderedDistribution(ordering) => RangePartitioning(ordering, numPartitions)
62-
case dist => sys.error(s"Do not know how to satisfy distribution $dist")
63-
}
64-
}
65-
6649
/**
6750
* Adds [[ExchangeCoordinator]] to [[ShuffleExchangeExec]]s if adaptive query execution is enabled
6851
* and partitioning schemes of these [[ShuffleExchangeExec]]s support [[ExchangeCoordinator]].
@@ -88,8 +71,9 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
8871
// shuffle data when we have more than one children because data generated by
8972
// these children may not be partitioned in the same way.
9073
// Please see the comment in withCoordinator for more details.
91-
val supportsDistribution =
92-
requiredChildDistributions.forall(_.isInstanceOf[ClusteredDistribution])
74+
val supportsDistribution = requiredChildDistributions.forall { dist =>
75+
dist.isInstanceOf[ClusteredDistribution] || dist.isInstanceOf[HashClusteredDistribution]
76+
}
9377
children.length > 1 && supportsDistribution
9478
}
9579

@@ -142,8 +126,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
142126
//
143127
// It will be great to introduce a new Partitioning to represent the post-shuffle
144128
// partitions when one post-shuffle partition includes multiple pre-shuffle partitions.
145-
val targetPartitioning =
146-
createPartitioning(distribution, defaultNumPreShufflePartitions)
129+
val targetPartitioning = distribution.createPartitioning(defaultNumPreShufflePartitions)
147130
assert(targetPartitioning.isInstanceOf[HashPartitioning])
148131
ShuffleExchangeExec(targetPartitioning, child, Some(coordinator))
149132
}
@@ -162,71 +145,56 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
162145
assert(requiredChildDistributions.length == children.length)
163146
assert(requiredChildOrderings.length == children.length)
164147

165-
// Ensure that the operator's children satisfy their output distribution requirements:
148+
// Ensure that the operator's children satisfy their output distribution requirements.
166149
children = children.zip(requiredChildDistributions).map {
167150
case (child, distribution) if child.outputPartitioning.satisfies(distribution) =>
168151
child
169152
case (child, BroadcastDistribution(mode)) =>
170153
BroadcastExchangeExec(mode, child)
171154
case (child, distribution) =>
172-
ShuffleExchangeExec(createPartitioning(distribution, defaultNumPreShufflePartitions), child)
155+
val numPartitions = distribution.requiredNumPartitions
156+
.getOrElse(defaultNumPreShufflePartitions)
157+
ShuffleExchangeExec(distribution.createPartitioning(numPartitions), child)
173158
}
174159

175-
// If the operator has multiple children and specifies child output distributions (e.g. join),
176-
// then the children's output partitionings must be compatible:
177-
def requireCompatiblePartitioning(distribution: Distribution): Boolean = distribution match {
178-
case UnspecifiedDistribution => false
179-
case BroadcastDistribution(_) => false
160+
// Get the indexes of children which have specified distribution requirements and need to have
161+
// same number of partitions.
162+
val childrenIndexes = requiredChildDistributions.zipWithIndex.filter {
163+
case (UnspecifiedDistribution, _) => false
164+
case (_: BroadcastDistribution, _) => false
180165
case _ => true
181-
}
182-
if (children.length > 1
183-
&& requiredChildDistributions.exists(requireCompatiblePartitioning)
184-
&& !Partitioning.allCompatible(children.map(_.outputPartitioning))) {
166+
}.map(_._2)
185167

186-
// First check if the existing partitions of the children all match. This means they are
187-
// partitioned by the same partitioning into the same number of partitions. In that case,
188-
// don't try to make them match `defaultPartitions`, just use the existing partitioning.
189-
val maxChildrenNumPartitions = children.map(_.outputPartitioning.numPartitions).max
190-
val useExistingPartitioning = children.zip(requiredChildDistributions).forall {
191-
case (child, distribution) =>
192-
child.outputPartitioning.guarantees(
193-
createPartitioning(distribution, maxChildrenNumPartitions))
168+
val childrenNumPartitions =
169+
childrenIndexes.map(children(_).outputPartitioning.numPartitions).toSet
170+
171+
if (childrenNumPartitions.size > 1) {
172+
// Get the number of partitions which is explicitly required by the distributions.
173+
val requiredNumPartitions = {
174+
val numPartitionsSet = childrenIndexes.flatMap {
175+
index => requiredChildDistributions(index).requiredNumPartitions
176+
}.toSet
177+
assert(numPartitionsSet.size <= 1,
178+
s"$operator have incompatible requirements of the number of partitions for its children")
179+
numPartitionsSet.headOption
194180
}
195181

196-
children = if (useExistingPartitioning) {
197-
// We do not need to shuffle any child's output.
198-
children
199-
} else {
200-
// We need to shuffle at least one child's output.
201-
// Now, we will determine the number of partitions that will be used by created
202-
// partitioning schemes.
203-
val numPartitions = {
204-
// Let's see if we need to shuffle all child's outputs when we use
205-
// maxChildrenNumPartitions.
206-
val shufflesAllChildren = children.zip(requiredChildDistributions).forall {
207-
case (child, distribution) =>
208-
!child.outputPartitioning.guarantees(
209-
createPartitioning(distribution, maxChildrenNumPartitions))
210-
}
211-
// If we need to shuffle all children, we use defaultNumPreShufflePartitions as the
212-
// number of partitions. Otherwise, we use maxChildrenNumPartitions.
213-
if (shufflesAllChildren) defaultNumPreShufflePartitions else maxChildrenNumPartitions
214-
}
182+
val targetNumPartitions = requiredNumPartitions.getOrElse(childrenNumPartitions.max)
215183

216-
children.zip(requiredChildDistributions).map {
217-
case (child, distribution) =>
218-
val targetPartitioning = createPartitioning(distribution, numPartitions)
219-
if (child.outputPartitioning.guarantees(targetPartitioning)) {
220-
child
221-
} else {
222-
child match {
223-
// If child is an exchange, we replace it with
224-
// a new one having targetPartitioning.
225-
case ShuffleExchangeExec(_, c, _) => ShuffleExchangeExec(targetPartitioning, c)
226-
case _ => ShuffleExchangeExec(targetPartitioning, child)
227-
}
184+
children = children.zip(requiredChildDistributions).zipWithIndex.map {
185+
case ((child, distribution), index) if childrenIndexes.contains(index) =>
186+
if (child.outputPartitioning.numPartitions == targetNumPartitions) {
187+
child
188+
} else {
189+
val defaultPartitioning = distribution.createPartitioning(targetNumPartitions)
190+
child match {
191+
// If child is an exchange, we replace it with a new one having defaultPartitioning.
192+
case ShuffleExchangeExec(_, c, _) => ShuffleExchangeExec(defaultPartitioning, c)
193+
case _ => ShuffleExchangeExec(defaultPartitioning, child)
194+
}
228195
}
229-
}
196+
197+
case ((child, _), _) => child
230198
}
231199
}
232200

@@ -324,10 +292,10 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
324292
}
325293

326294
def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
327-
case operator @ ShuffleExchangeExec(partitioning, child, _) =>
328-
child.children match {
329-
case ShuffleExchangeExec(childPartitioning, baseChild, _)::Nil =>
330-
if (childPartitioning.guarantees(partitioning)) child else operator
295+
// TODO: remove this after we create a physical operator for `RepartitionByExpression`.
296+
case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, _) =>
297+
child.outputPartitioning match {
298+
case lower: HashPartitioning if upper.semanticEquals(lower) => child
331299
case _ => operator
332300
}
333301
case operator: SparkPlan =>

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ case class ShuffledHashJoinExec(
4646
"avgHashProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hash probe"))
4747

4848
override def requiredChildDistribution: Seq[Distribution] =
49-
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
49+
HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil
5050

5151
private def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = {
5252
val buildDataSize = longMetric("buildDataSize")

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ case class SortMergeJoinExec(
7878
}
7979

8080
override def requiredChildDistribution: Seq[Distribution] =
81-
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
81+
HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil
8282

8383
override def outputOrdering: Seq[SortOrder] = joinType match {
8484
// For inner join, orders of both sides keys should be kept.

sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,7 @@ case class CoGroupExec(
456456
right: SparkPlan) extends BinaryExecNode with ObjectProducerExec {
457457

458458
override def requiredChildDistribution: Seq[Distribution] =
459-
ClusteredDistribution(leftGroup) :: ClusteredDistribution(rightGroup) :: Nil
459+
HashClusteredDistribution(leftGroup) :: HashClusteredDistribution(rightGroup) :: Nil
460460

461461
override def requiredChildOrdering: Seq[Seq[SortOrder]] =
462462
leftGroup.map(SortOrder(_, Ascending)) :: rightGroup.map(SortOrder(_, Ascending)) :: Nil

0 commit comments

Comments
 (0)