Skip to content

Commit 8a5a4f6

Browse files
author
Derek Sabry
committed
Add Sort() case
1 parent b4cfcbf commit 8a5a4f6

File tree

2 files changed

+40
-14
lines changed

2 files changed

+40
-14
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class Analyzer(
7070
Batch("Resolution", fixedPoint,
7171
ResolveRelations ::
7272
ResolveReferences ::
73-
ResolveGroupByReferences ::
73+
ResolveNumericReferences ::
7474
ResolveGroupingAnalytics ::
7575
ResolvePivot ::
7676
ResolveSortReferences ::
@@ -180,23 +180,42 @@ class Analyzer(
180180
}
181181

182182
/**
183-
* Replaces queries of the form "SELECT expr FROM A GROUP BY 1"
184-
* with a query of the form "SELECT expr FROM A GROUP BY expr"
183+
* Replaces queries of the form "SELECT expr FROM A GROUP BY 1 ORDER BY 1"
184+
* with a query of the form "SELECT expr FROM A GROUP BY expr ORDER BY expr"
185185
*/
186-
object ResolveGroupByReferences extends Rule[LogicalPlan] {
186+
object ResolveNumericReferences extends Rule[LogicalPlan] {
187187

188188
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
189189
case Aggregate(groups, aggs, child) =>
190-
Aggregate(groups.map(group => group match {
191-
case g if g.prettyString forall Character.isDigit =>
192-
aggs(g.prettyString.toInt - 1) match {
193-
case u : UnresolvedAlias =>
194-
u.child
195-
case a : Alias =>
196-
a.child
190+
val newGroups = groups.map {
191+
case group if group.isInstanceOf[Literal] && group.dataType.isInstanceOf[IntegralType] =>
192+
aggs(group.toString.toInt - 1) match {
193+
case u: UnresolvedAlias =>
194+
u.child match {
195+
case UnresolvedStar(_) => // Can't replace literal with column yet
196+
group
197+
case _ => u.child
198+
}
199+
case a: Alias => a.child
200+
case a: AttributeReference => a
197201
}
198-
case _ => group
199-
}), aggs, child)
202+
case group => group
203+
}
204+
Aggregate(newGroups, aggs, child)
205+
case Sort(ordering, global, child) =>
206+
val newOrdering = ordering.map {
207+
case o if o.child.isInstanceOf[Literal] && o.dataType.isInstanceOf[IntegralType] =>
208+
val newExpr = child.asInstanceOf[Project].projectList(o.child.toString.toInt - 1)
209+
match {
210+
case u: UnresolvedAlias =>
211+
u.child
212+
case a: Alias =>
213+
a.child
214+
}
215+
SortOrder(newExpr, o.direction)
216+
case other => other
217+
}
218+
Sort(newOrdering, global, child)
200219
}
201220
}
202221

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2028,10 +2028,17 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
20282028
Row(false) :: Row(true) :: Nil)
20292029
}
20302030

2031-
test("SPARK-12063: Group by Column Number identifier") {
2031+
test("SPARK-12063: Group by Columns Number") {
20322032
checkAnswer(
20332033
sql("SELECT a, SUM(b) FROM testData2 GROUP BY 1"),
20342034
Seq(Row(1, 3), Row(2, 3), Row(3, 3)))
20352035
}
2036+
2037+
test("SPARK-12063: Order by Column Number") {
2038+
Seq(("one", 1), ("two", 2), ("three", 3), ("one", 5)).toDF("k", "v").registerTempTable("ord")
2039+
checkAnswer(
2040+
sql("SELECT v from ord order by 1 desc"),
2041+
Row(5) :: Row(3) :: Row(2) :: Row(1) :: Nil)
2042+
}
20362043

20372044
}

0 commit comments

Comments
 (0)