Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
package org.apache.spark.sql.execution.datasources.parquet

import org.apache.parquet.filter2.predicate._
import org.apache.parquet.filter2.predicate.FilterApi._
import org.apache.parquet.filter2.predicate.Operators.{Column, SupportsEqNotEq, SupportsLtGt}
import org.apache.parquet.hadoop.metadata.ColumnPath
import org.apache.parquet.io.api.Binary

import org.apache.spark.sql.sources
Expand All @@ -29,6 +30,8 @@ import org.apache.spark.sql.types._
*/
private[parquet] object ParquetFilters {

import ParquetColumns._

private val makeEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = {
case BooleanType =>
(n: String, v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean])
Expand Down Expand Up @@ -235,3 +238,40 @@ private[parquet] object ParquetFilters {
}
}
}

/**
* Note that, this is a hacky workaround to allow dots in column names. Currently, column APIs
* in Parquet's `FilterApi` only allows dot-separated names so here we resemble those columns
* but only allow single column path that allows dots in the names as we don't currently push
* down filters with nested fields. The functions in this object are based on
* the codes in `org.apache.parquet.filter2.predicate`.
*/
private[parquet] object ParquetColumns {
def intColumn(columnPath: String): Column[Integer] with SupportsLtGt = {
new Column[Integer] (ColumnPath.get(columnPath), classOf[Integer]) with SupportsLtGt
}

def longColumn(columnPath: String): Column[java.lang.Long] with SupportsLtGt = {
new Column[java.lang.Long] (
ColumnPath.get(columnPath), classOf[java.lang.Long]) with SupportsLtGt
}

def floatColumn(columnPath: String): Column[java.lang.Float] with SupportsLtGt = {
new Column[java.lang.Float] (
ColumnPath.get(columnPath), classOf[java.lang.Float]) with SupportsLtGt
}

def doubleColumn(columnPath: String): Column[java.lang.Double] with SupportsLtGt = {
new Column[java.lang.Double] (
ColumnPath.get(columnPath), classOf[java.lang.Double]) with SupportsLtGt
}

def booleanColumn(columnPath: String): Column[java.lang.Boolean] with SupportsEqNotEq = {
new Column[java.lang.Boolean] (
ColumnPath.get(columnPath), classOf[java.lang.Boolean]) with SupportsEqNotEq
}

def binaryColumn(columnPath: String): Column[Binary] with SupportsLtGt = {
new Column[Binary] (ColumnPath.get(columnPath), classOf[Binary]) with SupportsLtGt
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,20 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi
}
}

test("no filter pushdown for nested field access") {
val table = createTable(
files = Seq("file1" -> 1),
format = classOf[TestFileFormatWithNestedSchema].getName)

checkScan(table.where("a1 = 1"))(_ => ())
// Check `a1` access pushes the predicate.
checkDataFilters(Set(IsNotNull("a1"), EqualTo("a1", 1)))

checkScan(table.where("a2.c1 = 1"))(_ => ())
// Check `a2.c1` access does not push the predicate.
checkDataFilters(Set(IsNotNull("a2")))
}

// Helpers for checking the arguments passed to the FileFormat.

protected val checkPartitionSchema =
Expand Down Expand Up @@ -537,7 +551,8 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi
*/
def createTable(
files: Seq[(String, Int)],
buckets: Int = 0): DataFrame = {
buckets: Int = 0,
format: String = classOf[TestFileFormat].getName): DataFrame = {
val tempDir = Utils.createTempDir()
files.foreach {
case (name, size) =>
Expand All @@ -547,7 +562,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi
}

val df = spark.read
.format(classOf[TestFileFormat].getName)
.format(format)
.load(tempDir.getCanonicalPath)

if (buckets > 0) {
Expand Down Expand Up @@ -632,6 +647,22 @@ class TestFileFormat extends TextBasedFileFormat {
}
}

/**
* A test [[FileFormat]] that records the arguments passed to buildReader, and returns nothing.
* Unlike the one above, this one has a nested schema.
*/
class TestFileFormatWithNestedSchema extends TestFileFormat {
override def inferSchema(
sparkSession: SparkSession,
options: Map[String, String],
files: Seq[FileStatus]): Option[StructType] =
Some(StructType(Nil)
.add("a1", IntegerType)
.add("a2",
StructType(Nil)
.add("c1", IntegerType)
.add("c2", IntegerType)))
}

class LocalityTestFileSystem extends RawLocalFileSystem {
private val invocations = new AtomicInteger(0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@ package org.apache.spark.sql.execution.datasources.parquet
import java.nio.charset.StandardCharsets

import org.apache.parquet.filter2.predicate.{FilterPredicate, Operators}
import org.apache.parquet.filter2.predicate.FilterApi._
import org.apache.parquet.filter2.predicate.FilterApi.{and, gt, lt}
import org.apache.parquet.filter2.predicate.Operators.{Column => _, _}

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, HadoopFsRelation, LogicalRelation}
import org.apache.spark.sql.execution.datasources.parquet.ParquetColumns.{doubleColumn, intColumn}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
Expand Down Expand Up @@ -538,6 +539,49 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
// scalastyle:on nonascii
}
}

test("SPARK-20364: Predicate pushdown for columns with a '.' in them") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there another existing test that checks that pushdown for struct.field1 syntax works correctly? I'm not sure how to reference those inner fields in a struct field as I don't use it much personally, but want to make sure that's not broken as a result of this change.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Up to my knolwedge, we don't push down filters with nested columns. Let me check if we already have the negative case explicitly and then add it if missing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the change -- I wasn't sure if predicate pushdown worked on nested columns and it looks like that change confirms it does not after this change.

import testImplicits._

Seq(true, false).foreach { vectorized =>
withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized.toString) {
val dfs = Seq(
Seq(Some(1), None).toDF("col.dots"),
Seq(Some(1L), None).toDF("col.dots"),
Seq(Some(1.0F), None).toDF("col.dots"),
Seq(Some(1.0D), None).toDF("col.dots"),
Seq(true, false).toDF("col.dots"),
Seq("apple", null).toDF("col.dots")
)

val predicates = Seq(
"`col.dots` > 0",
"`col.dots` >= 1L",
"`col.dots` < 2.0",
"`col.dots` <= 1.0D",
"`col.dots` == true",
"`col.dots` IS NOT NULL"
)

dfs.zip(predicates).foreach { case (df, predicate) =>
withTempPath { path =>
df.write.parquet(path.getAbsolutePath)
assert(spark.read.parquet(path.getAbsolutePath).where(predicate).count() == 1)
}
}
}
}

withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> false.toString) {
withTempPath { path =>
Seq("apple", null).toDF("col.dots").write.parquet(path.getAbsolutePath)
// This checks record-by-record filtering in Parquet's filter2.
val num = stripSparkFilter(
spark.read.parquet(path.getAbsolutePath).where("`col.dots` IS NULL")).count()
assert(num == 1)
}
}
}
}

class NumRowGroupsAcc extends AccumulatorV2[Integer, Integer] {
Expand Down