Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ trait KeyGroupedPartitionedScan[T] {
filteredPartitions: Seq[Seq[T]],
partitionValueAccessor: T => InternalRow): Seq[Seq[T]] = {
assert(spjParams.keyGroupedPartitioning.isDefined)

if (spjParams.disableGrouping) {
return filteredPartitions.flatten.map(Seq(_))
}

val expressions = spjParams.keyGroupedPartitioning.get

// Re-group the input partitions if we are projecting on a subset of join keys
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,7 @@ object QueryExecution {
PlanDynamicPruningFilters(sparkSession),
PlanSubqueries(sparkSession),
RemoveRedundantProjects,
EnsureRequirements(),
EnsureRequirements(subquery = subquery),
// This rule must be run after `EnsureRequirements`.
InsertSortForLimitAndOffset,
// `ReplaceHashWithSortAgg` needs to be added after `EnsureRequirements` to guarantee the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ case class AdaptiveSparkPlanExec(
// `EnsureRequirements` to not optimize out the user-specified repartition-by-col to work
// around this case.
val ensureRequirements =
EnsureRequirements(requiredDistribution.isDefined, requiredDistribution)
EnsureRequirements(requiredDistribution.isDefined, requiredDistribution, isSubquery)
// CoalesceBucketsInJoin can help eliminate shuffles and must be run before
// EnsureRequirements
Seq(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ import org.apache.spark.sql.internal.SQLConf
*/
case class EnsureRequirements(
optimizeOutRepartition: Boolean = true,
requiredDistribution: Option[Distribution] = None)
requiredDistribution: Option[Distribution] = None,
subquery: Boolean = false)
extends Rule[SparkPlan] {

private def ensureDistributionAndOrdering(
Expand All @@ -66,19 +67,18 @@ case class EnsureRequirements(
case (child, distribution) if child.outputPartitioning.satisfies(distribution) =>
ensureOrdering(child, distribution)
case (child, BroadcastDistribution(mode)) =>
BroadcastExchangeExec(mode, child)
val newChild = disableKeyGroupingIfNotNeeded(child)
BroadcastExchangeExec(mode, newChild)
case (child, distribution) =>
val numPartitions = distribution.requiredNumPartitions
.getOrElse(conf.numShufflePartitions)
distribution match {
case _: StatefulOpClusteredDistribution =>
ShuffleExchangeExec(
distribution.createPartitioning(numPartitions), child,
createShuffleExchangeExec(child, distribution.createPartitioning(numPartitions),
REQUIRED_BY_STATEFUL_OPERATOR)

case _ =>
ShuffleExchangeExec(
distribution.createPartitioning(numPartitions), child, shuffleOrigin)
createShuffleExchangeExec(child, distribution.createPartitioning(numPartitions),
shuffleOrigin)
}
}

Expand Down Expand Up @@ -225,7 +225,8 @@ case class EnsureRequirements(
child match {
case ShuffleExchangeExec(_, c, so, ps) =>
ShuffleExchangeExec(newPartitioning, c, so, ps)
case _ => ShuffleExchangeExec(newPartitioning, child)
case _ =>
createShuffleExchangeExec(child, newPartitioning)
}
}
}
Expand All @@ -244,6 +245,14 @@ case class EnsureRequirements(
children
}

private def createShuffleExchangeExec(
plan: SparkPlan,
partitioning: Partitioning,
shuffleOrigin: ShuffleOrigin = ENSURE_REQUIREMENTS) = {
val newPlan = disableKeyGroupingIfNotNeeded(plan)
ShuffleExchangeExec(partitioning, newPlan, shuffleOrigin)
}

private def reorder(
leftKeys: IndexedSeq[Expression],
rightKeys: IndexedSeq[Expression],
Expand Down Expand Up @@ -390,17 +399,16 @@ case class EnsureRequirements(
}

/**
* Whether partial clustering can be applied to a given child query plan. This is true if the plan
* Whether partial clustering can be applied to a given query plan. This is true if the plan
* consists only of a sequence of unary nodes where each node does not use the scan's key-grouped
* partitioning to satisfy its required distribution. Otherwise, partially clustering could be
* applied to a key-grouped partitioning unrelated to this join.
* applied to a key-grouped partitioning of the scan in the plan.
*/
private def canApplyPartialClusteredDistribution(plan: SparkPlan): Boolean = {
!plan.exists {
// Unary nodes are safe as long as they don't have a required distribution (for example, a
// project or filter). If they have a required distribution, then we should assume that this
// plan can't be partially clustered (since the key-grouped partitioning may be needed to
// satisfy this distribution unrelated to this JOIN).
// project or filter). If they have a required distribution, then we should assume that the
// scan in this plan can't be partially clustered.
case u if u.children.length == 1 =>
u.requiredChildDistribution.head != UnspecifiedDistribution
// Only allow a non-unary node if it's a leaf node - key-grouped partitionings other binary
Expand Down Expand Up @@ -677,39 +685,45 @@ case class EnsureRequirements(
joinKeyPositions: Option[Seq[Int]],
reducers: Option[Seq[Option[Reducer[_, _]]]],
applyPartialClustering: Boolean,
replicatePartitions: Boolean): SparkPlan = plan match {
case scan: BatchScanExec =>
val newScan = scan.copy(
spjParams = scan.spjParams.copy(
commonPartitionValues = Some(values),
joinKeyPositions = joinKeyPositions,
reducers = reducers,
applyPartialClustering = applyPartialClustering,
replicatePartitions = replicatePartitions
replicatePartitions: Boolean) = {
plan transform {
case scan: BatchScanExec =>
scan.copy(
spjParams = scan.spjParams.copy(
commonPartitionValues = Some(values),
joinKeyPositions = joinKeyPositions,
reducers = reducers,
applyPartialClustering = applyPartialClustering,
replicatePartitions = replicatePartitions
)
)
)
newScan.copyTagsFrom(scan)
newScan
case node =>
node.mapChildren(child => populateCommonPartitionInfo(
child, values, joinKeyPositions, reducers, applyPartialClustering, replicatePartitions))
}
}

/**
* Applies partial clustering (disabled partition grouping) if it can be applied to a given query
* plan. If the nodes of the plan don't use the key-grouped partitioning of the scan in the plan
* then unnecessary grouping can decrease parallelization.
*/
private def disableKeyGroupingIfNotNeeded(plan: SparkPlan) = {
if (canApplyPartialClusteredDistribution(plan)) {
populateNoGroupingPartitionInfo(plan)
} else {
plan
}
}

private def populateNoGroupingPartitionInfo(plan: SparkPlan) = {
plan.transform {
case scan: BatchScanExec => scan.copy(spjParams = scan.spjParams.copy(disableGrouping = true))
}
}

private def populateJoinKeyPositions(
plan: SparkPlan,
joinKeyPositions: Option[Seq[Int]]): SparkPlan = plan match {
case scan: BatchScanExec =>
val newScan = scan.copy(
spjParams = scan.spjParams.copy(
joinKeyPositions = joinKeyPositions
)
)
newScan.copyTagsFrom(scan)
newScan
case node =>
node.mapChildren(child => populateJoinKeyPositions(
child, joinKeyPositions))
private def populateJoinKeyPositions(plan: SparkPlan, joinKeyPositions: Option[Seq[Int]]) = {
plan.transform {
case scan: BatchScanExec =>
scan.copy(spjParams = scan.spjParams.copy(joinKeyPositions = joinKeyPositions))
}
}

private def reducePartValues(
Expand Down Expand Up @@ -837,6 +851,18 @@ case class EnsureRequirements(
reordered.withNewChildren(newChildren)
}

// We can't disable partition grouping of a scan in a main query if it contributes the ouput
// partitioning of the query result because we don't know whether the query is
// cached/checkpointed and how the output of the query will be used later. The output must keep
// `KeyGroupedPartitioning` semantics in this case.
// But we can disable partition grouping in subqueries when grouping is not needed for anything
// in the subquery plan.
val groupingDisabledPlan = if (subquery) {
disableKeyGroupingIfNotNeeded(newPlan)
} else {
newPlan
}

if (requiredDistribution.isDefined) {
val shuffleOrigin = if (requiredDistribution.get.requiredNumPartitions.isDefined) {
REPARTITION_BY_NUM
Expand All @@ -845,14 +871,14 @@ case class EnsureRequirements(
}
val finalPlan = ensureDistributionAndOrdering(
None,
newPlan :: Nil,
groupingDisabledPlan :: Nil,
requiredDistribution.get :: Nil,
Seq(Nil),
shuffleOrigin)
assert(finalPlan.size == 1)
finalPlan.head
} else {
newPlan
groupingDisabledPlan
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,15 @@ case class StoragePartitionJoinParams(
commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None,
reducers: Option[Seq[Option[Reducer[_, _]]]] = None,
applyPartialClustering: Boolean = false,
replicatePartitions: Boolean = false) {
replicatePartitions: Boolean = false,
disableGrouping: Boolean = false) {
override def equals(other: Any): Boolean = other match {
case other: StoragePartitionJoinParams =>
this.commonPartitionValues == other.commonPartitionValues &&
this.replicatePartitions == other.replicatePartitions &&
this.applyPartialClustering == other.applyPartialClustering &&
this.joinKeyPositions == other.joinKeyPositions
this.joinKeyPositions == other.joinKeyPositions &&
this.disableGrouping == other.disableGrouping
case _ =>
false
}
Expand All @@ -44,5 +46,6 @@ case class StoragePartitionJoinParams(
joinKeyPositions: Option[Seq[Int]],
commonPartitionValues: Option[Seq[(InternalRow, Int)]],
applyPartialClustering: java.lang.Boolean,
replicatePartitions: java.lang.Boolean)
replicatePartitions: java.lang.Boolean,
disableGrouping: java.lang.Boolean)
}
Original file line number Diff line number Diff line change
Expand Up @@ -1580,7 +1580,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
case (true, false, false) => assert(scannedPartitions == Seq(4, 4))

// No SPJ
case _ => assert(scannedPartitions == Seq(5, 4))
case _ => assert(scannedPartitions == Seq(7, 7))
}

checkAnswer(df, Seq(
Expand Down Expand Up @@ -2114,7 +2114,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
assert(scans == Seq(2, 2))
case (_, _) =>
assert(shuffles.nonEmpty, "SPJ should not be triggered")
assert(scans == Seq(3, 2))
assert(scans == Seq(3, 3))
}

checkAnswer(df, Seq(
Expand Down Expand Up @@ -2234,7 +2234,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
// SPJ and not partially-clustered
case (true, false) => assert(scans == Seq(3, 3))
// No SPJ
case _ => assert(scans == Seq(4, 4))
case _ => assert(scans == Seq(5, 5))
}

checkAnswer(df,
Expand Down Expand Up @@ -2823,4 +2823,65 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0)))
}
}

test("SPARK-55092: Don't group partitions when not needed") {
val items_partitions = Array(identity("id"))
createTable(items, itemsColumns, items_partitions)

sql(s"INSERT INTO testcat.ns.$items VALUES " +
"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
"(4, 'bb', 10.0, cast('2021-01-01' as timestamp)), " +
"(4, 'cc', 15.5, cast('2021-02-01' as timestamp))")

val purchases_partitions = Array(years("time"))
createTable(purchases, purchasesColumns, purchases_partitions)
sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
"(1, 42.0, cast('2020-01-01' as timestamp)), " +
"(3, 19.5, cast('2020-02-01' as timestamp))")

withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true") {
val df = createJoinTestDF(Seq("id" -> "item_id"), extraColumns = Seq("year(p.time)"))

val shuffles = collectShuffles(df.queryExecution.executedPlan)
assert(shuffles.size == 1, "only shuffle one side not report partitioning")

val scans = collectScans(df.queryExecution.executedPlan)
assert(scans(0).inputRDD.partitions.length === 2,
"items scan should group as it is the driver of SPJ")
assert(scans(1).inputRDD.partitions.length === 2,
"purchases scan should not group as SPJ can't leverage it")

checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0, 2020)))
}

withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "false") {
val df = createJoinTestDF(Seq("id" -> "item_id"), extraColumns = Seq("year(p.time)"))

val shuffles = collectShuffles(df.queryExecution.executedPlan)
assert(shuffles.size == 2, "only shuffle one side not report partitioning")

val scans = collectScans(df.queryExecution.executedPlan)
assert(scans(0).inputRDD.partitions.length === 3,
"items scan should not group as it is shuffled")
assert(scans(1).inputRDD.partitions.length === 2,
"purchases scan should not group as it is shuffled")

checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0, 2020)))
}
}

test("SPARK-55092: Main query output maintains partition grouping despite it is not needed") {
val items_partitions = Array(identity("id"))
createTable(items, itemsColumns, items_partitions)

sql(s"INSERT INTO testcat.ns.$items VALUES " +
"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
"(4, 'bb', 10.0, cast('2021-01-01' as timestamp)), " +
"(4, 'cc', 15.5, cast('2021-02-01' as timestamp))")

val df = sql(s"SELECT * FROM testcat.ns.$items")
val scans = collectScans(df.queryExecution.executedPlan)
assert(scans(0).inputRDD.partitions.length === 2,
"items scan should group to maintain query output partitioning semantics")
}
}