diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index a6a6cef5861f3..d74c4c9e43fe0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.execution.datasources.parquet import org.apache.parquet.filter2.predicate._ -import org.apache.parquet.filter2.predicate.FilterApi._ +import org.apache.parquet.filter2.predicate.Operators.{Column, SupportsEqNotEq, SupportsLtGt} +import org.apache.parquet.hadoop.metadata.ColumnPath import org.apache.parquet.io.api.Binary import org.apache.spark.sql.sources @@ -29,6 +30,8 @@ import org.apache.spark.sql.types._ */ private[parquet] object ParquetFilters { + import ParquetColumns._ + private val makeEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { case BooleanType => (n: String, v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean]) @@ -235,3 +238,40 @@ private[parquet] object ParquetFilters { } } } + +/** + * Note that, this is a hacky workaround to allow dots in column names. Currently, column APIs + * in Parquet's `FilterApi` only allows dot-separated names so here we resemble those columns + * but only allow single column path that allows dots in the names as we don't currently push + * down filters with nested fields. The functions in this object are based on + * the codes in `org.apache.parquet.filter2.predicate`. + */ +private[parquet] object ParquetColumns { + def intColumn(columnPath: String): Column[Integer] with SupportsLtGt = { + new Column[Integer] (ColumnPath.get(columnPath), classOf[Integer]) with SupportsLtGt + } + + def longColumn(columnPath: String): Column[java.lang.Long] with SupportsLtGt = { + new Column[java.lang.Long] ( + ColumnPath.get(columnPath), classOf[java.lang.Long]) with SupportsLtGt + } + + def floatColumn(columnPath: String): Column[java.lang.Float] with SupportsLtGt = { + new Column[java.lang.Float] ( + ColumnPath.get(columnPath), classOf[java.lang.Float]) with SupportsLtGt + } + + def doubleColumn(columnPath: String): Column[java.lang.Double] with SupportsLtGt = { + new Column[java.lang.Double] ( + ColumnPath.get(columnPath), classOf[java.lang.Double]) with SupportsLtGt + } + + def booleanColumn(columnPath: String): Column[java.lang.Boolean] with SupportsEqNotEq = { + new Column[java.lang.Boolean] ( + ColumnPath.get(columnPath), classOf[java.lang.Boolean]) with SupportsEqNotEq + } + + def binaryColumn(columnPath: String): Column[Binary] with SupportsLtGt = { + new Column[Binary] (ColumnPath.get(columnPath), classOf[Binary]) with SupportsLtGt + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index fa3c69612704d..e8a5950279b6c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -487,6 +487,20 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi } } + test("no filter pushdown for nested field access") { + val table = createTable( + files = Seq("file1" -> 1), + format = classOf[TestFileFormatWithNestedSchema].getName) + + checkScan(table.where("a1 = 1"))(_ => ()) + // Check `a1` access pushes the predicate. + checkDataFilters(Set(IsNotNull("a1"), EqualTo("a1", 1))) + + checkScan(table.where("a2.c1 = 1"))(_ => ()) + // Check `a2.c1` access does not push the predicate. + checkDataFilters(Set(IsNotNull("a2"))) + } + // Helpers for checking the arguments passed to the FileFormat. protected val checkPartitionSchema = @@ -537,7 +551,8 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi */ def createTable( files: Seq[(String, Int)], - buckets: Int = 0): DataFrame = { + buckets: Int = 0, + format: String = classOf[TestFileFormat].getName): DataFrame = { val tempDir = Utils.createTempDir() files.foreach { case (name, size) => @@ -547,7 +562,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi } val df = spark.read - .format(classOf[TestFileFormat].getName) + .format(format) .load(tempDir.getCanonicalPath) if (buckets > 0) { @@ -632,6 +647,22 @@ class TestFileFormat extends TextBasedFileFormat { } } +/** + * A test [[FileFormat]] that records the arguments passed to buildReader, and returns nothing. + * Unlike the one above, this one has a nested schema. + */ +class TestFileFormatWithNestedSchema extends TestFileFormat { + override def inferSchema( + sparkSession: SparkSession, + options: Map[String, String], + files: Seq[FileStatus]): Option[StructType] = + Some(StructType(Nil) + .add("a1", IntegerType) + .add("a2", + StructType(Nil) + .add("c1", IntegerType) + .add("c2", IntegerType))) +} class LocalityTestFileSystem extends RawLocalFileSystem { private val invocations = new AtomicInteger(0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index dd53b561326f3..eb0e43b9d1a48 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.nio.charset.StandardCharsets import org.apache.parquet.filter2.predicate.{FilterPredicate, Operators} -import org.apache.parquet.filter2.predicate.FilterApi._ +import org.apache.parquet.filter2.predicate.FilterApi.{and, gt, lt} import org.apache.parquet.filter2.predicate.Operators.{Column => _, _} import org.apache.spark.sql._ @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, HadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.execution.datasources.parquet.ParquetColumns.{doubleColumn, intColumn} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -538,6 +539,49 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex // scalastyle:on nonascii } } + + test("SPARK-20364: Predicate pushdown for columns with a '.' in them") { + import testImplicits._ + + Seq(true, false).foreach { vectorized => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized.toString) { + val dfs = Seq( + Seq(Some(1), None).toDF("col.dots"), + Seq(Some(1L), None).toDF("col.dots"), + Seq(Some(1.0F), None).toDF("col.dots"), + Seq(Some(1.0D), None).toDF("col.dots"), + Seq(true, false).toDF("col.dots"), + Seq("apple", null).toDF("col.dots") + ) + + val predicates = Seq( + "`col.dots` > 0", + "`col.dots` >= 1L", + "`col.dots` < 2.0", + "`col.dots` <= 1.0D", + "`col.dots` == true", + "`col.dots` IS NOT NULL" + ) + + dfs.zip(predicates).foreach { case (df, predicate) => + withTempPath { path => + df.write.parquet(path.getAbsolutePath) + assert(spark.read.parquet(path.getAbsolutePath).where(predicate).count() == 1) + } + } + } + } + + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> false.toString) { + withTempPath { path => + Seq("apple", null).toDF("col.dots").write.parquet(path.getAbsolutePath) + // This checks record-by-record filtering in Parquet's filter2. + val num = stripSparkFilter( + spark.read.parquet(path.getAbsolutePath).where("`col.dots` IS NULL")).count() + assert(num == 1) + } + } + } } class NumRowGroupsAcc extends AccumulatorV2[Integer, Integer] {