Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ case class AvroScan(
readDataSchema: StructType,
readPartitionSchema: StructType,
options: CaseInsensitiveStringMap,
partitionFilters: Seq[Expression] = Seq.empty) extends FileScan {
partitionFilters: Seq[Expression] = Seq.empty,
dataFilters: Seq[Expression] = Seq.empty) extends FileScan {
override def isSplitable(path: Path): Boolean = true

override def createReaderFactory(): PartitionReaderFactory = {
Expand All @@ -51,8 +52,9 @@ case class AvroScan(
dataSchema, readDataSchema, readPartitionSchema, caseSensitiveMap)
}

override def withPartitionFilters(partitionFilters: Seq[Expression]): FileScan =
this.copy(partitionFilters = partitionFilters)
override def withFilters(
partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan =
this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters)

override def equals(obj: Any): Boolean = obj match {
case a: AvroScan => super.equals(a) && dataSchema == a.dataSchema && options == a.options
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1566,6 +1566,7 @@ class AvroV2Suite extends AvroSuite {
}
assert(fileScan.nonEmpty)
assert(fileScan.get.partitionFilters.nonEmpty)
assert(fileScan.get.dataFilters.nonEmpty)
assert(fileScan.get.planInputPartitions().forall { partition =>
partition.asInstanceOf[FilePartition].files.forall { file =>
file.filePath.contains("p1=1") && file.filePath.contains("p2=2")
Expand All @@ -1575,6 +1576,34 @@ class AvroV2Suite extends AvroSuite {
}
}

test("Avro source v2: support passing data filters to FileScan without partitionFilters") {
withTempPath { dir =>
Seq(("a", 1, 2), ("b", 1, 2), ("c", 2, 1))
.toDF("value", "p1", "p2")
.write
.format("avro")
.save(dir.getCanonicalPath)
val df = spark
.read
.format("avro")
.load(dir.getCanonicalPath)
.where("value = 'a'")

val filterCondition = df.queryExecution.optimizedPlan.collectFirst {
case f: Filter => f.condition
}
assert(filterCondition.isDefined)

val fileScan = df.queryExecution.executedPlan collectFirst {
case BatchScanExec(_, f: AvroScan) => f
}
assert(fileScan.nonEmpty)
assert(fileScan.get.partitionFilters.isEmpty)
assert(fileScan.get.dataFilters.nonEmpty)
checkAnswer(df, Row("a", 1, 2))
}
}

private def getBatchScanExec(plan: SparkPlan): BatchScanExec = {
plan.find(_.isInstanceOf[BatchScanExec]).get.asInstanceOf[BatchScanExec]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,22 @@ import org.apache.spark.sql.types.StructType

private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] {

private def getPartitionKeyFilters(
private def getPartitionKeyFiltersAndDataFilters(
sparkSession: SparkSession,
relation: LeafNode,
partitionSchema: StructType,
filters: Seq[Expression],
output: Seq[AttributeReference]): ExpressionSet = {
output: Seq[AttributeReference]): (ExpressionSet, Seq[Expression]) = {
val normalizedFilters = DataSourceStrategy.normalizeExprs(
filters.filter(f => f.deterministic && !SubqueryExpression.hasSubquery(f)), output)
val partitionColumns =
relation.resolve(partitionSchema, sparkSession.sessionState.analyzer.resolver)
val partitionSet = AttributeSet(partitionColumns)
ExpressionSet(normalizedFilters.filter { f =>
val (partitionFilters, dataFilters) = normalizedFilters.partition(f =>
f.references.subsetOf(partitionSet)
})
)

(ExpressionSet(partitionFilters), dataFilters)
}

private def rebuildPhysicalOperation(
Expand Down Expand Up @@ -72,7 +74,7 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] {
_,
_))
if filters.nonEmpty && fsRelation.partitionSchemaOption.isDefined =>
val partitionKeyFilters = getPartitionKeyFilters(
val (partitionKeyFilters, _) = getPartitionKeyFiltersAndDataFilters(
fsRelation.sparkSession, logicalRelation, partitionSchema, filters, logicalRelation.output)
if (partitionKeyFilters.nonEmpty) {
val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq)
Expand All @@ -92,11 +94,13 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] {
case op @ PhysicalOperation(projects, filters,
v2Relation @ DataSourceV2ScanRelation(_, scan: FileScan, output))
if filters.nonEmpty && scan.readDataSchema.nonEmpty =>
val partitionKeyFilters = getPartitionKeyFilters(scan.sparkSession,
v2Relation, scan.readPartitionSchema, filters, output)
if (partitionKeyFilters.nonEmpty) {
val (partitionKeyFilters, dataFilters) =
getPartitionKeyFiltersAndDataFilters(scan.sparkSession, v2Relation,
scan.readPartitionSchema, filters, output)
// The dataFilters are pushed down only once
if (partitionKeyFilters.nonEmpty || (dataFilters.nonEmpty && scan.dataFilters.isEmpty)) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason for the condition

(dataFilters.nonEmpty && scan.dataFilters.isEmpty)

Is that unlike the partitionFilters which are pushed down and don't need to be reevaluated (which will make the partitionKeyFilters.nonEmpty to be false in the next iteration) the dataFilters will remain non empty so scan.dataFilters.isEmpty is needed to make sure we don't get stack overflow.

val prunedV2Relation =
v2Relation.copy(scan = scan.withPartitionFilters(partitionKeyFilters.toSeq))
v2Relation.copy(scan = scan.withFilters(partitionKeyFilters.toSeq, dataFilters))
// The pushed down partition filters don't need to be reevaluated.
val afterScanFilters =
ExpressionSet(filters) -- partitionKeyFilters.filter(_.references.nonEmpty)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,15 @@ trait FileScan extends Scan with Batch with SupportsReportStatistics with Loggin
def partitionFilters: Seq[Expression]

/**
* Create a new `FileScan` instance from the current one with different `partitionFilters`.
* Returns the data filters that can be use for file listing
*/
def withPartitionFilters(partitionFilters: Seq[Expression]): FileScan
def dataFilters: Seq[Expression]

/**
* Create a new `FileScan` instance from the current one
* with different `partitionFilters` and `dataFilters`
*/
def withFilters(partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan

/**
* If a file with `path` is unsplittable, return the unsplittable reason,
Expand All @@ -79,7 +85,8 @@ trait FileScan extends Scan with Batch with SupportsReportStatistics with Loggin
override def equals(obj: Any): Boolean = obj match {
case f: FileScan =>
fileIndex == f.fileIndex && readSchema == f.readSchema
ExpressionSet(partitionFilters) == ExpressionSet(f.partitionFilters)
ExpressionSet(partitionFilters) == ExpressionSet(f.partitionFilters) &&
ExpressionSet(dataFilters) == ExpressionSet(f.dataFilters)

case _ => false
}
Expand All @@ -92,6 +99,7 @@ trait FileScan extends Scan with Batch with SupportsReportStatistics with Loggin
val metadata: Map[String, String] = Map(
"ReadSchema" -> readDataSchema.catalogString,
"PartitionFilters" -> seqToString(partitionFilters),
"DataFilters" -> seqToString(dataFilters),
"Location" -> locationDesc)
val metadataStr = metadata.toSeq.sorted.map {
case (key, value) =>
Expand All @@ -103,7 +111,7 @@ trait FileScan extends Scan with Batch with SupportsReportStatistics with Loggin
}

protected def partitions: Seq[FilePartition] = {
val selectedPartitions = fileIndex.listFiles(partitionFilters, Seq.empty)
val selectedPartitions = fileIndex.listFiles(partitionFilters, dataFilters)
val maxSplitBytes = FilePartition.maxSplitBytes(sparkSession, selectedPartitions)
val partitionAttributes = fileIndex.partitionSchema.toAttributes
val attributeMap = partitionAttributes.map(a => normalizeName(a.name) -> a).toMap
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ case class CSVScan(
readPartitionSchema: StructType,
options: CaseInsensitiveStringMap,
pushedFilters: Array[Filter],
partitionFilters: Seq[Expression] = Seq.empty)
partitionFilters: Seq[Expression] = Seq.empty,
dataFilters: Seq[Expression] = Seq.empty)
extends TextBasedFileScan(sparkSession, options) {

private lazy val parsedOptions: CSVOptions = new CSVOptions(
Expand Down Expand Up @@ -91,8 +92,9 @@ case class CSVScan(
dataSchema, readDataSchema, readPartitionSchema, parsedOptions, pushedFilters)
}

override def withPartitionFilters(partitionFilters: Seq[Expression]): FileScan =
this.copy(partitionFilters = partitionFilters)
override def withFilters(
partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan =
this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters)

override def equals(obj: Any): Boolean = obj match {
case c: CSVScan => super.equals(c) && dataSchema == c.dataSchema && options == c.options &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ case class JsonScan(
readDataSchema: StructType,
readPartitionSchema: StructType,
options: CaseInsensitiveStringMap,
partitionFilters: Seq[Expression] = Seq.empty)
partitionFilters: Seq[Expression] = Seq.empty,
dataFilters: Seq[Expression] = Seq.empty)
extends TextBasedFileScan(sparkSession, options) {

private val parsedOptions = new JSONOptionsInRead(
Expand Down Expand Up @@ -88,8 +89,9 @@ case class JsonScan(
dataSchema, readDataSchema, readPartitionSchema, parsedOptions)
}

override def withPartitionFilters(partitionFilters: Seq[Expression]): FileScan =
this.copy(partitionFilters = partitionFilters)
override def withFilters(
partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan =
this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters)

override def equals(obj: Any): Boolean = obj match {
case j: JsonScan => super.equals(j) && dataSchema == j.dataSchema && options == j.options
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ case class OrcScan(
readPartitionSchema: StructType,
options: CaseInsensitiveStringMap,
pushedFilters: Array[Filter],
partitionFilters: Seq[Expression] = Seq.empty) extends FileScan {
partitionFilters: Seq[Expression] = Seq.empty,
dataFilters: Seq[Expression] = Seq.empty) extends FileScan {
override def isSplitable(path: Path): Boolean = true

override def createReaderFactory(): PartitionReaderFactory = {
Expand All @@ -64,6 +65,7 @@ case class OrcScan(
super.description() + ", PushedFilters: " + seqToString(pushedFilters)
}

override def withPartitionFilters(partitionFilters: Seq[Expression]): FileScan =
this.copy(partitionFilters = partitionFilters)
override def withFilters(
partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan =
this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters)
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ case class ParquetScan(
readPartitionSchema: StructType,
pushedFilters: Array[Filter],
options: CaseInsensitiveStringMap,
partitionFilters: Seq[Expression] = Seq.empty) extends FileScan {
partitionFilters: Seq[Expression] = Seq.empty,
dataFilters: Seq[Expression] = Seq.empty) extends FileScan {
override def isSplitable(path: Path): Boolean = true

override def createReaderFactory(): PartitionReaderFactory = {
Expand Down Expand Up @@ -92,6 +93,7 @@ case class ParquetScan(
super.description() + ", PushedFilters: " + seqToString(pushedFilters)
}

override def withPartitionFilters(partitionFilters: Seq[Expression]): FileScan =
this.copy(partitionFilters = partitionFilters)
override def withFilters(
partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan =
this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters)
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ case class TextScan(
readDataSchema: StructType,
readPartitionSchema: StructType,
options: CaseInsensitiveStringMap,
partitionFilters: Seq[Expression] = Seq.empty)
partitionFilters: Seq[Expression] = Seq.empty,
dataFilters: Seq[Expression] = Seq.empty)
extends TextBasedFileScan(sparkSession, options) {

private val optionsAsScala = options.asScala.toMap
Expand Down Expand Up @@ -70,8 +71,9 @@ case class TextScan(
readPartitionSchema, textOptions)
}

override def withPartitionFilters(partitionFilters: Seq[Expression]): FileScan =
this.copy(partitionFilters = partitionFilters)
override def withFilters(
partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan =
this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters)

override def equals(obj: Any): Boolean = obj match {
case t: TextScan => super.equals(t) && options == t.options
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,7 @@ class FileBasedDataSourceSuite extends QueryTest
}
assert(fileScan.nonEmpty)
assert(fileScan.get.partitionFilters.nonEmpty)
assert(fileScan.get.dataFilters.nonEmpty)
assert(fileScan.get.planInputPartitions().forall { partition =>
partition.asInstanceOf[FilePartition].files.forall { file =>
file.filePath.contains("p1=1") && file.filePath.contains("p2=2")
Expand All @@ -786,6 +787,41 @@ class FileBasedDataSourceSuite extends QueryTest
}
}

test("File source v2: support passing data filters to FileScan without partitionFilters") {
withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") {
allFileBasedDataSources.foreach { format =>
withTempPath { dir =>
Seq(("a", 1, 2), ("b", 1, 2), ("c", 2, 1))
.toDF("value", "p1", "p2")
.write
.format(format)
.partitionBy("p1", "p2")
.option("header", true)
.save(dir.getCanonicalPath)
val df = spark
.read
.format(format)
.option("header", true)
.load(dir.getCanonicalPath)
.where("value = 'a'")

val filterCondition = df.queryExecution.optimizedPlan.collectFirst {
case f: Filter => f.condition
}
assert(filterCondition.isDefined)

val fileScan = df.queryExecution.executedPlan collectFirst {
case BatchScanExec(_, f: FileScan) => f
}
assert(fileScan.nonEmpty)
assert(fileScan.get.partitionFilters.isEmpty)
assert(fileScan.get.dataFilters.nonEmpty)
checkAnswer(df, Row("a", 1, 2))
}
}
}
}

test("File table location should include both values of option `path` and `paths`") {
withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") {
withTempPaths(3) { paths =>
Expand Down