Skip to content

Commit 2b91d99

Browse files
dilipbiswalgatorsmile
authored andcommitted
[SPARK-24424][SQL] Support ANSI-SQL compliant syntax for GROUPING SET
## What changes were proposed in this pull request? Enhances the parser and analyzer to support ANSI compliant syntax for GROUPING SET. As part of this change we derive the grouping expressions from user supplied groupings in the grouping sets clause. ```SQL SELECT c1, c2, max(c3) FROM t1 GROUP BY GROUPING SETS ((c1), (c1, c2)) ``` ## How was this patch tested? Added tests in SQLQueryTestSuite and ResolveGroupingAnalyticsSuite. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Dilip Biswal <dbiswal@us.ibm.com> Closes #21813 from dilipbiswal/spark-24424.
1 parent a5925c1 commit 2b91d99

File tree

5 files changed

+210
-3
lines changed

5 files changed

+210
-3
lines changed

sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,7 @@ aggregation
406406
WITH kind=ROLLUP
407407
| WITH kind=CUBE
408408
| kind=GROUPING SETS '(' groupingSet (',' groupingSet)* ')')?
409+
| GROUP BY kind=GROUPING SETS '(' groupingSet (',' groupingSet)* ')'
409410
;
410411

411412
groupingSet

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -442,17 +442,35 @@ class Analyzer(
442442
child: LogicalPlan): LogicalPlan = {
443443
val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)()
444444

445+
// In case of ANSI-SQL compliant syntax for GROUPING SETS, groupByExprs is optional and
446+
// can be null. In such case, we derive the groupByExprs from the user supplied values for
447+
// grouping sets.
448+
val finalGroupByExpressions = if (groupByExprs == Nil) {
449+
selectedGroupByExprs.flatten.foldLeft(Seq.empty[Expression]) { (result, currentExpr) =>
450+
// Only unique expressions are included in the group by expressions and is determined
451+
// based on their semantic equality. Example. grouping sets ((a * b), (b * a)) results
452+
// in grouping expression (a * b)
453+
if (result.find(_.semanticEquals(currentExpr)).isDefined) {
454+
result
455+
} else {
456+
result :+ currentExpr
457+
}
458+
}
459+
} else {
460+
groupByExprs
461+
}
462+
445463
// Expand works by setting grouping expressions to null as determined by the
446464
// `selectedGroupByExprs`. To prevent these null values from being used in an aggregate
447465
// instead of the original value we need to create new aliases for all group by expressions
448466
// that will only be used for the intended purpose.
449-
val groupByAliases = constructGroupByAlias(groupByExprs)
467+
val groupByAliases = constructGroupByAlias(finalGroupByExpressions)
450468

451469
val expand = constructExpand(selectedGroupByExprs, child, groupByAliases, gid)
452470
val groupingAttrs = expand.output.drop(child.output.length)
453471

454472
val aggregations = constructAggregateExprs(
455-
groupByExprs, aggregationExprs, groupByAliases, groupingAttrs, gid)
473+
finalGroupByExpressions, aggregationExprs, groupByAliases, groupingAttrs, gid)
456474

457475
Aggregate(groupingAttrs, aggregations, expand)
458476
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,34 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest {
9191
assertAnalysisError(originalPlan3, Seq("doesn't show up in the GROUP BY list"))
9292
}
9393

94+
test("grouping sets with no explicit group by expressions") {
95+
val originalPlan = GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)),
96+
Nil, r1,
97+
Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c))))
98+
val expected = Aggregate(Seq(a, b, gid), Seq(a, b, count(c).as("count(c)")),
99+
Expand(
100+
Seq(Seq(a, b, c, nulInt, nulStr, 3), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, a, b, 0)),
101+
Seq(a, b, c, a, b, gid),
102+
Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)))
103+
checkAnalysis(originalPlan, expected)
104+
105+
// Computation of grouping expression should remove duplicate expression based on their
106+
// semantics (semanticEqual).
107+
val originalPlan2 = GroupingSets(Seq(Seq(Multiply(unresolved_a, Literal(2))),
108+
Seq(Multiply(Literal(2), unresolved_a), unresolved_b)), Nil, r1,
109+
Seq(UnresolvedAlias(Multiply(unresolved_a, Literal(2))),
110+
unresolved_b, UnresolvedAlias(count(unresolved_c))))
111+
112+
val resultPlan = getAnalyzer(true).executeAndCheck(originalPlan2)
113+
val gExpressions = resultPlan.asInstanceOf[Aggregate].groupingExpressions
114+
assert(gExpressions.size == 3)
115+
val firstGroupingExprAttrName =
116+
gExpressions(0).asInstanceOf[AttributeReference].name.replaceAll("#[0-9]*", "#0")
117+
assert(firstGroupingExprAttrName == "(a#0 * 2)")
118+
assert(gExpressions(1).asInstanceOf[AttributeReference].name == "b")
119+
assert(gExpressions(2).asInstanceOf[AttributeReference].name == VirtualColumn.groupingIdName)
120+
}
121+
94122
test("cube") {
95123
val originalPlan = Aggregate(Seq(Cube(Seq(unresolved_a, unresolved_b))),
96124
Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c))), r1)

