diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
index a8ccc39ac478f..6b3744fe02d2e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
@@ -40,6 +40,7 @@ import org.apache.spark.sql.execution.{ColumnarRule, SparkPlan}
*
* - Analyzer Rules.
* - Check Analysis Rules.
+ * - Cache Plan Normalization Rules.
* - Optimizer Rules.
* - Pre CBO Rules.
* - Planning Strategies.
@@ -217,6 +218,22 @@ class SparkSessionExtensions {
checkRuleBuilders += builder
}
+ private[this] val planNormalizationRules = mutable.Buffer.empty[RuleBuilder]
+
+ def buildPlanNormalizationRules(session: SparkSession): Seq[Rule[LogicalPlan]] = {
+ planNormalizationRules.map(_.apply(session)).toSeq
+ }
+
+ /**
+ * Inject a plan normalization `Rule` builder into the [[SparkSession]]. The injected rules will
+ * be executed just before query caching decisions are made. Such rules can be used to improve the
+ * cache hit rate by normalizing different plans to the same form. These rules should never modify
+ * the result of the LogicalPlan.
+ */
+ def injectPlanNormalizationRules(builder: RuleBuilder): Unit = {
+ planNormalizationRules += builder
+ }
+
private[this] val optimizerRules = mutable.Buffer.empty[RuleBuilder]
private[sql] def buildOptimizerRules(session: SparkSession): Seq[Rule[LogicalPlan]] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
index e9bbbc717d1e4..d41611439f0ba 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
@@ -89,7 +89,7 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper {
query: Dataset[_],
tableName: Option[String] = None,
storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = {
- cacheQuery(query.sparkSession, query.logicalPlan, tableName, storageLevel)
+ cacheQuery(query.sparkSession, query.queryExecution.normalized, tableName, storageLevel)
}
/**
@@ -143,7 +143,7 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper {
def uncacheQuery(
query: Dataset[_],
cascade: Boolean): Unit = {
- uncacheQuery(query.sparkSession, query.logicalPlan, cascade)
+ uncacheQuery(query.sparkSession, query.queryExecution.normalized, cascade)
}
/**
@@ -281,7 +281,7 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper {
/** Optionally returns cached data for the given [[Dataset]] */
def lookupCachedData(query: Dataset[_]): Option[CachedData] = {
- lookupCachedData(query.logicalPlan)
+ lookupCachedData(query.queryExecution.normalized)
}
/** Optionally returns cached data for the given [[LogicalPlan]]. */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index 3706d5a1e3d4c..796ec41ab51c4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -105,12 +105,29 @@ class QueryExecution(
case other => other
}
+ // The plan that has been normalized by custom rules, so that it's more likely to hit cache.
+ lazy val normalized: LogicalPlan = {
+ val normalizationRules = sparkSession.sessionState.planNormalizationRules
+ if (normalizationRules.isEmpty) {
+ commandExecuted
+ } else {
+ val planChangeLogger = new PlanChangeLogger[LogicalPlan]()
+ val normalized = normalizationRules.foldLeft(commandExecuted) { (p, rule) =>
+ val result = rule.apply(p)
+ planChangeLogger.logRule(rule.ruleName, p, result)
+ result
+ }
+ planChangeLogger.logBatch("Plan Normalization", commandExecuted, normalized)
+ normalized
+ }
+ }
+
lazy val withCachedData: LogicalPlan = sparkSession.withActive {
assertAnalyzed()
assertSupported()
// clone the plan to avoid sharing the plan instance between different stages like analyzing,
// optimizing and planning.
- sparkSession.sharedState.cacheManager.useCachedData(commandExecuted.clone())
+ sparkSession.sharedState.cacheManager.useCachedData(normalized.clone())
}
def assertCommandExecuted(): Unit = commandExecuted
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
index b12a86c08d18b..f81b12796ce97 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
@@ -317,6 +317,10 @@ abstract class BaseSessionStateBuilder(
extensions.buildRuntimeOptimizerRules(session))
}
+ protected def planNormalizationRules: Seq[Rule[LogicalPlan]] = {
+ extensions.buildPlanNormalizationRules(session)
+ }
+
/**
* Create a query execution object.
*/
@@ -371,7 +375,8 @@ abstract class BaseSessionStateBuilder(
createQueryExecution,
createClone,
columnarRules,
- adaptiveRulesHolder)
+ adaptiveRulesHolder,
+ planNormalizationRules)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
index 1d5e61aab269c..eb0b71d155bab 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
@@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.AdaptiveRulesHolder
@@ -79,7 +80,8 @@ private[sql] class SessionState(
createQueryExecution: (LogicalPlan, CommandExecutionMode.Value) => QueryExecution,
createClone: (SparkSession, SessionState) => SessionState,
val columnarRules: Seq[ColumnarRule],
- val adaptiveRulesHolder: AdaptiveRulesHolder) {
+ val adaptiveRulesHolder: AdaptiveRulesHolder,
+ val planNormalizationRules: Seq[Rule[LogicalPlan]]) {
// The following fields are lazy to avoid creating the Hive client when creating SessionState.
lazy val catalog: SessionCatalog = catalogBuilder()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
index 10d2227324f18..f5f04eabec036 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
@@ -192,6 +192,23 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper {
testInjectColumnar(false)
}
+ test("inject plan normalization rules") {
+ val extensions = create { extensions =>
+ extensions.injectPlanNormalizationRules { session =>
+ org.apache.spark.sql.catalyst.optimizer.PushDownPredicates
+ }
+ }
+ withSession(extensions) { session =>
+ import session.implicits._
+ val df = Seq((1, "a"), (2, "b")).toDF("i", "s")
+ df.select("i").filter($"i" > 1).cache()
+ assert(df.filter($"i" > 1).select("i").queryExecution.executedPlan.find {
+ case _: org.apache.spark.sql.execution.columnar.InMemoryTableScanExec => true
+ case _ => false
+ }.isDefined)
+ }
+ }
+
test("SPARK-39991: AQE should retain column statistics from completed query stages") {
val extensions = create { extensions =>
extensions.injectColumnar(_ =>