Skip to content

Commit

Permalink
[SPARK-41554] fix changing of Decimal scale when scale decreased by m…
Browse files Browse the repository at this point in the history
…ore than 18

This is a backport PR for #39099

Closes #39813 from fe2s/branch-3.3-fix-decimal-scaling.

Authored-by: oleksii.diagiliev <oleksii.diagiliev@workday.com>
Signed-off-by: Sean Owen <srowen@gmail.com>
  • Loading branch information
oleksii.diagiliev authored and srowen committed Feb 3, 2023
1 parent 6e0dfa9 commit 2d539c5
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -397,30 +397,42 @@ final class Decimal extends Ordered[Decimal] with Serializable {
if (scale < _scale) {
// Easier case: we just need to divide our scale down
val diff = _scale - scale
val pow10diff = POW_10(diff)
// % and / always round to 0
val droppedDigits = longVal % pow10diff
longVal /= pow10diff
roundMode match {
case ROUND_FLOOR =>
if (droppedDigits < 0) {
longVal += -1L
}
case ROUND_CEILING =>
if (droppedDigits > 0) {
longVal += 1L
}
case ROUND_HALF_UP =>
if (math.abs(droppedDigits) * 2 >= pow10diff) {
longVal += (if (droppedDigits < 0) -1L else 1L)
}
case ROUND_HALF_EVEN =>
val doubled = math.abs(droppedDigits) * 2
if (doubled > pow10diff || doubled == pow10diff && longVal % 2 != 0) {
longVal += (if (droppedDigits < 0) -1L else 1L)
}
case _ =>
throw QueryExecutionErrors.unsupportedRoundingMode(roundMode)
// If diff is greater than max number of digits we store in Long, then
// value becomes 0. Otherwise we calculate new value dividing by power of 10.
// In both cases we apply rounding after that.
if (diff > MAX_LONG_DIGITS) {
longVal = roundMode match {
case ROUND_FLOOR => if (longVal < 0) -1L else 0L
case ROUND_CEILING => if (longVal > 0) 1L else 0L
case ROUND_HALF_UP | ROUND_HALF_EVEN => 0L
case _ => sys.error(s"Not supported rounding mode: $roundMode")
}
} else {
val pow10diff = POW_10(diff)
// % and / always round to 0
val droppedDigits = longVal % pow10diff
longVal /= pow10diff
roundMode match {
case ROUND_FLOOR =>
if (droppedDigits < 0) {
longVal += -1L
}
case ROUND_CEILING =>
if (droppedDigits > 0) {
longVal += 1L
}
case ROUND_HALF_UP =>
if (math.abs(droppedDigits) * 2 >= pow10diff) {
longVal += (if (droppedDigits < 0) -1L else 1L)
}
case ROUND_HALF_EVEN =>
val doubled = math.abs(droppedDigits) * 2
if (doubled > pow10diff || doubled == pow10diff && longVal % 2 != 0) {
longVal += (if (droppedDigits < 0) -1L else 1L)
}
case _ =>
throw QueryExecutionErrors.unsupportedRoundingMode(roundMode)
}
}
} else if (scale > _scale) {
// We might be able to multiply longVal by a power of 10 and not overflow, but if not,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ import org.apache.spark.sql.types.Decimal._
import org.apache.spark.unsafe.types.UTF8String

class DecimalSuite extends SparkFunSuite with PrivateMethodTester with SQLHelper {

val allSupportedRoundModes = Seq(ROUND_HALF_UP, ROUND_HALF_EVEN, ROUND_CEILING, ROUND_FLOOR)

/** Check that a Decimal has the given string representation, precision and scale */
private def checkDecimal(d: Decimal, string: String, precision: Int, scale: Int): Unit = {
assert(d.toString === string)
Expand Down Expand Up @@ -222,7 +225,7 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester with SQLHelper
}

test("changePrecision/toPrecision on compact decimal should respect rounding mode") {
Seq(ROUND_FLOOR, ROUND_CEILING, ROUND_HALF_UP, ROUND_HALF_EVEN).foreach { mode =>
allSupportedRoundModes.foreach { mode =>
Seq("0.4", "0.5", "0.6", "1.0", "1.1", "1.6", "2.5", "5.5").foreach { n =>
Seq("", "-").foreach { sign =>
val bd = BigDecimal(sign + n)
Expand Down Expand Up @@ -315,4 +318,52 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester with SQLHelper
}
}
}

// 18 is a max number of digits in Decimal's compact long
test("SPARK-41554: decrease/increase scale by 18 and more on compact decimal") {
val unscaledNums = Seq(
0L, 1L, 10L, 51L, 123L, 523L,
// 18 digits
912345678901234567L,
112345678901234567L,
512345678901234567L
)
val precision = 38
// generate some (from, to) scale pairs, e.g. (38, 18), (-20, -2), etc
val scalePairs = for {
scale <- Seq(38, 20, 19, 18)
delta <- Seq(38, 20, 19, 18)
a = scale
b = scale - delta
} yield {
Seq((a, b), (-a, -b), (b, a), (-b, -a))
}

for {
unscaled <- unscaledNums
mode <- allSupportedRoundModes
(scaleFrom, scaleTo) <- scalePairs.flatten
sign <- Seq(1L, -1L)
} {
val unscaledWithSign = unscaled * sign
if (scaleFrom < 0 || scaleTo < 0) {
withSQLConf(SQLConf.LEGACY_ALLOW_NEGATIVE_SCALE_OF_DECIMAL_ENABLED.key -> "true") {
checkScaleChange(unscaledWithSign, scaleFrom, scaleTo, mode)
}
} else {
checkScaleChange(unscaledWithSign, scaleFrom, scaleTo, mode)
}
}

def checkScaleChange(unscaled: Long, scaleFrom: Int, scaleTo: Int,
roundMode: BigDecimal.RoundingMode.Value): Unit = {
val decimal = Decimal(unscaled, precision, scaleFrom)
checkCompact(decimal, true)
decimal.changePrecision(precision, scaleTo, roundMode)
val bd = BigDecimal(unscaled, scaleFrom).setScale(scaleTo, roundMode)
assert(decimal.toBigDecimal === bd,
s"unscaled: $unscaled, scaleFrom: $scaleFrom, scaleTo: $scaleTo, mode: $roundMode")
}
}

}

0 comments on commit 2d539c5

Please sign in to comment.