Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,43 @@ class Analyzer(
Generate(newG.asInstanceOf[Generator], join, outer, qualifier, output, child)
}

case a @ Aggregate(groups, aggs, child)
if (groups.nonEmpty && aggs.nonEmpty && child.resolved
&& aggs.map(_.resolved).reduceLeft(_ && _)
&& (!groups.map(_.resolved).reduceLeft(_ && _))) =>
val newGroups = groups.map(g => g transformUp {
case u @ UnresolvedAttribute(nameParts) =>
val result = withPosition(u) {
resolveAliases(nameParts,
aggs.collect { case a: Alias => a }, resolver).getOrElse(u)
}
logDebug(s"Resolving $u to $result")
result
case UnresolvedExtractValue(child, fieldExpr) if child.resolved =>
ExtractValue(child, fieldExpr, resolver)
})

val q = if (!newGroups.zip(groups).map(p => p._1 fastEquals (p._2))
.reduceLeft(_ && _)) {
Aggregate(newGroups, aggs, child)
} else {
a
}

logTrace(s"Attempting to resolve ${q.simpleString}")
q transformExpressionsUp {
case u @ UnresolvedAttribute(nameParts) =>
// Leave unchanged if resolution fails. Hopefully will be resolved next round.
val result =
withPosition(u) {
q.resolveChildren(nameParts, resolver).getOrElse(u)
}
logDebug(s"Resolving $u to $result")
result
case UnresolvedExtractValue(child, fieldExpr) if child.resolved =>
ExtractValue(child, fieldExpr, resolver)
}

case q: LogicalPlan =>
logTrace(s"Attempting to resolve ${q.simpleString}")
q transformExpressionsUp {
Expand Down Expand Up @@ -493,6 +530,19 @@ class Analyzer(
def containsStar(exprs: Seq[Expression]): Boolean =
exprs.exists(_.collect { case _: Star => true }.nonEmpty)
}
/**
* Find matching Aliases for attribute name
*/
private def resolveAliases(
nameParts: Seq[String],
aliases: Seq[Alias],
resolver: Resolver): Option[Alias] = {
val matches = if (nameParts.length == 1) {
aliases.distinct.filter(a => resolver(a.name, nameParts.head))
} else Nil
if (matches.isEmpty) None
else Some(matches.head)
}

private def resolveSortOrders(ordering: Seq[SortOrder], plan: LogicalPlan, throws: Boolean) = {
ordering.map { order =>
Expand Down
23 changes: 23 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2079,4 +2079,27 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
assert(rdd.takeAsync(2147483638).get.size === 3)
}

test("SPARK-10777: resolve the alias defined in aggregation expression used in group by") {
val structDf = testData2.select("a", "b").as("record")

checkAnswer(
sql("SELECT a as r, sum(b) as s from testData2 GROUP BY r"),
Row(1, 3) :: Row(2, 3) :: Row(3, 3) :: Nil
)

checkAnswer(
sql("SELECT * FROM " +
"(SELECT a as r, sum(b) as s from testData2 GROUP BY r) t"),
Row(1, 3) :: Row(2, 3) :: Row(3, 3) :: Nil
)

checkAnswer(
sql("SELECT r as c1, min(s) as c2 FROM " +
"(SELECT a as r, sum(b) as s from testData2 GROUP BY a) t order by r"),
Row(1, 3) :: Row(2, 3) :: Row(3, 3) :: Nil
)

//val df = sql("select a r, sum(b) s FROM testData2 GROUP BY r")
//val df = sql("SELECT * FROM ( select a r, sum(b) s FROM testData2 GROUP BY r) t")
}
}