diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala index 256b18771052a..860d20f897690 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala @@ -33,7 +33,7 @@ class SubstituteUnresolvedOrdinals(conf: SQLConf) extends Rule[LogicalPlan] { case _ => false } - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case s: Sort if conf.orderByOrdinal && s.order.exists(o => isIntLiteral(o.child)) => val newOrders = s.order.map { case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, _, _) => diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql index 6566338f3d4a9..928f766b4add2 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql @@ -52,7 +52,7 @@ select count(a), a from (select 1 as a) tmp group by 2 having a > 0; -- mixed cases: group-by ordinals and aliases select a, a AS k, count(b) from data group by k, 1; --- turn of group by ordinal +-- turn off group by ordinal set spark.sql.groupByOrdinal=false; -- can now group by negative literal diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index d802557b36ec9..69ea62ef5eb74 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -557,4 +557,20 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } assert(e.message.contains("aggregate functions are not allowed in GROUP BY")) } + + test("SPARK-21580 ints in aggregation expressions are taken as group-by ordinal.") { + checkAnswer( + testData2.groupBy(lit(3), lit(4)).agg(lit(6), lit(7), sum("b")), + Seq(Row(3, 4, 6, 7, 9))) + checkAnswer( + testData2.groupBy(lit(3), lit(4)).agg(lit(6), 'b, sum("b")), + Seq(Row(3, 4, 6, 1, 3), Row(3, 4, 6, 2, 6))) + + checkAnswer( + spark.sql("SELECT 3, 4, SUM(b) FROM testData2 GROUP BY 1, 2"), + Seq(Row(3, 4, 9))) + checkAnswer( + spark.sql("SELECT 3 AS c, 4 AS d, SUM(b) FROM testData2 GROUP BY c, d"), + Seq(Row(3, 4, 9))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 7c500728bdec9..df54fc4c95810 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2023,4 +2023,10 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df1.join(df2, $"t1.i" === $"t2.i").cache().count() == 1) } } + + test("order-by ordinal.") { + checkAnswer( + testData2.select(lit(7), 'a, 'b).orderBy(lit(1), lit(2), lit(3)), + Seq(Row(7, 1, 1), Row(7, 1, 2), Row(7, 2, 1), Row(7, 2, 2), Row(7, 3, 1), Row(7, 3, 2))) + } }