sql/core/src/test/resources/sql-tests/inputs/grouping_set.sql

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,41 @@ SELECT a, b, c, count(d) FROM grouping GROUP BY a, b, c GROUPING SETS ((a));
1313
-- SPARK-17849: grouping set throws NPE #3
1414
SELECT a, b, c, count(d) FROM grouping GROUP BY a, b, c GROUPING SETS ((c));
1515

16+
-- Group sets without explicit group by
17+
SELECT c1, sum(c2) FROM (VALUES ('x', 10, 0), ('y', 20, 0)) AS t (c1, c2, c3) GROUP BY GROUPING SETS (c1);
1618

19+
-- Group sets without group by and with grouping
20+
SELECT c1, sum(c2), grouping(c1) FROM (VALUES ('x', 10, 0), ('y', 20, 0)) AS t (c1, c2, c3) GROUP BY GROUPING SETS (c1);
21+
22+
-- Mutiple grouping within a grouping set
23+
SELECT c1, c2, Sum(c3), grouping__id
24+
FROM (VALUES ('x', 'a', 10), ('y', 'b', 20) ) AS t (c1, c2, c3)
25+
GROUP BY GROUPING SETS ( ( c1 ), ( c2 ) )
26+
HAVING GROUPING__ID > 1;
27+
28+
-- Group sets without explicit group by
29+
SELECT grouping(c1) FROM (VALUES ('x', 'a', 10), ('y', 'b', 20)) AS t (c1, c2, c3) GROUP BY GROUPING SETS (c1,c2);
30+
31+
-- Mutiple grouping within a grouping set
32+
SELECT -c1 AS c1 FROM (values (1,2), (3,2)) t(c1, c2) GROUP BY GROUPING SETS ((c1), (c1, c2));
33+
34+
-- complex expression in grouping sets
35+
SELECT a + b, b, sum(c) FROM (VALUES (1,1,1),(2,2,2)) AS t(a,b,c) GROUP BY GROUPING SETS ( (a + b), (b));
36+
37+
-- complex expression in grouping sets
38+
SELECT a + b, b, sum(c) FROM (VALUES (1,1,1),(2,2,2)) AS t(a,b,c) GROUP BY GROUPING SETS ( (a + b), (b + a), (b));
39+
40+
-- more query constructs with grouping sets
41+
SELECT c1 AS col1, c2 AS col2
42+
FROM (VALUES (1, 2), (3, 2)) t(c1, c2)
43+
GROUP BY GROUPING SETS ( ( c1 ), ( c1, c2 ) )
44+
HAVING col2 IS NOT NULL
45+
ORDER BY -col1;
46+
47+
-- negative tests - must have at least one grouping expression
48+
SELECT a, b, c, count(d) FROM grouping GROUP BY WITH ROLLUP;
49+
50+
SELECT a, b, c, count(d) FROM grouping GROUP BY WITH CUBE;
51+
52+
SELECT c1 FROM (values (1,2), (3,2)) t(c1, c2) GROUP BY GROUPING SETS (());
1753

sql/core/src/test/resources/sql-tests/results/grouping_set.sql.out

Lines changed: 125 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
-- Automatically generated by SQLQueryTestSuite
2-
-- Number of queries: 4
2+
-- Number of queries: 15
33

44

