diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index 996c548e1329c..17f906c698de2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import org.apache.spark.sql.catalyst.analysis.{DecimalPrecision, FunctionRegistry} +import org.apache.spark.sql.catalyst.analysis.{DecimalPrecision, FunctionRegistry, TypeCheckResult} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @ExpressionDescription( @@ -30,8 +31,6 @@ import org.apache.spark.sql.types._ 2.0 > SELECT _FUNC_(col) FROM VALUES (1), (2), (NULL) AS tab(col); 1.5 - > SELECT _FUNC_(cast(v as interval)) FROM VALUES ('-1 weeks'), ('2 seconds'), (null) t(v); - -3 days -11 hours -59 minutes -59 seconds """, since = "1.0.0") case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes { @@ -40,7 +39,10 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit override def children: Seq[Expression] = child :: Nil - override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval) + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function average") override def nullable: Boolean = true @@ -50,13 +52,11 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit private lazy val resultType = child.dataType match { case DecimalType.Fixed(p, s) => DecimalType.bounded(p + 4, s + 4) - case interval: CalendarIntervalType => interval case _ => DoubleType } private lazy val sumDataType = child.dataType match { case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s) - case interval: CalendarIntervalType => interval case _ => DoubleType } @@ -79,9 +79,6 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit override lazy val evaluateExpression = child.dataType match { case _: DecimalType => DecimalPrecision.decimalAndDecimal(sum / count.cast(DecimalType.LongDecimal)).cast(resultType) - case CalendarIntervalType => - val newCount = If(EqualTo(count, Literal(0L)), Literal(null, LongType), count) - DivideInterval(sum.cast(resultType), newCount.cast(DoubleType)) case _ => sum.cast(resultType) / count.cast(resultType) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index 87f1a4f02e4fc..8bfd889ea0563 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -// scalastyle:off line.size.limit @ExpressionDescription( usage = "_FUNC_(expr) - Returns the sum calculated from values of a group.", examples = """ @@ -33,11 +34,8 @@ import org.apache.spark.sql.types._ 25 > SELECT _FUNC_(col) FROM VALUES (NULL), (NULL) AS tab(col); NULL - > SELECT _FUNC_(cast(col as interval)) FROM VALUES ('1 seconds'), ('2 seconds'), (null) tab(col); - 3 seconds """, since = "1.0.0") -// scalastyle:on line.size.limit case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes { override def children: Seq[Expression] = child :: Nil @@ -47,12 +45,14 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast // Return data type. override def dataType: DataType = resultType - override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval) + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function sum") private lazy val resultType = child.dataType match { case DecimalType.Fixed(precision, scale) => DecimalType.bounded(precision + 10, scale) - case _: CalendarIntervalType => CalendarIntervalType case _: IntegralType => LongType case _ => DoubleType } @@ -61,7 +61,7 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast private lazy val sum = AttributeReference("sum", sumDataType)() - private lazy val zero = Literal.default(resultType) + private lazy val zero = Literal.default(sumDataType) override lazy val aggBufferAttributes = sum :: Nil diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 86a1f1fb58a07..46634c93148b1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -158,8 +158,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertError(Min(Symbol("mapField")), "min does not support ordering on type") assertError(Max(Symbol("mapField")), "max does not support ordering on type") - assertError(Sum(Symbol("booleanField")), "requires (numeric or interval) type") - assertError(Average(Symbol("booleanField")), "requires (numeric or interval) type") + assertError(Sum(Symbol("booleanField")), "function sum requires numeric type") + assertError(Average(Symbol("booleanField")), "function average requires numeric type") } test("check types for others") { diff --git a/sql/core/src/test/resources/sql-tests/inputs/interval.sql b/sql/core/src/test/resources/sql-tests/inputs/interval.sql index a4e621e9639d4..facd6321a1bd4 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/interval.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/interval.sql @@ -84,70 +84,6 @@ select interval (-30) day; select interval (a + 1) day; select interval 30 day day day; --- sum interval values --- null -select sum(cast(null as interval)); - --- empty set -select sum(cast(v as interval)) from VALUES ('1 seconds') t(v) where 1=0; - --- basic interval sum -select sum(cast(v as interval)) from VALUES ('1 seconds'), ('2 seconds'), (null) t(v); -select sum(cast(v as interval)) from VALUES ('-1 seconds'), ('2 seconds'), (null) t(v); -select sum(cast(v as interval)) from VALUES ('-1 seconds'), ('-2 seconds'), (null) t(v); -select sum(cast(v as interval)) from VALUES ('-1 weeks'), ('2 seconds'), (null) t(v); - --- group by -select - i, - sum(cast(v as interval)) -from VALUES (1, '-1 weeks'), (2, '2 seconds'), (3, null), (1, '5 days') t(i, v) -group by i; - --- having -select - sum(cast(v as interval)) as sv -from VALUES (1, '-1 weeks'), (2, '2 seconds'), (3, null), (1, '5 days') t(i, v) -having sv is not null; - --- window -SELECT - i, - sum(cast(v as interval)) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) -FROM VALUES(1, '1 seconds'), (1, '2 seconds'), (2, NULL), (2, NULL) t(i,v); - --- average with interval type --- null -select avg(cast(v as interval)) from VALUES (null) t(v); - --- empty set -select avg(cast(v as interval)) from VALUES ('1 seconds'), ('2 seconds'), (null) t(v) where 1=0; - --- basic interval avg -select avg(cast(v as interval)) from VALUES ('1 seconds'), ('2 seconds'), (null) t(v); -select avg(cast(v as interval)) from VALUES ('-1 seconds'), ('2 seconds'), (null) t(v); -select avg(cast(v as interval)) from VALUES ('-1 seconds'), ('-2 seconds'), (null) t(v); -select avg(cast(v as interval)) from VALUES ('-1 weeks'), ('2 seconds'), (null) t(v); - --- group by -select - i, - avg(cast(v as interval)) -from VALUES (1, '-1 weeks'), (2, '2 seconds'), (3, null), (1, '5 days') t(i, v) -group by i; - --- having -select - avg(cast(v as interval)) as sv -from VALUES (1, '-1 weeks'), (2, '2 seconds'), (3, null), (1, '5 days') t(i, v) -having sv is not null; - --- window -SELECT - i, - avg(cast(v as interval)) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) -FROM VALUES (1,'1 seconds'), (1,'2 seconds'), (2,NULL), (2,NULL) t(i,v); - -- Interval year-month arithmetic create temporary view interval_arithmetic as diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out index 7fdb4c53d1dcb..4b465406ee0ed 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 99 +-- Number of queries: 81 -- !query @@ -631,180 +631,6 @@ select interval 30 day day day -----------------------^^^ --- !query -select sum(cast(null as interval)) --- !query schema -struct --- !query output -NULL - - --- !query -select sum(cast(v as interval)) from VALUES ('1 seconds') t(v) where 1=0 --- !query schema -struct --- !query output -NULL - - --- !query -select sum(cast(v as interval)) from VALUES ('1 seconds'), ('2 seconds'), (null) t(v) --- !query schema -struct --- !query output -3 seconds - - --- !query -select sum(cast(v as interval)) from VALUES ('-1 seconds'), ('2 seconds'), (null) t(v) --- !query schema -struct --- !query output -1 seconds - - --- !query -select sum(cast(v as interval)) from VALUES ('-1 seconds'), ('-2 seconds'), (null) t(v) --- !query schema -struct --- !query output --3 seconds - - --- !query -select sum(cast(v as interval)) from VALUES ('-1 weeks'), ('2 seconds'), (null) t(v) --- !query schema -struct --- !query output --7 days 2 seconds - - --- !query -select - i, - sum(cast(v as interval)) -from VALUES (1, '-1 weeks'), (2, '2 seconds'), (3, null), (1, '5 days') t(i, v) -group by i --- !query schema -struct --- !query output -1 -2 days -2 2 seconds -3 NULL - - --- !query -select - sum(cast(v as interval)) as sv -from VALUES (1, '-1 weeks'), (2, '2 seconds'), (3, null), (1, '5 days') t(i, v) -having sv is not null --- !query schema -struct --- !query output --2 days 2 seconds - - --- !query -SELECT - i, - sum(cast(v as interval)) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) -FROM VALUES(1, '1 seconds'), (1, '2 seconds'), (2, NULL), (2, NULL) t(i,v) --- !query schema -struct --- !query output -1 2 seconds -1 3 seconds -2 NULL -2 NULL - - --- !query -select avg(cast(v as interval)) from VALUES (null) t(v) --- !query schema -struct --- !query output -NULL - - --- !query -select avg(cast(v as interval)) from VALUES ('1 seconds'), ('2 seconds'), (null) t(v) where 1=0 --- !query schema -struct --- !query output -NULL - - --- !query -select avg(cast(v as interval)) from VALUES ('1 seconds'), ('2 seconds'), (null) t(v) --- !query schema -struct --- !query output -1.5 seconds - - --- !query -select avg(cast(v as interval)) from VALUES ('-1 seconds'), ('2 seconds'), (null) t(v) --- !query schema -struct --- !query output -0.5 seconds - - --- !query -select avg(cast(v as interval)) from VALUES ('-1 seconds'), ('-2 seconds'), (null) t(v) --- !query schema -struct --- !query output --1.5 seconds - - --- !query -select avg(cast(v as interval)) from VALUES ('-1 weeks'), ('2 seconds'), (null) t(v) --- !query schema -struct --- !query output --3 days -11 hours -59 minutes -59 seconds - - --- !query -select - i, - avg(cast(v as interval)) -from VALUES (1, '-1 weeks'), (2, '2 seconds'), (3, null), (1, '5 days') t(i, v) -group by i --- !query schema -struct --- !query output -1 -1 days -2 2 seconds -3 NULL - - --- !query -select - avg(cast(v as interval)) as sv -from VALUES (1, '-1 weeks'), (2, '2 seconds'), (3, null), (1, '5 days') t(i, v) -having sv is not null --- !query schema -struct --- !query output --15 hours -59 minutes -59.333333 seconds - - --- !query -SELECT - i, - avg(cast(v as interval)) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) -FROM VALUES (1,'1 seconds'), (1,'2 seconds'), (2,NULL), (2,NULL) t(i,v) --- !query schema -struct --- !query output -1 1.5 seconds -1 2 seconds -2 NULL -2 NULL - - -- !query create temporary view interval_arithmetic as select CAST(dateval AS date), CAST(tsval AS timestamp) from values diff --git a/sql/core/src/test/resources/sql-tests/results/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/interval.sql.out index 3c4b4301d0025..0509594ac13d3 100644 --- a/sql/core/src/test/resources/sql-tests/results/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/interval.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 99 +-- Number of queries: 81 -- !query @@ -615,180 +615,6 @@ select interval 30 day day day ---------------------------^^^ --- !query -select sum(cast(null as interval)) --- !query schema -struct --- !query output -NULL - - --- !query -select sum(cast(v as interval)) from VALUES ('1 seconds') t(v) where 1=0 --- !query schema -struct --- !query output -NULL - - --- !query -select sum(cast(v as interval)) from VALUES ('1 seconds'), ('2 seconds'), (null) t(v) --- !query schema -struct --- !query output -3 seconds - - --- !query -select sum(cast(v as interval)) from VALUES ('-1 seconds'), ('2 seconds'), (null) t(v) --- !query schema -struct --- !query output -1 seconds - - --- !query -select sum(cast(v as interval)) from VALUES ('-1 seconds'), ('-2 seconds'), (null) t(v) --- !query schema -struct --- !query output --3 seconds - - --- !query -select sum(cast(v as interval)) from VALUES ('-1 weeks'), ('2 seconds'), (null) t(v) --- !query schema -struct --- !query output --7 days 2 seconds - - --- !query -select - i, - sum(cast(v as interval)) -from VALUES (1, '-1 weeks'), (2, '2 seconds'), (3, null), (1, '5 days') t(i, v) -group by i --- !query schema -struct --- !query output -1 -2 days -2 2 seconds -3 NULL - - --- !query -select - sum(cast(v as interval)) as sv -from VALUES (1, '-1 weeks'), (2, '2 seconds'), (3, null), (1, '5 days') t(i, v) -having sv is not null --- !query schema -struct --- !query output --2 days 2 seconds - - --- !query -SELECT - i, - sum(cast(v as interval)) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) -FROM VALUES(1, '1 seconds'), (1, '2 seconds'), (2, NULL), (2, NULL) t(i,v) --- !query schema -struct --- !query output -1 2 seconds -1 3 seconds -2 NULL -2 NULL - - --- !query -select avg(cast(v as interval)) from VALUES (null) t(v) --- !query schema -struct --- !query output -NULL - - --- !query -select avg(cast(v as interval)) from VALUES ('1 seconds'), ('2 seconds'), (null) t(v) where 1=0 --- !query schema -struct --- !query output -NULL - - --- !query -select avg(cast(v as interval)) from VALUES ('1 seconds'), ('2 seconds'), (null) t(v) --- !query schema -struct --- !query output -1.5 seconds - - --- !query -select avg(cast(v as interval)) from VALUES ('-1 seconds'), ('2 seconds'), (null) t(v) --- !query schema -struct --- !query output -0.5 seconds - - --- !query -select avg(cast(v as interval)) from VALUES ('-1 seconds'), ('-2 seconds'), (null) t(v) --- !query schema -struct --- !query output --1.5 seconds - - --- !query -select avg(cast(v as interval)) from VALUES ('-1 weeks'), ('2 seconds'), (null) t(v) --- !query schema -struct --- !query output --3 days -11 hours -59 minutes -59 seconds - - --- !query -select - i, - avg(cast(v as interval)) -from VALUES (1, '-1 weeks'), (2, '2 seconds'), (3, null), (1, '5 days') t(i, v) -group by i --- !query schema -struct --- !query output -1 -1 days -2 2 seconds -3 NULL - - --- !query -select - avg(cast(v as interval)) as sv -from VALUES (1, '-1 weeks'), (2, '2 seconds'), (3, null), (1, '5 days') t(i, v) -having sv is not null --- !query schema -struct --- !query output --15 hours -59 minutes -59.333333 seconds - - --- !query -SELECT - i, - avg(cast(v as interval)) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) -FROM VALUES (1,'1 seconds'), (1,'2 seconds'), (2,NULL), (2,NULL) t(i,v) --- !query schema -struct --- !query output -1 1.5 seconds -1 2 seconds -2 NULL -2 NULL - - -- !query create temporary view interval_arithmetic as select CAST(dateval AS date), CAST(tsval AS timestamp) from values diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index d7df75fd0e2c3..288f3dac36621 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -957,17 +957,4 @@ class DataFrameAggregateSuite extends QueryTest assert(error.message.contains("function count_if requires boolean type")) } } - - test("calendar interval agg support hash aggregate") { - val df1 = Seq((1, "1 day"), (2, "2 day"), (3, "3 day"), (3, null)).toDF("a", "b") - val df2 = df1.select(avg($"b" cast CalendarIntervalType)) - checkAnswer(df2, Row(new CalendarInterval(0, 2, 0)) :: Nil) - assert(find(df2.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined) - val df3 = df1.groupBy($"a").agg(avg($"b" cast CalendarIntervalType)) - checkAnswer(df3, - Row(1, new CalendarInterval(0, 1, 0)) :: - Row(2, new CalendarInterval(0, 2, 0)) :: - Row(3, new CalendarInterval(0, 3, 0)) :: Nil) - assert(find(df3.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined) - } }