diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 79a8380826ab3..039fd9382000a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning} import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.random.RandomSampler @@ -953,16 +954,18 @@ case class Repartition(numPartitions: Int, shuffle: Boolean, child: LogicalPlan) } /** - * This method repartitions data using [[Expression]]s into `numPartitions`, and receives + * This method repartitions data using [[Expression]]s into `optNumPartitions`, and receives * information about the number of partitions during execution. Used when a specific ordering or * distribution is expected by the consumer of the query result. Use [[Repartition]] for RDD-like - * `coalesce` and `repartition`. + * `coalesce` and `repartition`. If no `optNumPartitions` is given, by default it partitions data + * into `numShufflePartitions` defined in `SQLConf`, and could be coalesced by AQE. */ case class RepartitionByExpression( partitionExpressions: Seq[Expression], child: LogicalPlan, - numPartitions: Int) extends RepartitionOperation { + optNumPartitions: Option[Int]) extends RepartitionOperation { + val numPartitions = optNumPartitions.getOrElse(SQLConf.get.numShufflePartitions) require(numPartitions > 0, s"Number of partitions ($numPartitions) must be positive.") val partitioning: Partitioning = { @@ -990,6 +993,15 @@ case class RepartitionByExpression( override def shuffle: Boolean = true } +object RepartitionByExpression { + def apply( + partitionExpressions: Seq[Expression], + child: LogicalPlan, + numPartitions: Int): RepartitionByExpression = { + RepartitionByExpression(partitionExpressions, child, Some(numPartitions)) + } +} + /** * A relation with one row. This is used in "SELECT ..." without a from clause. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 524e231eb7eb9..6f97121d88ede 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2991,17 +2991,9 @@ class Dataset[T] private[sql]( Repartition(numPartitions, shuffle = true, logicalPlan) } - /** - * Returns a new Dataset partitioned by the given partitioning expressions into - * `numPartitions`. The resulting Dataset is hash partitioned. - * - * This is the same operation as "DISTRIBUTE BY" in SQL (Hive QL). - * - * @group typedrel - * @since 2.0.0 - */ - @scala.annotation.varargs - def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = { + private def repartitionByExpression( + numPartitions: Option[Int], + partitionExprs: Seq[Column]): Dataset[T] = { // The underlying `LogicalPlan` operator special-cases all-`SortOrder` arguments. // However, we don't want to complicate the semantics of this API method. // Instead, let's give users a friendly error message, pointing them to the new method. @@ -3015,6 +3007,20 @@ class Dataset[T] private[sql]( } } + /** + * Returns a new Dataset partitioned by the given partitioning expressions into + * `numPartitions`. The resulting Dataset is hash partitioned. + * + * This is the same operation as "DISTRIBUTE BY" in SQL (Hive QL). + * + * @group typedrel + * @since 2.0.0 + */ + @scala.annotation.varargs + def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = { + repartitionByExpression(Some(numPartitions), partitionExprs) + } + /** * Returns a new Dataset partitioned by the given partitioning expressions, using * `spark.sql.shuffle.partitions` as number of partitions. @@ -3027,7 +3033,20 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def repartition(partitionExprs: Column*): Dataset[T] = { - repartition(sparkSession.sessionState.conf.numShufflePartitions, partitionExprs: _*) + repartitionByExpression(None, partitionExprs) + } + + private def repartitionByRange( + numPartitions: Option[Int], + partitionExprs: Seq[Column]): Dataset[T] = { + require(partitionExprs.nonEmpty, "At least one partition-by expression must be specified.") + val sortOrder: Seq[SortOrder] = partitionExprs.map(_.expr match { + case expr: SortOrder => expr + case expr: Expression => SortOrder(expr, Ascending) + }) + withTypedPlan { + RepartitionByExpression(sortOrder, logicalPlan, numPartitions) + } } /** @@ -3049,14 +3068,7 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def repartitionByRange(numPartitions: Int, partitionExprs: Column*): Dataset[T] = { - require(partitionExprs.nonEmpty, "At least one partition-by expression must be specified.") - val sortOrder: Seq[SortOrder] = partitionExprs.map(_.expr match { - case expr: SortOrder => expr - case expr: Expression => SortOrder(expr, Ascending) - }) - withTypedPlan { - RepartitionByExpression(sortOrder, logicalPlan, numPartitions) - } + repartitionByRange(Some(numPartitions), partitionExprs) } /** @@ -3078,7 +3090,7 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def repartitionByRange(partitionExprs: Column*): Dataset[T] = { - repartitionByRange(sparkSession.sessionState.conf.numShufflePartitions, partitionExprs: _*) + repartitionByRange(None, partitionExprs) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 4d23e5e8a65b5..5341b4778c539 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -685,8 +685,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case r: logical.Range => execution.RangeExec(r) :: Nil case r: logical.RepartitionByExpression => + val canChangeNumParts = r.optNumPartitions.isEmpty exchange.ShuffleExchangeExec( - r.partitioning, planLater(r.child), canChangeNumPartitions = false) :: Nil + r.partitioning, planLater(r.child), canChangeNumParts) :: Nil case ExternalRDD(outputObjAttr, rdd) => ExternalRDDScanExec(outputObjAttr, rdd) :: Nil case r: LogicalRDD => RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, r.outputOrdering) :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 9fa97bffa8910..27d9748476c98 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} import org.apache.spark.sql.execution.{PartialReducerPartitionSpec, ReusedSubqueryExec, ShuffledRowRDD, SparkPlan} import org.apache.spark.sql.execution.command.DataWritingCommandExec -import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange, ReusedExchangeExec} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate import org.apache.spark.sql.functions._ @@ -1022,18 +1022,81 @@ class AdaptiveQueryExecSuite } } - test("SPARK-31220 repartition obeys initialPartitionNum when adaptiveExecutionEnabled") { + test("SPARK-31220, SPARK-32056: repartition by expression with AQE") { Seq(true, false).foreach { enableAQE => withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString, - SQLConf.SHUFFLE_PARTITIONS.key -> "6", - SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> "7") { - val partitionsNum = spark.range(10).repartition($"id").rdd.collectPartitions().length + SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true", + SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> "10", + SQLConf.SHUFFLE_PARTITIONS.key -> "10") { + + val df1 = spark.range(10).repartition($"id") + val df2 = spark.range(10).repartition($"id" + 1) + + val partitionsNum1 = df1.rdd.collectPartitions().length + val partitionsNum2 = df2.rdd.collectPartitions().length + if (enableAQE) { - assert(partitionsNum === 7) + assert(partitionsNum1 < 10) + assert(partitionsNum2 < 10) + + // repartition obeys initialPartitionNum when adaptiveExecutionEnabled + val plan = df1.queryExecution.executedPlan + assert(plan.isInstanceOf[AdaptiveSparkPlanExec]) + val shuffle = plan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect { + case s: ShuffleExchangeExec => s + } + assert(shuffle.size == 1) + assert(shuffle(0).outputPartitioning.numPartitions == 10) } else { - assert(partitionsNum === 6) + assert(partitionsNum1 === 10) + assert(partitionsNum2 === 10) } + + + // Don't coalesce partitions if the number of partitions is specified. + val df3 = spark.range(10).repartition(10, $"id") + val df4 = spark.range(10).repartition(10) + assert(df3.rdd.collectPartitions().length == 10) + assert(df4.rdd.collectPartitions().length == 10) + } + } + } + + test("SPARK-31220, SPARK-32056: repartition by range with AQE") { + Seq(true, false).foreach { enableAQE => + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString, + SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true", + SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> "10", + SQLConf.SHUFFLE_PARTITIONS.key -> "10") { + + val df1 = spark.range(10).toDF.repartitionByRange($"id".asc) + val df2 = spark.range(10).toDF.repartitionByRange(($"id" + 1).asc) + + val partitionsNum1 = df1.rdd.collectPartitions().length + val partitionsNum2 = df2.rdd.collectPartitions().length + + if (enableAQE) { + assert(partitionsNum1 < 10) + assert(partitionsNum2 < 10) + + // repartition obeys initialPartitionNum when adaptiveExecutionEnabled + val plan = df1.queryExecution.executedPlan + assert(plan.isInstanceOf[AdaptiveSparkPlanExec]) + val shuffle = plan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect { + case s: ShuffleExchangeExec => s + } + assert(shuffle.size == 1) + assert(shuffle(0).outputPartitioning.numPartitions == 10) + } else { + assert(partitionsNum1 === 10) + assert(partitionsNum2 === 10) + } + + // Don't coalesce partitions if the number of partitions is specified. + val df3 = spark.range(10).repartitionByRange(10, $"id".asc) + assert(df3.rdd.collectPartitions().length == 10) } } }