Skip to content

Commit 9e899c1

Browse files
wakunGitHub Enterprise
authored andcommitted
[CARMEL-6586] Ignore SinglePartition when determining expectedChildrenNumPartitions (#1252)
1 parent 25a3b90 commit 9e899c1

File tree

2 files changed

+65
-2
lines changed

2 files changed

+65
-2
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,12 @@ object EnsureRequirements extends Rule[SparkPlan] {
285285
val nonShuffleChildrenNumPartitions =
286286
childrenIndexes.map(children).filterNot(_.isInstanceOf[ShuffleExchangeExec])
287287
.map(_.outputPartitioning.numPartitions)
288-
val expectedChildrenNumPartitions = if (nonShuffleChildrenNumPartitions.nonEmpty) {
288+
val allSinglePartition =
289+
childrenIndexes.map(children).filterNot(_.isInstanceOf[ShuffleExchangeExec])
290+
.forall(_.outputPartitioning == SinglePartition)
291+
val expectedChildrenNumPartitions = if (allSinglePartition) {
292+
conf.numShufflePartitions
293+
} else if (nonShuffleChildrenNumPartitions.nonEmpty) {
289294
if (nonShuffleChildrenNumPartitions.length == childrenIndexes.length) {
290295
// Here we pick the max number of partitions among these non-shuffle children.
291296
nonShuffleChildrenNumPartitions.max

sql/core/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@ import org.scalatest.BeforeAndAfterAll
2222
import org.apache.spark.{SparkConf, SparkFunSuite}
2323
import org.apache.spark.internal.config.UI.UI_ENABLED
2424
import org.apache.spark.sql._
25+
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
2526
import org.apache.spark.sql.execution.adaptive._
2627
import org.apache.spark.sql.execution.adaptive.CustomShuffleReaderExec
27-
import org.apache.spark.sql.execution.exchange.{REPARTITION_BY_NONE, ReusedExchangeExec}
28+
import org.apache.spark.sql.execution.exchange.{REPARTITION_BY_NONE, ReusedExchangeExec, ShuffleExchangeExec}
2829
import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoinExec
2930
import org.apache.spark.sql.functions._
3031
import org.apache.spark.sql.internal.SQLConf
@@ -521,6 +522,63 @@ class CoalesceShufflePartitionsSuite
521522
withSparkSession(test, 10000, minPartitionNum)
522523
}
523524
}
525+
526+
test(s"Ignore SinglePartition when determining expectedChildrenNumPartitions") {
527+
val test: SparkSession => Unit = { spark: SparkSession =>
528+
try {
529+
val aeqShufflePartitionNum = 100
530+
spark.conf.set(SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key,
531+
aeqShufflePartitionNum)
532+
val df = spark.range(100).toDF("id").selectExpr("id", "id as name")
533+
df.write.format("parquet").saveAsTable("t1")
534+
df.write.format("parquet").saveAsTable("t2")
535+
df.write.bucketBy(10, "id").format("parquet").saveAsTable("t3")
536+
df.write.bucketBy(20, "id").format("parquet").saveAsTable("t4")
537+
538+
/* SinglePartition Join with parquet table */
539+
val join1 = spark.sql(
540+
s"""SELECT *
541+
|FROM (select * from t1 limit 1000) t1
542+
|JOIN t2
543+
|ON t1.id = t2.id
544+
|""".stripMargin)
545+
join1.collect()
546+
assert(collect(join1.queryExecution.executedPlan) {
547+
case r @ ShuffleExchangeExec(HashPartitioning(_, 100), _, _) => r
548+
}.length === 2)
549+
550+
/* Bucket join: Only shuffle parquet table to bucket number */
551+
val join2 = spark.sql(
552+
s"""SELECT *
553+
|FROM t3
554+
|JOIN t2
555+
|ON t3.id = t2.id
556+
|""".stripMargin)
557+
join2.collect()
558+
assert(collect(join2.queryExecution.executedPlan) {
559+
case r @ ShuffleExchangeExec(HashPartitioning(_, 10), _, _) => r
560+
}.length === 1)
561+
562+
/* Two bucket table join: select the smaller bucket number */
563+
val join3 = spark.sql(
564+
s"""SELECT *
565+
|FROM t3
566+
|JOIN t4
567+
|ON t3.id = t4.id
568+
|""".stripMargin)
569+
join3.collect()
570+
assert(collect(join3.queryExecution.executedPlan) {
571+
case r @ ShuffleExchangeExec(HashPartitioning(_, 20), _, _) => r
572+
}.length === 1)
573+
} finally {
574+
Seq("t1", "t2", "t3", "t4").foreach { name =>
575+
spark.sql(s"DROP TABLE IF EXISTS $name")
576+
}
577+
}
578+
}
579+
580+
withSparkSession(test, 10000, None)
581+
}
524582
}
525583

526584
object CoalescedShuffleReader {

0 commit comments

Comments
 (0)