diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala new file mode 100644 index 000000000000..971ad4281747 --- /dev/null +++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala @@ -0,0 +1,273 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.internal + +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} +import org.apache.spark.sql.types.{DataType, Metadata} + +/** + * AST for constructing columns. This API is implementation agnostic and allows us to build a + * single Column implementation that can be shared between implementations. Consequently a + * Dataframe API implementations will have to provide conversions from this AST to its + * implementation specific form (e.g. Catalyst expressions, or Connect protobuf messages). + * + * This API is a mirror image of Connect's expression.proto. There are a couple of extensions to + * make constructing nodes easier (e.g. [[CaseWhenOtherwise]]). We could not use the actual connect + * protobuf messages because of classpath clashes (e.g. Guava & gRPC) and Maven shading issues. + */ +private[sql] trait ColumnNode { + /** + * Origin where the node was created. + */ + def origin: Origin +} + +/** + * A literal column. + * + * @param value of the literal. This is the unconverted input value. + * @param dataType of the literal. If none is provided the dataType is inferred. + */ +private[sql] case class Literal( + value: Any, + dataType: Option[DataType] = None, + override val origin: Origin = CurrentOrigin.get) extends ColumnNode + +/** + * Reference to an attribute produced by one of the underlying DataFrames. + * + * @param unparsedIdentifier name of the attribute. + * @param planId id of the plan (Dataframe) that produces the attribute. + * @param isMetadataColumn whether this is a metadata column. + */ +private[sql] case class UnresolvedAttribute( + unparsedIdentifier: String, + planId: Option[Long] = None, + isMetadataColumn: Boolean = false, + override val origin: Origin = CurrentOrigin.get) + extends ColumnNode + +/** + * Reference to all columns in a namespace (global, a Dataframe, or a nested struct). + * + * @param unparsedTarget name of the namespace. None if the global namespace is supposed to be used. + * @param planId id of the plan (Dataframe) that produces the attribute. + */ +private[sql] case class UnresolvedStar( + unparsedTarget: Option[String], + planId: Option[Long] = None, + override val origin: Origin = CurrentOrigin.get) + extends ColumnNode + +/** + * Call a function. This can either be a built-in function, a UDF, or a UDF registered in the + * Catalog. + * + * @param functionName of the function to invoke. + * @param arguments to pass into the function. + * @param isDistinct (aggregate only) whether the input of the aggregate function should be + * de-duplicated. + */ +private[sql] case class UnresolvedFunction( + functionName: String, + arguments: Seq[ColumnNode], + isDistinct: Boolean = false, + isUserDefinedFunction: Boolean = false, + override val origin: Origin = CurrentOrigin.get) + extends ColumnNode + +/** + * Evaluate a SQL expression. + * + * @param expression text to execute. + */ +private[sql] case class SqlExpression( + expression: String, + override val origin: Origin = CurrentOrigin.get) extends ColumnNode + +/** + * Name a column, and (optionally) modify its metadata. + * + * @param child to name + * @param name to use + * @param metadata (optional) metadata to add. + */ +private[sql] case class Alias( + child: ColumnNode, + name: Seq[String], + metadata: Option[Metadata] = None, + override val origin: Origin = CurrentOrigin.get) extends ColumnNode + +/** + * Cast the value of a Column to a different [[DataType]]. The behavior of the cast can be + * influenced by the `evalMode`. + * + * @param child that produces the input value. + * @param dataType to cast to. + * @param evalMode (try/ansi/legacy) to use for the cast. + */ +private[sql] case class Cast( + child: ColumnNode, + dataType: DataType, + evalMode: Option[Cast.EvalMode.Value] = None, + override val origin: Origin = CurrentOrigin.get) extends ColumnNode + +private[sql] object Cast { + object EvalMode extends Enumeration { + type EvalMode = Value + val Legacy, Ansi, Try = Value + } +} + +/** + * Reference to all columns in the global namespace in that match a regex. + * + * @param regex name of the namespace. None if the global namespace is supposed to be used. + * @param planId id of the plan (Dataframe) that produces the attribute. + */ +private[sql] case class UnresolvedRegex( + regex: String, + planId: Option[Long] = None, + override val origin: Origin = CurrentOrigin.get) extends ColumnNode + +/** + * Sort the input column. + * + * @param child to sort. + * @param sortDirection to sort in, either Ascending or Descending. + * @param nullOrdering where to place nulls, either at the begin or the end. + */ +private[sql] case class SortOrder( + child: ColumnNode, + sortDirection: SortOrder.SortDirection.Value, + nullOrdering: SortOrder.NullOrdering.Value, + override val origin: Origin = CurrentOrigin.get) + extends ColumnNode + +private[sql] object SortOrder { + object SortDirection extends Enumeration { + type SortDirection = Value + val Ascending, Descending = Value + } + object NullOrdering extends Enumeration { + type NullOrdering = Value + val NullsFirst, NullsLast = Value + } +} + +/** + * Evaluate a function within a window. + * + * @param windowFunction function to execute. + * @param windowSpec of the window. + */ +private[sql] case class Window( + windowFunction: ColumnNode, + windowSpec: WindowSpec, + override val origin: Origin = CurrentOrigin.get) + extends ColumnNode + +private[sql] case class WindowSpec( + partitionColumns: Seq[ColumnNode], + sortColumns: Seq[SortOrder], + frame: Option[WindowFrame] = None) + +private[sql] case class WindowFrame( + frameType: WindowFrame.FrameType.Value, + lower: WindowFrame.FrameBoundary, + upper: WindowFrame.FrameBoundary) + +private[sql] object WindowFrame { + object FrameType extends Enumeration { + type FrameType = this.Value + val Row, Range = this.Value + } + + sealed trait FrameBoundary + object CurrentRow extends FrameBoundary + object Unbounded extends FrameBoundary + case class Value(value: ColumnNode) extends FrameBoundary +} + +/** + * Lambda function to execute. This typically passed as an argument to a function. + * + * @param function to execute. + * @param arguments the bound lambda variables. + */ +private[sql] case class LambdaFunction( + function: ColumnNode, + arguments: Seq[UnresolvedNamedLambdaVariable], + override val origin: Origin = CurrentOrigin.get) extends ColumnNode + +/** + * Variable used in a [[LambdaFunction]]. + * + * @param name of the variable. + */ +private[sql] case class UnresolvedNamedLambdaVariable( + name: String, + override val origin: Origin = CurrentOrigin.get) extends ColumnNode + +/** + * Extract a value from a complex type. This can be a field from a struct, a value from a map, + * or an element from an array. + * + * @param child that produces a complex value. + * @param extraction that is used to access the complex type. This needs to be a string type for + * structs and maps, and it needs to be an integer for arrays. + */ +private[sql] case class UnresolvedExtractValue( + child: ColumnNode, + extraction: ColumnNode, + override val origin: Origin = CurrentOrigin.get) extends ColumnNode + +/** + * Update or drop the field of a struct. + * + * @param structExpression that will be updated. + * @param fieldName name of the field to update. + * @param valueExpression new value of the field. If this is None the field will be dropped. + */ +private[sql] case class UpdateFields( + structExpression: ColumnNode, + fieldName: String, + valueExpression: Option[ColumnNode] = None, + override val origin: Origin = CurrentOrigin.get) extends ColumnNode + +/** + * Evaluate one or more conditional branches. The value of the first branch for which the predicate + * evalutes to true is returned. If none of the branches evaluate to true, the value of `otherwise` + * is returned. + * + * @param branches to evaluate. Each entry if a pair of condition and value. + * @param otherwise (optional) to evaluate when none of the branches evaluate to true. + */ +private[sql] case class CaseWhenOtherwise( + branches: Seq[(ColumnNode, ColumnNode)], + otherwise: Option[ColumnNode] = None, + override val origin: Origin = CurrentOrigin.get) + extends ColumnNode + +/** + * Extension point that allows an implementation to use its column representation to be used in a + * generic column expression. This should only be used when the Column constructed is used within + * the implementation. + */ +private[sql] case class Extension( + value: Any, + override val origin: Origin = CurrentOrigin.get) extends ColumnNode diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverter.scala new file mode 100644 index 000000000000..ff2d6f72c86a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverter.scala @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.internal + +import org.apache.spark.SparkException +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.parser.{ParserInterface, ParserUtils} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.trees.CurrentOrigin +import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.execution.SparkSqlParser + +/** + * Convert a [[ColumnNode]] into an [[Expression]]. + */ +private[sql] trait ColumnNodeToExpressionConverter extends (ColumnNode => Expression) { + + protected def parser: ParserInterface + protected def conf: SQLConf + + override def apply(node: ColumnNode): Expression = CurrentOrigin.withOrigin(node.origin) { + node match { + case Literal(value, Some(dataType), _) => + expressions.Literal.create(value, dataType) + + case Literal(value, None, _) => + expressions.Literal(value) + + case UnresolvedAttribute(unparsedIdentifier, planId, isMetadataColumn, _) => + convertUnresolvedAttribute(unparsedIdentifier, planId, isMetadataColumn) + + case UnresolvedStar(unparsedTarget, None, _) => + analysis.UnresolvedStar(unparsedTarget.map(analysis.UnresolvedAttribute.parseAttributeName)) + + case UnresolvedStar(None, Some(planId), _) => + analysis.UnresolvedDataFrameStar(planId) + + case UnresolvedRegex(ParserUtils.escapedIdentifier(columnNameRegex), _, _) => + analysis.UnresolvedRegex(columnNameRegex, None, conf.caseSensitiveAnalysis) + + case UnresolvedRegex( + ParserUtils.qualifiedEscapedIdentifier(nameParts, columnNameRegex), _, _) => + analysis.UnresolvedRegex(columnNameRegex, Some(nameParts), conf.caseSensitiveAnalysis) + + case UnresolvedRegex(unparsedIdentifier, planId, _) => + convertUnresolvedAttribute(unparsedIdentifier, planId, isMetadataColumn = false) + + case UnresolvedFunction(functionName, arguments, isDistinct, isUDF, _) => + val nameParts = if (isUDF) { + parser.parseMultipartIdentifier(functionName) + } else { + Seq(functionName) + } + analysis.UnresolvedFunction(nameParts, arguments.map(apply), isDistinct) + + case Alias(child, Seq(name), metadata, _) => + expressions.Alias(apply(child), name)(explicitMetadata = metadata) + + case Alias(child, names, None, _) if names.nonEmpty => + analysis.MultiAlias(apply(child), names) + + case Cast(child, dataType, evalMode, _) => + val convertedEvalMode = evalMode match { + case Some(Cast.EvalMode.Ansi) => expressions.EvalMode.ANSI + case Some(Cast.EvalMode.Legacy) => expressions.EvalMode.LEGACY + case Some(Cast.EvalMode.Try) => expressions.EvalMode.TRY + case _ => expressions.EvalMode.fromSQLConf(conf) + } + val cast = expressions.Cast( + apply(child), + CharVarcharUtils.replaceCharVarcharWithStringForCast(dataType), + None, + convertedEvalMode) + cast.setTagValue(expressions.Cast.USER_SPECIFIED_CAST, ()) + cast + + case SqlExpression(expression, _) => + parser.parseExpression(expression) + + case sortOrder: SortOrder => + convertSortOrder(sortOrder) + + case Window(function, spec, _) => + val frame = spec.frame match { + case Some(WindowFrame(frameType, lower, upper)) => + val convertedFrameType = frameType match { + case WindowFrame.FrameType.Range => expressions.RangeFrame + case WindowFrame.FrameType.Row => expressions.RowFrame + } + val convertedLower = lower match { + case WindowFrame.CurrentRow => expressions.CurrentRow + case WindowFrame.Unbounded => expressions.UnboundedPreceding + case WindowFrame.Value(node) => apply(node) + } + val convertedUpper = upper match { + case WindowFrame.CurrentRow => expressions.CurrentRow + case WindowFrame.Unbounded => expressions.UnboundedFollowing + case WindowFrame.Value(node) => apply(node) + } + expressions.SpecifiedWindowFrame(convertedFrameType, convertedLower, convertedUpper) + case None => + expressions.UnspecifiedFrame + } + expressions.WindowExpression( + apply(function), + expressions.WindowSpecDefinition( + partitionSpec = spec.partitionColumns.map(apply), + orderSpec = spec.sortColumns.map(convertSortOrder), + frameSpecification = frame)) + + case LambdaFunction(function, arguments, _) => + expressions.LambdaFunction( + apply(function), + arguments.map(convertUnresolvedNamedLambdaVariable)) + + case v: UnresolvedNamedLambdaVariable => + convertUnresolvedNamedLambdaVariable(v) + + case UnresolvedExtractValue(child, extraction, _) => + analysis.UnresolvedExtractValue(apply(child), apply(extraction)) + + case UpdateFields(struct, field, Some(value), _) => + expressions.UpdateFields(apply(struct), field, apply(value)) + + case UpdateFields(struct, field, None, _) => + expressions.UpdateFields(apply(struct), field) + + case CaseWhenOtherwise(branches, otherwise, _) => + expressions.CaseWhen( + branches = branches.map { case (condition, value) => + (apply(condition), apply(value)) + }, + elseValue = otherwise.map(apply)) + + case Extension(expression: Expression, _) => + expression + + case node => + throw SparkException.internalError("Unsupported ColumnNode: " + node) + } + } + + private def convertUnresolvedNamedLambdaVariable( + v: UnresolvedNamedLambdaVariable): expressions.UnresolvedNamedLambdaVariable = { + expressions.UnresolvedNamedLambdaVariable(Seq(v.name)) + } + + private def convertSortOrder(sortOrder: SortOrder): expressions.SortOrder = { + val sortDirection = sortOrder.sortDirection match { + case SortOrder.SortDirection.Ascending => expressions.Ascending + case SortOrder.SortDirection.Descending => expressions.Descending + } + val nullOrdering = sortOrder.nullOrdering match { + case SortOrder.NullOrdering.NullsFirst => expressions.NullsFirst + case SortOrder.NullOrdering.NullsLast => expressions.NullsLast + } + expressions.SortOrder(apply(sortOrder.child), sortDirection, nullOrdering, Nil) + } + + private def convertUnresolvedAttribute( + unparsedIdentifier: String, + planId: Option[Long], + isMetadataColumn: Boolean): analysis.UnresolvedAttribute = { + val attribute = analysis.UnresolvedAttribute.quotedString(unparsedIdentifier) + if (planId.isDefined) { + attribute.setTagValue(LogicalPlan.PLAN_ID_TAG, planId.get) + } + if (isMetadataColumn) { + attribute.setTagValue(LogicalPlan.IS_METADATA_COL, ()) + } + attribute + } +} + +object ColumnNodeToExpressionConverter extends ColumnNodeToExpressionConverter { + override protected def parser: ParserInterface = { + SparkSession.getActiveSession.map(_.sessionState.sqlParser).getOrElse { + new SparkSqlParser() + } + } + + override protected def conf: SQLConf = SQLConf.get +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverterSuite.scala new file mode 100644 index 000000000000..7d0a2e14ba7b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverterSuite.scala @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.internal + +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions.{Expression, ExprId} +import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} +import org.apache.spark.sql.execution.SparkSqlParser +import org.apache.spark.sql.types._ + +/** + * Test suite for [[ColumnNode]] to [[Expression]] conversions. + */ +class ColumnNodeToExpressionConverterSuite extends SparkFunSuite { + private object Converter extends ColumnNodeToExpressionConverter { + override val conf: SQLConf = new SQLConf + override val parser: ParserInterface = new SparkSqlParser + } + + private def testConversion(node: => ColumnNode, expected: Expression): Expression = { + val myOrigin = Origin() + CurrentOrigin.withOrigin(myOrigin) { + val expression = normalizeExpression(Converter(node)) + assert(expression == normalizeExpression(expected)) + assert(expression.origin eq myOrigin) + expression + } + } + + private def normalizeExpression(e: Expression): Expression = e.transform { + case a: expressions.Alias => + a.copy()(exprId = ExprId(0), a.qualifier, a.explicitMetadata, a.nonInheritableMetadataKeys) + } + + test("literal") { + testConversion(Literal(1), expressions.Literal(1, IntegerType)) + testConversion( + Literal("foo", Option(StringType)), + expressions.Literal.create("foo", StringType)) + } + + test("attribute") { + val expression1 = testConversion(UnresolvedAttribute("x"), analysis.UnresolvedAttribute("x")) + assert(expression1.getTagValue(LogicalPlan.PLAN_ID_TAG).isEmpty) + assert(expression1.getTagValue(LogicalPlan.IS_METADATA_COL).isEmpty) + + val expression2 = testConversion( + UnresolvedAttribute("y", Option(44L), isMetadataColumn = true), + analysis.UnresolvedAttribute("y")) + assert(expression2.getTagValue(LogicalPlan.PLAN_ID_TAG).contains(44L)) + assert(expression2.getTagValue(LogicalPlan.IS_METADATA_COL).isDefined) + } + + test("star") { + testConversion(UnresolvedStar(None), analysis.UnresolvedStar(None)) + testConversion( + UnresolvedStar(Option("x.y.z")), + analysis.UnresolvedStar(Option(Seq("x", "y", "z")))) + testConversion( + UnresolvedStar(None, Option(10L)), + analysis.UnresolvedDataFrameStar(10L)) + } + + test("regex") { + testConversion( + UnresolvedRegex("`(_1)?+.+`"), + analysis.UnresolvedRegex("(_1)?+.+", None, caseSensitive = false)) + + val expression = testConversion( + UnresolvedRegex("a", planId = Option(11L)), + analysis.UnresolvedAttribute("a")) + assert(expression.getTagValue(LogicalPlan.PLAN_ID_TAG).contains(11L)) + assert(expression.getTagValue(LogicalPlan.IS_METADATA_COL).isEmpty) + } + + test("function") { + testConversion( + UnresolvedFunction("+", Seq(UnresolvedAttribute("a"), Literal(1))), + analysis.UnresolvedFunction( + Seq("+"), + Seq(analysis.UnresolvedAttribute("a"), expressions.Literal(1)), + isDistinct = false)) + testConversion( + UnresolvedFunction( + "db1.myAgg", + Seq(UnresolvedAttribute("a")), + isDistinct = true, + isUserDefinedFunction = true), + analysis.UnresolvedFunction( + Seq("db1", "myAgg"), + Seq(analysis.UnresolvedAttribute("a")), + isDistinct = true)) + } + + test("alias") { + testConversion( + Alias(Literal("qwe"), "newA" :: Nil), + expressions.Alias(expressions.Literal("qwe"), "newA")()) + testConversion( + Alias(UnresolvedAttribute("complex"), "newA" :: "newB" :: Nil), + analysis.MultiAlias(analysis.UnresolvedAttribute("complex"), Seq("newA", "newB"))) + } + + private def testCast( + dataType: DataType, + colEvalMode: Cast.EvalMode.Value, + catEvalMode: expressions.EvalMode.Value): Unit = { + testConversion( + Cast(UnresolvedAttribute("attr"), dataType, Option(colEvalMode)), + expressions.Cast(analysis.UnresolvedAttribute("attr"), dataType, evalMode = catEvalMode)) + } + + test("cast") { + testConversion( + Cast(UnresolvedAttribute("str"), DoubleType), + expressions.Cast(analysis.UnresolvedAttribute("str"), DoubleType)) + + testCast(LongType, Cast.EvalMode.Legacy, expressions.EvalMode.LEGACY) + testCast(BinaryType, Cast.EvalMode.Try, expressions.EvalMode.TRY) + testCast(ShortType, Cast.EvalMode.Ansi, expressions.EvalMode.ANSI) + } + + private def testSortOrder( + colDirection: SortOrder.SortDirection.SortDirection, + colNullOrdering: SortOrder.NullOrdering.NullOrdering, + catDirection: expressions.SortDirection, + catNullOrdering: expressions.NullOrdering): Unit = { + testConversion( + SortOrder(UnresolvedAttribute("unsorted"), colDirection, colNullOrdering), + new expressions.SortOrder( + analysis.UnresolvedAttribute("unsorted"), + catDirection, + catNullOrdering, + Nil)) + } + + test("sortOrder") { + testSortOrder( + SortOrder.SortDirection.Ascending, + SortOrder.NullOrdering.NullsFirst, + expressions.Ascending, + expressions.NullsFirst) + testSortOrder( + SortOrder.SortDirection.Ascending, + SortOrder.NullOrdering.NullsLast, + expressions.Ascending, + expressions.NullsLast) + testSortOrder( + SortOrder.SortDirection.Descending, + SortOrder.NullOrdering.NullsFirst, + expressions.Descending, + expressions.NullsFirst) + testSortOrder( + SortOrder.SortDirection.Descending, + SortOrder.NullOrdering.NullsLast, + expressions.Descending, + expressions.NullsLast) + } + + private def testWindowFrame( + colFrameType: WindowFrame.FrameType.FrameType, + colLower: WindowFrame.FrameBoundary, + colUpper: WindowFrame.FrameBoundary, + catFrameType: expressions.FrameType, + catLower: Expression, + catUpper: Expression): Unit = { + testConversion( + Window( + UnresolvedFunction("sum", Seq(UnresolvedAttribute("a"))), + WindowSpec( + Seq(UnresolvedAttribute("b"), UnresolvedAttribute("c")), + Seq(SortOrder( + UnresolvedAttribute("d"), + SortOrder.SortDirection.Descending, + SortOrder.NullOrdering.NullsLast)), + Option(WindowFrame(colFrameType, colLower, colUpper)))), + expressions.WindowExpression( + analysis.UnresolvedFunction( + "sum", + Seq(analysis.UnresolvedAttribute("a")), + isDistinct = false), + expressions.WindowSpecDefinition( + Seq(analysis.UnresolvedAttribute("b"), analysis.UnresolvedAttribute("c")), + Seq(expressions.SortOrder( + analysis.UnresolvedAttribute("d"), + expressions.Descending, + expressions.NullsLast, + Nil)), + expressions.SpecifiedWindowFrame(catFrameType, catLower, catUpper)))) + } + + test("window") { + testConversion( + Window( + UnresolvedFunction("sum", Seq(UnresolvedAttribute("a"))), + WindowSpec( + Seq(UnresolvedAttribute("b"), UnresolvedAttribute("c")), + Nil, + None)), + expressions.WindowExpression( + analysis.UnresolvedFunction( + "sum", + Seq(analysis.UnresolvedAttribute("a")), + isDistinct = false), + expressions.WindowSpecDefinition( + Seq(analysis.UnresolvedAttribute("b"), analysis.UnresolvedAttribute("c")), + Nil, + expressions.UnspecifiedFrame))) + testWindowFrame( + WindowFrame.FrameType.Row, + WindowFrame.Value(Literal(-10)), + WindowFrame.Unbounded, + expressions.RowFrame, + expressions.Literal(-10), + expressions.UnboundedFollowing) + testWindowFrame( + WindowFrame.FrameType.Range, + WindowFrame.Unbounded, + WindowFrame.CurrentRow, + expressions.RangeFrame, + expressions.UnboundedPreceding, + expressions.CurrentRow) + } + + test("lambda") { + val colX = UnresolvedNamedLambdaVariable("x") + val catX = expressions.UnresolvedNamedLambdaVariable(Seq("x")) + testConversion( + LambdaFunction(UnresolvedFunction("+", Seq(colX, UnresolvedAttribute("y"))), Seq(colX)), + expressions.LambdaFunction( + analysis.UnresolvedFunction( + "+", + Seq(catX, analysis.UnresolvedAttribute("y")), + isDistinct = false), + Seq(catX))) + } + + test("caseWhen") { + testConversion( + CaseWhenOtherwise( + Seq(UnresolvedAttribute("c1") -> Literal("r1")), + Option(Literal("fallback"))), + expressions.CaseWhen( + Seq(analysis.UnresolvedAttribute("c1") -> expressions.Literal("r1")), + Option(expressions.Literal("fallback"))) + ) + } + + test("extract field") { + testConversion( + UnresolvedExtractValue(UnresolvedAttribute("struct"), Literal("cl_a")), + analysis.UnresolvedExtractValue( + analysis.UnresolvedAttribute("struct"), + expressions.Literal("cl_a"))) + } + + test("update field") { + testConversion( + UpdateFields(UnresolvedAttribute("struct"), "col_b", Option(Literal("cl_a"))), + expressions.UpdateFields( + analysis.UnresolvedAttribute("struct"), + Seq(expressions.WithField("col_b", expressions.Literal("cl_a"))))) + + testConversion( + UpdateFields(UnresolvedAttribute("struct"), "col_c", None), + expressions.UpdateFields( + analysis.UnresolvedAttribute("struct"), + Seq(expressions.DropField("col_c")))) + } + + test("extension") { + testConversion( + Extension(analysis.UnresolvedAttribute("bar")), + analysis.UnresolvedAttribute("bar")) + } + + test("unsupported") { + intercept[SparkException](Converter(Extension("kaboom"))) + } +}