From 4d48d1c99befb597af76b4d0fcd6b8577a095caf Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 23 Jul 2024 20:59:02 -0400 Subject: [PATCH 1/2] Add ColumnNode AST --- .../apache/spark/sql/column/columnNodes.scala | 273 +++++++++++++++++ .../ColumnNodeToExpressionConverter.scala | 200 +++++++++++++ ...ColumnNodeToExpressionConverterSuite.scala | 274 ++++++++++++++++++ 3 files changed, 747 insertions(+) create mode 100644 sql/api/src/main/scala/org/apache/spark/sql/column/columnNodes.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/expressions/ColumnNodeToExpressionConverter.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/expressions/ColumnNodeToExpressionConverterSuite.scala diff --git a/sql/api/src/main/scala/org/apache/spark/sql/column/columnNodes.scala b/sql/api/src/main/scala/org/apache/spark/sql/column/columnNodes.scala new file mode 100644 index 000000000000..caed5aca8fbc --- /dev/null +++ b/sql/api/src/main/scala/org/apache/spark/sql/column/columnNodes.scala @@ -0,0 +1,273 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.column + +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} +import org.apache.spark.sql.types.{DataType, Metadata} + +/** + * AST for constructing columns. This API is implementation agnostic and allows us to build a + * single Column implementation that can be shared between implementations. Consequently a + * Dataframe API implementations will have to provide conversions from this AST to its + * implementation specific form (e.g. Catalyst expressions, or Connect protobuf messages). + * + * This API is a mirror image of Connect's expression.proto. There are a couple of extensions to + * make constructing nodes easier (e.g. [[CaseWhenOtherwise]]). We could not use the actual connect + * protobuf messages because of classpath clashes (e.g. Guava & gRPC) and Maven shading issues. + */ +private[sql] trait ColumnNode { + /** + * Origin where the node was created. + */ + def origin: Origin +} + +/** + * A literal column. + * + * @param value of the literal. This is the unconverted input value. + * @param dataType of the literal. If none is provided the dataType is inferred. + */ +private[sql] case class Literal( + value: Any, + dataType: Option[DataType] = None, + override val origin: Origin = CurrentOrigin.get) extends ColumnNode + +/** + * Reference to an attribute produced by one of the underlying DataFrames. + * + * @param unparsedIdentifier name of the attribute. + * @param planId id of the plan (Dataframe) that produces the attribute. + * @param isMetadataColumn whether this is a metadata column. + */ +private[sql] case class UnresolvedAttribute( + unparsedIdentifier: String, + planId: Option[Long] = None, + isMetadataColumn: Boolean = false, + override val origin: Origin = CurrentOrigin.get) + extends ColumnNode + +/** + * Reference to all columns in a namespace (global, a Dataframe, or a nested struct). + * + * @param unparsedTarget name of the namespace. None if the global namespace is supposed to be used. + * @param planId id of the plan (Dataframe) that produces the attribute. + */ +private[sql] case class UnresolvedStar( + unparsedTarget: Option[String], + planId: Option[Long] = None, + override val origin: Origin = CurrentOrigin.get) + extends ColumnNode + +/** + * Call a function. This can either be a built-in function, a UDF, or a UDF registered in the + * Catalog. + * + * @param functionName of the function to invoke. + * @param arguments to pass into the function. + * @param isDistinct (aggregate only) whether the input of the aggregate function should be + * de-duplicated. + */ +private[sql] case class UnresolvedFunction( + functionName: String, + arguments: Seq[ColumnNode], + isDistinct: Boolean = false, + isUserDefinedFunction: Boolean = false, + override val origin: Origin = CurrentOrigin.get) + extends ColumnNode + +/** + * Evaluate a SQL expression. + * + * @param expression text to execute. + */ +private[sql] case class SqlExpression( + expression: String, + override val origin: Origin = CurrentOrigin.get) extends ColumnNode + +/** + * Name a column, and (optionally) modify its metadata. + * + * @param child to name + * @param name to use + * @param metadata (optional) metadata to add. + */ +private[sql] case class Alias( + child: ColumnNode, + name: Seq[String], + metadata: Option[Metadata] = None, + override val origin: Origin = CurrentOrigin.get) extends ColumnNode + +/** + * Cast the value of a Column to a different [[DataType]]. The behavior of the cast can be + * influenced by the `evalMode`. + * + * @param child that produces the input value. + * @param dataType to cast to. + * @param evalMode (try/ansi/legacy) to use for the cast. + */ +private[sql] case class Cast( + child: ColumnNode, + dataType: DataType, + evalMode: Option[Cast.EvalMode.Value] = None, + override val origin: Origin = CurrentOrigin.get) extends ColumnNode + +private[sql] object Cast { + object EvalMode extends Enumeration { + type EvalMode = Value + val Legacy, Ansi, Try = Value + } +} + +/** + * Reference to all columns in the global namespace in that match a regex. + * + * @param regex name of the namespace. None if the global namespace is supposed to be used. + * @param planId id of the plan (Dataframe) that produces the attribute. + */ +private[sql] case class UnresolvedRegex( + regex: String, + planId: Option[Long] = None, + override val origin: Origin = CurrentOrigin.get) extends ColumnNode + +/** + * Sort the input column. + * + * @param child to sort. + * @param sortDirection to sort in, either Ascending or Descending. + * @param nullOrdering where to place nulls, either at the begin or the end. + */ +private[sql] case class SortOrder( + child: ColumnNode, + sortDirection: SortOrder.SortDirection.Value, + nullOrdering: SortOrder.NullOrdering.Value, + override val origin: Origin = CurrentOrigin.get) + extends ColumnNode + +private[sql] object SortOrder { + object SortDirection extends Enumeration { + type SortDirection = Value + val Ascending, Descending = Value + } + object NullOrdering extends Enumeration { + type NullOrdering = Value + val NullsFirst, NullsLast = Value + } +} + +/** + * Evaluate a function within a window. + * + * @param windowFunction function to execute. + * @param windowSpec of the window. + */ +private[sql] case class Window( + windowFunction: ColumnNode, + windowSpec: WindowSpec, + override val origin: Origin = CurrentOrigin.get) + extends ColumnNode + +private[sql] case class WindowSpec( + partitionColumns: Seq[ColumnNode], + sortColumns: Seq[SortOrder], + frame: Option[WindowFrame] = None) + +private[sql] case class WindowFrame( + frameType: WindowFrame.FrameType.Value, + lower: WindowFrame.FrameBoundary, + upper: WindowFrame.FrameBoundary) + +private[sql] object WindowFrame { + object FrameType extends Enumeration { + type FrameType = this.Value + val Row, Range = this.Value + } + + sealed trait FrameBoundary + object CurrentRow extends FrameBoundary + object Unbounded extends FrameBoundary + case class Value(value: ColumnNode) extends FrameBoundary +} + +/** + * Lambda function to execute. This typically passed as an argument to a function. + * + * @param function to execute. + * @param arguments the bound lambda variables. + */ +private[sql] case class LambdaFunction( + function: ColumnNode, + arguments: Seq[UnresolvedNamedLambdaVariable], + override val origin: Origin = CurrentOrigin.get) extends ColumnNode + +/** + * Variable used in a [[LambdaFunction]]. + * + * @param name of the variable. + */ +private[sql] case class UnresolvedNamedLambdaVariable( + name: String, + override val origin: Origin = CurrentOrigin.get) extends ColumnNode + +/** + * Extract a value from a complex type. This can be a field from a struct, a value from a map, + * or an element from an array. + * + * @param child that produces a complex value. + * @param extraction that is used to access the complex type. This needs to be a string type for + * structs and maps, and it needs to be an integer for arrays. + */ +private[sql] case class UnresolvedExtractValue( + child: ColumnNode, + extraction: ColumnNode, + override val origin: Origin = CurrentOrigin.get) extends ColumnNode + +/** + * Update or drop the field of a struct. + * + * @param structExpression that will be updated. + * @param fieldName name of the field to update. + * @param valueExpression new value of the field. If this is None the field will be dropped. + */ +private[sql] case class UpdateFields( + structExpression: ColumnNode, + fieldName: String, + valueExpression: Option[ColumnNode] = None, + override val origin: Origin = CurrentOrigin.get) extends ColumnNode + +/** + * Evaluate one or more conditional branches. The value of the first branch for which the predicate + * evalutes to true is returned. If none of the branches evaluate to true, the value of `otherwise` + * is returned. + * + * @param branches to evaluate. Each entry if a pair of condition and value. + * @param otherwise (optional) to evaluate when none of the branches evaluate to true. + */ +private[sql] case class CaseWhenOtherwise( + branches: Seq[(ColumnNode, ColumnNode)], + otherwise: Option[ColumnNode] = None, + override val origin: Origin = CurrentOrigin.get) + extends ColumnNode + +/** + * Extension point that allows an implementation to use its column representation to be used in a + * generic column expression. This should only be used when the Column constructed is used within + * the implementation. + */ +private[sql] case class Extension( + value: Any, + override val origin: Origin = CurrentOrigin.get) extends ColumnNode diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/ColumnNodeToExpressionConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/ColumnNodeToExpressionConverter.scala new file mode 100644 index 000000000000..fb32c0e8398c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/ColumnNodeToExpressionConverter.scala @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.expressions + +import org.apache.spark.SparkException +import org.apache.spark.sql.{column, SparkSession} +import org.apache.spark.sql.catalyst.analysis.{MultiAlias, UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, CaseWhen, Cast, CurrentRow, Descending, EvalMode, Expression, LambdaFunction, Literal, NullsFirst, NullsLast, RangeFrame, RowFrame, SortOrder, SpecifiedWindowFrame, UnboundedFollowing, UnboundedPreceding, UnresolvedNamedLambdaVariable, UnspecifiedFrame, UpdateFields, WindowExpression, WindowSpecDefinition} +import org.apache.spark.sql.catalyst.parser.{ParserInterface, ParserUtils} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.trees.CurrentOrigin +import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.execution.SparkSqlParser +import org.apache.spark.sql.internal.SQLConf + +/** + * Convert a [[column.ColumnNode]] into an [[Expression]]. + */ +private[sql] trait ColumnNodeToExpressionConverter extends (column.ColumnNode => Expression) { + + protected def parser: ParserInterface + protected def conf: SQLConf + + override def apply(node: column.ColumnNode): Expression = CurrentOrigin.withOrigin(node.origin) { + node match { + case column.Literal(value, Some(dataType), _) => + Literal.create(value, dataType) + + case column.Literal(value, None, _) => + Literal(value) + + case column.UnresolvedAttribute(unparsedIdentifier, planId, isMetadataColumn, _) => + convertUnresolvedAttribute(unparsedIdentifier, planId, isMetadataColumn) + + case column.UnresolvedStar(unparsedTarget, None, _) => + UnresolvedStar(unparsedTarget.map(UnresolvedAttribute.parseAttributeName)) + + case column.UnresolvedStar(None, Some(planId), _) => + UnresolvedDataFrameStar(planId) + + case column.UnresolvedRegex(ParserUtils.escapedIdentifier(columnNameRegex), _, _) => + UnresolvedRegex(columnNameRegex, None, conf.caseSensitiveAnalysis) + + case column.UnresolvedRegex( + ParserUtils.qualifiedEscapedIdentifier(nameParts, columnNameRegex), _, _) => + UnresolvedRegex(columnNameRegex, Some(nameParts), conf.caseSensitiveAnalysis) + + case column.UnresolvedRegex(unparsedIdentifier, planId, _) => + convertUnresolvedAttribute(unparsedIdentifier, planId, isMetadataColumn = false) + + case column.UnresolvedFunction(functionName, arguments, isDistinct, isUDF, _) => + val nameParts = if (isUDF) { + parser.parseMultipartIdentifier(functionName) + } else { + Seq(functionName) + } + UnresolvedFunction(nameParts, arguments.map(apply), isDistinct) + + case column.Alias(child, Seq(name), metadata, _) => + Alias(apply(child), name)(explicitMetadata = metadata) + + case column.Alias(child, names, None, _) if names.nonEmpty => + MultiAlias(apply(child), names) + + case column.Cast(child, dataType, evalMode, _) => + val convertedEvalMode = evalMode match { + case Some(column.Cast.EvalMode.Ansi) => EvalMode.ANSI + case Some(column.Cast.EvalMode.Legacy) => EvalMode.LEGACY + case Some(column.Cast.EvalMode.Try) => EvalMode.TRY + case _ => EvalMode.fromSQLConf(conf) + } + val cast = Cast( + apply(child), + CharVarcharUtils.replaceCharVarcharWithStringForCast(dataType), + None, + convertedEvalMode) + cast.setTagValue(Cast.USER_SPECIFIED_CAST, ()) + cast + + case column.SqlExpression(expression, _) => + parser.parseExpression(expression) + + case sortOrder: column.SortOrder => + convertSortOrder(sortOrder) + + case column.Window(function, spec, _) => + val frame = spec.frame match { + case Some(column.WindowFrame(frameType, lower, upper)) => + val convertedFrameType = frameType match { + case column.WindowFrame.FrameType.Range => RangeFrame + case column.WindowFrame.FrameType.Row => RowFrame + } + val convertedLower = lower match { + case column.WindowFrame.CurrentRow => CurrentRow + case column.WindowFrame.Unbounded => UnboundedPreceding + case column.WindowFrame.Value(node) => apply(node) + } + val convertedUpper = upper match { + case column.WindowFrame.CurrentRow => CurrentRow + case column.WindowFrame.Unbounded => UnboundedFollowing + case column.WindowFrame.Value(node) => apply(node) + } + SpecifiedWindowFrame(convertedFrameType, convertedLower, convertedUpper) + case None => + UnspecifiedFrame + } + WindowExpression( + apply(function), + WindowSpecDefinition( + partitionSpec = spec.partitionColumns.map(apply), + orderSpec = spec.sortColumns.map(convertSortOrder), + frameSpecification = frame)) + + case column.LambdaFunction(function, arguments, _) => + LambdaFunction( + apply(function), + arguments.map(convertUnresolvedNamedLambdaVariable)) + + case v: column.UnresolvedNamedLambdaVariable => + convertUnresolvedNamedLambdaVariable(v) + + case column.UnresolvedExtractValue(child, extraction, _) => + UnresolvedExtractValue(apply(child), apply(extraction)) + + case column.UpdateFields(struct, field, Some(value), _) => + UpdateFields(apply(struct), field, apply(value)) + + case column.UpdateFields(struct, field, None, _) => + UpdateFields(apply(struct), field) + + case column.CaseWhenOtherwise(branches, otherwise, _) => + CaseWhen( + branches = branches.map { case (condition, value) => + (apply(condition), apply(value)) + }, + elseValue = otherwise.map(apply)) + + case column.Extension(expression: Expression, _) => + expression + + case node => + throw SparkException.internalError("Unsupported ColumnNode: " + node) + } + } + + private def convertUnresolvedNamedLambdaVariable( + v: column.UnresolvedNamedLambdaVariable): UnresolvedNamedLambdaVariable = { + UnresolvedNamedLambdaVariable(Seq(v.name)) + } + + private def convertSortOrder(sortOrder: column.SortOrder): SortOrder = { + val sortDirection = sortOrder.sortDirection match { + case column.SortOrder.SortDirection.Ascending => Ascending + case column.SortOrder.SortDirection.Descending => Descending + } + val nullOrdering = sortOrder.nullOrdering match { + case column.SortOrder.NullOrdering.NullsFirst => NullsFirst + case column.SortOrder.NullOrdering.NullsLast => NullsLast + } + SortOrder(apply(sortOrder.child), sortDirection, nullOrdering, Nil) + } + + private def convertUnresolvedAttribute( + unparsedIdentifier: String, + planId: Option[Long], + isMetadataColumn: Boolean): UnresolvedAttribute = { + val attribute = UnresolvedAttribute.quotedString(unparsedIdentifier) + if (planId.isDefined) { + attribute.setTagValue(LogicalPlan.PLAN_ID_TAG, planId.get) + } + if (isMetadataColumn) { + attribute.setTagValue(LogicalPlan.IS_METADATA_COL, ()) + } + attribute + } +} + +object ColumnNodeToExpressionConverter extends ColumnNodeToExpressionConverter { + override protected def parser: ParserInterface = { + SparkSession.getActiveSession.map(_.sessionState.sqlParser).getOrElse { + new SparkSqlParser() + } + } + + override protected def conf: SQLConf = SQLConf.get +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ColumnNodeToExpressionConverterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ColumnNodeToExpressionConverterSuite.scala new file mode 100644 index 000000000000..8a49caef6820 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ColumnNodeToExpressionConverterSuite.scala @@ -0,0 +1,274 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.expressions + +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.sql.catalyst.analysis.{MultiAlias, UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, CaseWhen, Cast, CurrentRow, Descending, DropField, EvalMode, Expression, ExprId, FrameType, LambdaFunction, Literal, NullOrdering, NullsFirst, NullsLast, RangeFrame, RowFrame, SortDirection, SortOrder, SpecifiedWindowFrame, UnboundedFollowing, UnboundedPreceding, UnresolvedNamedLambdaVariable, UnspecifiedFrame, UpdateFields, WindowExpression, WindowSpecDefinition, WithField} +import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} +import org.apache.spark.sql.column +import org.apache.spark.sql.execution.SparkSqlParser +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{BinaryType, DataType, DoubleType, IntegerType, LongType, ShortType, StringType} + +/** + * Test suite for [[column.ColumnNode]] to [[Expression]] conversions. + */ +class ColumnNodeToExpressionConverterSuite extends SparkFunSuite { + private object Converter extends ColumnNodeToExpressionConverter { + override val conf: SQLConf = new SQLConf + override val parser: ParserInterface = new SparkSqlParser + } + + private def testConversion(node: => column.ColumnNode, expected: Expression): Expression = { + val myOrigin = Origin() + CurrentOrigin.withOrigin(myOrigin) { + val expression = normalizeExpression(Converter(node)) + assert(expression == normalizeExpression(expected)) + assert(expression.origin eq myOrigin) + expression + } + } + + private def normalizeExpression(e: Expression): Expression = e.transform { + case a: Alias => + a.copy()(exprId = ExprId(0), a.qualifier, a.explicitMetadata, a.nonInheritableMetadataKeys) + } + + test("literal") { + testConversion(column.Literal(1), Literal(1, IntegerType)) + testConversion(column.Literal("foo", Option(StringType)), Literal.create("foo", StringType)) + } + + test("attribute") { + val expression1 = testConversion(column.UnresolvedAttribute("x"), UnresolvedAttribute("x")) + assert(expression1.getTagValue(LogicalPlan.PLAN_ID_TAG).isEmpty) + assert(expression1.getTagValue(LogicalPlan.IS_METADATA_COL).isEmpty) + + val expression2 = testConversion( + column.UnresolvedAttribute("y", Option(44L), isMetadataColumn = true), + UnresolvedAttribute("y")) + assert(expression2.getTagValue(LogicalPlan.PLAN_ID_TAG).contains(44L)) + assert(expression2.getTagValue(LogicalPlan.IS_METADATA_COL).isDefined) + } + + test("star") { + testConversion(column.UnresolvedStar(None), UnresolvedStar(None)) + testConversion( + column.UnresolvedStar(Option("x.y.z")), + UnresolvedStar(Option(Seq("x", "y", "z")))) + testConversion( + column.UnresolvedStar(None, Option(10L)), + UnresolvedDataFrameStar(10L)) + } + + test("regex") { + testConversion( + column.UnresolvedRegex("`(_1)?+.+`"), + UnresolvedRegex("(_1)?+.+", None, caseSensitive = false)) + + val expression = testConversion( + column.UnresolvedRegex("a", planId = Option(11L)), + UnresolvedAttribute("a")) + assert(expression.getTagValue(LogicalPlan.PLAN_ID_TAG).contains(11L)) + assert(expression.getTagValue(LogicalPlan.IS_METADATA_COL).isEmpty) + } + + test("function") { + testConversion( + column.UnresolvedFunction("+", Seq(column.UnresolvedAttribute("a"), column.Literal(1))), + UnresolvedFunction(Seq("+"), Seq(UnresolvedAttribute("a"), Literal(1)), isDistinct = false)) + testConversion( + column.UnresolvedFunction( + "db1.myAgg", + Seq(column.UnresolvedAttribute("a")), + isDistinct = true, + isUserDefinedFunction = true), + UnresolvedFunction( + Seq("db1", "myAgg"), + Seq(UnresolvedAttribute("a")), + isDistinct = true)) + } + + test("alias") { + testConversion( + column.Alias(column.Literal("qwe"), "newA" :: Nil), + Alias(Literal("qwe"), "newA")()) + testConversion( + column.Alias(column.UnresolvedAttribute("complex"), "newA" :: "newB" :: Nil), + MultiAlias(UnresolvedAttribute("complex"), Seq("newA", "newB"))) + } + + private def testCast( + dataType: DataType, + colEvalMode: column.Cast.EvalMode.Value, + catEvalMode: EvalMode.Value): Unit = { + testConversion( + column.Cast(column.UnresolvedAttribute("attr"), dataType, Option(colEvalMode)), + Cast(UnresolvedAttribute("attr"), dataType, evalMode = catEvalMode)) + } + + test("cast") { + testConversion( + column.Cast(column.UnresolvedAttribute("str"), DoubleType), + Cast(UnresolvedAttribute("str"), DoubleType)) + + testCast(LongType, column.Cast.EvalMode.Legacy, EvalMode.LEGACY) + testCast(BinaryType, column.Cast.EvalMode.Try, EvalMode.TRY) + testCast(ShortType, column.Cast.EvalMode.Ansi, EvalMode.ANSI) + } + + private def testSortOrder( + colDirection: column.SortOrder.SortDirection.SortDirection, + colNullOrdering: column.SortOrder.NullOrdering.NullOrdering, + catDirection: SortDirection, + catNullOrdering: NullOrdering): Unit = { + testConversion( + column.SortOrder(column.UnresolvedAttribute("unsorted"), colDirection, colNullOrdering), + new SortOrder(UnresolvedAttribute("unsorted"), catDirection, catNullOrdering, Nil)) + } + + test("sortOrder") { + testSortOrder( + column.SortOrder.SortDirection.Ascending, + column.SortOrder.NullOrdering.NullsFirst, + Ascending, + NullsFirst) + testSortOrder( + column.SortOrder.SortDirection.Ascending, + column.SortOrder.NullOrdering.NullsLast, + Ascending, + NullsLast) + testSortOrder( + column.SortOrder.SortDirection.Descending, + column.SortOrder.NullOrdering.NullsFirst, + Descending, + NullsFirst) + testSortOrder( + column.SortOrder.SortDirection.Descending, + column.SortOrder.NullOrdering.NullsLast, + Descending, + NullsLast) + } + + private def testWindowFrame( + colFrameType: column.WindowFrame.FrameType.FrameType, + colLower: column.WindowFrame.FrameBoundary, + colUpper: column.WindowFrame.FrameBoundary, + catFrameType: FrameType, + catLower: Expression, + catUpper: Expression): Unit = { + testConversion( + column.Window( + column.UnresolvedFunction("sum", Seq(column.UnresolvedAttribute("a"))), + column.WindowSpec( + Seq(column.UnresolvedAttribute("b"), column.UnresolvedAttribute("c")), + Seq(column.SortOrder( + column.UnresolvedAttribute("d"), + column.SortOrder.SortDirection.Descending, + column.SortOrder.NullOrdering.NullsLast)), + Option(column.WindowFrame(colFrameType, colLower, colUpper)))), + WindowExpression( + UnresolvedFunction("sum", Seq(UnresolvedAttribute("a")), isDistinct = false), + WindowSpecDefinition( + Seq(UnresolvedAttribute("b"), UnresolvedAttribute("c")), + Seq(SortOrder(UnresolvedAttribute("d"), Descending, NullsLast, Nil)), + SpecifiedWindowFrame(catFrameType, catLower, catUpper)))) + } + + test("window") { + testConversion( + column.Window( + column.UnresolvedFunction("sum", Seq(column.UnresolvedAttribute("a"))), + column.WindowSpec( + Seq(column.UnresolvedAttribute("b"), column.UnresolvedAttribute("c")), + Nil, + None)), + WindowExpression( + UnresolvedFunction("sum", Seq(UnresolvedAttribute("a")), isDistinct = false), + WindowSpecDefinition( + Seq(UnresolvedAttribute("b"), UnresolvedAttribute("c")), + Nil, + UnspecifiedFrame))) + testWindowFrame( + column.WindowFrame.FrameType.Row, + column.WindowFrame.Value(column.Literal(-10)), + column.WindowFrame.Unbounded, + RowFrame, + Literal(-10), + UnboundedFollowing) + testWindowFrame( + column.WindowFrame.FrameType.Range, + column.WindowFrame.Unbounded, + column.WindowFrame.CurrentRow, + RangeFrame, + UnboundedPreceding, + CurrentRow) + } + + test("lambda") { + val colX = column.UnresolvedNamedLambdaVariable("x") + val catX = UnresolvedNamedLambdaVariable(Seq("x")) + testConversion( + column.LambdaFunction( + column.UnresolvedFunction("+", Seq(colX, column.UnresolvedAttribute("y"))), + Seq(colX)), + LambdaFunction( + UnresolvedFunction("+", Seq(catX, UnresolvedAttribute("y")), isDistinct = false), + Seq(catX))) + } + + test("caseWhen") { + testConversion( + column.CaseWhenOtherwise( + Seq(column.UnresolvedAttribute("c1") -> column.Literal("r1")), + Option(column.Literal("fallback"))), + CaseWhen( + Seq(UnresolvedAttribute("c1") -> Literal("r1")), + Option(Literal("fallback"))) + ) + } + + test("extract field") { + testConversion( + column.UnresolvedExtractValue(column.UnresolvedAttribute("struct"), column.Literal("cl_a")), + UnresolvedExtractValue(UnresolvedAttribute("struct"), Literal("cl_a"))) + } + + test("update field") { + testConversion( + column.UpdateFields( + column.UnresolvedAttribute("struct"), + "col_b", + Option(column.Literal("cl_a"))), + UpdateFields(UnresolvedAttribute("struct"), Seq(WithField("col_b", Literal("cl_a"))))) + + testConversion( + column.UpdateFields(column.UnresolvedAttribute("struct"), "col_c", None), + UpdateFields(UnresolvedAttribute("struct"), Seq(DropField("col_c")))) + } + + test("extension") { + testConversion(column.Extension(UnresolvedAttribute("bar")), UnresolvedAttribute("bar")) + } + + test("unsupported") { + intercept[SparkException](Converter(column.Extension("kaboom"))) + } +} From 2d30e57bdd918ab502ff5ff46cada19388ca1fa8 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Thu, 25 Jul 2024 23:16:39 -0400 Subject: [PATCH 2/2] Move to internal --- .../{column => internal}/columnNodes.scala | 2 +- .../ColumnNodeToExpressionConverter.scala | 200 ------------ .../ColumnNodeToExpressionConverter.scala | 200 ++++++++++++ ...ColumnNodeToExpressionConverterSuite.scala | 274 ---------------- ...ColumnNodeToExpressionConverterSuite.scala | 298 ++++++++++++++++++ 5 files changed, 499 insertions(+), 475 deletions(-) rename sql/api/src/main/scala/org/apache/spark/sql/{column => internal}/columnNodes.scala (99%) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/expressions/ColumnNodeToExpressionConverter.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverter.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/expressions/ColumnNodeToExpressionConverterSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverterSuite.scala diff --git a/sql/api/src/main/scala/org/apache/spark/sql/column/columnNodes.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala similarity index 99% rename from sql/api/src/main/scala/org/apache/spark/sql/column/columnNodes.scala rename to sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala index caed5aca8fbc..971ad4281747 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/column/columnNodes.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.column +package org.apache.spark.sql.internal import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} import org.apache.spark.sql.types.{DataType, Metadata} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/ColumnNodeToExpressionConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/ColumnNodeToExpressionConverter.scala deleted file mode 100644 index fb32c0e8398c..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/ColumnNodeToExpressionConverter.scala +++ /dev/null @@ -1,200 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.expressions - -import org.apache.spark.SparkException -import org.apache.spark.sql.{column, SparkSession} -import org.apache.spark.sql.catalyst.analysis.{MultiAlias, UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, CaseWhen, Cast, CurrentRow, Descending, EvalMode, Expression, LambdaFunction, Literal, NullsFirst, NullsLast, RangeFrame, RowFrame, SortOrder, SpecifiedWindowFrame, UnboundedFollowing, UnboundedPreceding, UnresolvedNamedLambdaVariable, UnspecifiedFrame, UpdateFields, WindowExpression, WindowSpecDefinition} -import org.apache.spark.sql.catalyst.parser.{ParserInterface, ParserUtils} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.trees.CurrentOrigin -import org.apache.spark.sql.catalyst.util.CharVarcharUtils -import org.apache.spark.sql.execution.SparkSqlParser -import org.apache.spark.sql.internal.SQLConf - -/** - * Convert a [[column.ColumnNode]] into an [[Expression]]. - */ -private[sql] trait ColumnNodeToExpressionConverter extends (column.ColumnNode => Expression) { - - protected def parser: ParserInterface - protected def conf: SQLConf - - override def apply(node: column.ColumnNode): Expression = CurrentOrigin.withOrigin(node.origin) { - node match { - case column.Literal(value, Some(dataType), _) => - Literal.create(value, dataType) - - case column.Literal(value, None, _) => - Literal(value) - - case column.UnresolvedAttribute(unparsedIdentifier, planId, isMetadataColumn, _) => - convertUnresolvedAttribute(unparsedIdentifier, planId, isMetadataColumn) - - case column.UnresolvedStar(unparsedTarget, None, _) => - UnresolvedStar(unparsedTarget.map(UnresolvedAttribute.parseAttributeName)) - - case column.UnresolvedStar(None, Some(planId), _) => - UnresolvedDataFrameStar(planId) - - case column.UnresolvedRegex(ParserUtils.escapedIdentifier(columnNameRegex), _, _) => - UnresolvedRegex(columnNameRegex, None, conf.caseSensitiveAnalysis) - - case column.UnresolvedRegex( - ParserUtils.qualifiedEscapedIdentifier(nameParts, columnNameRegex), _, _) => - UnresolvedRegex(columnNameRegex, Some(nameParts), conf.caseSensitiveAnalysis) - - case column.UnresolvedRegex(unparsedIdentifier, planId, _) => - convertUnresolvedAttribute(unparsedIdentifier, planId, isMetadataColumn = false) - - case column.UnresolvedFunction(functionName, arguments, isDistinct, isUDF, _) => - val nameParts = if (isUDF) { - parser.parseMultipartIdentifier(functionName) - } else { - Seq(functionName) - } - UnresolvedFunction(nameParts, arguments.map(apply), isDistinct) - - case column.Alias(child, Seq(name), metadata, _) => - Alias(apply(child), name)(explicitMetadata = metadata) - - case column.Alias(child, names, None, _) if names.nonEmpty => - MultiAlias(apply(child), names) - - case column.Cast(child, dataType, evalMode, _) => - val convertedEvalMode = evalMode match { - case Some(column.Cast.EvalMode.Ansi) => EvalMode.ANSI - case Some(column.Cast.EvalMode.Legacy) => EvalMode.LEGACY - case Some(column.Cast.EvalMode.Try) => EvalMode.TRY - case _ => EvalMode.fromSQLConf(conf) - } - val cast = Cast( - apply(child), - CharVarcharUtils.replaceCharVarcharWithStringForCast(dataType), - None, - convertedEvalMode) - cast.setTagValue(Cast.USER_SPECIFIED_CAST, ()) - cast - - case column.SqlExpression(expression, _) => - parser.parseExpression(expression) - - case sortOrder: column.SortOrder => - convertSortOrder(sortOrder) - - case column.Window(function, spec, _) => - val frame = spec.frame match { - case Some(column.WindowFrame(frameType, lower, upper)) => - val convertedFrameType = frameType match { - case column.WindowFrame.FrameType.Range => RangeFrame - case column.WindowFrame.FrameType.Row => RowFrame - } - val convertedLower = lower match { - case column.WindowFrame.CurrentRow => CurrentRow - case column.WindowFrame.Unbounded => UnboundedPreceding - case column.WindowFrame.Value(node) => apply(node) - } - val convertedUpper = upper match { - case column.WindowFrame.CurrentRow => CurrentRow - case column.WindowFrame.Unbounded => UnboundedFollowing - case column.WindowFrame.Value(node) => apply(node) - } - SpecifiedWindowFrame(convertedFrameType, convertedLower, convertedUpper) - case None => - UnspecifiedFrame - } - WindowExpression( - apply(function), - WindowSpecDefinition( - partitionSpec = spec.partitionColumns.map(apply), - orderSpec = spec.sortColumns.map(convertSortOrder), - frameSpecification = frame)) - - case column.LambdaFunction(function, arguments, _) => - LambdaFunction( - apply(function), - arguments.map(convertUnresolvedNamedLambdaVariable)) - - case v: column.UnresolvedNamedLambdaVariable => - convertUnresolvedNamedLambdaVariable(v) - - case column.UnresolvedExtractValue(child, extraction, _) => - UnresolvedExtractValue(apply(child), apply(extraction)) - - case column.UpdateFields(struct, field, Some(value), _) => - UpdateFields(apply(struct), field, apply(value)) - - case column.UpdateFields(struct, field, None, _) => - UpdateFields(apply(struct), field) - - case column.CaseWhenOtherwise(branches, otherwise, _) => - CaseWhen( - branches = branches.map { case (condition, value) => - (apply(condition), apply(value)) - }, - elseValue = otherwise.map(apply)) - - case column.Extension(expression: Expression, _) => - expression - - case node => - throw SparkException.internalError("Unsupported ColumnNode: " + node) - } - } - - private def convertUnresolvedNamedLambdaVariable( - v: column.UnresolvedNamedLambdaVariable): UnresolvedNamedLambdaVariable = { - UnresolvedNamedLambdaVariable(Seq(v.name)) - } - - private def convertSortOrder(sortOrder: column.SortOrder): SortOrder = { - val sortDirection = sortOrder.sortDirection match { - case column.SortOrder.SortDirection.Ascending => Ascending - case column.SortOrder.SortDirection.Descending => Descending - } - val nullOrdering = sortOrder.nullOrdering match { - case column.SortOrder.NullOrdering.NullsFirst => NullsFirst - case column.SortOrder.NullOrdering.NullsLast => NullsLast - } - SortOrder(apply(sortOrder.child), sortDirection, nullOrdering, Nil) - } - - private def convertUnresolvedAttribute( - unparsedIdentifier: String, - planId: Option[Long], - isMetadataColumn: Boolean): UnresolvedAttribute = { - val attribute = UnresolvedAttribute.quotedString(unparsedIdentifier) - if (planId.isDefined) { - attribute.setTagValue(LogicalPlan.PLAN_ID_TAG, planId.get) - } - if (isMetadataColumn) { - attribute.setTagValue(LogicalPlan.IS_METADATA_COL, ()) - } - attribute - } -} - -object ColumnNodeToExpressionConverter extends ColumnNodeToExpressionConverter { - override protected def parser: ParserInterface = { - SparkSession.getActiveSession.map(_.sessionState.sqlParser).getOrElse { - new SparkSqlParser() - } - } - - override protected def conf: SQLConf = SQLConf.get -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverter.scala new file mode 100644 index 000000000000..ff2d6f72c86a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverter.scala @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.internal + +import org.apache.spark.SparkException +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.parser.{ParserInterface, ParserUtils} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.trees.CurrentOrigin +import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.execution.SparkSqlParser + +/** + * Convert a [[ColumnNode]] into an [[Expression]]. + */ +private[sql] trait ColumnNodeToExpressionConverter extends (ColumnNode => Expression) { + + protected def parser: ParserInterface + protected def conf: SQLConf + + override def apply(node: ColumnNode): Expression = CurrentOrigin.withOrigin(node.origin) { + node match { + case Literal(value, Some(dataType), _) => + expressions.Literal.create(value, dataType) + + case Literal(value, None, _) => + expressions.Literal(value) + + case UnresolvedAttribute(unparsedIdentifier, planId, isMetadataColumn, _) => + convertUnresolvedAttribute(unparsedIdentifier, planId, isMetadataColumn) + + case UnresolvedStar(unparsedTarget, None, _) => + analysis.UnresolvedStar(unparsedTarget.map(analysis.UnresolvedAttribute.parseAttributeName)) + + case UnresolvedStar(None, Some(planId), _) => + analysis.UnresolvedDataFrameStar(planId) + + case UnresolvedRegex(ParserUtils.escapedIdentifier(columnNameRegex), _, _) => + analysis.UnresolvedRegex(columnNameRegex, None, conf.caseSensitiveAnalysis) + + case UnresolvedRegex( + ParserUtils.qualifiedEscapedIdentifier(nameParts, columnNameRegex), _, _) => + analysis.UnresolvedRegex(columnNameRegex, Some(nameParts), conf.caseSensitiveAnalysis) + + case UnresolvedRegex(unparsedIdentifier, planId, _) => + convertUnresolvedAttribute(unparsedIdentifier, planId, isMetadataColumn = false) + + case UnresolvedFunction(functionName, arguments, isDistinct, isUDF, _) => + val nameParts = if (isUDF) { + parser.parseMultipartIdentifier(functionName) + } else { + Seq(functionName) + } + analysis.UnresolvedFunction(nameParts, arguments.map(apply), isDistinct) + + case Alias(child, Seq(name), metadata, _) => + expressions.Alias(apply(child), name)(explicitMetadata = metadata) + + case Alias(child, names, None, _) if names.nonEmpty => + analysis.MultiAlias(apply(child), names) + + case Cast(child, dataType, evalMode, _) => + val convertedEvalMode = evalMode match { + case Some(Cast.EvalMode.Ansi) => expressions.EvalMode.ANSI + case Some(Cast.EvalMode.Legacy) => expressions.EvalMode.LEGACY + case Some(Cast.EvalMode.Try) => expressions.EvalMode.TRY + case _ => expressions.EvalMode.fromSQLConf(conf) + } + val cast = expressions.Cast( + apply(child), + CharVarcharUtils.replaceCharVarcharWithStringForCast(dataType), + None, + convertedEvalMode) + cast.setTagValue(expressions.Cast.USER_SPECIFIED_CAST, ()) + cast + + case SqlExpression(expression, _) => + parser.parseExpression(expression) + + case sortOrder: SortOrder => + convertSortOrder(sortOrder) + + case Window(function, spec, _) => + val frame = spec.frame match { + case Some(WindowFrame(frameType, lower, upper)) => + val convertedFrameType = frameType match { + case WindowFrame.FrameType.Range => expressions.RangeFrame + case WindowFrame.FrameType.Row => expressions.RowFrame + } + val convertedLower = lower match { + case WindowFrame.CurrentRow => expressions.CurrentRow + case WindowFrame.Unbounded => expressions.UnboundedPreceding + case WindowFrame.Value(node) => apply(node) + } + val convertedUpper = upper match { + case WindowFrame.CurrentRow => expressions.CurrentRow + case WindowFrame.Unbounded => expressions.UnboundedFollowing + case WindowFrame.Value(node) => apply(node) + } + expressions.SpecifiedWindowFrame(convertedFrameType, convertedLower, convertedUpper) + case None => + expressions.UnspecifiedFrame + } + expressions.WindowExpression( + apply(function), + expressions.WindowSpecDefinition( + partitionSpec = spec.partitionColumns.map(apply), + orderSpec = spec.sortColumns.map(convertSortOrder), + frameSpecification = frame)) + + case LambdaFunction(function, arguments, _) => + expressions.LambdaFunction( + apply(function), + arguments.map(convertUnresolvedNamedLambdaVariable)) + + case v: UnresolvedNamedLambdaVariable => + convertUnresolvedNamedLambdaVariable(v) + + case UnresolvedExtractValue(child, extraction, _) => + analysis.UnresolvedExtractValue(apply(child), apply(extraction)) + + case UpdateFields(struct, field, Some(value), _) => + expressions.UpdateFields(apply(struct), field, apply(value)) + + case UpdateFields(struct, field, None, _) => + expressions.UpdateFields(apply(struct), field) + + case CaseWhenOtherwise(branches, otherwise, _) => + expressions.CaseWhen( + branches = branches.map { case (condition, value) => + (apply(condition), apply(value)) + }, + elseValue = otherwise.map(apply)) + + case Extension(expression: Expression, _) => + expression + + case node => + throw SparkException.internalError("Unsupported ColumnNode: " + node) + } + } + + private def convertUnresolvedNamedLambdaVariable( + v: UnresolvedNamedLambdaVariable): expressions.UnresolvedNamedLambdaVariable = { + expressions.UnresolvedNamedLambdaVariable(Seq(v.name)) + } + + private def convertSortOrder(sortOrder: SortOrder): expressions.SortOrder = { + val sortDirection = sortOrder.sortDirection match { + case SortOrder.SortDirection.Ascending => expressions.Ascending + case SortOrder.SortDirection.Descending => expressions.Descending + } + val nullOrdering = sortOrder.nullOrdering match { + case SortOrder.NullOrdering.NullsFirst => expressions.NullsFirst + case SortOrder.NullOrdering.NullsLast => expressions.NullsLast + } + expressions.SortOrder(apply(sortOrder.child), sortDirection, nullOrdering, Nil) + } + + private def convertUnresolvedAttribute( + unparsedIdentifier: String, + planId: Option[Long], + isMetadataColumn: Boolean): analysis.UnresolvedAttribute = { + val attribute = analysis.UnresolvedAttribute.quotedString(unparsedIdentifier) + if (planId.isDefined) { + attribute.setTagValue(LogicalPlan.PLAN_ID_TAG, planId.get) + } + if (isMetadataColumn) { + attribute.setTagValue(LogicalPlan.IS_METADATA_COL, ()) + } + attribute + } +} + +object ColumnNodeToExpressionConverter extends ColumnNodeToExpressionConverter { + override protected def parser: ParserInterface = { + SparkSession.getActiveSession.map(_.sessionState.sqlParser).getOrElse { + new SparkSqlParser() + } + } + + override protected def conf: SQLConf = SQLConf.get +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ColumnNodeToExpressionConverterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ColumnNodeToExpressionConverterSuite.scala deleted file mode 100644 index 8a49caef6820..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ColumnNodeToExpressionConverterSuite.scala +++ /dev/null @@ -1,274 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.expressions - -import org.apache.spark.{SparkException, SparkFunSuite} -import org.apache.spark.sql.catalyst.analysis.{MultiAlias, UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, CaseWhen, Cast, CurrentRow, Descending, DropField, EvalMode, Expression, ExprId, FrameType, LambdaFunction, Literal, NullOrdering, NullsFirst, NullsLast, RangeFrame, RowFrame, SortDirection, SortOrder, SpecifiedWindowFrame, UnboundedFollowing, UnboundedPreceding, UnresolvedNamedLambdaVariable, UnspecifiedFrame, UpdateFields, WindowExpression, WindowSpecDefinition, WithField} -import org.apache.spark.sql.catalyst.parser.ParserInterface -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} -import org.apache.spark.sql.column -import org.apache.spark.sql.execution.SparkSqlParser -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{BinaryType, DataType, DoubleType, IntegerType, LongType, ShortType, StringType} - -/** - * Test suite for [[column.ColumnNode]] to [[Expression]] conversions. - */ -class ColumnNodeToExpressionConverterSuite extends SparkFunSuite { - private object Converter extends ColumnNodeToExpressionConverter { - override val conf: SQLConf = new SQLConf - override val parser: ParserInterface = new SparkSqlParser - } - - private def testConversion(node: => column.ColumnNode, expected: Expression): Expression = { - val myOrigin = Origin() - CurrentOrigin.withOrigin(myOrigin) { - val expression = normalizeExpression(Converter(node)) - assert(expression == normalizeExpression(expected)) - assert(expression.origin eq myOrigin) - expression - } - } - - private def normalizeExpression(e: Expression): Expression = e.transform { - case a: Alias => - a.copy()(exprId = ExprId(0), a.qualifier, a.explicitMetadata, a.nonInheritableMetadataKeys) - } - - test("literal") { - testConversion(column.Literal(1), Literal(1, IntegerType)) - testConversion(column.Literal("foo", Option(StringType)), Literal.create("foo", StringType)) - } - - test("attribute") { - val expression1 = testConversion(column.UnresolvedAttribute("x"), UnresolvedAttribute("x")) - assert(expression1.getTagValue(LogicalPlan.PLAN_ID_TAG).isEmpty) - assert(expression1.getTagValue(LogicalPlan.IS_METADATA_COL).isEmpty) - - val expression2 = testConversion( - column.UnresolvedAttribute("y", Option(44L), isMetadataColumn = true), - UnresolvedAttribute("y")) - assert(expression2.getTagValue(LogicalPlan.PLAN_ID_TAG).contains(44L)) - assert(expression2.getTagValue(LogicalPlan.IS_METADATA_COL).isDefined) - } - - test("star") { - testConversion(column.UnresolvedStar(None), UnresolvedStar(None)) - testConversion( - column.UnresolvedStar(Option("x.y.z")), - UnresolvedStar(Option(Seq("x", "y", "z")))) - testConversion( - column.UnresolvedStar(None, Option(10L)), - UnresolvedDataFrameStar(10L)) - } - - test("regex") { - testConversion( - column.UnresolvedRegex("`(_1)?+.+`"), - UnresolvedRegex("(_1)?+.+", None, caseSensitive = false)) - - val expression = testConversion( - column.UnresolvedRegex("a", planId = Option(11L)), - UnresolvedAttribute("a")) - assert(expression.getTagValue(LogicalPlan.PLAN_ID_TAG).contains(11L)) - assert(expression.getTagValue(LogicalPlan.IS_METADATA_COL).isEmpty) - } - - test("function") { - testConversion( - column.UnresolvedFunction("+", Seq(column.UnresolvedAttribute("a"), column.Literal(1))), - UnresolvedFunction(Seq("+"), Seq(UnresolvedAttribute("a"), Literal(1)), isDistinct = false)) - testConversion( - column.UnresolvedFunction( - "db1.myAgg", - Seq(column.UnresolvedAttribute("a")), - isDistinct = true, - isUserDefinedFunction = true), - UnresolvedFunction( - Seq("db1", "myAgg"), - Seq(UnresolvedAttribute("a")), - isDistinct = true)) - } - - test("alias") { - testConversion( - column.Alias(column.Literal("qwe"), "newA" :: Nil), - Alias(Literal("qwe"), "newA")()) - testConversion( - column.Alias(column.UnresolvedAttribute("complex"), "newA" :: "newB" :: Nil), - MultiAlias(UnresolvedAttribute("complex"), Seq("newA", "newB"))) - } - - private def testCast( - dataType: DataType, - colEvalMode: column.Cast.EvalMode.Value, - catEvalMode: EvalMode.Value): Unit = { - testConversion( - column.Cast(column.UnresolvedAttribute("attr"), dataType, Option(colEvalMode)), - Cast(UnresolvedAttribute("attr"), dataType, evalMode = catEvalMode)) - } - - test("cast") { - testConversion( - column.Cast(column.UnresolvedAttribute("str"), DoubleType), - Cast(UnresolvedAttribute("str"), DoubleType)) - - testCast(LongType, column.Cast.EvalMode.Legacy, EvalMode.LEGACY) - testCast(BinaryType, column.Cast.EvalMode.Try, EvalMode.TRY) - testCast(ShortType, column.Cast.EvalMode.Ansi, EvalMode.ANSI) - } - - private def testSortOrder( - colDirection: column.SortOrder.SortDirection.SortDirection, - colNullOrdering: column.SortOrder.NullOrdering.NullOrdering, - catDirection: SortDirection, - catNullOrdering: NullOrdering): Unit = { - testConversion( - column.SortOrder(column.UnresolvedAttribute("unsorted"), colDirection, colNullOrdering), - new SortOrder(UnresolvedAttribute("unsorted"), catDirection, catNullOrdering, Nil)) - } - - test("sortOrder") { - testSortOrder( - column.SortOrder.SortDirection.Ascending, - column.SortOrder.NullOrdering.NullsFirst, - Ascending, - NullsFirst) - testSortOrder( - column.SortOrder.SortDirection.Ascending, - column.SortOrder.NullOrdering.NullsLast, - Ascending, - NullsLast) - testSortOrder( - column.SortOrder.SortDirection.Descending, - column.SortOrder.NullOrdering.NullsFirst, - Descending, - NullsFirst) - testSortOrder( - column.SortOrder.SortDirection.Descending, - column.SortOrder.NullOrdering.NullsLast, - Descending, - NullsLast) - } - - private def testWindowFrame( - colFrameType: column.WindowFrame.FrameType.FrameType, - colLower: column.WindowFrame.FrameBoundary, - colUpper: column.WindowFrame.FrameBoundary, - catFrameType: FrameType, - catLower: Expression, - catUpper: Expression): Unit = { - testConversion( - column.Window( - column.UnresolvedFunction("sum", Seq(column.UnresolvedAttribute("a"))), - column.WindowSpec( - Seq(column.UnresolvedAttribute("b"), column.UnresolvedAttribute("c")), - Seq(column.SortOrder( - column.UnresolvedAttribute("d"), - column.SortOrder.SortDirection.Descending, - column.SortOrder.NullOrdering.NullsLast)), - Option(column.WindowFrame(colFrameType, colLower, colUpper)))), - WindowExpression( - UnresolvedFunction("sum", Seq(UnresolvedAttribute("a")), isDistinct = false), - WindowSpecDefinition( - Seq(UnresolvedAttribute("b"), UnresolvedAttribute("c")), - Seq(SortOrder(UnresolvedAttribute("d"), Descending, NullsLast, Nil)), - SpecifiedWindowFrame(catFrameType, catLower, catUpper)))) - } - - test("window") { - testConversion( - column.Window( - column.UnresolvedFunction("sum", Seq(column.UnresolvedAttribute("a"))), - column.WindowSpec( - Seq(column.UnresolvedAttribute("b"), column.UnresolvedAttribute("c")), - Nil, - None)), - WindowExpression( - UnresolvedFunction("sum", Seq(UnresolvedAttribute("a")), isDistinct = false), - WindowSpecDefinition( - Seq(UnresolvedAttribute("b"), UnresolvedAttribute("c")), - Nil, - UnspecifiedFrame))) - testWindowFrame( - column.WindowFrame.FrameType.Row, - column.WindowFrame.Value(column.Literal(-10)), - column.WindowFrame.Unbounded, - RowFrame, - Literal(-10), - UnboundedFollowing) - testWindowFrame( - column.WindowFrame.FrameType.Range, - column.WindowFrame.Unbounded, - column.WindowFrame.CurrentRow, - RangeFrame, - UnboundedPreceding, - CurrentRow) - } - - test("lambda") { - val colX = column.UnresolvedNamedLambdaVariable("x") - val catX = UnresolvedNamedLambdaVariable(Seq("x")) - testConversion( - column.LambdaFunction( - column.UnresolvedFunction("+", Seq(colX, column.UnresolvedAttribute("y"))), - Seq(colX)), - LambdaFunction( - UnresolvedFunction("+", Seq(catX, UnresolvedAttribute("y")), isDistinct = false), - Seq(catX))) - } - - test("caseWhen") { - testConversion( - column.CaseWhenOtherwise( - Seq(column.UnresolvedAttribute("c1") -> column.Literal("r1")), - Option(column.Literal("fallback"))), - CaseWhen( - Seq(UnresolvedAttribute("c1") -> Literal("r1")), - Option(Literal("fallback"))) - ) - } - - test("extract field") { - testConversion( - column.UnresolvedExtractValue(column.UnresolvedAttribute("struct"), column.Literal("cl_a")), - UnresolvedExtractValue(UnresolvedAttribute("struct"), Literal("cl_a"))) - } - - test("update field") { - testConversion( - column.UpdateFields( - column.UnresolvedAttribute("struct"), - "col_b", - Option(column.Literal("cl_a"))), - UpdateFields(UnresolvedAttribute("struct"), Seq(WithField("col_b", Literal("cl_a"))))) - - testConversion( - column.UpdateFields(column.UnresolvedAttribute("struct"), "col_c", None), - UpdateFields(UnresolvedAttribute("struct"), Seq(DropField("col_c")))) - } - - test("extension") { - testConversion(column.Extension(UnresolvedAttribute("bar")), UnresolvedAttribute("bar")) - } - - test("unsupported") { - intercept[SparkException](Converter(column.Extension("kaboom"))) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverterSuite.scala new file mode 100644 index 000000000000..7d0a2e14ba7b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverterSuite.scala @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.internal + +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions.{Expression, ExprId} +import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} +import org.apache.spark.sql.execution.SparkSqlParser +import org.apache.spark.sql.types._ + +/** + * Test suite for [[ColumnNode]] to [[Expression]] conversions. + */ +class ColumnNodeToExpressionConverterSuite extends SparkFunSuite { + private object Converter extends ColumnNodeToExpressionConverter { + override val conf: SQLConf = new SQLConf + override val parser: ParserInterface = new SparkSqlParser + } + + private def testConversion(node: => ColumnNode, expected: Expression): Expression = { + val myOrigin = Origin() + CurrentOrigin.withOrigin(myOrigin) { + val expression = normalizeExpression(Converter(node)) + assert(expression == normalizeExpression(expected)) + assert(expression.origin eq myOrigin) + expression + } + } + + private def normalizeExpression(e: Expression): Expression = e.transform { + case a: expressions.Alias => + a.copy()(exprId = ExprId(0), a.qualifier, a.explicitMetadata, a.nonInheritableMetadataKeys) + } + + test("literal") { + testConversion(Literal(1), expressions.Literal(1, IntegerType)) + testConversion( + Literal("foo", Option(StringType)), + expressions.Literal.create("foo", StringType)) + } + + test("attribute") { + val expression1 = testConversion(UnresolvedAttribute("x"), analysis.UnresolvedAttribute("x")) + assert(expression1.getTagValue(LogicalPlan.PLAN_ID_TAG).isEmpty) + assert(expression1.getTagValue(LogicalPlan.IS_METADATA_COL).isEmpty) + + val expression2 = testConversion( + UnresolvedAttribute("y", Option(44L), isMetadataColumn = true), + analysis.UnresolvedAttribute("y")) + assert(expression2.getTagValue(LogicalPlan.PLAN_ID_TAG).contains(44L)) + assert(expression2.getTagValue(LogicalPlan.IS_METADATA_COL).isDefined) + } + + test("star") { + testConversion(UnresolvedStar(None), analysis.UnresolvedStar(None)) + testConversion( + UnresolvedStar(Option("x.y.z")), + analysis.UnresolvedStar(Option(Seq("x", "y", "z")))) + testConversion( + UnresolvedStar(None, Option(10L)), + analysis.UnresolvedDataFrameStar(10L)) + } + + test("regex") { + testConversion( + UnresolvedRegex("`(_1)?+.+`"), + analysis.UnresolvedRegex("(_1)?+.+", None, caseSensitive = false)) + + val expression = testConversion( + UnresolvedRegex("a", planId = Option(11L)), + analysis.UnresolvedAttribute("a")) + assert(expression.getTagValue(LogicalPlan.PLAN_ID_TAG).contains(11L)) + assert(expression.getTagValue(LogicalPlan.IS_METADATA_COL).isEmpty) + } + + test("function") { + testConversion( + UnresolvedFunction("+", Seq(UnresolvedAttribute("a"), Literal(1))), + analysis.UnresolvedFunction( + Seq("+"), + Seq(analysis.UnresolvedAttribute("a"), expressions.Literal(1)), + isDistinct = false)) + testConversion( + UnresolvedFunction( + "db1.myAgg", + Seq(UnresolvedAttribute("a")), + isDistinct = true, + isUserDefinedFunction = true), + analysis.UnresolvedFunction( + Seq("db1", "myAgg"), + Seq(analysis.UnresolvedAttribute("a")), + isDistinct = true)) + } + + test("alias") { + testConversion( + Alias(Literal("qwe"), "newA" :: Nil), + expressions.Alias(expressions.Literal("qwe"), "newA")()) + testConversion( + Alias(UnresolvedAttribute("complex"), "newA" :: "newB" :: Nil), + analysis.MultiAlias(analysis.UnresolvedAttribute("complex"), Seq("newA", "newB"))) + } + + private def testCast( + dataType: DataType, + colEvalMode: Cast.EvalMode.Value, + catEvalMode: expressions.EvalMode.Value): Unit = { + testConversion( + Cast(UnresolvedAttribute("attr"), dataType, Option(colEvalMode)), + expressions.Cast(analysis.UnresolvedAttribute("attr"), dataType, evalMode = catEvalMode)) + } + + test("cast") { + testConversion( + Cast(UnresolvedAttribute("str"), DoubleType), + expressions.Cast(analysis.UnresolvedAttribute("str"), DoubleType)) + + testCast(LongType, Cast.EvalMode.Legacy, expressions.EvalMode.LEGACY) + testCast(BinaryType, Cast.EvalMode.Try, expressions.EvalMode.TRY) + testCast(ShortType, Cast.EvalMode.Ansi, expressions.EvalMode.ANSI) + } + + private def testSortOrder( + colDirection: SortOrder.SortDirection.SortDirection, + colNullOrdering: SortOrder.NullOrdering.NullOrdering, + catDirection: expressions.SortDirection, + catNullOrdering: expressions.NullOrdering): Unit = { + testConversion( + SortOrder(UnresolvedAttribute("unsorted"), colDirection, colNullOrdering), + new expressions.SortOrder( + analysis.UnresolvedAttribute("unsorted"), + catDirection, + catNullOrdering, + Nil)) + } + + test("sortOrder") { + testSortOrder( + SortOrder.SortDirection.Ascending, + SortOrder.NullOrdering.NullsFirst, + expressions.Ascending, + expressions.NullsFirst) + testSortOrder( + SortOrder.SortDirection.Ascending, + SortOrder.NullOrdering.NullsLast, + expressions.Ascending, + expressions.NullsLast) + testSortOrder( + SortOrder.SortDirection.Descending, + SortOrder.NullOrdering.NullsFirst, + expressions.Descending, + expressions.NullsFirst) + testSortOrder( + SortOrder.SortDirection.Descending, + SortOrder.NullOrdering.NullsLast, + expressions.Descending, + expressions.NullsLast) + } + + private def testWindowFrame( + colFrameType: WindowFrame.FrameType.FrameType, + colLower: WindowFrame.FrameBoundary, + colUpper: WindowFrame.FrameBoundary, + catFrameType: expressions.FrameType, + catLower: Expression, + catUpper: Expression): Unit = { + testConversion( + Window( + UnresolvedFunction("sum", Seq(UnresolvedAttribute("a"))), + WindowSpec( + Seq(UnresolvedAttribute("b"), UnresolvedAttribute("c")), + Seq(SortOrder( + UnresolvedAttribute("d"), + SortOrder.SortDirection.Descending, + SortOrder.NullOrdering.NullsLast)), + Option(WindowFrame(colFrameType, colLower, colUpper)))), + expressions.WindowExpression( + analysis.UnresolvedFunction( + "sum", + Seq(analysis.UnresolvedAttribute("a")), + isDistinct = false), + expressions.WindowSpecDefinition( + Seq(analysis.UnresolvedAttribute("b"), analysis.UnresolvedAttribute("c")), + Seq(expressions.SortOrder( + analysis.UnresolvedAttribute("d"), + expressions.Descending, + expressions.NullsLast, + Nil)), + expressions.SpecifiedWindowFrame(catFrameType, catLower, catUpper)))) + } + + test("window") { + testConversion( + Window( + UnresolvedFunction("sum", Seq(UnresolvedAttribute("a"))), + WindowSpec( + Seq(UnresolvedAttribute("b"), UnresolvedAttribute("c")), + Nil, + None)), + expressions.WindowExpression( + analysis.UnresolvedFunction( + "sum", + Seq(analysis.UnresolvedAttribute("a")), + isDistinct = false), + expressions.WindowSpecDefinition( + Seq(analysis.UnresolvedAttribute("b"), analysis.UnresolvedAttribute("c")), + Nil, + expressions.UnspecifiedFrame))) + testWindowFrame( + WindowFrame.FrameType.Row, + WindowFrame.Value(Literal(-10)), + WindowFrame.Unbounded, + expressions.RowFrame, + expressions.Literal(-10), + expressions.UnboundedFollowing) + testWindowFrame( + WindowFrame.FrameType.Range, + WindowFrame.Unbounded, + WindowFrame.CurrentRow, + expressions.RangeFrame, + expressions.UnboundedPreceding, + expressions.CurrentRow) + } + + test("lambda") { + val colX = UnresolvedNamedLambdaVariable("x") + val catX = expressions.UnresolvedNamedLambdaVariable(Seq("x")) + testConversion( + LambdaFunction(UnresolvedFunction("+", Seq(colX, UnresolvedAttribute("y"))), Seq(colX)), + expressions.LambdaFunction( + analysis.UnresolvedFunction( + "+", + Seq(catX, analysis.UnresolvedAttribute("y")), + isDistinct = false), + Seq(catX))) + } + + test("caseWhen") { + testConversion( + CaseWhenOtherwise( + Seq(UnresolvedAttribute("c1") -> Literal("r1")), + Option(Literal("fallback"))), + expressions.CaseWhen( + Seq(analysis.UnresolvedAttribute("c1") -> expressions.Literal("r1")), + Option(expressions.Literal("fallback"))) + ) + } + + test("extract field") { + testConversion( + UnresolvedExtractValue(UnresolvedAttribute("struct"), Literal("cl_a")), + analysis.UnresolvedExtractValue( + analysis.UnresolvedAttribute("struct"), + expressions.Literal("cl_a"))) + } + + test("update field") { + testConversion( + UpdateFields(UnresolvedAttribute("struct"), "col_b", Option(Literal("cl_a"))), + expressions.UpdateFields( + analysis.UnresolvedAttribute("struct"), + Seq(expressions.WithField("col_b", expressions.Literal("cl_a"))))) + + testConversion( + UpdateFields(UnresolvedAttribute("struct"), "col_c", None), + expressions.UpdateFields( + analysis.UnresolvedAttribute("struct"), + Seq(expressions.DropField("col_c")))) + } + + test("extension") { + testConversion( + Extension(analysis.UnresolvedAttribute("bar")), + analysis.UnresolvedAttribute("bar")) + } + + test("unsupported") { + intercept[SparkException](Converter(Extension("kaboom"))) + } +}