Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
2dc9db4
Support distinct with filter
beliefer Feb 1, 2020
6a32d83
Add results of test case
beliefer Feb 1, 2020
5c38bbe
Optimize code
beliefer Feb 2, 2020
a6498f9
Fix incorrect sql
beliefer Feb 2, 2020
c6caf73
Resolve conflict
beliefer Feb 7, 2020
cd00f91
Fix conflict
beliefer Feb 7, 2020
4a6f903
Reuse completeNextStageWithFetchFailure
beliefer Jun 19, 2020
96456e2
Merge remote-tracking branch 'upstream/master'
beliefer Jul 1, 2020
4314005
Merge remote-tracking branch 'upstream/master'
beliefer Jul 3, 2020
bd314cb
Merge branch 'master' into same_distinct_aggregate_with_filter
beliefer Jul 3, 2020
a56f2b0
Optimize comments
beliefer Jul 3, 2020
7d6ada4
Expand to Project
beliefer Jul 3, 2020
529b69e
Expand to Project
beliefer Jul 3, 2020
54f6d84
change Expand to Project
beliefer Jul 4, 2020
a7bcbc9
Optimize code
beliefer Jul 4, 2020
73dc600
Optimize code
beliefer Jul 4, 2020
5a4ca02
Supplement docs.
beliefer Jul 7, 2020
70ff08e
Merge project with expand
beliefer Jul 8, 2020
16d8c1d
Merge project with expand
beliefer Jul 8, 2020
d6af4a7
Merge remote-tracking branch 'upstream/master'
beliefer Jul 9, 2020
5cd1439
Merge branch 'master' into same_distinct_aggregate_with_filter
beliefer Jul 9, 2020
3c49156
Supplement comments.
beliefer Jul 9, 2020
762e839
Optimize code.
beliefer Jul 9, 2020
12e6fbc
Unified implementation of filter in regular aggregates and distinct a…
beliefer Jul 13, 2020
d531864
Update comments.
beliefer Jul 13, 2020
5bbbfd7
Optimize code.
beliefer Jul 13, 2020
20ad143
Update comments.
beliefer Jul 14, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1920,15 +1920,9 @@ class Analyzer(
}
// We get an aggregate function, we need to wrap it in an AggregateExpression.
case agg: AggregateFunction =>
// TODO: SPARK-30276 Support Filter expression allows simultaneous use of DISTINCT
if (filter.isDefined) {
if (isDistinct) {
failAnalysis("DISTINCT and FILTER cannot be used in aggregate functions " +
"at the same time")
} else if (!filter.get.deterministic) {
failAnalysis("FILTER expression is non-deterministic, " +
"it cannot be used in aggregate functions")
}
if (filter.isDefined && !filter.get.deterministic) {
failAnalysis("FILTER expression is non-deterministic, " +
"it cannot be used in aggregate functions")
}
AggregateExpression(agg, Complete, isDistinct, filter)
// This function is not an aggregate function, just return the resolved one.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ package object dsl {
def count(e: Expression): Expression = Count(e).toAggregateExpression()
def countDistinct(e: Expression*): Expression =
Count(e).toAggregateExpression(isDistinct = true)
def countDistinct(filter: Option[Expression], e: Expression*): Expression =
Count(e).toAggregateExpression(isDistinct = true, filter = filter)
def approxCountDistinct(e: Expression, rsd: Double = 0.05): Expression =
HyperLogLogPlusPlus(e, rsd).toAggregateExpression()
def avg(e: Expression): Expression = Average(e).toAggregateExpression()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,15 +216,21 @@ abstract class AggregateFunction extends Expression {
def toAggregateExpression(): AggregateExpression = toAggregateExpression(isDistinct = false)

/**
* Wraps this [[AggregateFunction]] in an [[AggregateExpression]] and sets `isDistinct`
* flag of the [[AggregateExpression]] to the given value because
* Wraps this [[AggregateFunction]] in an [[AggregateExpression]] with `isDistinct`
* flag and an optional `filter` of the [[AggregateExpression]] to the given value because
* [[AggregateExpression]] is the container of an [[AggregateFunction]], aggregation mode,
* and the flag indicating if this aggregation is distinct aggregation or not.
* An [[AggregateFunction]] should not be used without being wrapped in
* the flag indicating if this aggregation is distinct aggregation or not and the optional
* `filter`. An [[AggregateFunction]] should not be used without being wrapped in
* an [[AggregateExpression]].
*/
def toAggregateExpression(isDistinct: Boolean): AggregateExpression = {
AggregateExpression(aggregateFunction = this, mode = Complete, isDistinct = isDistinct)
def toAggregateExpression(
isDistinct: Boolean,
filter: Option[Expression] = None): AggregateExpression = {
AggregateExpression(
aggregateFunction = this,
mode = Complete,
isDistinct = isDistinct,
filter = filter)
}

def sql(isDistinct: Boolean): String = {
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -207,11 +207,6 @@ class AnalysisErrorSuite extends AnalysisTest {
"FILTER (WHERE c > 1)"),
"FILTER predicate specified, but aggregate is not an aggregate function" :: Nil)

errorTest(
"DISTINCT aggregate function with filter predicate",
CatalystSqlParser.parsePlan("SELECT count(DISTINCT a) FILTER (WHERE c > 1) FROM TaBlE2"),
"DISTINCT and FILTER cannot be used in aggregate functions at the same time" :: Nil)

errorTest(
"non-deterministic filter predicate in aggregate functions",
CatalystSqlParser.parsePlan("SELECT count(a) FILTER (WHERE rand(int(c)) > 1) FROM TaBlE2"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.expressions.{EqualTo, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.CollectSet
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, GROUP_BY_ORDINAL}
import org.apache.spark.sql.types.{IntegerType, StringType}
Expand All @@ -42,6 +42,16 @@ class RewriteDistinctAggregatesSuite extends PlanTest {
case _ => fail(s"Plan is not rewritten:\n$rewrite")
}

private def checkGenerate(generate: LogicalPlan): Unit = generate match {
case Aggregate(_, _, _: Project) =>
case _ => fail(s"Plan is not generated:\n$generate")
}

private def checkGenerateAndRewrite(rewrite: LogicalPlan): Unit = rewrite match {
case Aggregate(_, _, Aggregate(_, _, _: Expand)) =>
case _ => fail(s"Plan is not rewritten:\n$rewrite")
}

test("single distinct group") {
val input = testRelation
.groupBy('a)(countDistinct('e))
Expand All @@ -50,6 +60,13 @@ class RewriteDistinctAggregatesSuite extends PlanTest {
comparePlans(input, rewrite)
}

test("single distinct group with filter") {
val input = testRelation
.groupBy('a)(countDistinct(Some(EqualTo('d, Literal(""))), 'e))
.analyze
checkGenerate(RewriteDistinctAggregates(input))
}

test("single distinct group with partial aggregates") {
val input = testRelation
.groupBy('a, 'd)(
Expand All @@ -67,6 +84,13 @@ class RewriteDistinctAggregatesSuite extends PlanTest {
checkRewrite(RewriteDistinctAggregates(input))
}

test("multiple distinct groups with filter") {
val input = testRelation
.groupBy('a)(countDistinct(Some(EqualTo('d, Literal(""))), 'b, 'c), countDistinct('d))
.analyze
checkGenerateAndRewrite(RewriteDistinctAggregates(input))
}

test("multiple distinct groups with partial aggregates") {
val input = testRelation
.groupBy('a)(countDistinct('b, 'c), countDistinct('d), sum('e))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,44 +157,19 @@ abstract class AggregationIterator(
inputAttributes: Seq[Attribute]): (InternalRow, InternalRow) => Unit = {
val joinedRow = new JoinedRow
if (expressions.nonEmpty) {
val mergeExpressions =
functions.zip(expressions.map(ae => (ae.mode, ae.isDistinct, ae.filter))).flatMap {
case (ae: DeclarativeAggregate, (mode, isDistinct, filter)) =>
mode match {
case Partial | Complete =>
if (filter.isDefined) {
ae.updateExpressions.zip(ae.aggBufferAttributes).map {
case (updateExpr, attr) => If(filter.get, updateExpr, attr)
}
} else {
ae.updateExpressions
}
case PartialMerge | Final => ae.mergeExpressions
}
case (agg: AggregateFunction, _) => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
}
// Initialize predicates for aggregate functions if necessary
val predicateOptions = expressions.map {
case AggregateExpression(_, mode, _, Some(filter), _) =>
mode match {
case Partial | Complete =>
val predicate = Predicate.create(filter, inputAttributes)
predicate.initialize(partIndex)
Some(predicate)
case _ => None
val mergeExpressions = functions.zip(expressions).flatMap {
case (ae: DeclarativeAggregate, expression) =>
expression.mode match {
case Partial | Complete => ae.updateExpressions
case PartialMerge | Final => ae.mergeExpressions
}
case _ => None
case (agg: AggregateFunction, _) => Seq.fill(agg.aggBufferAttributes.length)(NoOp)
}
val updateFunctions = functions.zipWithIndex.collect {
case (ae: ImperativeAggregate, i) =>
expressions(i).mode match {
case Partial | Complete =>
if (predicateOptions(i).isDefined) {
(buffer: InternalRow, row: InternalRow) =>
if (predicateOptions(i).get.eval(row)) { ae.update(buffer, row) }
} else {
(buffer: InternalRow, row: InternalRow) => ae.update(buffer, row)
}
(buffer: InternalRow, row: InternalRow) => ae.update(buffer, row)
case PartialMerge | Final =>
(buffer: InternalRow, row: InternalRow) => ae.merge(buffer, row)
}
Expand Down
41 changes: 27 additions & 14 deletions sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ SELECT COUNT(id) FILTER (WHERE hiredate = date "2001-01-01") FROM emp;
SELECT COUNT(id) FILTER (WHERE hiredate = to_date('2001-01-01 00:00:00')) FROM emp;
SELECT COUNT(id) FILTER (WHERE hiredate = to_timestamp("2001-01-01 00:00:00")) FROM emp;
SELECT COUNT(id) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd") = "2001-01-01") FROM emp;
-- [SPARK-30276] Support Filter expression allows simultaneous use of DISTINCT
-- SELECT COUNT(DISTINCT id) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") = "2001-01-01 00:00:00") FROM emp;
SELECT COUNT(DISTINCT id) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") = "2001-01-01 00:00:00") FROM emp;
SELECT COUNT(DISTINCT id), COUNT(DISTINCT id) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") = "2001-01-01 00:00:00") FROM emp;
SELECT COUNT(DISTINCT id) FILTER (WHERE hiredate = to_timestamp("2001-01-01 00:00:00")), COUNT(DISTINCT id) FILTER (WHERE hiredate = to_date('2001-01-01 00:00:00')) FROM emp;
SELECT SUM(salary), COUNT(DISTINCT id), COUNT(DISTINCT id) FILTER (WHERE hiredate = date "2001-01-01") FROM emp;

-- Aggregate with filter and non-empty GroupBy expressions.
SELECT a, COUNT(b) FILTER (WHERE a >= 2) FROM testData GROUP BY a;
Expand All @@ -44,8 +46,10 @@ SELECT dept_id, SUM(salary) FILTER (WHERE hiredate > date "2003-01-01") FROM emp
SELECT dept_id, SUM(salary) FILTER (WHERE hiredate > to_date("2003-01-01")) FROM emp GROUP BY dept_id;
SELECT dept_id, SUM(salary) FILTER (WHERE hiredate > to_timestamp("2003-01-01 00:00:00")) FROM emp GROUP BY dept_id;
SELECT dept_id, SUM(salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd") > "2003-01-01") FROM emp GROUP BY dept_id;
-- [SPARK-30276] Support Filter expression allows simultaneous use of DISTINCT
-- SELECT dept_id, SUM(DISTINCT salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") > "2001-01-01 00:00:00") FROM emp GROUP BY dept_id;
SELECT dept_id, SUM(DISTINCT salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") > "2001-01-01 00:00:00") FROM emp GROUP BY dept_id;
SELECT dept_id, SUM(DISTINCT salary), SUM(DISTINCT salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") > "2001-01-01 00:00:00") FROM emp GROUP BY dept_id;
SELECT dept_id, SUM(DISTINCT salary) FILTER (WHERE hiredate > date "2001-01-01"), SUM(DISTINCT salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd HH:mm:ss") > "2001-01-01 00:00:00") FROM emp GROUP BY dept_id;
SELECT dept_id, COUNT(id), SUM(DISTINCT salary), SUM(DISTINCT salary) FILTER (WHERE date_format(hiredate, "yyyy-MM-dd") > "2001-01-01") FROM emp GROUP BY dept_id;

-- Aggregate with filter and grouped by literals.
SELECT 'foo', COUNT(a) FILTER (WHERE b <= 2) FROM testData GROUP BY 1;
Expand All @@ -58,13 +62,23 @@ select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary),
select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary), sum(salary) filter (where id + dept_id > 500) from emp group by dept_id;
select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id > 200) from emp group by dept_id;
select dept_id, count(distinct emp_name), count(distinct hiredate), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id + dept_id > 500) from emp group by dept_id;
-- [SPARK-30276] Support Filter expression allows simultaneous use of DISTINCT
-- select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate), sum(salary) from emp group by dept_id;
-- select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) from emp group by dept_id;
-- select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) filter (where salary < 400.00D) from emp group by dept_id;
-- select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id > 200) from emp group by dept_id;
-- select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct emp_name), sum(salary) from emp group by dept_id;
-- select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct emp_name) filter (where hiredate > date "2003-01-01"), sum(salary) from emp group by dept_id;
select dept_id, count(distinct emp_name) filter (where id > 200), sum(salary) from emp group by dept_id;
select dept_id, count(distinct emp_name) filter (where id + dept_id > 500), sum(salary) from emp group by dept_id;
select dept_id, count(distinct emp_name), count(distinct emp_name) filter (where id > 200), sum(salary) from emp group by dept_id;
select dept_id, count(distinct emp_name), count(distinct emp_name) filter (where id + dept_id > 500), sum(salary) from emp group by dept_id;
select dept_id, count(distinct emp_name), count(distinct emp_name) filter (where id > 200), sum(salary), sum(salary) filter (where id > 200) from emp group by dept_id;
select dept_id, count(distinct emp_name), count(distinct emp_name) filter (where id + dept_id > 500), sum(salary), sum(salary) filter (where id > 200) from emp group by dept_id;
select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate), sum(salary) from emp group by dept_id;
select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) from emp group by dept_id;
select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) filter (where salary < 400.00D) from emp group by dept_id;
select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) filter (where salary < 400.00D), sum(salary) filter (where id > 200) from emp group by dept_id;
select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct emp_name), sum(salary) from emp group by dept_id;
select dept_id, count(distinct emp_name) filter (where id > 200), count(distinct emp_name) filter (where hiredate > date "2003-01-01"), sum(salary) from emp group by dept_id;
select dept_id, sum(distinct (id + dept_id)) filter (where id > 200), count(distinct hiredate), sum(salary) from emp group by dept_id;
select dept_id, sum(distinct (id + dept_id)) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) from emp group by dept_id;
select dept_id, avg(distinct (id + dept_id)) filter (where id > 200), count(distinct hiredate) filter (where hiredate > date "2003-01-01"), sum(salary) filter (where salary < 400.00D) from emp group by dept_id;
select dept_id, count(distinct emp_name, hiredate) filter (where id > 200), sum(salary) from emp group by dept_id;
select dept_id, count(distinct emp_name, hiredate) filter (where id > 0), sum(salary) from emp group by dept_id;

