From cf0e445b600baad201062e77651bee3d907b42fd Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 19 Sep 2016 14:47:23 -0700 Subject: [PATCH 1/3] changePrecision() on compact decimal should respect rounding mode --- .../org/apache/spark/sql/types/Decimal.scala | 28 ++++++++++++++++--- .../apache/spark/sql/types/DecimalSuite.scala | 13 +++++++++ 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index cc8175c0a366..f25d672c36f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -242,10 +242,30 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (scale < _scale) { // Easier case: we just need to divide our scale down val diff = _scale - scale - val droppedDigits = longVal % POW_10(diff) - longVal /= POW_10(diff) - if (math.abs(droppedDigits) * 2 >= POW_10(diff)) { - longVal += (if (longVal < 0) -1L else 1L) + val pow10diff = POW_10(diff) + // % and / always round to 0 + val droppedDigits = longVal % pow10diff + longVal /= pow10diff + roundMode match { + case ROUND_FLOOR => + if (droppedDigits < 0) { + longVal += -1L + } + case ROUND_CEILING => + if (droppedDigits > 0) { + longVal += 1L + } + case ROUND_HALF_UP => + if (math.abs(droppedDigits) * 2 >= pow10diff) { + longVal += (if (longVal < 0) -1L else 1L) + } + case ROUND_HALF_EVEN => + val doubled = math.abs(droppedDigits) * 2 + if (doubled > pow10diff || doubled == pow10diff && longVal % 2 != 0) { + longVal += (if (longVal < 0) -1L else 1L) + } + case _ => + sys.error(s"Not supported rounding mode: $roundMode") } } else if (scale > _scale) { // We might be able to multiply longVal by a power of 10 and not overflow, but if not, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala index a10c0e39eb68..fca088fdfc76 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.types import org.scalatest.PrivateMethodTester import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types.Decimal._ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { /** Check that a Decimal has the given string representation, precision and scale */ @@ -191,4 +192,16 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { assert(new Decimal().set(100L, 10, 0).toUnscaledLong === 100L) assert(Decimal(Long.MaxValue, 100, 0).toUnscaledLong === Long.MaxValue) } + + test("changePrecision() on compact decimal should respect rounding mode") { + Seq(ROUND_FLOOR, ROUND_CEILING, ROUND_HALF_UP, ROUND_HALF_EVEN).foreach { mode => + Seq("1.0", "1.1", "1.6", "2.5", "5.5", "-1.0", "-1.1", "-1.6", "-2.5", "-5.5").foreach { n => + val bd = BigDecimal(n) + val unscaled = (bd * 10).toLongExact + val d = Decimal(unscaled, 8, 1) + assert(d.changePrecision(10, 0, mode)) + assert(d.toString === bd.setScale(0, mode).toString(), s"num: $n, mode: $mode") + } + } + } } From bb15eefed57ea32551d8b7985d9e54d4a512afe6 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 21 Sep 2016 12:57:50 -0700 Subject: [PATCH 2/3] fix bug --- .../scala/org/apache/spark/sql/types/Decimal.scala | 4 ++-- .../org/apache/spark/sql/types/DecimalSuite.scala | 14 ++++++++------ 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index f25d672c36f5..70859052872d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -257,12 +257,12 @@ final class Decimal extends Ordered[Decimal] with Serializable { } case ROUND_HALF_UP => if (math.abs(droppedDigits) * 2 >= pow10diff) { - longVal += (if (longVal < 0) -1L else 1L) + longVal += (if (droppedDigits < 0) -1L else 1L) } case ROUND_HALF_EVEN => val doubled = math.abs(droppedDigits) * 2 if (doubled > pow10diff || doubled == pow10diff && longVal % 2 != 0) { - longVal += (if (longVal < 0) -1L else 1L) + longVal += (if (droppedDigits < 0) -1L else 1L) } case _ => sys.error(s"Not supported rounding mode: $roundMode") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala index fca088fdfc76..102f954df63a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala @@ -195,12 +195,14 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { test("changePrecision() on compact decimal should respect rounding mode") { Seq(ROUND_FLOOR, ROUND_CEILING, ROUND_HALF_UP, ROUND_HALF_EVEN).foreach { mode => - Seq("1.0", "1.1", "1.6", "2.5", "5.5", "-1.0", "-1.1", "-1.6", "-2.5", "-5.5").foreach { n => - val bd = BigDecimal(n) - val unscaled = (bd * 10).toLongExact - val d = Decimal(unscaled, 8, 1) - assert(d.changePrecision(10, 0, mode)) - assert(d.toString === bd.setScale(0, mode).toString(), s"num: $n, mode: $mode") + Seq("0.4", "0.5", "0.6", "1.0", "1.1", "1.6", "2.5", "5.5").foreach { n => + Seq("", "-").foreach { sigh => + val bd = BigDecimal(sigh + n) + val unscaled = (bd * 10).toLongExact + val d = Decimal(unscaled, 8, 1) + assert(d.changePrecision(10, 0, mode)) + assert(d.toString === bd.setScale(0, mode).toString(), s"num: $sigh$n, mode: $mode") + } } } } From 7008cd36f21b14817c28dbb9a4ea9547174f0f01 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 21 Sep 2016 15:41:11 -0700 Subject: [PATCH 3/3] Update DecimalSuite.scala Fix typo --- .../scala/org/apache/spark/sql/types/DecimalSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala index 102f954df63a..52d0692524d0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala @@ -196,12 +196,12 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { test("changePrecision() on compact decimal should respect rounding mode") { Seq(ROUND_FLOOR, ROUND_CEILING, ROUND_HALF_UP, ROUND_HALF_EVEN).foreach { mode => Seq("0.4", "0.5", "0.6", "1.0", "1.1", "1.6", "2.5", "5.5").foreach { n => - Seq("", "-").foreach { sigh => - val bd = BigDecimal(sigh + n) + Seq("", "-").foreach { sign => + val bd = BigDecimal(sign + n) val unscaled = (bd * 10).toLongExact val d = Decimal(unscaled, 8, 1) assert(d.changePrecision(10, 0, mode)) - assert(d.toString === bd.setScale(0, mode).toString(), s"num: $sigh$n, mode: $mode") + assert(d.toString === bd.setScale(0, mode).toString(), s"num: $sign$n, mode: $mode") } } }