Skip to content

Commit 47f6534

Browse files
committed
[SPARK-16955][SQL] Fix analysis error when using ordinal in ORDER BY or GROUP BY
1 parent ac84fb6 commit 47f6534

File tree

6 files changed

+148
-17
lines changed

6 files changed

+148
-17
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.planning.IntegerIndex
3232
import org.apache.spark.sql.catalyst.plans._
3333
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, _}
3434
import org.apache.spark.sql.catalyst.rules._
35-
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
35+
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin, TreeNodeRef}
3636
import org.apache.spark.sql.catalyst.util.toPrettySQL
3737
import org.apache.spark.sql.types._
3838

@@ -84,7 +84,8 @@ class Analyzer(
8484
Batch("Substitution", fixedPoint,
8585
CTESubstitution,
8686
WindowsSubstitution,
87-
EliminateUnions),
87+
EliminateUnions,
88+
new UnresolvedOrdinalSubstitution(conf)),
8889
Batch("Resolution", fixedPoint,
8990
ResolveRelations ::
9091
ResolveReferences ::
@@ -545,7 +546,7 @@ class Analyzer(
545546
p.copy(projectList = buildExpandedProjectList(p.projectList, p.child))
546547
// If the aggregate function argument contains Stars, expand it.
547548
case a: Aggregate if containsStar(a.aggregateExpressions) =>
548-
if (conf.groupByOrdinal && a.groupingExpressions.exists(IntegerIndex.unapply(_).nonEmpty)) {
549+
if (a.groupingExpressions.exists(_.isInstanceOf[UnresolvedOrdinal])) {
549550
failAnalysis(
550551
"Star (*) is not allowed in select list when GROUP BY ordinal position is used")
551552
} else {
@@ -716,9 +717,9 @@ class Analyzer(
716717
// Replace the index with the related attribute for ORDER BY,
717718
// which is a 1-base position of the projection list.
718719
case s @ Sort(orders, global, child)
719-
if conf.orderByOrdinal && orders.exists(o => IntegerIndex.unapply(o.child).nonEmpty) =>
720+
if orders.exists(_.child.isInstanceOf[UnresolvedOrdinal]) =>
720721
val newOrders = orders map {
721-
case s @ SortOrder(IntegerIndex(index), direction) =>
722+
case s @ SortOrder(UnresolvedOrdinal(index), direction) =>
722723
if (index > 0 && index <= child.output.size) {
723724
SortOrder(child.output(index - 1), direction)
724725
} else {
@@ -732,19 +733,18 @@ class Analyzer(
732733

733734
// Replace the index with the corresponding expression in aggregateExpressions. The index is
734735
// a 1-base position of aggregateExpressions, which is output columns (select expression)
735-
case a @ Aggregate(groups, aggs, child)
736-
if conf.groupByOrdinal && aggs.forall(_.resolved) &&
737-
groups.exists(IntegerIndex.unapply(_).nonEmpty) =>
736+
case a @ Aggregate(groups, aggs, child) if aggs.forall(_.resolved) &&
737+
groups.exists(_.isInstanceOf[UnresolvedOrdinal]) =>
738738
val newGroups = groups.map {
739-
case ordinal @ IntegerIndex(index) if index > 0 && index <= aggs.size =>
739+
case ordinal @ UnresolvedOrdinal(index) if index > 0 && index <= aggs.size =>
740740
aggs(index - 1) match {
741741
case e if ResolveAggregateFunctions.containsAggregate(e) =>
742742
ordinal.failAnalysis(
743743
s"GROUP BY position $index is an aggregate function, and " +
744744
"aggregate functions are not allowed in GROUP BY")
745745
case o => o
746746
}
747-
case ordinal @ IntegerIndex(index) =>
747+
case ordinal @ UnresolvedOrdinal(index) =>
748748
ordinal.failAnalysis(
749749
s"GROUP BY position $index is not in select list " +
750750
s"(valid range is [1, ${aggs.size}])")
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.analysis
19+
20+
import org.apache.spark.sql.catalyst.CatalystConf
21+
import org.apache.spark.sql.catalyst.expressions.{Expression, SortOrder}
22+
import org.apache.spark.sql.catalyst.planning.IntegerIndex
23+
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Sort}
24+
import org.apache.spark.sql.catalyst.rules.Rule
25+
import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
26+
27+
/**
28+
* Replaces ordinal in 'order by' or 'group by' with UnresolvedOrdinal expression.
29+
*/
30+
class UnresolvedOrdinalSubstitution(conf: CatalystConf) extends Rule[LogicalPlan] {
31+
private def isIntegerLiteral(sorter: Expression) = IntegerIndex.unapply(sorter).nonEmpty
32+
33+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
34+
case s @ Sort(orders, global, child) if conf.orderByOrdinal &&
35+
orders.exists(o => isIntegerLiteral(o.child)) =>
36+
val newOrders = orders.map {
37+
case order @ SortOrder(ordinal @ IntegerIndex(index: Int), _) =>
38+
val newOrdinal = withOrigin(ordinal.origin)(UnresolvedOrdinal(index))
39+
withOrigin(order.origin)(order.copy(child = newOrdinal))
40+
case other => other
41+
}
42+
withOrigin(s.origin)(s.copy(order = newOrders))
43+
case a @ Aggregate(groups, aggs, child) if conf.groupByOrdinal &&
44+
groups.exists(isIntegerLiteral(_)) =>
45+
val newGroups = groups.map {
46+
case ordinal @ IntegerIndex(index) =>
47+
withOrigin(ordinal.origin)(UnresolvedOrdinal(index))
48+
case other => other
49+
}
50+
withOrigin(a.origin)(a.copy(groupingExpressions = newGroups))
51+
}
52+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,3 +370,21 @@ case class GetColumnByOrdinal(ordinal: Int, dataType: DataType) extends LeafExpr
370370
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
371371
override lazy val resolved = false
372372
}
373+
374+
/**
375+
* Represents unresolved ordinal used in order by or group by.
376+
*
377+
* For example:
378+
* {{{
379+
* select a from table order by 1
380+
* select a from table group by 1
381+
* }}}
382+
* @param ordinal ordinal starts from 1, instead of 0
383+
*/
384+
case class UnresolvedOrdinal(ordinal: Int)
385+
extends LeafExpression with Unevaluable with NonSQLExpression {
386+
override def dataType: DataType = throw new UnresolvedException(this, "dataType")
387+
override def foldable: Boolean = throw new UnresolvedException(this, "foldable")
388+
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
389+
override lazy val resolved = false
390+
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.analysis
1919

20-
import org.apache.spark.sql.catalyst.TableIdentifier
20+
import org.apache.spark.sql.catalyst.{SimpleCatalystConf, TableIdentifier}
2121
import org.apache.spark.sql.catalyst.dsl.expressions._
2222
import org.apache.spark.sql.catalyst.dsl.plans._
2323
import org.apache.spark.sql.catalyst.expressions._
@@ -377,4 +377,43 @@ class AnalysisSuite extends AnalysisTest {
377377
assertExpressionType(sum(Divide(Decimal(1), 2.0)), DoubleType)
378378
assertExpressionType(sum(Divide(1.0, Decimal(2.0))), DoubleType)
379379
}
380+
381+
test("test rule UnresolvedOrdinalSubstitution, replaces ordinal in order by or group by") {
382+
val a = testRelation2.output(0)
383+
val b = testRelation2.output(1)
384+
val conf = new SimpleCatalystConf(caseSensitiveAnalysis = true)
385+
386+
// Expression OrderByOrdinal is unresolved.
387+
assert(!UnresolvedOrdinal(0).resolved)
388+
389+
// Tests order by ordinal, apply single rule.
390+
val plan = testRelation2.orderBy(Literal(1).asc, Literal(2).asc)
391+
comparePlans(
392+
new UnresolvedOrdinalSubstitution(conf).apply(plan),
393+
testRelation2.orderBy(UnresolvedOrdinal(1).asc, UnresolvedOrdinal(2).asc))
394+
395+
// Tests order by ordinal, do full analysis
396+
checkAnalysis(plan, testRelation2.orderBy(a.asc, b.asc))
397+
398+
// order by ordinal can be turned off by config
399+
comparePlans(
400+
new UnresolvedOrdinalSubstitution(conf.copy(orderByOrdinal = false)).apply(plan),
401+
testRelation2.orderBy(Literal(1).asc, Literal(2).asc))
402+
403+
404+
// Tests group by ordinal, apply single rule.
405+
val plan2 = testRelation2.groupBy(Literal(1), Literal(2))('a, 'b)
406+
comparePlans(
407+
new UnresolvedOrdinalSubstitution(conf).apply(plan2),
408+
testRelation2.groupBy(UnresolvedOrdinal(1), UnresolvedOrdinal(2))('a, 'b))
409+
410+
// Tests group by ordinal, do full analysis
411+
checkAnalysis(plan2, testRelation2.groupBy(a, b)(a, b))
412+
413+
// group by ordinal can be turned off by config
414+
comparePlans(
415+
new UnresolvedOrdinalSubstitution(conf.copy(groupByOrdinal = false)).apply(plan2),
416+
testRelation2.groupBy(Literal(1), Literal(2))('a, 'b))
417+
418+
}
380419
}

sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,12 @@ select a, rand(0), sum(b) from data group by a, 2;
4343
-- negative case: star
4444
select * from data group by a, b, 1;
4545

46+
-- group by ordinal followed by order by
47+
select a, count(a) from (select 1 as a) tmp group by 1 order by 1;
48+
49+
-- group by ordinal followed by having
50+
select count(a), a from (select 1 as a) tmp group by 2 having a > 0;
51+
4652
-- turn of group by ordinal
4753
set spark.sql.groupByOrdinal=false;
4854

sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
-- Automatically generated by SQLQueryTestSuite
2-
-- Number of queries: 17
2+
-- Number of queries: 19
33

44

55
-- !query 0
@@ -153,16 +153,32 @@ Star (*) is not allowed in select list when GROUP BY ordinal position is used;
153153

154154

155155
-- !query 15
156-
set spark.sql.groupByOrdinal=false
156+
select a, count(a) from (select 1 as a) tmp group by 1 order by 1
157157
-- !query 15 schema
158-
struct<key:string,value:string>
158+
struct<a:int,count(a):bigint>
159159
-- !query 15 output
160-
spark.sql.groupByOrdinal
160+
1 1
161161

162162

163163
-- !query 16
164-
select sum(b) from data group by -1
164+
select count(a), a from (select 1 as a) tmp group by 2 having a > 0
165165
-- !query 16 schema
166-
struct<sum(b):bigint>
166+
struct<count(a):bigint,a:int>
167167
-- !query 16 output
168+
1 1
169+
170+
171+
-- !query 17
172+
set spark.sql.groupByOrdinal=false
173+
-- !query 17 schema
174+
struct<key:string,value:string>
175+
-- !query 17 output
176+
spark.sql.groupByOrdinal
177+
178+
179+
-- !query 18
180+
select sum(b) from data group by -1
181+
-- !query 18 schema
182+
struct<sum(b):bigint>
183+
-- !query 18 output
168184
9

0 commit comments

Comments
 (0)