diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index d04fe9249d06..758133368e52 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -57,7 +57,11 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast case _ => DoubleType } - private lazy val sumDataType = resultType + private lazy val sumDataType = child.dataType match { + case LongType if SQLConf.get.getConf(SQLConf.SUM_DECIMAL_BUFFER_FOR_LONG) => + DecimalType.BigIntDecimal + case _ => resultType + } private lazy val sum = AttributeReference("sum", sumDataType)() @@ -90,9 +94,15 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast ) } - override lazy val evaluateExpression: Expression = resultType match { - case d: DecimalType => CheckOverflow(sum, d, SQLConf.get.decimalOperationsNullOnOverflow) - case _ => sum + override lazy val evaluateExpression: Expression = { + val res = sumDataType match { + case d: DecimalType => CheckOverflow(sum, d, SQLConf.get.decimalOperationsNullOnOverflow) + case _ => sum + } + if (sumDataType == resultType) { + res + } else { + Cast(res, resultType) + } } - } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index c5c94708aa7e..3a634b5bc99c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1814,6 +1814,16 @@ object SQLConf { .booleanConf .createWithDefault(false) + val SUM_DECIMAL_BUFFER_FOR_LONG = + buildConf("spark.sql.sum.decimalBufferForLong") + .doc("If it is set to true, sum of long uses decimal type for the buffer. When false " + + "(default), long is used to the buffer. If spark.sql.arithmeticOperations.failOnOverFlow" + + " is turned on, having this config set to true allows operations which have temporary " + + "overflows to execute properly without the exception thrown when this flag is false.") + .internal() + .booleanConf + .createWithDefault(false) + val LEGACY_HAVING_WITHOUT_GROUP_BY_AS_WHERE = buildConf("spark.sql.legacy.parser.havingWithoutGroupByAsWhere") .internal() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index ec7b636c8f69..69b2c3bc0f38 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -21,6 +21,7 @@ import scala.util.Random import org.scalatest.Matchers.the +import org.apache.spark.SparkException import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec @@ -942,4 +943,20 @@ class DataFrameAggregateSuite extends QueryTest with SharedSparkSession { assert(error.message.contains("function count_if requires boolean type")) } } + + test("SPARK-28610: temporary overflow on sum of long should not fail") { + withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> "true", + SQLConf.SUM_DECIMAL_BUFFER_FOR_LONG.key -> "true") { + val df = sparkContext.parallelize(Seq(100L, Long.MaxValue, -1000L), 1).toDF("a") + checkAnswer(df.select(sum($"a")), Row(Long.MaxValue - 900L)) + } + withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> "true", + SQLConf.SUM_DECIMAL_BUFFER_FOR_LONG.key -> "false") { + val df = sparkContext.parallelize(Seq(100L, Long.MaxValue, -1000L), 1).toDF("a") + val e = intercept[SparkException] { + df.select(sum($"a")).collect() + } + assert(e.getCause.isInstanceOf[ArithmeticException]) + } + } }