diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
index 6952f4bfd056..d5d969032a5e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
@@ -40,6 +40,7 @@ import org.apache.spark.sql.execution.{ColumnarRule, SparkPlan}
*
Analyzer Rules.
* Check Analysis Rules.
* Optimizer Rules.
+ * Data Source Rewrite Rules.
* Planning Strategies.
* Customized Parser.
* (External) Catalog listeners.
@@ -199,6 +200,21 @@ class SparkSessionExtensions {
optimizerRules += builder
}
+ private[this] val dataSourceRewriteRules = mutable.Buffer.empty[RuleBuilder]
+
+ private[sql] def buildDataSourceRewriteRules(session: SparkSession): Seq[Rule[LogicalPlan]] = {
+ dataSourceRewriteRules.map(_.apply(session)).toSeq
+ }
+
+ /**
+ * Inject an optimizer `Rule` builder that rewrites data source plans into the [[SparkSession]].
+ * The injected rules will be executed after the operator optimization batch and before rules
+ * that depend on stats.
+ */
+ def injectDataSourceRewriteRule(builder: RuleBuilder): Unit = {
+ dataSourceRewriteRules += builder
+ }
+
private[this] val plannerStrategyBuilders = mutable.Buffer.empty[StrategyBuilder]
private[sql] def buildPlannerStrategies(session: SparkSession): Seq[Strategy] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
index 8101f9e291b4..f51ee11091d0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
@@ -273,7 +273,9 @@ abstract class BaseSessionStateBuilder(
*
* Note that this may NOT depend on the `optimizer` function.
*/
- protected def customDataSourceRewriteRules: Seq[Rule[LogicalPlan]] = Nil
+ protected def customDataSourceRewriteRules: Seq[Rule[LogicalPlan]] = {
+ extensions.buildDataSourceRewriteRules(session)
+ }
/**
* Planner that converts optimized logical plans to physical plans.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
index 12abd31b99e9..37bb7a5450ae 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
@@ -88,6 +88,12 @@ class SparkSessionExtensionSuite extends SparkFunSuite {
}
}
+ test("SPARK-33621: inject data source rewrite rule") {
+ withSession(Seq(_.injectDataSourceRewriteRule(MyRule))) { session =>
+ assert(session.sessionState.optimizer.dataSourceRewriteRules.contains(MyRule(session)))
+ }
+ }
+
test("inject spark planner strategy") {
withSession(Seq(_.injectPlannerStrategy(MySparkStrategy))) { session =>
assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session)))