diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala index 6ee24ee0c1913..416329b8bc58c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala @@ -268,4 +268,9 @@ class MetadataBuilder { map.put(key, value) this } + + def remove(key: String): this.type = { + map.remove(key) + this + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 3bd733fa2d26c..da0c92864e9bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -334,6 +334,8 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru object StructType extends AbstractDataType { + private[sql] val metadataKeyForOptionalField = "_OPTIONAL_" + override private[sql] def defaultConcreteType: DataType = new StructType override private[sql] def acceptsType(other: DataType): Boolean = { @@ -359,6 +361,18 @@ object StructType extends AbstractDataType { protected[sql] def fromAttributes(attributes: Seq[Attribute]): StructType = StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))) + def removeMetadata(key: String, dt: DataType): DataType = + dt match { + case StructType(fields) => + val newFields = fields.map { f => + val mb = new MetadataBuilder() + f.copy(dataType = removeMetadata(key, f.dataType), + metadata = mb.withMetadata(f.metadata).remove(key).build()) + } + StructType(newFields) + case _ => dt + } + private[sql] def merge(left: DataType, right: DataType): DataType = (left, right) match { case (ArrayType(leftElementType, leftContainsNull), @@ -376,24 +390,32 @@ object StructType extends AbstractDataType { case (StructType(leftFields), StructType(rightFields)) => val newFields = ArrayBuffer.empty[StructField] + // This metadata will record the fields that only exist in one of two StructTypes + val optionalMeta = new MetadataBuilder() val rightMapped = fieldsMap(rightFields) leftFields.foreach { case leftField @ StructField(leftName, leftType, leftNullable, _) => rightMapped.get(leftName) .map { case rightField @ StructField(_, rightType, rightNullable, _) => - leftField.copy( - dataType = merge(leftType, rightType), - nullable = leftNullable || rightNullable) - } - .orElse(Some(leftField)) + leftField.copy( + dataType = merge(leftType, rightType), + nullable = leftNullable || rightNullable) + } + .orElse { + optionalMeta.putBoolean(metadataKeyForOptionalField, true) + Some(leftField.copy(metadata = optionalMeta.build())) + } .foreach(newFields += _) } val leftMapped = fieldsMap(leftFields) rightFields .filterNot(f => leftMapped.get(f.name).nonEmpty) - .foreach(newFields += _) + .foreach { f => + optionalMeta.putBoolean(metadataKeyForOptionalField, true) + newFields += f.copy(metadata = optionalMeta.build()) + } StructType(newFields) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 706ecd29d1355..c2bbca7c33f28 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -122,7 +122,9 @@ class DataTypeSuite extends SparkFunSuite { val right = StructType(List()) val merged = left.merge(right) - assert(merged === left) + assert(DataType.equalsIgnoreCompatibleNullability(merged, left)) + assert(merged("a").metadata.getBoolean(StructType.metadataKeyForOptionalField)) + assert(merged("b").metadata.getBoolean(StructType.metadataKeyForOptionalField)) } test("merge where left is empty") { @@ -135,8 +137,9 @@ class DataTypeSuite extends SparkFunSuite { val merged = left.merge(right) - assert(right === merged) - + assert(DataType.equalsIgnoreCompatibleNullability(merged, right)) + assert(merged("a").metadata.getBoolean(StructType.metadataKeyForOptionalField)) + assert(merged("b").metadata.getBoolean(StructType.metadataKeyForOptionalField)) } test("merge where both are non-empty") { @@ -154,7 +157,10 @@ class DataTypeSuite extends SparkFunSuite { val merged = left.merge(right) - assert(merged === expected) + assert(DataType.equalsIgnoreCompatibleNullability(merged, expected)) + assert(merged("a").metadata.getBoolean(StructType.metadataKeyForOptionalField)) + assert(merged("b").metadata.getBoolean(StructType.metadataKeyForOptionalField)) + assert(merged("c").metadata.getBoolean(StructType.metadataKeyForOptionalField)) } test("merge where right contains type conflict") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index e9b734b0abf50..5a5cb5cf03d4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -207,11 +207,26 @@ private[sql] object ParquetFilters { */ } + /** + * SPARK-11955: The optional fields will have metadata StructType.metadataKeyForOptionalField. + * These fields only exist in one side of merged schemas. Due to that, we can't push down filters + * using such fields, otherwise Parquet library will throw exception. Here we filter out such + * fields. + */ + private def getFieldMap(dataType: DataType): Array[(String, DataType)] = dataType match { + case StructType(fields) => + fields.filter { f => + !f.metadata.contains(StructType.metadataKeyForOptionalField) || + !f.metadata.getBoolean(StructType.metadataKeyForOptionalField) + }.map(f => f.name -> f.dataType) ++ fields.flatMap { f => getFieldMap(f.dataType) } + case _ => Array.empty[(String, DataType)] + } + /** * Converts data sources filters to Parquet filter predicates. */ def createFilter(schema: StructType, predicate: sources.Filter): Option[FilterPredicate] = { - val dataTypeOf = schema.map(f => f.name -> f.dataType).toMap + val dataTypeOf = getFieldMap(schema).toMap relaxParquetValidTypeMap @@ -231,29 +246,29 @@ private[sql] object ParquetFilters { // Probably I missed something and obviously this should be changed. predicate match { - case sources.IsNull(name) => + case sources.IsNull(name) if dataTypeOf.contains(name) => makeEq.lift(dataTypeOf(name)).map(_(name, null)) - case sources.IsNotNull(name) => + case sources.IsNotNull(name) if dataTypeOf.contains(name) => makeNotEq.lift(dataTypeOf(name)).map(_(name, null)) - case sources.EqualTo(name, value) => + case sources.EqualTo(name, value) if dataTypeOf.contains(name) => makeEq.lift(dataTypeOf(name)).map(_(name, value)) - case sources.Not(sources.EqualTo(name, value)) => + case sources.Not(sources.EqualTo(name, value)) if dataTypeOf.contains(name) => makeNotEq.lift(dataTypeOf(name)).map(_(name, value)) - case sources.EqualNullSafe(name, value) => + case sources.EqualNullSafe(name, value) if dataTypeOf.contains(name) => makeEq.lift(dataTypeOf(name)).map(_(name, value)) - case sources.Not(sources.EqualNullSafe(name, value)) => + case sources.Not(sources.EqualNullSafe(name, value)) if dataTypeOf.contains(name) => makeNotEq.lift(dataTypeOf(name)).map(_(name, value)) - case sources.LessThan(name, value) => + case sources.LessThan(name, value) if dataTypeOf.contains(name) => makeLt.lift(dataTypeOf(name)).map(_(name, value)) - case sources.LessThanOrEqual(name, value) => + case sources.LessThanOrEqual(name, value) if dataTypeOf.contains(name) => makeLtEq.lift(dataTypeOf(name)).map(_(name, value)) - case sources.GreaterThan(name, value) => + case sources.GreaterThan(name, value) if dataTypeOf.contains(name) => makeGt.lift(dataTypeOf(name)).map(_(name, value)) - case sources.GreaterThanOrEqual(name, value) => + case sources.GreaterThanOrEqual(name, value) if dataTypeOf.contains(name) => makeGtEq.lift(dataTypeOf(name)).map(_(name, value)) case sources.In(name, valueSet) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index b460ec1d26047..f87590095d344 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -258,7 +258,12 @@ private[sql] class ParquetRelation( job.setOutputFormatClass(classOf[ParquetOutputFormat[Row]]) ParquetOutputFormat.setWriteSupportClass(job, classOf[CatalystWriteSupport]) - CatalystWriteSupport.setSchema(dataSchema, conf) + + // We want to clear this temporary metadata from saving into Parquet file. + // This metadata is only useful for detecting optional columns when pushdowning filters. + val dataSchemaToWrite = StructType.removeMetadata(StructType.metadataKeyForOptionalField, + dataSchema).asInstanceOf[StructType] + CatalystWriteSupport.setSchema(dataSchemaToWrite, conf) // Sets flags for `CatalystSchemaConverter` (which converts Catalyst schema to Parquet schema) // and `CatalystWriteSupport` (writing actual rows to Parquet files). @@ -304,10 +309,6 @@ private[sql] class ParquetRelation( val assumeBinaryIsString = sqlContext.conf.isParquetBinaryAsString val assumeInt96IsTimestamp = sqlContext.conf.isParquetINT96AsTimestamp - // When merging schemas is enabled and the column of the given filter does not exist, - // Parquet emits an exception which is an issue of Parquet (PARQUET-389). - val safeParquetFilterPushDown = !shouldMergeSchemas && parquetFilterPushDown - // Parquet row group size. We will use this value as the value for // mapreduce.input.fileinputformat.split.minsize and mapred.min.split.size if the value // of these flags are smaller than the parquet row group size. @@ -321,7 +322,7 @@ private[sql] class ParquetRelation( dataSchema, parquetBlockSize, useMetadataCache, - safeParquetFilterPushDown, + parquetFilterPushDown, assumeBinaryIsString, assumeInt96IsTimestamp) _ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 97c5313f0feff..1796b3af0e37a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, LogicalRelation} +import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -379,9 +380,47 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex // If the "c = 1" filter gets pushed down, this query will throw an exception which // Parquet emits. This is a Parquet issue (PARQUET-389). + val df = sqlContext.read.parquet(pathOne, pathTwo).filter("c = 1").selectExpr("c", "b", "a") checkAnswer( - sqlContext.read.parquet(pathOne, pathTwo).filter("c = 1").selectExpr("c", "b", "a"), - (1 to 1).map(i => Row(i, i.toString, null))) + df, + Row(1, "1", null)) + + // The fields "a" and "c" only exist in one Parquet file. + assert(df.schema("a").metadata.getBoolean(StructType.metadataKeyForOptionalField)) + assert(df.schema("c").metadata.getBoolean(StructType.metadataKeyForOptionalField)) + + val pathThree = s"${dir.getCanonicalPath}/table3" + df.write.parquet(pathThree) + + // We will remove the temporary metadata when writing Parquet file. + val schema = sqlContext.read.parquet(pathThree).schema + assert(schema.forall(!_.metadata.contains(StructType.metadataKeyForOptionalField))) + + val pathFour = s"${dir.getCanonicalPath}/table4" + val dfStruct = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") + dfStruct.select(struct("a").as("s")).write.parquet(pathFour) + + val pathFive = s"${dir.getCanonicalPath}/table5" + val dfStruct2 = sparkContext.parallelize(Seq((1, 1))).toDF("c", "b") + dfStruct2.select(struct("c").as("s")).write.parquet(pathFive) + + // If the "s.c = 1" filter gets pushed down, this query will throw an exception which + // Parquet emits. + val dfStruct3 = sqlContext.read.parquet(pathFour, pathFive).filter("s.c = 1") + .selectExpr("s") + checkAnswer(dfStruct3, Row(Row(null, 1))) + + // The fields "s.a" and "s.c" only exist in one Parquet file. + val field = dfStruct3.schema("s").dataType.asInstanceOf[StructType] + assert(field("a").metadata.getBoolean(StructType.metadataKeyForOptionalField)) + assert(field("c").metadata.getBoolean(StructType.metadataKeyForOptionalField)) + + val pathSix = s"${dir.getCanonicalPath}/table6" + dfStruct3.write.parquet(pathSix) + + // We will remove the temporary metadata when writing Parquet file. + val forPathSix = sqlContext.read.parquet(pathSix).schema + assert(forPathSix.forall(!_.metadata.contains(StructType.metadataKeyForOptionalField))) } } }