Skip to content

Commit 2d59ca4

Browse files
guykhazmagengliangwang
authored andcommitted
[SPARK-30475][SQL] File source V2: Push data filters for file listing
### What changes were proposed in this pull request? Follow up on [SPARK-30428](#27112) which added support for partition pruning in File source V2. This PR implements the necessary changes in order to pass the `dataFilters` to the `listFiles`. This enables having `FileIndex` implementations which use the `dataFilters` for further pruning the file listing (see the discussion [here](#27112 (comment))). ### Why are the changes needed? Datasources such as `csv` and `json` do not implement the `SupportsPushDownFilters` trait. In order to support data skipping uniformly for all file based data sources, one can override the `listFiles` method in a `FileIndex` implementation, which consults external metadata and prunes the list of files. ### Does this PR introduce any user-facing change? No ### How was this patch tested? Modifying the unit tests for v2 file sources to verify the `dataFilters` are passed Closes #27157 from guykhazma/PushdataFiltersInFileListing. Authored-by: Guy Khazma <guykhag@gmail.com> Signed-off-by: Gengliang Wang <gengliang.wang@databricks.com>
1 parent 94284c8 commit 2d59ca4

File tree

10 files changed

+120
-31
lines changed

10 files changed

+120
-31
lines changed

external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ case class AvroScan(
3737
readDataSchema: StructType,
3838
readPartitionSchema: StructType,
3939
options: CaseInsensitiveStringMap,
40-
partitionFilters: Seq[Expression] = Seq.empty) extends FileScan {
40+
partitionFilters: Seq[Expression] = Seq.empty,
41+
dataFilters: Seq[Expression] = Seq.empty) extends FileScan {
4142
override def isSplitable(path: Path): Boolean = true
4243

4344
override def createReaderFactory(): PartitionReaderFactory = {
@@ -53,8 +54,9 @@ case class AvroScan(
5354
dataSchema, readDataSchema, readPartitionSchema, parsedOptions)
5455
}
5556

56-
override def withPartitionFilters(partitionFilters: Seq[Expression]): FileScan =
57-
this.copy(partitionFilters = partitionFilters)
57+
override def withFilters(
58+
partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan =
59+
this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters)
5860

5961
override def equals(obj: Any): Boolean = obj match {
6062
case a: AvroScan => super.equals(a) && dataSchema == a.dataSchema && options == a.options

external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1566,6 +1566,7 @@ class AvroV2Suite extends AvroSuite {
15661566
}
15671567
assert(fileScan.nonEmpty)
15681568
assert(fileScan.get.partitionFilters.nonEmpty)
1569+
assert(fileScan.get.dataFilters.nonEmpty)
15691570
assert(fileScan.get.planInputPartitions().forall { partition =>
15701571
partition.asInstanceOf[FilePartition].files.forall { file =>
15711572
file.filePath.contains("p1=1") && file.filePath.contains("p2=2")
@@ -1575,6 +1576,34 @@ class AvroV2Suite extends AvroSuite {
15751576
}
15761577
}
15771578

1579+
test("Avro source v2: support passing data filters to FileScan without partitionFilters") {
1580+
withTempPath { dir =>
1581+
Seq(("a", 1, 2), ("b", 1, 2), ("c", 2, 1))
1582+
.toDF("value", "p1", "p2")
1583+
.write
1584+
.format("avro")
1585+
.save(dir.getCanonicalPath)
1586+
val df = spark
1587+
.read
1588+
.format("avro")
1589+
.load(dir.getCanonicalPath)
1590+
.where("value = 'a'")
1591+
1592+
val filterCondition = df.queryExecution.optimizedPlan.collectFirst {
1593+
case f: Filter => f.condition
1594+
}
1595+
assert(filterCondition.isDefined)
1596+
1597+
val fileScan = df.queryExecution.executedPlan collectFirst {
1598+
case BatchScanExec(_, f: AvroScan) => f
1599+
}
1600+
assert(fileScan.nonEmpty)
1601+
assert(fileScan.get.partitionFilters.isEmpty)
1602+
assert(fileScan.get.dataFilters.nonEmpty)
1603+
checkAnswer(df, Row("a", 1, 2))
1604+
}
1605+
}
1606+
15781607
private def getBatchScanExec(plan: SparkPlan): BatchScanExec = {
15791608
plan.find(_.isInstanceOf[BatchScanExec]).get.asInstanceOf[BatchScanExec]
15801609
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,22 @@ import org.apache.spark.sql.types.StructType
2828

2929
private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] {
3030

31-
private def getPartitionKeyFilters(
31+
private def getPartitionKeyFiltersAndDataFilters(
3232
sparkSession: SparkSession,
3333
relation: LeafNode,
3434
partitionSchema: StructType,
3535
filters: Seq[Expression],
36-
output: Seq[AttributeReference]): ExpressionSet = {
36+
output: Seq[AttributeReference]): (ExpressionSet, Seq[Expression]) = {
3737
val normalizedFilters = DataSourceStrategy.normalizeExprs(
3838
filters.filter(f => f.deterministic && !SubqueryExpression.hasSubquery(f)), output)
3939
val partitionColumns =
4040
relation.resolve(partitionSchema, sparkSession.sessionState.analyzer.resolver)
4141
val partitionSet = AttributeSet(partitionColumns)
42-
ExpressionSet(normalizedFilters.filter { f =>
42+
val (partitionFilters, dataFilters) = normalizedFilters.partition(f =>
4343
f.references.subsetOf(partitionSet)
44-
})
44+
)
45+
46+
(ExpressionSet(partitionFilters), dataFilters)
4547
}
4648

4749
private def rebuildPhysicalOperation(
@@ -72,7 +74,7 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] {
7274
_,
7375
_))
7476
if filters.nonEmpty && fsRelation.partitionSchemaOption.isDefined =>
75-
val partitionKeyFilters = getPartitionKeyFilters(
77+
val (partitionKeyFilters, _) = getPartitionKeyFiltersAndDataFilters(
7678
fsRelation.sparkSession, logicalRelation, partitionSchema, filters, logicalRelation.output)
7779
if (partitionKeyFilters.nonEmpty) {
7880
val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq)
@@ -92,11 +94,13 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] {
9294
case op @ PhysicalOperation(projects, filters,
9395
v2Relation @ DataSourceV2ScanRelation(_, scan: FileScan, output))
9496
if filters.nonEmpty && scan.readDataSchema.nonEmpty =>
95-
val partitionKeyFilters = getPartitionKeyFilters(scan.sparkSession,
96-
v2Relation, scan.readPartitionSchema, filters, output)
97-
if (partitionKeyFilters.nonEmpty) {
97+
val (partitionKeyFilters, dataFilters) =
98+
getPartitionKeyFiltersAndDataFilters(scan.sparkSession, v2Relation,
99+
scan.readPartitionSchema, filters, output)
100+
// The dataFilters are pushed down only once
101+
if (partitionKeyFilters.nonEmpty || (dataFilters.nonEmpty && scan.dataFilters.isEmpty)) {
98102
val prunedV2Relation =
99-
v2Relation.copy(scan = scan.withPartitionFilters(partitionKeyFilters.toSeq))
103+
v2Relation.copy(scan = scan.withFilters(partitionKeyFilters.toSeq, dataFilters))
100104
// The pushed down partition filters don't need to be reevaluated.
101105
val afterScanFilters =
102106
ExpressionSet(filters) -- partitionKeyFilters.filter(_.references.nonEmpty)

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,15 @@ trait FileScan extends Scan with Batch with SupportsReportStatistics with Loggin
6161
def partitionFilters: Seq[Expression]
6262

6363
/**
64-
* Create a new `FileScan` instance from the current one with different `partitionFilters`.
64+
* Returns the data filters that can be use for file listing
6565
*/
66-
def withPartitionFilters(partitionFilters: Seq[Expression]): FileScan
66+
def dataFilters: Seq[Expression]
67+
68+
/**
69+
* Create a new `FileScan` instance from the current one
70+
* with different `partitionFilters` and `dataFilters`
71+
*/
72+
def withFilters(partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan
6773

6874
/**
6975
* If a file with `path` is unsplittable, return the unsplittable reason,
@@ -79,7 +85,8 @@ trait FileScan extends Scan with Batch with SupportsReportStatistics with Loggin
7985
override def equals(obj: Any): Boolean = obj match {
8086
case f: FileScan =>
8187
fileIndex == f.fileIndex && readSchema == f.readSchema
82-
ExpressionSet(partitionFilters) == ExpressionSet(f.partitionFilters)
88+
ExpressionSet(partitionFilters) == ExpressionSet(f.partitionFilters) &&
89+
ExpressionSet(dataFilters) == ExpressionSet(f.dataFilters)
8390

8491
case _ => false
8592
}
@@ -92,6 +99,7 @@ trait FileScan extends Scan with Batch with SupportsReportStatistics with Loggin
9299
val metadata: Map[String, String] = Map(
93100
"ReadSchema" -> readDataSchema.catalogString,
94101
"PartitionFilters" -> seqToString(partitionFilters),
102+
"DataFilters" -> seqToString(dataFilters),
95103
"Location" -> locationDesc)
96104
val metadataStr = metadata.toSeq.sorted.map {
97105
case (key, value) =>
@@ -103,7 +111,7 @@ trait FileScan extends Scan with Batch with SupportsReportStatistics with Loggin
103111
}
104112

105113
protected def partitions: Seq[FilePartition] = {
106-
val selectedPartitions = fileIndex.listFiles(partitionFilters, Seq.empty)
114+
val selectedPartitions = fileIndex.listFiles(partitionFilters, dataFilters)
107115
val maxSplitBytes = FilePartition.maxSplitBytes(sparkSession, selectedPartitions)
108116
val partitionAttributes = fileIndex.partitionSchema.toAttributes
109117
val attributeMap = partitionAttributes.map(a => normalizeName(a.name) -> a).toMap

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ case class CSVScan(
4040
readPartitionSchema: StructType,
4141
options: CaseInsensitiveStringMap,
4242
pushedFilters: Array[Filter],
43-
partitionFilters: Seq[Expression] = Seq.empty)
43+
partitionFilters: Seq[Expression] = Seq.empty,
44+
dataFilters: Seq[Expression] = Seq.empty)
4445
extends TextBasedFileScan(sparkSession, options) {
4546

4647
private lazy val parsedOptions: CSVOptions = new CSVOptions(
@@ -91,8 +92,9 @@ case class CSVScan(
9192
dataSchema, readDataSchema, readPartitionSchema, parsedOptions, pushedFilters)
9293
}
9394

94-
override def withPartitionFilters(partitionFilters: Seq[Expression]): FileScan =
95-
this.copy(partitionFilters = partitionFilters)
95+
override def withFilters(
96+
partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan =
97+
this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters)
9698

9799
override def equals(obj: Any): Boolean = obj match {
98100
case c: CSVScan => super.equals(c) && dataSchema == c.dataSchema && options == c.options &&

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ case class JsonScan(
3939
readDataSchema: StructType,
4040
readPartitionSchema: StructType,
4141
options: CaseInsensitiveStringMap,
42-
partitionFilters: Seq[Expression] = Seq.empty)
42+
partitionFilters: Seq[Expression] = Seq.empty,
43+
dataFilters: Seq[Expression] = Seq.empty)
4344
extends TextBasedFileScan(sparkSession, options) {
4445

4546
private val parsedOptions = new JSONOptionsInRead(
@@ -88,8 +89,9 @@ case class JsonScan(
8889
dataSchema, readDataSchema, readPartitionSchema, parsedOptions)
8990
}
9091

91-
override def withPartitionFilters(partitionFilters: Seq[Expression]): FileScan =
92-
this.copy(partitionFilters = partitionFilters)
92+
override def withFilters(
93+
partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan =
94+
this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters)
9395

9496
override def equals(obj: Any): Boolean = obj match {
9597
case j: JsonScan => super.equals(j) && dataSchema == j.dataSchema && options == j.options

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ case class OrcScan(
3838
readPartitionSchema: StructType,
3939
options: CaseInsensitiveStringMap,
4040
pushedFilters: Array[Filter],
41-
partitionFilters: Seq[Expression] = Seq.empty) extends FileScan {
41+
partitionFilters: Seq[Expression] = Seq.empty,
42+
dataFilters: Seq[Expression] = Seq.empty) extends FileScan {
4243
override def isSplitable(path: Path): Boolean = true
4344

4445
override def createReaderFactory(): PartitionReaderFactory = {
@@ -64,6 +65,7 @@ case class OrcScan(
6465
super.description() + ", PushedFilters: " + seqToString(pushedFilters)
6566
}
6667

67-
override def withPartitionFilters(partitionFilters: Seq[Expression]): FileScan =
68-
this.copy(partitionFilters = partitionFilters)
68+
override def withFilters(
69+
partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan =
70+
this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters)
6971
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ case class ParquetScan(
4141
readPartitionSchema: StructType,
4242
pushedFilters: Array[Filter],
4343
options: CaseInsensitiveStringMap,
44-
partitionFilters: Seq[Expression] = Seq.empty) extends FileScan {
44+
partitionFilters: Seq[Expression] = Seq.empty,
45+
dataFilters: Seq[Expression] = Seq.empty) extends FileScan {
4546
override def isSplitable(path: Path): Boolean = true
4647

4748
override def createReaderFactory(): PartitionReaderFactory = {
@@ -92,6 +93,7 @@ case class ParquetScan(
9293
super.description() + ", PushedFilters: " + seqToString(pushedFilters)
9394
}
9495

95-
override def withPartitionFilters(partitionFilters: Seq[Expression]): FileScan =
96-
this.copy(partitionFilters = partitionFilters)
96+
override def withFilters(
97+
partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan =
98+
this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters)
9799
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScan.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ case class TextScan(
3636
readDataSchema: StructType,
3737
readPartitionSchema: StructType,
3838
options: CaseInsensitiveStringMap,
39-
partitionFilters: Seq[Expression] = Seq.empty)
39+
partitionFilters: Seq[Expression] = Seq.empty,
40+
dataFilters: Seq[Expression] = Seq.empty)
4041
extends TextBasedFileScan(sparkSession, options) {
4142

4243
private val optionsAsScala = options.asScala.toMap
@@ -70,8 +71,9 @@ case class TextScan(
7071
readPartitionSchema, textOptions)
7172
}
7273

73-
override def withPartitionFilters(partitionFilters: Seq[Expression]): FileScan =
74-
this.copy(partitionFilters = partitionFilters)
74+
override def withFilters(
75+
partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan =
76+
this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters)
7577

7678
override def equals(obj: Any): Boolean = obj match {
7779
case t: TextScan => super.equals(t) && options == t.options

sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -775,6 +775,7 @@ class FileBasedDataSourceSuite extends QueryTest
775775
}
776776
assert(fileScan.nonEmpty)
777777
assert(fileScan.get.partitionFilters.nonEmpty)
778+
assert(fileScan.get.dataFilters.nonEmpty)
778779
assert(fileScan.get.planInputPartitions().forall { partition =>
779780
partition.asInstanceOf[FilePartition].files.forall { file =>
780781
file.filePath.contains("p1=1") && file.filePath.contains("p2=2")
@@ -786,6 +787,41 @@ class FileBasedDataSourceSuite extends QueryTest
786787
}
787788
}
788789

790+
test("File source v2: support passing data filters to FileScan without partitionFilters") {
791+
withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") {
792+
allFileBasedDataSources.foreach { format =>
793+
withTempPath { dir =>
794+
Seq(("a", 1, 2), ("b", 1, 2), ("c", 2, 1))
795+
.toDF("value", "p1", "p2")
796+
.write
797+
.format(format)
798+
.partitionBy("p1", "p2")
799+
.option("header", true)
800+
.save(dir.getCanonicalPath)
801+
val df = spark
802+
.read
803+
.format(format)
804+
.option("header", true)
805+
.load(dir.getCanonicalPath)
806+
.where("value = 'a'")
807+
808+
val filterCondition = df.queryExecution.optimizedPlan.collectFirst {
809+
case f: Filter => f.condition
810+
}
811+
assert(filterCondition.isDefined)
812+
813+
val fileScan = df.queryExecution.executedPlan collectFirst {
814+
case BatchScanExec(_, f: FileScan) => f
815+
}
816+
assert(fileScan.nonEmpty)
817+
assert(fileScan.get.partitionFilters.isEmpty)
818+
assert(fileScan.get.dataFilters.nonEmpty)
819+
checkAnswer(df, Row("a", 1, 2))
820+
}
821+
}
822+
}
823+
}
824+
789825
test("File table location should include both values of option `path` and `paths`") {
790826
withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") {
791827
withTempPaths(3) { paths =>

0 commit comments

Comments
 (0)