Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1920,15 +1920,9 @@ class Analyzer(
}
// We get an aggregate function, we need to wrap it in an AggregateExpression.
case agg: AggregateFunction =>
// TODO: SPARK-30276 Support Filter expression allows simultaneous use of DISTINCT
if (filter.isDefined) {
if (isDistinct) {
failAnalysis("DISTINCT and FILTER cannot be used in aggregate functions " +
"at the same time")
} else if (!filter.get.deterministic) {
failAnalysis("FILTER expression is non-deterministic, " +
"it cannot be used in aggregate functions")
}
if (filter.isDefined && !filter.get.deterministic) {
failAnalysis("FILTER expression is non-deterministic, " +
"it cannot be used in aggregate functions")
}
AggregateExpression(agg, Complete, isDistinct, filter)
// This function is not an aggregate function, just return the resolved one.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ package object dsl {
def count(e: Expression): Expression = Count(e).toAggregateExpression()
def countDistinct(e: Expression*): Expression =
Count(e).toAggregateExpression(isDistinct = true)
def countDistinct(filter: Option[Expression], e: Expression*): Expression =
Count(e).toAggregateExpression(isDistinct = true, filter = filter)
def approxCountDistinct(e: Expression, rsd: Double = 0.05): Expression =
HyperLogLogPlusPlus(e, rsd).toAggregateExpression()
def avg(e: Expression): Expression = Average(e).toAggregateExpression()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,15 +216,21 @@ abstract class AggregateFunction extends Expression {
def toAggregateExpression(): AggregateExpression = toAggregateExpression(isDistinct = false)

/**
* Wraps this [[AggregateFunction]] in an [[AggregateExpression]] and sets `isDistinct`
* flag of the [[AggregateExpression]] to the given value because
* Wraps this [[AggregateFunction]] in an [[AggregateExpression]] with `isDistinct`
* flag and an optional `filter` of the [[AggregateExpression]] to the given value because
* [[AggregateExpression]] is the container of an [[AggregateFunction]], aggregation mode,
* and the flag indicating if this aggregation is distinct aggregation or not.
* An [[AggregateFunction]] should not be used without being wrapped in
* the flag indicating if this aggregation is distinct aggregation or not and the optional
* `filter`. An [[AggregateFunction]] should not be used without being wrapped in
* an [[AggregateExpression]].
*/
def toAggregateExpression(isDistinct: Boolean): AggregateExpression = {
AggregateExpression(aggregateFunction = this, mode = Complete, isDistinct = isDistinct)
def toAggregateExpression(
isDistinct: Boolean,
filter: Option[Expression] = None): AggregateExpression = {
AggregateExpression(
aggregateFunction = this,
mode = Complete,
isDistinct = isDistinct,
filter = filter)
}

def sql(isDistinct: Boolean): String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Complete}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types.IntegerType

Expand All @@ -31,10 +31,10 @@ import org.apache.spark.sql.types.IntegerType
* First example: query without filter clauses (in scala):
* {{{
* val data = Seq(
* ("a", "ca1", "cb1", 10),
* ("a", "ca1", "cb2", 5),
* ("b", "ca1", "cb1", 13))
* .toDF("key", "cat1", "cat2", "value")
* (1, "a", "ca1", "cb1", 10),
* (2, "a", "ca1", "cb2", 5),
* (3, "b", "ca1", "cb1", 13))
* .toDF("id", "key", "cat1", "cat2", "value")
* data.createOrReplaceTempView("data")
*
* val agg = data.groupBy($"key")
Expand Down Expand Up @@ -118,7 +118,110 @@ import org.apache.spark.sql.types.IntegerType
* LocalTableScan [...]
* }}}
*
* The rule does the following things here:
* Third example: single distinct aggregate function with filter clauses and have
* not other distinct aggregate function (in sql):
* {{{
* SELECT
* COUNT(DISTINCT cat1) FILTER (WHERE id > 1) as cat1_cnt,
* SUM(value) AS total
* FROM
* data
* GROUP BY
* key
* }}}
*
* This translates to the following (pseudo) logical plan:
* {{{
* Aggregate(
* key = ['key]
* functions = [COUNT(DISTINCT 'cat1) with FILTER('id > 1),
* sum('value)]
* output = ['key, 'cat1_cnt, 'total])
* LocalTableScan [...]
* }}}
*
* This rule rewrites this logical plan to the following (pseudo) logical plan:
* {{{
* Aggregate(
* key = ['key]
* functions = [count('_gen_distinct_1),
* sum('value)]
* output = ['key, 'cat1_cnt, 'total])
* Project(
* projectionList = ['key, if ('id > 1) 'cat1 else null, cast('value as bigint)]
* output = ['key, '_gen_distinct_1, 'value])
* LocalTableScan [...]
* }}}
*
* Four example: single distinct aggregate function with filter clauses (in sql):
* {{{
* SELECT
* COUNT(DISTINCT cat1) FILTER (WHERE id > 1) as cat1_cnt,
* COUNT(DISTINCT cat2) as cat2_cnt,
* SUM(value) AS total
* FROM
* data
* GROUP BY
* key
* }}}
*
* This translates to the following (pseudo) logical plan:
* {{{
* Aggregate(
* key = ['key]
* functions = [COUNT(DISTINCT 'cat1) with FILTER('id > 1),
* COUNT(DISTINCT 'cat2),
* sum('value)]
* output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
* LocalTableScan [...]
* }}}
*
* This rule rewrites this logical plan to the following (pseudo) logical plan:
* {{{
* Aggregate(
* key = ['key]
* functions = [count(if (('gid = 1)) '_gen_distinct_1 else null),
* count(if (('gid = 2)) '_gen_distinct_2 else null),
* first(if (('gid = 0)) 'total else null) ignore nulls]
* output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
* Aggregate(
* key = ['key, '_gen_distinct_1, '_gen_distinct_2, 'gid]
* functions = [sum('value)]
* output = ['key, '_gen_distinct_1, '_gen_distinct_2, 'gid, 'total])
* Expand(
* projections = [('key, null, null, 0, cast('value as bigint)),
* ('key, if ('id > 1) 'cat1 else null, null, 1, null),
* ('key, null, 'cat2, 2, null)]
* output = ['key, '_gen_distinct_1, '_gen_distinct_2, 'gid, 'value])
* LocalTableScan [...]
* }}}
*
* The rule consists of the two phases as follows:
*
* In the first phase, if the aggregate query with distinct aggregations and
* filter clauses, project the output of the child of the aggregate query:
* 1. Project the data. There are three aggregation groups in this query:
* i. the non-distinct group;
* ii. the distinct 'cat1 group;
* iii. the distinct 'cat2 group with filter clause.
* Because there is at least one distinct group with filter clause (e.g. the distinct 'cat2
* group with filter clause), then will project the data.
* 2. Avoid projections that may output the same attributes. There are three aggregation groups
* in this query:
* i. the non-distinct group;
* ii. the distinct 'cat1 group;
* iii. the distinct 'cat1 group with filter clause.
* The attributes referenced by different distinct aggregate expressions are likely to overlap,
* and if no additional processing is performed, data loss will occur. If we directly output
* the attributes of the aggregate expression, we may get two attributes 'cat1. To prevent
* this, we generate new attributes (e.g. '_gen_distinct_1) and replace the original ones.
*
* Why we need the first phase? guaranteed to compute filter clauses in the first aggregate
* locally.
* Note: after generate new attributes, the aggregate may have at least two distinct aggregates,
* so we need the second phase too.
*
* In the second phase, rewrite a query with two or more distinct groups:
* 1. Expand the data. There are three aggregation groups in this query:
* i. the non-distinct group;
* ii. the distinct 'cat1 group;
Expand All @@ -135,6 +238,9 @@ import org.apache.spark.sql.types.IntegerType
* aggregation. In this step we use the group id to filter the inputs for the aggregate
* functions. The result of the non-distinct group are 'aggregated' by using the first operator,
* it might be more elegant to use the native UDAF merge mechanism for this in the future.
* 4. If the first phase inserted a project operator as the child of aggregate and the second phase
* already decided to insert an expand operator as the child of aggregate, the second phase will
* merge the project operator with expand operator.
*
* This rule duplicates the input data by two or more times (# distinct groups + an optional
* non-distinct group). This will put quite a bit of memory pressure of the used aggregate and
Expand All @@ -148,24 +254,107 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
val distinctAggs = exprs.flatMap { _.collect {
case ae: AggregateExpression if ae.isDistinct => ae
}}
// We need at least two distinct aggregates for this rule because aggregation
// strategy can handle a single distinct group.
// We need at least two distinct aggregates or a single distinct aggregate with a filter for
// this rule because aggregation strategy can handle a single distinct group without a filter.
// This check can produce false-positives, e.g., SUM(DISTINCT a) & COUNT(DISTINCT a).
distinctAggs.size > 1
distinctAggs.size > 1 || distinctAggs.exists(_.filter.isDefined)
}

def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case a: Aggregate if mayNeedtoRewrite(a.aggregateExpressions) => rewrite(a)
case a: Aggregate if mayNeedtoRewrite(a.aggregateExpressions) =>
val (aggregate, projected) = projectFiltersInDistinctAggregates(a)
rewriteDistinctAggregates(aggregate, projected)
}

def rewrite(a: Aggregate): Aggregate = {
private def projectFiltersInDistinctAggregates(a: Aggregate): (Aggregate, Boolean) = {
val aggExpressions = collectAggregateExprs(a)
val (distinctAggExpressions, regularAggExpressions) = aggExpressions.partition(_.isDistinct)
if (distinctAggExpressions.exists(_.filter.isDefined)) {
// Constructs pairs between old and new expressions for regular aggregates. Because we
// will construct a new `Aggregate` and the children of the distinct aggregates will be
// changed to generated ones, we need to create new references to avoid collisions between
// distinct and regular aggregate children.
val regularAggExprs = regularAggExpressions.filter(_.children.exists(!_.foldable))
val regularFunChildren = regularAggExprs
.flatMap(_.aggregateFunction.children.filter(!_.foldable))
val regularFilterAttrs = regularAggExprs.flatMap(_.filterAttributes)
val regularAggChildren = (regularFunChildren ++ regularFilterAttrs).distinct
val regularAggChildrenMap = regularAggChildren.map {
case ne: NamedExpression => ne -> ne
case other => other -> Alias(other, other.toString)()
}
val namedRegularAggChildren = regularAggChildrenMap.map(_._2)
val regularAggChildAttrLookup = regularAggChildrenMap.map { kv =>
(kv._1, kv._2.toAttribute)
}.toMap
val regularAggPairs = regularAggExprs.map {
case ae @ AggregateExpression(af, _, _, filter, _) =>
val newChildren = af.children.map(c => regularAggChildAttrLookup.getOrElse(c, c))
val raf = af.withNewChildren(newChildren).asInstanceOf[AggregateFunction]
val filterOpt = filter.map(_.transform {
case a: Attribute => regularAggChildAttrLookup.getOrElse(a, a)
})
val aggExpr = ae.copy(aggregateFunction = raf, filter = filterOpt)
(ae, aggExpr)
}

// Collect all aggregate expressions.
val aggExpressions = a.aggregateExpressions.flatMap { e =>
e.collect {
case ae: AggregateExpression => ae
// Constructs pairs between old and new expressions for distinct aggregates, too.
val distinctAggExprs = distinctAggExpressions.filter(e => e.children.exists(!_.foldable))
val (projections, distinctAggPairs) = distinctAggExprs.map {
case ae @ AggregateExpression(af, _, _, filter, _) =>
// First, In order to reduce costs, it is better to handle the filter clause locally.
// e.g. COUNT (DISTINCT a) FILTER (WHERE id > 1), evaluate expression
// If(id > 1) 'a else null first, and use the result as output.
// Second, If at least two DISTINCT aggregate expression which may references the
// same attributes. We need to construct the generated attributes so as the output not
// lost. e.g. SUM (DISTINCT a), COUNT (DISTINCT a) FILTER (WHERE id > 1) will output
// attribute '_gen_distinct-1 and attribute '_gen_distinct-2 instead of two 'a.
// Note: The illusionary mechanism may result in at least two distinct groups, so we
// still need to call `rewrite`.
val unfoldableChildren = af.children.filter(!_.foldable)
// Expand projection
val projectionMap = unfoldableChildren.map {
case e if filter.isDefined =>
val ife = If(filter.get, e, nullify(e))
e -> Alias(ife, s"_gen_distinct_${NamedExpression.newExprId.id}")()
// For convenience and unification, we always alias the distinct column, even if
// there is no filter.
case e => e -> Alias(e, s"_gen_distinct_${NamedExpression.newExprId.id}")()
}
val projection = projectionMap.map(_._2)
val exprAttrs = projectionMap.map { kv =>
(kv._1, kv._2.toAttribute)
}
val exprAttrLookup = exprAttrs.toMap
val newChildren = af.children.map(c => exprAttrLookup.getOrElse(c, c))
val raf = af.withNewChildren(newChildren).asInstanceOf[AggregateFunction]
val aggExpr = ae.copy(aggregateFunction = raf, filter = None)
(projection, (ae, aggExpr))
}.unzip
// Construct the aggregate input projection.
val namedGroupingExpressions = a.groupingExpressions.map {
case ne: NamedExpression => ne
case other => Alias(other, other.toString)()
}
val rewriteAggProjection =
namedGroupingExpressions ++ namedRegularAggChildren ++ projections.flatten
// Construct the project operator.
val project = Project(rewriteAggProjection, a.child)
val groupByAttrs = namedGroupingExpressions.map(_.toAttribute)
val rewriteAggExprLookup = (distinctAggPairs ++ regularAggPairs).toMap
val patchedAggExpressions = a.aggregateExpressions.map { e =>
e.transformDown {
case ae: AggregateExpression => rewriteAggExprLookup.getOrElse(ae, ae)
}.asInstanceOf[NamedExpression]
}
(Aggregate(groupByAttrs, patchedAggExpressions, project), true)
} else {
(a, false)
}
}

private def rewriteDistinctAggregates(a: Aggregate, projected: Boolean): Aggregate = {
val aggExpressions = collectAggregateExprs(a)

// Extract distinct aggregate expressions.
val distinctAggGroups = aggExpressions.filter(_.isDistinct).groupBy { e =>
Expand Down Expand Up @@ -205,7 +394,15 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {

// Setup unique distinct aggregate children.
val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct
val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair)
val distinctAggChildAttrMap = if (projected) {
// To facilitate merging Project with Expand, not need creating a new reference here.
distinctAggChildren.map {
case ar: AttributeReference => ar -> ar
case other => expressionAttributePair(other)
}
} else {
distinctAggChildren.map(expressionAttributePair)
}
val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2)

// Setup expand & aggregate operators for distinct aggregate expressions.
Expand Down Expand Up @@ -294,11 +491,27 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
regularAggNulls
}

val (projections, expandChild) = if (projected) {
// If `projectFiltersInDistinctAggregates` inserts Project as child of Aggregate and
// `rewriteDistinctAggregates` will insert Expand here, merge Project with the Expand.
val projectAttributeExpressionMap = a.child.asInstanceOf[Project].projectList.map {
case ne: NamedExpression => ne.name -> ne
}.toMap
val projections = (regularAggProjection ++ distinctAggProjections).map {
case projection: Seq[Expression] => projection.map {
case ne: NamedExpression => projectAttributeExpressionMap.getOrElse(ne.name, ne)
case other => other
}
}
(projections, a.child.asInstanceOf[Project].child)
} else {
(regularAggProjection ++ distinctAggProjections, a.child)
}
// Construct the expand operator.
val expand = Expand(
regularAggProjection ++ distinctAggProjections,
projections,
groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ regularAggChildAttrMap.map(_._2),
a.child)
expandChild)

// Construct the first aggregate operator. This de-duplicates all the children of
// distinct operators, and applies the regular aggregate operators.
Expand Down Expand Up @@ -331,6 +544,14 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
}
}

private def collectAggregateExprs(a: Aggregate): Seq[AggregateExpression] = {
a.aggregateExpressions.flatMap { e =>
e.collect {
case ae: AggregateExpression => ae
}
}
}

private def nullify(e: Expression) = Literal.create(null, e.dataType)

private def expressionAttributePair(e: Expression) =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,11 +207,6 @@ class AnalysisErrorSuite extends AnalysisTest {
"FILTER (WHERE c > 1)"),
"FILTER predicate specified, but aggregate is not an aggregate function" :: Nil)

errorTest(
"DISTINCT aggregate function with filter predicate",
CatalystSqlParser.parsePlan("SELECT count(DISTINCT a) FILTER (WHERE c > 1) FROM TaBlE2"),
"DISTINCT and FILTER cannot be used in aggregate functions at the same time" :: Nil)

errorTest(
"non-deterministic filter predicate in aggregate functions",
CatalystSqlParser.parsePlan("SELECT count(a) FILTER (WHERE rand(int(c)) > 1) FROM TaBlE2"),
Expand Down
Loading