-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-27986][SQL] Support ANSI SQL filter clause for aggregate expression #26656
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
17b76e2
3f0583f
f64d14c
d521be1
8e342da
0e56d03
fd6461f
f32ac4d
5d33dab
9ea4736
4dcd0d3
060d3d4
4443883
8beff8a
4d0c3aa
14f2b21
895f6ac
675dca9
1c1cf52
fb8f477
b677268
b831855
4c644ca
255650a
518aa4f
d979509
6082e57
ed80517
3d37370
392c18d
9a127e4
967b135
07f774a
c86b691
81c9482
747b3ab
f66c180
3652aef
8bfff6f
0911a76
61bf6fd
ea472aa
583d51f
030a9dc
df643ba
14daee6
ce51461
ce53930
cb31eea
f154622
4d1413f
0d20561
bc2ad92
1297e03
f56400a
eb856df
cffe318
4523616
1cb0725
33d2b5b
c3e0f6a
6c878d3
40e31be
affb6c0
de11c4d
7c40292
258a6c6
4a494ae
8cdd92d
46c4980
d3f38f2
d40dd9f
2518692
94a4a06
9adfd2d
0a4a5a2
66ceeca
3a350cb
01f306e
87697ec
4dd7527
53a6f2a
b29ef0f
c15389d
0cabcb6
b58c126
a3bb997
e520938
a2a79d5
ff1147f
b3584c8
e03a959
998929c
a3d71f4
91fe90e
eb96463
97bb440
bb14439
c1dbb27
000ae72
27ad46e
436b1c0
d5aa8ee
5b2a1b9
6f9e839
d98ea41
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -33,11 +33,14 @@ import org.apache.spark.sql.types.DataType | |
| case class ResolveHigherOrderFunctions(catalog: SessionCatalog) extends Rule[LogicalPlan] { | ||
|
|
||
| override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressions { | ||
| case u @ UnresolvedFunction(fn, children, false) | ||
| case u @ UnresolvedFunction(fn, children, false, filter) | ||
| if hasLambdaAndResolvedArguments(children) => | ||
| withPosition(u) { | ||
| catalog.lookupFunction(fn, children) match { | ||
| case func: HigherOrderFunction => func | ||
| case func: HigherOrderFunction => | ||
| filter.foreach(_.failAnalysis("FILTER predicate specified, " + | ||
| s"but ${func.prettyName} is not an aggregate function")) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add tests for this path?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK |
||
| func | ||
| case other => other.failAnalysis( | ||
| "A lambda function should only be used in a higher order function. However, " + | ||
| s"its class is ${other.getClass.getCanonicalName}, which is not a " + | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,7 +28,7 @@ import org.apache.spark.sql.types.IntegerType | |
| * aggregation in which the regular aggregation expressions and every distinct clause is aggregated | ||
| * in a separate group. The results are then combined in a second aggregate. | ||
| * | ||
| * For example (in scala): | ||
| * First example: query without filter clauses (in scala): | ||
| * {{{ | ||
| * val data = Seq( | ||
| * ("a", "ca1", "cb1", 10), | ||
|
|
@@ -75,6 +75,49 @@ import org.apache.spark.sql.types.IntegerType | |
| * LocalTableScan [...] | ||
| * }}} | ||
| * | ||
| * Second example: aggregate function without distinct and with filter clauses (in sql): | ||
| * {{{ | ||
| * SELECT | ||
| * COUNT(DISTINCT cat1) as cat1_cnt, | ||
| * COUNT(DISTINCT cat2) as cat2_cnt, | ||
| * SUM(value) FILTER (WHERE id > 1) AS total | ||
| * FROM | ||
| * data | ||
| * GROUP BY | ||
| * key | ||
| * }}} | ||
| * | ||
| * This translates to the following (pseudo) logical plan: | ||
| * {{{ | ||
| * Aggregate( | ||
| * key = ['key] | ||
| * functions = [COUNT(DISTINCT 'cat1), | ||
| * COUNT(DISTINCT 'cat2), | ||
| * sum('value) with FILTER('id > 1)] | ||
| * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total]) | ||
| * LocalTableScan [...] | ||
| * }}} | ||
| * | ||
| * This rule rewrites this logical plan to the following (pseudo) logical plan: | ||
| * {{{ | ||
| * Aggregate( | ||
| * key = ['key] | ||
| * functions = [count(if (('gid = 1)) 'cat1 else null), | ||
| * count(if (('gid = 2)) 'cat2 else null), | ||
| * first(if (('gid = 0)) 'total else null) ignore nulls] | ||
| * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total]) | ||
| * Aggregate( | ||
| * key = ['key, 'cat1, 'cat2, 'gid] | ||
| * functions = [sum('value) with FILTER('id > 1)] | ||
| * output = ['key, 'cat1, 'cat2, 'gid, 'total]) | ||
| * Expand( | ||
| * projections = [('key, null, null, 0, cast('value as bigint), 'id), | ||
cloud-fan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| * ('key, 'cat1, null, 1, null, null), | ||
| * ('key, null, 'cat2, 2, null, null)] | ||
| * output = ['key, 'cat1, 'cat2, 'gid, 'value, 'id]) | ||
| * LocalTableScan [...] | ||
| * }}} | ||
| * | ||
| * The rule does the following things here: | ||
| * 1. Expand the data. There are three aggregation groups in this query: | ||
| * i. the non-distinct group; | ||
|
|
@@ -183,9 +226,10 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { | |
| // only expand unfoldable children | ||
| val regularAggExprs = aggExpressions | ||
| .filter(e => !e.isDistinct && e.children.exists(!_.foldable)) | ||
| val regularAggChildren = regularAggExprs | ||
| val regularAggFunChildren = regularAggExprs | ||
| .flatMap(_.aggregateFunction.children.filter(!_.foldable)) | ||
| .distinct | ||
| val regularAggFilterAttrs = regularAggExprs.flatMap(_.filterAttributes) | ||
| val regularAggChildren = (regularAggFunChildren ++ regularAggFilterAttrs).distinct | ||
| val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair) | ||
|
|
||
| // Setup aggregates for 'regular' aggregate expressions. | ||
|
|
@@ -194,7 +238,12 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { | |
| val regularAggOperatorMap = regularAggExprs.map { e => | ||
| // Perform the actual aggregation in the initial aggregate. | ||
| val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup.get) | ||
| val operator = Alias(e.copy(aggregateFunction = af), e.sql)() | ||
| // We changed the attributes in the [[Expand]] output using expressionAttributePair. | ||
| // So we need to replace the attributes in FILTER expression with new ones. | ||
| val filterOpt = e.filter.map(_.transform { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you leave some comments for #26656 (comment) here?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK. There need the comments indeed. |
||
| case a: Attribute => regularAggChildAttrLookup.getOrElse(a, a) | ||
| }) | ||
| val operator = Alias(e.copy(aggregateFunction = af, filter = filterOpt), e.sql)() | ||
|
|
||
| // Select the result of the first aggregate in the last aggregate. | ||
| val result = AggregateExpression( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you update
TableIdentifierParserSuite, too?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A function named FILTER, so I removed from
TableIdentifierParserSuite.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, I see.