@@ -22,9 +22,10 @@ import org.scalatest.BeforeAndAfterAll
2222import org .apache .spark .{SparkConf , SparkFunSuite }
2323import org .apache .spark .internal .config .UI .UI_ENABLED
2424import org .apache .spark .sql ._
25+ import org .apache .spark .sql .catalyst .plans .physical .HashPartitioning
2526import org .apache .spark .sql .execution .adaptive ._
2627import org .apache .spark .sql .execution .adaptive .CustomShuffleReaderExec
27- import org .apache .spark .sql .execution .exchange .{REPARTITION_BY_NONE , ReusedExchangeExec }
28+ import org .apache .spark .sql .execution .exchange .{REPARTITION_BY_NONE , ReusedExchangeExec , ShuffleExchangeExec }
2829import org .apache .spark .sql .execution .joins .BroadcastNestedLoopJoinExec
2930import org .apache .spark .sql .functions ._
3031import org .apache .spark .sql .internal .SQLConf
@@ -521,6 +522,63 @@ class CoalesceShufflePartitionsSuite
521522 withSparkSession(test, 10000 , minPartitionNum)
522523 }
523524 }
525+
526+ test(s " Ignore SinglePartition when determining expectedChildrenNumPartitions " ) {
527+ val test : SparkSession => Unit = { spark : SparkSession =>
528+ try {
529+ val aeqShufflePartitionNum = 100
530+ spark.conf.set(SQLConf .COALESCE_PARTITIONS_INITIAL_PARTITION_NUM .key,
531+ aeqShufflePartitionNum)
532+ val df = spark.range(100 ).toDF(" id" ).selectExpr(" id" , " id as name" )
533+ df.write.format(" parquet" ).saveAsTable(" t1" )
534+ df.write.format(" parquet" ).saveAsTable(" t2" )
535+ df.write.bucketBy(10 , " id" ).format(" parquet" ).saveAsTable(" t3" )
536+ df.write.bucketBy(20 , " id" ).format(" parquet" ).saveAsTable(" t4" )
537+
538+ /* SinglePartition Join with parquet table */
539+ val join1 = spark.sql(
540+ s """ SELECT *
541+ |FROM (select * from t1 limit 1000) t1
542+ |JOIN t2
543+ |ON t1.id = t2.id
544+ | """ .stripMargin)
545+ join1.collect()
546+ assert(collect(join1.queryExecution.executedPlan) {
547+ case r @ ShuffleExchangeExec (HashPartitioning (_, 100 ), _, _) => r
548+ }.length === 2 )
549+
550+ /* Bucket join: Only shuffle parquet table to bucket number */
551+ val join2 = spark.sql(
552+ s """ SELECT *
553+ |FROM t3
554+ |JOIN t2
555+ |ON t3.id = t2.id
556+ | """ .stripMargin)
557+ join2.collect()
558+ assert(collect(join2.queryExecution.executedPlan) {
559+ case r @ ShuffleExchangeExec (HashPartitioning (_, 10 ), _, _) => r
560+ }.length === 1 )
561+
562+ /* Two bucket table join: select the smaller bucket number */
563+ val join3 = spark.sql(
564+ s """ SELECT *
565+ |FROM t3
566+ |JOIN t4
567+ |ON t3.id = t4.id
568+ | """ .stripMargin)
569+ join3.collect()
570+ assert(collect(join3.queryExecution.executedPlan) {
571+ case r @ ShuffleExchangeExec (HashPartitioning (_, 20 ), _, _) => r
572+ }.length === 1 )
573+ } finally {
574+ Seq (" t1" , " t2" , " t3" , " t4" ).foreach { name =>
575+ spark.sql(s " DROP TABLE IF EXISTS $name" )
576+ }
577+ }
578+ }
579+
580+ withSparkSession(test, 10000 , None )
581+ }
524582}
525583
526584object CoalescedShuffleReader {
0 commit comments