diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala index 10a6aaa2e1851..00331b38ccdec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala @@ -81,6 +81,11 @@ trait KeyGroupedPartitionedScan[T] { filteredPartitions: Seq[Seq[T]], partitionValueAccessor: T => InternalRow): Seq[Seq[T]] = { assert(spjParams.keyGroupedPartitioning.isDefined) + + if (spjParams.disableGrouping) { + return filteredPartitions.flatten.map(Seq(_)) + } + val expressions = spjParams.keyGroupedPartitioning.get // Re-group the input partitions if we are projecting on a subset of join keys diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index f08b561d6ef9a..7d163518aeb25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -630,7 +630,7 @@ object QueryExecution { PlanDynamicPruningFilters(sparkSession), PlanSubqueries(sparkSession), RemoveRedundantProjects, - EnsureRequirements(), + EnsureRequirements(subquery = subquery), // This rule must be run after `EnsureRequirements`. InsertSortForLimitAndOffset, // `ReplaceHashWithSortAgg` needs to be added after `EnsureRequirements` to guarantee the diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index 4840016bf745d..152dd1505f389 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -113,7 +113,7 @@ case class AdaptiveSparkPlanExec( // `EnsureRequirements` to not optimize out the user-specified repartition-by-col to work // around this case. val ensureRequirements = - EnsureRequirements(requiredDistribution.isDefined, requiredDistribution) + EnsureRequirements(requiredDistribution.isDefined, requiredDistribution, isSubquery) // CoalesceBucketsInJoin can help eliminate shuffles and must be run before // EnsureRequirements Seq( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index e239174e40ad4..aaed1bbf4f5d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -50,7 +50,8 @@ import org.apache.spark.sql.internal.SQLConf */ case class EnsureRequirements( optimizeOutRepartition: Boolean = true, - requiredDistribution: Option[Distribution] = None) + requiredDistribution: Option[Distribution] = None, + subquery: Boolean = false) extends Rule[SparkPlan] { private def ensureDistributionAndOrdering( @@ -66,19 +67,18 @@ case class EnsureRequirements( case (child, distribution) if child.outputPartitioning.satisfies(distribution) => ensureOrdering(child, distribution) case (child, BroadcastDistribution(mode)) => - BroadcastExchangeExec(mode, child) + val newChild = disableKeyGroupingIfNotNeeded(child) + BroadcastExchangeExec(mode, newChild) case (child, distribution) => val numPartitions = distribution.requiredNumPartitions .getOrElse(conf.numShufflePartitions) distribution match { case _: StatefulOpClusteredDistribution => - ShuffleExchangeExec( - distribution.createPartitioning(numPartitions), child, + createShuffleExchangeExec(child, distribution.createPartitioning(numPartitions), REQUIRED_BY_STATEFUL_OPERATOR) - case _ => - ShuffleExchangeExec( - distribution.createPartitioning(numPartitions), child, shuffleOrigin) + createShuffleExchangeExec(child, distribution.createPartitioning(numPartitions), + shuffleOrigin) } } @@ -225,7 +225,8 @@ case class EnsureRequirements( child match { case ShuffleExchangeExec(_, c, so, ps) => ShuffleExchangeExec(newPartitioning, c, so, ps) - case _ => ShuffleExchangeExec(newPartitioning, child) + case _ => + createShuffleExchangeExec(child, newPartitioning) } } } @@ -244,6 +245,14 @@ case class EnsureRequirements( children } + private def createShuffleExchangeExec( + plan: SparkPlan, + partitioning: Partitioning, + shuffleOrigin: ShuffleOrigin = ENSURE_REQUIREMENTS) = { + val newPlan = disableKeyGroupingIfNotNeeded(plan) + ShuffleExchangeExec(partitioning, newPlan, shuffleOrigin) + } + private def reorder( leftKeys: IndexedSeq[Expression], rightKeys: IndexedSeq[Expression], @@ -390,17 +399,16 @@ case class EnsureRequirements( } /** - * Whether partial clustering can be applied to a given child query plan. This is true if the plan + * Whether partial clustering can be applied to a given query plan. This is true if the plan * consists only of a sequence of unary nodes where each node does not use the scan's key-grouped * partitioning to satisfy its required distribution. Otherwise, partially clustering could be - * applied to a key-grouped partitioning unrelated to this join. + * applied to a key-grouped partitioning of the scan in the plan. */ private def canApplyPartialClusteredDistribution(plan: SparkPlan): Boolean = { !plan.exists { // Unary nodes are safe as long as they don't have a required distribution (for example, a - // project or filter). If they have a required distribution, then we should assume that this - // plan can't be partially clustered (since the key-grouped partitioning may be needed to - // satisfy this distribution unrelated to this JOIN). + // project or filter). If they have a required distribution, then we should assume that the + // scan in this plan can't be partially clustered. case u if u.children.length == 1 => u.requiredChildDistribution.head != UnspecifiedDistribution // Only allow a non-unary node if it's a leaf node - key-grouped partitionings other binary @@ -677,39 +685,45 @@ case class EnsureRequirements( joinKeyPositions: Option[Seq[Int]], reducers: Option[Seq[Option[Reducer[_, _]]]], applyPartialClustering: Boolean, - replicatePartitions: Boolean): SparkPlan = plan match { - case scan: BatchScanExec => - val newScan = scan.copy( - spjParams = scan.spjParams.copy( - commonPartitionValues = Some(values), - joinKeyPositions = joinKeyPositions, - reducers = reducers, - applyPartialClustering = applyPartialClustering, - replicatePartitions = replicatePartitions + replicatePartitions: Boolean) = { + plan transform { + case scan: BatchScanExec => + scan.copy( + spjParams = scan.spjParams.copy( + commonPartitionValues = Some(values), + joinKeyPositions = joinKeyPositions, + reducers = reducers, + applyPartialClustering = applyPartialClustering, + replicatePartitions = replicatePartitions + ) ) - ) - newScan.copyTagsFrom(scan) - newScan - case node => - node.mapChildren(child => populateCommonPartitionInfo( - child, values, joinKeyPositions, reducers, applyPartialClustering, replicatePartitions)) + } + } + + /** + * Applies partial clustering (disabled partition grouping) if it can be applied to a given query + * plan. If the nodes of the plan don't use the key-grouped partitioning of the scan in the plan + * then unnecessary grouping can decrease parallelization. + */ + private def disableKeyGroupingIfNotNeeded(plan: SparkPlan) = { + if (canApplyPartialClusteredDistribution(plan)) { + populateNoGroupingPartitionInfo(plan) + } else { + plan + } } + private def populateNoGroupingPartitionInfo(plan: SparkPlan) = { + plan.transform { + case scan: BatchScanExec => scan.copy(spjParams = scan.spjParams.copy(disableGrouping = true)) + } + } - private def populateJoinKeyPositions( - plan: SparkPlan, - joinKeyPositions: Option[Seq[Int]]): SparkPlan = plan match { - case scan: BatchScanExec => - val newScan = scan.copy( - spjParams = scan.spjParams.copy( - joinKeyPositions = joinKeyPositions - ) - ) - newScan.copyTagsFrom(scan) - newScan - case node => - node.mapChildren(child => populateJoinKeyPositions( - child, joinKeyPositions)) + private def populateJoinKeyPositions(plan: SparkPlan, joinKeyPositions: Option[Seq[Int]]) = { + plan.transform { + case scan: BatchScanExec => + scan.copy(spjParams = scan.spjParams.copy(joinKeyPositions = joinKeyPositions)) + } } private def reducePartValues( @@ -837,6 +851,18 @@ case class EnsureRequirements( reordered.withNewChildren(newChildren) } + // We can't disable partition grouping of a scan in a main query if it contributes the ouput + // partitioning of the query result because we don't know whether the query is + // cached/checkpointed and how the output of the query will be used later. The output must keep + // `KeyGroupedPartitioning` semantics in this case. + // But we can disable partition grouping in subqueries when grouping is not needed for anything + // in the subquery plan. + val groupingDisabledPlan = if (subquery) { + disableKeyGroupingIfNotNeeded(newPlan) + } else { + newPlan + } + if (requiredDistribution.isDefined) { val shuffleOrigin = if (requiredDistribution.get.requiredNumPartitions.isDefined) { REPARTITION_BY_NUM @@ -845,14 +871,14 @@ case class EnsureRequirements( } val finalPlan = ensureDistributionAndOrdering( None, - newPlan :: Nil, + groupingDisabledPlan :: Nil, requiredDistribution.get :: Nil, Seq(Nil), shuffleOrigin) assert(finalPlan.size == 1) finalPlan.head } else { - newPlan + groupingDisabledPlan } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/StoragePartitionJoinParams.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/StoragePartitionJoinParams.scala index a28eafc5cae5b..8322f4eecc1cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/StoragePartitionJoinParams.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/StoragePartitionJoinParams.scala @@ -29,13 +29,15 @@ case class StoragePartitionJoinParams( commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None, reducers: Option[Seq[Option[Reducer[_, _]]]] = None, applyPartialClustering: Boolean = false, - replicatePartitions: Boolean = false) { + replicatePartitions: Boolean = false, + disableGrouping: Boolean = false) { override def equals(other: Any): Boolean = other match { case other: StoragePartitionJoinParams => this.commonPartitionValues == other.commonPartitionValues && this.replicatePartitions == other.replicatePartitions && this.applyPartialClustering == other.applyPartialClustering && - this.joinKeyPositions == other.joinKeyPositions + this.joinKeyPositions == other.joinKeyPositions && + this.disableGrouping == other.disableGrouping case _ => false } @@ -44,5 +46,6 @@ case class StoragePartitionJoinParams( joinKeyPositions: Option[Seq[Int]], commonPartitionValues: Option[Seq[(InternalRow, Int)]], applyPartialClustering: java.lang.Boolean, - replicatePartitions: java.lang.Boolean) + replicatePartitions: java.lang.Boolean, + disableGrouping: java.lang.Boolean) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index a1b1b8444719a..e9a0be946d890 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -1580,7 +1580,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { case (true, false, false) => assert(scannedPartitions == Seq(4, 4)) // No SPJ - case _ => assert(scannedPartitions == Seq(5, 4)) + case _ => assert(scannedPartitions == Seq(7, 7)) } checkAnswer(df, Seq( @@ -2114,7 +2114,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { assert(scans == Seq(2, 2)) case (_, _) => assert(shuffles.nonEmpty, "SPJ should not be triggered") - assert(scans == Seq(3, 2)) + assert(scans == Seq(3, 3)) } checkAnswer(df, Seq( @@ -2234,7 +2234,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { // SPJ and not partially-clustered case (true, false) => assert(scans == Seq(3, 3)) // No SPJ - case _ => assert(scans == Seq(4, 4)) + case _ => assert(scans == Seq(5, 5)) } checkAnswer(df, @@ -2823,4 +2823,65 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0))) } } + + test("SPARK-55092: Don't group partitions when not needed") { + val items_partitions = Array(identity("id")) + createTable(items, itemsColumns, items_partitions) + + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(4, 'bb', 10.0, cast('2021-01-01' as timestamp)), " + + "(4, 'cc', 15.5, cast('2021-02-01' as timestamp))") + + val purchases_partitions = Array(years("time")) + createTable(purchases, purchasesColumns, purchases_partitions) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(3, 19.5, cast('2020-02-01' as timestamp))") + + withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true") { + val df = createJoinTestDF(Seq("id" -> "item_id"), extraColumns = Seq("year(p.time)")) + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.size == 1, "only shuffle one side not report partitioning") + + val scans = collectScans(df.queryExecution.executedPlan) + assert(scans(0).inputRDD.partitions.length === 2, + "items scan should group as it is the driver of SPJ") + assert(scans(1).inputRDD.partitions.length === 2, + "purchases scan should not group as SPJ can't leverage it") + + checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0, 2020))) + } + + withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "false") { + val df = createJoinTestDF(Seq("id" -> "item_id"), extraColumns = Seq("year(p.time)")) + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.size == 2, "only shuffle one side not report partitioning") + + val scans = collectScans(df.queryExecution.executedPlan) + assert(scans(0).inputRDD.partitions.length === 3, + "items scan should not group as it is shuffled") + assert(scans(1).inputRDD.partitions.length === 2, + "purchases scan should not group as it is shuffled") + + checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0, 2020))) + } + } + + test("SPARK-55092: Main query output maintains partition grouping despite it is not needed") { + val items_partitions = Array(identity("id")) + createTable(items, itemsColumns, items_partitions) + + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(4, 'bb', 10.0, cast('2021-01-01' as timestamp)), " + + "(4, 'cc', 15.5, cast('2021-02-01' as timestamp))") + + val df = sql(s"SELECT * FROM testcat.ns.$items") + val scans = collectScans(df.queryExecution.executedPlan) + assert(scans(0).inputRDD.partitions.length === 2, + "items scan should group to maintain query output partitioning semantics") + } }