diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 1a8d0e310aec..346663616192 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -187,7 +187,23 @@ case class FileSourceScanExec( "InputPaths" -> relation.location.paths.mkString(", ")) private lazy val inputRDD: RDD[InternalRow] = { - val selectedPartitions = relation.location.listFiles(partitionFilters) + val originalPartitions = relation.location.listFiles(partitionFilters) + val filteredPartitions = if (relation.location.paths.isEmpty) { + originalPartitions + } else { + relation.fileFormat.filterPartitions( + dataFilters, + outputSchema, + relation.sparkSession.sparkContext.hadoopConfiguration, + relation.location.allFiles(), + relation.location.paths.head, + originalPartitions) + } + val totalFilesRaw = originalPartitions.map(_.files.size).sum + val totalFilesFiltered = filteredPartitions.map(_.files.size).sum + logInfo(s"Filtered down total number of partitions to ${filteredPartitions.size}" + + s" from ${originalPartitions.size}, " + + s"total number of files to ${totalFilesFiltered} from ${totalFilesRaw}") val readFile: (PartitionedFile) => Iterator[InternalRow] = relation.fileFormat.buildReaderWithPartitionValues( @@ -201,9 +217,9 @@ case class FileSourceScanExec( relation.bucketSpec match { case Some(bucketing) if relation.sparkSession.sessionState.conf.bucketingEnabled => - createBucketedReadRDD(bucketing, readFile, selectedPartitions, relation) + createBucketedReadRDD(bucketing, readFile, filteredPartitions, relation) case _ => - createNonBucketedReadRDD(readFile, selectedPartitions, relation) + createNonBucketedReadRDD(readFile, filteredPartitions, relation) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala index e03a2323c749..fe17080be7e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala @@ -219,6 +219,21 @@ trait FileFormat { false } + /** + * Allow FileFormats to have a pluggable way to utilize pushed filters to eliminate partitions + * before execution. By default no pruning is performed and the original partitioning is + * preserved. + */ + def filterPartitions( + filters: Seq[Filter], + schema: StructType, + conf: Configuration, + allFiles: Seq[FileStatus], + root: Path, + partitions: Seq[Partition]): Seq[Partition] = { + partitions + } + /** * Returns whether a file with `path` could be splitted or not. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 9c4778acf53d..ecca2bdee386 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -27,12 +27,14 @@ import scala.util.{Failure, Try} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, FileSplit} +import org.apache.hadoop.mapreduce.lib.input.FileSplit import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl import org.apache.parquet.{Log => ApacheParquetLog} -import org.apache.parquet.filter2.compat.FilterCompat +import org.apache.parquet.filter2.compat.{FilterCompat, RowGroupFilter} import org.apache.parquet.filter2.predicate.FilterApi +import org.apache.parquet.format.converter.ParquetMetadataConverter import org.apache.parquet.hadoop._ +import org.apache.parquet.hadoop.metadata.ParquetMetadata import org.apache.parquet.hadoop.util.ContextUtil import org.apache.parquet.schema.MessageType import org.slf4j.bridge.SLF4JBridgeHandler @@ -58,6 +60,9 @@ class ParquetFileFormat with Logging with Serializable { + // Attempt to cache parquet metadata + @transient @volatile private var cachedMetadata: ParquetMetadata = _ + override def shortName(): String = "parquet" override def toString: String = "ParquetFormat" @@ -423,6 +428,64 @@ class ParquetFileFormat sqlContext.sessionState.newHadoopConf(), options) } + + override def filterPartitions( + filters: Seq[Filter], + schema: StructType, + conf: Configuration, + allFiles: Seq[FileStatus], + root: Path, + partitions: Seq[Partition]): Seq[Partition] = { + // Read the "_metadata" file if available, contains all block headers. On S3 better to grab + // all of the footers in a batch rather than having to read every single file just to get its + // footer. + allFiles.find(_.getPath.getName == ParquetFileWriter.PARQUET_METADATA_FILE).map { stat => + val metadata = getOrReadMetadata(conf, stat) + partitions.map { part => + filterByMetadata( + filters, + schema, + conf, + root, + metadata, + part) + }.filterNot(_.files.isEmpty) + }.getOrElse(partitions) + } + + private def filterByMetadata( + filters: Seq[Filter], + schema: StructType, + conf: Configuration, + root: Path, + metadata: ParquetMetadata, + partition: Partition): Partition = { + val blockMetadatas = metadata.getBlocks.asScala + val parquetSchema = metadata.getFileMetaData.getSchema + val conjunctiveFilter = filters + .flatMap(ParquetFilters.createFilter(schema, _)) + .reduceOption(FilterApi.and) + conjunctiveFilter.map { conjunction => + val filteredBlocks = RowGroupFilter.filterRowGroups( + FilterCompat.get(conjunction), blockMetadatas.asJava, parquetSchema).asScala.map { bmd => + new Path(root, bmd.getPath).toString + } + Partition(partition.values, partition.files.filter { f => + filteredBlocks.contains(f.getPath.toString) + }) + }.getOrElse(partition) + } + + private def getOrReadMetadata(conf: Configuration, stat: FileStatus): ParquetMetadata = { + if (cachedMetadata == null) { + logInfo("Reading summary metadata into cache in ParquetFileFormat") + cachedMetadata = ParquetFileReader.readFooter(conf, stat, ParquetMetadataConverter.NO_FILTER) + } else { + logInfo("Using cached summary metadata") + } + cachedMetadata + } + } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 9dd8d9f80496..2621c12655d1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -703,6 +703,16 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext } } } + + test("SPARK-17059: Allow FileFormat to specify partition pruning strategy") { + withSQLConf(ParquetOutputFormat.ENABLE_JOB_SUMMARY -> "true") { + withTempPath { path => + Seq(1, 2, 3).toDF("x").write.parquet(path.getCanonicalPath) + val df = spark.read.parquet(path.getCanonicalPath).where("x = 0") + assert(df.rdd.partitions.length == 0) + } + } + } } object TestingUDT {