diff --git a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala index e1268ac2ce58..bb840e69d99a 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala @@ -21,6 +21,7 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.fs.Path import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.connector.read.PartitionReaderFactory import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.v2.FileScan @@ -34,19 +35,30 @@ case class AvroScan( dataSchema: StructType, readDataSchema: StructType, readPartitionSchema: StructType, - options: CaseInsensitiveStringMap) - extends FileScan(sparkSession, fileIndex, readDataSchema, readPartitionSchema) { - override def isSplitable(path: Path): Boolean = true - - override def createReaderFactory(): PartitionReaderFactory = { - val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap - // Hadoop Configurations are case sensitive. - val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) - val broadcastedConf = sparkSession.sparkContext.broadcast( - new SerializableConfiguration(hadoopConf)) - // The partition values are already truncated in `FileScan.partitions`. - // We should use `readPartitionSchema` as the partition schema here. - AvroPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf, - dataSchema, readDataSchema, readPartitionSchema, caseSensitiveMap) - } + options: CaseInsensitiveStringMap, + partitionFilters: Seq[Expression] = Seq.empty) extends FileScan { + override def isSplitable(path: Path): Boolean = true + + override def createReaderFactory(): PartitionReaderFactory = { + val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap + // Hadoop Configurations are case sensitive. + val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) + val broadcastedConf = sparkSession.sparkContext.broadcast( + new SerializableConfiguration(hadoopConf)) + // The partition values are already truncated in `FileScan.partitions`. + // We should use `readPartitionSchema` as the partition schema here. + AvroPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf, + dataSchema, readDataSchema, readPartitionSchema, caseSensitiveMap) + } + + override def withPartitionFilters(partitionFilters: Seq[Expression]): FileScan = + this.copy(partitionFilters = partitionFilters) + + override def equals(obj: Any): Boolean = obj match { + case a: AvroScan => super.equals(a) && dataSchema == a.dataSchema && options == a.options + + case _ => false } + + override def hashCode(): Int = super.hashCode() +} diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index dc60cfe41ca7..3f2744014c19 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -36,10 +36,15 @@ import org.apache.commons.io.FileUtils import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql._ import org.apache.spark.sql.TestingUDT.{IntervalData, NullData, NullUDT} -import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.plans.logical.Filter +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.datasources.{DataSource, FilePartition} +import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ +import org.apache.spark.sql.v2.avro.AvroScan import org.apache.spark.util.Utils abstract class AvroSuite extends QueryTest with SharedSparkSession { @@ -1502,8 +1507,75 @@ class AvroV1Suite extends AvroSuite { } class AvroV2Suite extends AvroSuite { + import testImplicits._ + override protected def sparkConf: SparkConf = super .sparkConf .set(SQLConf.USE_V1_SOURCE_LIST, "") + + test("Avro source v2: support partition pruning") { + withTempPath { dir => + Seq(("a", 1, 2), ("b", 1, 2), ("c", 2, 1)) + .toDF("value", "p1", "p2") + .write + .format("avro") + .partitionBy("p1", "p2") + .option("header", true) + .save(dir.getCanonicalPath) + val df = spark + .read + .format("avro") + .option("header", true) + .load(dir.getCanonicalPath) + .where("p1 = 1 and p2 = 2 and value != \"a\"") + + val filterCondition = df.queryExecution.optimizedPlan.collectFirst { + case f: Filter => f.condition + } + assert(filterCondition.isDefined) + // The partitions filters should be pushed down and no need to be reevaluated. + assert(filterCondition.get.collectFirst { + case a: AttributeReference if a.name == "p1" || a.name == "p2" => a + }.isEmpty) + + val fileScan = df.queryExecution.executedPlan collectFirst { + case BatchScanExec(_, f: AvroScan) => f + } + assert(fileScan.nonEmpty) + assert(fileScan.get.partitionFilters.nonEmpty) + assert(fileScan.get.planInputPartitions().forall { partition => + partition.asInstanceOf[FilePartition].files.forall { file => + file.filePath.contains("p1=1") && file.filePath.contains("p2=2") + } + }) + checkAnswer(df, Row("b", 1, 2)) + } + } + + private def getBatchScanExec(plan: SparkPlan): BatchScanExec = { + plan.find(_.isInstanceOf[BatchScanExec]).get.asInstanceOf[BatchScanExec] + } + + test("Avro source v2: same result with different orders of data filters and partition filters") { + withTempPath { path => + val tmpDir = path.getCanonicalPath + spark + .range(10) + .selectExpr("id as a", "id + 1 as b", "id + 2 as c", "id + 3 as d") + .write + .partitionBy("a", "b") + .format("avro") + .save(tmpDir) + val df = spark.read.format("avro").load(tmpDir) + // partition filters: a > 1 AND b < 9 + // data filters: c > 1 AND d < 9 + val plan1 = df.where("a > 1 AND b < 9 AND c > 1 AND d < 9").queryExecution.sparkPlan + val plan2 = df.where("b < 9 AND a > 1 AND d < 9 AND c > 1").queryExecution.sparkPlan + assert(plan1.sameResult(plan2)) + val scan1 = getBatchScanExec(plan1) + val scan2 = getBatchScanExec(plan2) + assert(scan1.sameResult(scan2)) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index e65faefad5b9..013d94768a2a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -37,7 +37,7 @@ class SparkOptimizer( override def earlyScanPushDownRules: Seq[Rule[LogicalPlan]] = // TODO: move SchemaPruning into catalyst - SchemaPruning :: PruneFileSourcePartitions :: V2ScanRelationPushDown :: Nil + SchemaPruning :: V2ScanRelationPushDown :: PruneFileSourcePartitions :: Nil override def defaultBatches: Seq[Batch] = (preOptimizationBatches ++ super.defaultBatches :+ Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala index 02d629721327..7fd154ccac44 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -17,13 +17,46 @@ package org.apache.spark.sql.execution.datasources +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.catalog.CatalogStatistics import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation -import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LeafNode, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, FileScan, FileTable} +import org.apache.spark.sql.types.StructType private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { + + private def getPartitionKeyFilters( + sparkSession: SparkSession, + relation: LeafNode, + partitionSchema: StructType, + filters: Seq[Expression], + output: Seq[AttributeReference]): ExpressionSet = { + val normalizedFilters = DataSourceStrategy.normalizeExprs( + filters.filter(f => f.deterministic && !SubqueryExpression.hasSubquery(f)), output) + val partitionColumns = + relation.resolve(partitionSchema, sparkSession.sessionState.analyzer.resolver) + val partitionSet = AttributeSet(partitionColumns) + ExpressionSet(normalizedFilters.filter { f => + f.references.subsetOf(partitionSet) + }) + } + + private def rebuildPhysicalOperation( + projects: Seq[NamedExpression], + filters: Seq[Expression], + relation: LeafNode): Project = { + val withFilter = if (filters.nonEmpty) { + val filterExpression = filters.reduceLeft(And) + Filter(filterExpression, relation) + } else { + relation + } + Project(projects, withFilter) + } + override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { case op @ PhysicalOperation(projects, filters, logicalRelation @ @@ -39,31 +72,35 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { _, _)) if filters.nonEmpty && fsRelation.partitionSchemaOption.isDefined => - val normalizedFilters = DataSourceStrategy.normalizeExprs( - filters.filterNot(SubqueryExpression.hasSubquery), logicalRelation.output) - - val sparkSession = fsRelation.sparkSession - val partitionColumns = - logicalRelation.resolve( - partitionSchema, sparkSession.sessionState.analyzer.resolver) - val partitionSet = AttributeSet(partitionColumns) - val partitionKeyFilters = ExpressionSet(normalizedFilters.filter { f => - f.references.subsetOf(partitionSet) - }) - + val partitionKeyFilters = getPartitionKeyFilters( + fsRelation.sparkSession, logicalRelation, partitionSchema, filters, logicalRelation.output) if (partitionKeyFilters.nonEmpty) { val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq) val prunedFsRelation = - fsRelation.copy(location = prunedFileIndex)(sparkSession) + fsRelation.copy(location = prunedFileIndex)(fsRelation.sparkSession) // Change table stats based on the sizeInBytes of pruned files val withStats = logicalRelation.catalogTable.map(_.copy( stats = Some(CatalogStatistics(sizeInBytes = BigInt(prunedFileIndex.sizeInBytes))))) val prunedLogicalRelation = logicalRelation.copy( relation = prunedFsRelation, catalogTable = withStats) // Keep partition-pruning predicates so that they are visible in physical planning - val filterExpression = filters.reduceLeft(And) - val filter = Filter(filterExpression, prunedLogicalRelation) - Project(projects, filter) + rebuildPhysicalOperation(projects, filters, prunedLogicalRelation) + } else { + op + } + + case op @ PhysicalOperation(projects, filters, + v2Relation @ DataSourceV2ScanRelation(_, scan: FileScan, output)) + if filters.nonEmpty && scan.readDataSchema.nonEmpty => + val partitionKeyFilters = getPartitionKeyFilters(scan.sparkSession, + v2Relation, scan.readPartitionSchema, filters, output) + if (partitionKeyFilters.nonEmpty) { + val prunedV2Relation = + v2Relation.copy(scan = scan.withPartitionFilters(partitionKeyFilters.toSeq)) + // The pushed down partition filters don't need to be reevaluated. + val afterScanFilters = + ExpressionSet(filters) -- partitionKeyFilters.filter(_.references.nonEmpty) + rebuildPhysicalOperation(projects, afterScanFilters.toSeq, prunedV2Relation) } else { op } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala index 55104a2b21de..a22e1ccfe451 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala @@ -24,6 +24,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging import org.apache.spark.internal.config.IO_WARNING_LARGEFILETHRESHOLD import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionSet} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.connector.read.{Batch, InputPartition, Scan, Statistics, SupportsReportStatistics} import org.apache.spark.sql.execution.PartitionedFileUtil @@ -32,13 +33,7 @@ import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils -abstract class FileScan( - sparkSession: SparkSession, - fileIndex: PartitioningAwareFileIndex, - readDataSchema: StructType, - readPartitionSchema: StructType) - extends Scan - with Batch with SupportsReportStatistics with Logging { +trait FileScan extends Scan with Batch with SupportsReportStatistics with Logging { /** * Returns whether a file with `path` could be split or not. */ @@ -46,6 +41,30 @@ abstract class FileScan( false } + def sparkSession: SparkSession + + def fileIndex: PartitioningAwareFileIndex + + /** + * Returns the required data schema + */ + def readDataSchema: StructType + + /** + * Returns the required partition schema + */ + def readPartitionSchema: StructType + + /** + * Returns the filters that can be use for partition pruning + */ + def partitionFilters: Seq[Expression] + + /** + * Create a new `FileScan` instance from the current one with different `partitionFilters`. + */ + def withPartitionFilters(partitionFilters: Seq[Expression]): FileScan + /** * If a file with `path` is unsplittable, return the unsplittable reason, * otherwise return `None`. @@ -55,11 +74,24 @@ abstract class FileScan( "undefined" } + protected def seqToString(seq: Seq[Any]): String = seq.mkString("[", ", ", "]") + + override def equals(obj: Any): Boolean = obj match { + case f: FileScan => + fileIndex == f.fileIndex && readSchema == f.readSchema + ExpressionSet(partitionFilters) == ExpressionSet(f.partitionFilters) + + case _ => false + } + + override def hashCode(): Int = getClass.hashCode() + override def description(): String = { val locationDesc = fileIndex.getClass.getSimpleName + fileIndex.rootPaths.mkString("[", ", ", "]") val metadata: Map[String, String] = Map( "ReadSchema" -> readDataSchema.catalogString, + "PartitionFilters" -> seqToString(partitionFilters), "Location" -> locationDesc) val metadataStr = metadata.toSeq.sorted.map { case (key, value) => @@ -71,7 +103,7 @@ abstract class FileScan( } protected def partitions: Seq[FilePartition] = { - val selectedPartitions = fileIndex.listFiles(Seq.empty, Seq.empty) + val selectedPartitions = fileIndex.listFiles(partitionFilters, Seq.empty) val maxSplitBytes = FilePartition.maxSplitBytes(sparkSession, selectedPartitions) val partitionAttributes = fileIndex.partitionSchema.toAttributes val attributeMap = partitionAttributes.map(a => normalizeName(a.name) -> a).toMap diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TextBasedFileScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TextBasedFileScan.scala index 7ddd99a0293b..1ca3fd42c059 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TextBasedFileScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TextBasedFileScan.scala @@ -29,11 +29,7 @@ import org.apache.spark.util.Utils abstract class TextBasedFileScan( sparkSession: SparkSession, - fileIndex: PartitioningAwareFileIndex, - readDataSchema: StructType, - readPartitionSchema: StructType, - options: CaseInsensitiveStringMap) - extends FileScan(sparkSession, fileIndex, readDataSchema, readPartitionSchema) { + options: CaseInsensitiveStringMap) extends FileScan { @transient private lazy val codecFactory: CompressionCodecFactory = new CompressionCodecFactory( sparkSession.sessionState.newHadoopConfWithOptions(options.asScala.toMap)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala index 5125de9313a4..78b04aa811e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala @@ -22,11 +22,11 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.csv.CSVOptions -import org.apache.spark.sql.catalyst.expressions.ExprUtils +import org.apache.spark.sql.catalyst.expressions.{Expression, ExprUtils} import org.apache.spark.sql.connector.read.PartitionReaderFactory import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.csv.CSVDataSource -import org.apache.spark.sql.execution.datasources.v2.TextBasedFileScan +import org.apache.spark.sql.execution.datasources.v2.{FileScan, TextBasedFileScan} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.SerializableConfiguration @@ -37,8 +37,9 @@ case class CSVScan( dataSchema: StructType, readDataSchema: StructType, readPartitionSchema: StructType, - options: CaseInsensitiveStringMap) - extends TextBasedFileScan(sparkSession, fileIndex, readDataSchema, readPartitionSchema, options) { + options: CaseInsensitiveStringMap, + partitionFilters: Seq[Expression] = Seq.empty) + extends TextBasedFileScan(sparkSession, options) { private lazy val parsedOptions: CSVOptions = new CSVOptions( options.asScala.toMap, @@ -87,4 +88,15 @@ case class CSVScan( CSVPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf, dataSchema, readDataSchema, readPartitionSchema, parsedOptions) } + + override def withPartitionFilters(partitionFilters: Seq[Expression]): FileScan = + this.copy(partitionFilters = partitionFilters) + + override def equals(obj: Any): Boolean = obj match { + case c: CSVScan => super.equals(c) && dataSchema == c.dataSchema && options == c.options + + case _ => false + } + + override def hashCode(): Int = super.hashCode() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala index a64b78d3c830..153b402476c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala @@ -21,13 +21,13 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.fs.Path import org.apache.spark.sql.{AnalysisException, SparkSession} -import org.apache.spark.sql.catalyst.expressions.ExprUtils +import org.apache.spark.sql.catalyst.expressions.{Expression, ExprUtils} import org.apache.spark.sql.catalyst.json.JSONOptionsInRead import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.connector.read.PartitionReaderFactory import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.json.JsonDataSource -import org.apache.spark.sql.execution.datasources.v2.TextBasedFileScan +import org.apache.spark.sql.execution.datasources.v2.{FileScan, TextBasedFileScan} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.SerializableConfiguration @@ -38,8 +38,9 @@ case class JsonScan( dataSchema: StructType, readDataSchema: StructType, readPartitionSchema: StructType, - options: CaseInsensitiveStringMap) - extends TextBasedFileScan(sparkSession, fileIndex, readDataSchema, readPartitionSchema, options) { + options: CaseInsensitiveStringMap, + partitionFilters: Seq[Expression] = Seq.empty) + extends TextBasedFileScan(sparkSession, options) { private val parsedOptions = new JSONOptionsInRead( CaseInsensitiveMap(options.asScala.toMap), @@ -86,4 +87,15 @@ case class JsonScan( JsonPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf, dataSchema, readDataSchema, readPartitionSchema, parsedOptions) } + + override def withPartitionFilters(partitionFilters: Seq[Expression]): FileScan = + this.copy(partitionFilters = partitionFilters) + + override def equals(obj: Any): Boolean = obj match { + case j: JsonScan => super.equals(j) && dataSchema == j.dataSchema && options == j.options + + case _ => false + } + + override def hashCode(): Int = super.hashCode() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala index 40784516a6f3..f0595cb6d09c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala @@ -20,6 +20,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.connector.read.PartitionReaderFactory import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.v2.FileScan @@ -36,8 +37,8 @@ case class OrcScan( readDataSchema: StructType, readPartitionSchema: StructType, options: CaseInsensitiveStringMap, - pushedFilters: Array[Filter]) - extends FileScan(sparkSession, fileIndex, readDataSchema, readPartitionSchema) { + pushedFilters: Array[Filter], + partitionFilters: Seq[Expression] = Seq.empty) extends FileScan { override def isSplitable(path: Path): Boolean = true override def createReaderFactory(): PartitionReaderFactory = { @@ -51,15 +52,18 @@ case class OrcScan( override def equals(obj: Any): Boolean = obj match { case o: OrcScan => - fileIndex == o.fileIndex && dataSchema == o.dataSchema && - readDataSchema == o.readDataSchema && readPartitionSchema == o.readPartitionSchema && - options == o.options && equivalentFilters(pushedFilters, o.pushedFilters) + super.equals(o) && dataSchema == o.dataSchema && options == o.options && + equivalentFilters(pushedFilters, o.pushedFilters) + case _ => false } override def hashCode(): Int = getClass.hashCode() override def description(): String = { - super.description() + ", PushedFilters: " + pushedFilters.mkString("[", ", ", "]") + super.description() + ", PushedFilters: " + seqToString(pushedFilters) } + + override def withPartitionFilters(partitionFilters: Seq[Expression]): FileScan = + this.copy(partitionFilters = partitionFilters) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala index cf16a174d9e2..44179e2e42a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala @@ -21,6 +21,7 @@ import org.apache.hadoop.fs.Path import org.apache.parquet.hadoop.ParquetInputFormat import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.connector.read.PartitionReaderFactory import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.parquet.{ParquetReadSupport, ParquetWriteSupport} @@ -39,8 +40,8 @@ case class ParquetScan( readDataSchema: StructType, readPartitionSchema: StructType, pushedFilters: Array[Filter], - options: CaseInsensitiveStringMap) - extends FileScan(sparkSession, fileIndex, readDataSchema, readPartitionSchema) { + options: CaseInsensitiveStringMap, + partitionFilters: Seq[Expression] = Seq.empty) extends FileScan { override def isSplitable(path: Path): Boolean = true override def createReaderFactory(): PartitionReaderFactory = { @@ -80,15 +81,17 @@ case class ParquetScan( override def equals(obj: Any): Boolean = obj match { case p: ParquetScan => - fileIndex == p.fileIndex && dataSchema == p.dataSchema && - readDataSchema == p.readDataSchema && readPartitionSchema == p.readPartitionSchema && - options == p.options && equivalentFilters(pushedFilters, p.pushedFilters) + super.equals(p) && dataSchema == p.dataSchema && options == p.options && + equivalentFilters(pushedFilters, p.pushedFilters) case _ => false } override def hashCode(): Int = getClass.hashCode() override def description(): String = { - super.description() + ", PushedFilters: " + pushedFilters.mkString("[", ", ", "]") + super.description() + ", PushedFilters: " + seqToString(pushedFilters) } + + override def withPartitionFilters(partitionFilters: Seq[Expression]): FileScan = + this.copy(partitionFilters = partitionFilters) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScan.scala index a2c42db59d7f..cf6595e5c126 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScan.scala @@ -21,10 +21,11 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.fs.Path import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.connector.read.PartitionReaderFactory import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.text.TextOptions -import org.apache.spark.sql.execution.datasources.v2.TextBasedFileScan +import org.apache.spark.sql.execution.datasources.v2.{FileScan, TextBasedFileScan} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.SerializableConfiguration @@ -34,8 +35,9 @@ case class TextScan( fileIndex: PartitioningAwareFileIndex, readDataSchema: StructType, readPartitionSchema: StructType, - options: CaseInsensitiveStringMap) - extends TextBasedFileScan(sparkSession, fileIndex, readDataSchema, readPartitionSchema, options) { + options: CaseInsensitiveStringMap, + partitionFilters: Seq[Expression] = Seq.empty) + extends TextBasedFileScan(sparkSession, options) { private val optionsAsScala = options.asScala.toMap private lazy val textOptions: TextOptions = new TextOptions(optionsAsScala) @@ -67,4 +69,15 @@ case class TextScan( TextPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf, readDataSchema, readPartitionSchema, textOptions) } + + override def withPartitionFilters(partitionFilters: Seq[Expression]): FileScan = + this.copy(partitionFilters = partitionFilters) + + override def equals(obj: Any): Boolean = obj match { + case t: TextScan => super.equals(t) && options == t.options + + case _ => false + } + + override def hashCode(): Int = super.hashCode() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index d4f76858af95..b8b27b52c67f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -28,8 +28,11 @@ import org.apache.hadoop.fs.Path import org.apache.spark.SparkException import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} import org.apache.spark.sql.TestingUDT.{IntervalUDT, NullData, NullUDT} +import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.planning.PhysicalOperation -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation +import org.apache.spark.sql.catalyst.plans.logical.Filter +import org.apache.spark.sql.execution.datasources.FilePartition +import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2ScanRelation, FileScan} import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetTable import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ @@ -726,6 +729,49 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSparkSession { } } + test("File source v2: support partition pruning") { + withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") { + allFileBasedDataSources.foreach { format => + withTempPath { dir => + Seq(("a", 1, 2), ("b", 1, 2), ("c", 2, 1)) + .toDF("value", "p1", "p2") + .write + .format(format) + .partitionBy("p1", "p2") + .option("header", true) + .save(dir.getCanonicalPath) + val df = spark + .read + .format(format) + .option("header", true) + .load(dir.getCanonicalPath) + .where("p1 = 1 and p2 = 2 and value != \"a\"") + + val filterCondition = df.queryExecution.optimizedPlan.collectFirst { + case f: Filter => f.condition + } + assert(filterCondition.isDefined) + // The partitions filters should be pushed down and no need to be reevaluated. + assert(filterCondition.get.collectFirst { + case a: AttributeReference if a.name == "p1" || a.name == "p2" => a + }.isEmpty) + + val fileScan = df.queryExecution.executedPlan collectFirst { + case BatchScanExec(_, f: FileScan) => f + } + assert(fileScan.nonEmpty) + assert(fileScan.get.partitionFilters.nonEmpty) + assert(fileScan.get.planInputPartitions().forall { partition => + partition.asInstanceOf[FilePartition].files.forall { file => + file.filePath.contains("p1=1") && file.filePath.contains("p2=2") + } + }) + checkAnswer(df, Row("b", 1, 2)) + } + } + } + } + test("File table location should include both values of option `path` and `paths`") { withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") { withTempPaths(3) { paths => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala index 528c3474a17c..388744bd0fd6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala @@ -119,14 +119,14 @@ abstract class OrcTest extends QueryTest with FileBasedDataSourceTest with Befor query.queryExecution.optimizedPlan match { case PhysicalOperation(_, filters, - DataSourceV2ScanRelation(_, OrcScan(_, _, _, _, _, _, _, pushedFilters), _)) => + DataSourceV2ScanRelation(_, o: OrcScan, _)) => assert(filters.nonEmpty, "No filter is analyzed from the given query") if (noneSupported) { - assert(pushedFilters.isEmpty, "Unsupported filters should not show in pushed filters") + assert(o.pushedFilters.isEmpty, "Unsupported filters should not show in pushed filters") } else { - assert(pushedFilters.nonEmpty, "No filter is pushed down") - val maybeFilter = OrcFilters.createFilter(query.schema, pushedFilters) - assert(maybeFilter.isEmpty, s"Couldn't generate filter predicate for $pushedFilters") + assert(o.pushedFilters.nonEmpty, "No filter is pushed down") + val maybeFilter = OrcFilters.createFilter(query.schema, o.pushedFilters) + assert(maybeFilter.isEmpty, s"Couldn't generate filter predicate for ${o.pushedFilters}") } case _ => diff --git a/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala b/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala index d09236a93433..526ce5cb7085 100644 --- a/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala +++ b/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala @@ -53,12 +53,11 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession { .where(Column(predicate)) query.queryExecution.optimizedPlan match { - case PhysicalOperation(_, filters, - DataSourceV2ScanRelation(_, OrcScan(_, _, _, _, _, _, _, pushedFilters), _)) => + case PhysicalOperation(_, filters, DataSourceV2ScanRelation(_, o: OrcScan, _)) => assert(filters.nonEmpty, "No filter is analyzed from the given query") - assert(pushedFilters.nonEmpty, "No filter is pushed down") - val maybeFilter = OrcFilters.createFilter(query.schema, pushedFilters) - assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $pushedFilters") + assert(o.pushedFilters.nonEmpty, "No filter is pushed down") + val maybeFilter = OrcFilters.createFilter(query.schema, o.pushedFilters) + assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for ${o.pushedFilters}") checker(maybeFilter.get) case _ => diff --git a/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala b/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala index b95a32ef85dd..f88fec7ed4d6 100644 --- a/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala +++ b/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala @@ -54,12 +54,11 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession { .where(Column(predicate)) query.queryExecution.optimizedPlan match { - case PhysicalOperation(_, filters, - DataSourceV2ScanRelation(_, OrcScan(_, _, _, _, _, _, _, pushedFilters), _)) => + case PhysicalOperation(_, filters, DataSourceV2ScanRelation(_, o: OrcScan, _)) => assert(filters.nonEmpty, "No filter is analyzed from the given query") - assert(pushedFilters.nonEmpty, "No filter is pushed down") - val maybeFilter = OrcFilters.createFilter(query.schema, pushedFilters) - assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $pushedFilters") + assert(o.pushedFilters.nonEmpty, "No filter is pushed down") + val maybeFilter = OrcFilters.createFilter(query.schema, o.pushedFilters) + assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for ${o.pushedFilters}") checker(maybeFilter.get) case _ =>