From af78b62962decd44df231ff1947639de95231e82 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sat, 3 Aug 2019 11:27:20 +0200 Subject: [PATCH 1/5] [SPARK-28610][SQL] Allow having a decimal buffer for long sum --- .../catalyst/expressions/aggregate/Sum.scala | 15 +++++++++++++-- .../apache/spark/sql/internal/SQLConf.scala | 10 ++++++++++ .../spark/sql/DataFrameAggregateSuite.scala | 18 ++++++++++++++++++ 3 files changed, 41 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index ef204ec82c52..acde399b3407 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @ExpressionDescription( @@ -56,7 +57,11 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast case _ => DoubleType } - private lazy val sumDataType = resultType + private lazy val sumDataType = child.dataType match { + case LongType if SQLConf.get.getConf(SQLConf.SUM_DECIMAL_BUFFER_FOR_LONG) => + DecimalType.BigIntDecimal + case _ => resultType + } private lazy val sum = AttributeReference("sum", sumDataType)() @@ -89,5 +94,11 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast ) } - override lazy val evaluateExpression: Expression = sum + override lazy val evaluateExpression: Expression = { + if (sumDataType == resultType) { + sum + } else { + Cast(sum, resultType) + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 2fede591fc80..40b2873600e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1789,6 +1789,16 @@ object SQLConf { .booleanConf .createWithDefault(false) + val SUM_DECIMAL_BUFFER_FOR_LONG = + buildConf("spark.sql.sum.decimalBufferForLong") + .doc("If it is set to true, sum of long uses decimal type for the buffer. When false " + + "(default), long is used to the buffer. If spark.sql.arithmeticOperations.failOnOverFlow" + + " is turned on, having this config set to true allows operations which have temporary " + + "overflows to execute properly without the exception thrown when this flag is false.") + .internal() + .booleanConf + .createWithDefault(false) + val LEGACY_HAVING_WITHOUT_GROUP_BY_AS_WHERE = buildConf("spark.sql.legacy.parser.havingWithoutGroupByAsWhere") .internal() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index e49ef012f5eb..27bbbfc3d5a4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import org.apache.spark.SparkException + import scala.util.Random import org.scalatest.Matchers.the @@ -927,4 +929,20 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { assert(error.message.contains("function count_if requires boolean type")) } } + + test("SPARK-28610: temporary overflow on sum of long should not fail") { + withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> "true", + SQLConf.SUM_DECIMAL_BUFFER_FOR_LONG.key -> "true") { + val df = sparkContext.parallelize(Seq(100L, Long.MaxValue, -1000L), 1).toDF("a") + checkAnswer(df.select(sum($"a")), Row(Long.MaxValue - 900L)) + } + withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> "true", + SQLConf.SUM_DECIMAL_BUFFER_FOR_LONG.key -> "false") { + val df = sparkContext.parallelize(Seq(100L, Long.MaxValue, -1000L), 1).toDF("a") + val e = intercept[SparkException] { + df.select(sum($"a")).collect() + } + assert(e.getCause.isInstanceOf[ArithmeticException]) + } + } } From e81e8fa657395e2819c0550d2ddd8eb9886da473 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sun, 4 Aug 2019 10:08:24 +0200 Subject: [PATCH 2/5] fix scalastyle --- .../scala/org/apache/spark/sql/DataFrameAggregateSuite.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 27bbbfc3d5a4..0acc6e5c2051 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql -import org.apache.spark.SparkException - import scala.util.Random import org.scalatest.Matchers.the +import org.apache.spark.SparkException import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec From 573851116a6325ee2be1648a5990af59c08687ab Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 8 Aug 2019 10:10:09 +0200 Subject: [PATCH 3/5] return decimal --- .../catalyst/expressions/aggregate/Sum.scala | 16 +++------- .../apache/spark/sql/internal/SQLConf.scala | 10 +++--- .../spark/sql/DataFrameAggregateSuite.scala | 32 ++++++++++++++++--- 3 files changed, 36 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index acde399b3407..7881d4e69e8c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -53,15 +53,13 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast private lazy val resultType = child.dataType match { case DecimalType.Fixed(precision, scale) => DecimalType.bounded(precision + 10, scale) + case LongType if SQLConf.get.getConf(SQLConf.SUM_DECIMAL_RESULT_FOR_LONG) => + DecimalType.BigIntDecimal case _: IntegralType => LongType case _ => DoubleType } - private lazy val sumDataType = child.dataType match { - case LongType if SQLConf.get.getConf(SQLConf.SUM_DECIMAL_BUFFER_FOR_LONG) => - DecimalType.BigIntDecimal - case _ => resultType - } + private lazy val sumDataType = resultType private lazy val sum = AttributeReference("sum", sumDataType)() @@ -94,11 +92,5 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast ) } - override lazy val evaluateExpression: Expression = { - if (sumDataType == resultType) { - sum - } else { - Cast(sum, resultType) - } - } + override lazy val evaluateExpression: Expression = sum } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 40b2873600e7..57d5d123e5b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1789,12 +1789,10 @@ object SQLConf { .booleanConf .createWithDefault(false) - val SUM_DECIMAL_BUFFER_FOR_LONG = - buildConf("spark.sql.sum.decimalBufferForLong") - .doc("If it is set to true, sum of long uses decimal type for the buffer. When false " + - "(default), long is used to the buffer. If spark.sql.arithmeticOperations.failOnOverFlow" + - " is turned on, having this config set to true allows operations which have temporary " + - "overflows to execute properly without the exception thrown when this flag is false.") + val SUM_DECIMAL_RESULT_FOR_LONG = + buildConf("spark.sql.sum.decimalResultForLong") + .doc("If it is set to true, sum of long returns decimal type. When false (default), " + + "long is returned.") .internal() .booleanConf .createWithDefault(false) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 0acc6e5c2051..3eac910efdda 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -929,19 +929,43 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } } - test("SPARK-28610: temporary overflow on sum of long should not fail") { - withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> "true", - SQLConf.SUM_DECIMAL_BUFFER_FOR_LONG.key -> "true") { + test("SPARK-28610: overflow on sum of long should not fail with decimal") { + // Temporary overflow + withSQLConf(SQLConf.SUM_DECIMAL_RESULT_FOR_LONG.key -> "true") { + val df = sparkContext.parallelize(Seq(100L, Long.MaxValue, -1000L), 1).toDF("a") + checkAnswer(df.select(sum($"a")), Row(BigDecimal(Long.MaxValue - 900L))) + } + withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> "false", + SQLConf.SUM_DECIMAL_RESULT_FOR_LONG.key -> "false") { val df = sparkContext.parallelize(Seq(100L, Long.MaxValue, -1000L), 1).toDF("a") checkAnswer(df.select(sum($"a")), Row(Long.MaxValue - 900L)) } withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> "true", - SQLConf.SUM_DECIMAL_BUFFER_FOR_LONG.key -> "false") { + SQLConf.SUM_DECIMAL_RESULT_FOR_LONG.key -> "false") { val df = sparkContext.parallelize(Seq(100L, Long.MaxValue, -1000L), 1).toDF("a") val e = intercept[SparkException] { df.select(sum($"a")).collect() } assert(e.getCause.isInstanceOf[ArithmeticException]) } + // Resulting overflow + withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> "false", + SQLConf.SUM_DECIMAL_RESULT_FOR_LONG.key -> "false") { + val df = sparkContext.parallelize(Seq(100L, Long.MaxValue, 1000L), 1).toDF("a") + // wrong result + checkAnswer(df.select(sum($"a")), Row(Long.MinValue + 1099L)) + } + withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> "true", + SQLConf.SUM_DECIMAL_RESULT_FOR_LONG.key -> "false") { + val df = sparkContext.parallelize(Seq(100L, Long.MaxValue, 1000L), 1).toDF("a") + val e = intercept[SparkException] { + df.select(sum($"a")).collect() + } + assert(e.getCause.isInstanceOf[ArithmeticException]) + } + withSQLConf(SQLConf.SUM_DECIMAL_RESULT_FOR_LONG.key -> "true") { + val df = sparkContext.parallelize(Seq(100L, Long.MaxValue, 1000L), 1).toDF("a") + checkAnswer(df.select(sum($"a")), Row(BigDecimal("9223372036854776907"))) + } } } From aae4642670126fdadaf593c24ea14c2bf66ebae0 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 22 Aug 2019 10:24:18 +0200 Subject: [PATCH 4/5] Revert "return decimal" This reverts commit 573851116a6325ee2be1648a5990af59c08687ab. --- .../catalyst/expressions/aggregate/Sum.scala | 16 +++++++--- .../apache/spark/sql/internal/SQLConf.scala | 10 +++--- .../spark/sql/DataFrameAggregateSuite.scala | 32 +++---------------- 3 files changed, 22 insertions(+), 36 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index 7881d4e69e8c..acde399b3407 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -53,13 +53,15 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast private lazy val resultType = child.dataType match { case DecimalType.Fixed(precision, scale) => DecimalType.bounded(precision + 10, scale) - case LongType if SQLConf.get.getConf(SQLConf.SUM_DECIMAL_RESULT_FOR_LONG) => - DecimalType.BigIntDecimal case _: IntegralType => LongType case _ => DoubleType } - private lazy val sumDataType = resultType + private lazy val sumDataType = child.dataType match { + case LongType if SQLConf.get.getConf(SQLConf.SUM_DECIMAL_BUFFER_FOR_LONG) => + DecimalType.BigIntDecimal + case _ => resultType + } private lazy val sum = AttributeReference("sum", sumDataType)() @@ -92,5 +94,11 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast ) } - override lazy val evaluateExpression: Expression = sum + override lazy val evaluateExpression: Expression = { + if (sumDataType == resultType) { + sum + } else { + Cast(sum, resultType) + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 57d5d123e5b3..40b2873600e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1789,10 +1789,12 @@ object SQLConf { .booleanConf .createWithDefault(false) - val SUM_DECIMAL_RESULT_FOR_LONG = - buildConf("spark.sql.sum.decimalResultForLong") - .doc("If it is set to true, sum of long returns decimal type. When false (default), " + - "long is returned.") + val SUM_DECIMAL_BUFFER_FOR_LONG = + buildConf("spark.sql.sum.decimalBufferForLong") + .doc("If it is set to true, sum of long uses decimal type for the buffer. When false " + + "(default), long is used to the buffer. If spark.sql.arithmeticOperations.failOnOverFlow" + + " is turned on, having this config set to true allows operations which have temporary " + + "overflows to execute properly without the exception thrown when this flag is false.") .internal() .booleanConf .createWithDefault(false) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 3eac910efdda..0acc6e5c2051 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -929,43 +929,19 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } } - test("SPARK-28610: overflow on sum of long should not fail with decimal") { - // Temporary overflow - withSQLConf(SQLConf.SUM_DECIMAL_RESULT_FOR_LONG.key -> "true") { - val df = sparkContext.parallelize(Seq(100L, Long.MaxValue, -1000L), 1).toDF("a") - checkAnswer(df.select(sum($"a")), Row(BigDecimal(Long.MaxValue - 900L))) - } - withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> "false", - SQLConf.SUM_DECIMAL_RESULT_FOR_LONG.key -> "false") { + test("SPARK-28610: temporary overflow on sum of long should not fail") { + withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> "true", + SQLConf.SUM_DECIMAL_BUFFER_FOR_LONG.key -> "true") { val df = sparkContext.parallelize(Seq(100L, Long.MaxValue, -1000L), 1).toDF("a") checkAnswer(df.select(sum($"a")), Row(Long.MaxValue - 900L)) } withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> "true", - SQLConf.SUM_DECIMAL_RESULT_FOR_LONG.key -> "false") { + SQLConf.SUM_DECIMAL_BUFFER_FOR_LONG.key -> "false") { val df = sparkContext.parallelize(Seq(100L, Long.MaxValue, -1000L), 1).toDF("a") val e = intercept[SparkException] { df.select(sum($"a")).collect() } assert(e.getCause.isInstanceOf[ArithmeticException]) } - // Resulting overflow - withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> "false", - SQLConf.SUM_DECIMAL_RESULT_FOR_LONG.key -> "false") { - val df = sparkContext.parallelize(Seq(100L, Long.MaxValue, 1000L), 1).toDF("a") - // wrong result - checkAnswer(df.select(sum($"a")), Row(Long.MinValue + 1099L)) - } - withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> "true", - SQLConf.SUM_DECIMAL_RESULT_FOR_LONG.key -> "false") { - val df = sparkContext.parallelize(Seq(100L, Long.MaxValue, 1000L), 1).toDF("a") - val e = intercept[SparkException] { - df.select(sum($"a")).collect() - } - assert(e.getCause.isInstanceOf[ArithmeticException]) - } - withSQLConf(SQLConf.SUM_DECIMAL_RESULT_FOR_LONG.key -> "true") { - val df = sparkContext.parallelize(Seq(100L, Long.MaxValue, 1000L), 1).toDF("a") - checkAnswer(df.select(sum($"a")), Row(BigDecimal("9223372036854776907"))) - } } } From bd27f5806e96d72a347708d55c89d92f46ec1200 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 22 Aug 2019 17:28:58 +0200 Subject: [PATCH 5/5] fix --- .../spark/sql/catalyst/expressions/aggregate/Sum.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index 815035bb8ac9..758133368e52 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -95,14 +95,14 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast } override lazy val evaluateExpression: Expression = { - val res = resultType match { + val res = sumDataType match { case d: DecimalType => CheckOverflow(sum, d, SQLConf.get.decimalOperationsNullOnOverflow) case _ => sum } if (sumDataType == resultType) { - sum + res } else { - Cast(sum, resultType) + Cast(res, resultType) } } }