Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 42 additions & 19 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3141,6 +3141,7 @@ class ArrowTests(ReusedSQLTestCase):
@classmethod
def setUpClass(cls):
from datetime import datetime
from decimal import Decimal
ReusedSQLTestCase.setUpClass()

# Synchronize default timezone between Python and Java
Expand All @@ -3157,11 +3158,15 @@ def setUpClass(cls):
StructField("3_long_t", LongType(), True),
StructField("4_float_t", FloatType(), True),
StructField("5_double_t", DoubleType(), True),
StructField("6_date_t", DateType(), True),
StructField("7_timestamp_t", TimestampType(), True)])
cls.data = [(u"a", 1, 10, 0.2, 2.0, datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)),
(u"b", 2, 20, 0.4, 4.0, datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)),
(u"c", 3, 30, 0.8, 6.0, datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))]
StructField("6_decimal_t", DecimalType(38, 18), True),
StructField("7_date_t", DateType(), True),
StructField("8_timestamp_t", TimestampType(), True)])
cls.data = [(u"a", 1, 10, 0.2, 2.0, Decimal("2.0"),
datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)),
(u"b", 2, 20, 0.4, 4.0, Decimal("4.0"),
datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)),
(u"c", 3, 30, 0.8, 6.0, Decimal("6.0"),
datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))]

@classmethod
def tearDownClass(cls):
Expand Down Expand Up @@ -3189,10 +3194,11 @@ def create_pandas_data_frame(self):
return pd.DataFrame(data=data_dict)

def test_unsupported_datatype(self):
schema = StructType([StructField("decimal", DecimalType(), True)])
schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
df = self.spark.createDataFrame([(None,)], schema=schema)
with QuietTest(self.sc):
self.assertRaises(Exception, lambda: df.toPandas())
with self.assertRaisesRegexp(Exception, 'Unsupported data type'):
df.toPandas()

def test_null_conversion(self):
df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] +
Expand Down Expand Up @@ -3292,7 +3298,7 @@ def test_createDataFrame_respect_session_timezone(self):
self.assertNotEqual(result_ny, result_la)

# Correct result_la by adjusting 3 hours difference between Los Angeles and New York
result_la_corrected = [Row(**{k: v - timedelta(hours=3) if k == '7_timestamp_t' else v
result_la_corrected = [Row(**{k: v - timedelta(hours=3) if k == '8_timestamp_t' else v
for k, v in row.asDict().items()})
for row in result_la]
self.assertEqual(result_ny, result_la_corrected)
Expand All @@ -3316,11 +3322,11 @@ def test_createDataFrame_with_incorrect_schema(self):
def test_createDataFrame_with_names(self):
pdf = self.create_pandas_data_frame()
# Test that schema as a list of column names gets applied
df = self.spark.createDataFrame(pdf, schema=list('abcdefg'))
self.assertEquals(df.schema.fieldNames(), list('abcdefg'))
df = self.spark.createDataFrame(pdf, schema=list('abcdefgh'))
self.assertEquals(df.schema.fieldNames(), list('abcdefgh'))
# Test that schema as tuple of column names gets applied
df = self.spark.createDataFrame(pdf, schema=tuple('abcdefg'))
self.assertEquals(df.schema.fieldNames(), list('abcdefg'))
df = self.spark.createDataFrame(pdf, schema=tuple('abcdefgh'))
self.assertEquals(df.schema.fieldNames(), list('abcdefgh'))

def test_createDataFrame_column_name_encoding(self):
import pandas as pd
Expand All @@ -3343,7 +3349,7 @@ def test_createDataFrame_does_not_modify_input(self):
# Some series get converted for Spark to consume, this makes sure input is unchanged
pdf = self.create_pandas_data_frame()
# Use a nanosecond value to make sure it is not truncated
pdf.ix[0, '7_timestamp_t'] = pd.Timestamp(1)
pdf.ix[0, '8_timestamp_t'] = pd.Timestamp(1)
# Integers with nulls will get NaNs filled with 0 and will be casted
pdf.ix[1, '2_int_t'] = None
pdf_copy = pdf.copy(deep=True)
Expand Down Expand Up @@ -3513,17 +3519,20 @@ def test_vectorized_udf_basic(self):
col('id').alias('long'),
col('id').cast('float').alias('float'),
col('id').cast('double').alias('double'),
col('id').cast('decimal').alias('decimal'),
col('id').cast('boolean').alias('bool'))
f = lambda x: x
str_f = pandas_udf(f, StringType())
int_f = pandas_udf(f, IntegerType())
long_f = pandas_udf(f, LongType())
float_f = pandas_udf(f, FloatType())
double_f = pandas_udf(f, DoubleType())
decimal_f = pandas_udf(f, DecimalType())
bool_f = pandas_udf(f, BooleanType())
res = df.select(str_f(col('str')), int_f(col('int')),
long_f(col('long')), float_f(col('float')),
double_f(col('double')), bool_f(col('bool')))
double_f(col('double')), decimal_f('decimal'),
bool_f(col('bool')))
self.assertEquals(df.collect(), res.collect())

def test_vectorized_udf_null_boolean(self):
Expand Down Expand Up @@ -3589,6 +3598,16 @@ def test_vectorized_udf_null_double(self):
res = df.select(double_f(col('double')))
self.assertEquals(df.collect(), res.collect())

def test_vectorized_udf_null_decimal(self):
from decimal import Decimal
from pyspark.sql.functions import pandas_udf, col
data = [(Decimal(3.0),), (Decimal(5.0),), (Decimal(-1.0),), (None,)]
schema = StructType().add("decimal", DecimalType(38, 18))
df = self.spark.createDataFrame(data, schema)
decimal_f = pandas_udf(lambda x: x, DecimalType(38, 18))
res = df.select(decimal_f(col('decimal')))
self.assertEquals(df.collect(), res.collect())

def test_vectorized_udf_null_string(self):
from pyspark.sql.functions import pandas_udf, col
data = [("foo",), (None,), ("bar",), ("bar",)]
Expand All @@ -3606,17 +3625,20 @@ def test_vectorized_udf_datatype_string(self):
col('id').alias('long'),
col('id').cast('float').alias('float'),
col('id').cast('double').alias('double'),
col('id').cast('decimal').alias('decimal'),
col('id').cast('boolean').alias('bool'))
f = lambda x: x
str_f = pandas_udf(f, 'string')
int_f = pandas_udf(f, 'integer')
long_f = pandas_udf(f, 'long')
float_f = pandas_udf(f, 'float')
double_f = pandas_udf(f, 'double')
decimal_f = pandas_udf(f, 'decimal(38, 18)')
bool_f = pandas_udf(f, 'boolean')
res = df.select(str_f(col('str')), int_f(col('int')),
long_f(col('long')), float_f(col('float')),
double_f(col('double')), bool_f(col('bool')))
double_f(col('double')), decimal_f('decimal'),
bool_f(col('bool')))
self.assertEquals(df.collect(), res.collect())

