Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -242,10 +242,30 @@ final class Decimal extends Ordered[Decimal] with Serializable {
if (scale < _scale) {
// Easier case: we just need to divide our scale down
val diff = _scale - scale
val droppedDigits = longVal % POW_10(diff)
longVal /= POW_10(diff)
if (math.abs(droppedDigits) * 2 >= POW_10(diff)) {
longVal += (if (longVal < 0) -1L else 1L)
val pow10diff = POW_10(diff)
// % and / always round to 0
val droppedDigits = longVal % pow10diff
longVal /= pow10diff
roundMode match {
case ROUND_FLOOR =>
if (droppedDigits < 0) {
longVal += -1L
}
case ROUND_CEILING =>
if (droppedDigits > 0) {
longVal += 1L
}
case ROUND_HALF_UP =>
if (math.abs(droppedDigits) * 2 >= pow10diff) {
longVal += (if (droppedDigits < 0) -1L else 1L)
}
case ROUND_HALF_EVEN =>
val doubled = math.abs(droppedDigits) * 2
if (doubled > pow10diff || doubled == pow10diff && longVal % 2 != 0) {
longVal += (if (droppedDigits < 0) -1L else 1L)
}
case _ =>
sys.error(s"Not supported rounding mode: $roundMode")
}
} else if (scale > _scale) {
// We might be able to multiply longVal by a power of 10 and not overflow, but if not,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.types
import org.scalatest.PrivateMethodTester

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.Decimal._

class DecimalSuite extends SparkFunSuite with PrivateMethodTester {
/** Check that a Decimal has the given string representation, precision and scale */
Expand Down Expand Up @@ -191,4 +192,18 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester {
assert(new Decimal().set(100L, 10, 0).toUnscaledLong === 100L)
assert(Decimal(Long.MaxValue, 100, 0).toUnscaledLong === Long.MaxValue)
}

test("changePrecision() on compact decimal should respect rounding mode") {
Seq(ROUND_FLOOR, ROUND_CEILING, ROUND_HALF_UP, ROUND_HALF_EVEN).foreach { mode =>
Seq("0.4", "0.5", "0.6", "1.0", "1.1", "1.6", "2.5", "5.5").foreach { n =>
Seq("", "-").foreach { sign =>
val bd = BigDecimal(sign + n)
val unscaled = (bd * 10).toLongExact
val d = Decimal(unscaled, 8, 1)
assert(d.changePrecision(10, 0, mode))
assert(d.toString === bd.setScale(0, mode).toString(), s"num: $sign$n, mode: $mode")
}
}
}
}
}