diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala index 6952f4bfd056..d5d969032a5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala @@ -40,6 +40,7 @@ import org.apache.spark.sql.execution.{ColumnarRule, SparkPlan} *
  • Analyzer Rules.
  • *
  • Check Analysis Rules.
  • *
  • Optimizer Rules.
  • + *
  • Data Source Rewrite Rules.
  • *
  • Planning Strategies.
  • *
  • Customized Parser.
  • *
  • (External) Catalog listeners.
  • @@ -199,6 +200,21 @@ class SparkSessionExtensions { optimizerRules += builder } + private[this] val dataSourceRewriteRules = mutable.Buffer.empty[RuleBuilder] + + private[sql] def buildDataSourceRewriteRules(session: SparkSession): Seq[Rule[LogicalPlan]] = { + dataSourceRewriteRules.map(_.apply(session)).toSeq + } + + /** + * Inject an optimizer `Rule` builder that rewrites data source plans into the [[SparkSession]]. + * The injected rules will be executed after the operator optimization batch and before rules + * that depend on stats. + */ + def injectDataSourceRewriteRule(builder: RuleBuilder): Unit = { + dataSourceRewriteRules += builder + } + private[this] val plannerStrategyBuilders = mutable.Buffer.empty[StrategyBuilder] private[sql] def buildPlannerStrategies(session: SparkSession): Seq[Strategy] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 8101f9e291b4..f51ee11091d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -273,7 +273,9 @@ abstract class BaseSessionStateBuilder( * * Note that this may NOT depend on the `optimizer` function. */ - protected def customDataSourceRewriteRules: Seq[Rule[LogicalPlan]] = Nil + protected def customDataSourceRewriteRules: Seq[Rule[LogicalPlan]] = { + extensions.buildDataSourceRewriteRules(session) + } /** * Planner that converts optimized logical plans to physical plans. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala index 12abd31b99e9..37bb7a5450ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -88,6 +88,12 @@ class SparkSessionExtensionSuite extends SparkFunSuite { } } + test("SPARK-33621: inject data source rewrite rule") { + withSession(Seq(_.injectDataSourceRewriteRule(MyRule))) { session => + assert(session.sessionState.optimizer.dataSourceRewriteRules.contains(MyRule(session))) + } + } + test("inject spark planner strategy") { withSession(Seq(_.injectPlannerStrategy(MySparkStrategy))) { session => assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session)))