diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 32f7a98ccab8..29d23adf0feb 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -509,7 +509,11 @@ fromStatementBody querySpecification : transformClause fromClause? - whereClause? #transformQuerySpecification + lateralView* + whereClause? + aggregationClause? + havingClause? + windowClause? #transformQuerySpecification | selectClause fromClause? lateralView* diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index f5f0f5c2037a..4949017df4ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1558,14 +1558,9 @@ class Analyzer(override val catalogManager: CatalogManager) } else { a.copy(aggregateExpressions = buildExpandedProjectList(a.aggregateExpressions, a.child)) } - // If the script transformation input contains Stars, expand it. + // TODO: Remove this logic and see SPARK-34035 case t: ScriptTransformation if containsStar(t.input) => - t.copy( - input = t.input.flatMap { - case s: Star => s.expand(t.child, resolver) - case o => o :: Nil - } - ) + t.copy(input = t.child.output) case g: Generate if containsStar(g.generator.children) => throw QueryCompilationErrors.invalidStarUsageError("explode/json_tuple/UDTF") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index e448af8469f7..dde66d592ff8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -150,7 +150,11 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg withTransformQuerySpecification( ctx, ctx.transformClause, + ctx.lateralView, ctx.whereClause, + ctx.aggregationClause, + ctx.havingClause, + ctx.windowClause, plan ) } else { @@ -587,7 +591,16 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg val from = OneRowRelation().optional(ctx.fromClause) { visitFromClause(ctx.fromClause) } - withTransformQuerySpecification(ctx, ctx.transformClause, ctx.whereClause, from) + withTransformQuerySpecification( + ctx, + ctx.transformClause, + ctx.lateralView, + ctx.whereClause, + ctx.aggregationClause, + ctx.havingClause, + ctx.windowClause, + from + ) } override def visitRegularQuerySpecification( @@ -641,14 +654,12 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg private def withTransformQuerySpecification( ctx: ParserRuleContext, transformClause: TransformClauseContext, + lateralView: java.util.List[LateralViewContext], whereClause: WhereClauseContext, - relation: LogicalPlan): LogicalPlan = withOrigin(ctx) { - // Add where. - val withFilter = relation.optionalMap(whereClause)(withWhereClause) - - // Create the transform. - val expressions = visitNamedExpressionSeq(transformClause.namedExpressionSeq) - + aggregationClause: AggregationClauseContext, + havingClause: HavingClauseContext, + windowClause: WindowClauseContext, + relation: LogicalPlan): LogicalPlan = withOrigin(ctx) { // Create the attributes. val (attributes, schemaLess) = if (transformClause.colTypeList != null) { // Typed return columns. @@ -664,12 +675,22 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg AttributeReference("value", StringType)()), true) } - // Create the transform. + val plan = visitCommonSelectQueryClausePlan( + relation, + lateralView, + transformClause.namedExpressionSeq, + whereClause, + aggregationClause, + havingClause, + windowClause, + isDistinct = false) + ScriptTransformation( - expressions, + // TODO: Remove this logic and see SPARK-34035 + Seq(UnresolvedStar(None)), string(transformClause.script), attributes, - withFilter, + plan, withScriptIOSchema( ctx, transformClause.inRowFormat, @@ -697,13 +718,40 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg havingClause: HavingClauseContext, windowClause: WindowClauseContext, relation: LogicalPlan): LogicalPlan = withOrigin(ctx) { + val isDistinct = selectClause.setQuantifier() != null && + selectClause.setQuantifier().DISTINCT() != null + + val plan = visitCommonSelectQueryClausePlan( + relation, + lateralView, + selectClause.namedExpressionSeq, + whereClause, + aggregationClause, + havingClause, + windowClause, + isDistinct) + + // Hint + selectClause.hints.asScala.foldRight(plan)(withHints) + } + + def visitCommonSelectQueryClausePlan( + relation: LogicalPlan, + lateralView: java.util.List[LateralViewContext], + namedExpressionSeq: NamedExpressionSeqContext, + whereClause: WhereClauseContext, + aggregationClause: AggregationClauseContext, + havingClause: HavingClauseContext, + windowClause: WindowClauseContext, + isDistinct: Boolean): LogicalPlan = { // Add lateral views. val withLateralView = lateralView.asScala.foldLeft(relation)(withGenerate) // Add where. val withFilter = withLateralView.optionalMap(whereClause)(withWhereClause) - val expressions = visitNamedExpressionSeq(selectClause.namedExpressionSeq) + val expressions = visitNamedExpressionSeq(namedExpressionSeq) + // Add aggregation or a project. val namedExpressions = expressions.map { case e: NamedExpression => e @@ -737,9 +785,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg } // Distinct - val withDistinct = if ( - selectClause.setQuantifier() != null && - selectClause.setQuantifier().DISTINCT() != null) { + val withDistinct = if (isDistinct) { Distinct(withProject) } else { withProject @@ -748,8 +794,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg // Window val withWindow = withDistinct.optionalMap(windowClause)(withWindowClause) - // Hint - selectClause.hints.asScala.foldRight(withWindow)(withHints) + withWindow } // Script Transform's input/output format. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index f75f9c174256..778e81365830 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.parser import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedGenerator, UnresolvedInlineTable, UnresolvedRelation, UnresolvedSubqueryColumnAliases, UnresolvedTableValuedFunction} +import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedGenerator, UnresolvedInlineTable, UnresolvedRelation, UnresolvedStar, UnresolvedSubqueryColumnAliases, UnresolvedTableValuedFunction} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -1061,11 +1061,11 @@ class PlanParserSuite extends AnalysisTest { |FROM testData """.stripMargin, ScriptTransformation( - Seq('a, 'b, 'c), + Seq(UnresolvedStar(None)), "cat", Seq(AttributeReference("key", StringType)(), AttributeReference("value", StringType)()), - UnresolvedRelation(TableIdentifier("testData")), + Project(Seq('a, 'b, 'c), UnresolvedRelation(TableIdentifier("testData"))), ScriptInputOutputSchema(List.empty, List.empty, None, None, List.empty, List.empty, None, None, true)) ) @@ -1078,12 +1078,12 @@ class PlanParserSuite extends AnalysisTest { |FROM testData """.stripMargin, ScriptTransformation( - Seq('a, 'b, 'c), + Seq(UnresolvedStar(None)), "cat", Seq(AttributeReference("a", StringType)(), AttributeReference("b", StringType)(), AttributeReference("c", StringType)()), - UnresolvedRelation(TableIdentifier("testData")), + Project(Seq('a, 'b, 'c), UnresolvedRelation(TableIdentifier("testData"))), ScriptInputOutputSchema(List.empty, List.empty, None, None, List.empty, List.empty, None, None, false))) @@ -1095,12 +1095,12 @@ class PlanParserSuite extends AnalysisTest { |FROM testData """.stripMargin, ScriptTransformation( - Seq('a, 'b, 'c), + Seq(UnresolvedStar(None)), "cat", Seq(AttributeReference("a", IntegerType)(), AttributeReference("b", StringType)(), AttributeReference("c", LongType)()), - UnresolvedRelation(TableIdentifier("testData")), + Project(Seq('a, 'b, 'c), UnresolvedRelation(TableIdentifier("testData"))), ScriptInputOutputSchema(List.empty, List.empty, None, None, List.empty, List.empty, None, None, false))) @@ -1124,12 +1124,12 @@ class PlanParserSuite extends AnalysisTest { |FROM testData """.stripMargin, ScriptTransformation( - Seq('a, 'b, 'c), + Seq(UnresolvedStar(None)), "cat", Seq(AttributeReference("a", StringType)(), AttributeReference("b", StringType)(), AttributeReference("c", StringType)()), - UnresolvedRelation(TableIdentifier("testData")), + Project(Seq('a, 'b, 'c), UnresolvedRelation(TableIdentifier("testData"))), ScriptInputOutputSchema( Seq(("TOK_TABLEROWFORMATFIELD", "\t"), ("TOK_TABLEROWFORMATCOLLITEMS", "\u0002"), diff --git a/sql/core/src/test/resources/sql-tests/inputs/transform.sql b/sql/core/src/test/resources/sql-tests/inputs/transform.sql index bfb2474eb4ae..e89404c5f845 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/transform.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/transform.sql @@ -5,6 +5,12 @@ CREATE OR REPLACE TEMPORARY VIEW t AS SELECT * FROM VALUES ('3', true, unhex('537061726B2053514C'), tinyint(3), 3, smallint(300), bigint(3), float(3.0), 3.0, Decimal(3.0), timestamp('1997-02-10 17:32:01-08'), date('2000-04-03')) AS t(a, b, c, d, e, f, g, h, i, j, k, l); +CREATE OR REPLACE TEMPORARY VIEW script_trans AS SELECT * FROM VALUES +(1, 2, 3), +(4, 5, 6), +(7, 8, 9) +AS script_trans(a, b, c); + SELECT TRANSFORM(a) USING 'cat' AS (a) FROM t; @@ -184,6 +190,132 @@ SELECT a, b, decode(c, 'UTF-8'), d, e, f, g, h, i, j, k, l FROM ( FROM t ) tmp; +SELECT TRANSFORM(b, a, CAST(c AS STRING)) + USING 'cat' AS (a, b, c) +FROM script_trans +WHERE a <= 4; + +SELECT TRANSFORM(1, 2, 3) + USING 'cat' AS (a, b, c) +FROM script_trans +WHERE a <= 4; + +SELECT TRANSFORM(1, 2) + USING 'cat' AS (a INT, b INT) +FROM script_trans +LIMIT 1; + +SELECT TRANSFORM( + b AS d5, a, + CASE + WHEN c > 100 THEN 1 + WHEN c < 100 THEN 2 + ELSE 3 END) + USING 'cat' AS (a, b, c) +FROM script_trans +WHERE a <= 4; + +SELECT TRANSFORM(b, a, c + 1) + USING 'cat' AS (a, b, c) +FROM script_trans +WHERE a <= 4; + +SELECT TRANSFORM(*) + USING 'cat' AS (a, b, c) +FROM script_trans +WHERE a <= 4; + +SELECT TRANSFORM(b AS d, MAX(a) as max_a, CAST(SUM(c) AS STRING)) + USING 'cat' AS (a, b, c) +FROM script_trans +WHERE a <= 4 +GROUP BY b; + +SELECT TRANSFORM(b AS d, MAX(a) FILTER (WHERE a > 3) AS max_a, CAST(SUM(c) AS STRING)) + USING 'cat' AS (a,b,c) +FROM script_trans +WHERE a <= 4 +GROUP BY b; + +SELECT TRANSFORM(b, MAX(a) as max_a, CAST(sum(c) AS STRING)) + USING 'cat' AS (a, b, c) +FROM script_trans +WHERE a <= 2 +GROUP BY b; + +SELECT TRANSFORM(b, MAX(a) as max_a, CAST(SUM(c) AS STRING)) + USING 'cat' AS (a, b, c) +FROM script_trans +WHERE a <= 4 +GROUP BY b +HAVING max_a > 0; + +SELECT TRANSFORM(b, MAX(a) as max_a, CAST(SUM(c) AS STRING)) + USING 'cat' AS (a, b, c) +FROM script_trans +WHERE a <= 4 +GROUP BY b +HAVING max(a) > 1; + +SELECT TRANSFORM(b, MAX(a) OVER w as max_a, CAST(SUM(c) OVER w AS STRING)) + USING 'cat' AS (a, b, c) +FROM script_trans +WHERE a <= 4 +WINDOW w AS (PARTITION BY b ORDER BY a); + +SELECT TRANSFORM(b, MAX(a) as max_a, CAST(SUM(c) AS STRING), myCol, myCol2) + USING 'cat' AS (a, b, c, d, e) +FROM script_trans +LATERAL VIEW explode(array(array(1,2,3))) myTable AS myCol +LATERAL VIEW explode(myTable.myCol) myTable2 AS myCol2 +WHERE a <= 4 +GROUP BY b, myCol, myCol2 +HAVING max(a) > 1; + +FROM( + FROM script_trans + SELECT TRANSFORM(a, b) + USING 'cat' AS (`a` INT, b STRING) +) t +SELECT a + 1; + +FROM( + SELECT TRANSFORM(a, SUM(b) b) + USING 'cat' AS (`a` INT, b STRING) + FROM script_trans + GROUP BY a +) t +SELECT (b + 1) AS result +ORDER BY result; + +MAP k / 10 USING 'cat' AS (one) FROM (SELECT 10 AS k); + +FROM (SELECT 1 AS key, 100 AS value) src +MAP src.*, src.key, CAST(src.key / 10 AS INT), CAST(src.key % 10 AS INT), src.value + USING 'cat' AS (k, v, tkey, ten, one, tvalue); + +SELECT TRANSFORM(1) + USING 'cat' AS (a) +FROM script_trans +HAVING true; + +SET spark.sql.legacy.parser.havingWithoutGroupByAsWhere=true; + +SELECT TRANSFORM(1) + USING 'cat' AS (a) +FROM script_trans +HAVING true; + +SET spark.sql.legacy.parser.havingWithoutGroupByAsWhere=false; + +SET spark.sql.parser.quotedRegexColumnNames=true; + +SELECT TRANSFORM(`(a|b)?+.+`) + USING 'cat' AS (c) +FROM script_trans; + +SET spark.sql.parser.quotedRegexColumnNames=false; + -- SPARK-34634: self join using CTE contains transform WITH temp AS ( SELECT TRANSFORM(a) USING 'cat' AS (b string) FROM t diff --git a/sql/core/src/test/resources/sql-tests/results/transform.sql.out b/sql/core/src/test/resources/sql-tests/results/transform.sql.out index 16cab8ab495a..1fa165be1a87 100644 --- a/sql/core/src/test/resources/sql-tests/results/transform.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/transform.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 17 +-- Number of queries: 42 -- !query @@ -14,6 +14,18 @@ struct<> +-- !query +CREATE OR REPLACE TEMPORARY VIEW script_trans AS SELECT * FROM VALUES +(1, 2, 3), +(4, 5, 6), +(7, 8, 9) +AS script_trans(a, b, c) +-- !query schema +struct<> +-- !query output + + + -- !query SELECT TRANSFORM(a) USING 'cat' AS (a) @@ -131,18 +143,10 @@ USING 'cat' AS (a, b) FROM t GROUP BY b -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.catalyst.parser.ParseException - -mismatched input 'GROUP' expecting {, ';'}(line 4, pos 0) - -== SQL == -SELECT TRANSFORM(b, max(a), sum(f)) -USING 'cat' AS (a, b) -FROM t -GROUP BY b -^^^ +false 2 +true 3 -- !query @@ -335,6 +339,294 @@ SELECT a, b, decode(c, 'UTF-8'), d, e, f, g, h, i, j, k, l FROM ( ) tmp +-- !query +SELECT TRANSFORM(b, a, CAST(c AS STRING)) + USING 'cat' AS (a, b, c) +FROM script_trans +WHERE a <= 4 +-- !query schema +struct +-- !query output +2 1 3 +5 4 6 + + +-- !query +SELECT TRANSFORM(1, 2, 3) + USING 'cat' AS (a, b, c) +FROM script_trans +WHERE a <= 4 +-- !query schema +struct +-- !query output +1 2 3 +1 2 3 + + +-- !query +SELECT TRANSFORM(1, 2) + USING 'cat' AS (a INT, b INT) +FROM script_trans +LIMIT 1 +-- !query schema +struct +-- !query output +1 2 + + +-- !query +SELECT TRANSFORM( + b AS d5, a, + CASE + WHEN c > 100 THEN 1 + WHEN c < 100 THEN 2 + ELSE 3 END) + USING 'cat' AS (a, b, c) +FROM script_trans +WHERE a <= 4 +-- !query schema +struct +-- !query output +2 1 2 +5 4 2 + + +-- !query +SELECT TRANSFORM(b, a, c + 1) + USING 'cat' AS (a, b, c) +FROM script_trans +WHERE a <= 4 +-- !query schema +struct +-- !query output +2 1 4 +5 4 7 + + +-- !query +SELECT TRANSFORM(*) + USING 'cat' AS (a, b, c) +FROM script_trans +WHERE a <= 4 +-- !query schema +struct +-- !query output +1 2 3 +4 5 6 + + +-- !query +SELECT TRANSFORM(b AS d, MAX(a) as max_a, CAST(SUM(c) AS STRING)) + USING 'cat' AS (a, b, c) +FROM script_trans +WHERE a <= 4 +GROUP BY b +-- !query schema +struct +-- !query output +2 1 3 +5 4 6 + + +-- !query +SELECT TRANSFORM(b AS d, MAX(a) FILTER (WHERE a > 3) AS max_a, CAST(SUM(c) AS STRING)) + USING 'cat' AS (a,b,c) +FROM script_trans +WHERE a <= 4 +GROUP BY b +-- !query schema +struct +-- !query output +2 null 3 +5 4 6 + + +-- !query +SELECT TRANSFORM(b, MAX(a) as max_a, CAST(sum(c) AS STRING)) + USING 'cat' AS (a, b, c) +FROM script_trans +WHERE a <= 2 +GROUP BY b +-- !query schema +struct +-- !query output +2 1 3 + + +-- !query +SELECT TRANSFORM(b, MAX(a) as max_a, CAST(SUM(c) AS STRING)) + USING 'cat' AS (a, b, c) +FROM script_trans +WHERE a <= 4 +GROUP BY b +HAVING max_a > 0 +-- !query schema +struct +-- !query output +2 1 3 +5 4 6 + + +-- !query +SELECT TRANSFORM(b, MAX(a) as max_a, CAST(SUM(c) AS STRING)) + USING 'cat' AS (a, b, c) +FROM script_trans +WHERE a <= 4 +GROUP BY b +HAVING max(a) > 1 +-- !query schema +struct +-- !query output +5 4 6 + + +-- !query +SELECT TRANSFORM(b, MAX(a) OVER w as max_a, CAST(SUM(c) OVER w AS STRING)) + USING 'cat' AS (a, b, c) +FROM script_trans +WHERE a <= 4 +WINDOW w AS (PARTITION BY b ORDER BY a) +-- !query schema +struct +-- !query output +2 1 3 +5 4 6 + + +-- !query +SELECT TRANSFORM(b, MAX(a) as max_a, CAST(SUM(c) AS STRING), myCol, myCol2) + USING 'cat' AS (a, b, c, d, e) +FROM script_trans +LATERAL VIEW explode(array(array(1,2,3))) myTable AS myCol +LATERAL VIEW explode(myTable.myCol) myTable2 AS myCol2 +WHERE a <= 4 +GROUP BY b, myCol, myCol2 +HAVING max(a) > 1 +-- !query schema +struct +-- !query output +5 4 6 [1, 2, 3] 1 +5 4 6 [1, 2, 3] 2 +5 4 6 [1, 2, 3] 3 + + +-- !query +FROM( + FROM script_trans + SELECT TRANSFORM(a, b) + USING 'cat' AS (`a` INT, b STRING) +) t +SELECT a + 1 +-- !query schema +struct<(a + 1):int> +-- !query output +2 +5 +8 + + +-- !query +FROM( + SELECT TRANSFORM(a, SUM(b) b) + USING 'cat' AS (`a` INT, b STRING) + FROM script_trans + GROUP BY a +) t +SELECT (b + 1) AS result +ORDER BY result +-- !query schema +struct +-- !query output +3.0 +6.0 +9.0 + + +-- !query +MAP k / 10 USING 'cat' AS (one) FROM (SELECT 10 AS k) +-- !query schema +struct +-- !query output +1.0 + + +-- !query +FROM (SELECT 1 AS key, 100 AS value) src +MAP src.*, src.key, CAST(src.key / 10 AS INT), CAST(src.key % 10 AS INT), src.value + USING 'cat' AS (k, v, tkey, ten, one, tvalue) +-- !query schema +struct +-- !query output +1 100 1 0 1 100 + + +-- !query +SELECT TRANSFORM(1) + USING 'cat' AS (a) +FROM script_trans +HAVING true +-- !query schema +struct +-- !query output +1 + + +-- !query +SET spark.sql.legacy.parser.havingWithoutGroupByAsWhere=true +-- !query schema +struct +-- !query output +spark.sql.legacy.parser.havingWithoutGroupByAsWhere true + + +-- !query +SELECT TRANSFORM(1) + USING 'cat' AS (a) +FROM script_trans +HAVING true +-- !query schema +struct +-- !query output +1 +1 +1 + + +-- !query +SET spark.sql.legacy.parser.havingWithoutGroupByAsWhere=false +-- !query schema +struct +-- !query output +spark.sql.legacy.parser.havingWithoutGroupByAsWhere false + + +-- !query +SET spark.sql.parser.quotedRegexColumnNames=true +-- !query schema +struct +-- !query output +spark.sql.parser.quotedRegexColumnNames true + + +-- !query +SELECT TRANSFORM(`(a|b)?+.+`) + USING 'cat' AS (c) +FROM script_trans +-- !query schema +struct +-- !query output +3 +6 +9 + + +-- !query +SET spark.sql.parser.quotedRegexColumnNames=false +-- !query schema +struct +-- !query output +spark.sql.parser.quotedRegexColumnNames false + -- !query WITH temp AS ( SELECT TRANSFORM(a) USING 'cat' AS (b string) FROM t diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index f1788e9c31af..a037acbe7fdb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -20,9 +20,9 @@ package org.apache.spark.sql.execution import scala.collection.JavaConverters._ import org.apache.spark.internal.config.ConfigEntry -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAlias, UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, Concat, SortOrder} +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedGenerator, UnresolvedHaving, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, Concat, GreaterThan, Literal, NullsFirst, SortOrder, UnresolvedWindowExpression, UnspecifiedFrame, WindowSpecDefinition, WindowSpecReference} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.{CreateTempViewUsing, RefreshResource} @@ -262,31 +262,17 @@ class SparkSqlParserSuite extends AnalysisTest { } test("SPARK-32608: script transform with row format delimit") { - assertEqual( + val rowFormat = """ - |SELECT TRANSFORM(a, b, c) | ROW FORMAT DELIMITED | FIELDS TERMINATED BY ',' | COLLECTION ITEMS TERMINATED BY '#' | MAP KEYS TERMINATED BY '@' | LINES TERMINATED BY '\n' | NULL DEFINED AS 'null' - | USING 'cat' AS (a, b, c) - | ROW FORMAT DELIMITED - | FIELDS TERMINATED BY ',' - | COLLECTION ITEMS TERMINATED BY '#' - | MAP KEYS TERMINATED BY '@' - | LINES TERMINATED BY '\n' - | NULL DEFINED AS 'NULL' - |FROM testData - """.stripMargin, - ScriptTransformation( - Seq('a, 'b, 'c), - "cat", - Seq(AttributeReference("a", StringType)(), - AttributeReference("b", StringType)(), - AttributeReference("c", StringType)()), - UnresolvedRelation(TableIdentifier("testData")), + """.stripMargin + + val ioSchema = ScriptInputOutputSchema( Seq(("TOK_TABLEROWFORMATFIELD", ","), ("TOK_TABLEROWFORMATCOLLITEMS", "#"), @@ -296,9 +282,141 @@ class SparkSqlParserSuite extends AnalysisTest { Seq(("TOK_TABLEROWFORMATFIELD", ","), ("TOK_TABLEROWFORMATCOLLITEMS", "#"), ("TOK_TABLEROWFORMATMAPKEYS", "@"), - ("TOK_TABLEROWFORMATNULL", "NULL"), + ("TOK_TABLEROWFORMATNULL", "null"), ("TOK_TABLEROWFORMATLINES", "\n")), None, None, - List.empty, List.empty, None, None, false))) + List.empty, List.empty, None, None, false) + + assertEqual( + s""" + |SELECT TRANSFORM(a, b, c) + | $rowFormat + | USING 'cat' AS (a, b, c) + | $rowFormat + |FROM testData + """.stripMargin, + ScriptTransformation( + Seq(UnresolvedStar(None)), + "cat", + Seq(AttributeReference("a", StringType)(), + AttributeReference("b", StringType)(), + AttributeReference("c", StringType)()), + Project(Seq('a, 'b, 'c), + UnresolvedRelation(TableIdentifier("testData"))), + ioSchema)) + + assertEqual( + s""" + |SELECT TRANSFORM(a, sum(b), max(c)) + | $rowFormat + | USING 'cat' AS (a, b, c) + | $rowFormat + |FROM testData + |GROUP BY a + |HAVING sum(b) > 10 + """.stripMargin, + ScriptTransformation( + Seq(UnresolvedStar(None)), + "cat", + Seq(AttributeReference("a", StringType)(), + AttributeReference("b", StringType)(), + AttributeReference("c", StringType)()), + UnresolvedHaving( + GreaterThan( + UnresolvedFunction("sum", Seq(UnresolvedAttribute("b")), isDistinct = false), + Literal(10)), + Aggregate( + Seq('a), + Seq( + 'a, + UnresolvedAlias( + UnresolvedFunction("sum", Seq(UnresolvedAttribute("b")), isDistinct = false), None), + UnresolvedAlias( + UnresolvedFunction("max", Seq(UnresolvedAttribute("c")), isDistinct = false), None) + ), + UnresolvedRelation(TableIdentifier("testData")))), + ioSchema)) + + assertEqual( + s""" + |SELECT TRANSFORM(a, sum(b) OVER w, max(c) OVER w) + | $rowFormat + | USING 'cat' AS (a, b, c) + | $rowFormat + |FROM testData + |WINDOW w AS (PARTITION BY a ORDER BY b) + """.stripMargin, + ScriptTransformation( + Seq(UnresolvedStar(None)), + "cat", + Seq(AttributeReference("a", StringType)(), + AttributeReference("b", StringType)(), + AttributeReference("c", StringType)()), + WithWindowDefinition( + Map("w" -> WindowSpecDefinition( + Seq('a), + Seq(SortOrder('b, Ascending, NullsFirst, Seq.empty)), + UnspecifiedFrame)), + Project( + Seq( + 'a, + UnresolvedAlias( + UnresolvedWindowExpression( + UnresolvedFunction("sum", Seq(UnresolvedAttribute("b")), isDistinct = false), + WindowSpecReference("w")), None), + UnresolvedAlias( + UnresolvedWindowExpression( + UnresolvedFunction("max", Seq(UnresolvedAttribute("c")), isDistinct = false), + WindowSpecReference("w")), None) + ), + UnresolvedRelation(TableIdentifier("testData")))), + ioSchema)) + + assertEqual( + s""" + |SELECT TRANSFORM(a, sum(b), max(c)) + | $rowFormat + | USING 'cat' AS (a, b, c) + | $rowFormat + |FROM testData + |LATERAL VIEW explode(array(array(1,2,3))) myTable AS myCol + |LATERAL VIEW explode(myTable.myCol) myTable2 AS myCol2 + |GROUP BY a, myCol, myCol2 + |HAVING sum(b) > 10 + """.stripMargin, + ScriptTransformation( + Seq(UnresolvedStar(None)), + "cat", + Seq(AttributeReference("a", StringType)(), + AttributeReference("b", StringType)(), + AttributeReference("c", StringType)()), + UnresolvedHaving( + GreaterThan( + UnresolvedFunction("sum", Seq(UnresolvedAttribute("b")), isDistinct = false), + Literal(10)), + Aggregate( + Seq('a, 'myCol, 'myCol2), + Seq( + 'a, + UnresolvedAlias( + UnresolvedFunction("sum", Seq(UnresolvedAttribute("b")), isDistinct = false), None), + UnresolvedAlias( + UnresolvedFunction("max", Seq(UnresolvedAttribute("c")), isDistinct = false), None) + ), + Generate( + UnresolvedGenerator( + FunctionIdentifier("explode"), + Seq(UnresolvedAttribute("myTable.myCol"))), + Nil, false, Option("mytable2"), Seq('myCol2), + Generate( + UnresolvedGenerator( + FunctionIdentifier("explode"), + Seq(UnresolvedFunction("array", + Seq( + UnresolvedFunction("array", Seq(Literal(1), Literal(2), Literal(3)), false)), + false))), + Nil, false, Option("mytable"), Seq('myCol), + UnresolvedRelation(TableIdentifier("testData")))))), + ioSchema)) } test("SPARK-32607: Script Transformation ROW FORMAT DELIMITED" + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala index 96f9421e1d98..35fdf198fffb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.command import java.util.Locale import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAttribute, UnresolvedStar} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan @@ -279,16 +279,16 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { } test("transform query spec") { - val p = ScriptTransformation( - Seq(UnresolvedAttribute("a"), UnresolvedAttribute("b")), - "func", Seq.empty, plans.table("e"), null) + val p = Project(Seq(UnresolvedAttribute("a"), UnresolvedAttribute("b")), plans.table("e")) + val s = ScriptTransformation(Seq(UnresolvedStar(None)), "func", Seq.empty, p, null) compareTransformQuery("select transform(a, b) using 'func' from e where f < 10", - p.copy(child = p.child.where('f < 10), output = Seq('key.string, 'value.string))) + s.copy(child = p.copy(child = p.child.where('f < 10)), + output = Seq('key.string, 'value.string))) compareTransformQuery("map a, b using 'func' as c, d from e", - p.copy(output = Seq('c.string, 'd.string))) + s.copy(output = Seq('c.string, 'd.string))) compareTransformQuery("reduce a, b using 'func' as (c int, d decimal(10, 0)) from e", - p.copy(output = Seq('c.int, 'd.decimal(10, 0)))) + s.copy(output = Seq('c.int, 'd.decimal(10, 0)))) } test("use backticks in output of Script Transform") {