diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 56cc2a274bb7..75f1aa7185ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2428,6 +2428,10 @@ class Analyzer( } wsc.copy(partitionSpec = newPartitionSpec, orderSpec = newOrderSpec) + case WindowExpression(ae: AggregateExpression, _) if ae.filter.isDefined => + failAnalysis( + "window aggregate function with filter predicate is not supported yet.") + // Extract Windowed AggregateExpression case we @ WindowExpression( ae @ AggregateExpression(function, _, _, _, _), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 4ec737fd9b70..e769e038c960 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -308,6 +308,9 @@ trait CheckAnalysis extends PredicateHelper { case a: AggregateExpression if a.isDistinct => e.failAnalysis( "distinct aggregates are not allowed in observed metrics, but found: " + s.sql) + case a: AggregateExpression if a.filter.isDefined => + e.failAnalysis("aggregates with filter predicate are not allowed in " + + "observed metrics, but found: " + s.sql) case _: Attribute if !seenAggregate => e.failAnalysis (s"attribute ${s.sql} can only be used as an argument to an " + "aggregate function.") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 7023dbe2a367..5cc0453135c0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -164,6 +164,22 @@ class AnalysisErrorSuite extends AnalysisTest { UnspecifiedFrame)).as("window")), "Distinct window functions are not supported" :: Nil) + errorTest( + "window aggregate function with filter predicate", + testRelation2.select( + WindowExpression( + AggregateExpression( + Count(UnresolvedAttribute("b")), + Complete, + isDistinct = false, + filter = Some(UnresolvedAttribute("b") > 1)), + WindowSpecDefinition( + UnresolvedAttribute("a") :: Nil, + SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, + UnspecifiedFrame)).as("window")), + "window aggregate function with filter predicate is not supported" :: Nil + ) + errorTest( "distinct function", CatalystSqlParser.parsePlan("SELECT hex(DISTINCT a) FROM TaBlE"), @@ -191,12 +207,12 @@ class AnalysisErrorSuite extends AnalysisTest { "FILTER predicate specified, but aggregate is not an aggregate function" :: Nil) errorTest( - "DISTINCT and FILTER cannot be used in aggregate functions at the same time", + "DISTINCT aggregate function with filter predicate", CatalystSqlParser.parsePlan("SELECT count(DISTINCT a) FILTER (WHERE c > 1) FROM TaBlE2"), "DISTINCT and FILTER cannot be used in aggregate functions at the same time" :: Nil) errorTest( - "FILTER expression is non-deterministic, it cannot be used in aggregate functions", + "non-deterministic filter predicate in aggregate functions", CatalystSqlParser.parsePlan("SELECT count(a) FILTER (WHERE rand(int(c)) > 1) FROM TaBlE2"), "FILTER expression is non-deterministic, it cannot be used in aggregate functions" :: Nil) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 5405009c9e20..c747d394b1bc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{Count, Sum} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Count, Sum} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan import org.apache.spark.sql.catalyst.plans.{Cross, Inner} import org.apache.spark.sql.catalyst.plans.logical._ @@ -736,5 +736,13 @@ class AnalysisSuite extends AnalysisTest with Matchers { b :: ScalarSubquery(subquery, Nil).as("sum") :: Nil, CollectMetrics("evt1", count :: Nil, tblB)) assertAnalysisError(query, "Multiple definitions of observed metrics" :: "evt1" :: Nil) + + // Aggregate with filter predicate - fail + val sumWithFilter = sum.transform { + case a: AggregateExpression => a.copy(filter = Some(true)) + }.asInstanceOf[NamedExpression] + assertAnalysisError( + CollectMetrics("evt1", sumWithFilter :: Nil, testRelation), + "aggregates with filter predicate are not allowed" :: Nil) } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/window.sql b/sql/core/src/test/resources/sql-tests/inputs/window.sql index e25a25241830..3d05dfda6c3f 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/window.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/window.sql @@ -120,3 +120,8 @@ SELECT cate, sum(val) OVER (w) FROM testData WHERE val is not null WINDOW w AS (PARTITION BY cate ORDER BY val); + +-- with filter predicate +SELECT val, cate, +count(val) FILTER (WHERE val > 1) OVER(PARTITION BY cate) +FROM testData ORDER BY cate, val; \ No newline at end of file diff --git a/sql/core/src/test/resources/sql-tests/results/window.sql.out b/sql/core/src/test/resources/sql-tests/results/window.sql.out index f795374735f5..625088f90ced 100644 --- a/sql/core/src/test/resources/sql-tests/results/window.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/window.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 23 +-- Number of queries: 24 -- !query @@ -380,3 +380,14 @@ a 4 b 1 b 3 b 6 + + +-- !query +SELECT val, cate, +count(val) FILTER (WHERE val > 1) OVER(PARTITION BY cate) +FROM testData ORDER BY cate, val +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +window aggregate function with filter predicate is not supported yet.;