Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
4a6f903
Reuse completeNextStageWithFetchFailure
beliefer Jun 19, 2020
96456e2
Merge remote-tracking branch 'upstream/master'
beliefer Jul 1, 2020
4314005
Merge remote-tracking branch 'upstream/master'
beliefer Jul 3, 2020
d6af4a7
Merge remote-tracking branch 'upstream/master'
beliefer Jul 9, 2020
f69094f
Merge remote-tracking branch 'upstream/master'
beliefer Jul 16, 2020
b86a42d
Merge remote-tracking branch 'upstream/master'
beliefer Jul 25, 2020
2ac5159
Merge branch 'master' of github.com:beliefer/spark
beliefer Jul 25, 2020
9021d6c
Merge remote-tracking branch 'upstream/master'
beliefer Jul 28, 2020
74a2ef4
Merge branch 'master' of github.com:beliefer/spark
beliefer Jul 28, 2020
199aa6f
Support single distinct group with filter.
beliefer Jul 28, 2020
a73f11e
Support distinct agg with filter
beliefer Jul 29, 2020
72e95f1
Supplement doc and comment.
beliefer Jul 29, 2020
8e82e83
Add test case and regenerate golden files.
beliefer Jul 29, 2020
4ba808b
Add test case and regenerate golden files.
beliefer Jul 29, 2020
145a9dd
Optimize code
beliefer Jul 30, 2020
0fcf643
Update doc
beliefer Jul 30, 2020
92a37a9
Optimize code.
beliefer Jul 30, 2020
7362dfb
Optimize code.
beliefer Jul 30, 2020
9828158
Merge remote-tracking branch 'upstream/master'
beliefer Jul 31, 2020
fbb051b
Merge branch 'master' into support-distinct-with-filter
beliefer Jul 31, 2020
9939ea7
Add tests case like distinct 1
beliefer Jul 31, 2020
2dc6f32
Optimize code
beliefer Jul 31, 2020
abafc20
Optimize code
beliefer Jul 31, 2020
39583dd
Optimize code
beliefer Aug 3, 2020
883973b
Optimize code
beliefer Aug 3, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1973,15 +1973,9 @@ class Analyzer(
}
// We get an aggregate function, we need to wrap it in an AggregateExpression.
case agg: AggregateFunction =>
// TODO: SPARK-30276 Support Filter expression allows simultaneous use of DISTINCT
if (filter.isDefined) {
if (isDistinct) {
failAnalysis("DISTINCT and FILTER cannot be used in aggregate functions " +
"at the same time")
} else if (!filter.get.deterministic) {
failAnalysis("FILTER expression is non-deterministic, " +
"it cannot be used in aggregate functions")
}
if (filter.isDefined && !filter.get.deterministic) {
failAnalysis("FILTER expression is non-deterministic, " +
"it cannot be used in aggregate functions")
}
AggregateExpression(agg, Complete, isDistinct, filter)
// This function is not an aggregate function, just return the resolved one.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Complete}
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types.IntegerType
Expand Down Expand Up @@ -81,10 +81,10 @@ import org.apache.spark.sql.types.IntegerType
* COUNT(DISTINCT cat1) as cat1_cnt,
* COUNT(DISTINCT cat2) as cat2_cnt,
* SUM(value) FILTER (WHERE id > 1) AS total
* FROM
* data
* GROUP BY
* key
* FROM
* data
* GROUP BY
* key
* }}}
*
* This translates to the following (pseudo) logical plan:
Expand All @@ -93,7 +93,7 @@ import org.apache.spark.sql.types.IntegerType
* key = ['key]
* functions = [COUNT(DISTINCT 'cat1),
* COUNT(DISTINCT 'cat2),
* sum('value) with FILTER('id > 1)]
* sum('value) FILTER (WHERE 'id > 1)]
* output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
* LocalTableScan [...]
* }}}
Expand All @@ -108,7 +108,7 @@ import org.apache.spark.sql.types.IntegerType
* output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
* Aggregate(
* key = ['key, 'cat1, 'cat2, 'gid]
* functions = [sum('value) with FILTER('id > 1)]
* functions = [sum('value) FILTER (WHERE 'id > 1)]
* output = ['key, 'cat1, 'cat2, 'gid, 'total])
* Expand(
* projections = [('key, null, null, 0, cast('value as bigint), 'id),
Expand All @@ -118,6 +118,49 @@ import org.apache.spark.sql.types.IntegerType
* LocalTableScan [...]
* }}}
*
* Third example: aggregate function with distinct and filter clauses (in sql):
* {{{
* SELECT
* COUNT(DISTINCT cat1) FILTER (WHERE id > 1) as cat1_cnt,
* COUNT(DISTINCT cat2) FILTER (WHERE id > 2) as cat2_cnt,
* SUM(value) FILTER (WHERE id > 3) AS total
* FROM
* data
* GROUP BY
* key
* }}}
*
* This translates to the following (pseudo) logical plan:
* {{{
* Aggregate(
* key = ['key]
* functions = [COUNT(DISTINCT 'cat1) FILTER (WHERE 'id > 1),
* COUNT(DISTINCT 'cat2) FILTER (WHERE 'id > 2),
* sum('value) FILTER (WHERE 'id > 3)]
* output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
* LocalTableScan [...]
* }}}
*
* This rule rewrites this logical plan to the following (pseudo) logical plan:
* {{{
* Aggregate(
* key = ['key]
* functions = [count(if (('gid = 1) and 'max_cond1) 'cat1 else null),
* count(if (('gid = 2) and 'max_cond2) 'cat2 else null),
* first(if (('gid = 0)) 'total else null) ignore nulls]
* output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
* Aggregate(
* key = ['key, 'cat1, 'cat2, 'gid]
* functions = [max('cond1), max('cond2), sum('value) FILTER (WHERE 'id > 3)]
* output = ['key, 'cat1, 'cat2, 'gid, 'max_cond1, 'max_cond2, 'total])
* Expand(
* projections = [('key, null, null, 0, null, null, cast('value as bigint), 'id),
* ('key, 'cat1, null, 1, 'id > 1, null, null, null),
* ('key, null, 'cat2, 2, null, 'id > 2, null, null)]
* output = ['key, 'cat1, 'cat2, 'gid, 'cond1, 'cond2, 'value, 'id])
* LocalTableScan [...]
* }}}
*
* The rule does the following things here:
* 1. Expand the data. There are three aggregation groups in this query:
* i. the non-distinct group;
Expand All @@ -126,15 +169,24 @@ import org.apache.spark.sql.types.IntegerType
* An expand operator is inserted to expand the child data for each group. The expand will null
* out all unused columns for the given group; this must be done in order to ensure correctness
* later on. Groups can by identified by a group id (gid) column added by the expand operator.
* If distinct group exists filter clause, the expand will calculate the filter and output it's
* result (e.g. cond1) which will be used to calculate the global conditions (e.g. max_cond1)
* equivalent to filter clauses.
* 2. De-duplicate the distinct paths and aggregate the non-aggregate path. The group by clause of
* this aggregate consists of the original group by clause, all the requested distinct columns
* and the group id. Both de-duplication of distinct column and the aggregation of the
* non-distinct group take advantage of the fact that we group by the group id (gid) and that we
* have nulled out all non-relevant columns the given group.
* have nulled out all non-relevant columns the given group. If distinct group exists filter
* clause, we will use max to aggregate the results (e.g. cond1) of the filter output in the
* previous step. These aggregate will output the global conditions (e.g. max_cond1) equivalent
* to filter clauses.
* 3. Aggregating the distinct groups and combining this with the results of the non-distinct
* aggregation. In this step we use the group id to filter the inputs for the aggregate
* functions. The result of the non-distinct group are 'aggregated' by using the first operator,
* it might be more elegant to use the native UDAF merge mechanism for this in the future.
* aggregation. In this step we use the group id and the global condition to filter the inputs
* for the aggregate functions. If the global condition (e.g. max_cond1) is true, it means at
* least one row of a distinct value satisfies the filter. This distinct value should be included
* in the aggregate function. The result of the non-distinct group are 'aggregated' by using
* the first operator, it might be more elegant to use the native UDAF merge mechanism for this
* in the future.
*
* This rule duplicates the input data by two or more times (# distinct groups + an optional
* non-distinct group). This will put quite a bit of memory pressure of the used aggregate and
Expand All @@ -144,28 +196,24 @@ import org.apache.spark.sql.types.IntegerType
*/
object RewriteDistinctAggregates extends Rule[LogicalPlan] {

private def mayNeedtoRewrite(exprs: Seq[Expression]): Boolean = {
val distinctAggs = exprs.flatMap { _.collect {
case ae: AggregateExpression if ae.isDistinct => ae
}}
// We need at least two distinct aggregates for this rule because aggregation
// strategy can handle a single distinct group.
private def mayNeedtoRewrite(a: Aggregate): Boolean = {
val aggExpressions = collectAggregateExprs(a)
val distinctAggs = aggExpressions.filter(_.isDistinct)
// We need at least two distinct aggregates or the single distinct aggregate group exists filter
// clause for this rule because aggregation strategy can handle a single distinct aggregate
// group without filter clause.
// This check can produce false-positives, e.g., SUM(DISTINCT a) & COUNT(DISTINCT a).
distinctAggs.size > 1
distinctAggs.size > 1 || distinctAggs.exists(_.filter.isDefined)
}

def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case a: Aggregate if mayNeedtoRewrite(a.aggregateExpressions) => rewrite(a)
case a: Aggregate if mayNeedtoRewrite(a) => rewrite(a)
}

def rewrite(a: Aggregate): Aggregate = {

// Collect all aggregate expressions.
val aggExpressions = a.aggregateExpressions.flatMap { e =>
e.collect {
case ae: AggregateExpression => ae
}
}
val aggExpressions = collectAggregateExprs(a)
val distinctAggs = aggExpressions.filter(_.isDistinct)

// Extract distinct aggregate expressions.
val distinctAggGroups = aggExpressions.filter(_.isDistinct).groupBy { e =>
Expand All @@ -184,8 +232,8 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
}
}

// Aggregation strategy can handle queries with a single distinct group.
if (distinctAggGroups.size > 1) {
// Aggregation strategy can handle queries with a single distinct group without filter clause.
if (distinctAggGroups.size > 1 || distinctAggs.exists(_.filter.isDefined)) {
// Create the attributes for the grouping id and the group by clause.
val gid = AttributeReference("gid", IntegerType, nullable = false)()
val groupByMap = a.groupingExpressions.collect {
Expand All @@ -195,7 +243,13 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
val groupByAttrs = groupByMap.map(_._2)

// Functions used to modify aggregate functions and their inputs.
def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, id), e, nullify(e))
def evalWithinGroup(id: Literal, e: Expression, condition: Option[Expression]) =
if (condition.isDefined) {
If(And(EqualTo(gid, id), condition.get), e, nullify(e))
} else {
If(EqualTo(gid, id), e, nullify(e))
}

def patchAggregateFunctionChildren(
af: AggregateFunction)(
attrs: Expression => Option[Expression]): AggregateFunction = {
Expand All @@ -207,13 +261,28 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct
val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair)
val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2)
// Setup all the filters in distinct aggregate.
val (distinctAggFilters, distinctAggFilterAttrs, maxConds) = distinctAggs.collect {
case AggregateExpression(_, _, _, filter, _) if filter.isDefined =>
val (e, attr) = expressionAttributePair(filter.get)
val aggregateExp = Max(attr).toAggregateExpression()
(e, attr, Alias(aggregateExp, attr.name)())
}.unzip3

// Setup expand & aggregate operators for distinct aggregate expressions.
val distinctAggChildAttrLookup = distinctAggChildAttrMap.toMap
val distinctAggFilterAttrLookup = distinctAggFilters.zip(maxConds.map(_.toAttribute)).toMap
val distinctAggOperatorMap = distinctAggGroups.toSeq.zipWithIndex.map {
case ((group, expressions), i) =>
val id = Literal(i + 1)

// Expand projection for filter
val filters = expressions.filter(_.filter.isDefined).map(_.filter.get)
val filterProjection = distinctAggFilters.map {
case e if filters.contains(e) => e
case e => nullify(e)
}

// Expand projection
val projection = distinctAggChildren.map {
case e if group.contains(e) => e
Expand All @@ -224,12 +293,17 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
val operators = expressions.map { e =>
val af = e.aggregateFunction
val naf = patchAggregateFunctionChildren(af) { x =>
distinctAggChildAttrLookup.get(x).map(evalWithinGroup(id, _))
val condition = if (e.filter.isDefined) {
e.filter.map(distinctAggFilterAttrLookup.get(_)).get
} else {
None
}
distinctAggChildAttrLookup.get(x).map(evalWithinGroup(id, _, condition))
}
(e, e.copy(aggregateFunction = naf, isDistinct = false))
(e, e.copy(aggregateFunction = naf, isDistinct = false, filter = None))
}

(projection, operators)
(projection ++ filterProjection, operators)
}

// Setup expand for the 'regular' aggregate expressions.
Expand Down Expand Up @@ -257,7 +331,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {

// Select the result of the first aggregate in the last aggregate.
val result = AggregateExpression(
aggregate.First(evalWithinGroup(regularGroupId, operator.toAttribute), true),
aggregate.First(evalWithinGroup(regularGroupId, operator.toAttribute, None), true),
mode = Complete,
isDistinct = false)

Expand All @@ -280,6 +354,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
Seq(a.groupingExpressions ++
distinctAggChildren.map(nullify) ++
Seq(regularGroupId) ++
distinctAggFilters.map(nullify) ++
regularAggChildren)
} else {
Seq.empty[Seq[Expression]]
Expand All @@ -297,15 +372,16 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
// Construct the expand operator.
val expand = Expand(
regularAggProjection ++ distinctAggProjections,
groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ regularAggChildAttrMap.map(_._2),
groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ distinctAggFilterAttrs ++
regularAggChildAttrMap.map(_._2),
a.child)

// Construct the first aggregate operator. This de-duplicates all the children of
// distinct operators, and applies the regular aggregate operators.
val firstAggregateGroupBy = groupByAttrs ++ distinctAggChildAttrs :+ gid
val firstAggregate = Aggregate(
firstAggregateGroupBy,
firstAggregateGroupBy ++ regularAggOperatorMap.map(_._2),
firstAggregateGroupBy ++ maxConds ++ regularAggOperatorMap.map(_._2),
expand)

// Construct the second aggregate
Expand All @@ -331,6 +407,13 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
}
}

private def collectAggregateExprs(a: Aggregate): Seq[AggregateExpression] = {
// Collect all aggregate expressions.
a.aggregateExpressions.flatMap { _.collect {
case ae: AggregateExpression => ae
}}
}

private def nullify(e: Expression) = Literal.create(null, e.dataType)

private def expressionAttributePair(e: Expression) =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,11 +207,6 @@ class AnalysisErrorSuite extends AnalysisTest {
"FILTER (WHERE c > 1)"),
"FILTER predicate specified, but aggregate is not an aggregate function" :: Nil)

errorTest(
"DISTINCT aggregate function with filter predicate",
CatalystSqlParser.parsePlan("SELECT count(DISTINCT a) FILTER (WHERE c > 1) FROM TaBlE2"),
"DISTINCT and FILTER cannot be used in aggregate functions at the same time" :: Nil)

errorTest(
"non-deterministic filter predicate in aggregate functions",
CatalystSqlParser.parsePlan("SELECT count(a) FILTER (WHERE rand(int(c)) > 1) FROM TaBlE2"),
Expand Down
Loading