Skip to content
This repository was archived by the owner on Feb 13, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.planning
package org.apache.spark.sql.execution

import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField}
import org.apache.spark.sql.types.StructField
Expand All @@ -25,7 +25,7 @@ import org.apache.spark.sql.types.StructField
* This is in contrast to the [[GetStructField]] case class extractor which returns the field
* ordinal instead of the field itself.
*/
private[planning] object GetStructFieldObject {
private[execution] object GetStructFieldObject {
def unapply(getStructField: GetStructField): Option[(Expression, StructField)] =
Some((
getStructField.child,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.planning
package org.apache.spark.sql.execution

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
Expand All @@ -26,29 +26,32 @@ import org.apache.spark.sql.types._
* are adjusted to fit the schema. All other expressions are left as-is. This
* class is motivated by columnar nested schema pruning.
*/
case class ProjectionOverSchema(schema: StructType) {
private[execution] case class ProjectionOverSchema(schema: StructType) {
private val fieldNames = schema.fieldNames.toSet

def unapply(expr: Expression): Option[Expression] = getProjection(expr)

private def getProjection(expr: Expression): Option[Expression] =
expr match {
case a @ AttributeReference(name, _, _, _) if (fieldNames.contains(name)) =>
Some(a.copy(dataType = schema(name).dataType)(a.exprId, a.qualifier))
case a: AttributeReference if fieldNames.contains(a.name) =>
Some(a.copy(dataType = schema(a.name).dataType)(a.exprId, a.qualifier))
case GetArrayItem(child, arrayItemOrdinal) =>
getProjection(child).map { projection => GetArrayItem(projection, arrayItemOrdinal) }
case GetArrayStructFields(child, StructField(name, _, _, _), _, numFields, containsNull) =>
getProjection(child).map(p => (p, p.dataType)).map {
case a: GetArrayStructFields =>
getProjection(a.child).map(p => (p, p.dataType)).map {
case (projection, ArrayType(projSchema @ StructType(_), _)) =>
GetArrayStructFields(projection,
projSchema(name), projSchema.fieldIndex(name), projSchema.size, containsNull)
projSchema(a.field.name),
projSchema.fieldIndex(a.field.name),
projSchema.size,
a.containsNull)
}
case GetMapValue(child, key) =>
getProjection(child).map { projection => GetMapValue(projection, key) }
case GetStructFieldObject(child, StructField(name, _, _, _)) =>
case GetStructFieldObject(child, field: StructField) =>
getProjection(child).map(p => (p, p.dataType)).map {
case (projection, projSchema @ StructType(_)) =>
GetStructField(projection, projSchema.fieldIndex(name))
case (projection, projSchema: StructType) =>
GetStructField(projection, projSchema.fieldIndex(field.name))
}
case _ =>
None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.planning
package org.apache.spark.sql.execution

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
Expand All @@ -24,27 +24,27 @@ import org.apache.spark.sql.types._
* A Scala extractor that builds a [[org.apache.spark.sql.types.StructField]] from a Catalyst
* complex type extractor. For example, consider a relation with the following schema:
*
* {{{
* root
* |-- name: struct (nullable = true)
* | |-- first: string (nullable = true)
* | |-- last: string (nullable = true)
* }}}
* {{{
* root
* |-- name: struct (nullable = true)
* | |-- first: string (nullable = true)
* | |-- last: string (nullable = true)
* }}}
*
* Further, suppose we take the select expression `name.first`. This will parse into an
* `Alias(child, "first")`. Ignoring the alias, `child` matches the following pattern:
*
* {{{
* GetStructFieldObject(
* AttributeReference("name", StructType(_), _, _),
* StructField("first", StringType, _, _))
* }}}
* {{{
* GetStructFieldObject(
* AttributeReference("name", StructType(_), _, _),
* StructField("first", StringType, _, _))
* }}}
*
* [[SelectedField]] converts that expression into
*
* {{{
* StructField("name", StructType(Array(StructField("first", StringType))))
* }}}
* {{{
* StructField("name", StructType(Array(StructField("first", StringType))))
* }}}
*
* by mapping each complex type extractor to a [[org.apache.spark.sql.types.StructField]] with the
* same name as its child (or "parent" going right to left in the select expression) and a data
Expand All @@ -54,7 +54,7 @@ import org.apache.spark.sql.types._
*
* @param expr the top-level complex type extractor
*/
object SelectedField {
private[execution] object SelectedField {
def unapply(expr: Expression): Option[StructField] = {
// If this expression is an alias, work on its child instead
val unaliased = expr match {
Expand Down Expand Up @@ -85,25 +85,25 @@ object SelectedField {
field @ StructField(name, dataType, nullable, metadata), _, _, _) =>
val childField = fieldOpt.map(field => StructField(name,
wrapStructType(dataType, field),
nullable, metadata)).getOrElse(field)
selectField(child, Some(childField))
nullable, metadata)).orElse(Some(field))
selectField(child, childField)
// Handles case "expr0.field", where "expr0" is of array type.
case GetArrayStructFields(child,
field @ StructField(name, dataType, nullable, metadata), _, _, containsNull) =>
field @ StructField(name, dataType, nullable, metadata), _, _, _) =>
val childField =
fieldOpt.map(field => StructField(name,
wrapStructType(dataType, field),
nullable, metadata)).getOrElse(field)
selectField(child, Some(childField))
nullable, metadata)).orElse(Some(field))
selectField(child, childField)
// Handles case "expr0.field[key]", where "expr0" is of struct type and "expr0.field" is of
// map type.
case GetMapValue(x @ GetStructFieldObject(child, field @ StructField(name,
dataType,
nullable, metadata)), _) =>
val childField = fieldOpt.map(field => StructField(name,
wrapStructType(dataType, field),
nullable, metadata)).getOrElse(field)
selectField(child, Some(childField))
nullable, metadata)).orElse(Some(field))
selectField(child, childField)
// Handles case "expr0.field[key]", where "expr0.field" is of map type.
case GetMapValue(child, _) =>
selectField(child, fieldOpt)
Expand All @@ -112,8 +112,8 @@ object SelectedField {
field @ StructField(name, dataType, nullable, metadata)) =>
val childField = fieldOpt.map(field => StructField(name,
wrapStructType(dataType, field),
nullable, metadata)).getOrElse(field)
selectField(child, Some(childField))
nullable, metadata)).orElse(Some(field))
selectField(child, childField)
case _ =>
None
}
Expand Down
Loading