def test_vectorized_udf_complex(self):
Expand Down Expand Up @@ -3712,12 +3734,12 @@ def test_vectorized_udf_varargs(self):

def test_vectorized_udf_unsupported_types(self):
from pyspark.sql.functions import pandas_udf, col
schema = StructType([StructField("dt", DecimalType(), True)])
schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
df = self.spark.createDataFrame([(None,)], schema=schema)
f = pandas_udf(lambda x: x, DecimalType())
f = pandas_udf(lambda x: x, MapType(StringType(), IntegerType()))
with QuietTest(self.sc):
with self.assertRaisesRegexp(Exception, 'Unsupported data type'):
df.select(f(col('dt'))).collect()
df.select(f(col('map'))).collect()

def test_vectorized_udf_null_date(self):
from pyspark.sql.functions import pandas_udf, col
Expand Down Expand Up @@ -4011,7 +4033,8 @@ def test_wrong_args(self):
def test_unsupported_types(self):
from pyspark.sql.functions import pandas_udf, col, PandasUDFType
schema = StructType(
[StructField("id", LongType(), True), StructField("dt", DecimalType(), True)])
[StructField("id", LongType(), True),
StructField("map", MapType(StringType(), IntegerType()), True)])
df = self.spark.createDataFrame([(1, None,)], schema=schema)
f = pandas_udf(lambda x: x, df.schema, PandasUDFType.GROUP_MAP)
with QuietTest(self.sc):
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1617,7 +1617,7 @@ def to_arrow_type(dt):
elif type(dt) == DoubleType:
arrow_type = pa.float64()
elif type(dt) == DecimalType:
arrow_type = pa.decimal(dt.precision, dt.scale)
arrow_type = pa.decimal128(dt.precision, dt.scale)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wesm @BryanCutler Is this a right way to define decimal type for Arrow?
I also wonder if there is a limit for precision and scale?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, that's the right way - it is now fixed at 128 bits internally. I believe the Arrow Java limit is the same as Spark 38/38, not sure if pyarrow is the same but I think so.

