From 6fae959dbc7a89303b0a02c389217d6f126272fb Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Wed, 27 Jan 2016 22:22:46 -0800 Subject: [PATCH 1/4] resolve the UnresolvedAttribute for aliases in GROUP By clause --- .../sql/catalyst/analysis/Analyzer.scala | 50 +++++++++++++++++++ .../org/apache/spark/sql/JoinSuite.scala | 12 +++++ 2 files changed, 62 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e362b55d80cd..60def366a3cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -462,6 +462,43 @@ class Analyzer( Generate(newG.asInstanceOf[Generator], join, outer, qualifier, output, child) } + case a @ Aggregate(groups, aggs, child) + if (groups.nonEmpty && aggs.nonEmpty && child.resolved + && aggs.map(_.resolved).reduceLeft(_ && _) + && (!groups.map(_.resolved).reduceLeft(_ && _))) => + val newGroups = groups.map(g => g transformUp { + case u @ UnresolvedAttribute(nameParts) => + val result = withPosition(u) { + resolveAliases(nameParts, + aggs.collect { case a: Alias => a }, resolver).getOrElse(u) + } + logDebug(s"Resolving $u to $result") + result + case UnresolvedExtractValue(child, fieldExpr) if child.resolved => + ExtractValue(child, fieldExpr, resolver) + }) + + val q = if (!newGroups.zip(groups).map(p => p._1 fastEquals (p._2)) + .reduceLeft(_ && _)) { + Aggregate(newGroups, aggs, child) + } else { + a + } + + logTrace(s"Attempting to resolve ${q.simpleString}") + q transformExpressionsUp { + case u @ UnresolvedAttribute(nameParts) => + // Leave unchanged if resolution fails. Hopefully will be resolved next round. + val result = + withPosition(u) { + q.resolveChildren(nameParts, resolver).getOrElse(u) + } + logDebug(s"Resolving $u to $result") + result + case UnresolvedExtractValue(child, fieldExpr) if child.resolved => + ExtractValue(child, fieldExpr, resolver) + } + case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") q transformExpressionsUp { @@ -493,6 +530,19 @@ class Analyzer( def containsStar(exprs: Seq[Expression]): Boolean = exprs.exists(_.collect { case _: Star => true }.nonEmpty) } + /** + * Find matching Aliases for attribute name + */ + private def resolveAliases( + nameParts: Seq[String], + aliases: Seq[Alias], + resolver: Resolver): Option[Alias] = { + val matches = if (nameParts.length == 1) { + aliases.distinct.filter(a => resolver(a.name, nameParts.head)) + } else Nil + if (matches.isEmpty) None + else Some(matches.head) + } private def resolveSortOrders(ordering: Seq[SortOrder], plan: LogicalPlan, throws: Boolean) = { ordering.map { order => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 9a3c262e9485..3911068f4402 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -32,6 +32,18 @@ class JoinSuite extends QueryTest with SharedSQLContext { df.queryExecution.optimizedPlan.statistics.sizeInBytes } + test("spark-10777 order by") { + + val df1 = sql("select a r, sum(b) s FROM testData2 GROUP BY r") + + val df2 = sql("SELECT * FROM ( select a r, sum(b) s FROM testData2 GROUP BY r) t") + + val df3 = sql("SELECT r as c1, min(s) over () as c2 FROM" + + "( select a r, sum(b) s FROM testData2 GROUP BY r) t order by r") + + val df4 = sql("select a r, sum(b) s FROM testData2 GROUP BY r, s") + } + test("equi-join is hash-join") { val x = testData2.as("x") val y = testData2.as("y") From da86a01221b1d24b53183d9c4d605554dbbc6897 Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Thu, 28 Jan 2016 13:47:31 -0800 Subject: [PATCH 2/4] move testcase to SQLQuerySuite.scala --- .../scala/org/apache/spark/sql/JoinSuite.scala | 2 ++ .../org/apache/spark/sql/SQLQuerySuite.scala | 15 +++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 3911068f4402..988264b187f6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -36,6 +36,8 @@ class JoinSuite extends QueryTest with SharedSQLContext { val df1 = sql("select a r, sum(b) s FROM testData2 GROUP BY r") + + val df2 = sql("SELECT * FROM ( select a r, sum(b) s FROM testData2 GROUP BY r) t") val df3 = sql("SELECT r as c1, min(s) over () as c2 FROM" + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index bd987ae1bb03..edb40b1bef07 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2079,4 +2079,19 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { assert(rdd.takeAsync(2147483638).get.size === 3) } + test("SPARK-10777: resolve the alias defined in aggregation expression used in group by") { + val structDf = testData2.select("a", "b").as("record") + + checkAnswer( + sql("SELECT a as r, sum(b) as s from testData2 GROUP BY r"), + Row(1, 3) :: Row(2, 3) :: Row(3, 3) :: Nil + ) + + checkAnswer( + sql("SELECT * FROM " + + "(SELECT a as r, sum(b) as s from testData2 GROUP BY r) t"), + Row(1, 3) :: Row(2, 3) :: Row(3, 3) :: Nil + ) + + } } From 6ad40c781166499613dd7ec74e8b76b3880c4e33 Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Thu, 28 Jan 2016 13:53:11 -0800 Subject: [PATCH 3/4] remove the change in JoinSuite.scala --- .../scala/org/apache/spark/sql/JoinSuite.scala | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 988264b187f6..9a3c262e9485 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -32,20 +32,6 @@ class JoinSuite extends QueryTest with SharedSQLContext { df.queryExecution.optimizedPlan.statistics.sizeInBytes } - test("spark-10777 order by") { - - val df1 = sql("select a r, sum(b) s FROM testData2 GROUP BY r") - - - - val df2 = sql("SELECT * FROM ( select a r, sum(b) s FROM testData2 GROUP BY r) t") - - val df3 = sql("SELECT r as c1, min(s) over () as c2 FROM" + - "( select a r, sum(b) s FROM testData2 GROUP BY r) t order by r") - - val df4 = sql("select a r, sum(b) s FROM testData2 GROUP BY r, s") - } - test("equi-join is hash-join") { val x = testData2.as("x") val y = testData2.as("y") From b0157d4ec7c2898837948dc1882a57217d9caf1e Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Fri, 29 Jan 2016 23:46:32 -0800 Subject: [PATCH 4/4] adding testcase --- .../test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index edb40b1bef07..779c509c5bdc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2093,5 +2093,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(1, 3) :: Row(2, 3) :: Row(3, 3) :: Nil ) + checkAnswer( + sql("SELECT r as c1, min(s) as c2 FROM " + + "(SELECT a as r, sum(b) as s from testData2 GROUP BY a) t order by r"), + Row(1, 3) :: Row(2, 3) :: Row(3, 3) :: Nil + ) + + //val df = sql("select a r, sum(b) s FROM testData2 GROUP BY r") + //val df = sql("SELECT * FROM ( select a r, sum(b) s FROM testData2 GROUP BY r) t") } }