diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/GetStructFieldObject.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GetStructFieldObject.scala similarity index 93% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/GetStructFieldObject.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/GetStructFieldObject.scala index 033792a9ac72..c88b2f8c034f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/GetStructFieldObject.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GetStructFieldObject.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.planning +package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField} import org.apache.spark.sql.types.StructField @@ -25,7 +25,7 @@ import org.apache.spark.sql.types.StructField * This is in contrast to the [[GetStructField]] case class extractor which returns the field * ordinal instead of the field itself. */ -private[planning] object GetStructFieldObject { +private[execution] object GetStructFieldObject { def unapply(getStructField: GetStructField): Option[(Expression, StructField)] = Some(( getStructField.child, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/ProjectionOverSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ProjectionOverSchema.scala similarity index 72% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/ProjectionOverSchema.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/ProjectionOverSchema.scala index 39d8a102d605..2236f18b0da1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/ProjectionOverSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ProjectionOverSchema.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.planning +package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -26,29 +26,32 @@ import org.apache.spark.sql.types._ * are adjusted to fit the schema. All other expressions are left as-is. This * class is motivated by columnar nested schema pruning. */ -case class ProjectionOverSchema(schema: StructType) { +private[execution] case class ProjectionOverSchema(schema: StructType) { private val fieldNames = schema.fieldNames.toSet def unapply(expr: Expression): Option[Expression] = getProjection(expr) private def getProjection(expr: Expression): Option[Expression] = expr match { - case a @ AttributeReference(name, _, _, _) if (fieldNames.contains(name)) => - Some(a.copy(dataType = schema(name).dataType)(a.exprId, a.qualifier)) + case a: AttributeReference if fieldNames.contains(a.name) => + Some(a.copy(dataType = schema(a.name).dataType)(a.exprId, a.qualifier)) case GetArrayItem(child, arrayItemOrdinal) => getProjection(child).map { projection => GetArrayItem(projection, arrayItemOrdinal) } - case GetArrayStructFields(child, StructField(name, _, _, _), _, numFields, containsNull) => - getProjection(child).map(p => (p, p.dataType)).map { + case a: GetArrayStructFields => + getProjection(a.child).map(p => (p, p.dataType)).map { case (projection, ArrayType(projSchema @ StructType(_), _)) => GetArrayStructFields(projection, - projSchema(name), projSchema.fieldIndex(name), projSchema.size, containsNull) + projSchema(a.field.name), + projSchema.fieldIndex(a.field.name), + projSchema.size, + a.containsNull) } case GetMapValue(child, key) => getProjection(child).map { projection => GetMapValue(projection, key) } - case GetStructFieldObject(child, StructField(name, _, _, _)) => + case GetStructFieldObject(child, field: StructField) => getProjection(child).map(p => (p, p.dataType)).map { - case (projection, projSchema @ StructType(_)) => - GetStructField(projection, projSchema.fieldIndex(name)) + case (projection, projSchema: StructType) => + GetStructField(projection, projSchema.fieldIndex(field.name)) } case _ => None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/SelectedField.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SelectedField.scala similarity index 85% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/SelectedField.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/SelectedField.scala index dc1e00290bed..0e7c593f9fb6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/SelectedField.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SelectedField.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.planning +package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -24,27 +24,27 @@ import org.apache.spark.sql.types._ * A Scala extractor that builds a [[org.apache.spark.sql.types.StructField]] from a Catalyst * complex type extractor. For example, consider a relation with the following schema: * - * {{{ - * root - * |-- name: struct (nullable = true) - * | |-- first: string (nullable = true) - * | |-- last: string (nullable = true) - * }}} + * {{{ + * root + * |-- name: struct (nullable = true) + * | |-- first: string (nullable = true) + * | |-- last: string (nullable = true) + * }}} * * Further, suppose we take the select expression `name.first`. This will parse into an * `Alias(child, "first")`. Ignoring the alias, `child` matches the following pattern: * - * {{{ - * GetStructFieldObject( - * AttributeReference("name", StructType(_), _, _), - * StructField("first", StringType, _, _)) - * }}} + * {{{ + * GetStructFieldObject( + * AttributeReference("name", StructType(_), _, _), + * StructField("first", StringType, _, _)) + * }}} * * [[SelectedField]] converts that expression into * - * {{{ - * StructField("name", StructType(Array(StructField("first", StringType)))) - * }}} + * {{{ + * StructField("name", StructType(Array(StructField("first", StringType)))) + * }}} * * by mapping each complex type extractor to a [[org.apache.spark.sql.types.StructField]] with the * same name as its child (or "parent" going right to left in the select expression) and a data @@ -54,7 +54,7 @@ import org.apache.spark.sql.types._ * * @param expr the top-level complex type extractor */ -object SelectedField { +private[execution] object SelectedField { def unapply(expr: Expression): Option[StructField] = { // If this expression is an alias, work on its child instead val unaliased = expr match { @@ -85,16 +85,16 @@ object SelectedField { field @ StructField(name, dataType, nullable, metadata), _, _, _) => val childField = fieldOpt.map(field => StructField(name, wrapStructType(dataType, field), - nullable, metadata)).getOrElse(field) - selectField(child, Some(childField)) + nullable, metadata)).orElse(Some(field)) + selectField(child, childField) // Handles case "expr0.field", where "expr0" is of array type. case GetArrayStructFields(child, - field @ StructField(name, dataType, nullable, metadata), _, _, containsNull) => + field @ StructField(name, dataType, nullable, metadata), _, _, _) => val childField = fieldOpt.map(field => StructField(name, wrapStructType(dataType, field), - nullable, metadata)).getOrElse(field) - selectField(child, Some(childField)) + nullable, metadata)).orElse(Some(field)) + selectField(child, childField) // Handles case "expr0.field[key]", where "expr0" is of struct type and "expr0.field" is of // map type. case GetMapValue(x @ GetStructFieldObject(child, field @ StructField(name, @@ -102,8 +102,8 @@ object SelectedField { nullable, metadata)), _) => val childField = fieldOpt.map(field => StructField(name, wrapStructType(dataType, field), - nullable, metadata)).getOrElse(field) - selectField(child, Some(childField)) + nullable, metadata)).orElse(Some(field)) + selectField(child, childField) // Handles case "expr0.field[key]", where "expr0.field" is of map type. case GetMapValue(child, _) => selectField(child, fieldOpt) @@ -112,8 +112,8 @@ object SelectedField { field @ StructField(name, dataType, nullable, metadata)) => val childField = fieldOpt.map(field => StructField(name, wrapStructType(dataType, field), - nullable, metadata)).getOrElse(field) - selectField(child, Some(childField)) + nullable, metadata)).orElse(Some(field)) + selectField(child, childField) case _ => None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruning.scala index 47465051a739..15b8615ec365 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruning.scala @@ -18,9 +18,10 @@ package org.apache.spark.sql.execution.datasources.parquet import org.apache.spark.sql.catalyst.expressions.{And, Attribute, Expression, NamedExpression} -import org.apache.spark.sql.catalyst.planning.{PhysicalOperation, ProjectionOverSchema, SelectedField} +import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{ProjectionOverSchema, SelectedField} import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructField, StructType} @@ -42,78 +43,28 @@ private[sql] object ParquetSchemaPruning extends Rule[LogicalPlan] { private def apply0(plan: LogicalPlan): LogicalPlan = plan transformDown { case op @ PhysicalOperation(projects, filters, - l @ LogicalRelation(hadoopFsRelation @ HadoopFsRelation(_, _, - dataSchema, _, parquetFormat: ParquetFileFormat, _), _, _, _)) => - val projectionRootFields = projects.flatMap(getRootFields) - val filterRootFields = filters.flatMap(getRootFields) - val requestedRootFields = (projectionRootFields ++ filterRootFields).distinct - - // If [[requestedRootFields]] includes a nested field, continue. Otherwise, - // return [[op]] - if (requestedRootFields.exists { case RootField(_, derivedFromAtt) => !derivedFromAtt }) { - // Merge the requested root fields into a single schema. Note the ordering of the fields - // in the resulting schema may differ from their ordering in the logical relation's - // original schema - val mergedSchema = requestedRootFields - .map { case RootField(field, _) => StructType(Array(field)) } - .reduceLeft(_ merge _) - val dataSchemaFieldNames = dataSchema.fieldNames.toSet - val mergedDataSchema = - StructType(mergedSchema.filter(f => dataSchemaFieldNames.contains(f.name))) - // Sort the fields of [[mergedDataSchema]] according to their order in [[dataSchema]], - // recursively. This makes [[mergedDataSchema]] a pruned schema of [[dataSchema]] - val prunedDataSchema = - sortLeftFieldsByRight(mergedDataSchema, dataSchema).asInstanceOf[StructType] + l @ LogicalRelation(hadoopFsRelation: HadoopFsRelation, _, _, _)) + if canPruneRelation(hadoopFsRelation) => + val requestedRootFields = identifyRootFields(projects, filters) + + // If requestedRootFields includes a nested field, continue. Otherwise, + // return op + if (requestedRootFields.exists { root: RootField => !root.derivedFromAtt }) { + val dataSchema = hadoopFsRelation.dataSchema + val prunedDataSchema = pruneDataSchema(dataSchema, requestedRootFields) // If the data schema is different from the pruned data schema, continue. Otherwise, - // return [[op]]. We effect this comparison by counting the number of "leaf" fields in - // each schemata, assuming the fields in [[prunedDataSchema]] are a subset of the fields - // in [[dataSchema]]. + // return op. We effect this comparison by counting the number of "leaf" fields in + // each schemata, assuming the fields in prunedDataSchema are a subset of the fields + // in dataSchema. if (countLeaves(dataSchema) > countLeaves(prunedDataSchema)) { val prunedParquetRelation = hadoopFsRelation.copy(dataSchema = prunedDataSchema)(hadoopFsRelation.sparkSession) - // We need to replace the expression ids of the pruned relation output attributes - // with the expression ids of the original relation output attributes so that - // references to the original relation's output are not broken - val outputIdMap = l.output.map(att => (att.name, att.exprId)).toMap - val prunedRelationOutput = - prunedParquetRelation - .schema - .toAttributes - .map { - case att if outputIdMap.contains(att.name) => - att.withExprId(outputIdMap(att.name)) - case att => att - } - val prunedRelation = - l.copy(relation = prunedParquetRelation, output = prunedRelationOutput) - + val prunedRelation = buildPrunedRelation(l, prunedParquetRelation) val projectionOverSchema = ProjectionOverSchema(prunedDataSchema) - // Construct a new target for our projection by rewriting and - // including the original filters where available - val projectionChild = - if (filters.nonEmpty) { - val projectedFilters = filters.map(_.transformDown { - case projectionOverSchema(expr) => expr - }) - val newFilterCondition = projectedFilters.reduce(And) - Filter(newFilterCondition, prunedRelation) - } else { - prunedRelation - } - - // Construct the new projections of our [[Project]] by - // rewriting the original projections - val newProjects = projects.map(_.transformDown { - case projectionOverSchema(expr) => expr - }).map { case expr: NamedExpression => expr } - - logDebug(s"New projects:\n${newProjects.map(_.treeString).mkString("\n")}") - logDebug(s"Pruned data schema:\n${prunedDataSchema.treeString}") - - Project(newProjects, projectionChild) + buildNewProjection(projects, filters, prunedRelation, projectionOverSchema) } else { op } @@ -122,25 +73,113 @@ private[sql] object ParquetSchemaPruning extends Rule[LogicalPlan] { } } + /** + * Checks to see if the given relation is Parquet and can be pruned. + */ + private def canPruneRelation(fsRelation: HadoopFsRelation) = + fsRelation.fileFormat.isInstanceOf[ParquetFileFormat] + + /** + * Returns the set of fields from the Parquet file that the query plan needs. + */ + private def identifyRootFields(projects: Seq[NamedExpression], filters: Seq[Expression]) = { + val projectionRootFields = projects.flatMap(getRootFields) + val filterRootFields = filters.flatMap(getRootFields) + + (projectionRootFields ++ filterRootFields).distinct + } + + /** + * Builds the new output [[Project]] Spark SQL operator that has the pruned output relation. + */ + private def buildNewProjection( + projects: Seq[NamedExpression], filters: Seq[Expression], prunedRelation: LogicalRelation, + projectionOverSchema: ProjectionOverSchema) = { + // Construct a new target for our projection by rewriting and + // including the original filters where available + val projectionChild = + if (filters.nonEmpty) { + val projectedFilters = filters.map(_.transformDown { + case projectionOverSchema(expr) => expr + }) + val newFilterCondition = projectedFilters.reduce(And) + Filter(newFilterCondition, prunedRelation) + } else { + prunedRelation + } + + // Construct the new projections of our Project by + // rewriting the original projections + val newProjects = projects.map(_.transformDown { + case projectionOverSchema(expr) => expr + }).map { case expr: NamedExpression => expr } + + if (log.isDebugEnabled) { + logDebug(s"New projects:\n${newProjects.map(_.treeString).mkString("\n")}") + } + + Project(newProjects, projectionChild) + } + + /** + * Filters the schema from the given file by the requested fields. + * Schema field ordering from the file is preserved. + */ + private def pruneDataSchema( + fileDataSchema: StructType, + requestedRootFields: Seq[RootField]) = { + // Merge the requested root fields into a single schema. Note the ordering of the fields + // in the resulting schema may differ from their ordering in the logical relation's + // original schema + val mergedSchema = requestedRootFields + .map { case RootField(field, _) => StructType(Array(field)) } + .reduceLeft(_ merge _) + val dataSchemaFieldNames = fileDataSchema.fieldNames.toSet + val mergedDataSchema = + StructType(mergedSchema.filter(f => dataSchemaFieldNames.contains(f.name))) + // Sort the fields of mergedDataSchema according to their order in dataSchema, + // recursively. This makes mergedDataSchema a pruned schema of dataSchema + sortLeftFieldsByRight(mergedDataSchema, fileDataSchema).asInstanceOf[StructType] + } + + private def buildPrunedRelation( + outputRelation: LogicalRelation, + parquetRelation: HadoopFsRelation) = { + // We need to replace the expression ids of the pruned relation output attributes + // with the expression ids of the original relation output attributes so that + // references to the original relation's output are not broken + val outputIdMap = outputRelation.output.map(att => (att.name, att.exprId)).toMap + val prunedRelationOutput = + parquetRelation + .schema + .toAttributes + .map { + case att if outputIdMap.contains(att.name) => + att.withExprId(outputIdMap(att.name)) + case att => att + } + outputRelation.copy(relation = parquetRelation, output = prunedRelationOutput) + } + /** * Gets the root (aka top-level, no-parent) [[StructField]]s for the given [[Expression]]. - * When [[expr]] is an [[Attribute]], construct a field around it and indicate that that + * When expr is an [[Attribute]], construct a field around it and indicate that that * field was derived from an attribute. */ private def getRootFields(expr: Expression): Seq[RootField] = { expr match { case att: Attribute => - RootField(StructField(att.name, att.dataType, att.nullable), true) :: Nil - case SelectedField(field) => RootField(field, false) :: Nil + RootField(StructField(att.name, att.dataType, att.nullable), derivedFromAtt = true) :: Nil + case SelectedField(field) => RootField(field, derivedFromAtt = false) :: Nil case _ => expr.children.flatMap(getRootFields) } } /** - * Counts the "leaf" fields of the given [[dataType]]. Informally, this is the + * Counts the "leaf" fields of the given dataType. Informally, this is the * number of fields of non-complex data type in the tree representation of - * [[dataType]]. + * [[DataType]]. */ private def countLeaves(dataType: DataType): Int = { dataType match { @@ -153,10 +192,10 @@ private[sql] object ParquetSchemaPruning extends Rule[LogicalPlan] { } /** - * Sorts the fields and descendant fields of structs in [[left]] according to their order in - * [[right]]. This function assumes that the fields of [[left]] are a subset of the fields of - * [[right]], recursively. That is, [[left]] is a "subschema" of [[right]], ignoring order of - * fields. + * Sorts the fields and descendant fields of structs in left according to their order in + * right. This function assumes that the fields of left are a subset of the fields of + * right, recursively. That is, left is a "subschema" of right, ignoring order of + * fields. */ private def sortLeftFieldsByRight(left: DataType, right: DataType): DataType = (left, right) match { @@ -179,7 +218,7 @@ private[sql] object ParquetSchemaPruning extends Rule[LogicalPlan] { StructField(fieldName, sortedLeftFieldType) } StructType(sortedLeftFields) - case (left, _) => left + case _ => left } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/FileSchemaPruningTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/FileSchemaPruningTest.scala deleted file mode 100644 index 52911034ff06..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/FileSchemaPruningTest.scala +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import org.scalactic.Equality -import org.scalatest.Assertions - -import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.catalyst.parser.CatalystSqlParser -import org.apache.spark.sql.types.StructType - -private[sql] trait FileSchemaPruningTest { - _: Assertions => - - private val schemaEquality = new Equality[StructType] { - override def areEqual(a: StructType, b: Any) = - b match { - case otherType: StructType => a sameType otherType - case _ => false - } - } - - protected def checkScan(df: DataFrame, expectedSchemaCatalogStrings: String*): Unit = { - checkScanSchemata(df, expectedSchemaCatalogStrings: _*) - // We check here that we can execute the query without throwing an exception. The results - // themselves are irrelevant, and should be checked elsewhere as needed - df.collect() - } - - private def checkScanSchemata(df: DataFrame, expectedSchemaCatalogStrings: String*): Unit = { - val fileSourceScanSchemata = - df.queryExecution.executedPlan.collect { - case scan: FileSourceScanExec => scan.requiredSchema - } - assert(fileSourceScanSchemata.size === expectedSchemaCatalogStrings.size, - s"Found ${fileSourceScanSchemata.size} file sources in dataframe, " + - s"but expected ${expectedSchemaCatalogStrings}") - fileSourceScanSchemata.zip(expectedSchemaCatalogStrings).foreach { - case (scanSchema, expectedScanSchemaCatalogString) => - val expectedScanSchema = CatalystSqlParser.parseDataType(expectedScanSchemaCatalogString) - implicit val equality = schemaEquality - assert(scanSchema === expectedScanSchema) - } - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/planning/SelectedFieldSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SelectedFieldSuite.scala similarity index 61% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/planning/SelectedFieldSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/SelectedFieldSuite.scala index f4ed7570ab25..05f7e3ce8388 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/planning/SelectedFieldSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SelectedFieldSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.planning +package org.apache.spark.sql.execution import org.scalatest.BeforeAndAfterAll import org.scalatest.exceptions.TestFailedException @@ -27,36 +27,14 @@ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.types._ -// scalastyle:off line.size.limit class SelectedFieldSuite extends SparkFunSuite with BeforeAndAfterAll { + private val ignoredField = StructField("col1", StringType, nullable = false) + // The test schema as a tree string, i.e. `schema.treeString` // root // |-- col1: string (nullable = false) // |-- col2: struct (nullable = true) // | |-- field1: integer (nullable = true) - // | |-- field2: array (nullable = true) - // | | |-- element: integer (containsNull = false) - // | |-- field3: array (nullable = false) - // | | |-- element: struct (containsNull = true) - // | | | |-- subfield1: integer (nullable = true) - // | | | |-- subfield2: integer (nullable = true) - // | | | |-- subfield3: array (nullable = true) - // | | | | |-- element: integer (containsNull = true) - // | |-- field4: map (nullable = true) - // | | |-- key: string - // | | |-- value: struct (valueContainsNull = false) - // | | | |-- subfield1: integer (nullable = true) - // | | | |-- subfield2: array (nullable = true) - // | | | | |-- element: integer (containsNull = false) - // | |-- field5: array (nullable = false) - // | | |-- element: struct (containsNull = true) - // | | | |-- subfield1: struct (nullable = false) - // | | | | |-- subsubfield1: integer (nullable = true) - // | | | | |-- subsubfield2: integer (nullable = true) - // | | | |-- subfield2: struct (nullable = true) - // | | | | |-- subsubfield1: struct (nullable = true) - // | | | | | |-- subsubsubfield1: string (nullable = true) - // | | | | |-- subsubfield2: integer (nullable = true) // | |-- field6: struct (nullable = true) // | | |-- subfield1: string (nullable = false) // | | |-- subfield2: string (nullable = true) @@ -64,80 +42,12 @@ class SelectedFieldSuite extends SparkFunSuite with BeforeAndAfterAll { // | | |-- subfield1: struct (nullable = true) // | | | |-- subsubfield1: integer (nullable = true) // | | | |-- subsubfield2: integer (nullable = true) - // | |-- field8: map (nullable = true) - // | | |-- key: string - // | | |-- value: array (valueContainsNull = false) - // | | | |-- element: struct (containsNull = true) - // | | | | |-- subfield1: integer (nullable = true) - // | | | | |-- subfield2: array (nullable = true) - // | | | | | |-- element: integer (containsNull = false) // | |-- field9: map (nullable = true) // | | |-- key: string // | | |-- value: integer (valueContainsNull = false) - // |-- col3: array (nullable = false) - // | |-- element: struct (containsNull = false) - // | | |-- field1: struct (nullable = true) - // | | | |-- subfield1: integer (nullable = false) - // | | | |-- subfield2: integer (nullable = true) - // | | |-- field2: map (nullable = true) - // | | | |-- key: string - // | | | |-- value: integer (valueContainsNull = false) - // |-- col4: map (nullable = false) - // | |-- key: string - // | |-- value: struct (valueContainsNull = false) - // | | |-- field1: struct (nullable = true) - // | | | |-- subfield1: integer (nullable = false) - // | | | |-- subfield2: integer (nullable = true) - // | | |-- field2: map (nullable = true) - // | | | |-- key: string - // | | | |-- value: integer (valueContainsNull = false) - // |-- col5: array (nullable = true) - // | |-- element: map (containsNull = true) - // | | |-- key: string - // | | |-- value: struct (valueContainsNull = false) - // | | | |-- field1: struct (nullable = true) - // | | | | |-- subfield1: integer (nullable = true) - // | | | | |-- subfield2: integer (nullable = true) - // |-- col6: map (nullable = true) - // | |-- key: string - // | |-- value: array (valueContainsNull = true) - // | | |-- element: struct (containsNull = false) - // | | | |-- field1: struct (nullable = true) - // | | | | |-- subfield1: integer (nullable = true) - // | | | | |-- subfield2: integer (nullable = true) - // |-- col7: array (nullable = true) - // | |-- element: struct (containsNull = true) - // | | |-- field1: integer (nullable = false) - // | | |-- field2: struct (nullable = true) - // | | | |-- subfield1: integer (nullable = false) - // | | |-- field3: array (nullable = true) - // | | | |-- element: struct (containsNull = true) - // | | | | |-- subfield1: integer (nullable = false) - // |-- col8: array (nullable = true) - // | |-- element: struct (containsNull = true) - // | | |-- field1: array (nullable = false) - // | | | |-- element: integer (containsNull = false) - private val schema = - StructType( - StructField("col1", StringType, nullable = false) :: + private val nestedComplex = StructType(ignoredField :: StructField("col2", StructType( StructField("field1", IntegerType) :: - StructField("field2", ArrayType(IntegerType, containsNull = false)) :: - StructField("field3", ArrayType(StructType( - StructField("subfield1", IntegerType) :: - StructField("subfield2", IntegerType) :: - StructField("subfield3", ArrayType(IntegerType)) :: Nil)), nullable = false) :: - StructField("field4", MapType(StringType, StructType( - StructField("subfield1", IntegerType) :: - StructField("subfield2", ArrayType(IntegerType, containsNull = false)) :: Nil), valueContainsNull = false)) :: - StructField("field5", ArrayType(StructType( - StructField("subfield1", StructType( - StructField("subsubfield1", IntegerType) :: - StructField("subsubfield2", IntegerType) :: Nil), nullable = false) :: - StructField("subfield2", StructType( - StructField("subsubfield1", StructType( - StructField("subsubsubfield1", StringType) :: Nil)) :: - StructField("subsubfield2", IntegerType) :: Nil)) :: Nil)), nullable = false) :: StructField("field6", StructType( StructField("subfield1", StringType, nullable = false) :: StructField("subfield2", StringType) :: Nil)) :: @@ -145,158 +55,178 @@ class SelectedFieldSuite extends SparkFunSuite with BeforeAndAfterAll { StructField("subfield1", StructType( StructField("subsubfield1", IntegerType) :: StructField("subsubfield2", IntegerType) :: Nil)) :: Nil)) :: - StructField("field8", MapType(StringType, ArrayType(StructType( - StructField("subfield1", IntegerType) :: - StructField("subfield2", ArrayType(IntegerType, containsNull = false)) :: Nil)), valueContainsNull = false)) :: - StructField("field9", MapType(StringType, IntegerType, valueContainsNull = false)) :: Nil)) :: - StructField("col3", ArrayType(StructType( - StructField("field1", StructType( - StructField("subfield1", IntegerType, nullable = false) :: - StructField("subfield2", IntegerType) :: Nil)) :: - StructField("field2", MapType(StringType, IntegerType, valueContainsNull = false)) :: Nil), containsNull = false), nullable = false) :: - StructField("col4", MapType(StringType, StructType( - StructField("field1", StructType( - StructField("subfield1", IntegerType, nullable = false) :: - StructField("subfield2", IntegerType) :: Nil)) :: - StructField("field2", MapType(StringType, IntegerType, valueContainsNull = false)) :: Nil), valueContainsNull = false), nullable = false) :: - StructField("col5", ArrayType(MapType(StringType, StructType( - StructField("field1", StructType( - StructField("subfield1", IntegerType) :: - StructField("subfield2", IntegerType) :: Nil)) :: Nil), valueContainsNull = false))) :: - StructField("col6", MapType(StringType, ArrayType(StructType( - StructField("field1", StructType( - StructField("subfield1", IntegerType) :: - StructField("subfield2", IntegerType) :: Nil)) :: Nil), containsNull = false))) :: - StructField("col7", ArrayType(StructType( - StructField("field1", IntegerType, nullable = false) :: - StructField("field2", StructType( - StructField("subfield1", IntegerType, nullable = false) :: Nil)) :: - StructField("field3", ArrayType(StructType( - StructField("subfield1", IntegerType, nullable = false) :: Nil))) :: Nil))) :: - StructField("col8", ArrayType(StructType( - StructField("field1", ArrayType(IntegerType, containsNull = false), nullable = false) :: Nil))) :: Nil) - - private val testRelation = LocalRelation(schema.toAttributes) + StructField("field9", + MapType(StringType, IntegerType, valueContainsNull = false)) :: Nil)) :: Nil) test("SelectedField should not match an attribute reference") { - assertResult(None)(unapplySelect("col1")) - assertResult(None)(unapplySelect("col1 as foo")) - assertResult(None)(unapplySelect("col2")) + val testRelation = LocalRelation(nestedComplex.toAttributes) + assertResult(None)(unapplySelect("col1", testRelation)) + assertResult(None)(unapplySelect("col1 as foo", testRelation)) + assertResult(None)(unapplySelect("col2", testRelation)) } - info("For a relation with schema\n" + indent(schema.treeString)) + // |-- col1: string (nullable = false) + // |-- col2: struct (nullable = true) + // | |-- field2: array (nullable = true) + // | | |-- element: integer (containsNull = false) + // | |-- field3: array (nullable = false) + // | | |-- element: struct (containsNull = true) + // | | | |-- subfield1: integer (nullable = true) + // | | | |-- subfield2: integer (nullable = true) + // | | | |-- subfield3: array (nullable = true) + // | | | | |-- element: integer (containsNull = true) + private val structOfArray = StructType(ignoredField :: + StructField("col2", StructType( + StructField("field2", ArrayType(IntegerType, containsNull = false)) :: + StructField("field3", ArrayType(StructType( + StructField("subfield1", IntegerType) :: + StructField("subfield2", IntegerType) :: + StructField("subfield3", ArrayType(IntegerType)) :: Nil)), nullable = false) + :: Nil)) + :: Nil) - testSelect("col2.field2", "col2.field2[0] as foo") { + testSelect(structOfArray, "col2.field2", "col2.field2[0] as foo") { StructField("col2", StructType( StructField("field2", ArrayType(IntegerType, containsNull = false)) :: Nil)) } - testSelect("col2.field9", "col2.field9['foo'] as foo") { + testSelect(nestedComplex, "col2.field9", "col2.field9['foo'] as foo") { StructField("col2", StructType( StructField("field9", MapType(StringType, IntegerType, valueContainsNull = false)) :: Nil)) } - testSelect("col2.field3.subfield3", "col2.field3[0].subfield3 as foo", + testSelect(structOfArray, "col2.field3.subfield3", "col2.field3[0].subfield3 as foo", "col2.field3.subfield3[0] as foo", "col2.field3[0].subfield3[0] as foo") { StructField("col2", StructType( StructField("field3", ArrayType(StructType( StructField("subfield3", ArrayType(IntegerType)) :: Nil)), nullable = false) :: Nil)) } - testSelect("col2.field3.subfield1") { + testSelect(structOfArray, "col2.field3.subfield1") { StructField("col2", StructType( StructField("field3", ArrayType(StructType( StructField("subfield1", IntegerType) :: Nil)), nullable = false) :: Nil)) } - testSelect("col2.field5.subfield1") { + // |-- col1: string (nullable = false) + // |-- col2: struct (nullable = true) + // | |-- field4: map (nullable = true) + // | | |-- key: string + // | | |-- value: struct (valueContainsNull = false) + // | | | |-- subfield1: integer (nullable = true) + // | | | |-- subfield2: array (nullable = true) + // | | | | |-- element: integer (containsNull = false) + // | |-- field8: map (nullable = true) + // | | |-- key: string + // | | |-- value: array (valueContainsNull = false) + // | | | |-- element: struct (containsNull = true) + // | | | | |-- subfield1: integer (nullable = true) + // | | | | |-- subfield2: array (nullable = true) + // | | | | | |-- element: integer (containsNull = false) + private val structWithMap = StructType( + ignoredField :: StructField("col2", StructType( - StructField("field5", ArrayType(StructType( - StructField("subfield1", StructType( - StructField("subsubfield1", IntegerType) :: - StructField("subsubfield2", IntegerType) :: Nil), nullable = false) :: Nil)), nullable = false) :: Nil)) - } - - testSelect("col3.field1.subfield1") { - StructField("col3", ArrayType(StructType( - StructField("field1", StructType( - StructField("subfield1", IntegerType, nullable = false) :: Nil)) :: Nil), containsNull = false), nullable = false) - } - - testSelect("col3.field2['foo'] as foo") { - StructField("col3", ArrayType(StructType( - StructField("field2", MapType(StringType, IntegerType, valueContainsNull = false)) :: Nil), containsNull = false), nullable = false) - } - - testSelect("col4['foo'].field1.subfield1 as foo") { - StructField("col4", MapType(StringType, StructType( - StructField("field1", StructType( - StructField("subfield1", IntegerType, nullable = false) :: Nil)) :: Nil), valueContainsNull = false), nullable = false) - } - - testSelect("col4['foo'].field2['bar'] as foo") { - StructField("col4", MapType(StringType, StructType( - StructField("field2", MapType(StringType, IntegerType, valueContainsNull = false)) :: Nil), valueContainsNull = false), nullable = false) - } + StructField("field4", MapType(StringType, StructType( + StructField("subfield1", IntegerType) :: + StructField("subfield2", ArrayType(IntegerType, containsNull = false)) :: Nil + ), valueContainsNull = false)) :: + StructField("field8", MapType(StringType, ArrayType(StructType( + StructField("subfield1", IntegerType) :: + StructField("subfield2", ArrayType(IntegerType, containsNull = false)) :: Nil) + ), valueContainsNull = false)) :: Nil + )) :: Nil + ) - testSelect("col5[0]['foo'].field1.subfield1 as foo") { - StructField("col5", ArrayType(MapType(StringType, StructType( - StructField("field1", StructType( - StructField("subfield1", IntegerType) :: Nil)) :: Nil), valueContainsNull = false))) + testSelect(structWithMap, "col2.field4['foo'].subfield1 as foo") { + StructField("col2", StructType( + StructField("field4", MapType(StringType, StructType( + StructField("subfield1", IntegerType) :: Nil), valueContainsNull = false)) :: Nil)) } - testSelect("col6['foo'][0].field1.subfield1 as foo") { - StructField("col6", MapType(StringType, ArrayType(StructType( - StructField("field1", StructType( - StructField("subfield1", IntegerType) :: Nil)) :: Nil), containsNull = false))) + testSelect(structWithMap, + "col2.field4['foo'].subfield2 as foo", "col2.field4['foo'].subfield2[0] as foo") { + StructField("col2", StructType( + StructField("field4", MapType(StringType, StructType( + StructField("subfield2", ArrayType(IntegerType, containsNull = false)) + :: Nil), valueContainsNull = false)) :: Nil)) } - testSelect("col2.field5.subfield1.subsubfield1") { + // |-- col1: string (nullable = false) + // |-- col2: struct (nullable = true) + // | |-- field5: array (nullable = false) + // | | |-- element: struct (containsNull = true) + // | | | |-- subfield1: struct (nullable = false) + // | | | | |-- subsubfield1: integer (nullable = true) + // | | | | |-- subsubfield2: integer (nullable = true) + // | | | |-- subfield2: struct (nullable = true) + // | | | | |-- subsubfield1: struct (nullable = true) + // | | | | | |-- subsubsubfield1: string (nullable = true) + // | | | | |-- subsubfield2: integer (nullable = true) + private val structWithArray = StructType( + ignoredField :: StructField("col2", StructType( StructField("field5", ArrayType(StructType( StructField("subfield1", StructType( - StructField("subsubfield1", IntegerType) :: Nil), nullable = false) :: Nil)), nullable = false) :: Nil)) - } + StructField("subsubfield1", IntegerType) :: + StructField("subsubfield2", IntegerType) :: Nil), nullable = false) :: + StructField("subfield2", StructType( + StructField("subsubfield1", StructType( + StructField("subsubsubfield1", StringType) :: Nil)) :: + StructField("subsubfield2", IntegerType) :: Nil)) :: Nil)), nullable = false) :: Nil) + ) :: Nil + ) - testSelect("col2.field5.subfield2.subsubfield1.subsubsubfield1") { + testSelect(structWithArray, "col2.field5.subfield1") { StructField("col2", StructType( StructField("field5", ArrayType(StructType( - StructField("subfield2", StructType( - StructField("subsubfield1", StructType( - StructField("subsubsubfield1", StringType) :: Nil)) :: Nil)) :: Nil)), nullable = false) :: Nil)) + StructField("subfield1", StructType( + StructField("subsubfield1", IntegerType) :: + StructField("subsubfield2", IntegerType) :: Nil), nullable = false) + :: Nil)), nullable = false) :: Nil)) } - testSelect("col2.field4['foo'].subfield1 as foo") { + testSelect(structWithArray, "col2.field5.subfield1.subsubfield1") { StructField("col2", StructType( - StructField("field4", MapType(StringType, StructType( - StructField("subfield1", IntegerType) :: Nil), valueContainsNull = false)) :: Nil)) + StructField("field5", ArrayType(StructType( + StructField("subfield1", StructType( + StructField("subsubfield1", IntegerType) :: Nil), nullable = false) + :: Nil)), nullable = false) :: Nil)) } - testSelect("col2.field4['foo'].subfield2 as foo", "col2.field4['foo'].subfield2[0] as foo") { + testSelect(structWithArray, "col2.field5.subfield2.subsubfield1.subsubsubfield1") { StructField("col2", StructType( - StructField("field4", MapType(StringType, StructType( - StructField("subfield2", ArrayType(IntegerType, containsNull = false)) :: Nil), valueContainsNull = false)) :: Nil)) + StructField("field5", ArrayType(StructType( + StructField("subfield2", StructType( + StructField("subsubfield1", StructType( + StructField("subsubsubfield1", StringType) :: Nil)) :: Nil)) + :: Nil)), nullable = false) :: Nil)) } - testSelect("col2.field8['foo'][0].subfield1 as foo") { + testSelect(structWithMap, "col2.field8['foo'][0].subfield1 as foo") { StructField("col2", StructType( StructField("field8", MapType(StringType, ArrayType(StructType( StructField("subfield1", IntegerType) :: Nil)), valueContainsNull = false)) :: Nil)) } - testSelect("col2.field1") { + testSelect(nestedComplex, "col2.field1") { StructField("col2", StructType( StructField("field1", IntegerType) :: Nil)) } - testSelect("col2.field6") { + testSelect(nestedComplex, "col2.field6") { StructField("col2", StructType( StructField("field6", StructType( StructField("subfield1", StringType, nullable = false) :: StructField("subfield2", StringType) :: Nil)) :: Nil)) } - testSelect("col2.field7.subfield1") { + testSelect(nestedComplex, "col2.field6.subfield1") { + StructField("col2", StructType( + StructField("field6", StructType( + StructField("subfield1", StringType, nullable = false) :: Nil)) :: Nil)) + } + + testSelect(nestedComplex, "col2.field7.subfield1") { StructField("col2", StructType( StructField("field7", StructType( StructField("subfield1", StructType( @@ -304,32 +234,165 @@ class SelectedFieldSuite extends SparkFunSuite with BeforeAndAfterAll { StructField("subsubfield2", IntegerType) :: Nil)) :: Nil)) :: Nil)) } - testSelect("col2.field6.subfield1") { - StructField("col2", StructType( - StructField("field6", StructType( - StructField("subfield1", StringType, nullable = false) :: Nil)) :: Nil)) + // |-- col1: string (nullable = false) + // |-- col3: array (nullable = false) + // | |-- element: struct (containsNull = false) + // | | |-- field1: struct (nullable = true) + // | | | |-- subfield1: integer (nullable = false) + // | | | |-- subfield2: integer (nullable = true) + // | | |-- field2: map (nullable = true) + // | | | |-- key: string + // | | | |-- value: integer (valueContainsNull = false) + private val arrayWithStructAndMap = StructType(Array( + StructField("col3", ArrayType(StructType( + StructField("field1", StructType( + StructField("subfield1", IntegerType, nullable = false) :: + StructField("subfield2", IntegerType) :: Nil)) :: + StructField("field2", MapType(StringType, IntegerType, valueContainsNull = false)) + :: Nil), containsNull = false), nullable = false) + )) + + testSelect(arrayWithStructAndMap, "col3.field1.subfield1") { + StructField("col3", ArrayType(StructType( + StructField("field1", StructType( + StructField("subfield1", IntegerType, nullable = false) :: Nil)) + :: Nil), containsNull = false), nullable = false) + } + + testSelect(arrayWithStructAndMap, "col3.field2['foo'] as foo") { + StructField("col3", ArrayType(StructType( + StructField("field2", MapType(StringType, IntegerType, valueContainsNull = false)) + :: Nil), containsNull = false), nullable = false) + } + + // |-- col1: string (nullable = false) + // |-- col4: map (nullable = false) + // | |-- key: string + // | |-- value: struct (valueContainsNull = false) + // | | |-- field1: struct (nullable = true) + // | | | |-- subfield1: integer (nullable = false) + // | | | |-- subfield2: integer (nullable = true) + // | | |-- field2: map (nullable = true) + // | | | |-- key: string + // | | | |-- value: integer (valueContainsNull = false) + private val col4 = StructType(Array(ignoredField, + StructField("col4", MapType(StringType, StructType( + StructField("field1", StructType( + StructField("subfield1", IntegerType, nullable = false) :: + StructField("subfield2", IntegerType) :: Nil)) :: + StructField("field2", MapType(StringType, IntegerType, valueContainsNull = false)) + :: Nil), valueContainsNull = false), nullable = false) + )) + + testSelect(col4, "col4['foo'].field1.subfield1 as foo") { + StructField("col4", MapType(StringType, StructType( + StructField("field1", StructType( + StructField("subfield1", IntegerType, nullable = false) :: Nil)) + :: Nil), valueContainsNull = false), nullable = false) } - testSelect("col7.field1", "col7[0].field1 as foo", "col7.field1[0] as foo") { + testSelect(col4, "col4['foo'].field2['bar'] as foo") { + StructField("col4", MapType(StringType, StructType( + StructField("field2", MapType(StringType, IntegerType, valueContainsNull = false)) + :: Nil), valueContainsNull = false), nullable = false) + } + + // |-- col1: string (nullable = false) + // |-- col5: array (nullable = true) + // | |-- element: map (containsNull = true) + // | | |-- key: string + // | | |-- value: struct (valueContainsNull = false) + // | | | |-- field1: struct (nullable = true) + // | | | | |-- subfield1: integer (nullable = true) + // | | | | |-- subfield2: integer (nullable = true) + private val arrayOfStruct = StructType(Array(ignoredField, + StructField("col5", ArrayType(MapType(StringType, StructType( + StructField("field1", StructType( + StructField("subfield1", IntegerType) :: + StructField("subfield2", IntegerType) :: Nil)) :: Nil), valueContainsNull = false))) + )) + + testSelect(arrayOfStruct, "col5[0]['foo'].field1.subfield1 as foo") { + StructField("col5", ArrayType(MapType(StringType, StructType( + StructField("field1", StructType( + StructField("subfield1", IntegerType) :: Nil)) :: Nil), valueContainsNull = false))) + } + + // |-- col1: string (nullable = false) + // |-- col6: map (nullable = true) + // | |-- key: string + // | |-- value: array (valueContainsNull = true) + // | | |-- element: struct (containsNull = false) + // | | | |-- field1: struct (nullable = true) + // | | | | |-- subfield1: integer (nullable = true) + // | | | | |-- subfield2: integer (nullable = true) + private val mapOfArray = StructType(Array(ignoredField, + StructField("col6", MapType(StringType, ArrayType(StructType( + StructField("field1", StructType( + StructField("subfield1", IntegerType) :: + StructField("subfield2", IntegerType) :: Nil)) :: Nil), containsNull = false))))) + + testSelect(mapOfArray, "col6['foo'][0].field1.subfield1 as foo") { + StructField("col6", MapType(StringType, ArrayType(StructType( + StructField("field1", StructType( + StructField("subfield1", IntegerType) :: Nil)) :: Nil), containsNull = false))) + } + + // An array with a struct with a different fields + // |-- col1: string (nullable = false) + // |-- col7: array (nullable = true) + // | |-- element: struct (containsNull = true) + // | | |-- field1: integer (nullable = false) + // | | |-- field2: struct (nullable = true) + // | | | |-- subfield1: integer (nullable = false) + // | | |-- field3: array (nullable = true) + // | | | |-- element: struct (containsNull = true) + // | | | | |-- subfield1: integer (nullable = false) + private val arrayWithMultipleFields = StructType(Array(ignoredField, + StructField("col7", ArrayType(StructType( + StructField("field1", IntegerType, nullable = false) :: + StructField("field2", StructType( + StructField("subfield1", IntegerType, nullable = false) :: Nil)) :: + StructField("field3", ArrayType(StructType( + StructField("subfield1", IntegerType, nullable = false) :: Nil))) :: Nil))))) + + testSelect(arrayWithMultipleFields, + "col7.field1", "col7[0].field1 as foo", "col7.field1[0] as foo") { StructField("col7", ArrayType(StructType( StructField("field1", IntegerType, nullable = false) :: Nil))) } - testSelect("col7.field2.subfield1") { + testSelect(arrayWithMultipleFields, "col7.field2.subfield1") { StructField("col7", ArrayType(StructType( StructField("field2", StructType( StructField("subfield1", IntegerType, nullable = false) :: Nil)) :: Nil))) } - testSelect("col7.field3.subfield1") { + testSelect(arrayWithMultipleFields, "col7.field3.subfield1") { StructField("col7", ArrayType(StructType( StructField("field3", ArrayType(StructType( StructField("subfield1", IntegerType, nullable = false) :: Nil))) :: Nil))) } - testSelect("col8.field1", "col8[0].field1 as foo", "col8.field1[0] as foo", "col8[0].field1[0] as foo") { + // Array with a nested int array + // |-- col1: string (nullable = false) + // |-- col8: array (nullable = true) + // | |-- element: struct (containsNull = true) + // | | |-- field1: array (nullable = false) + // | | | |-- element: integer (containsNull = false) + private val arrayOfArray = StructType(Array(ignoredField, + StructField("col8", + ArrayType(StructType(Array(StructField("field1", + ArrayType(IntegerType, containsNull = false), nullable = false)))) + ))) + + testSelect(arrayOfArray, "col8.field1", + "col8[0].field1 as foo", + "col8.field1[0] as foo", + "col8[0].field1[0] as foo") { StructField("col8", ArrayType(StructType( - StructField("field1", ArrayType(IntegerType, containsNull = false), nullable = false) :: Nil))) + StructField("field1", ArrayType(IntegerType, containsNull = false), nullable = false) + :: Nil))) } def assertResult(expected: StructField)(actual: StructField)(selectExpr: String): Unit = { @@ -350,18 +413,19 @@ class SelectedFieldSuite extends SparkFunSuite with BeforeAndAfterAll { // Test that the given SELECT expressions prune the test schema to the single-column schema // defined by the given field - private def testSelect(selectExpr: String, otherSelectExprs: String*)(expected: StructField) { - val selectExprs = selectExpr +: otherSelectExprs + private def testSelect(inputSchema: StructType, selectExprs: String*) + (expected: StructField) { test(s"SELECT ${selectExprs.map(s => s""""$s"""").mkString(", ")} should select the schema\n" + - indent(StructType(expected :: Nil).treeString)) { + indent(StructType(expected :: Nil).treeString)) { for (selectExpr <- selectExprs) { - assertSelect(selectExpr, expected) + assertSelect(selectExpr, expected, inputSchema) } } } - private def assertSelect(expr: String, expected: StructField) = { - unapplySelect(expr) match { + private def assertSelect(expr: String, expected: StructField, inputSchema: StructType): Unit = { + val relation = LocalRelation(inputSchema.toAttributes) + unapplySelect(expr, relation) match { case Some(field) => assertResult(expected)(field)(expr) case None => @@ -373,16 +437,19 @@ class SelectedFieldSuite extends SparkFunSuite with BeforeAndAfterAll { } } - private def unapplySelect(expr: String) = { - val parsedExpr = - CatalystSqlParser.parseExpression(expr) match { - case namedExpr: NamedExpression => namedExpr - } - val select = testRelation.select(parsedExpr) + private def unapplySelect(expr: String, relation: LocalRelation) = { + val parsedExpr = parseAsCatalystExpression(Seq(expr)).head + val select = relation.select(parsedExpr) val analyzed = select.analyze SelectedField.unapply(analyzed.expressions.head) } + private def parseAsCatalystExpression(exprs: Seq[String]) = { + exprs.map(CatalystSqlParser.parseExpression(_) match { + case namedExpr: NamedExpression => namedExpr + }) + } + // Indent every line in `string` by four spaces private def indent(string: String) = string.replaceAll("(?m)^", " ") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala index d9e4696c3fa1..b70e19462d8b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala @@ -19,15 +19,18 @@ package org.apache.spark.sql.execution.datasources.parquet import java.io.File -import org.apache.spark.sql.{QueryTest, Row} -import org.apache.spark.sql.execution.FileSchemaPruningTest +import org.scalactic.Equality + +import org.apache.spark.sql.{DataFrame, QueryTest, Row} +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.StructType class ParquetSchemaPruningSuite extends QueryTest with ParquetTest - with FileSchemaPruningTest with SharedSQLContext { case class FullName(first: String, middle: String, last: String) case class Contact( @@ -35,14 +38,14 @@ class ParquetSchemaPruningSuite name: FullName, address: String, pets: Int, - friends: Array[FullName] = Array(), - relatives: Map[String, FullName] = Map()) + friends: Array[FullName] = Array.empty, + relatives: Map[String, FullName] = Map.empty) val janeDoe = FullName("Jane", "X.", "Doe") val johnDoe = FullName("John", "Y.", "Doe") val susanSmith = FullName("Susan", "Z.", "Smith") - val contacts = + private val contacts = Contact(0, janeDoe, "123 Main Street", 1, friends = Array(susanSmith), relatives = Map("brother" -> johnDoe)) :: Contact(1, johnDoe, "321 Wall Street", 3, relatives = Map("sister" -> janeDoe)) :: Nil @@ -50,7 +53,7 @@ class ParquetSchemaPruningSuite case class Name(first: String, last: String) case class BriefContact(id: Int, name: Name, address: String) - val briefContacts = + private val briefContacts = BriefContact(2, Name("Janet", "Jones"), "567 Maple Drive") :: BriefContact(3, Name("Jim", "Jones"), "6242 Ash Street") :: Nil @@ -65,10 +68,10 @@ class ParquetSchemaPruningSuite case class BriefContactWithDataPartitionColumn(id: Int, name: Name, address: String, p: Int) - val contactsWithDataPartitionColumn = + private val contactsWithDataPartitionColumn = contacts.map { case Contact(id, name, address, pets, friends, relatives) => ContactWithDataPartitionColumn(id, name, address, pets, friends, relatives, 1) } - val briefContactsWithDataPartitionColumn = + private val briefContactsWithDataPartitionColumn = briefContacts.map { case BriefContact(id, name, address) => BriefContactWithDataPartitionColumn(id, name, address, 2) } @@ -161,10 +164,10 @@ class ParquetSchemaPruningSuite } withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { - test(s"Parquet-mr reader - without partition data column - $testName") { + test(s"Native Parquet reader - without partition data column - $testName") { withContacts(testThunk) } - test(s"Parquet-mr reader - with partition data column - $testName") { + test(s"Native Parquet reader - with partition data column - $testName") { withContactsWithDataPartitionColumn(testThunk) } } @@ -195,4 +198,35 @@ class ParquetSchemaPruningSuite testThunk } } + + private val schemaEquality = new Equality[StructType] { + override def areEqual(a: StructType, b: Any): Boolean = + b match { + case otherType: StructType => a.sameType(otherType) + case _ => false + } + } + + protected def checkScan(df: DataFrame, expectedSchemaCatalogStrings: String*): Unit = { + checkScanSchemata(df, expectedSchemaCatalogStrings: _*) + // We check here that we can execute the query without throwing an exception. The results + // themselves are irrelevant, and should be checked elsewhere as needed + df.collect() + } + + private def checkScanSchemata(df: DataFrame, expectedSchemaCatalogStrings: String*): Unit = { + val fileSourceScanSchemata = + df.queryExecution.executedPlan.collect { + case scan: FileSourceScanExec => scan.requiredSchema + } + assert(fileSourceScanSchemata.size === expectedSchemaCatalogStrings.size, + s"Found ${fileSourceScanSchemata.size} file sources in dataframe, " + + s"but expected $expectedSchemaCatalogStrings") + fileSourceScanSchemata.zip(expectedSchemaCatalogStrings).foreach { + case (scanSchema, expectedScanSchemaCatalogString) => + val expectedScanSchema = CatalystSqlParser.parseDataType(expectedScanSchemaCatalogString) + implicit val equality = schemaEquality + assert(scanSchema === expectedScanSchema) + } + } }