From edb32b91cfd1d8e249b43ed6dc06810de702b3df Mon Sep 17 00:00:00 2001 From: Sunitha Kambhampati Date: Wed, 12 Feb 2020 19:04:17 -0800 Subject: [PATCH 01/15] Fix the incorrect results for sum for decimal overflow, support reporting null for it when ansiEnabled is false and throw exception if ansiEnabled is true --- .../catalyst/expressions/aggregate/Sum.scala | 90 ++++++++++++++++--- .../expressions/decimalExpressions.scala | 28 ++++++ 2 files changed, 106 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index d2daaac72fc85..5f689ff0312f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -61,38 +61,104 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast private lazy val sumDataType = resultType private lazy val sum = AttributeReference("sum", sumDataType)() + private lazy val overflow = AttributeReference("overflow", BooleanType)() private lazy val zero = Literal.default(sumDataType) - override lazy val aggBufferAttributes = sum :: Nil + override lazy val aggBufferAttributes = sum :: overflow :: Nil override lazy val initialValues: Seq[Expression] = Seq( - /* sum = */ Literal.create(null, sumDataType) + /* sum = */ Literal.create(null, sumDataType), + /* overflow = */ Literal.create(false, BooleanType) ) override lazy val updateExpressions: Seq[Expression] = { if (child.nullable) { - Seq( - /* sum = */ - coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum) - ) + if (!SQLConf.get.ansiEnabled) { + Seq( + /* sum = */ + resultType match { + case d: DecimalType => coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum) + case _ => coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum) + }, + /* overflow = */ + resultType match { + case d: DecimalType => + If (overflow, true, HasOverflow(coalesce(sum, zero) + child.cast(sumDataType), d)) + case _ => If(overflow, true, false) + }) + } else { + Seq( + /* sum = */ + resultType match { + case d: DecimalType => coalesce( + CheckOverflow( + coalesce(sum, zero) + child.cast(sumDataType), d, !SQLConf.get.ansiEnabled), sum) + case _ => coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum) + }, + /* overflow = */ + false + ) + } } else { - Seq( - /* sum = */ - coalesce(sum, zero) + child.cast(sumDataType) - ) + if (!SQLConf.get.ansiEnabled) { + Seq( + /* sum = */ + resultType match { + case d: DecimalType => coalesce(sum, zero) + child.cast(sumDataType) + case _ => coalesce(sum, zero) + child.cast(sumDataType) + }, + /* overflow = */ + resultType match { + case d: DecimalType => + If(overflow, true, HasOverflow(coalesce(sum, zero) + child.cast(sumDataType), d)) + case _ => If(overflow, true, false) + }) + } else { + Seq( + /* sum = */ + resultType match { + case d: DecimalType => coalesce( + CheckOverflow( + coalesce(sum, zero) + child.cast(sumDataType), d, !SQLConf.get.ansiEnabled), sum) + case _ => coalesce(sum, zero) + child.cast(sumDataType) + }, + /* overflow = */ + false + ) + } } } override lazy val mergeExpressions: Seq[Expression] = { Seq( /* sum = */ - coalesce(coalesce(sum.left, zero) + sum.right, sum.left) + resultType match { + case d: DecimalType => + if (!SQLConf.get.ansiEnabled) { + coalesce(coalesce(sum.left, zero) + sum.right, sum.left) + } else { + coalesce(CheckOverflow( + coalesce(sum.left, zero) + sum.right, d, !SQLConf.get.ansiEnabled), sum.left) + } + case _ => coalesce(coalesce(sum.left, zero) + sum.right, sum.left) + }, + /* overflow = */ + resultType match { + case d: DecimalType => + if (!SQLConf.get.ansiEnabled) { + If(overflow.left || overflow.right, + true, HasOverflow(coalesce(sum.left, zero) + sum.right, d)) + } else { + If(overflow.left || overflow.right, true, false) + } + } ) } override lazy val evaluateExpression: Expression = resultType match { - case d: DecimalType => CheckOverflow(sum, d, !SQLConf.get.ansiEnabled) + case d: DecimalType => If(overflow && !SQLConf.get.ansiEnabled, + Literal.create(null, sumDataType) , sum) case _ => sum } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index 9014ebfe2f96a..ce00c0fa3228a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -144,3 +144,31 @@ case class CheckOverflow( override def sql: String = child.sql } + +case class HasOverflow( + child: Expression, + inputType: DecimalType) extends UnaryExpression { + + override def dataType: DataType= BooleanType + + override def nullable: Boolean = false + + override def nullSafeEval(input: Any): Any = + input.asInstanceOf[Decimal].changePrecision( + inputType.precision, + inputType.scale, + Decimal.ROUND_HALF_UP) + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, eval => { + s""" + |${ev.value} = $eval.changePrecision( + | ${inputType.precision}, ${inputType.scale}, Decimal.ROUND_HALF_UP()); + """.stripMargin + }) + } + + override def toString: String = s"HasOverflow($child, $inputType)" + + override def sql: String = child.sql +} From 8fe9aa4ca6d709e520b3862969f15da6dc18fe3e Mon Sep 17 00:00:00 2001 From: Sunitha Kambhampati Date: Wed, 12 Feb 2020 19:48:44 -0800 Subject: [PATCH 02/15] Add the jira testcase --- .../org/apache/spark/sql/DataFrameSuite.scala | 47 +++++++++++-------- 1 file changed, 28 insertions(+), 19 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index f20e684bf7657..7658a7a73122c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -213,25 +213,34 @@ class DataFrameSuite extends QueryTest } } - test("Star Expansion - ds.explode should fail with a meaningful message if it takes a star") { - val df = Seq(("1", "1,2"), ("2", "4"), ("3", "7,8,9")).toDF("prefix", "csv") - val e = intercept[AnalysisException] { - df.explode($"*") { case Row(prefix: String, csv: String) => - csv.split(",").map(v => Tuple1(prefix + ":" + v)).toSeq - }.queryExecution.assertAnalyzed() - } - assert(e.getMessage.contains("Invalid usage of '*' in explode/json_tuple/UDTF")) - - checkAnswer( - df.explode('prefix, 'csv) { case Row(prefix: String, csv: String) => - csv.split(",").map(v => Tuple1(prefix + ":" + v)).toSeq - }, - Row("1", "1,2", "1:1") :: - Row("1", "1,2", "1:2") :: - Row("2", "4", "2:4") :: - Row("3", "7,8,9", "3:7") :: - Row("3", "7,8,9", "3:8") :: - Row("3", "7,8,9", "3:9") :: Nil) + test("SPARK-28067 - Aggregate sum should not return wrong results for decimal overflow") { + Seq(true, false).foreach { ansiEnabled => + withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) { + val df = Seq( + (BigDecimal("10000000000000000000"), 1), + (BigDecimal("10000000000000000000"), 1), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum") + val df2 = df.withColumnRenamed("decNum", "decNum2").join(df, "intNum").agg(sum("decNum")) + if (!ansiEnabled) { + checkAnswer(df2, Row(null)) + } else { + val e = intercept[SparkException] { + df2.collect() + } + assert(e.getCause.getClass.equals(classOf[ArithmeticException])) + assert(e.getCause.getMessage.contains("cannot be represented as Decimal")) + } + } + } } test("Star Expansion - explode should fail with a meaningful message if it takes a star") { From 1adc512b41647f7b788e2af5a3c6659dac34439d Mon Sep 17 00:00:00 2001 From: Sunitha Kambhampati Date: Wed, 12 Feb 2020 19:49:19 -0800 Subject: [PATCH 03/15] Change formatting --- .../src/test/scala/org/apache/spark/sql/DataFrameSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 7658a7a73122c..4e88c7e6bb69a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -213,7 +213,7 @@ class DataFrameSuite extends QueryTest } } - test("SPARK-28067 - Aggregate sum should not return wrong results for decimal overflow") { + test("SPARK-28067: Aggregate sum should not return wrong results for decimal overflow") { Seq(true, false).foreach { ansiEnabled => withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) { val df = Seq( From d90c790943d3aa49063e1636eeabd6b298ff2a0a Mon Sep 17 00:00:00 2001 From: Sunitha Kambhampati Date: Mon, 17 Feb 2020 17:17:56 -0800 Subject: [PATCH 04/15] hasoverflow changes --- .../catalyst/expressions/aggregate/Sum.scala | 102 +++++++++--------- .../expressions/decimalExpressions.scala | 9 +- 2 files changed, 56 insertions(+), 55 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index 5f689ff0312f0..1ae30f354480d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -61,7 +61,7 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast private lazy val sumDataType = resultType private lazy val sum = AttributeReference("sum", sumDataType)() - private lazy val overflow = AttributeReference("overflow", BooleanType)() + private lazy val overflow = AttributeReference("overflow", BooleanType, false)() private lazy val zero = Literal.default(sumDataType) @@ -73,21 +73,31 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast ) override lazy val updateExpressions: Seq[Expression] = { - if (child.nullable) { - if (!SQLConf.get.ansiEnabled) { + if (!SQLConf.get.ansiEnabled) { + if (child.nullable) { Seq( /* sum = */ - resultType match { - case d: DecimalType => coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum) - case _ => coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum) - }, + coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum), /* overflow = */ resultType match { case d: DecimalType => - If (overflow, true, HasOverflow(coalesce(sum, zero) + child.cast(sumDataType), d)) - case _ => If(overflow, true, false) + If(overflow, true, + HasOverflow(coalesce(sum, zero) + child.cast(sumDataType), d)) + case _ => false }) } else { + Seq( + /* sum = */ + coalesce(sum, zero) + child.cast(sumDataType), + /* overflow = */ + resultType match { + case d: DecimalType => + If(overflow, true, HasOverflow(coalesce(sum, zero) + child.cast(sumDataType), d)) + case _ => false + }) + } + } else { + if (child.nullable) { Seq( /* sum = */ resultType match { @@ -97,30 +107,16 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast case _ => coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum) }, /* overflow = */ + // overflow flag doesnt need any updates since CheckOverflow will throw exception + // if overflow happens false ) - } - } else { - if (!SQLConf.get.ansiEnabled) { - Seq( - /* sum = */ - resultType match { - case d: DecimalType => coalesce(sum, zero) + child.cast(sumDataType) - case _ => coalesce(sum, zero) + child.cast(sumDataType) - }, - /* overflow = */ - resultType match { - case d: DecimalType => - If(overflow, true, HasOverflow(coalesce(sum, zero) + child.cast(sumDataType), d)) - case _ => If(overflow, true, false) - }) } else { Seq( /* sum = */ resultType match { - case d: DecimalType => coalesce( - CheckOverflow( - coalesce(sum, zero) + child.cast(sumDataType), d, !SQLConf.get.ansiEnabled), sum) + case d: DecimalType => CheckOverflow( + coalesce(sum, zero) + child.cast(sumDataType), d, !SQLConf.get.ansiEnabled) case _ => coalesce(sum, zero) + child.cast(sumDataType) }, /* overflow = */ @@ -131,34 +127,38 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast } override lazy val mergeExpressions: Seq[Expression] = { - Seq( - /* sum = */ - resultType match { - case d: DecimalType => - if (!SQLConf.get.ansiEnabled) { - coalesce(coalesce(sum.left, zero) + sum.right, sum.left) - } else { - coalesce(CheckOverflow( - coalesce(sum.left, zero) + sum.right, d, !SQLConf.get.ansiEnabled), sum.left) - } - case _ => coalesce(coalesce(sum.left, zero) + sum.right, sum.left) - }, - /* overflow = */ - resultType match { - case d: DecimalType => - if (!SQLConf.get.ansiEnabled) { - If(overflow.left || overflow.right, + if (!SQLConf.get.ansiEnabled) { + Seq( + /* sum = */ + coalesce(coalesce(sum.left, zero) + sum.right, sum.left), + /* overflow = */ + resultType match { + case d: DecimalType => + If(coalesce(overflow.left, false) || coalesce(overflow.right, false), true, HasOverflow(coalesce(sum.left, zero) + sum.right, d)) - } else { - If(overflow.left || overflow.right, true, false) - } - } - ) + case _ => + If(coalesce(overflow.left, false) || coalesce(overflow.right, false), true, false) + } + ) + } else { + Seq( + /* sum = */ + resultType match { + case d: DecimalType => + coalesce( + CheckOverflow(coalesce(sum.left, zero) + sum.right, d, !SQLConf.get.ansiEnabled), + sum.left) + case _ => coalesce(coalesce(sum.left, zero) + sum.right, sum.left) + }, + /* overflow = */ + If(coalesce(overflow.left, false) || coalesce(overflow.right, false), true, false) + ) + } } override lazy val evaluateExpression: Expression = resultType match { - case d: DecimalType => If(overflow && !SQLConf.get.ansiEnabled, - Literal.create(null, sumDataType) , sum) + case d: DecimalType => + If(overflow && !SQLConf.get.ansiEnabled, Literal.create(null, resultType), sum) case _ => sum } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index ce00c0fa3228a..c731e1d8c56b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -149,12 +149,12 @@ case class HasOverflow( child: Expression, inputType: DecimalType) extends UnaryExpression { - override def dataType: DataType= BooleanType + override def dataType: DataType = BooleanType - override def nullable: Boolean = false + override def nullable: Boolean = true override def nullSafeEval(input: Any): Any = - input.asInstanceOf[Decimal].changePrecision( + !input.asInstanceOf[Decimal].changePrecision( inputType.precision, inputType.scale, Decimal.ROUND_HALF_UP) @@ -162,8 +162,9 @@ case class HasOverflow( override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, eval => { s""" - |${ev.value} = $eval.changePrecision( + |${ev.value} = !$eval.changePrecision( | ${inputType.precision}, ${inputType.scale}, Decimal.ROUND_HALF_UP()); + | ${ev.isNull} = false; """.stripMargin }) } From 136e6dcee64a00e7c42c0a2e0fe2e1e84a258edb Mon Sep 17 00:00:00 2001 From: Sunitha Kambhampati Date: Mon, 2 Mar 2020 19:07:24 -0800 Subject: [PATCH 05/15] Fix formatting --- .../spark/sql/catalyst/expressions/decimalExpressions.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index c731e1d8c56b5..c817ed813a1e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -164,7 +164,7 @@ case class HasOverflow( s""" |${ev.value} = !$eval.changePrecision( | ${inputType.precision}, ${inputType.scale}, Decimal.ROUND_HALF_UP()); - | ${ev.isNull} = false; + |${ev.isNull} = false; """.stripMargin }) } From 6979e8dceed4ad632ee8114ccf4561ed63f35d19 Mon Sep 17 00:00:00 2001 From: Sunitha Kambhampati Date: Tue, 7 Apr 2020 17:47:42 -0700 Subject: [PATCH 06/15] Add a new implementation for sum for dec type to handle overflow, Add a new DecimalSum and add a analyzer rule and changes to optimizer rule. Add tests For now the DecimalSum has the code to handle other types, incase we want to put this logic back into Sum after discussion --- .../sql/catalyst/analysis/Analyzer.scala | 10 ++ .../expressions/aggregate/DecimalSum.scala | 110 +++++++++++++++++ .../catalyst/expressions/aggregate/Sum.scala | 92 ++------------- .../expressions/decimalExpressions.scala | 23 +++- .../sql/catalyst/optimizer/Optimizer.scala | 4 +- .../org/apache/spark/sql/types/Decimal.scala | 5 + .../org/apache/spark/sql/DataFrameSuite.scala | 111 +++++++++++++----- 7 files changed, 243 insertions(+), 112 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DecimalSum.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 7a2b4e63e133a..8ca2d869203e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -240,6 +240,7 @@ class Analyzer( ExtractWindowExpressions :: GlobalAggregates :: ResolveAggregateFunctions :: + UseDecimalSum :: TimeWindowing :: ResolveInlineTables(conf) :: ResolveHigherOrderFunctions(v1SessionCatalog) :: @@ -3066,6 +3067,15 @@ class Analyzer( } } + /** + * Substitute Sum on decimal type to use DecimalSum implementation as it handles overflow + */ + object UseDecimalSum extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressions { + case Sum(e) if e.resolved && e.dataType.isInstanceOf[DecimalType] => DecimalSum(e) + } + } + /** Rule to mostly resolve, normalize and rewrite column names based on case sensitivity. */ object ResolveAlterTableChanges extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DecimalSum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DecimalSum.scala new file mode 100644 index 0000000000000..2803a10457cd2 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DecimalSum.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ + +case class DecimalSum(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes { + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + // Return data type. + override def dataType: DataType = resultType + + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) + + private lazy val resultType = child.dataType match { + case DecimalType.Fixed(precision, scale) => + DecimalType.bounded(precision + 10, scale) + case _: IntegralType => LongType + case _ => DoubleType + } + + private lazy val sumDataType = resultType + + private lazy val sum = AttributeReference("sum", sumDataType)() + private lazy val overflow = AttributeReference("overflow", BooleanType, false)() + + private lazy val zero = Literal.default(resultType) + + override lazy val aggBufferAttributes = sum :: overflow :: Nil + + override lazy val initialValues: Seq[Expression] = Seq( + /* sum = */ Literal.create(null, sumDataType), + /* overflow = */ Literal.create(false, BooleanType) + ) + + override lazy val updateExpressions: Seq[Expression] = { + if (child.nullable) { + resultType match { + case d: DecimalType => + Seq( + If(overflow, Literal.create(null, sumDataType), coalesce( + CheckOverflow(coalesce(sum, zero) + child.cast(sumDataType), d, true), sum)), + overflow || + coalesce(HasOverflow(coalesce(sum, zero) + child.cast(sumDataType), d), false) + ) + case _ => Seq(coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum), false) + } + } else { + resultType match { + case d: DecimalType => + Seq( + If(overflow, Literal.create(null, sumDataType), + CheckOverflow(coalesce(sum, zero) + child.cast(sumDataType), d, true)), + overflow || + coalesce(HasOverflow(coalesce(sum, zero) + child.cast(sumDataType), d), false) + ) + case _ => Seq(coalesce(sum, zero) + child.cast(sumDataType), false) + } + } + } + + override lazy val mergeExpressions: Seq[Expression] = { + resultType match { + case d: DecimalType => + Seq( + If(coalesce(overflow.left, false) || coalesce(overflow.right, false), + Literal.create(null, d), + coalesce(CheckOverflow(coalesce(sum.left, zero) + sum.right, d, true), sum.left)), + If(coalesce(overflow.left, false) || coalesce(overflow.right, false), + true, HasOverflow(coalesce(sum.left, zero) + sum.right, d)) + ) + case _ => + Seq( + coalesce(coalesce(sum.left, zero) + sum.right, sum.left), + If(coalesce(overflow.left, false) || coalesce(overflow.right, false), true, false) + ) + } + } + + override lazy val evaluateExpression: Expression = resultType match { + case d: DecimalType => + If(EqualTo(overflow, true), + If(!SQLConf.get.ansiEnabled, + Literal.create(null, sumDataType), + OverflowException(resultType, "Arithmetic Operation overflow")), + sum) + case _ => sum + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index 1ae30f354480d..d2daaac72fc85 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -61,104 +61,38 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast private lazy val sumDataType = resultType private lazy val sum = AttributeReference("sum", sumDataType)() - private lazy val overflow = AttributeReference("overflow", BooleanType, false)() private lazy val zero = Literal.default(sumDataType) - override lazy val aggBufferAttributes = sum :: overflow :: Nil + override lazy val aggBufferAttributes = sum :: Nil override lazy val initialValues: Seq[Expression] = Seq( - /* sum = */ Literal.create(null, sumDataType), - /* overflow = */ Literal.create(false, BooleanType) + /* sum = */ Literal.create(null, sumDataType) ) override lazy val updateExpressions: Seq[Expression] = { - if (!SQLConf.get.ansiEnabled) { - if (child.nullable) { - Seq( - /* sum = */ - coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum), - /* overflow = */ - resultType match { - case d: DecimalType => - If(overflow, true, - HasOverflow(coalesce(sum, zero) + child.cast(sumDataType), d)) - case _ => false - }) - } else { - Seq( - /* sum = */ - coalesce(sum, zero) + child.cast(sumDataType), - /* overflow = */ - resultType match { - case d: DecimalType => - If(overflow, true, HasOverflow(coalesce(sum, zero) + child.cast(sumDataType), d)) - case _ => false - }) - } - } else { - if (child.nullable) { - Seq( - /* sum = */ - resultType match { - case d: DecimalType => coalesce( - CheckOverflow( - coalesce(sum, zero) + child.cast(sumDataType), d, !SQLConf.get.ansiEnabled), sum) - case _ => coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum) - }, - /* overflow = */ - // overflow flag doesnt need any updates since CheckOverflow will throw exception - // if overflow happens - false - ) - } else { - Seq( - /* sum = */ - resultType match { - case d: DecimalType => CheckOverflow( - coalesce(sum, zero) + child.cast(sumDataType), d, !SQLConf.get.ansiEnabled) - case _ => coalesce(sum, zero) + child.cast(sumDataType) - }, - /* overflow = */ - false - ) - } - } - } - - override lazy val mergeExpressions: Seq[Expression] = { - if (!SQLConf.get.ansiEnabled) { + if (child.nullable) { Seq( /* sum = */ - coalesce(coalesce(sum.left, zero) + sum.right, sum.left), - /* overflow = */ - resultType match { - case d: DecimalType => - If(coalesce(overflow.left, false) || coalesce(overflow.right, false), - true, HasOverflow(coalesce(sum.left, zero) + sum.right, d)) - case _ => - If(coalesce(overflow.left, false) || coalesce(overflow.right, false), true, false) - } + coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum) ) } else { Seq( /* sum = */ - resultType match { - case d: DecimalType => - coalesce( - CheckOverflow(coalesce(sum.left, zero) + sum.right, d, !SQLConf.get.ansiEnabled), - sum.left) - case _ => coalesce(coalesce(sum.left, zero) + sum.right, sum.left) - }, - /* overflow = */ - If(coalesce(overflow.left, false) || coalesce(overflow.right, false), true, false) + coalesce(sum, zero) + child.cast(sumDataType) ) } } + override lazy val mergeExpressions: Seq[Expression] = { + Seq( + /* sum = */ + coalesce(coalesce(sum.left, zero) + sum.right, sum.left) + ) + } + override lazy val evaluateExpression: Expression = resultType match { - case d: DecimalType => - If(overflow && !SQLConf.get.ansiEnabled, Literal.create(null, resultType), sum) + case d: DecimalType => CheckOverflow(sum, d, !SQLConf.get.ansiEnabled) case _ => sum } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index c817ed813a1e8..603afbb315919 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, EmptyBlock, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -173,3 +174,23 @@ case class HasOverflow( override def sql: String = child.sql } + +case class OverflowException(dtype: DataType, msg: String) extends LeafExpression { + + override def dataType: DataType = dtype + + override def nullable: Boolean = false + + def eval(input: InternalRow): Any = { + Decimal.throwArithmeticException(msg) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + ev.copy(code = code""" + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |${ev.value} = Decimal.throwArithmeticException("${msg}"); + |""", isNull = FalseLiteral) + } + + override def toString: String = "OverflowException" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index e59e3b999aa7f..5d843fab3b25b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1450,7 +1450,7 @@ object DecimalAggregates extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsDown { case we @ WindowExpression(ae @ AggregateExpression(af, _, _, _, _), _) => af match { - case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS => + case DecimalSum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS => MakeDecimal(we.copy(windowFunction = ae.copy(aggregateFunction = Sum(UnscaledValue(e)))), prec + 10, scale) @@ -1464,7 +1464,7 @@ object DecimalAggregates extends Rule[LogicalPlan] { case _ => we } case ae @ AggregateExpression(af, _, _, _, _) => af match { - case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS => + case DecimalSum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS => MakeDecimal(ae.copy(aggregateFunction = Sum(UnscaledValue(e))), prec + 10, scale) case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index f32e48e1cc128..75e3cf4ad7a67 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -651,4 +651,9 @@ object Decimal { override def quot(x: Decimal, y: Decimal): Decimal = x quot y override def rem(x: Decimal, y: Decimal): Decimal = x % y } + + + def throwArithmeticException(msg: String): Decimal = { + throw new ArithmeticException(msg) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 4e88c7e6bb69a..31eec5d36bfa5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -200,49 +200,100 @@ class DataFrameSuite extends QueryTest Seq(true, false).foreach { ansiEnabled => withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) { val structDf = largeDecimals.select("a").agg(sum("a")) - if (!ansiEnabled) { - checkAnswer(structDf, Row(null)) - } else { - val e = intercept[SparkException] { - structDf.collect + checkAnsi(structDf, ansiEnabled) + } + } + } + + private def checkAnsi(df: DataFrame, ansiEnabled: Boolean): Unit = { + if (!ansiEnabled) { + checkAnswer(df, Row(null)) + } else { + val e = intercept[SparkException] { + df.collect() + } + assert(e.getCause.getClass.equals(classOf[ArithmeticException])) + assert(e.getCause.getMessage.contains("Arithmetic Operation overflow")) + } + } + + test("test sum on null decimal values") { + Seq("true", "false").foreach { wholeStageEnabled => + withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStageEnabled)) { + Seq("true", "false").foreach { ansiEnabled => + withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled)) { + val df = spark.range(1, 4, 1).select(expr(s"cast(null as decimal(38,18)) as d")) + checkAnswer(df.agg(sum($"d")), Row(null)) + df.agg(sum($"d")).show } - assert(e.getCause.getClass.equals(classOf[ArithmeticException])) - assert(e.getCause.getMessage.contains("cannot be represented as Decimal")) } } } } test("SPARK-28067: Aggregate sum should not return wrong results for decimal overflow") { - Seq(true, false).foreach { ansiEnabled => - withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) { - val df = Seq( - (BigDecimal("10000000000000000000"), 1), - (BigDecimal("10000000000000000000"), 1), - (BigDecimal("10000000000000000000"), 2), - (BigDecimal("10000000000000000000"), 2), - (BigDecimal("10000000000000000000"), 2), - (BigDecimal("10000000000000000000"), 2), - (BigDecimal("10000000000000000000"), 2), - (BigDecimal("10000000000000000000"), 2), - (BigDecimal("10000000000000000000"), 2), - (BigDecimal("10000000000000000000"), 2), - (BigDecimal("10000000000000000000"), 2), - (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum") - val df2 = df.withColumnRenamed("decNum", "decNum2").join(df, "intNum").agg(sum("decNum")) - if (!ansiEnabled) { - checkAnswer(df2, Row(null)) - } else { - val e = intercept[SparkException] { - df2.collect() + Seq("true", "false").foreach { wholeStageEnabled => + withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStageEnabled)) { + Seq(true, false).foreach { ansiEnabled => + withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) { + val df0 = Seq( + (BigDecimal("10000000000000000000"), 1), + (BigDecimal("10000000000000000000"), 1), + (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum") + val df1 = Seq( + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum") + val df = df0.union(df1) + val df2 = df.withColumnRenamed("decNum", "decNum2"). + join(df, "intNum").agg(sum("decNum")) + checkAnsi(df2, ansiEnabled) + + val decStr = "1" + "0" * 19 + val d1 = spark.range(0, 12, 1, 1) + val d2 = d1.select(expr(s"cast('$decStr' as decimal (38, 18)) as d")).agg(sum($"d")) + checkAnsi(d2, ansiEnabled) + + val d3 = spark.range(0, 1, 1, 1).union(spark.range(0, 11, 1, 1)) + val d4 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as d")).agg(sum($"d")) + checkAnsi(d4, ansiEnabled) + + val d5 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as d"), + lit(1).as("key")).groupBy("key").agg(sum($"d").alias("sumd")).select($"sumd") + checkAnsi(d5, ansiEnabled) } - assert(e.getCause.getClass.equals(classOf[ArithmeticException])) - assert(e.getCause.getMessage.contains("cannot be represented as Decimal")) } } } } + test("Star Expansion - ds.explode should fail with a meaningful message if it takes a star") { + val df = Seq(("1", "1,2"), ("2", "4"), ("3", "7,8,9")).toDF("prefix", "csv") + val e = intercept[AnalysisException] { + df.explode($"*") { case Row(prefix: String, csv: String) => + csv.split(",").map(v => Tuple1(prefix + ":" + v)).toSeq + }.queryExecution.assertAnalyzed() + } + assert(e.getMessage.contains("Invalid usage of '*' in explode/json_tuple/UDTF")) + + checkAnswer( + df.explode('prefix, 'csv) { case Row(prefix: String, csv: String) => + csv.split(",").map(v => Tuple1(prefix + ":" + v)).toSeq + }, + Row("1", "1,2", "1:1") :: + Row("1", "1,2", "1:2") :: + Row("2", "4", "2:4") :: + Row("3", "7,8,9", "3:7") :: + Row("3", "7,8,9", "3:8") :: + Row("3", "7,8,9", "3:9") :: Nil) + } + test("Star Expansion - explode should fail with a meaningful message if it takes a star") { val df = Seq(("1,2"), ("4"), ("7,8,9")).toDF("csv") val e = intercept[AnalysisException] { From 4119e029adab3b4d69c2851b0bceb5a064f1c25b Mon Sep 17 00:00:00 2001 From: Sunitha Kambhampati Date: Thu, 9 Apr 2020 14:51:32 -0700 Subject: [PATCH 07/15] Put back the decimal handling changes into Sum like how we started with --- .../sql/catalyst/analysis/Analyzer.scala | 10 -- .../expressions/aggregate/DecimalSum.scala | 110 ------------------ .../catalyst/expressions/aggregate/Sum.scala | 83 ++++++++++--- .../sql/catalyst/optimizer/Optimizer.scala | 4 +- .../sql-tests/results/explain.sql.out | 8 +- 5 files changed, 74 insertions(+), 141 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DecimalSum.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8ca2d869203e0..7a2b4e63e133a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -240,7 +240,6 @@ class Analyzer( ExtractWindowExpressions :: GlobalAggregates :: ResolveAggregateFunctions :: - UseDecimalSum :: TimeWindowing :: ResolveInlineTables(conf) :: ResolveHigherOrderFunctions(v1SessionCatalog) :: @@ -3067,15 +3066,6 @@ class Analyzer( } } - /** - * Substitute Sum on decimal type to use DecimalSum implementation as it handles overflow - */ - object UseDecimalSum extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressions { - case Sum(e) if e.resolved && e.dataType.isInstanceOf[DecimalType] => DecimalSum(e) - } - } - /** Rule to mostly resolve, normalize and rewrite column names based on case sensitivity. */ object ResolveAlterTableChanges extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DecimalSum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DecimalSum.scala deleted file mode 100644 index 2803a10457cd2..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DecimalSum.scala +++ /dev/null @@ -1,110 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.aggregate - -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types._ - -case class DecimalSum(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes { - - override def children: Seq[Expression] = child :: Nil - - override def nullable: Boolean = true - - // Return data type. - override def dataType: DataType = resultType - - override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) - - private lazy val resultType = child.dataType match { - case DecimalType.Fixed(precision, scale) => - DecimalType.bounded(precision + 10, scale) - case _: IntegralType => LongType - case _ => DoubleType - } - - private lazy val sumDataType = resultType - - private lazy val sum = AttributeReference("sum", sumDataType)() - private lazy val overflow = AttributeReference("overflow", BooleanType, false)() - - private lazy val zero = Literal.default(resultType) - - override lazy val aggBufferAttributes = sum :: overflow :: Nil - - override lazy val initialValues: Seq[Expression] = Seq( - /* sum = */ Literal.create(null, sumDataType), - /* overflow = */ Literal.create(false, BooleanType) - ) - - override lazy val updateExpressions: Seq[Expression] = { - if (child.nullable) { - resultType match { - case d: DecimalType => - Seq( - If(overflow, Literal.create(null, sumDataType), coalesce( - CheckOverflow(coalesce(sum, zero) + child.cast(sumDataType), d, true), sum)), - overflow || - coalesce(HasOverflow(coalesce(sum, zero) + child.cast(sumDataType), d), false) - ) - case _ => Seq(coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum), false) - } - } else { - resultType match { - case d: DecimalType => - Seq( - If(overflow, Literal.create(null, sumDataType), - CheckOverflow(coalesce(sum, zero) + child.cast(sumDataType), d, true)), - overflow || - coalesce(HasOverflow(coalesce(sum, zero) + child.cast(sumDataType), d), false) - ) - case _ => Seq(coalesce(sum, zero) + child.cast(sumDataType), false) - } - } - } - - override lazy val mergeExpressions: Seq[Expression] = { - resultType match { - case d: DecimalType => - Seq( - If(coalesce(overflow.left, false) || coalesce(overflow.right, false), - Literal.create(null, d), - coalesce(CheckOverflow(coalesce(sum.left, zero) + sum.right, d, true), sum.left)), - If(coalesce(overflow.left, false) || coalesce(overflow.right, false), - true, HasOverflow(coalesce(sum.left, zero) + sum.right, d)) - ) - case _ => - Seq( - coalesce(coalesce(sum.left, zero) + sum.right, sum.left), - If(coalesce(overflow.left, false) || coalesce(overflow.right, false), true, false) - ) - } - } - - override lazy val evaluateExpression: Expression = resultType match { - case d: DecimalType => - If(EqualTo(overflow, true), - If(!SQLConf.get.ansiEnabled, - Literal.create(null, sumDataType), - OverflowException(resultType, "Arithmetic Operation overflow")), - sum) - case _ => sum - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index d2daaac72fc85..2d9c0f1ad18cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -62,37 +62,90 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast private lazy val sum = AttributeReference("sum", sumDataType)() + private lazy val overflow = AttributeReference("overflow", BooleanType, false)() + private lazy val zero = Literal.default(sumDataType) - override lazy val aggBufferAttributes = sum :: Nil + override lazy val aggBufferAttributes = sum :: overflow :: Nil override lazy val initialValues: Seq[Expression] = Seq( - /* sum = */ Literal.create(null, sumDataType) + /* sum = */ Literal.create(null, sumDataType), + /* overflow = */ Literal.create(false, BooleanType) ) + /** + * For decimal types, update will do the following: + * We have a overflow flag and when it is true, it indicates overflow has happened + * 1. Start initial state with overflow = false, sum = null + * 2. Set sum to null if the value overflows else sum contains the intermediate sum + * 3. If overflow flag is true, keep sum as null + * 4. If overflow happened, then set overflow flag to true + */ override lazy val updateExpressions: Seq[Expression] = { if (child.nullable) { - Seq( - /* sum = */ - coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum) - ) + resultType match { + case d: DecimalType => + Seq( + If(overflow, sum, coalesce( + CheckOverflow(coalesce(sum, zero) + child.cast(sumDataType), d, true), sum)), + overflow || + coalesce(HasOverflow(coalesce(sum, zero) + child.cast(sumDataType), d), false) + ) + case _ => Seq(coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum), false) + } } else { - Seq( - /* sum = */ - coalesce(sum, zero) + child.cast(sumDataType) - ) + resultType match { + case d: DecimalType => + Seq( + If(overflow, sum, + CheckOverflow(coalesce(sum, zero) + child.cast(sumDataType), d, true)), + overflow || + coalesce(HasOverflow(coalesce(sum, zero) + child.cast(sumDataType), d), false) + ) + case _ => Seq(coalesce(sum, zero) + child.cast(sumDataType), false) + } } } + /** + * + * Decimal handling: + * If any of the left or right portion of the agg buffers has the overflow flag to true, + * then sum is set to null else sum is added for both sum.left and sum.right + * and if the value overflows it is set to null. + * If we have already seen overflow , then set overflow to true, else check if the addition + * overflowed and update the overflow buffer. + */ override lazy val mergeExpressions: Seq[Expression] = { - Seq( - /* sum = */ - coalesce(coalesce(sum.left, zero) + sum.right, sum.left) - ) + resultType match { + case d: DecimalType => + Seq( + If(coalesce(overflow.left, false) || coalesce(overflow.right, false), + Literal.create(null, d), + coalesce(CheckOverflow(coalesce(sum.left, zero) + sum.right, d, true), sum.left)), + If(coalesce(overflow.left, false) || coalesce(overflow.right, false), + true, HasOverflow(coalesce(sum.left, zero) + sum.right, d)) + ) + case _ => + Seq( + coalesce(coalesce(sum.left, zero) + sum.right, sum.left), + false + ) + } } + /** + * Decimal handling: + * If overflow buffer is true, and if ansiEnabled is true then throw exception, else return null + * If overflow did not happen, then return the sum value + */ override lazy val evaluateExpression: Expression = resultType match { - case d: DecimalType => CheckOverflow(sum, d, !SQLConf.get.ansiEnabled) + case d: DecimalType => + If(EqualTo(overflow, true), + If(!SQLConf.get.ansiEnabled, + Literal.create(null, sumDataType), + OverflowException(resultType, "Arithmetic Operation overflow")), + sum) case _ => sum } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 5d843fab3b25b..e59e3b999aa7f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1450,7 +1450,7 @@ object DecimalAggregates extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsDown { case we @ WindowExpression(ae @ AggregateExpression(af, _, _, _, _), _) => af match { - case DecimalSum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS => + case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS => MakeDecimal(we.copy(windowFunction = ae.copy(aggregateFunction = Sum(UnscaledValue(e)))), prec + 10, scale) @@ -1464,7 +1464,7 @@ object DecimalAggregates extends Rule[LogicalPlan] { case _ => we } case ae @ AggregateExpression(af, _, _, _, _) => af match { - case DecimalSum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS => + case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS => MakeDecimal(ae.copy(aggregateFunction = Sum(UnscaledValue(e))), prec + 10, scale) case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS => diff --git a/sql/core/src/test/resources/sql-tests/results/explain.sql.out b/sql/core/src/test/resources/sql-tests/results/explain.sql.out index 2b07dac0e5d0a..55c51783bd5bf 100644 --- a/sql/core/src/test/resources/sql-tests/results/explain.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/explain.sql.out @@ -918,15 +918,15 @@ Input [2]: [key#x, val#x] Input [2]: [key#x, val#x] Keys: [] Functions [3]: [partial_count(val#x), partial_sum(cast(key#x as bigint)), partial_count(key#x) FILTER (WHERE (val#x > 1))] -Aggregate Attributes [3]: [count#xL, sum#xL, count#xL] -Results [3]: [count#xL, sum#xL, count#xL] +Aggregate Attributes [4]: [count#xL, sum#xL, overflow#x, count#xL] +Results [4]: [count#xL, sum#xL, overflow#x, count#xL] (4) Exchange -Input [3]: [count#xL, sum#xL, count#xL] +Input [4]: [count#xL, sum#xL, overflow#x, count#xL] Arguments: SinglePartition, true, [id=#x] (5) HashAggregate [codegen id : 2] -Input [3]: [count#xL, sum#xL, count#xL] +Input [4]: [count#xL, sum#xL, overflow#x, count#xL] Keys: [] Functions [3]: [count(val#x), sum(cast(key#x as bigint)), count(key#x)] Aggregate Attributes [3]: [count(val#x)#xL, sum(cast(key#x as bigint))#xL, count(key#x)#xL] From cc2fec066f2ec1bfc7563ef7b37c935608656b11 Mon Sep 17 00:00:00 2001 From: Sunitha Kambhampati Date: Mon, 27 Apr 2020 12:43:30 -0700 Subject: [PATCH 08/15] Use a new flag isEmptyOrNulls to handle the scenario with all nulls as well and also to use it for identifying overflow scenarios --- .../catalyst/expressions/aggregate/Sum.scala | 108 +++++++++++------- .../expressions/decimalExpressions.scala | 29 ----- .../sql-tests/results/explain.sql.out | 8 +- .../org/apache/spark/sql/DataFrameSuite.scala | 7 +- 4 files changed, 75 insertions(+), 77 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index 2d9c0f1ad18cc..76a4f4e8ead67 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -62,91 +62,113 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast private lazy val sum = AttributeReference("sum", sumDataType)() - private lazy val overflow = AttributeReference("overflow", BooleanType, false)() + private lazy val isEmptyOrNulls = AttributeReference("isEmptyOrNulls", BooleanType, false)() private lazy val zero = Literal.default(sumDataType) - override lazy val aggBufferAttributes = sum :: overflow :: Nil + override lazy val aggBufferAttributes = sum :: isEmptyOrNulls :: Nil override lazy val initialValues: Seq[Expression] = Seq( - /* sum = */ Literal.create(null, sumDataType), - /* overflow = */ Literal.create(false, BooleanType) + /* sum = */ zero, + /* isEmptyOrNulls = */ Literal.create(true, BooleanType) ) /** - * For decimal types, update will do the following: - * We have a overflow flag and when it is true, it indicates overflow has happened - * 1. Start initial state with overflow = false, sum = null - * 2. Set sum to null if the value overflows else sum contains the intermediate sum - * 3. If overflow flag is true, keep sum as null - * 4. If overflow happened, then set overflow flag to true + * For decimal types and when child is nullable: + * isEmptyOrNulls flag is a boolean to represent if there are no rows or if all rows that + * have been seen are null. This will be used to identify if the end result of sum in + * evaluateExpression should be null or not. + * + * Update of the isEmptyOrNulls flag: + * If this flag is false, then keep it as is. + * If this flag is true, then check if the incoming value is null and if it is null, keep it + * as true else update it to false. + * Once this flag is switched to false, it will remain false. + * + * The update of the sum is as follows: + * If sum is null, then we have a case of overflow, so keep sum as is. + * If sum is not null, and the incoming value is not null, then perform the addition along + * with the overflow checking. Note, that if overflow occurs, then sum will be null here. + * If the new incoming value is null, we will keep the sum in buffer as is and skip this + * incoming null */ override lazy val updateExpressions: Seq[Expression] = { if (child.nullable) { resultType match { case d: DecimalType => Seq( - If(overflow, sum, coalesce( - CheckOverflow(coalesce(sum, zero) + child.cast(sumDataType), d, true), sum)), - overflow || - coalesce(HasOverflow(coalesce(sum, zero) + child.cast(sumDataType), d), false) + /* sum */ + If(IsNull(sum), sum, + If(IsNotNull(child.cast(sumDataType)), + CheckOverflow(sum + child.cast(sumDataType), d, true), sum)), + /* isEmptyOrNulls */ + If(isEmptyOrNulls, IsNull(child.cast(sumDataType)), isEmptyOrNulls) + ) + case _ => + Seq( + coalesce(sum + child.cast(sumDataType), sum), + If(isEmptyOrNulls, IsNull(child.cast(sumDataType)), isEmptyOrNulls) ) - case _ => Seq(coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum), false) } } else { resultType match { case d: DecimalType => Seq( - If(overflow, sum, - CheckOverflow(coalesce(sum, zero) + child.cast(sumDataType), d, true)), - overflow || - coalesce(HasOverflow(coalesce(sum, zero) + child.cast(sumDataType), d), false) + /* sum */ + If(IsNull(sum), sum, CheckOverflow(sum + child.cast(sumDataType), d, true)), + /* isEmptyOrNulls */ + false ) - case _ => Seq(coalesce(sum, zero) + child.cast(sumDataType), false) + case _ => Seq(sum + child.cast(sumDataType), false) } } } /** + * For decimal type: + * update of the sum is as follows: + * Check if either portion of the left.sum or right.sum has overflowed + * If it has, then the sum value will remain null. + * If it did not have overflow, then add the sum.left and sum.right and check for overflow. * - * Decimal handling: - * If any of the left or right portion of the agg buffers has the overflow flag to true, - * then sum is set to null else sum is added for both sum.left and sum.right - * and if the value overflows it is set to null. - * If we have already seen overflow , then set overflow to true, else check if the addition - * overflowed and update the overflow buffer. + * isEmptyOrNulls: Set to false if either one of the left or right is set to false. This + * means we have seen atleast a row that was not null. + * If the value from bufferLeft and bufferRight are both true, then this will be true. */ override lazy val mergeExpressions: Seq[Expression] = { resultType match { case d: DecimalType => Seq( - If(coalesce(overflow.left, false) || coalesce(overflow.right, false), - Literal.create(null, d), - coalesce(CheckOverflow(coalesce(sum.left, zero) + sum.right, d, true), sum.left)), - If(coalesce(overflow.left, false) || coalesce(overflow.right, false), - true, HasOverflow(coalesce(sum.left, zero) + sum.right, d)) - ) + /* sum = */ + If(And(IsNull(sum.left), EqualTo(isEmptyOrNulls.left, false)) || + And(IsNull(sum.right), EqualTo(isEmptyOrNulls.right, false)), + Literal.create(null, resultType), + CheckOverflow(sum.left + sum.right, d, true)), + /* isEmptyOrNulls = */ + And(isEmptyOrNulls.left, isEmptyOrNulls.right) + ) case _ => Seq( - coalesce(coalesce(sum.left, zero) + sum.right, sum.left), - false + coalesce(sum.left + sum.right, sum.left), + And(isEmptyOrNulls.left, isEmptyOrNulls.right) ) } } /** - * Decimal handling: - * If overflow buffer is true, and if ansiEnabled is true then throw exception, else return null - * If overflow did not happen, then return the sum value + * If the isEmptyOrNulls is true, then it means either there are no rows, or all the rows were + * null, so the result will be null. + * If the isEmptyOrNulls is false, then if sum is null that means an overflow has happened. + * So now, if ansi is enabled, then throw exception, if not then return null. + * If sum is not null, then return the sum. */ override lazy val evaluateExpression: Expression = resultType match { case d: DecimalType => - If(EqualTo(overflow, true), - If(!SQLConf.get.ansiEnabled, - Literal.create(null, sumDataType), - OverflowException(resultType, "Arithmetic Operation overflow")), - sum) - case _ => sum + If(EqualTo(isEmptyOrNulls, true), + Literal.create(null, sumDataType), + If(And(SQLConf.get.ansiEnabled, IsNull(sum)), + OverflowException(resultType, "Arithmetic Operation overflow"), sum)) + case _ => If(EqualTo(isEmptyOrNulls, true), Literal.create(null, resultType), sum) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index 603afbb315919..ddd7940fc0e19 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -146,35 +146,6 @@ case class CheckOverflow( override def sql: String = child.sql } -case class HasOverflow( - child: Expression, - inputType: DecimalType) extends UnaryExpression { - - override def dataType: DataType = BooleanType - - override def nullable: Boolean = true - - override def nullSafeEval(input: Any): Any = - !input.asInstanceOf[Decimal].changePrecision( - inputType.precision, - inputType.scale, - Decimal.ROUND_HALF_UP) - - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, eval => { - s""" - |${ev.value} = !$eval.changePrecision( - | ${inputType.precision}, ${inputType.scale}, Decimal.ROUND_HALF_UP()); - |${ev.isNull} = false; - """.stripMargin - }) - } - - override def toString: String = s"HasOverflow($child, $inputType)" - - override def sql: String = child.sql -} - case class OverflowException(dtype: DataType, msg: String) extends LeafExpression { override def dataType: DataType = dtype diff --git a/sql/core/src/test/resources/sql-tests/results/explain.sql.out b/sql/core/src/test/resources/sql-tests/results/explain.sql.out index 55c51783bd5bf..f0063cb43d54a 100644 --- a/sql/core/src/test/resources/sql-tests/results/explain.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/explain.sql.out @@ -918,15 +918,15 @@ Input [2]: [key#x, val#x] Input [2]: [key#x, val#x] Keys: [] Functions [3]: [partial_count(val#x), partial_sum(cast(key#x as bigint)), partial_count(key#x) FILTER (WHERE (val#x > 1))] -Aggregate Attributes [4]: [count#xL, sum#xL, overflow#x, count#xL] -Results [4]: [count#xL, sum#xL, overflow#x, count#xL] +Aggregate Attributes [4]: [count#xL, sum#xL, isEmptyOrNulls#x, count#xL] +Results [4]: [count#xL, sum#xL, isEmptyOrNulls#x, count#xL] (4) Exchange -Input [4]: [count#xL, sum#xL, overflow#x, count#xL] +Input [4]: [count#xL, sum#xL, isEmptyOrNulls#x, count#xL] Arguments: SinglePartition, true, [id=#x] (5) HashAggregate [codegen id : 2] -Input [4]: [count#xL, sum#xL, overflow#x, count#xL] +Input [4]: [count#xL, sum#xL, isEmptyOrNulls#x, count#xL] Keys: [] Functions [3]: [count(val#x), sum(cast(key#x as bigint)), count(key#x)] Aggregate Attributes [3]: [count(val#x)#xL, sum(cast(key#x as bigint))#xL, count(key#x)#xL] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 31eec5d36bfa5..ae82fdb79e473 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -224,7 +224,6 @@ class DataFrameSuite extends QueryTest withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled)) { val df = spark.range(1, 4, 1).select(expr(s"cast(null as decimal(38,18)) as d")) checkAnswer(df.agg(sum($"d")), Row(null)) - df.agg(sum($"d")).show } } } @@ -267,6 +266,12 @@ class DataFrameSuite extends QueryTest val d5 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as d"), lit(1).as("key")).groupBy("key").agg(sum($"d").alias("sumd")).select($"sumd") checkAnsi(d5, ansiEnabled) + + val nullsDf = spark.range(1, 4, 1).select(expr(s"cast(null as decimal(38,18)) as d")) + + val largeDecimals = Seq(BigDecimal("1"* 20 + ".123"), BigDecimal("9"* 20 + ".123")). + toDF("d") + checkAnsi(nullsDf.union(largeDecimals).agg(sum($"d")), ansiEnabled) } } } From 23739c9e1810758c640cc71ebb9db39373a05399 Mon Sep 17 00:00:00 2001 From: Sunitha Kambhampati Date: Mon, 27 Apr 2020 15:01:52 -0700 Subject: [PATCH 09/15] Rebase --- .../test/resources/sql-tests/results/explain-aqe.sql.out | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out b/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out index 36757863ffcb5..6245b035ddc90 100644 --- a/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out @@ -709,15 +709,15 @@ ReadSchema: struct Input [2]: [key#x, val#x] Keys: [] Functions [3]: [partial_count(val#x), partial_sum(cast(key#x as bigint)), partial_count(key#x) FILTER (WHERE (val#x > 1))] -Aggregate Attributes [3]: [count#xL, sum#xL, count#xL] -Results [3]: [count#xL, sum#xL, count#xL] +Aggregate Attributes [4]: [count#xL, sum#xL, isEmptyOrNulls#x, count#xL] +Results [4]: [count#xL, sum#xL, isEmptyOrNulls#x, count#xL] (3) Exchange -Input [3]: [count#xL, sum#xL, count#xL] +Input [4]: [count#xL, sum#xL, isEmptyOrNulls#x, count#xL] Arguments: SinglePartition, true, [id=#x] (4) HashAggregate -Input [3]: [count#xL, sum#xL, count#xL] +Input [4]: [count#xL, sum#xL, isEmptyOrNulls#x, count#xL] Keys: [] Functions [3]: [count(val#x), sum(cast(key#x as bigint)), count(key#x)] Aggregate Attributes [3]: [count(val#x)#xL, sum(cast(key#x as bigint))#xL, count(key#x)#xL] From fa4537821337c45b3db26d6b4042a7103d044eac Mon Sep 17 00:00:00 2001 From: Sunitha Kambhampati Date: Mon, 27 Apr 2020 23:59:41 -0700 Subject: [PATCH 10/15] rebase --- .../src/test/scala/org/apache/spark/sql/ExplainSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala index d41d624f1762d..e4c697b0f6346 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala @@ -379,12 +379,12 @@ class ExplainSuiteAE extends ExplainSuiteHelper with EnableAdaptiveExecutionSuit |""".stripMargin, s""" |(11) ShuffleQueryStage - |Output [5]: [k#x, count#xL, sum#xL, sum#x, count#xL] + |Output [6]: [k#x, count#xL, sum#xL, isEmptyOrNulls#x, sum#x, count#xL] |Arguments: 1 |""".stripMargin, s""" |(12) CustomShuffleReader - |Input [5]: [k#x, count#xL, sum#xL, sum#x, count#xL] + |Input [6]: [k#x, count#xL, sum#xL, isEmptyOrNulls#x, sum#x, count#xL] |Arguments: coalesced |""".stripMargin, s""" From fbd80a65a6c02d2513dc978ed02bd6da09609dd5 Mon Sep 17 00:00:00 2001 From: Sunitha Kambhampati Date: Thu, 14 May 2020 22:04:18 -0700 Subject: [PATCH 11/15] Add new test and also make the agg buffer changes only for decimal type --- .../catalyst/expressions/aggregate/Sum.scala | 29 +++++++------- .../sql-tests/results/explain-aqe.sql.out | 8 ++-- .../sql-tests/results/explain.sql.out | 8 ++-- .../org/apache/spark/sql/DataFrameSuite.scala | 38 +++++++++++++++---- .../org/apache/spark/sql/ExplainSuite.scala | 4 +- 5 files changed, 54 insertions(+), 33 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index 76a4f4e8ead67..b315de3091dc9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -66,12 +66,15 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast private lazy val zero = Literal.default(sumDataType) - override lazy val aggBufferAttributes = sum :: isEmptyOrNulls :: Nil + override lazy val aggBufferAttributes = resultType match { + case _: DecimalType => sum :: isEmptyOrNulls :: Nil + case _ => sum :: Nil + } - override lazy val initialValues: Seq[Expression] = Seq( - /* sum = */ zero, - /* isEmptyOrNulls = */ Literal.create(true, BooleanType) - ) + override lazy val initialValues: Seq[Expression] = resultType match { + case _: DecimalType => Seq(zero, Literal.create(true, BooleanType)) + case other => Seq(Literal.create(null, other)) + } /** * For decimal types and when child is nullable: @@ -105,10 +108,7 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast If(isEmptyOrNulls, IsNull(child.cast(sumDataType)), isEmptyOrNulls) ) case _ => - Seq( - coalesce(sum + child.cast(sumDataType), sum), - If(isEmptyOrNulls, IsNull(child.cast(sumDataType)), isEmptyOrNulls) - ) + Seq(coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum)) } } else { resultType match { @@ -119,13 +119,15 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast /* isEmptyOrNulls */ false ) - case _ => Seq(sum + child.cast(sumDataType), false) + case _ => Seq(coalesce(sum, zero) + child.cast(sumDataType)) } } } /** * For decimal type: + * If isEmptyOrNulls is false and if sum is null, then it means we have an overflow. + * * update of the sum is as follows: * Check if either portion of the left.sum or right.sum has overflowed * If it has, then the sum value will remain null. @@ -148,10 +150,7 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast And(isEmptyOrNulls.left, isEmptyOrNulls.right) ) case _ => - Seq( - coalesce(sum.left + sum.right, sum.left), - And(isEmptyOrNulls.left, isEmptyOrNulls.right) - ) + Seq(coalesce(coalesce(sum.left, zero) + sum.right, sum.left)) } } @@ -168,7 +167,7 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast Literal.create(null, sumDataType), If(And(SQLConf.get.ansiEnabled, IsNull(sum)), OverflowException(resultType, "Arithmetic Operation overflow"), sum)) - case _ => If(EqualTo(isEmptyOrNulls, true), Literal.create(null, resultType), sum) + case _ => sum } } diff --git a/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out b/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out index 6245b035ddc90..36757863ffcb5 100644 --- a/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out @@ -709,15 +709,15 @@ ReadSchema: struct Input [2]: [key#x, val#x] Keys: [] Functions [3]: [partial_count(val#x), partial_sum(cast(key#x as bigint)), partial_count(key#x) FILTER (WHERE (val#x > 1))] -Aggregate Attributes [4]: [count#xL, sum#xL, isEmptyOrNulls#x, count#xL] -Results [4]: [count#xL, sum#xL, isEmptyOrNulls#x, count#xL] +Aggregate Attributes [3]: [count#xL, sum#xL, count#xL] +Results [3]: [count#xL, sum#xL, count#xL] (3) Exchange -Input [4]: [count#xL, sum#xL, isEmptyOrNulls#x, count#xL] +Input [3]: [count#xL, sum#xL, count#xL] Arguments: SinglePartition, true, [id=#x] (4) HashAggregate -Input [4]: [count#xL, sum#xL, isEmptyOrNulls#x, count#xL] +Input [3]: [count#xL, sum#xL, count#xL] Keys: [] Functions [3]: [count(val#x), sum(cast(key#x as bigint)), count(key#x)] Aggregate Attributes [3]: [count(val#x)#xL, sum(cast(key#x as bigint))#xL, count(key#x)#xL] diff --git a/sql/core/src/test/resources/sql-tests/results/explain.sql.out b/sql/core/src/test/resources/sql-tests/results/explain.sql.out index f0063cb43d54a..2b07dac0e5d0a 100644 --- a/sql/core/src/test/resources/sql-tests/results/explain.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/explain.sql.out @@ -918,15 +918,15 @@ Input [2]: [key#x, val#x] Input [2]: [key#x, val#x] Keys: [] Functions [3]: [partial_count(val#x), partial_sum(cast(key#x as bigint)), partial_count(key#x) FILTER (WHERE (val#x > 1))] -Aggregate Attributes [4]: [count#xL, sum#xL, isEmptyOrNulls#x, count#xL] -Results [4]: [count#xL, sum#xL, isEmptyOrNulls#x, count#xL] +Aggregate Attributes [3]: [count#xL, sum#xL, count#xL] +Results [3]: [count#xL, sum#xL, count#xL] (4) Exchange -Input [4]: [count#xL, sum#xL, isEmptyOrNulls#x, count#xL] +Input [3]: [count#xL, sum#xL, count#xL] Arguments: SinglePartition, true, [id=#x] (5) HashAggregate [codegen id : 2] -Input [4]: [count#xL, sum#xL, isEmptyOrNulls#x, count#xL] +Input [3]: [count#xL, sum#xL, count#xL] Keys: [] Functions [3]: [count(val#x), sum(cast(key#x as bigint)), count(key#x)] Aggregate Attributes [3]: [count(val#x)#xL, sum(cast(key#x as bigint))#xL, count(key#x)#xL] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index ae82fdb79e473..f4119b412e7e3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -200,14 +200,14 @@ class DataFrameSuite extends QueryTest Seq(true, false).foreach { ansiEnabled => withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) { val structDf = largeDecimals.select("a").agg(sum("a")) - checkAnsi(structDf, ansiEnabled) + checkAnsi(structDf, ansiEnabled, Row(null)) } } } - private def checkAnsi(df: DataFrame, ansiEnabled: Boolean): Unit = { + private def checkAnsi(df: DataFrame, ansiEnabled: Boolean, expectedAnswer: Row ): Unit = { if (!ansiEnabled) { - checkAnswer(df, Row(null)) + checkAnswer(df, expectedAnswer) } else { val e = intercept[SparkException] { df.collect() @@ -252,26 +252,48 @@ class DataFrameSuite extends QueryTest val df = df0.union(df1) val df2 = df.withColumnRenamed("decNum", "decNum2"). join(df, "intNum").agg(sum("decNum")) - checkAnsi(df2, ansiEnabled) + + val expectedAnswer = Row(null) + checkAnsi(df2, ansiEnabled, expectedAnswer) val decStr = "1" + "0" * 19 val d1 = spark.range(0, 12, 1, 1) val d2 = d1.select(expr(s"cast('$decStr' as decimal (38, 18)) as d")).agg(sum($"d")) - checkAnsi(d2, ansiEnabled) + checkAnsi(d2, ansiEnabled, expectedAnswer) val d3 = spark.range(0, 1, 1, 1).union(spark.range(0, 11, 1, 1)) val d4 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as d")).agg(sum($"d")) - checkAnsi(d4, ansiEnabled) + checkAnsi(d4, ansiEnabled, expectedAnswer) val d5 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as d"), lit(1).as("key")).groupBy("key").agg(sum($"d").alias("sumd")).select($"sumd") - checkAnsi(d5, ansiEnabled) + checkAnsi(d5, ansiEnabled, expectedAnswer) val nullsDf = spark.range(1, 4, 1).select(expr(s"cast(null as decimal(38,18)) as d")) val largeDecimals = Seq(BigDecimal("1"* 20 + ".123"), BigDecimal("9"* 20 + ".123")). toDF("d") - checkAnsi(nullsDf.union(largeDecimals).agg(sum($"d")), ansiEnabled) + checkAnsi(nullsDf.union(largeDecimals).agg(sum($"d")), ansiEnabled, expectedAnswer) + + val df3 = Seq( + (BigDecimal("10000000000000000000"), 1), + (BigDecimal("50000000000000000000"), 1), + (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum") + + val df4 = Seq( + (BigDecimal("10000000000000000000"), 1), + (BigDecimal("10000000000000000000"), 1), + (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum") + + val df5 = Seq( + (BigDecimal("10000000000000000000"), 1), + (BigDecimal("10000000000000000000"), 1), + (BigDecimal("20000000000000000000"), 2)).toDF("decNum", "intNum") + + val df6 = df3.union(df4).union(df5) + val df7 = df6.groupBy("intNum").agg(sum("decNum"), countDistinct("decNum")). + filter("intNum == 1") + checkAnsi(df7, ansiEnabled, Row(1, null, 2)) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala index e4c697b0f6346..d41d624f1762d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala @@ -379,12 +379,12 @@ class ExplainSuiteAE extends ExplainSuiteHelper with EnableAdaptiveExecutionSuit |""".stripMargin, s""" |(11) ShuffleQueryStage - |Output [6]: [k#x, count#xL, sum#xL, isEmptyOrNulls#x, sum#x, count#xL] + |Output [5]: [k#x, count#xL, sum#xL, sum#x, count#xL] |Arguments: 1 |""".stripMargin, s""" |(12) CustomShuffleReader - |Input [6]: [k#x, count#xL, sum#xL, isEmptyOrNulls#x, sum#x, count#xL] + |Input [5]: [k#x, count#xL, sum#xL, sum#x, count#xL] |Arguments: coalesced |""".stripMargin, s""" From de2d68fd1dbc21d8f5b5e40b475c29bd113e33ec Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 21 May 2020 23:40:02 +0800 Subject: [PATCH 12/15] simplify --- .../catalyst/expressions/aggregate/Sum.scala | 60 +++++++------------ .../expressions/decimalExpressions.scala | 57 ++++++++++++++---- .../org/apache/spark/sql/types/Decimal.scala | 5 -- .../org/apache/spark/sql/DataFrameSuite.scala | 51 +++++++++------- 4 files changed, 96 insertions(+), 77 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index b315de3091dc9..d70fe92e7fe54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -62,18 +62,18 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast private lazy val sum = AttributeReference("sum", sumDataType)() - private lazy val isEmptyOrNulls = AttributeReference("isEmptyOrNulls", BooleanType, false)() + private lazy val isEmpty = AttributeReference("isEmpty", BooleanType, nullable = false)() private lazy val zero = Literal.default(sumDataType) override lazy val aggBufferAttributes = resultType match { - case _: DecimalType => sum :: isEmptyOrNulls :: Nil + case _: DecimalType => sum :: isEmpty :: Nil case _ => sum :: Nil } override lazy val initialValues: Seq[Expression] = resultType match { - case _: DecimalType => Seq(zero, Literal.create(true, BooleanType)) - case other => Seq(Literal.create(null, other)) + case _: DecimalType => Seq(Literal(null, resultType), Literal(true, BooleanType)) + case _ => Seq(Literal(null, resultType)) } /** @@ -97,29 +97,18 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast */ override lazy val updateExpressions: Seq[Expression] = { if (child.nullable) { + val updateSumExpr = coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum) resultType match { - case d: DecimalType => - Seq( - /* sum */ - If(IsNull(sum), sum, - If(IsNotNull(child.cast(sumDataType)), - CheckOverflow(sum + child.cast(sumDataType), d, true), sum)), - /* isEmptyOrNulls */ - If(isEmptyOrNulls, IsNull(child.cast(sumDataType)), isEmptyOrNulls) - ) - case _ => - Seq(coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum)) + case _: DecimalType => + Seq(updateSumExpr, isEmpty && child.isNull) + case _ => Seq(updateSumExpr) } } else { + val updateSumExpr = coalesce(sum, zero) + child.cast(sumDataType) resultType match { - case d: DecimalType => - Seq( - /* sum */ - If(IsNull(sum), sum, CheckOverflow(sum + child.cast(sumDataType), d, true)), - /* isEmptyOrNulls */ - false - ) - case _ => Seq(coalesce(sum, zero) + child.cast(sumDataType)) + case _: DecimalType => + Seq(updateSumExpr, Literal(false, BooleanType)) + case _ => Seq(updateSumExpr) } } } @@ -138,19 +127,15 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast * If the value from bufferLeft and bufferRight are both true, then this will be true. */ override lazy val mergeExpressions: Seq[Expression] = { + val mergeSumExpr = coalesce(coalesce(sum.left, zero) + sum.right, sum.left) resultType match { - case d: DecimalType => + case _: DecimalType => + val inputOverflow = !isEmpty.right && sum.right.isNull + val bufferOverflow = !isEmpty.left && sum.left.isNull Seq( - /* sum = */ - If(And(IsNull(sum.left), EqualTo(isEmptyOrNulls.left, false)) || - And(IsNull(sum.right), EqualTo(isEmptyOrNulls.right, false)), - Literal.create(null, resultType), - CheckOverflow(sum.left + sum.right, d, true)), - /* isEmptyOrNulls = */ - And(isEmptyOrNulls.left, isEmptyOrNulls.right) - ) - case _ => - Seq(coalesce(coalesce(sum.left, zero) + sum.right, sum.left)) + If(inputOverflow || bufferOverflow, Literal.create(null, sumDataType), mergeSumExpr), + isEmpty.left && isEmpty.right) + case _ => Seq(mergeSumExpr) } } @@ -163,11 +148,8 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast */ override lazy val evaluateExpression: Expression = resultType match { case d: DecimalType => - If(EqualTo(isEmptyOrNulls, true), - Literal.create(null, sumDataType), - If(And(SQLConf.get.ansiEnabled, IsNull(sum)), - OverflowException(resultType, "Arithmetic Operation overflow"), sum)) + If(isEmpty, Literal.create(null, sumDataType), + CheckOverflowInSum(sum, d, !SQLConf.get.ansiEnabled)) case _ => sum } - } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index ddd7940fc0e19..9edd5cac75c5e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, EmptyBlock, ExprCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -146,22 +146,53 @@ case class CheckOverflow( override def sql: String = child.sql } -case class OverflowException(dtype: DataType, msg: String) extends LeafExpression { - - override def dataType: DataType = dtype +// A variant `CheckOverflow`, which treats null as overflow. This is necessary in `Sum`. +case class CheckOverflowInSum( + child: Expression, + dataType: DecimalType, + nullOnOverflow: Boolean) extends UnaryExpression { - override def nullable: Boolean = false + override def nullable: Boolean = true - def eval(input: InternalRow): Any = { - Decimal.throwArithmeticException(msg) + override def eval(input: InternalRow): Any = { + val value = child.eval(input) + if (value == null) { + if (nullOnOverflow) null else throw new ArithmeticException("Overflow in sum of decimals.") + } else { + input.asInstanceOf[Decimal].toPrecision( + dataType.precision, + dataType.scale, + Decimal.ROUND_HALF_UP, + nullOnOverflow) + } } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ev.copy(code = code""" - |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - |${ev.value} = Decimal.throwArithmeticException("${msg}"); - |""", isNull = FalseLiteral) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val childGen = child.genCode(ctx) + val nullHandling = if (nullOnOverflow) { + "" + } else { + s""" + |throw new ArithmeticException("Overflow in sum of decimals."); + |""".stripMargin + } + val code = code""" + |${childGen.code} + |boolean ${ev.isNull} = ${childGen.isNull}; + |Decimal ${ev.value} = null; + |if (${childGen.isNull}) { + | $nullHandling + |} else { + | ${ev.value} = ${childGen.value}.toPrecision( + | ${dataType.precision}, ${dataType.scale}, Decimal.ROUND_HALF_UP(), $nullOnOverflow); + | ${ev.isNull} = ${ev.value} == null; + |} + |""".stripMargin + + ev.copy(code = code) } - override def toString: String = "OverflowException" + override def toString: String = s"CheckOverflowInSum($child, $dataType, $nullOnOverflow)" + + override def sql: String = child.sql } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 75e3cf4ad7a67..f32e48e1cc128 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -651,9 +651,4 @@ object Decimal { override def quot(x: Decimal, y: Decimal): Decimal = x quot y override def rem(x: Decimal, y: Decimal): Decimal = x % y } - - - def throwArithmeticException(msg: String): Decimal = { - throw new ArithmeticException(msg) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index f4119b412e7e3..bbcb9df455501 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -192,6 +192,28 @@ class DataFrameSuite extends QueryTest structDf.select(xxhash64($"a", $"record.*"))) } + private def assertDecimalSumOverflow( + df: DataFrame, ansiEnabled: Boolean, expectedAnswer: Row): Unit = { + if (!ansiEnabled) { + try { + checkAnswer(df, expectedAnswer) + } catch { + case e: SparkException if e.getCause.isInstanceOf[ArithmeticException] => + // This is an existing bug that we can write overflowed decimal to UnsafeRow but fail + // to read it. + assert(e.getCause.getMessage.contains("Decimal precision 39 exceeds max precision 38")) + } + } else { + val e = intercept[SparkException] { + df.collect + } + assert(e.getCause.isInstanceOf[ArithmeticException]) + assert(e.getCause.getMessage.contains("cannot be represented as Decimal") || + e.getCause.getMessage.contains("Overflow in sum of decimals") || + e.getCause.getMessage.contains("Decimal precision 39 exceeds max precision 38")) + } + } + test("SPARK-28224: Aggregate sum big decimal overflow") { val largeDecimals = spark.sparkContext.parallelize( DecimalData(BigDecimal("1"* 20 + ".123"), BigDecimal("1"* 20 + ".123")) :: @@ -200,24 +222,12 @@ class DataFrameSuite extends QueryTest Seq(true, false).foreach { ansiEnabled => withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) { val structDf = largeDecimals.select("a").agg(sum("a")) - checkAnsi(structDf, ansiEnabled, Row(null)) - } - } - } - - private def checkAnsi(df: DataFrame, ansiEnabled: Boolean, expectedAnswer: Row ): Unit = { - if (!ansiEnabled) { - checkAnswer(df, expectedAnswer) - } else { - val e = intercept[SparkException] { - df.collect() + assertDecimalSumOverflow(structDf, ansiEnabled, Row(null)) } - assert(e.getCause.getClass.equals(classOf[ArithmeticException])) - assert(e.getCause.getMessage.contains("Arithmetic Operation overflow")) } } - test("test sum on null decimal values") { + test("SPARK-28067: sum of null decimal values") { Seq("true", "false").foreach { wholeStageEnabled => withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStageEnabled)) { Seq("true", "false").foreach { ansiEnabled => @@ -254,26 +264,27 @@ class DataFrameSuite extends QueryTest join(df, "intNum").agg(sum("decNum")) val expectedAnswer = Row(null) - checkAnsi(df2, ansiEnabled, expectedAnswer) + assertDecimalSumOverflow(df2, ansiEnabled, expectedAnswer) val decStr = "1" + "0" * 19 val d1 = spark.range(0, 12, 1, 1) val d2 = d1.select(expr(s"cast('$decStr' as decimal (38, 18)) as d")).agg(sum($"d")) - checkAnsi(d2, ansiEnabled, expectedAnswer) + assertDecimalSumOverflow(d2, ansiEnabled, expectedAnswer) val d3 = spark.range(0, 1, 1, 1).union(spark.range(0, 11, 1, 1)) val d4 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as d")).agg(sum($"d")) - checkAnsi(d4, ansiEnabled, expectedAnswer) + assertDecimalSumOverflow(d4, ansiEnabled, expectedAnswer) val d5 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as d"), lit(1).as("key")).groupBy("key").agg(sum($"d").alias("sumd")).select($"sumd") - checkAnsi(d5, ansiEnabled, expectedAnswer) + assertDecimalSumOverflow(d5, ansiEnabled, expectedAnswer) val nullsDf = spark.range(1, 4, 1).select(expr(s"cast(null as decimal(38,18)) as d")) val largeDecimals = Seq(BigDecimal("1"* 20 + ".123"), BigDecimal("9"* 20 + ".123")). toDF("d") - checkAnsi(nullsDf.union(largeDecimals).agg(sum($"d")), ansiEnabled, expectedAnswer) + assertDecimalSumOverflow( + nullsDf.union(largeDecimals).agg(sum($"d")), ansiEnabled, expectedAnswer) val df3 = Seq( (BigDecimal("10000000000000000000"), 1), @@ -293,7 +304,7 @@ class DataFrameSuite extends QueryTest val df6 = df3.union(df4).union(df5) val df7 = df6.groupBy("intNum").agg(sum("decNum"), countDistinct("decNum")). filter("intNum == 1") - checkAnsi(df7, ansiEnabled, Row(1, null, 2)) + assertDecimalSumOverflow(df7, ansiEnabled, Row(1, null, 2)) } } } From 8339e286a15abf274226148b1bcf302a9667bbd1 Mon Sep 17 00:00:00 2001 From: Sunitha Kambhampati Date: Thu, 28 May 2020 18:54:26 -0700 Subject: [PATCH 13/15] Fix the test failure --- .../spark/sql/catalyst/expressions/decimalExpressions.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index 9edd5cac75c5e..9f0408a380f04 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -159,7 +159,7 @@ case class CheckOverflowInSum( if (value == null) { if (nullOnOverflow) null else throw new ArithmeticException("Overflow in sum of decimals.") } else { - input.asInstanceOf[Decimal].toPrecision( + value.asInstanceOf[Decimal].toPrecision( dataType.precision, dataType.scale, Decimal.ROUND_HALF_UP, From 59a00c4e1092579532c37569261fb830c194f891 Mon Sep 17 00:00:00 2001 From: Sunitha Kambhampati Date: Thu, 28 May 2020 20:10:06 -0700 Subject: [PATCH 14/15] Cleanup comments --- .../catalyst/expressions/aggregate/Sum.scala | 28 +++---------------- 1 file changed, 4 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index d70fe92e7fe54..57098d3e0fb6e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -76,25 +76,6 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast case _ => Seq(Literal(null, resultType)) } - /** - * For decimal types and when child is nullable: - * isEmptyOrNulls flag is a boolean to represent if there are no rows or if all rows that - * have been seen are null. This will be used to identify if the end result of sum in - * evaluateExpression should be null or not. - * - * Update of the isEmptyOrNulls flag: - * If this flag is false, then keep it as is. - * If this flag is true, then check if the incoming value is null and if it is null, keep it - * as true else update it to false. - * Once this flag is switched to false, it will remain false. - * - * The update of the sum is as follows: - * If sum is null, then we have a case of overflow, so keep sum as is. - * If sum is not null, and the incoming value is not null, then perform the addition along - * with the overflow checking. Note, that if overflow occurs, then sum will be null here. - * If the new incoming value is null, we will keep the sum in buffer as is and skip this - * incoming null - */ override lazy val updateExpressions: Seq[Expression] = { if (child.nullable) { val updateSumExpr = coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum) @@ -115,16 +96,15 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast /** * For decimal type: - * If isEmptyOrNulls is false and if sum is null, then it means we have an overflow. + * If isEmpty is false and if sum is null, then it means we have an overflow. * * update of the sum is as follows: * Check if either portion of the left.sum or right.sum has overflowed * If it has, then the sum value will remain null. * If it did not have overflow, then add the sum.left and sum.right and check for overflow. * - * isEmptyOrNulls: Set to false if either one of the left or right is set to false. This + * isEmpty: Set to false if either one of the left or right is set to false. This * means we have seen atleast a row that was not null. - * If the value from bufferLeft and bufferRight are both true, then this will be true. */ override lazy val mergeExpressions: Seq[Expression] = { val mergeSumExpr = coalesce(coalesce(sum.left, zero) + sum.right, sum.left) @@ -140,9 +120,9 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast } /** - * If the isEmptyOrNulls is true, then it means either there are no rows, or all the rows were + * If the isEmpty is true, then it means either there are no rows, or all the rows were * null, so the result will be null. - * If the isEmptyOrNulls is false, then if sum is null that means an overflow has happened. + * If the isEmpty is false, then if sum is null that means an overflow has happened. * So now, if ansi is enabled, then throw exception, if not then return null. * If sum is not null, then return the sum. */ From 77958880245cca238bd976900e57715f6f96a3c4 Mon Sep 17 00:00:00 2001 From: Sunitha Kambhampati Date: Mon, 1 Jun 2020 11:14:54 -0700 Subject: [PATCH 15/15] cleanup code comments --- .../spark/sql/catalyst/expressions/aggregate/Sum.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index 57098d3e0fb6e..6e850267100fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -96,15 +96,15 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast /** * For decimal type: - * If isEmpty is false and if sum is null, then it means we have an overflow. + * If isEmpty is false and if sum is null, then it means we have had an overflow. * * update of the sum is as follows: * Check if either portion of the left.sum or right.sum has overflowed * If it has, then the sum value will remain null. - * If it did not have overflow, then add the sum.left and sum.right and check for overflow. + * If it did not have overflow, then add the sum.left and sum.right * * isEmpty: Set to false if either one of the left or right is set to false. This - * means we have seen atleast a row that was not null. + * means we have seen atleast a value that was not null. */ override lazy val mergeExpressions: Seq[Expression] = { val mergeSumExpr = coalesce(coalesce(sum.left, zero) + sum.right, sum.left) @@ -120,8 +120,8 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast } /** - * If the isEmpty is true, then it means either there are no rows, or all the rows were - * null, so the result will be null. + * If the isEmpty is true, then it means there were no values to begin with or all the values + * were null, so the result will be null. * If the isEmpty is false, then if sum is null that means an overflow has happened. * So now, if ansi is enabled, then throw exception, if not then return null. * If sum is not null, then return the sum.