diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index c199df676ced..60d6e60b2e16 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -49,6 +49,15 @@ case class BatchScanExec( } override def doCanonicalize(): BatchScanExec = { - this.copy(output = output.map(QueryPlan.normalizeExpressions(_, output))) + val canonicalizedScan = scan match { + case s: FileScan => + s.withFilters( + QueryPlan.normalizePredicates(s.partitionFilters, output), + QueryPlan.normalizePredicates(s.dataFilters, output)) + case _ => scan + } + this.copy( + output = output.map(QueryPlan.normalizeExpressions(_, output)), + scan = canonicalizedScan) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 1273a91e5510..6cd3408deedc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -40,6 +40,7 @@ import org.apache.spark.sql.execution.datasources.{LogicalRelation, SchemaColumn import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan +import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -4065,6 +4066,29 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } } } + + test("SPARK-33482: Fix FileScan canonicalization") { + withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") { + withTempPath { path => + spark.range(5).toDF().write.mode("overwrite").parquet(path.toString) + withTempView("t") { + spark.read.parquet(path.toString).createOrReplaceTempView("t") + val df = sql( + """ + |SELECT * + |FROM t AS t1 + |JOIN t AS t2 ON t2.id = t1.id + |JOIN t AS t3 ON t3.id = t2.id + |""".stripMargin) + df.collect() + val reusedExchanges = collect(df.queryExecution.executedPlan) { + case r: ReusedExchangeExec => r + } + assert(reusedExchanges.size == 1) + } + } + } + } } case class Foo(bar: Option[String])