diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index 21ac32adca6e..25303475a73c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -79,8 +79,8 @@ private[sql] class TypeCollection(private val types: Seq[AbstractDataType]) private[sql] object TypeCollection { /** - * Types that include numeric types and interval type. They are only used in unary_minus, - * unary_positive, add and subtract operations. + * Types that include numeric types and interval type, which support numeric type calculations, + * i.e. unary_minus, unary_positive, sum, avg, min, max, add and subtract operations. */ val NumericAndInterval = TypeCollection(NumericType, CalendarIntervalType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index e1bca44dfccf..77a779a2f310 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -268,9 +268,9 @@ class Dataset[T] private[sql]( } } - private[sql] def numericColumns: Seq[Expression] = { - schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n => - queryExecution.analyzed.resolveQuoted(n.name, sparkSession.sessionState.analyzer.resolver).get + private[sql] def numericCalculationSupportedColumns: Seq[Expression] = { + queryExecution.analyzed.output.filter { attr => + TypeCollection.NumericAndInterval.acceptsType(attr.dataType) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index b1ba7d453873..52bd0ecb1fff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{NumericType, StructType} +import org.apache.spark.sql.types.{StructType, TypeCollection} /** * A set of methods for aggregations on a `DataFrame`, created by [[Dataset#groupBy groupBy]], @@ -88,20 +88,20 @@ class RelationalGroupedDataset protected[sql]( case expr: Expression => Alias(expr, toPrettySQL(expr))() } - private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => AggregateFunction) - : DataFrame = { + private[this] def aggregateNumericOrIntervalColumns( + colNames: String*)(f: Expression => AggregateFunction): DataFrame = { val columnExprs = if (colNames.isEmpty) { - // No columns specified. Use all numeric columns. - df.numericColumns + // No columns specified. Use all numeric calculation supported columns. + df.numericCalculationSupportedColumns } else { - // Make sure all specified columns are numeric. + // Make sure all specified columns are numeric calculation supported columns. colNames.map { colName => val namedExpr = df.resolve(colName) - if (!namedExpr.dataType.isInstanceOf[NumericType]) { + if (!TypeCollection.NumericAndInterval.acceptsType(namedExpr.dataType)) { throw new AnalysisException( - s""""$colName" is not a numeric column. """ + - "Aggregation function can only be applied on a numeric column.") + s""""$colName" is not a numeric or calendar interval column. """ + + "Aggregation function can only be applied on a numeric or calendar interval column.") } namedExpr } @@ -269,7 +269,8 @@ class RelationalGroupedDataset protected[sql]( def count(): DataFrame = toDF(Seq(Alias(Count(Literal(1)).toAggregateExpression(), "count")())) /** - * Compute the average value for each numeric columns for each group. This is an alias for `avg`. + * Compute the average value for each numeric or calender interval columns for each group. This + * is an alias for `avg`. * The resulting `DataFrame` will also contain the grouping columns. * When specified columns are given, only compute the average values for them. * @@ -277,11 +278,11 @@ class RelationalGroupedDataset protected[sql]( */ @scala.annotation.varargs def mean(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames : _*)(Average) + aggregateNumericOrIntervalColumns(colNames : _*)(Average) } /** - * Compute the max value for each numeric columns for each group. + * Compute the max value for each numeric calender interval columns for each group. * The resulting `DataFrame` will also contain the grouping columns. * When specified columns are given, only compute the max values for them. * @@ -289,11 +290,11 @@ class RelationalGroupedDataset protected[sql]( */ @scala.annotation.varargs def max(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames : _*)(Max) + aggregateNumericOrIntervalColumns(colNames : _*)(Max) } /** - * Compute the mean value for each numeric columns for each group. + * Compute the mean value for each numeric calender interval columns for each group. * The resulting `DataFrame` will also contain the grouping columns. * When specified columns are given, only compute the mean values for them. * @@ -301,11 +302,11 @@ class RelationalGroupedDataset protected[sql]( */ @scala.annotation.varargs def avg(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames : _*)(Average) + aggregateNumericOrIntervalColumns(colNames : _*)(Average) } /** - * Compute the min value for each numeric column for each group. + * Compute the min value for each numeric calender interval column for each group. * The resulting `DataFrame` will also contain the grouping columns. * When specified columns are given, only compute the min values for them. * @@ -313,11 +314,11 @@ class RelationalGroupedDataset protected[sql]( */ @scala.annotation.varargs def min(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames : _*)(Min) + aggregateNumericOrIntervalColumns(colNames : _*)(Min) } /** - * Compute the sum for each numeric columns for each group. + * Compute the sum for each numeric calender interval columns for each group. * The resulting `DataFrame` will also contain the grouping columns. * When specified columns are given, only compute the sum for them. * @@ -325,7 +326,7 @@ class RelationalGroupedDataset protected[sql]( */ @scala.annotation.varargs def sum(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames : _*)(Sum) + aggregateNumericOrIntervalColumns(colNames : _*)(Sum) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 944e4212b1bf..a08ef19c3ac5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -964,4 +964,24 @@ class DataFrameAggregateSuite extends QueryTest with SharedSparkSession { Row(3, new CalendarInterval(0, 3, 0)) :: Nil) assert(df3.queryExecution.executedPlan.find(_.isInstanceOf[HashAggregateExec]).isDefined) } + + test("Dataset agg functions support calendar intervals") { + val df1 = Seq((1, "1 day"), (2, "2 day"), (3, "3 day"), (3, null)).toDF("a", "b") + val df2 = df1.select('a, 'b cast CalendarIntervalType).groupBy('a % 2) + checkAnswer(df2.sum("b"), + Row(0, new CalendarInterval(0, 2, 0)) :: + Row(1, new CalendarInterval(0, 4, 0)) :: Nil) + checkAnswer(df2.avg("b"), + Row(0, new CalendarInterval(0, 2, 0)) :: + Row(1, new CalendarInterval(0, 2, 0)) :: Nil) + checkAnswer(df2.mean("b"), + Row(0, new CalendarInterval(0, 2, 0)) :: + Row(1, new CalendarInterval(0, 2, 0)) :: Nil) + checkAnswer(df2.max("b"), + Row(0, new CalendarInterval(0, 2, 0)) :: + Row(1, new CalendarInterval(0, 3, 0)) :: Nil) + checkAnswer(df2.min("b"), + Row(0, new CalendarInterval(0, 2, 0)) :: + Row(1, new CalendarInterval(0, 1, 0)) :: Nil) + } }