diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e58b0ae64784..477863a1b86d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1973,15 +1973,9 @@ class Analyzer( } // We get an aggregate function, we need to wrap it in an AggregateExpression. case agg: AggregateFunction => - // TODO: SPARK-30276 Support Filter expression allows simultaneous use of DISTINCT - if (filter.isDefined) { - if (isDistinct) { - failAnalysis("DISTINCT and FILTER cannot be used in aggregate functions " + - "at the same time") - } else if (!filter.get.deterministic) { - failAnalysis("FILTER expression is non-deterministic, " + - "it cannot be used in aggregate functions") - } + if (filter.isDefined && !filter.get.deterministic) { + failAnalysis("FILTER expression is non-deterministic, " + + "it cannot be used in aggregate functions") } AggregateExpression(agg, Complete, isDistinct, filter) // This function is not an aggregate function, just return the resolved one. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index 15aa02ff677d..af3a8fe684bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Complete} +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types.IntegerType @@ -81,10 +81,10 @@ import org.apache.spark.sql.types.IntegerType * COUNT(DISTINCT cat1) as cat1_cnt, * COUNT(DISTINCT cat2) as cat2_cnt, * SUM(value) FILTER (WHERE id > 1) AS total - * FROM - * data - * GROUP BY - * key + * FROM + * data + * GROUP BY + * key * }}} * * This translates to the following (pseudo) logical plan: @@ -93,7 +93,7 @@ import org.apache.spark.sql.types.IntegerType * key = ['key] * functions = [COUNT(DISTINCT 'cat1), * COUNT(DISTINCT 'cat2), - * sum('value) with FILTER('id > 1)] + * sum('value) FILTER (WHERE 'id > 1)] * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total]) * LocalTableScan [...] * }}} @@ -108,7 +108,7 @@ import org.apache.spark.sql.types.IntegerType * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total]) * Aggregate( * key = ['key, 'cat1, 'cat2, 'gid] - * functions = [sum('value) with FILTER('id > 1)] + * functions = [sum('value) FILTER (WHERE 'id > 1)] * output = ['key, 'cat1, 'cat2, 'gid, 'total]) * Expand( * projections = [('key, null, null, 0, cast('value as bigint), 'id), @@ -118,6 +118,49 @@ import org.apache.spark.sql.types.IntegerType * LocalTableScan [...] * }}} * + * Third example: aggregate function with distinct and filter clauses (in sql): + * {{{ + * SELECT + * COUNT(DISTINCT cat1) FILTER (WHERE id > 1) as cat1_cnt, + * COUNT(DISTINCT cat2) FILTER (WHERE id > 2) as cat2_cnt, + * SUM(value) FILTER (WHERE id > 3) AS total + * FROM + * data + * GROUP BY + * key + * }}} + * + * This translates to the following (pseudo) logical plan: + * {{{ + * Aggregate( + * key = ['key] + * functions = [COUNT(DISTINCT 'cat1) FILTER (WHERE 'id > 1), + * COUNT(DISTINCT 'cat2) FILTER (WHERE 'id > 2), + * sum('value) FILTER (WHERE 'id > 3)] + * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total]) + * LocalTableScan [...] + * }}} + * + * This rule rewrites this logical plan to the following (pseudo) logical plan: + * {{{ + * Aggregate( + * key = ['key] + * functions = [count(if (('gid = 1) and 'max_cond1) 'cat1 else null), + * count(if (('gid = 2) and 'max_cond2) 'cat2 else null), + * first(if (('gid = 0)) 'total else null) ignore nulls] + * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total]) + * Aggregate( + * key = ['key, 'cat1, 'cat2, 'gid] + * functions = [max('cond1), max('cond2), sum('value) FILTER (WHERE 'id > 3)] + * output = ['key, 'cat1, 'cat2, 'gid, 'max_cond1, 'max_cond2, 'total]) + * Expand( + * projections = [('key, null, null, 0, null, null, cast('value as bigint), 'id), + * ('key, 'cat1, null, 1, 'id > 1, null, null, null), + * ('key, null, 'cat2, 2, null, 'id > 2, null, null)] + * output = ['key, 'cat1, 'cat2, 'gid, 'cond1, 'cond2, 'value, 'id]) + * LocalTableScan [...] + * }}} + * * The rule does the following things here: * 1. Expand the data. There are three aggregation groups in this query: * i. the non-distinct group; @@ -126,15 +169,24 @@ import org.apache.spark.sql.types.IntegerType * An expand operator is inserted to expand the child data for each group. The expand will null * out all unused columns for the given group; this must be done in order to ensure correctness * later on. Groups can by identified by a group id (gid) column added by the expand operator. + * If distinct group exists filter clause, the expand will calculate the filter and output it's + * result (e.g. cond1) which will be used to calculate the global conditions (e.g. max_cond1) + * equivalent to filter clauses. * 2. De-duplicate the distinct paths and aggregate the non-aggregate path. The group by clause of * this aggregate consists of the original group by clause, all the requested distinct columns * and the group id. Both de-duplication of distinct column and the aggregation of the * non-distinct group take advantage of the fact that we group by the group id (gid) and that we - * have nulled out all non-relevant columns the given group. + * have nulled out all non-relevant columns the given group. If distinct group exists filter + * clause, we will use max to aggregate the results (e.g. cond1) of the filter output in the + * previous step. These aggregate will output the global conditions (e.g. max_cond1) equivalent + * to filter clauses. * 3. Aggregating the distinct groups and combining this with the results of the non-distinct - * aggregation. In this step we use the group id to filter the inputs for the aggregate - * functions. The result of the non-distinct group are 'aggregated' by using the first operator, - * it might be more elegant to use the native UDAF merge mechanism for this in the future. + * aggregation. In this step we use the group id and the global condition to filter the inputs + * for the aggregate functions. If the global condition (e.g. max_cond1) is true, it means at + * least one row of a distinct value satisfies the filter. This distinct value should be included + * in the aggregate function. The result of the non-distinct group are 'aggregated' by using + * the first operator, it might be more elegant to use the native UDAF merge mechanism for this + * in the future. * * This rule duplicates the input data by two or more times (# distinct groups + an optional * non-distinct group). This will put quite a bit of memory pressure of the used aggregate and @@ -144,28 +196,24 @@ import org.apache.spark.sql.types.IntegerType */ object RewriteDistinctAggregates extends Rule[LogicalPlan] { - private def mayNeedtoRewrite(exprs: Seq[Expression]): Boolean = { - val distinctAggs = exprs.flatMap { _.collect { - case ae: AggregateExpression if ae.isDistinct => ae - }} - // We need at least two distinct aggregates for this rule because aggregation - // strategy can handle a single distinct group. + private def mayNeedtoRewrite(a: Aggregate): Boolean = { + val aggExpressions = collectAggregateExprs(a) + val distinctAggs = aggExpressions.filter(_.isDistinct) + // We need at least two distinct aggregates or the single distinct aggregate group exists filter + // clause for this rule because aggregation strategy can handle a single distinct aggregate + // group without filter clause. // This check can produce false-positives, e.g., SUM(DISTINCT a) & COUNT(DISTINCT a). - distinctAggs.size > 1 + distinctAggs.size > 1 || distinctAggs.exists(_.filter.isDefined) } def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case a: Aggregate if mayNeedtoRewrite(a.aggregateExpressions) => rewrite(a) + case a: Aggregate if mayNeedtoRewrite(a) => rewrite(a) } def rewrite(a: Aggregate): Aggregate = { - // Collect all aggregate expressions. - val aggExpressions = a.aggregateExpressions.flatMap { e => - e.collect { - case ae: AggregateExpression => ae - } - } + val aggExpressions = collectAggregateExprs(a) + val distinctAggs = aggExpressions.filter(_.isDistinct) // Extract distinct aggregate expressions. val distinctAggGroups = aggExpressions.filter(_.isDistinct).groupBy { e => @@ -184,8 +232,8 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { } } - // Aggregation strategy can handle queries with a single distinct group. - if (distinctAggGroups.size > 1) { + // Aggregation strategy can handle queries with a single distinct group without filter clause. + if (distinctAggGroups.size > 1 || distinctAggs.exists(_.filter.isDefined)) { // Create the attributes for the grouping id and the group by clause. val gid = AttributeReference("gid", IntegerType, nullable = false)() val groupByMap = a.groupingExpressions.collect { @@ -195,7 +243,13 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { val groupByAttrs = groupByMap.map(_._2) // Functions used to modify aggregate functions and their inputs. - def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, id), e, nullify(e)) + def evalWithinGroup(id: Literal, e: Expression, condition: Option[Expression]) = + if (condition.isDefined) { + If(And(EqualTo(gid, id), condition.get), e, nullify(e)) + } else { + If(EqualTo(gid, id), e, nullify(e)) + } + def patchAggregateFunctionChildren( af: AggregateFunction)( attrs: Expression => Option[Expression]): AggregateFunction = { @@ -207,13 +261,28 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair) val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2) + // Setup all the filters in distinct aggregate. + val (distinctAggFilters, distinctAggFilterAttrs, maxConds) = distinctAggs.collect { + case AggregateExpression(_, _, _, filter, _) if filter.isDefined => + val (e, attr) = expressionAttributePair(filter.get) + val aggregateExp = Max(attr).toAggregateExpression() + (e, attr, Alias(aggregateExp, attr.name)()) + }.unzip3 // Setup expand & aggregate operators for distinct aggregate expressions. val distinctAggChildAttrLookup = distinctAggChildAttrMap.toMap + val distinctAggFilterAttrLookup = distinctAggFilters.zip(maxConds.map(_.toAttribute)).toMap val distinctAggOperatorMap = distinctAggGroups.toSeq.zipWithIndex.map { case ((group, expressions), i) => val id = Literal(i + 1) + // Expand projection for filter + val filters = expressions.filter(_.filter.isDefined).map(_.filter.get) + val filterProjection = distinctAggFilters.map { + case e if filters.contains(e) => e + case e => nullify(e) + } + // Expand projection val projection = distinctAggChildren.map { case e if group.contains(e) => e @@ -224,12 +293,17 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { val operators = expressions.map { e => val af = e.aggregateFunction val naf = patchAggregateFunctionChildren(af) { x => - distinctAggChildAttrLookup.get(x).map(evalWithinGroup(id, _)) + val condition = if (e.filter.isDefined) { + e.filter.map(distinctAggFilterAttrLookup.get(_)).get + } else { + None + } + distinctAggChildAttrLookup.get(x).map(evalWithinGroup(id, _, condition)) } - (e, e.copy(aggregateFunction = naf, isDistinct = false)) + (e, e.copy(aggregateFunction = naf, isDistinct = false, filter = None)) } - (projection, operators) + (projection ++ filterProjection, operators) } // Setup expand for the 'regular' aggregate expressions. @@ -257,7 +331,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // Select the result of the first aggregate in the last aggregate. val result = AggregateExpression( - aggregate.First(evalWithinGroup(regularGroupId, operator.toAttribute), true), + aggregate.First(evalWithinGroup(regularGroupId, operator.toAttribute, None), true), mode = Complete, isDistinct = false) @@ -280,6 +354,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { Seq(a.groupingExpressions ++ distinctAggChildren.map(nullify) ++ Seq(regularGroupId) ++ + distinctAggFilters.map(nullify) ++ regularAggChildren) } else { Seq.empty[Seq[Expression]] @@ -297,7 +372,8 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // Construct the expand operator. val expand = Expand( regularAggProjection ++ distinctAggProjections, - groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ regularAggChildAttrMap.map(_._2), + groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ distinctAggFilterAttrs ++ + regularAggChildAttrMap.map(_._2), a.child) // Construct the first aggregate operator. This de-duplicates all the children of @@ -305,7 +381,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { val firstAggregateGroupBy = groupByAttrs ++ distinctAggChildAttrs :+ gid val firstAggregate = Aggregate( firstAggregateGroupBy, - firstAggregateGroupBy ++ regularAggOperatorMap.map(_._2), + firstAggregateGroupBy ++ maxConds ++ regularAggOperatorMap.map(_._2), expand) // Construct the second aggregate @@ -331,6 +407,13 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { } } + private def collectAggregateExprs(a: Aggregate): Seq[AggregateExpression] = { + // Collect all aggregate expressions. + a.aggregateExpressions.flatMap { _.collect { + case ae: AggregateExpression => ae + }} + } + private def nullify(e: Expression) = Literal.create(null, e.dataType) private def expressionAttributePair(e: Expression) = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 166ffec44a60..a99f7e2be6e7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -207,11 +207,6 @@ class AnalysisErrorSuite extends AnalysisTest { "FILTER (WHERE c > 1)"), "FILTER predicate specified, but aggregate is not an aggregate function" :: Nil) - errorTest( - "DISTINCT aggregate function with filter predicate", - CatalystSqlParser.parsePlan("SELECT count(DISTINCT a) FILTER (WHERE c > 1) FROM TaBlE2"), - "DISTINCT and FILTER cannot be used in aggregate functions at the same time" :: Nil) - errorTest( "non-deterministic filter predicate in aggregate functions", CatalystSqlParser.parsePlan("SELECT count(a) FILTER (WHERE rand(int(c)) > 1) FROM TaBlE2"), diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql index 4c1816e93b08..24d303621fae 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql @@ -36,8 +36,13 @@ SELECT COUNT(id) FILTER (WHERE hiredate = date "2001-01-01") FROM emp; SELECT COUNT(id) FILTER (WHERE hiredate = to_date('2001-01-01 00:00:00')) FROM emp; SELECT COUNT(id) FILTER (WHERE hiredate = to_timestamp("2001-01-01 00:00:00")) FROM emp; SELECT COUNT(id) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd") = "2001-01-01") FROM emp; --- [SPARK-30276] Support Filter expression allows simultaneous use of DISTINCT --- SELECT COUNT(DISTINCT id) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") = "2001-01-01 00:00:00") FROM emp; +SELECT COUNT(DISTINCT id) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") = "2001-01-01 00:00:00") FROM emp; +SELECT COUNT(DISTINCT id), COUNT(DISTINCT id) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") = "2001-01-01 00:00:00") FROM emp; +SELECT COUNT(DISTINCT id) FILTER (WHERE hiredate = to_timestamp("2001-01-01 00:00:00")), COUNT(DISTINCT id) FILTER (WHERE hiredate = to_date('2001-01-01 00:00:00')) FROM emp; +SELECT SUM(salary), COUNT(DISTINCT id), COUNT(DISTINCT id) FILTER (WHERE hiredate = date "2001-01-01") FROM emp; +SELECT COUNT(DISTINCT 1) FILTER (WHERE a = 1) FROM testData; +SELECT COUNT(DISTINCT id) FILTER (WHERE true) FROM emp; +SELECT COUNT(DISTINCT id) FILTER (WHERE false) FROM emp; -- Aggregate with filter and non-empty GroupBy expressions. SELECT a, COUNT(b) FILTER (WHERE a >= 2) FROM testData GROUP BY a; @@ -47,8 +52,11 @@ SELECT dept_id, SUM(salary) FILTER (WHERE hiredate > date "2003-01-01") FROM emp SELECT dept_id, SUM(salary) FILTER (WHERE hiredate > to_date("2003-01-01")) FROM emp GROUP BY dept_id; SELECT dept_id, SUM(salary) FILTER (WHERE hiredate > to_timestamp("2003-01-01 00:00:00")) FROM emp GROUP BY dept_id; SELECT dept_id, SUM(salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd") > "2003-01-01") FROM emp GROUP BY dept_id; --- [SPARK-30276] Support Filter expression allows simultaneous use of DISTINCT --- SELECT dept_id, SUM(DISTINCT salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") > "2001-01-01 00:00:00") FROM emp GROUP BY dept_id; +SELECT dept_id, SUM(DISTINCT salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") > "2001-01-01 00:00:00") FROM emp GROUP BY dept_id; +SELECT dept_id, SUM(DISTINCT salary), SUM(DISTINCT salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") > "2001-01-01 00:00:00") FROM emp GROUP BY dept_id; +SELECT dept_id, SUM(DISTINCT salary) FILTER (WHERE hiredate > date "2001-01-01"), SUM(DISTINCT salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") > "2001-01-01 00:00:00") FROM emp GROUP BY dept_id; +SELECT dept_id, COUNT(id), SUM(DISTINCT salary), SUM(DISTINCT salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd") > "2001-01-01") FROM emp GROUP BY dept_id; +SELECT b, COUNT(DISTINCT 1) FILTER (WHERE a = 1) FROM testData GROUP BY b; -- Aggregate with filter and grouped by literals. SELECT 'foo', COUNT(a) FILTER (WHERE b <= 2) FROM testData GROUP BY 1; @@ -61,13 +69,24 @@ select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary), select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary), sum(salary) filter (where id + dept_id > 500) from emp group by dept_id; select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id > 200) from emp group by dept_id; select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id + dept_id > 500) from emp group by dept_id; --- [SPARK-30276] Support Filter expression allows simultaneous use of DISTINCT --- select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate), sum(salary) from emp group by dept_id; --- select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) from emp group by dept_id; --- select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) filter (where salary < 400.00D) from emp group by dept_id; --- select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id > 200) from emp group by dept_id; --- select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct emp_name), sum(salary) from emp group by dept_id; --- select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct emp_name) filter (where hiredate > date "2003-01-01"), sum(salary) from emp group by dept_id; +select dept_id, count(distinct emp_name) filter (where id > 200), sum(salary) from emp group by dept_id; +select dept_id, count(distinct emp_name) filter (where id + dept_id > 500), sum(salary) from emp group by dept_id; +select dept_id, count(distinct emp_name), count(distinct emp_name) filter (where id > 200), sum(salary) from emp group by dept_id; +select dept_id, count(distinct emp_name), count(distinct emp_name) filter (where id + dept_id > 500), sum(salary) from emp group by dept_id; +select dept_id, count(distinct emp_name), count(distinct emp_name) filter (where id > 200), sum(salary), sum(salary) filter (where id > 200) from emp group by dept_id; +select dept_id, count(distinct emp_name), count(distinct emp_name) filter (where id + dept_id > 500), sum(salary), sum(salary) filter (where id > 200) from emp group by dept_id; +select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate), sum(salary) from emp group by dept_id; +select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) from emp group by dept_id; +select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) filter (where salary < 400.00D) from emp group by dept_id; +select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id > 200) from emp group by dept_id; +select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct emp_name), sum(salary) from emp group by dept_id; +select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct emp_name) filter (where hiredate > date "2003-01-01"), sum(salary) from emp group by dept_id; +select dept_id, sum(distinct (id + dept_id)) filter (where id > 200), count(distinct hiredate), sum(salary) from emp group by dept_id; +select dept_id, sum(distinct (id + dept_id)) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) from emp group by dept_id; +select dept_id, avg(distinct (id + dept_id)) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) filter (where salary < 400.00D) from emp group by dept_id; +select dept_id, count(distinct emp_name, hiredate) filter (where id > 200), sum(salary) from emp group by dept_id; +select dept_id, count(distinct emp_name, hiredate) filter (where id > 0), sum(salary) from emp group by dept_id; +select dept_id, count(distinct 1), count(distinct 1) filter (where id > 200), sum(salary) from emp group by dept_id; -- Aggregate with filter and grouped by literals (hash aggregate), here the input table is filtered using WHERE. SELECT 'foo', APPROX_COUNT_DISTINCT(a) FILTER (WHERE b >= 0) FROM testData WHERE a = 0 GROUP BY 1; @@ -81,9 +100,8 @@ SELECT a + 2, COUNT(b) FILTER (WHERE b IN (1, 2)) FROM testData GROUP BY a + 1; SELECT a + 1 + 1, COUNT(b) FILTER (WHERE b > 0) FROM testData GROUP BY a + 1; -- Aggregate with filter, foldable input and multiple distinct groups. --- [SPARK-30276] Support Filter expression allows simultaneous use of DISTINCT --- SELECT COUNT(DISTINCT b) FILTER (WHERE b > 0), COUNT(DISTINCT b, c) FILTER (WHERE b > 0 AND c > 2) --- FROM (SELECT 1 AS a, 2 AS b, 3 AS c) GROUP BY a; +SELECT COUNT(DISTINCT b) FILTER (WHERE b > 0), COUNT(DISTINCT b, c) FILTER (WHERE b > 0 AND c > 2) +FROM (SELECT 1 AS a, 2 AS b, 3 AS c) GROUP BY a; -- Check analysis exceptions SELECT a AS k, COUNT(b) FILTER (WHERE b > 0) FROM testData GROUP BY k; diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql index 746b67723483..657ea59ec8f1 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql @@ -241,10 +241,9 @@ select sum(1/ten) filter (where ten > 0) from tenk1; -- select ten, sum(distinct four) filter (where four::text ~ '123') from onek a -- group by ten; --- [SPARK-30276] Support Filter expression allows simultaneous use of DISTINCT --- select ten, sum(distinct four) filter (where four > 10) from onek a --- group by ten --- having exists (select 1 from onek b where sum(distinct a.four) = b.four); +select ten, sum(distinct four) filter (where four > 10) from onek a +group by ten +having exists (select 1 from onek b where sum(distinct a.four) = b.four); -- [SPARK-28682] ANSI SQL: Collation Support -- select max(foo COLLATE "C") filter (where (bar collate "POSIX") > '0') diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/groupingsets.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/groupingsets.sql index fc54d179f742..45617c53166a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/groupingsets.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/groupingsets.sql @@ -336,9 +336,8 @@ order by 2,1; -- order by 2,1; -- FILTER queries --- [SPARK-30276] Support Filter expression allows simultaneous use of DISTINCT --- select ten, sum(distinct four) filter (where string(four) like '123') from onek a --- group by rollup(ten); +select ten, sum(distinct four) filter (where string(four) like '123') from onek a +group by rollup(ten); -- More rescan tests -- [SPARK-27877] ANSI SQL: LATERAL derived table(T491) diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out index d41d25280146..c349d9d84c22 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 37 +-- Number of queries: 68 -- !query @@ -94,6 +94,62 @@ struct +-- !query output +2 + + +-- !query +SELECT COUNT(DISTINCT id), COUNT(DISTINCT id) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") = "2001-01-01 00:00:00") FROM emp +-- !query schema +struct +-- !query output +8 2 + + +-- !query +SELECT COUNT(DISTINCT id) FILTER (WHERE hiredate = to_timestamp("2001-01-01 00:00:00")), COUNT(DISTINCT id) FILTER (WHERE hiredate = to_date('2001-01-01 00:00:00')) FROM emp +-- !query schema +struct +-- !query output +2 2 + + +-- !query +SELECT SUM(salary), COUNT(DISTINCT id), COUNT(DISTINCT id) FILTER (WHERE hiredate = date "2001-01-01") FROM emp +-- !query schema +struct +-- !query output +2450.0 8 2 + + +-- !query +SELECT COUNT(DISTINCT 1) FILTER (WHERE a = 1) FROM testData +-- !query schema +struct +-- !query output +1 + + +-- !query +SELECT COUNT(DISTINCT id) FILTER (WHERE true) FROM emp +-- !query schema +struct +-- !query output +8 + + +-- !query +SELECT COUNT(DISTINCT id) FILTER (WHERE false) FROM emp +-- !query schema +struct +-- !query output +0 + + -- !query SELECT a, COUNT(b) FILTER (WHERE a >= 2) FROM testData GROUP BY a -- !query schema @@ -177,6 +233,68 @@ struct "2001-01-01 00:00:00") FROM emp GROUP BY dept_id +-- !query schema +struct 2001-01-01 00:00:00)):double> +-- !query output +10 300.0 +100 400.0 +20 300.0 +30 400.0 +70 150.0 +NULL NULL + + +-- !query +SELECT dept_id, SUM(DISTINCT salary), SUM(DISTINCT salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") > "2001-01-01 00:00:00") FROM emp GROUP BY dept_id +-- !query schema +struct 2001-01-01 00:00:00)):double> +-- !query output +10 300.0 300.0 +100 400.0 400.0 +20 300.0 300.0 +30 400.0 400.0 +70 150.0 150.0 +NULL 400.0 NULL + + +-- !query +SELECT dept_id, SUM(DISTINCT salary) FILTER (WHERE hiredate > date "2001-01-01"), SUM(DISTINCT salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") > "2001-01-01 00:00:00") FROM emp GROUP BY dept_id +-- !query schema +struct DATE '2001-01-01')):double,sum(DISTINCT salary) FILTER (WHERE (date_format(CAST(hiredate AS TIMESTAMP), yyyy-MM-dd HH:mm:ss) > 2001-01-01 00:00:00)):double> +-- !query output +10 300.0 300.0 +100 400.0 400.0 +20 300.0 300.0 +30 400.0 400.0 +70 150.0 150.0 +NULL NULL NULL + + +-- !query +SELECT dept_id, COUNT(id), SUM(DISTINCT salary), SUM(DISTINCT salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd") > "2001-01-01") FROM emp GROUP BY dept_id +-- !query schema +struct 2001-01-01)):double> +-- !query output +10 3 300.0 300.0 +100 2 400.0 400.0 +20 1 300.0 300.0 +30 1 400.0 400.0 +70 1 150.0 150.0 +NULL 1 400.0 NULL + + +-- !query +SELECT b, COUNT(DISTINCT 1) FILTER (WHERE a = 1) FROM testData GROUP BY b +-- !query schema +struct +-- !query output +1 1 +2 1 +NULL 0 + + -- !query SELECT 'foo', COUNT(a) FILTER (WHERE b <= 2) FROM testData GROUP BY 1 -- !query schema @@ -261,6 +379,240 @@ struct 200), sum(salary) from emp group by dept_id +-- !query schema +struct 200)):bigint,sum(salary):double> +-- !query output +10 0 400.0 +100 2 800.0 +20 1 300.0 +30 1 400.0 +70 1 150.0 +NULL 1 400.0 + + +-- !query +select dept_id, count(distinct emp_name) filter (where id + dept_id > 500), sum(salary) from emp group by dept_id +-- !query schema +struct 500)):bigint,sum(salary):double> +-- !query output +10 0 400.0 +100 2 800.0 +20 0 300.0 +30 0 400.0 +70 1 150.0 +NULL 0 400.0 + + +-- !query +select dept_id, count(distinct emp_name), count(distinct emp_name) filter (where id > 200), sum(salary) from emp group by dept_id +-- !query schema +struct 200)):bigint,sum(salary):double> +-- !query output +10 2 0 400.0 +100 2 2 800.0 +20 1 1 300.0 +30 1 1 400.0 +70 1 1 150.0 +NULL 1 1 400.0 + + +-- !query +select dept_id, count(distinct emp_name), count(distinct emp_name) filter (where id + dept_id > 500), sum(salary) from emp group by dept_id +-- !query schema +struct 500)):bigint,sum(salary):double> +-- !query output +10 2 0 400.0 +100 2 2 800.0 +20 1 0 300.0 +30 1 0 400.0 +70 1 1 150.0 +NULL 1 0 400.0 + + +-- !query +select dept_id, count(distinct emp_name), count(distinct emp_name) filter (where id > 200), sum(salary), sum(salary) filter (where id > 200) from emp group by dept_id +-- !query schema +struct 200)):bigint,sum(salary):double,sum(salary) FILTER (WHERE (id > 200)):double> +-- !query output +10 2 0 400.0 NULL +100 2 2 800.0 800.0 +20 1 1 300.0 300.0 +30 1 1 400.0 400.0 +70 1 1 150.0 150.0 +NULL 1 1 400.0 400.0 + + +-- !query +select dept_id, count(distinct emp_name), count(distinct emp_name) filter (where id + dept_id > 500), sum(salary), sum(salary) filter (where id > 200) from emp group by dept_id +-- !query schema +struct 500)):bigint,sum(salary):double,sum(salary) FILTER (WHERE (id > 200)):double> +-- !query output +10 2 0 400.0 NULL +100 2 2 800.0 800.0 +20 1 0 300.0 300.0 +30 1 0 400.0 400.0 +70 1 1 150.0 150.0 +NULL 1 0 400.0 400.0 + + +-- !query +select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate), sum(salary) from emp group by dept_id +-- !query schema +struct 200)):bigint,count(DISTINCT hiredate):bigint,sum(salary):double> +-- !query output +10 0 2 400.0 +100 2 2 800.0 +20 1 1 300.0 +30 1 1 400.0 +70 1 1 150.0 +NULL 1 1 400.0 + + +-- !query +select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) from emp group by dept_id +-- !query schema +struct 200)):bigint,count(DISTINCT hiredate) FILTER (WHERE (hiredate > DATE '2003-01-01')):bigint,sum(salary):double> +-- !query output +10 0 1 400.0 +100 2 1 800.0 +20 1 0 300.0 +30 1 1 400.0 +70 1 1 150.0 +NULL 1 0 400.0 + + +-- !query +select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) filter (where salary < 400.00D) from emp group by dept_id +-- !query schema +struct 200)):bigint,count(DISTINCT hiredate) FILTER (WHERE (hiredate > DATE '2003-01-01')):bigint,sum(salary) FILTER (WHERE (salary < 400.0)):double> +-- !query output +10 0 1 400.0 +100 2 1 NULL +20 1 0 300.0 +30 1 1 NULL +70 1 1 150.0 +NULL 1 0 NULL + + +-- !query +select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id > 200) from emp group by dept_id +-- !query schema +struct 200)):bigint,count(DISTINCT hiredate) FILTER (WHERE (hiredate > DATE '2003-01-01')):bigint,sum(salary) FILTER (WHERE (salary < 400.0)):double,sum(salary) FILTER (WHERE (id > 200)):double> +-- !query output +10 0 1 400.0 NULL +100 2 1 NULL 800.0 +20 1 0 300.0 300.0 +30 1 1 NULL 400.0 +70 1 1 150.0 150.0 +NULL 1 0 NULL 400.0 + + +-- !query +select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct emp_name), sum(salary) from emp group by dept_id +-- !query schema +struct 200)):bigint,count(DISTINCT emp_name):bigint,sum(salary):double> +-- !query output +10 0 2 400.0 +100 2 2 800.0 +20 1 1 300.0 +30 1 1 400.0 +70 1 1 150.0 +NULL 1 1 400.0 + + +-- !query +select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct emp_name) filter (where hiredate > date "2003-01-01"), sum(salary) from emp group by dept_id +-- !query schema +struct 200)):bigint,count(DISTINCT emp_name) FILTER (WHERE (hiredate > DATE '2003-01-01')):bigint,sum(salary):double> +-- !query output +10 0 1 400.0 +100 2 1 800.0 +20 1 0 300.0 +30 1 1 400.0 +70 1 1 150.0 +NULL 1 0 400.0 + + +-- !query +select dept_id, sum(distinct (id + dept_id)) filter (where id > 200), count(distinct hiredate), sum(salary) from emp group by dept_id +-- !query schema +struct 200)):bigint,count(DISTINCT hiredate):bigint,sum(salary):double> +-- !query output +10 NULL 2 400.0 +100 1500 2 800.0 +20 320 1 300.0 +30 430 1 400.0 +70 870 1 150.0 +NULL NULL 1 400.0 + + +-- !query +select dept_id, sum(distinct (id + dept_id)) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) from emp group by dept_id +-- !query schema +struct 200)):bigint,count(DISTINCT hiredate) FILTER (WHERE (hiredate > DATE '2003-01-01')):bigint,sum(salary):double> +-- !query output +10 NULL 1 400.0 +100 1500 1 800.0 +20 320 0 300.0 +30 430 1 400.0 +70 870 1 150.0 +NULL NULL 0 400.0 + + +-- !query +select dept_id, avg(distinct (id + dept_id)) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) filter (where salary < 400.00D) from emp group by dept_id +-- !query schema +struct 200)):double,count(DISTINCT hiredate) FILTER (WHERE (hiredate > DATE '2003-01-01')):bigint,sum(salary) FILTER (WHERE (salary < 400.0)):double> +-- !query output +10 NULL 1 400.0 +100 750.0 1 NULL +20 320.0 0 300.0 +30 430.0 1 NULL +70 870.0 1 150.0 +NULL NULL 0 NULL + + +-- !query +select dept_id, count(distinct emp_name, hiredate) filter (where id > 200), sum(salary) from emp group by dept_id +-- !query schema +struct 200)):bigint,sum(salary):double> +-- !query output +10 0 400.0 +100 2 800.0 +20 1 300.0 +30 1 400.0 +70 1 150.0 +NULL 1 400.0 + + +-- !query +select dept_id, count(distinct emp_name, hiredate) filter (where id > 0), sum(salary) from emp group by dept_id +-- !query schema +struct 0)):bigint,sum(salary):double> +-- !query output +10 2 400.0 +100 2 800.0 +20 1 300.0 +30 1 400.0 +70 1 150.0 +NULL 1 400.0 + + +-- !query +select dept_id, count(distinct 1), count(distinct 1) filter (where id > 200), sum(salary) from emp group by dept_id +-- !query schema +struct 200)):bigint,sum(salary):double> +-- !query output +10 1 0 400.0 +100 1 1 800.0 +20 1 1 300.0 +30 1 1 400.0 +70 1 1 150.0 +NULL 1 1 400.0 + + -- !query SELECT 'foo', APPROX_COUNT_DISTINCT(a) FILTER (WHERE b >= 0) FROM testData WHERE a = 0 GROUP BY 1 -- !query schema @@ -309,6 +661,15 @@ struct<((a + 1) + 1):int,count(b) FILTER (WHERE (b > 0)):bigint> NULL 1 +-- !query +SELECT COUNT(DISTINCT b) FILTER (WHERE b > 0), COUNT(DISTINCT b, c) FILTER (WHERE b > 0 AND c > 2) +FROM (SELECT 1 AS a, 2 AS b, 3 AS c) GROUP BY a +-- !query schema +struct 0)):bigint,count(DISTINCT b, c) FILTER (WHERE ((b > 0) AND (c > 2))):bigint> +-- !query output +1 1 + + -- !query SELECT a AS k, COUNT(b) FILTER (WHERE b > 0) FROM testData GROUP BY k -- !query schema diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/aggregates_part3.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/aggregates_part3.sql.out index 69f96b02782e..e1f735e5fe1d 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/aggregates_part3.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/aggregates_part3.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 4 +-- Number of queries: 5 -- !query @@ -27,6 +27,20 @@ struct 0)):d 2828.9682539682954 +-- !query +select ten, sum(distinct four) filter (where four > 10) from onek a +group by ten +having exists (select 1 from onek b where sum(distinct a.four) = b.four) +-- !query schema +struct 10)):bigint> +-- !query output +0 NULL +2 NULL +4 NULL +6 NULL +8 NULL + + -- !query select (select count(*) from (values (1)) t0(inner_c)) diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/groupingsets.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/groupingsets.sql.out index 7312c2087629..2619634d7d56 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/groupingsets.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/groupingsets.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 54 +-- Number of queries: 55 -- !query @@ -443,6 +443,25 @@ struct NULL 1 +-- !query +select ten, sum(distinct four) filter (where string(four) like '123') from onek a +group by rollup(ten) +-- !query schema +struct +-- !query output +0 NULL +1 NULL +2 NULL +3 NULL +4 NULL +5 NULL +6 NULL +7 NULL +8 NULL +9 NULL +NULL NULL + + -- !query select count(*) from gstest4 group by rollup(unhashable_col,unsortable_col) -- !query schema