diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala index cac2d6e62612..7caa2a4ad04f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala @@ -56,6 +56,11 @@ class InMemoryFileIndex( extends PartitioningAwareFileIndex( sparkSession, parameters, userSpecifiedSchema, fileStatusCache) { + assert(userSpecifiedPartitionSpec.isEmpty || + userSpecifiedPartitionSpec.get.partitions.map(_.path).equals(rootPathsSpecified), + s"The rootPathsSpecified ($rootPathsSpecified) is inconsistent with the file paths " + + s"of userSpecifiedPartitionSpec (${userSpecifiedPartitionSpec.get.partitions.map(_.path)}).") + // Filter out streaming metadata dirs or files such as "/.../_spark_metadata" (the metadata dir) // or "/.../_spark_metadata/0" (a file in the metadata dir). `rootPathsSpecified` might contain // such streaming metadata dir or files, e.g. when after globbing "basePath/*" where "basePath" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala index 2e09c729529a..6546ec261868 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{expressions, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} -import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.sql.types.StructType /** * An abstract class that represents [[FileIndex]]s that are aware of partitioned tables. @@ -102,6 +102,10 @@ abstract class PartitioningAwareFileIndex( override def sizeInBytes: Long = allFiles().map(_.getLen).sum + def sizeInBytesOfPartitions(partitions: Seq[PartitionDirectory]): Long = { + partitions.flatMap(_.files).map(_.getLen).sum + } + def allFiles(): Seq[FileStatus] = { val files = if (partitionSpec().partitionColumns.isEmpty && !recursiveFileLookup) { // For each of the root input paths, get the list of files inside them diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala index 6e05aa56f4f7..96eedeac4da1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala @@ -159,11 +159,20 @@ trait FileScan extends Scan with Batch with SupportsReportStatistics with Loggin partitions.toArray } + private def getSizeInBytes(): Long = { + if (fileIndex.partitionSpec().partitionColumns.isEmpty || partitionFilters.isEmpty) { + fileIndex.sizeInBytes + } else { + fileIndex.sizeInBytesOfPartitions(fileIndex.listFiles(partitionFilters, dataFilters)) + } + } + override def estimateStatistics(): Statistics = { new Statistics { override def sizeInBytes(): OptionalLong = { val compressionFactor = sparkSession.sessionState.conf.fileCompressionFactor - val size = (compressionFactor * fileIndex.sizeInBytes).toLong + val sizeInBytes = getSizeInBytes() + val size = (compressionFactor * sizeInBytes).toLong OptionalLong.of(size) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index c87095812848..841e40ab6337 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -787,6 +787,42 @@ class FileBasedDataSourceSuite extends QueryTest } } + test("File source v2: involve partition filters in statistic estimation") { + withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") { + allFileBasedDataSources.foreach { format => + withTempPath { dir => + Seq(("a", 1, 2), ("b", 1, 2), ("c", 2, 1)) + .toDF("value", "p1", "p2") + .write + .format(format) + .partitionBy("p1", "p2") + .option("header", true) + .save(dir.getCanonicalPath) + val df1 = spark + .read + .format(format) + .option("header", true) + .load(dir.getCanonicalPath) + .where("p1 = 1 and p2 = 2") + val df2 = spark + .read + .format(format) + .option("header", true) + .load(dir.getCanonicalPath) + .where("p1 = 2 and p2 = 1") + val fileScan1 = df1.queryExecution.executedPlan collectFirst { + case BatchScanExec(_, f: FileScan) => f + } + val fileScan2 = df2.queryExecution.executedPlan collectFirst { + case BatchScanExec(_, f: FileScan) => f + } + assert(fileScan1.get.estimateStatistics().sizeInBytes().getAsLong / 2 === + fileScan2.get.estimateStatistics().sizeInBytes().getAsLong) + } + } + } + } + test("File source v2: support passing data filters to FileScan without partitionFilters") { withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") { allFileBasedDataSources.foreach { format =>