From cc95f00798be1ea70eb987a54af92bb9efbe5193 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sat, 13 Jul 2019 16:39:56 +0200 Subject: [PATCH] [SPARK-28369][SQL] Honor spark.sql.decimalOperations.nullOnOverflow in ScalaUDF result --- .../sql/catalyst/CatalystTypeConverters.scala | 5 +++- .../expressions/mathExpressions.scala | 1 + .../org/apache/spark/sql/types/Decimal.scala | 12 ++------- .../catalyst/expressions/ScalaUDFSuite.scala | 25 ++++++++++++++++++- 4 files changed, 31 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 6020b068155fc..488252aa0c7b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -343,6 +343,9 @@ object CatalystTypeConverters { private class DecimalConverter(dataType: DecimalType) extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] { + + private val nullOnOverflow = SQLConf.get.decimalOperationsNullOnOverflow + override def toCatalystImpl(scalaValue: Any): Decimal = { val decimal = scalaValue match { case d: BigDecimal => Decimal(d) @@ -353,7 +356,7 @@ object CatalystTypeConverters { s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + s"cannot be converted to ${dataType.catalogString}") } - decimal.toPrecision(dataType.precision, dataType.scale) + decimal.toPrecision(dataType.precision, dataType.scale, Decimal.ROUND_HALF_UP, nullOnOverflow) } override def toScala(catalystValue: Decimal): JavaBigDecimal = { if (catalystValue == null) null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index bdeb9ed29e0ac..7e39942b7a7bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -1106,6 +1106,7 @@ abstract class RoundBase(child: Expression, scale: Expression, dataType match { case DecimalType.Fixed(_, s) => val decimal = input1.asInstanceOf[Decimal] + // Overflow cannot happen, so no need to control nullOnOverflow decimal.toPrecision(decimal.precision, s, mode) case ByteType => BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, mode).toByte diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 1bf322af21799..a5d1a72d62d5c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -414,20 +414,12 @@ final class Decimal extends Ordered[Decimal] with Serializable { def floor: Decimal = if (scale == 0) this else { val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision - val res = toPrecision(newPrecision, 0, ROUND_FLOOR) - if (res == null) { - throw new AnalysisException(s"Overflow when setting precision to $newPrecision") - } - res + toPrecision(newPrecision, 0, ROUND_FLOOR, nullOnOverflow = false) } def ceil: Decimal = if (scale == 0) this else { val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision - val res = toPrecision(newPrecision, 0, ROUND_CEILING) - if (res == null) { - throw new AnalysisException(s"Overflow when setting precision to $newPrecision") - } - res + toPrecision(newPrecision, 0, ROUND_CEILING, nullOnOverflow = false) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala index df92fa3475bd9..981ef57c051fd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala @@ -21,7 +21,8 @@ import java.util.Locale import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext -import org.apache.spark.sql.types.{IntegerType, StringType} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{DecimalType, IntegerType, StringType} class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -54,4 +55,26 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper { ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil, false :: Nil).genCode(ctx) assert(ctx.inlinedMutableStates.isEmpty) } + + test("SPARK-28369: honor nullOnOverflow config for ScalaUDF") { + withSQLConf(SQLConf.DECIMAL_OPERATIONS_NULL_ON_OVERFLOW.key -> "false") { + val udf = ScalaUDF( + (a: java.math.BigDecimal) => a.multiply(new java.math.BigDecimal(100)), + DecimalType.SYSTEM_DEFAULT, + Literal(BigDecimal("12345678901234567890.123")) :: Nil, false :: Nil) + val e1 = intercept[ArithmeticException](udf.eval()) + assert(e1.getMessage.contains("cannot be represented as Decimal")) + val e2 = intercept[SparkException] { + checkEvaluationWithUnsafeProjection(udf, null) + } + assert(e2.getCause.isInstanceOf[ArithmeticException]) + } + withSQLConf(SQLConf.DECIMAL_OPERATIONS_NULL_ON_OVERFLOW.key -> "true") { + val udf = ScalaUDF( + (a: java.math.BigDecimal) => a.multiply(new java.math.BigDecimal(100)), + DecimalType.SYSTEM_DEFAULT, + Literal(BigDecimal("12345678901234567890.123")) :: Nil, false :: Nil) + checkEvaluation(udf, null) + } + } }