diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 96d1e42ffafe..f0a040bceb28 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -718,8 +718,8 @@ primaryExpression | '(' query ')' #subqueryExpression | qualifiedName '(' (setQuantifier? argument+=expression (',' argument+=expression)*)? ')' (OVER windowSpec)? #functionCall - | IDENTIFIER '->' expression #lambda - | '(' IDENTIFIER (',' IDENTIFIER)+ ')' '->' expression #lambda + | identifier '->' expression #lambda + | '(' identifier (',' identifier)+ ')' '->' expression #lambda | value=primaryExpression '[' index=valueExpression ']' #subscript | identifier #columnReference | base=primaryExpression '.' fieldName=identifier #dereference diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 838fc4d84a5d..0bc87095ce53 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1559,7 +1559,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging * Create an [[LambdaFunction]]. */ override def visitLambda(ctx: LambdaContext): Expression = withOrigin(ctx) { - val arguments = ctx.IDENTIFIER().asScala.map { name => + val arguments = ctx.identifier().asScala.map { name => UnresolvedNamedLambdaVariable(UnresolvedAttribute.quoted(name.getText).nameParts) } val function = expression(ctx.expression).transformUp { diff --git a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql index 02ad5e353868..8d5d9fae7a73 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql @@ -83,3 +83,12 @@ select transform_values(ys, (k, v) -> v + 1) as v from nested; -- Transform values in a map using values select transform_values(ys, (k, v) -> k + v) as v from nested; + +-- use non reversed keywords: all is non reversed only if !ansi +select transform(ys, all -> all * all) as v from values (array(32, 97)) as t(ys); +select transform(ys, (all, i) -> all + i) as v from values (array(32, 97)) as t(ys); + +set spark.sql.ansi.enabled=true; +select transform(ys, all -> all * all) as v from values (array(32, 97)) as t(ys); +select transform(ys, (all, i) -> all + i) as v from values (array(32, 97)) as t(ys); +set spark.sql.ansi.enabled=false; diff --git a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out index 1b7c6f4f7625..0b78076588c1 100644 --- a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 27 +-- Number of queries: 33 -- !query 0 @@ -254,3 +254,63 @@ struct> -- !query 26 output {1:2,2:4,3:6} {4:8,5:10,6:12} + + +-- !query 27 +select transform(ys, all -> all * all) as v from values (array(32, 97)) as t(ys) +-- !query 27 schema +struct> +-- !query 27 output +[1024,9409] + + +-- !query 28 +select transform(ys, (all, i) -> all + i) as v from values (array(32, 97)) as t(ys) +-- !query 28 schema +struct> +-- !query 28 output +[32,98] + + +-- !query 29 +set spark.sql.ansi.enabled=true +-- !query 29 schema +struct +-- !query 29 output +spark.sql.ansi.enabled true + + +-- !query 30 +select transform(ys, all -> all * all) as v from values (array(32, 97)) as t(ys) +-- !query 30 schema +struct<> +-- !query 30 output +org.apache.spark.sql.catalyst.parser.ParseException + +no viable alternative at input 'all'(line 1, pos 21) + +== SQL == +select transform(ys, all -> all * all) as v from values (array(32, 97)) as t(ys) +---------------------^^^ + + +-- !query 31 +select transform(ys, (all, i) -> all + i) as v from values (array(32, 97)) as t(ys) +-- !query 31 schema +struct<> +-- !query 31 output +org.apache.spark.sql.catalyst.parser.ParseException + +no viable alternative at input 'all'(line 1, pos 22) + +== SQL == +select transform(ys, (all, i) -> all + i) as v from values (array(32, 97)) as t(ys) +----------------------^^^ + + +-- !query 32 +set spark.sql.ansi.enabled=false +-- !query 32 schema +struct +-- !query 32 output +spark.sql.ansi.enabled false