diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 831ef62d74c3..0cb58fab47ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -327,6 +327,13 @@ object SQLConf { .booleanConf .createWithDefault(false) + val PARQUET_RECORD_FILTER_ENABLED = buildConf("spark.sql.parquet.recordLevelFilter.enabled") + .doc("If true, enables Parquet's native record-level filtering using the pushed down " + + "filters. This configuration only has an effect when 'spark.sql.parquet.filterPushdown' " + + "is enabled.") + .booleanConf + .createWithDefault(false) + val PARQUET_OUTPUT_COMMITTER_CLASS = buildConf("spark.sql.parquet.output.committer.class") .doc("The output committer class used by Parquet. The specified class needs to be a " + "subclass of org.apache.hadoop.mapreduce.OutputCommitter. Typically, it's also a subclass " + @@ -1173,6 +1180,8 @@ class SQLConf extends Serializable with Logging { def writeLegacyParquetFormat: Boolean = getConf(PARQUET_WRITE_LEGACY_FORMAT) + def parquetRecordFilterEnabled: Boolean = getConf(PARQUET_RECORD_FILTER_ENABLED) + def inMemoryPartitionPruning: Boolean = getConf(IN_MEMORY_PARTITION_PRUNING) def columnNameOfCorruptRecord: String = getConf(COLUMN_NAME_OF_CORRUPT_RECORD) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index a48f8d517b6a..044b1a89d57c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -335,6 +335,8 @@ class ParquetFileFormat val enableVectorizedReader: Boolean = sparkSession.sessionState.conf.parquetVectorizedReaderEnabled && resultSchema.forall(_.dataType.isInstanceOf[AtomicType]) + val enableRecordFilter: Boolean = + sparkSession.sessionState.conf.parquetRecordFilterEnabled // Whole stage codegen (PhysicalRDD) is able to deal with batches directly val returningBatch = supportBatch(sparkSession, resultSchema) @@ -374,13 +376,11 @@ class ParquetFileFormat } else { logDebug(s"Falling back to parquet-mr") // ParquetRecordReader returns UnsafeRow - val reader = pushed match { - case Some(filter) => - new ParquetRecordReader[UnsafeRow]( - new ParquetReadSupport, - FilterCompat.get(filter, null)) - case _ => - new ParquetRecordReader[UnsafeRow](new ParquetReadSupport) + val reader = if (pushed.isDefined && enableRecordFilter) { + val parquetFilter = FilterCompat.get(pushed.get, null) + new ParquetRecordReader[UnsafeRow](new ParquetReadSupport, parquetFilter) + } else { + new ParquetRecordReader[UnsafeRow](new ParquetReadSupport) } reader.initialize(split, hadoopAttemptContext) reader diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 90f6620d990c..33801954ebd5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -45,8 +45,29 @@ import org.apache.spark.util.{AccumulatorContext, AccumulatorV2} * * 2. `Tuple1(Option(x))` is used together with `AnyVal` types like `Int` to ensure the inferred * data type is nullable. + * + * NOTE: + * + * This file intendedly enables record-level filtering explicitly. If new test cases are + * dependent on this configuration, don't forget you better explicitly set this configuration + * within the test. */ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContext { + + override def beforeEach(): Unit = { + super.beforeEach() + // Note that there are many tests here that require record-level filtering set to be true. + spark.conf.set(SQLConf.PARQUET_RECORD_FILTER_ENABLED.key, "true") + } + + override def afterEach(): Unit = { + try { + spark.conf.unset(SQLConf.PARQUET_RECORD_FILTER_ENABLED.key) + } finally { + super.afterEach() + } + } + private def checkFilterPredicate( df: DataFrame, predicate: Predicate, @@ -369,7 +390,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex test("Filter applied on merged Parquet schema with new column should work") { import testImplicits._ - Seq("true", "false").map { vectorized => + Seq("true", "false").foreach { vectorized => withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true", SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true", SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized) { @@ -491,7 +512,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } - test("Fiters should be pushed down for vectorized Parquet reader at row group level") { + test("Filters should be pushed down for vectorized Parquet reader at row group level") { import testImplicits._ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true", @@ -555,6 +576,32 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } } + + test("Filters should be pushed down for Parquet readers at row group level") { + import testImplicits._ + + withSQLConf( + // Makes sure disabling 'spark.sql.parquet.recordFilter' still enables + // row group level filtering. + SQLConf.PARQUET_RECORD_FILTER_ENABLED.key -> "false", + SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true", + SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + withTempPath { path => + val data = (1 to 1024) + data.toDF("a").coalesce(1) + .write.option("parquet.block.size", 512) + .parquet(path.getAbsolutePath) + val df = spark.read.parquet(path.getAbsolutePath).filter("a == 500") + // Here, we strip the Spark side filter and check the actual results from Parquet. + val actual = stripSparkFilter(df).collect().length + // Since those are filtered at row group level, the result count should be less + // than the total length but should not be a single record. + // Note that, if record level filtering is enabled, it should be a single record. + // If no filter is pushed down to Parquet, it should be the total length of data. + assert(actual > 1 && actual < data.length) + } + } + } } class NumRowGroupsAcc extends AccumulatorV2[Integer, Integer] {