Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,15 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val ADAPTIVE_EXECUTION_APPLY_FINAL_STAGE_SHUFFLE_OPTIMIZATIONS =
buildConf("spark.sql.adaptive.applyFinalStageShuffleOptimizations")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ulysses-you is this useful even for non-table-cache queries?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bucket table write ? but I think people would specify partition number explicitly if there is a shuffle on bucket column. I can't find other case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea this can be a use case. People may want to fully control the partitioning of the final write stage, which can affect number of files.

.internal()
.doc("Configures whether adaptive query execution (if enabled) should apply shuffle " +
"coalescing and local shuffle read optimization for the final query stage.")
.version("3.4.2")
.booleanConf
.createWithDefault(true)

val ADAPTIVE_EXECUTION_LOG_LEVEL = buildConf("spark.sql.adaptive.logLevel")
.internal()
.doc("Configures the log level for adaptive execution logging of plan changes. The value " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -402,8 +402,9 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper {
if (session.conf.get(SQLConf.CAN_CHANGE_CACHED_PLAN_OUTPUT_PARTITIONING)) {
// Bucketed scan only has one time overhead but can have multi-times benefits in cache,
// so we always do bucketed scan in a cached plan.
SparkSession.getOrCloneSessionWithConfigsOff(
session, SQLConf.AUTO_BUCKETED_SCAN_ENABLED :: Nil)
SparkSession.getOrCloneSessionWithConfigsOff(session,
SQLConf.ADAPTIVE_EXECUTION_APPLY_FINAL_STAGE_SHUFFLE_OPTIMIZATIONS ::
SQLConf.AUTO_BUCKETED_SCAN_ENABLED :: Nil)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a question here is: do we want strictly better performance, or generally better performance?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding an extra shuffle can be very expensive, so I'm inclined to this change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This approach is guaranteed better performance compared to AQE in SQL cache disabled.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, @maryannxue .

} else {
SparkSession.getOrCloneSessionWithConfigsOff(session, forceDisableConfigs)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,13 @@ case class AdaptiveSparkPlanExec(
)

private def optimizeQueryStage(plan: SparkPlan, isFinalStage: Boolean): SparkPlan = {
val optimized = queryStageOptimizerRules.foldLeft(plan) { case (latestPlan, rule) =>
val rules = if (isFinalStage &&
!conf.getConf(SQLConf.ADAPTIVE_EXECUTION_APPLY_FINAL_STAGE_SHUFFLE_OPTIMIZATIONS)) {
queryStageOptimizerRules.filterNot(_.isInstanceOf[AQEShuffleReadRule])
} else {
queryStageOptimizerRules
}
val optimized = rules.foldLeft(plan) { case (latestPlan, rule) =>
val applied = rule.apply(latestPlan)
val result = rule match {
case _: AQEShuffleReadRule if !applied.fastEquals(latestPlan) =>
Expand Down
52 changes: 37 additions & 15 deletions sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.commons.io.FileUtils
import org.apache.spark.CleanerListener
import org.apache.spark.executor.DataReadMethod._
import org.apache.spark.executor.DataReadMethod.DataReadMethod
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.TempTableAlreadyExistsException
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
Expand All @@ -39,6 +39,7 @@ import org.apache.spark.sql.execution.{ColumnarToRowExec, ExecSubqueryExpression
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, AQEPropagateEmptyRelation}
import org.apache.spark.sql.execution.columnar._
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils}
Expand Down Expand Up @@ -1629,23 +1630,44 @@ class CachedTableSuite extends QueryTest with SQLTestUtils
SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1",
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {

withTempView("t1", "t2", "t3") {
withSQLConf(SQLConf.CAN_CHANGE_CACHED_PLAN_OUTPUT_PARTITIONING.key -> "false") {
sql("CACHE TABLE t1 as SELECT /*+ REPARTITION */ * FROM values(1) as t(c)")
assert(spark.table("t1").rdd.partitions.length == 2)
var finalPlan = ""
val listener = new SparkListener {
override def onOtherEvent(event: SparkListenerEvent): Unit = {
event match {
case SparkListenerSQLAdaptiveExecutionUpdate(_, physicalPlanDesc, sparkPlanInfo) =>
if (sparkPlanInfo.simpleString.startsWith(
"AdaptiveSparkPlan isFinalPlan=true")) {
finalPlan = physicalPlanDesc
}
case _ => // ignore other events
}
}
}

withSQLConf(SQLConf.CAN_CHANGE_CACHED_PLAN_OUTPUT_PARTITIONING.key -> "true") {
assert(spark.table("t1").rdd.partitions.length == 2)
sql("CACHE TABLE t2 as SELECT /*+ REPARTITION */ * FROM values(2) as t(c)")
assert(spark.table("t2").rdd.partitions.length == 1)
}
withTempView("t0", "t1", "t2") {
try {
spark.range(10).write.saveAsTable("t0")
spark.sparkContext.listenerBus.waitUntilEmpty()
spark.sparkContext.addSparkListener(listener)

withSQLConf(SQLConf.CAN_CHANGE_CACHED_PLAN_OUTPUT_PARTITIONING.key -> "false") {
assert(spark.table("t1").rdd.partitions.length == 2)
assert(spark.table("t2").rdd.partitions.length == 1)
sql("CACHE TABLE t3 as SELECT /*+ REPARTITION */ * FROM values(3) as t(c)")
assert(spark.table("t3").rdd.partitions.length == 2)
withSQLConf(SQLConf.CAN_CHANGE_CACHED_PLAN_OUTPUT_PARTITIONING.key -> "false") {
sql("CACHE TABLE t1 as SELECT /*+ REPARTITION */ * FROM (" +
"SELECT distinct (id+1) FROM t0)")
assert(spark.table("t1").rdd.partitions.length == 2)
spark.sparkContext.listenerBus.waitUntilEmpty()
assert(finalPlan.nonEmpty && !finalPlan.contains("coalesced"))
}

finalPlan = "" // reset finalPlan
withSQLConf(SQLConf.CAN_CHANGE_CACHED_PLAN_OUTPUT_PARTITIONING.key -> "true") {
sql("CACHE TABLE t2 as SELECT /*+ REPARTITION */ * FROM (" +
"SELECT distinct (id-1) FROM t0)")
assert(spark.table("t2").rdd.partitions.length == 2)
spark.sparkContext.listenerBus.waitUntilEmpty()
assert(finalPlan.nonEmpty && finalPlan.contains("coalesced"))
}
} finally {
spark.sparkContext.removeSparkListener(listener)
}
}
}
Expand Down
33 changes: 23 additions & 10 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2655,16 +2655,29 @@ class DatasetSuite extends QueryTest
}

test("SPARK-45592: Coaleasced shuffle read is not compatible with hash partitioning") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maryannxue This test seems to pass on master branch without your patch. May I ask why this PR change this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

b/c this bug has been fixed by the other PR, right?
This patch disables the last stage coalescing in SQL cache. So it won't cause perf regression. It would still pass if you reverted the other fix. But the modified tests in CachedTableSuite would verify the new behavior of this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I simply changed this test to make it more robust in terms of reproducing the original bug.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it.

val ee = spark.range(0, 1000000, 1, 5).map(l => (l, l)).toDF()
.persist(org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK)
ee.count()

val minNbrs1 = ee
.groupBy("_1").agg(min(col("_2")).as("min_number"))
.persist(org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK)

val join = ee.join(minNbrs1, "_1")
assert(join.count() == 1000000)
withSQLConf(SQLConf.CAN_CHANGE_CACHED_PLAN_OUTPUT_PARTITIONING.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.SHUFFLE_PARTITIONS.key -> "20",
SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "2000") {
val ee = spark.range(0, 1000, 1, 5).map(l => (l, l - 1)).toDF()
.persist(org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK)
ee.count()

// `minNbrs1` will start with 20 partitions and without the fix would coalesce to ~10
// partitions.
val minNbrs1 = ee
.groupBy("_2").agg(min(col("_1")).as("min_number"))
.select(col("_2") as "_1", col("min_number"))
.persist(org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK)
minNbrs1.count()

// shuffle on `ee` will start with 2 partitions, smaller than `minNbrs1`'s partition num,
// and `EnsureRequirements` will change its partition num to `minNbrs1`'s partition num.
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "5") {
val join = ee.join(minNbrs1, "_1")
assert(join.count() == 999)
}
}
}

test("SPARK-45022: exact DatasetQueryContext call site") {
Expand Down