55
-- !query 0
@@ -40,3 +40,127 @@ struct<a:string,b:string,c:string,count(d):bigint>
4040
NULL NULL 3 1
4141
NULL NULL 6 1
4242
NULL NULL 9 1
43+
44+
45+
-- !query 4
46+
SELECT c1, sum(c2) FROM (VALUES ('x', 10, 0), ('y', 20, 0)) AS t (c1, c2, c3) GROUP BY GROUPING SETS (c1)
47+
-- !query 4 schema
48+
struct<c1:string,sum(c2):bigint>
49+
-- !query 4 output
50+
x 10
51+
y 20
52+
53+
54+
-- !query 5
55+
SELECT c1, sum(c2), grouping(c1) FROM (VALUES ('x', 10, 0), ('y', 20, 0)) AS t (c1, c2, c3) GROUP BY GROUPING SETS (c1)
56+
-- !query 5 schema
57+
struct<c1:string,sum(c2):bigint,grouping(c1):tinyint>
58+
-- !query 5 output
59+
x 10 0
60+
y 20 0
61+
62+
63+
-- !query 6
64+
SELECT c1, c2, Sum(c3), grouping__id
65+
FROM (VALUES ('x', 'a', 10), ('y', 'b', 20) ) AS t (c1, c2, c3)
66+
GROUP BY GROUPING SETS ( ( c1 ), ( c2 ) )
67+
HAVING GROUPING__ID > 1
68+
-- !query 6 schema
69+
struct<c1:string,c2:string,sum(c3):bigint,grouping__id:int>
70+
-- !query 6 output
71+
NULL a 10 2
72+
NULL b 20 2
73+
74+
75+
-- !query 7
76+
SELECT grouping(c1) FROM (VALUES ('x', 'a', 10), ('y', 'b', 20)) AS t (c1, c2, c3) GROUP BY GROUPING SETS (c1,c2)
77+
-- !query 7 schema
78+
struct<grouping(c1):tinyint>
79+
-- !query 7 output
80+
0
81+
0
82+
1
83+
1
84+
85+
86+
-- !query 8
87+
SELECT -c1 AS c1 FROM (values (1,2), (3,2)) t(c1, c2) GROUP BY GROUPING SETS ((c1), (c1, c2))
88+
-- !query 8 schema
89+
struct<c1:int>
90+
-- !query 8 output
91+
-1
92+
-1
93+
-3
94+
-3
95+
96+
97+
-- !query 9
98+
SELECT a + b, b, sum(c) FROM (VALUES (1,1,1),(2,2,2)) AS t(a,b,c) GROUP BY GROUPING SETS ( (a + b), (b))
99+
-- !query 9 schema
100+
struct<(a + b):int,b:int,sum(c):bigint>
101+
-- !query 9 output
102+
2 NULL 1
103+
4 NULL 2
104+
NULL 1 1
105+
NULL 2 2
106+
107+
108+
-- !query 10
109+
SELECT a + b, b, sum(c) FROM (VALUES (1,1,1),(2,2,2)) AS t(a,b,c) GROUP BY GROUPING SETS ( (a + b), (b + a), (b))
110+
-- !query 10 schema
111+
struct<(a + b):int,b:int,sum(c):bigint>
112+
-- !query 10 output
113+
2 NULL 2
114+
4 NULL 4
115+
NULL 1 1
116+
NULL 2 2
117+
118+
119+
-- !query 11
120+
SELECT c1 AS col1, c2 AS col2
121+
FROM (VALUES (1, 2), (3, 2)) t(c1, c2)
122+
GROUP BY GROUPING SETS ( ( c1 ), ( c1, c2 ) )
123+
HAVING col2 IS NOT NULL
124+
ORDER BY -col1
125+
-- !query 11 schema
126+
struct<col1:int,col2:int>
127+
-- !query 11 output
128+
3 2
129+
1 2
130+
131+
132+
-- !query 12
133+
SELECT a, b, c, count(d) FROM grouping GROUP BY WITH ROLLUP
134+
-- !query 12 schema
135+
struct<>
136+
-- !query 12 output
137+
org.apache.spark.sql.catalyst.parser.ParseException
138+
139+
extraneous input 'ROLLUP' expecting <EOF>(line 1, pos 53)
140+
141+
== SQL ==
142+
SELECT a, b, c, count(d) FROM grouping GROUP BY WITH ROLLUP
143+
-----------------------------------------------------^^^
144+
145+
146+
-- !query 13
147+
SELECT a, b, c, count(d) FROM grouping GROUP BY WITH CUBE
148+
-- !query 13 schema
149+
struct<>
150+
-- !query 13 output
151+
org.apache.spark.sql.catalyst.parser.ParseException
152+
153+
extraneous input 'CUBE' expecting <EOF>(line 1, pos 53)
154+
155+
== SQL ==
156+
SELECT a, b, c, count(d) FROM grouping GROUP BY WITH CUBE
157+
-----------------------------------------------------^^^
158+
159+
160+
-- !query 14
161+
SELECT c1 FROM (values (1,2), (3,2)) t(c1, c2) GROUP BY GROUPING SETS (())
162+
-- !query 14 schema
163+
struct<>
164+
-- !query 14 output
165+
org.apache.spark.sql.AnalysisException
166+
expression '`c1`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.;

0 commit comments

Comments
 (0)