Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,11 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
case _ => DoubleType
}

private lazy val sumDataType = resultType
private lazy val sumDataType = child.dataType match {
case LongType if SQLConf.get.getConf(SQLConf.SUM_DECIMAL_BUFFER_FOR_LONG) =>
DecimalType.BigIntDecimal
case _ => resultType
}

private lazy val sum = AttributeReference("sum", sumDataType)()

Expand Down Expand Up @@ -90,9 +94,15 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
)
}

override lazy val evaluateExpression: Expression = resultType match {
case d: DecimalType => CheckOverflow(sum, d, SQLConf.get.decimalOperationsNullOnOverflow)
case _ => sum
override lazy val evaluateExpression: Expression = {
val res = sumDataType match {
case d: DecimalType => CheckOverflow(sum, d, SQLConf.get.decimalOperationsNullOnOverflow)
case _ => sum
}
if (sumDataType == resultType) {
res
} else {
Cast(res, resultType)
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -1814,6 +1814,16 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val SUM_DECIMAL_BUFFER_FOR_LONG =
buildConf("spark.sql.sum.decimalBufferForLong")
.doc("If it is set to true, sum of long uses decimal type for the buffer. When false " +
"(default), long is used to the buffer. If spark.sql.arithmeticOperations.failOnOverFlow" +
" is turned on, having this config set to true allows operations which have temporary " +
"overflows to execute properly without the exception thrown when this flag is false.")
.internal()
.booleanConf
.createWithDefault(false)

val LEGACY_HAVING_WITHOUT_GROUP_BY_AS_WHERE =
buildConf("spark.sql.legacy.parser.havingWithoutGroupByAsWhere")
.internal()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import scala.util.Random

import org.scalatest.Matchers.the

import org.apache.spark.SparkException
import org.apache.spark.sql.execution.WholeStageCodegenExec
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
Expand Down Expand Up @@ -942,4 +943,20 @@ class DataFrameAggregateSuite extends QueryTest with SharedSparkSession {
assert(error.message.contains("function count_if requires boolean type"))
}
}

test("SPARK-28610: temporary overflow on sum of long should not fail") {
withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> "true",
SQLConf.SUM_DECIMAL_BUFFER_FOR_LONG.key -> "true") {
val df = sparkContext.parallelize(Seq(100L, Long.MaxValue, -1000L), 1).toDF("a")
checkAnswer(df.select(sum($"a")), Row(Long.MaxValue - 900L))
}
withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> "true",
SQLConf.SUM_DECIMAL_BUFFER_FOR_LONG.key -> "false") {
val df = sparkContext.parallelize(Seq(100L, Long.MaxValue, -1000L), 1).toDF("a")
val e = intercept[SparkException] {
df.select(sum($"a")).collect()
}
assert(e.getCause.isInstanceOf[ArithmeticException])
}
}
}