Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,10 @@ object RowEncoder {
"fromJavaDate",
inputObject :: Nil)

case _: DecimalType =>
case d: DecimalType =>
StaticInvoke(
Decimal.getClass,
DecimalType.SYSTEM_DEFAULT,
d,
"fromDecimal",
inputObject :: Nil)

Expand Down Expand Up @@ -162,7 +162,7 @@ object RowEncoder {
* `org.apache.spark.sql.types.Decimal`.
*/
private def externalDataTypeForInput(dt: DataType): DataType = dt match {
// In order to support both Decimal and java BigDecimal in external row, we make this
// In order to support both Decimal and java/scala BigDecimal in external row, we make this
// as java.lang.Object.
case _: DecimalType => ObjectType(classOf[java.lang.Object])
case _ => externalDataTypeFor(dt)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ object Decimal {
def fromDecimal(value: Any): Decimal = {
value match {
case j: java.math.BigDecimal => apply(j)
case d: BigDecimal => apply(d)
case d: Decimal => d
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
encodeDecodeTest(new java.lang.Double(-3.7), "boxed double")

encodeDecodeTest(BigDecimal("32131413.211321313"), "scala decimal")
// encodeDecodeTest(new java.math.BigDecimal("231341.23123"), "java decimal")
encodeDecodeTest(new java.math.BigDecimal("231341.23123"), "java decimal")

encodeDecodeTest(Decimal("32131413.211321313"), "catalyst decimal")

Expand Down Expand Up @@ -336,6 +336,7 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
Arrays.deepEquals(b1.asInstanceOf[Array[AnyRef]], b2.asInstanceOf[Array[AnyRef]])
case (b1: Array[_], b2: Array[_]) =>
Arrays.equals(b1.asInstanceOf[Array[AnyRef]], b2.asInstanceOf[Array[AnyRef]])
case (left: Comparable[Any], right: Comparable[Any]) => left.compareTo(right) == 0
case _ => input == convertedBack
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,21 +143,38 @@ class RowEncoderSuite extends SparkFunSuite {
assert(input.getStruct(0) == convertedBack.getStruct(0))
}

test("encode/decode Decimal") {
test("encode/decode decimal type") {
val schema = new StructType()
.add("int", IntegerType)
.add("string", StringType)
.add("double", DoubleType)
.add("decimal", DecimalType.SYSTEM_DEFAULT)
.add("java_decimal", DecimalType.SYSTEM_DEFAULT)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel it is still good to keep int, string, and double columns?

.add("scala_decimal", DecimalType.SYSTEM_DEFAULT)
.add("catalyst_decimal", DecimalType.SYSTEM_DEFAULT)

val encoder = RowEncoder(schema)

val input: Row = Row(100, "test", 0.123, Decimal(1234.5678))
val javaDecimal = new java.math.BigDecimal("1234.5678")
val scalaDecimal = BigDecimal("1234.5678")
val catalystDecimal = Decimal("1234.5678")

val input = Row(100, "test", 0.123, javaDecimal, scalaDecimal, catalystDecimal)
val row = encoder.toRow(input)
val convertedBack = encoder.fromRow(row)
// Decimal inside external row will be converted back to Java BigDecimal when decoding.
assert(input.get(3).asInstanceOf[Decimal].toJavaBigDecimal
.compareTo(convertedBack.getDecimal(3)) == 0)
// Decimal will be converted back to Java BigDecimal when decoding.
assert(convertedBack.getDecimal(3).compareTo(javaDecimal) == 0)
assert(convertedBack.getDecimal(4).compareTo(scalaDecimal.bigDecimal) == 0)
assert(convertedBack.getDecimal(5).compareTo(catalystDecimal.toJavaBigDecimal) == 0)
}

test("RowEncoder should preserve decimal precision and scale") {
val schema = new StructType().add("decimal", DecimalType(10, 5), false)
val encoder = RowEncoder(schema)
val decimal = Decimal("67123.45")
val input = Row(decimal)
val row = encoder.toRow(input)

assert(row.toSeq(schema).head == decimal)
}

test("RowEncoder should preserve schema nullability") {
Expand Down