From 9e60762d830c320967742d80cb17c55631f6b11a Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 26 Jul 2017 13:34:31 +0900 Subject: [PATCH 1/5] Add DecimalType support to ArrowWriter. --- .../sql/execution/arrow/ArrowWriter.scala | 21 +++++ .../arrow/ArrowConvertersSuite.scala | 82 ++++++++++++++++++- .../execution/arrow/ArrowWriterSuite.scala | 2 + 3 files changed, 103 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index 11ba04d2ce9a..33e532260ecd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -53,6 +53,8 @@ object ArrowWriter { case (LongType, vector: NullableBigIntVector) => new LongWriter(vector) case (FloatType, vector: NullableFloat4Vector) => new FloatWriter(vector) case (DoubleType, vector: NullableFloat8Vector) => new DoubleWriter(vector) + case (DecimalType.Fixed(precision, scale), vector: NullableDecimalVector) => + new DecimalWriter(vector, precision, scale) case (StringType, vector: NullableVarCharVector) => new StringWriter(vector) case (BinaryType, vector: NullableVarBinaryVector) => new BinaryWriter(vector) case (ArrayType(_, _), vector: ListVector) => @@ -224,6 +226,25 @@ private[arrow] class DoubleWriter(val valueVector: NullableFloat8Vector) extends } } +private[arrow] class DecimalWriter( + val valueVector: NullableDecimalVector, + precision: Int, + scale: Int) extends ArrowFieldWriter { + + override def valueMutator: NullableDecimalVector#Mutator = valueVector.getMutator() + + override def setNull(): Unit = { + valueMutator.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + valueMutator.setIndexDefined(count) + val decimal = input.getDecimal(ordinal, precision, scale) + decimal.changePrecision(precision, scale) + DecimalUtility.writeBigDecimalToArrowBuf(decimal.toJavaBigDecimal, valueVector.getBuffer, count) + } +} + private[arrow] class StringWriter(val valueVector: NullableVarCharVector) extends ArrowFieldWriter { override def valueMutator: NullableVarCharVector#Mutator = valueVector.getMutator() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index 4893b52f240e..464bb97e67e0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -32,7 +32,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkException import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{BinaryType, IntegerType, StructField, StructType} +import org.apache.spark.sql.types.{BinaryType, Decimal, IntegerType, StructField, StructType} import org.apache.spark.util.Utils @@ -391,6 +391,85 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { collectAndValidate(df, json, "floating_point-double_precision.json") } + test("decimal conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_d", + | "nullable" : true, + | "type" : { + | "name" : "decimal", + | "precision" : 38, + | "scale" : 18 + | }, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | }, { + | "name" : "b_d", + | "nullable" : true, + | "type" : { + | "name" : "decimal", + | "precision" : 38, + | "scale" : 18 + | }, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 6, + | "columns" : [ { + | "name" : "a_d", + | "count" : 6, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "DATA" : [ + | 1.000000000000000000, + | 2.000000000000000000, + | 0.010000000000000000, + | 200.000000000000000000, + | 0.000100000000000000, + | 20000.000000000000000000 ] + | }, { + | "name" : "b_d", + | "count" : 6, + | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], + | "DATA" : [ + | 1.100000000000000000, + | 0E-18, + | 0E-18, + | 2.200000000000000000, + | 0E-18, + | 3.300000000000000000 ] + | } ] + | } ] + |} + """.stripMargin + + val a_d = List(1.0, 2.0, 0.01, 200.0, 0.0001, 20000.0).map(Decimal(_)) + val b_d = List(Some(Decimal(1.1)), None, None, Some(Decimal(2.2)), None, Some(Decimal(3.3))) + val df = a_d.zip(b_d).toDF("a_d", "b_d") + + collectAndValidate(df, json, "decimalData.json") + } + test("index conversion") { val data = List[Int](1, 2, 3, 4, 5, 6) val json = @@ -1482,7 +1561,6 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { assert(msg.getCause.getClass === classOf[UnsupportedOperationException]) } - runUnsupported { decimalData.toArrowPayload.collect() } runUnsupported { mapData.toDF().toArrowPayload.collect() } runUnsupported { complexData.toArrowPayload.collect() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala index e9a629315f5f..4d7d36d29fa5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala @@ -49,6 +49,7 @@ class ArrowWriterSuite extends SparkFunSuite { case LongType => reader.getLong(rowId) case FloatType => reader.getFloat(rowId) case DoubleType => reader.getDouble(rowId) + case DecimalType.Fixed(precision, scale) => reader.getDecimal(rowId, precision, scale) case StringType => reader.getUTF8String(rowId) case BinaryType => reader.getBinary(rowId) } @@ -64,6 +65,7 @@ class ArrowWriterSuite extends SparkFunSuite { check(LongType, Seq(1L, 2L, null, 4L)) check(FloatType, Seq(1.0f, 2.0f, null, 4.0f)) check(DoubleType, Seq(1.0d, 2.0d, null, 4.0d)) + check(DecimalType.SYSTEM_DEFAULT, Seq(Decimal(1), Decimal(2), null, Decimal(4))) check(StringType, Seq("a", "b", null, "d").map(UTF8String.fromString)) check(BinaryType, Seq("a".getBytes(), "b".getBytes(), null, "d".getBytes())) } From 4a697b3048218df5ae8167ba88ff9a1e9bcba60b Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Tue, 12 Dec 2017 14:41:36 +0900 Subject: [PATCH 2/5] Use Arrow 0.8 APIs and fix a test. --- .../sql/execution/arrow/ArrowWriter.scala | 11 ++-- .../arrow/ArrowConvertersSuite.scala | 50 ++++++------------- 2 files changed, 20 insertions(+), 41 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index 259dee3b45d7..25c6309c5c49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -53,7 +53,7 @@ object ArrowWriter { case (LongType, vector: BigIntVector) => new LongWriter(vector) case (FloatType, vector: Float4Vector) => new FloatWriter(vector) case (DoubleType, vector: Float8Vector) => new DoubleWriter(vector) - case (DecimalType.Fixed(precision, scale), vector: NullableDecimalVector) => + case (DecimalType.Fixed(precision, scale), vector: DecimalVector) => new DecimalWriter(vector, precision, scale) case (StringType, vector: VarCharVector) => new StringWriter(vector) case (BinaryType, vector: VarBinaryVector) => new BinaryWriter(vector) @@ -217,21 +217,18 @@ private[arrow] class DoubleWriter(val valueVector: Float8Vector) extends ArrowFi } private[arrow] class DecimalWriter( - val valueVector: NullableDecimalVector, + val valueVector: DecimalVector, precision: Int, scale: Int) extends ArrowFieldWriter { - override def valueMutator: NullableDecimalVector#Mutator = valueVector.getMutator() - override def setNull(): Unit = { - valueMutator.setNull(count) + valueVector.setNull(count) } override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { - valueMutator.setIndexDefined(count) val decimal = input.getDecimal(ordinal, precision, scale) decimal.changePrecision(precision, scale) - DecimalUtility.writeBigDecimalToArrowBuf(decimal.toJavaBigDecimal, valueVector.getBuffer, count) + valueVector.setSafe(count, decimal.toJavaBigDecimal) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index 02754a81cda6..5a32809034e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -311,40 +311,22 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "schema" : { | "fields" : [ { | "name" : "a_d", - | "nullable" : true, | "type" : { | "name" : "decimal", | "precision" : 38, | "scale" : 18 | }, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 64 - | } ] - | } + | "nullable" : true, + | "children" : [ ] | }, { | "name" : "b_d", - | "nullable" : true, | "type" : { | "name" : "decimal", | "precision" : 38, | "scale" : 18 | }, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 64 - | } ] - | } + | "nullable" : true, + | "children" : [ ] | } ] | }, | "batches" : [ { @@ -354,23 +336,23 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "count" : 6, | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], | "DATA" : [ - | 1.000000000000000000, - | 2.000000000000000000, - | 0.010000000000000000, - | 200.000000000000000000, - | 0.000100000000000000, - | 20000.000000000000000000 ] + | "1000000000000000000", + | "2000000000000000000", + | "10000000000000000", + | "200000000000000000000", + | "100000000000000", + | "20000000000000000000000" ] | }, { | "name" : "b_d", | "count" : 6, | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], | "DATA" : [ - | 1.100000000000000000, - | 0E-18, - | 0E-18, - | 2.200000000000000000, - | 0E-18, - | 3.300000000000000000 ] + | "1100000000000000000", + | "0", + | "0", + | "2200000000000000000", + | "0", + | "3300000000000000000" ] | } ] | } ] |} From bcadac822d0204091ca25c8b03a812826b7319da Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 18 Dec 2017 15:46:30 +0900 Subject: [PATCH 3/5] Fix Python side. --- python/pyspark/sql/tests.py | 47 +++++++++++++++++++++++++++---------- python/pyspark/sql/types.py | 2 +- 2 files changed, 35 insertions(+), 14 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 6fdfda1cc831..39b0de1968cd 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3141,6 +3141,7 @@ class ArrowTests(ReusedSQLTestCase): @classmethod def setUpClass(cls): from datetime import datetime + from decimal import Decimal ReusedSQLTestCase.setUpClass() # Synchronize default timezone between Python and Java @@ -3157,11 +3158,15 @@ def setUpClass(cls): StructField("3_long_t", LongType(), True), StructField("4_float_t", FloatType(), True), StructField("5_double_t", DoubleType(), True), - StructField("6_date_t", DateType(), True), - StructField("7_timestamp_t", TimestampType(), True)]) - cls.data = [(u"a", 1, 10, 0.2, 2.0, datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)), - (u"b", 2, 20, 0.4, 4.0, datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)), - (u"c", 3, 30, 0.8, 6.0, datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))] + StructField("6_decimal_t", DecimalType(38, 18), True), + StructField("7_date_t", DateType(), True), + StructField("8_timestamp_t", TimestampType(), True)]) + cls.data = [(u"a", 1, 10, 0.2, 2.0, Decimal("2.0"), + datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)), + (u"b", 2, 20, 0.4, 4.0, Decimal("4.0"), + datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)), + (u"c", 3, 30, 0.8, 6.0, Decimal("6.0"), + datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))] @classmethod def tearDownClass(cls): @@ -3292,7 +3297,7 @@ def test_createDataFrame_respect_session_timezone(self): self.assertNotEqual(result_ny, result_la) # Correct result_la by adjusting 3 hours difference between Los Angeles and New York - result_la_corrected = [Row(**{k: v - timedelta(hours=3) if k == '7_timestamp_t' else v + result_la_corrected = [Row(**{k: v - timedelta(hours=3) if k == '8_timestamp_t' else v for k, v in row.asDict().items()}) for row in result_la] self.assertEqual(result_ny, result_la_corrected) @@ -3316,11 +3321,11 @@ def test_createDataFrame_with_incorrect_schema(self): def test_createDataFrame_with_names(self): pdf = self.create_pandas_data_frame() # Test that schema as a list of column names gets applied - df = self.spark.createDataFrame(pdf, schema=list('abcdefg')) - self.assertEquals(df.schema.fieldNames(), list('abcdefg')) + df = self.spark.createDataFrame(pdf, schema=list('abcdefgh')) + self.assertEquals(df.schema.fieldNames(), list('abcdefgh')) # Test that schema as tuple of column names gets applied - df = self.spark.createDataFrame(pdf, schema=tuple('abcdefg')) - self.assertEquals(df.schema.fieldNames(), list('abcdefg')) + df = self.spark.createDataFrame(pdf, schema=tuple('abcdefgh')) + self.assertEquals(df.schema.fieldNames(), list('abcdefgh')) def test_createDataFrame_column_name_encoding(self): import pandas as pd @@ -3343,7 +3348,7 @@ def test_createDataFrame_does_not_modify_input(self): # Some series get converted for Spark to consume, this makes sure input is unchanged pdf = self.create_pandas_data_frame() # Use a nanosecond value to make sure it is not truncated - pdf.ix[0, '7_timestamp_t'] = pd.Timestamp(1) + pdf.ix[0, '8_timestamp_t'] = pd.Timestamp(1) # Integers with nulls will get NaNs filled with 0 and will be casted pdf.ix[1, '2_int_t'] = None pdf_copy = pdf.copy(deep=True) @@ -3513,6 +3518,7 @@ def test_vectorized_udf_basic(self): col('id').alias('long'), col('id').cast('float').alias('float'), col('id').cast('double').alias('double'), + col('id').cast('decimal').alias('decimal'), col('id').cast('boolean').alias('bool')) f = lambda x: x str_f = pandas_udf(f, StringType()) @@ -3520,10 +3526,12 @@ def test_vectorized_udf_basic(self): long_f = pandas_udf(f, LongType()) float_f = pandas_udf(f, FloatType()) double_f = pandas_udf(f, DoubleType()) + decimal_f = pandas_udf(f, DecimalType()) bool_f = pandas_udf(f, BooleanType()) res = df.select(str_f(col('str')), int_f(col('int')), long_f(col('long')), float_f(col('float')), - double_f(col('double')), bool_f(col('bool'))) + double_f(col('double')), decimal_f('decimal'), + bool_f(col('bool'))) self.assertEquals(df.collect(), res.collect()) def test_vectorized_udf_null_boolean(self): @@ -3589,6 +3597,16 @@ def test_vectorized_udf_null_double(self): res = df.select(double_f(col('double'))) self.assertEquals(df.collect(), res.collect()) + def test_vectorized_udf_null_decimal(self): + from decimal import Decimal + from pyspark.sql.functions import pandas_udf, col + data = [(Decimal(3.0),), (Decimal(5.0),), (Decimal(-1.0),), (None,)] + schema = StructType().add("decimal", DecimalType(38, 18)) + df = self.spark.createDataFrame(data, schema) + decimal_f = pandas_udf(lambda x: x, DecimalType(38, 18)) + res = df.select(decimal_f(col('decimal'))) + self.assertEquals(df.collect(), res.collect()) + def test_vectorized_udf_null_string(self): from pyspark.sql.functions import pandas_udf, col data = [("foo",), (None,), ("bar",), ("bar",)] @@ -3606,6 +3624,7 @@ def test_vectorized_udf_datatype_string(self): col('id').alias('long'), col('id').cast('float').alias('float'), col('id').cast('double').alias('double'), + col('id').cast('decimal').alias('decimal'), col('id').cast('boolean').alias('bool')) f = lambda x: x str_f = pandas_udf(f, 'string') @@ -3613,10 +3632,12 @@ def test_vectorized_udf_datatype_string(self): long_f = pandas_udf(f, 'long') float_f = pandas_udf(f, 'float') double_f = pandas_udf(f, 'double') + decimal_f = pandas_udf(f, 'decimal(38, 18)') bool_f = pandas_udf(f, 'boolean') res = df.select(str_f(col('str')), int_f(col('int')), long_f(col('long')), float_f(col('float')), - double_f(col('double')), bool_f(col('bool'))) + double_f(col('double')), decimal_f('decimal'), + bool_f(col('bool'))) self.assertEquals(df.collect(), res.collect()) def test_vectorized_udf_complex(self): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 46d9a417414b..ecffae72c414 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1617,7 +1617,7 @@ def to_arrow_type(dt): elif type(dt) == DoubleType: arrow_type = pa.float64() elif type(dt) == DecimalType: - arrow_type = pa.decimal(dt.precision, dt.scale) + arrow_type = pa.decimal128(dt.precision, dt.scale) elif type(dt) == StringType: arrow_type = pa.string() elif type(dt) == DateType: From e29f8330e455e6bf66d2c69f2875a70ae2c71cdb Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 18 Dec 2017 17:06:14 +0900 Subject: [PATCH 4/5] Modify tests for unsupported types. --- python/pyspark/sql/tests.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 39b0de1968cd..e0d666f95550 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3194,10 +3194,11 @@ def create_pandas_data_frame(self): return pd.DataFrame(data=data_dict) def test_unsupported_datatype(self): - schema = StructType([StructField("decimal", DecimalType(), True)]) + schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) df = self.spark.createDataFrame([(None,)], schema=schema) with QuietTest(self.sc): - self.assertRaises(Exception, lambda: df.toPandas()) + with self.assertRaisesRegexp(Exception, 'Unsupported data type'): + df.toPandas() def test_null_conversion(self): df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] + @@ -3733,12 +3734,12 @@ def test_vectorized_udf_varargs(self): def test_vectorized_udf_unsupported_types(self): from pyspark.sql.functions import pandas_udf, col - schema = StructType([StructField("dt", DecimalType(), True)]) + schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) df = self.spark.createDataFrame([(None,)], schema=schema) - f = pandas_udf(lambda x: x, DecimalType()) + f = pandas_udf(lambda x: x, MapType(StringType(), IntegerType())) with QuietTest(self.sc): with self.assertRaisesRegexp(Exception, 'Unsupported data type'): - df.select(f(col('dt'))).collect() + df.select(f(col('map'))).collect() def test_vectorized_udf_null_date(self): from pyspark.sql.functions import pandas_udf, col @@ -4032,7 +4033,8 @@ def test_wrong_args(self): def test_unsupported_types(self): from pyspark.sql.functions import pandas_udf, col, PandasUDFType schema = StructType( - [StructField("id", LongType(), True), StructField("dt", DecimalType(), True)]) + [StructField("id", LongType(), True), + StructField("map", MapType(StringType(), IntegerType()), True)]) df = self.spark.createDataFrame([(1, None,)], schema=schema) f = pandas_udf(lambda x: x, df.schema, PandasUDFType.GROUP_MAP) with QuietTest(self.sc): From 025f2987c54fb3c9c7de36c525a8caaa72d1f3ee Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 25 Dec 2017 14:38:34 +0900 Subject: [PATCH 5/5] Check the return value of `changePrecision()` for the case of overflow. --- .../sql/execution/arrow/ArrowWriter.scala | 7 +++++-- .../arrow/ArrowConvertersSuite.scala | 21 +++++++++++-------- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index 25c6309c5c49..22b63513548f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -227,8 +227,11 @@ private[arrow] class DecimalWriter( override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { val decimal = input.getDecimal(ordinal, precision, scale) - decimal.changePrecision(precision, scale) - valueVector.setSafe(count, decimal.toJavaBigDecimal) + if (decimal.changePrecision(precision, scale)) { + valueVector.setSafe(count, decimal.toJavaBigDecimal) + } else { + setNull() + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index 5a32809034e7..261df06100ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -330,36 +330,39 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | } ] | }, | "batches" : [ { - | "count" : 6, + | "count" : 7, | "columns" : [ { | "name" : "a_d", - | "count" : 6, - | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ], + | "count" : 7, + | "VALIDITY" : [ 1, 1, 1, 1, 1, 1, 1 ], | "DATA" : [ | "1000000000000000000", | "2000000000000000000", | "10000000000000000", | "200000000000000000000", | "100000000000000", - | "20000000000000000000000" ] + | "20000000000000000000000", + | "30000000000000000000" ] | }, { | "name" : "b_d", - | "count" : 6, - | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ], + | "count" : 7, + | "VALIDITY" : [ 1, 0, 0, 1, 0, 1, 0 ], | "DATA" : [ | "1100000000000000000", | "0", | "0", | "2200000000000000000", | "0", - | "3300000000000000000" ] + | "3300000000000000000", + | "0" ] | } ] | } ] |} """.stripMargin - val a_d = List(1.0, 2.0, 0.01, 200.0, 0.0001, 20000.0).map(Decimal(_)) - val b_d = List(Some(Decimal(1.1)), None, None, Some(Decimal(2.2)), None, Some(Decimal(3.3))) + val a_d = List(1.0, 2.0, 0.01, 200.0, 0.0001, 20000.0, 30.0).map(Decimal(_)) + val b_d = List(Some(Decimal(1.1)), None, None, Some(Decimal(2.2)), None, Some(Decimal(3.3)), + Some(Decimal("123456789012345678901234567890"))) val df = a_d.zip(b_d).toDF("a_d", "b_d") collectAndValidate(df, json, "decimalData.json")