Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -966,7 +966,7 @@ class Analyzer(
case p if !p.childrenResolved => p
// Replace the index with the related attribute for ORDER BY,
// which is a 1-base position of the projection list.
case s @ Sort(orders, global, child)
case Sort(orders, global, child)
if orders.exists(_.child.isInstanceOf[UnresolvedOrdinal]) =>
val newOrders = orders map {
case s @ SortOrder(UnresolvedOrdinal(index), direction, nullOrdering, _) =>
Expand All @@ -983,17 +983,11 @@ class Analyzer(

// Replace the index with the corresponding expression in aggregateExpressions. The index is
// a 1-base position of aggregateExpressions, which is output columns (select expression)
case a @ Aggregate(groups, aggs, child) if aggs.forall(_.resolved) &&
case Aggregate(groups, aggs, child) if aggs.forall(_.resolved) &&
groups.exists(_.isInstanceOf[UnresolvedOrdinal]) =>
val newGroups = groups.map {
case ordinal @ UnresolvedOrdinal(index) if index > 0 && index <= aggs.size =>
aggs(index - 1) match {
case e if ResolveAggregateFunctions.containsAggregate(e) =>
ordinal.failAnalysis(
s"GROUP BY position $index is an aggregate function, and " +
"aggregate functions are not allowed in GROUP BY")
case o => o
}
case u @ UnresolvedOrdinal(index) if index > 0 && index <= aggs.size =>
aggs(index - 1)
case ordinal @ UnresolvedOrdinal(index) =>
ordinal.failAnalysis(
s"GROUP BY position $index is not in select list " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,11 @@ trait CheckAnalysis extends PredicateHelper {
}

def checkValidGroupingExprs(expr: Expression): Unit = {
if (expr.find(_.isInstanceOf[AggregateExpression]).isDefined) {
failAnalysis(
"aggregate functions are not allowed in GROUP BY, but found " + expr.sql)
}

// Check if the data type of expr is orderable.
if (!RowOrdering.isOrderable(expr.dataType)) {
failAnalysis(
Expand All @@ -271,8 +276,8 @@ trait CheckAnalysis extends PredicateHelper {
}
}

aggregateExprs.foreach(checkValidAggregateExpression)
groupingExprs.foreach(checkValidGroupingExprs)
aggregateExprs.foreach(checkValidAggregateExpression)

case Sort(orders, _, _) =>
orders.foreach { order =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ select a, b, sum(b) from data group by 3
struct<>
-- !query 11 output
org.apache.spark.sql.AnalysisException
GROUP BY position 3 is an aggregate function, and aggregate functions are not allowed in GROUP BY; line 1 pos 39
aggregate functions are not allowed in GROUP BY, but found sum(CAST(data.`b` AS BIGINT));


-- !query 12
Expand All @@ -131,7 +131,7 @@ select a, b, sum(b) + 2 from data group by 3
struct<>
-- !query 12 output
org.apache.spark.sql.AnalysisException
GROUP BY position 3 is an aggregate function, and aggregate functions are not allowed in GROUP BY; line 1 pos 43
aggregate functions are not allowed in GROUP BY, but found (sum(CAST(data.`b` AS BIGINT)) + CAST(2 AS BIGINT));


-- !query 13
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -538,4 +538,11 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
Seq(Row(3, 0, 0.0, 1, 5.0), Row(2, 1, 4.0, 0, 0.0))
)
}

test("aggregate function in GROUP BY") {
val e = intercept[AnalysisException] {
testData.groupBy(sum($"key")).count()
}
assert(e.message.contains("aggregate functions are not allowed in GROUP BY"))
}
}