Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,9 @@ object CatalystTypeConverters {

private class DecimalConverter(dataType: DecimalType)
extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] {

private val nullOnOverflow = SQLConf.get.decimalOperationsNullOnOverflow

override def toCatalystImpl(scalaValue: Any): Decimal = {
val decimal = scalaValue match {
case d: BigDecimal => Decimal(d)
Expand All @@ -353,7 +356,7 @@ object CatalystTypeConverters {
s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) "
+ s"cannot be converted to ${dataType.catalogString}")
}
decimal.toPrecision(dataType.precision, dataType.scale)
decimal.toPrecision(dataType.precision, dataType.scale, Decimal.ROUND_HALF_UP, nullOnOverflow)
}
override def toScala(catalystValue: Decimal): JavaBigDecimal = {
if (catalystValue == null) null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1106,6 +1106,7 @@ abstract class RoundBase(child: Expression, scale: Expression,
dataType match {
case DecimalType.Fixed(_, s) =>
val decimal = input1.asInstanceOf[Decimal]
// Overflow cannot happen, so no need to control nullOnOverflow
decimal.toPrecision(decimal.precision, s, mode)
case ByteType =>
BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, mode).toByte
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -414,20 +414,12 @@ final class Decimal extends Ordered[Decimal] with Serializable {

def floor: Decimal = if (scale == 0) this else {
val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision
val res = toPrecision(newPrecision, 0, ROUND_FLOOR)
if (res == null) {
throw new AnalysisException(s"Overflow when setting precision to $newPrecision")
}
res
toPrecision(newPrecision, 0, ROUND_FLOOR, nullOnOverflow = false)
}

def ceil: Decimal = if (scale == 0) this else {
val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision
val res = toPrecision(newPrecision, 0, ROUND_CEILING)
if (res == null) {
throw new AnalysisException(s"Overflow when setting precision to $newPrecision")
}
res
toPrecision(newPrecision, 0, ROUND_CEILING, nullOnOverflow = false)
Copy link
Member

@viirya viirya Jul 15, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This throws exception on overflow, I think it preserves current behavior, but don't we want to respect decimalOperationsNullOnOverflow here too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, I don't think that is really an issue here. I mean, I see no way ceil and floor can produce an overflow, they rather reduce the needed precision. So I think this case cannot really happen and it is fine to just throw an exception

}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ import java.util.Locale

import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.types.{IntegerType, StringType}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DecimalType, IntegerType, StringType}

class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {

Expand Down Expand Up @@ -54,4 +55,26 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {
ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil, false :: Nil).genCode(ctx)
assert(ctx.inlinedMutableStates.isEmpty)
}

test("SPARK-28369: honor nullOnOverflow config for ScalaUDF") {
withSQLConf(SQLConf.DECIMAL_OPERATIONS_NULL_ON_OVERFLOW.key -> "false") {
val udf = ScalaUDF(
(a: java.math.BigDecimal) => a.multiply(new java.math.BigDecimal(100)),
DecimalType.SYSTEM_DEFAULT,
Literal(BigDecimal("12345678901234567890.123")) :: Nil, false :: Nil)
val e1 = intercept[ArithmeticException](udf.eval())
assert(e1.getMessage.contains("cannot be represented as Decimal"))
val e2 = intercept[SparkException] {
checkEvaluationWithUnsafeProjection(udf, null)
}
assert(e2.getCause.isInstanceOf[ArithmeticException])
}
withSQLConf(SQLConf.DECIMAL_OPERATIONS_NULL_ON_OVERFLOW.key -> "true") {
val udf = ScalaUDF(
(a: java.math.BigDecimal) => a.multiply(new java.math.BigDecimal(100)),
DecimalType.SYSTEM_DEFAULT,
Literal(BigDecimal("12345678901234567890.123")) :: Nil, false :: Nil)
checkEvaluation(udf, null)
}
}
}