Skip to content

Commit f65ebe3

Browse files
committed
normalize filters in FileScan.equals()
1 parent e1a9722 commit f65ebe3

File tree

1 file changed

+17
-3
lines changed
  • sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2

1 file changed

+17
-3
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@ import org.apache.hadoop.fs.Path
2424
import org.apache.spark.internal.Logging
2525
import org.apache.spark.internal.config.IO_WARNING_LARGEFILETHRESHOLD
2626
import org.apache.spark.sql.{AnalysisException, SparkSession}
27-
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionSet}
27+
import org.apache.spark.sql.catalyst.expressions.{AttributeSet, Expression, ExpressionSet}
2828
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
29+
import org.apache.spark.sql.catalyst.plans.QueryPlan
2930
import org.apache.spark.sql.connector.read.{Batch, InputPartition, Scan, Statistics, SupportsReportStatistics}
3031
import org.apache.spark.sql.execution.PartitionedFileUtil
3132
import org.apache.spark.sql.execution.datasources._
@@ -84,11 +85,24 @@ trait FileScan extends Scan
8485

8586
protected def seqToString(seq: Seq[Any]): String = seq.mkString("[", ", ", "]")
8687

88+
private lazy val (normalizedPartitionFilters, normalizedDataFilters) = {
89+
val output = readSchema().toAttributes
90+
val partitionFilterAttributes = AttributeSet(partitionFilters).map(a => a.name -> a).toMap
91+
val dataFiltersAttributes = AttributeSet(dataFilters).map(a => a.name -> a).toMap
92+
val normalizedPartitionFilters = ExpressionSet(partitionFilters.map(
93+
QueryPlan.normalizeExpressions(_, output.map(a =>
94+
partitionFilterAttributes.getOrElse(a.name, a)))))
95+
val normalizedDataFilters = ExpressionSet(dataFilters.map(
96+
QueryPlan.normalizeExpressions(_, output.map(a =>
97+
dataFiltersAttributes.getOrElse(a.name, a)))))
98+
(normalizedPartitionFilters, normalizedDataFilters)
99+
}
100+
87101
override def equals(obj: Any): Boolean = obj match {
88102
case f: FileScan =>
89103
fileIndex == f.fileIndex && readSchema == f.readSchema &&
90-
ExpressionSet(partitionFilters) == ExpressionSet(f.partitionFilters) &&
91-
ExpressionSet(dataFilters) == ExpressionSet(f.dataFilters)
104+
normalizedPartitionFilters == f.normalizedPartitionFilters &&
105+
normalizedDataFilters == f.normalizedDataFilters
92106

93107
case _ => false
94108
}

0 commit comments

Comments
 (0)