diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d0eb9c2c90bdf..e4407516de33a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -465,12 +465,13 @@ class Analyzer( // Find aggregate expressions and evaluate them early, since they can't be evaluated in a // Sort. - val (withAggsRemoved, aliasedAggregateList) = newOrdering.map { - case aggOrdering if aggOrdering.collect { case a: AggregateExpression => a }.nonEmpty => - val aliased = Alias(aggOrdering.child, "_aggOrdering")() + val (withAggsRemoved, aliasedAggregateList) = newOrdering.zipWithIndex.map { + case (aggOrdering, idx) + if aggOrdering.collect { case a: AggregateExpression => a }.nonEmpty => + val aliased = Alias(aggOrdering.child, s"${aggOrdering.toString}_$idx")() (aggOrdering.copy(child = aliased.toAttribute), Some(aliased)) - case other => (other, None) + case (other, _) => (other, None) }.unzip val missing = missingAttr ++ aliasedAggregateList.flatten diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index c329fdb2a6bb1..415c8dfdbd640 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.DefaultParserDialect import org.apache.spark.sql.catalyst.errors.DialectException +import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext @@ -1625,4 +1626,21 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(sql("select count(num) from 1one"), Row(10)) } } + + test("SPARK-10044: resolving reference for sorting with aggregation") { + withTempTable("mytable") { + sqlContext.sparkContext.parallelize(1 to 10).map(i => (i, i.toString)) + .toDF("key", "value") + .registerTempTable("mytable") + checkAnswer(sql( + """select max(value) as _aggOrdering_0 from mytable group by key % 2 + |order by max(concat(value,",", key)), min(substr(value, 0, 4)) + |""".stripMargin), Row("8") :: Row("9") :: Nil) + + checkAnswer( + sqlContext.table("mytable").groupBy($"key" % 2).agg(max($"value")) + .orderBy(max(concat($"value", lit(","), $"key")), min(substring($"value", 0, 4))), + Row(0, "8") :: Row(1, "9") :: Nil) + } + } }