diff --git a/sql/api/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/api/src/main/scala/org/apache/spark/sql/Dataset.scala index eda20f6fae80..d09bdeb9a675 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2981,6 +2981,35 @@ abstract class Dataset[T] extends Serializable { */ def repartitionById(numPartitions: Int, partitionIdExpr: Column): Dataset[T] + /** + * Proactively optimizes the partition count of this Dataset based on its estimated size. + * + * == Best Practice: Use on Ingest == + * This method is best used immediately after reading a dataset to ensure the initial + * parallelism matches the data size. This prevents "Small File" issues (too many partitions) + * or "Giant Partition" issues (too few partitions) before heavy transformations begin. + * + * {{{ + * val raw = spark.read.parquet("...") + * val optimized = raw.optimizePartitions() // Perfect start for transformations + * optimized.filter(...).groupBy(...) + * }}} + * + * == Warning: Use on Write == + * This method uses Round Robin partitioning (random shuffle) to balance sizes. + * If used immediately before writing to a partitioned table (e.g., `write.partitionBy("city")`), + * it may degrade performance by breaking data locality, causing the writer to create + * many small files across directories. + * + * @param targetMB The target partition size in Megabytes. Defaults to 128MB. + * @group typedrel + * @since 4.2.0 + */ + def optimizePartitions(targetMB: Int = 128): Dataset[T] = { + throw new UnsupportedOperationException("This method is implemented in " + + "the concrete Dataset classes") + } + /** * Returns a new Dataset that has exactly `numPartitions` partitions, when the fewer partitions * are requested. If a larger number of partitions is requested, it will stay at the current diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizePartitionsRule.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizePartitionsRule.scala new file mode 100644 index 000000000000..688712dae97a --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizePartitionsRule.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OptimizePartitionsCommand, Repartition} +import org.apache.spark.sql.catalyst.rules.Rule + +/** + * Proactively optimizes the partition count of a Dataset based on its estimated size. + * This rule transforms the custom OptimizePartitionsCommand into standard Spark operations. + */ +object OptimizePartitionsRule extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = plan.transform { + case OptimizePartitionsCommand(child, targetMB, currentPartitions) => + + val targetBytes = targetMB.toLong * 1024L * 1024L + + // Get the estimated size from Catalyst Statistics + val sizeInBytes = child.stats.sizeInBytes + + // Calculate Optimal Partition Count (N) + val count = math.ceil(sizeInBytes.toDouble / targetBytes).toInt + val calculatedN: Int = if (count <= 1) 1 else count + + // Smart Switch: Coalesce vs Repartition + if (calculatedN < currentPartitions) { + // DOWNSCALING: Use Coalesce (shuffle = false) + Repartition(calculatedN, shuffle = false, child) + } else if (calculatedN > currentPartitions) { + // UPSCALING: Use Repartition (shuffle = true) + Repartition(calculatedN, shuffle = true, child) + } else { + // OPTIMAL + child + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/OptimizePartitionsCommand.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/OptimizePartitionsCommand.scala new file mode 100644 index 000000000000..bdbef97873fc --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/OptimizePartitionsCommand.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Attribute + +/** + * A logical command that hints to the optimizer that we want to + * automatically repartition the data based on statistics. + */ +case class OptimizePartitionsCommand(child: LogicalPlan, + targetMB: Int, + currentPartitions: Int) extends UnaryNode { + + override def output: Seq[Attribute] = child.output + + override protected def withNewChildInternal(newChild: LogicalPlan): OptimizePartitionsCommand = + copy(child = newChild) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala index d02b63b49ca5..8d9db248fb82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala @@ -1562,6 +1562,14 @@ class Dataset[T] private[sql]( } } + override def optimizePartitions(targetMB: Int): Dataset[T] = { + val currentPartitions = rdd.getNumPartitions + + withTypedPlan { + OptimizePartitionsCommand(logicalPlan, targetMB, currentPartitions) + } + } + /** @inheritdoc */ def coalesce(numPartitions: Int): Dataset[T] = withSameTypedPlan { Repartition(numPartitions, shuffle = false, logicalPlan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 7f3b8383f0f8..3484072fb3c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -96,7 +96,8 @@ class SparkOptimizer( ConstantFolding, EliminateLimits), Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*), - Batch("Replace CTE with Repartition", Once, ReplaceCTERefWithRepartition))) + Batch("Replace CTE with Repartition", Once, ReplaceCTERefWithRepartition), + Batch("Optimizer Partitions", Once, OptimizePartitionsRule))) override def nonExcludableRules: Seq[String] = super.nonExcludableRules ++ Seq( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/OptimizePartitionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/OptimizePartitionsSuite.scala new file mode 100644 index 000000000000..1284ba45335c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/OptimizePartitionsSuite.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.test.SharedSparkSession + +class OptimizePartitionsSuite extends SparkFunSuite with SharedSparkSession { + + test("TEST 1: Small Data Compaction (Coalesce)") { + val initialDF = spark.range(10000).repartition(100) + val optimizedDF = initialDF.optimizePartitions() + assert(optimizedDF.rdd.getNumPartitions == 1, + s"Expected 1 partition, got ${optimizedDF.rdd.getNumPartitions}.") + } + + test("TEST 2: Scaling Up (Large Data Repartition)") { + val initialDF = spark.range(500000).repartition(1) + // initialDF size is 4MB. + // Passing desired partition = 2MB to trigger increase in partition from 1 to 2. + val optimizedDF = initialDF.optimizePartitions(2) + // We expect number of partitions to increase to 2 so that each partition size is 2MB. + assert(optimizedDF.rdd.getNumPartitions == 2, + s"Expected scaling up, got ${optimizedDF.rdd.getNumPartitions}.") + } +}