diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e362b55d80cd..60def366a3cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -462,6 +462,43 @@ class Analyzer( Generate(newG.asInstanceOf[Generator], join, outer, qualifier, output, child) } + case a @ Aggregate(groups, aggs, child) + if (groups.nonEmpty && aggs.nonEmpty && child.resolved + && aggs.map(_.resolved).reduceLeft(_ && _) + && (!groups.map(_.resolved).reduceLeft(_ && _))) => + val newGroups = groups.map(g => g transformUp { + case u @ UnresolvedAttribute(nameParts) => + val result = withPosition(u) { + resolveAliases(nameParts, + aggs.collect { case a: Alias => a }, resolver).getOrElse(u) + } + logDebug(s"Resolving $u to $result") + result + case UnresolvedExtractValue(child, fieldExpr) if child.resolved => + ExtractValue(child, fieldExpr, resolver) + }) + + val q = if (!newGroups.zip(groups).map(p => p._1 fastEquals (p._2)) + .reduceLeft(_ && _)) { + Aggregate(newGroups, aggs, child) + } else { + a + } + + logTrace(s"Attempting to resolve ${q.simpleString}") + q transformExpressionsUp { + case u @ UnresolvedAttribute(nameParts) => + // Leave unchanged if resolution fails. Hopefully will be resolved next round. + val result = + withPosition(u) { + q.resolveChildren(nameParts, resolver).getOrElse(u) + } + logDebug(s"Resolving $u to $result") + result + case UnresolvedExtractValue(child, fieldExpr) if child.resolved => + ExtractValue(child, fieldExpr, resolver) + } + case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") q transformExpressionsUp { @@ -493,6 +530,19 @@ class Analyzer( def containsStar(exprs: Seq[Expression]): Boolean = exprs.exists(_.collect { case _: Star => true }.nonEmpty) } + /** + * Find matching Aliases for attribute name + */ + private def resolveAliases( + nameParts: Seq[String], + aliases: Seq[Alias], + resolver: Resolver): Option[Alias] = { + val matches = if (nameParts.length == 1) { + aliases.distinct.filter(a => resolver(a.name, nameParts.head)) + } else Nil + if (matches.isEmpty) None + else Some(matches.head) + } private def resolveSortOrders(ordering: Seq[SortOrder], plan: LogicalPlan, throws: Boolean) = { ordering.map { order => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index bd987ae1bb03..779c509c5bdc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2079,4 +2079,27 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { assert(rdd.takeAsync(2147483638).get.size === 3) } + test("SPARK-10777: resolve the alias defined in aggregation expression used in group by") { + val structDf = testData2.select("a", "b").as("record") + + checkAnswer( + sql("SELECT a as r, sum(b) as s from testData2 GROUP BY r"), + Row(1, 3) :: Row(2, 3) :: Row(3, 3) :: Nil + ) + + checkAnswer( + sql("SELECT * FROM " + + "(SELECT a as r, sum(b) as s from testData2 GROUP BY r) t"), + Row(1, 3) :: Row(2, 3) :: Row(3, 3) :: Nil + ) + + checkAnswer( + sql("SELECT r as c1, min(s) as c2 FROM " + + "(SELECT a as r, sum(b) as s from testData2 GROUP BY a) t order by r"), + Row(1, 3) :: Row(2, 3) :: Row(3, 3) :: Nil + ) + + //val df = sql("select a r, sum(b) s FROM testData2 GROUP BY r") + //val df = sql("SELECT * FROM ( select a r, sum(b) s FROM testData2 GROUP BY r) t") + } }