diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 5a5cb5cf03d4..5bade7edacab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.datasources.parquet import java.io.Serializable +import scala.collection.mutable.ArrayBuffer + import org.apache.parquet.filter2.predicate._ import org.apache.parquet.filter2.predicate.FilterApi._ import org.apache.parquet.io.api.Binary @@ -222,10 +224,26 @@ private[sql] object ParquetFilters { case _ => Array.empty[(String, DataType)] } + /** + * Returns referenced columns in [[sources.Filter]]. + */ + def referencedColumns(schema: StructType, predicate: sources.Filter): Array[String] = { + val referencedCols = ArrayBuffer.empty[String] + createParquetFilter(schema, predicate, referencedCols) + referencedCols.distinct.toArray + } + /** * Converts data sources filters to Parquet filter predicates. */ def createFilter(schema: StructType, predicate: sources.Filter): Option[FilterPredicate] = { + createParquetFilter(schema, predicate) + } + + private def createParquetFilter( + schema: StructType, + predicate: sources.Filter, + referencedCols: ArrayBuffer[String] = ArrayBuffer.empty[String]): Option[FilterPredicate] = { val dataTypeOf = getFieldMap(schema).toMap relaxParquetValidTypeMap @@ -247,31 +265,42 @@ private[sql] object ParquetFilters { predicate match { case sources.IsNull(name) if dataTypeOf.contains(name) => + referencedCols += name makeEq.lift(dataTypeOf(name)).map(_(name, null)) case sources.IsNotNull(name) if dataTypeOf.contains(name) => + referencedCols += name makeNotEq.lift(dataTypeOf(name)).map(_(name, null)) case sources.EqualTo(name, value) if dataTypeOf.contains(name) => + referencedCols += name makeEq.lift(dataTypeOf(name)).map(_(name, value)) case sources.Not(sources.EqualTo(name, value)) if dataTypeOf.contains(name) => + referencedCols += name makeNotEq.lift(dataTypeOf(name)).map(_(name, value)) case sources.EqualNullSafe(name, value) if dataTypeOf.contains(name) => + referencedCols += name makeEq.lift(dataTypeOf(name)).map(_(name, value)) case sources.Not(sources.EqualNullSafe(name, value)) if dataTypeOf.contains(name) => + referencedCols += name makeNotEq.lift(dataTypeOf(name)).map(_(name, value)) case sources.LessThan(name, value) if dataTypeOf.contains(name) => + referencedCols += name makeLt.lift(dataTypeOf(name)).map(_(name, value)) case sources.LessThanOrEqual(name, value) if dataTypeOf.contains(name) => + referencedCols += name makeLtEq.lift(dataTypeOf(name)).map(_(name, value)) case sources.GreaterThan(name, value) if dataTypeOf.contains(name) => + referencedCols += name makeGt.lift(dataTypeOf(name)).map(_(name, value)) case sources.GreaterThanOrEqual(name, value) if dataTypeOf.contains(name) => + referencedCols += name makeGtEq.lift(dataTypeOf(name)).map(_(name, value)) case sources.In(name, valueSet) => + referencedCols += name makeInSet.lift(dataTypeOf(name)).map(_(name, valueSet.toSet)) case sources.And(lhs, rhs) => @@ -283,18 +312,18 @@ private[sql] object ParquetFilters { // Pushing one side of AND down is only safe to do at the top level. // You can see ParquetRelation's initializeLocalJobFunc method as an example. for { - lhsFilter <- createFilter(schema, lhs) - rhsFilter <- createFilter(schema, rhs) + lhsFilter <- createParquetFilter(schema, lhs, referencedCols) + rhsFilter <- createParquetFilter(schema, rhs, referencedCols) } yield FilterApi.and(lhsFilter, rhsFilter) case sources.Or(lhs, rhs) => for { - lhsFilter <- createFilter(schema, lhs) - rhsFilter <- createFilter(schema, rhs) + lhsFilter <- createParquetFilter(schema, lhs, referencedCols) + rhsFilter <- createParquetFilter(schema, rhs, referencedCols) } yield FilterApi.or(lhsFilter, rhsFilter) case sources.Not(pred) => - createFilter(schema, pred).map(FilterApi.not) + createParquetFilter(schema, pred, referencedCols).map(FilterApi.not) case _ => None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index 184cbb2f296b..e0c1d6f91188 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -141,6 +141,11 @@ private[sql] class ParquetRelation( .map(_.toBoolean) .getOrElse(sqlContext.conf.getConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED)) + // When merging schemas is enabled and the column of the given filter does not exist, + // Parquet emits an exception which is an issue of Parquet (PARQUET-389). + private val safeParquetFilterPushDown = + sqlContext.conf.parquetFilterPushDown && !shouldMergeSchemas + private val mergeRespectSummaries = sqlContext.conf.getConf(SQLConf.PARQUET_SCHEMA_RESPECT_SUMMARIES) @@ -300,13 +305,26 @@ private[sql] class ParquetRelation( } } + override def unhandledFilters(filters: Array[Filter]): Array[Filter] = { + // The unsafe row RecordReader does not support row by row filtering so for this case + // it should wrap this with Spark-side filtering. + val enableUnsafeRowParquetReader = + sqlContext.getConf(SQLConf.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED.key).toBoolean + val shouldHandleFilters = safeParquetFilterPushDown && !enableUnsafeRowParquetReader + if (shouldHandleFilters) { + filters.filter(ParquetFilters.createFilter(dataSchema, _).isEmpty) + } else { + filters + } + } + override def buildInternalScan( requiredColumns: Array[String], filters: Array[Filter], inputFiles: Array[FileStatus], broadcastedConf: Broadcast[SerializableConfiguration]): RDD[InternalRow] = { val useMetadataCache = sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA) - val parquetFilterPushDown = sqlContext.conf.parquetFilterPushDown + val parquetFilterPushDown = safeParquetFilterPushDown val assumeBinaryIsString = sqlContext.conf.isParquetBinaryAsString val assumeInt96IsTimestamp = sqlContext.conf.isParquetINT96AsTimestamp @@ -576,6 +594,15 @@ private[sql] object ParquetRelation extends Logging { conf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[CatalystReadSupport].getName) // Try to push down filters when filter push-down is enabled. + val safeRequiredColumns = if (parquetFilterPushDown) { + val referencedColumns = filters + // Collects all columns referenced in Parquet filter predicates. + .flatMap(filter => ParquetFilters.referencedColumns(dataSchema, filter)) + (requiredColumns ++ referencedColumns).distinct + } else { + requiredColumns + } + if (parquetFilterPushDown) { filters // Collects all converted Parquet filter predicates. Notice that not all predicates can be @@ -587,7 +614,7 @@ private[sql] object ParquetRelation extends Logging { } conf.set(CatalystReadSupport.SPARK_ROW_REQUESTED_SCHEMA, { - val requestedSchema = StructType(requiredColumns.map(dataSchema(_))) + val requestedSchema = StructType(safeRequiredColumns.map(dataSchema(_))) CatalystSchemaConverter.checkFieldNames(requestedSchema).json }) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index c1e3f386b256..dd8c5ac7bca0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -343,7 +343,7 @@ object SQLConf { val PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED = booleanConf( key = "spark.sql.parquet.enableUnsafeRowRecordReader", defaultValue = Some(true), - doc = "Enables using the custom ParquetUnsafeRowRecordReader.") + doc = "Enables using the custom UnsafeRowParquetRecordReader.") // Note: this can not be enabled all the time because the reader will not be returning UnsafeRows. // Doing so is very expensive and we should remove this requirement instead of fixing it here. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index fbffe867e4b7..c3f797fc2bf5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.execution.{PhysicalRDD, WholeStageCodegen} import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, LogicalRelation} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -79,7 +80,8 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex assert(f.getClass === filterClass) } } - checker(stripSparkFilter(query), expected) + checkPlan(query) + checker(query, expected) } } } @@ -108,6 +110,14 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex checkFilterPredicate(df, predicate, filterClass, checkBinaryAnswer _, expected) } + private def checkPlan(df: DataFrame): Unit = { + val executedPlan = df.queryExecution.executedPlan + assert(executedPlan.isInstanceOf[WholeStageCodegen]) + // Check if SparkPlan Filter is removed and this plan only has `PhysicalRDD`. + val childPlan = executedPlan.asInstanceOf[WholeStageCodegen].plan + assert(childPlan.isInstanceOf[PhysicalRDD]) + } + private def checkBinaryFilterPredicate (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Array[Byte]) (implicit df: DataFrame): Unit = { @@ -444,6 +454,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex // The unsafe row RecordReader does not support row by row filtering so run it with it disabled. test("SPARK-11661 Still pushdown filters returned by unhandledFilters") { import testImplicits._ + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { withSQLConf(SQLConf.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED.key -> "false") { withTempPath { dir => @@ -451,11 +462,13 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(path) val df = sqlContext.read.parquet(path).filter("a = 2") + // Check if SparkPlan Filter is removed and this plan only has `PhysicalRDD`. + checkPlan(df) // The result should be single row. // When a filter is pushed to Parquet, Parquet can apply it to every row. // So, we can check the number of rows returned from the Parquet // to make sure our filter pushdown work. - assert(stripSparkFilter(df).count == 1) + assert(df.count == 1) } } } @@ -518,30 +531,31 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex test("SPARK-11164: test the parquet filter in") { import testImplicits._ - withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { - withSQLConf(SQLConf.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED.key -> "false") { - withTempPath { dir => - val path = s"${dir.getCanonicalPath}/table1" - (1 to 5).map(i => (i.toFloat, i%3)).toDF("a", "b").write.parquet(path) + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true", + SQLConf.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED.key -> "false") { + withTempPath { dir => + val path = s"${dir.getCanonicalPath}/table1" + (1 to 5).map(i => (i.toFloat, i%3)).toDF("a", "b").write.parquet(path) - // When a filter is pushed to Parquet, Parquet can apply it to every row. - // So, we can check the number of rows returned from the Parquet - // to make sure our filter pushdown work. - val df = sqlContext.read.parquet(path).where("b in (0,2)") - assert(stripSparkFilter(df).count == 3) + val df = sqlContext.read.parquet(path).where("b in (0,2)") + checkPlan(df) + assert(df.count == 3) - val df1 = sqlContext.read.parquet(path).where("not (b in (1))") - assert(stripSparkFilter(df1).count == 3) + val df1 = sqlContext.read.parquet(path).where("not (b in (1))") + checkPlan(df1) + assert(df1.count == 3) - val df2 = sqlContext.read.parquet(path).where("not (b in (1,3) or a <= 2)") - assert(stripSparkFilter(df2).count == 2) + val df2 = sqlContext.read.parquet(path).where("not (b in (1,3) or a <= 2)") + checkPlan(df2) + assert(df2.count == 2) - val df3 = sqlContext.read.parquet(path).where("not (b in (1,3) and a <= 2)") - assert(stripSparkFilter(df3).count == 4) + val df3 = sqlContext.read.parquet(path).where("not (b in (1,3) and a <= 2)") + checkPlan(df3) + assert(df3.count == 4) - val df4 = sqlContext.read.parquet(path).where("not (a <= 2)") - assert(stripSparkFilter(df4).count == 3) - } + val df4 = sqlContext.read.parquet(path).where("not (a <= 2)") + checkPlan(df4) + assert(df4.count == 3) } } }