-- Aggregate with filter and grouped by literals (hash aggregate), here the input table is filtered using WHERE.
SELECT 'foo', APPROX_COUNT_DISTINCT(a) FILTER (WHERE b >= 0) FROM testData WHERE a = 0 GROUP BY 1;
Expand All @@ -78,9 +92,8 @@ SELECT a + 2, COUNT(b) FILTER (WHERE b IN (1, 2)) FROM testData GROUP BY a + 1;
SELECT a + 1 + 1, COUNT(b) FILTER (WHERE b > 0) FROM testData GROUP BY a + 1;

-- Aggregate with filter, foldable input and multiple distinct groups.
-- [SPARK-30276] Support Filter expression allows simultaneous use of DISTINCT
-- SELECT COUNT(DISTINCT b) FILTER (WHERE b > 0), COUNT(DISTINCT b, c) FILTER (WHERE b > 0 AND c > 2)
-- FROM (SELECT 1 AS a, 2 AS b, 3 AS c) GROUP BY a;
SELECT COUNT(DISTINCT b) FILTER (WHERE b > 0), COUNT(DISTINCT b, c) FILTER (WHERE b > 0 AND c > 2)
FROM (SELECT 1 AS a, 2 AS b, 3 AS c) GROUP BY a;

-- Check analysis exceptions
SELECT a AS k, COUNT(b) FILTER (WHERE b > 0) FROM testData GROUP BY k;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,10 +241,9 @@ select sum(1/ten) filter (where ten > 0) from tenk1;
-- select ten, sum(distinct four) filter (where four::text ~ '123') from onek a
-- group by ten;

-- [SPARK-30276] Support Filter expression allows simultaneous use of DISTINCT
-- select ten, sum(distinct four) filter (where four > 10) from onek a
-- group by ten
-- having exists (select 1 from onek b where sum(distinct a.four) = b.four);
select ten, sum(distinct four) filter (where four > 10) from onek a
group by ten
having exists (select 1 from onek b where sum(distinct a.four) = b.four);

-- [SPARK-28682] ANSI SQL: Collation Support
-- select max(foo COLLATE "C") filter (where (bar collate "POSIX") > '0')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -336,9 +336,8 @@ order by 2,1;
-- order by 2,1;

-- FILTER queries
-- [SPARK-30276] Support Filter expression allows simultaneous use of DISTINCT
-- select ten, sum(distinct four) filter (where string(four) like '123') from onek a
-- group by rollup(ten);
select ten, sum(distinct four) filter (where string(four) like '123') from onek a
group by rollup(ten);

-- More rescan tests
-- [SPARK-27877] ANSI SQL: LATERAL derived table(T491)
Expand Down
Loading