Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import org.apache.spark.sql.execution.{ColumnarRule, SparkPlan}
* <ul>
* <li>Analyzer Rules.</li>
* <li>Check Analysis Rules.</li>
* <li>Cache Plan Normalization Rules.</li>
* <li>Optimizer Rules.</li>
* <li>Pre CBO Rules.</li>
* <li>Planning Strategies.</li>
Expand Down Expand Up @@ -217,6 +218,22 @@ class SparkSessionExtensions {
checkRuleBuilders += builder
}

private[this] val planNormalizationRules = mutable.Buffer.empty[RuleBuilder]

def buildPlanNormalizationRules(session: SparkSession): Seq[Rule[LogicalPlan]] = {
planNormalizationRules.map(_.apply(session)).toSeq
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: isn't .apply(...) redundant with just (...) ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm following the existing code style in this file. I assume the reason is people who are not familiar with Scala may be confused when reading the code .map(_(session))

}

/**
* Inject a plan normalization `Rule` builder into the [[SparkSession]]. The injected rules will
* be executed just before query caching decisions are made. Such rules can be used to improve the
* cache hit rate by normalizing different plans to the same form. These rules should never modify
* the result of the LogicalPlan.
*/
def injectPlanNormalizationRules(builder: RuleBuilder): Unit = {
planNormalizationRules += builder
}

private[this] val optimizerRules = mutable.Buffer.empty[RuleBuilder]

private[sql] def buildOptimizerRules(session: SparkSession): Seq[Rule[LogicalPlan]] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper {
query: Dataset[_],
tableName: Option[String] = None,
storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = {
cacheQuery(query.sparkSession, query.logicalPlan, tableName, storageLevel)
cacheQuery(query.sparkSession, query.queryExecution.normalized, tableName, storageLevel)
}

/**
Expand Down Expand Up @@ -143,7 +143,7 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper {
def uncacheQuery(
query: Dataset[_],
cascade: Boolean): Unit = {
uncacheQuery(query.sparkSession, query.logicalPlan, cascade)
uncacheQuery(query.sparkSession, query.queryExecution.normalized, cascade)
}

/**
Expand Down Expand Up @@ -281,7 +281,7 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper {

/** Optionally returns cached data for the given [[Dataset]] */
def lookupCachedData(query: Dataset[_]): Option[CachedData] = {
lookupCachedData(query.logicalPlan)
lookupCachedData(query.queryExecution.normalized)
}

/** Optionally returns cached data for the given [[LogicalPlan]]. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,29 @@ class QueryExecution(
case other => other
}

// The plan that has been normalized by custom rules, so that it's more likely to hit cache.
lazy val normalized: LogicalPlan = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps adding some comments for this.

val normalizationRules = sparkSession.sessionState.planNormalizationRules
if (normalizationRules.isEmpty) {
commandExecuted
} else {
val planChangeLogger = new PlanChangeLogger[LogicalPlan]()
val normalized = normalizationRules.foldLeft(commandExecuted) { (p, rule) =>
val result = rule.apply(p)
planChangeLogger.logRule(rule.ruleName, p, result)
result
}
planChangeLogger.logBatch("Plan Normalization", commandExecuted, normalized)
normalized
}
}

lazy val withCachedData: LogicalPlan = sparkSession.withActive {
assertAnalyzed()
assertSupported()
// clone the plan to avoid sharing the plan instance between different stages like analyzing,
// optimizing and planning.
sparkSession.sharedState.cacheManager.useCachedData(commandExecuted.clone())
sparkSession.sharedState.cacheManager.useCachedData(normalized.clone())
}

def assertCommandExecuted(): Unit = commandExecuted
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,10 @@ abstract class BaseSessionStateBuilder(
extensions.buildRuntimeOptimizerRules(session))
}

protected def planNormalizationRules: Seq[Rule[LogicalPlan]] = {
extensions.buildPlanNormalizationRules(session)
}

/**
* Create a query execution object.
*/
Expand Down Expand Up @@ -371,7 +375,8 @@ abstract class BaseSessionStateBuilder(
createQueryExecution,
createClone,
columnarRules,
adaptiveRulesHolder)
adaptiveRulesHolder,
planNormalizationRules)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.AdaptiveRulesHolder
Expand Down Expand Up @@ -79,7 +80,8 @@ private[sql] class SessionState(
createQueryExecution: (LogicalPlan, CommandExecutionMode.Value) => QueryExecution,
createClone: (SparkSession, SessionState) => SessionState,
val columnarRules: Seq[ColumnarRule],
val adaptiveRulesHolder: AdaptiveRulesHolder) {
val adaptiveRulesHolder: AdaptiveRulesHolder,
val planNormalizationRules: Seq[Rule[LogicalPlan]]) {

// The following fields are lazy to avoid creating the Hive client when creating SessionState.
lazy val catalog: SessionCatalog = catalogBuilder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,23 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper {
testInjectColumnar(false)
}

test("inject plan normalization rules") {
val extensions = create { extensions =>
extensions.injectPlanNormalizationRules { session =>
org.apache.spark.sql.catalyst.optimizer.PushDownPredicates
}
}
withSession(extensions) { session =>
import session.implicits._
val df = Seq((1, "a"), (2, "b")).toDF("i", "s")
df.select("i").filter($"i" > 1).cache()
assert(df.filter($"i" > 1).select("i").queryExecution.executedPlan.find {
case _: org.apache.spark.sql.execution.columnar.InMemoryTableScanExec => true
case _ => false
Comment on lines +204 to +207
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So without the added rule, caching is unable to apply here, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a negative test that verifies this? Might be overkill...

}.isDefined)
}
}

test("SPARK-39991: AQE should retain column statistics from completed query stages") {
val extensions = create { extensions =>
extensions.injectColumnar(_ =>
Expand Down