elif type(dt) == StringType:
arrow_type = pa.string()
elif type(dt) == DateType:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ object ArrowWriter {
case (LongType, vector: BigIntVector) => new LongWriter(vector)
case (FloatType, vector: Float4Vector) => new FloatWriter(vector)
case (DoubleType, vector: Float8Vector) => new DoubleWriter(vector)
case (DecimalType.Fixed(precision, scale), vector: DecimalVector) =>
new DecimalWriter(vector, precision, scale)
case (StringType, vector: VarCharVector) => new StringWriter(vector)
case (BinaryType, vector: VarBinaryVector) => new BinaryWriter(vector)
case (DateType, vector: DateDayVector) => new DateWriter(vector)
Expand Down Expand Up @@ -214,6 +216,25 @@ private[arrow] class DoubleWriter(val valueVector: Float8Vector) extends ArrowFi
}
}

private[arrow] class DecimalWriter(
val valueVector: DecimalVector,
precision: Int,
scale: Int) extends ArrowFieldWriter {

override def setNull(): Unit = {
valueVector.setNull(count)
}

override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
val decimal = input.getDecimal(ordinal, precision, scale)
if (decimal.changePrecision(precision, scale)) {
valueVector.setSafe(count, decimal.toJavaBigDecimal)
} else {
setNull()
}
}
}

private[arrow] class StringWriter(val valueVector: VarCharVector) extends ArrowFieldWriter {

override def setNull(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{BinaryType, IntegerType, StructField, StructType}
import org.apache.spark.sql.types.{BinaryType, Decimal, IntegerType, StructField, StructType}
import org.apache.spark.util.Utils


Expand Down Expand Up @@ -304,6 +304,70 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll {
collectAndValidate(df, json, "floating_point-double_precision.json")
}

test("decimal conversion") {
val json =
s"""
|{
| "schema" : {
| "fields" : [ {
| "name" : "a_d",
| "type" : {
| "name" : "decimal",
| "precision" : 38,
| "scale" : 18
| },
| "nullable" : true,
| "children" : [ ]
| }, {
| "name" : "b_d",
| "type" : {
| "name" : "decimal",
| "precision" : 38,
| "scale" : 18
| },
| "nullable" : true,
| "children" : [ ]
| } ]
| },
| "batches" : [ {
| "count" : 7,
| "columns" : [ {
| "name" : "a_d",
| "count" : 7,
| "VALIDITY" : [ 1, 1, 1, 1, 1, 1, 1 ],
| "DATA" : [
| "1000000000000000000",
| "2000000000000000000",
| "10000000000000000",
| "200000000000000000000",
| "100000000000000",
| "20000000000000000000000",
| "30000000000000000000" ]
| }, {
| "name" : "b_d",
| "count" : 7,
| "VALIDITY" : [ 1, 0, 0, 1, 0, 1, 0 ],
| "DATA" : [
| "1100000000000000000",
| "0",
| "0",
| "2200000000000000000",
| "0",
| "3300000000000000000",
| "0" ]
| } ]
| } ]
|}
""".stripMargin

val a_d = List(1.0, 2.0, 0.01, 200.0, 0.0001, 20000.0, 30.0).map(Decimal(_))
val b_d = List(Some(Decimal(1.1)), None, None, Some(Decimal(2.2)), None, Some(Decimal(3.3)),
Some(Decimal("123456789012345678901234567890")))
val df = a_d.zip(b_d).toDF("a_d", "b_d")

collectAndValidate(df, json, "decimalData.json")
}

test("index conversion") {
val data = List[Int](1, 2, 3, 4, 5, 6)
val json =
Expand Down Expand Up @@ -1153,7 +1217,6 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll {
assert(msg.getCause.getClass === classOf[UnsupportedOperationException])
}

runUnsupported { decimalData.toArrowPayload.collect() }
runUnsupported { mapData.toDF().toArrowPayload.collect() }
runUnsupported { complexData.toArrowPayload.collect() }
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class ArrowWriterSuite extends SparkFunSuite {
case LongType => reader.getLong(rowId)
case FloatType => reader.getFloat(rowId)
case DoubleType => reader.getDouble(rowId)
case DecimalType.Fixed(precision, scale) => reader.getDecimal(rowId, precision, scale)
case StringType => reader.getUTF8String(rowId)
case BinaryType => reader.getBinary(rowId)
case DateType => reader.getInt(rowId)
Expand All @@ -66,6 +67,7 @@ class ArrowWriterSuite extends SparkFunSuite {
check(LongType, Seq(1L, 2L, null, 4L))
check(FloatType, Seq(1.0f, 2.0f, null, 4.0f))
check(DoubleType, Seq(1.0d, 2.0d, null, 4.0d))
check(DecimalType.SYSTEM_DEFAULT, Seq(Decimal(1), Decimal(2), null, Decimal(4)))
check(StringType, Seq("a", "b", null, "d").map(UTF8String.fromString))
check(BinaryType, Seq("a".getBytes(), "b".getBytes(), null, "d".getBytes()))
check(DateType, Seq(0, 1, 2, null, 4))
Expand Down