diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index d7eb14356b8b..ea4f1592a7c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -347,6 +347,7 @@ class ParquetFileFormat val pushDownDecimal = sqlConf.parquetFilterPushDownDecimal val pushDownStringStartWith = sqlConf.parquetFilterPushDownStringStartWith val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold + val isCaseSensitive = sqlConf.caseSensitiveAnalysis (file: PartitionedFile) => { assert(file.partitionValues.numFields == partitionSchema.size) @@ -372,7 +373,7 @@ class ParquetFileFormat val pushed = if (enableParquetFilterPushDown) { val parquetSchema = footerFileMetaData.getSchema val parquetFilters = new ParquetFilters(pushDownDate, pushDownTimestamp, pushDownDecimal, - pushDownStringStartWith, pushDownInFilterThreshold) + pushDownStringStartWith, pushDownInFilterThreshold, isCaseSensitive) filters // Collects all converted Parquet filter predicates. Notice that not all predicates can be // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 58b4a769fcb6..0c286defb940 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.lang.{Boolean => JBoolean, Double => JDouble, Float => JFloat, Long => JLong} import java.math.{BigDecimal => JBigDecimal} import java.sql.{Date, Timestamp} +import java.util.Locale import scala.collection.JavaConverters.asScalaBufferConverter @@ -31,7 +32,7 @@ import org.apache.parquet.schema.OriginalType._ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.catalyst.util.DateTimeUtils.SQLDate import org.apache.spark.sql.sources import org.apache.spark.unsafe.types.UTF8String @@ -44,7 +45,18 @@ private[parquet] class ParquetFilters( pushDownTimestamp: Boolean, pushDownDecimal: Boolean, pushDownStartWith: Boolean, - pushDownInFilterThreshold: Int) { + pushDownInFilterThreshold: Int, + caseSensitive: Boolean) { + + /** + * Holds a single field information stored in the underlying parquet file. + * + * @param fieldName field name in parquet file + * @param fieldType field type related info in parquet file + */ + private case class ParquetField( + fieldName: String, + fieldType: ParquetSchemaType) private case class ParquetSchemaType( originalType: OriginalType, @@ -350,25 +362,38 @@ private[parquet] class ParquetFilters( } /** - * Returns a map from name of the column to the data type, if predicate push down applies. + * Returns a map, which contains parquet field name and data type, if predicate push down applies. */ - private def getFieldMap(dataType: MessageType): Map[String, ParquetSchemaType] = dataType match { - case m: MessageType => - // Here we don't flatten the fields in the nested schema but just look up through - // root fields. Currently, accessing to nested fields does not push down filters - // and it does not support to create filters for them. - m.getFields.asScala.filter(_.isPrimitive).map(_.asPrimitiveType()).map { f => - f.getName -> ParquetSchemaType( - f.getOriginalType, f.getPrimitiveTypeName, f.getTypeLength, f.getDecimalMetadata) - }.toMap - case _ => Map.empty[String, ParquetSchemaType] + private def getFieldMap(dataType: MessageType): Map[String, ParquetField] = { + // Here we don't flatten the fields in the nested schema but just look up through + // root fields. Currently, accessing to nested fields does not push down filters + // and it does not support to create filters for them. + val primitiveFields = + dataType.getFields.asScala.filter(_.isPrimitive).map(_.asPrimitiveType()).map { f => + f.getName -> ParquetField(f.getName, + ParquetSchemaType(f.getOriginalType, + f.getPrimitiveTypeName, f.getTypeLength, f.getDecimalMetadata)) + } + if (caseSensitive) { + primitiveFields.toMap + } else { + // Don't consider ambiguity here, i.e. more than one field is matched in case insensitive + // mode, just skip pushdown for these fields, they will trigger Exception when reading, + // See: SPARK-25132. + val dedupPrimitiveFields = + primitiveFields + .groupBy(_._1.toLowerCase(Locale.ROOT)) + .filter(_._2.size == 1) + .mapValues(_.head._2) + CaseInsensitiveMap(dedupPrimitiveFields) + } } /** * Converts data sources filters to Parquet filter predicates. */ def createFilter(schema: MessageType, predicate: sources.Filter): Option[FilterPredicate] = { - val nameToType = getFieldMap(schema) + val nameToParquetField = getFieldMap(schema) // Decimal type must make sure that filter value's scale matched the file. // If doesn't matched, which would cause data corruption. @@ -381,7 +406,7 @@ private[parquet] class ParquetFilters( // Parquet's type in the given file should be matched to the value's type // in the pushed filter in order to push down the filter to Parquet. def valueCanMakeFilterOn(name: String, value: Any): Boolean = { - value == null || (nameToType(name) match { + value == null || (nameToParquetField(name).fieldType match { case ParquetBooleanType => value.isInstanceOf[JBoolean] case ParquetByteType | ParquetShortType | ParquetIntegerType => value.isInstanceOf[Number] case ParquetLongType => value.isInstanceOf[JLong] @@ -408,7 +433,7 @@ private[parquet] class ParquetFilters( // filters for the column having dots in the names. Thus, we do not push down such filters. // See SPARK-20364. def canMakeFilterOn(name: String, value: Any): Boolean = { - nameToType.contains(name) && !name.contains(".") && valueCanMakeFilterOn(name, value) + nameToParquetField.contains(name) && !name.contains(".") && valueCanMakeFilterOn(name, value) } // NOTE: @@ -428,29 +453,39 @@ private[parquet] class ParquetFilters( predicate match { case sources.IsNull(name) if canMakeFilterOn(name, null) => - makeEq.lift(nameToType(name)).map(_(name, null)) + makeEq.lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldName, null)) case sources.IsNotNull(name) if canMakeFilterOn(name, null) => - makeNotEq.lift(nameToType(name)).map(_(name, null)) + makeNotEq.lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldName, null)) case sources.EqualTo(name, value) if canMakeFilterOn(name, value) => - makeEq.lift(nameToType(name)).map(_(name, value)) + makeEq.lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldName, value)) case sources.Not(sources.EqualTo(name, value)) if canMakeFilterOn(name, value) => - makeNotEq.lift(nameToType(name)).map(_(name, value)) + makeNotEq.lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldName, value)) case sources.EqualNullSafe(name, value) if canMakeFilterOn(name, value) => - makeEq.lift(nameToType(name)).map(_(name, value)) + makeEq.lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldName, value)) case sources.Not(sources.EqualNullSafe(name, value)) if canMakeFilterOn(name, value) => - makeNotEq.lift(nameToType(name)).map(_(name, value)) + makeNotEq.lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldName, value)) case sources.LessThan(name, value) if canMakeFilterOn(name, value) => - makeLt.lift(nameToType(name)).map(_(name, value)) + makeLt.lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldName, value)) case sources.LessThanOrEqual(name, value) if canMakeFilterOn(name, value) => - makeLtEq.lift(nameToType(name)).map(_(name, value)) + makeLtEq.lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldName, value)) case sources.GreaterThan(name, value) if canMakeFilterOn(name, value) => - makeGt.lift(nameToType(name)).map(_(name, value)) + makeGt.lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldName, value)) case sources.GreaterThanOrEqual(name, value) if canMakeFilterOn(name, value) => - makeGtEq.lift(nameToType(name)).map(_(name, value)) + makeGtEq.lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldName, value)) case sources.And(lhs, rhs) => // At here, it is not safe to just convert one side if we do not understand the @@ -477,7 +512,8 @@ private[parquet] class ParquetFilters( case sources.In(name, values) if canMakeFilterOn(name, values.head) && values.distinct.length <= pushDownInFilterThreshold => values.distinct.flatMap { v => - makeEq.lift(nameToType(name)).map(_(name, v)) + makeEq.lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldName, v)) }.reduceLeftOption(FilterApi.or) case sources.StringStartsWith(name, prefix) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index be4f498c921a..7ebb75009555 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -25,6 +25,7 @@ import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate, Operato import org.apache.parquet.filter2.predicate.FilterApi._ import org.apache.parquet.filter2.predicate.Operators.{Column => _, _} +import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ @@ -60,7 +61,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex private lazy val parquetFilters = new ParquetFilters(conf.parquetFilterPushDownDate, conf.parquetFilterPushDownTimestamp, conf.parquetFilterPushDownDecimal, conf.parquetFilterPushDownStringStartWith, - conf.parquetFilterPushDownInFilterThreshold) + conf.parquetFilterPushDownInFilterThreshold, conf.caseSensitiveAnalysis) override def beforeEach(): Unit = { super.beforeEach() @@ -1021,6 +1022,118 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } } + + test("SPARK-25207: Case-insensitive field resolution for pushdown when reading parquet") { + def createParquetFilter(caseSensitive: Boolean): ParquetFilters = { + new ParquetFilters(conf.parquetFilterPushDownDate, conf.parquetFilterPushDownTimestamp, + conf.parquetFilterPushDownDecimal, conf.parquetFilterPushDownStringStartWith, + conf.parquetFilterPushDownInFilterThreshold, caseSensitive) + } + val caseSensitiveParquetFilters = createParquetFilter(caseSensitive = true) + val caseInsensitiveParquetFilters = createParquetFilter(caseSensitive = false) + + def testCaseInsensitiveResolution( + schema: StructType, + expected: FilterPredicate, + filter: sources.Filter): Unit = { + val parquetSchema = new SparkToParquetSchemaConverter(conf).convert(schema) + + assertResult(Some(expected)) { + caseInsensitiveParquetFilters.createFilter(parquetSchema, filter) + } + assertResult(None) { + caseSensitiveParquetFilters.createFilter(parquetSchema, filter) + } + } + + val schema = StructType(Seq(StructField("cint", IntegerType))) + + testCaseInsensitiveResolution( + schema, FilterApi.eq(intColumn("cint"), null.asInstanceOf[Integer]), sources.IsNull("CINT")) + + testCaseInsensitiveResolution( + schema, + FilterApi.notEq(intColumn("cint"), null.asInstanceOf[Integer]), + sources.IsNotNull("CINT")) + + testCaseInsensitiveResolution( + schema, FilterApi.eq(intColumn("cint"), 1000: Integer), sources.EqualTo("CINT", 1000)) + + testCaseInsensitiveResolution( + schema, + FilterApi.notEq(intColumn("cint"), 1000: Integer), + sources.Not(sources.EqualTo("CINT", 1000))) + + testCaseInsensitiveResolution( + schema, FilterApi.eq(intColumn("cint"), 1000: Integer), sources.EqualNullSafe("CINT", 1000)) + + testCaseInsensitiveResolution( + schema, + FilterApi.notEq(intColumn("cint"), 1000: Integer), + sources.Not(sources.EqualNullSafe("CINT", 1000))) + + testCaseInsensitiveResolution( + schema, + FilterApi.lt(intColumn("cint"), 1000: Integer), sources.LessThan("CINT", 1000)) + + testCaseInsensitiveResolution( + schema, + FilterApi.ltEq(intColumn("cint"), 1000: Integer), + sources.LessThanOrEqual("CINT", 1000)) + + testCaseInsensitiveResolution( + schema, FilterApi.gt(intColumn("cint"), 1000: Integer), sources.GreaterThan("CINT", 1000)) + + testCaseInsensitiveResolution( + schema, + FilterApi.gtEq(intColumn("cint"), 1000: Integer), + sources.GreaterThanOrEqual("CINT", 1000)) + + testCaseInsensitiveResolution( + schema, + FilterApi.or( + FilterApi.eq(intColumn("cint"), 10: Integer), + FilterApi.eq(intColumn("cint"), 20: Integer)), + sources.In("CINT", Array(10, 20))) + + val dupFieldSchema = StructType( + Seq(StructField("cint", IntegerType), StructField("cINT", IntegerType))) + val dupParquetSchema = new SparkToParquetSchemaConverter(conf).convert(dupFieldSchema) + assertResult(None) { + caseInsensitiveParquetFilters.createFilter( + dupParquetSchema, sources.EqualTo("CINT", 1000)) + } + } + + test("SPARK-25207: exception when duplicate fields in case-insensitive mode") { + withTempPath { dir => + val count = 10 + val tableName = "spark_25207" + val tableDir = dir.getAbsoluteFile + "/table" + withTable(tableName) { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + spark.range(count).selectExpr("id as A", "id as B", "id as b") + .write.mode("overwrite").parquet(tableDir) + } + sql( + s""" + |CREATE TABLE $tableName (A LONG, B LONG) USING PARQUET LOCATION '$tableDir' + """.stripMargin) + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + val e = intercept[SparkException] { + sql(s"select a from $tableName where b > 0").collect() + } + assert(e.getCause.isInstanceOf[RuntimeException] && e.getCause.getMessage.contains( + """Found duplicate field(s) "B": [B, b] in case-insensitive mode""")) + } + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + checkAnswer(sql(s"select A from $tableName where B > 0"), (1 until count).map(Row(_))) + } + } + } + } } class NumRowGroupsAcc extends AccumulatorV2[Integer, Integer] {