Skip to content

Commit

Permalink
increase minimum version of pyarrow to 0.12.1 and remove prior workar…
Browse files Browse the repository at this point in the history
…ounds
  • Loading branch information
BryanCutler committed Apr 9, 2019
1 parent f62f44f commit 87dc661
Show file tree
Hide file tree
Showing 10 changed files with 66 additions and 213 deletions.
48 changes: 17 additions & 31 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,10 +260,14 @@ def __init__(self, timezone, safecheck, assign_cols_by_name):
self._safecheck = safecheck
self._assign_cols_by_name = assign_cols_by_name

def arrow_to_pandas(self, arrow_column, data_type):
from pyspark.sql.types import _arrow_column_to_pandas, _check_series_localize_timestamps
def arrow_to_pandas(self, arrow_column):
from pyspark.sql.types import _check_series_localize_timestamps

# If the given column is a date type column, creates a series of datetime.date directly
# instead of creating datetime64[ns] as intermediate data to avoid overflow caused by
# datetime64[ns] type handling.
s = arrow_column.to_pandas(date_as_object=True)

s = _arrow_column_to_pandas(arrow_column, data_type)
s = _check_series_localize_timestamps(s, self._timezone)
return s

Expand All @@ -275,8 +279,6 @@ def _create_batch(self, series):
:param series: A single pandas.Series, list of Series, or list of (series, arrow_type)
:return: Arrow RecordBatch
"""
import decimal
from distutils.version import LooseVersion
import pandas as pd
import pyarrow as pa
from pyspark.sql.types import _check_series_convert_timestamps_internal
Expand All @@ -289,7 +291,6 @@ def _create_batch(self, series):
def create_array(s, t):
mask = s.isnull()
# Ensure timestamp series are in expected form for Spark internal representation
# TODO: maybe don't need None check anymore as of Arrow 0.9.1
if t is not None and pa.types.is_timestamp(t):
s = _check_series_convert_timestamps_internal(s.fillna(0), self._timezone)
# TODO: need cast after Arrow conversion, ns values cause error with pandas 0.19.2
Expand All @@ -299,14 +300,6 @@ def create_array(s, t):
# TODO: don't need as of Arrow 0.9.1
return pa.Array.from_pandas(s.apply(
lambda v: v.decode("utf-8") if isinstance(v, str) else v), mask=mask, type=t)
elif t is not None and pa.types.is_decimal(t) and \
LooseVersion("0.9.0") <= LooseVersion(pa.__version__) < LooseVersion("0.10.0"):
# TODO: see ARROW-2432. Remove when the minimum PyArrow version becomes 0.10.0.
return pa.Array.from_pandas(s.apply(
lambda v: decimal.Decimal('NaN') if v is None else v), mask=mask, type=t)
elif LooseVersion(pa.__version__) < LooseVersion("0.11.0"):
# TODO: see ARROW-1949. Remove when the minimum PyArrow version becomes 0.11.0.
return pa.Array.from_pandas(s, mask=mask, type=t)

try:
array = pa.Array.from_pandas(s, mask=mask, type=t, safe=self._safecheck)
Expand Down Expand Up @@ -340,12 +333,7 @@ def create_array(s, t):
for i, field in enumerate(t)]

struct_arrs, struct_names = zip(*arrs_names)

# TODO: from_arrays args switched for v0.9.0, remove when bump min pyarrow version
if LooseVersion(pa.__version__) < LooseVersion("0.9.0"):
arrs.append(pa.StructArray.from_arrays(struct_names, struct_arrs))
else:
arrs.append(pa.StructArray.from_arrays(struct_arrs, struct_names))
arrs.append(pa.StructArray.from_arrays(struct_arrs, struct_names))
else:
arrs.append(create_array(s, t))

Expand All @@ -365,10 +353,8 @@ def load_stream(self, stream):
"""
batches = super(ArrowStreamPandasSerializer, self).load_stream(stream)
import pyarrow as pa
from pyspark.sql.types import from_arrow_type
for batch in batches:
yield [self.arrow_to_pandas(c, from_arrow_type(c.type))
for c in pa.Table.from_batches([batch]).itercolumns()]
yield [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns()]

def __repr__(self):
return "ArrowStreamPandasSerializer"
Expand All @@ -384,17 +370,17 @@ def __init__(self, timezone, safecheck, assign_cols_by_name, df_for_struct=False
.__init__(timezone, safecheck, assign_cols_by_name)
self._df_for_struct = df_for_struct

def arrow_to_pandas(self, arrow_column, data_type):
from pyspark.sql.types import StructType, \
_arrow_column_to_pandas, _check_dataframe_localize_timestamps
def arrow_to_pandas(self, arrow_column):
import pyarrow.types as types

if self._df_for_struct and type(data_type) == StructType:
if self._df_for_struct and types.is_struct(arrow_column.type):
import pandas as pd
series = [_arrow_column_to_pandas(column, field.dataType).rename(field.name)
for column, field in zip(arrow_column.flatten(), data_type)]
s = _check_dataframe_localize_timestamps(pd.concat(series, axis=1), self._timezone)
series = [super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(column)
.rename(field.name)
for column, field in zip(arrow_column.flatten(), arrow_column.type)]
s = pd.concat(series, axis=1)
else:
s = super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(arrow_column, data_type)
s = super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(arrow_column)
return s

def dump_stream(self, iterator, stream):
Expand Down
8 changes: 5 additions & 3 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2138,13 +2138,15 @@ def toPandas(self):
# of PyArrow is found, if 'spark.sql.execution.arrow.enabled' is enabled.
if use_arrow:
try:
from pyspark.sql.types import _arrow_table_to_pandas, \
_check_dataframe_localize_timestamps
from pyspark.sql.types import _check_dataframe_localize_timestamps
import pyarrow
batches = self._collectAsArrow()
if len(batches) > 0:
table = pyarrow.Table.from_batches(batches)
pdf = _arrow_table_to_pandas(table, self.schema)
# Pandas DataFrame created from PyArrow uses datetime64[ns] for date type
# values, but we should use datetime.date to match the behavior with when
# Arrow optimization is disabled.
pdf = table.to_pandas(date_as_object=True)
return _check_dataframe_localize_timestamps(pdf, timezone)
else:
return pd.DataFrame.from_records([], columns=self.columns)
Expand Down
7 changes: 1 addition & 6 deletions python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,6 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone):
to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the
data types will be used to coerce the data in Pandas to Arrow conversion.
"""
from distutils.version import LooseVersion
from pyspark.serializers import ArrowStreamPandasSerializer
from pyspark.sql.types import from_arrow_type, to_arrow_type, TimestampType
from pyspark.sql.utils import require_minimum_pandas_version, \
Expand All @@ -544,11 +543,7 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone):

# Create the Spark schema from list of names passed in with Arrow types
if isinstance(schema, (list, tuple)):
if LooseVersion(pa.__version__) < LooseVersion("0.12.0"):
temp_batch = pa.RecordBatch.from_pandas(pdf[0:100], preserve_index=False)
arrow_schema = temp_batch.schema
else:
arrow_schema = pa.Schema.from_pandas(pdf, preserve_index=False)
arrow_schema = pa.Schema.from_pandas(pdf, preserve_index=False)
struct = StructType()
for name, field in zip(schema, arrow_schema):
struct.add(name, from_arrow_type(field.type), nullable=field.nullable)
Expand Down
39 changes: 6 additions & 33 deletions python/pyspark/sql/tests/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ class ArrowTests(ReusedSQLTestCase):
def setUpClass(cls):
from datetime import date, datetime
from decimal import Decimal
from distutils.version import LooseVersion
super(ArrowTests, cls).setUpClass()
cls.warnings_lock = threading.Lock()

Expand All @@ -68,23 +67,16 @@ def setUpClass(cls):
StructField("5_double_t", DoubleType(), True),
StructField("6_decimal_t", DecimalType(38, 18), True),
StructField("7_date_t", DateType(), True),
StructField("8_timestamp_t", TimestampType(), True)])
StructField("8_timestamp_t", TimestampType(), True),
StructField("9_binary_t", BinaryType(), True)])
cls.data = [(u"a", 1, 10, 0.2, 2.0, Decimal("2.0"),
date(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)),
date(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1), bytearray(b"a")),
(u"b", 2, 20, 0.4, 4.0, Decimal("4.0"),
date(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)),
date(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2), bytearray(b"bb")),
(u"c", 3, 30, 0.8, 6.0, Decimal("6.0"),
date(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3)),
date(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3), bytearray(b"ccc")),
(u"d", 4, 40, 1.0, 8.0, Decimal("8.0"),
date(2262, 4, 12), datetime(2262, 3, 3, 3, 3, 3))]

# TODO: remove version check once minimum pyarrow version is 0.10.0
if LooseVersion("0.10.0") <= LooseVersion(pa.__version__):
cls.schema.add(StructField("9_binary_t", BinaryType(), True))
cls.data[0] = cls.data[0] + (bytearray(b"a"),)
cls.data[1] = cls.data[1] + (bytearray(b"bb"),)
cls.data[2] = cls.data[2] + (bytearray(b"ccc"),)
cls.data[3] = cls.data[3] + (bytearray(b"dddd"),)
date(2262, 4, 12), datetime(2262, 3, 3, 3, 3, 3), bytearray(b"dddd"))]

@classmethod
def tearDownClass(cls):
Expand Down Expand Up @@ -123,23 +115,13 @@ def test_toPandas_fallback_enabled(self):
assert_frame_equal(pdf, pd.DataFrame({u'map': [{u'a': 1}]}))

def test_toPandas_fallback_disabled(self):
from distutils.version import LooseVersion

schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
df = self.spark.createDataFrame([(None,)], schema=schema)
with QuietTest(self.sc):
with self.warnings_lock:
with self.assertRaisesRegexp(Exception, 'Unsupported type'):
df.toPandas()

# TODO: remove BinaryType check once minimum pyarrow version is 0.10.0
if LooseVersion(pa.__version__) < LooseVersion("0.10.0"):
schema = StructType([StructField("binary", BinaryType(), True)])
df = self.spark.createDataFrame([(None,)], schema=schema)
with QuietTest(self.sc):
with self.assertRaisesRegexp(Exception, 'Unsupported type.*BinaryType'):
df.toPandas()

def test_null_conversion(self):
df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] +
self.data)
Expand Down Expand Up @@ -348,20 +330,11 @@ def test_createDataFrame_fallback_enabled(self):
self.assertEqual(df.collect(), [Row(a={u'a': 1})])

def test_createDataFrame_fallback_disabled(self):
from distutils.version import LooseVersion

with QuietTest(self.sc):
with self.assertRaisesRegexp(TypeError, 'Unsupported type'):
self.spark.createDataFrame(
pd.DataFrame([[{u'a': 1}]]), "a: map<string, int>")

# TODO: remove BinaryType check once minimum pyarrow version is 0.10.0
if LooseVersion(pa.__version__) < LooseVersion("0.10.0"):
with QuietTest(self.sc):
with self.assertRaisesRegexp(TypeError, 'Unsupported type.*BinaryType'):
self.spark.createDataFrame(
pd.DataFrame([[{'a': b'aaa'}]]), "a: binary")

# Regression test for SPARK-23314
def test_timestamp_dst(self):
# Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am
Expand Down
48 changes: 17 additions & 31 deletions python/pyspark/sql/tests/test_pandas_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,62 +198,48 @@ def foofoo(x, y):
)

def test_pandas_udf_detect_unsafe_type_conversion(self):
from distutils.version import LooseVersion
import pandas as pd
import numpy as np
import pyarrow as pa

values = [1.0] * 3
pdf = pd.DataFrame({'A': values})
df = self.spark.createDataFrame(pdf).repartition(1)

@pandas_udf(returnType="int")
def udf(column):
return pd.Series(np.linspace(0, 1, 3))
return pd.Series(np.linspace(0, 1, len(column)))

# Since 0.11.0, PyArrow supports the feature to raise an error for unsafe cast.
if LooseVersion(pa.__version__) >= LooseVersion("0.11.0"):
with self.sql_conf({
"spark.sql.execution.pandas.arrowSafeTypeConversion": True}):
with self.assertRaisesRegexp(Exception,
"Exception thrown when converting pandas.Series"):
df.select(['A']).withColumn('udf', udf('A')).collect()
with self.sql_conf({
"spark.sql.execution.pandas.arrowSafeTypeConversion": True}):
with self.assertRaisesRegexp(Exception,
"Exception thrown when converting pandas.Series"):
df.select(['A']).withColumn('udf', udf('A')).collect()

# Disabling Arrow safe type check.
with self.sql_conf({
"spark.sql.execution.pandas.arrowSafeTypeConversion": False}):
df.select(['A']).withColumn('udf', udf('A')).collect()

def test_pandas_udf_arrow_overflow(self):
from distutils.version import LooseVersion
import pandas as pd
import pyarrow as pa

df = self.spark.range(0, 1)

@pandas_udf(returnType="byte")
def udf(column):
return pd.Series([128])

# Arrow 0.11.0+ allows enabling or disabling safe type check.
if LooseVersion(pa.__version__) >= LooseVersion("0.11.0"):
# When enabling safe type check, Arrow 0.11.0+ disallows overflow cast.
with self.sql_conf({
"spark.sql.execution.pandas.arrowSafeTypeConversion": True}):
with self.assertRaisesRegexp(Exception,
"Exception thrown when converting pandas.Series"):
df.withColumn('udf', udf('id')).collect()

# Disabling safe type check, let Arrow do the cast anyway.
with self.sql_conf({"spark.sql.execution.pandas.arrowSafeTypeConversion": False}):
return pd.Series([128] * len(column))

# When enabling safe type check, Arrow 0.11.0+ disallows overflow cast.
with self.sql_conf({
"spark.sql.execution.pandas.arrowSafeTypeConversion": True}):
with self.assertRaisesRegexp(Exception,
"Exception thrown when converting pandas.Series"):
df.withColumn('udf', udf('id')).collect()
else:
# SQL config `arrowSafeTypeConversion` no matters for older Arrow.
# Overflow cast causes an error.
with self.sql_conf({"spark.sql.execution.pandas.arrowSafeTypeConversion": False}):
with self.assertRaisesRegexp(Exception,
"Integer value out of bounds"):
df.withColumn('udf', udf('id')).collect()

# Disabling safe type check, let Arrow do the cast anyway.
with self.sql_conf({"spark.sql.execution.pandas.arrowSafeTypeConversion": False}):
df.withColumn('udf', udf('id')).collect()


if __name__ == "__main__":
Expand Down
28 changes: 9 additions & 19 deletions python/pyspark/sql/tests/test_pandas_udf_grouped_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

from collections import OrderedDict
from decimal import Decimal
from distutils.version import LooseVersion

from pyspark.sql import Row
from pyspark.sql.functions import array, explode, col, lit, udf, sum, pandas_udf, PandasUDFType
Expand Down Expand Up @@ -65,20 +64,17 @@ def test_supported_types(self):
1, 2, 3,
4, 5, 1.1,
2.2, Decimal(1.123),
[1, 2, 2], True, 'hello'
[1, 2, 2], True, 'hello',
bytearray([0x01, 0x02])
]
output_fields = [
('id', IntegerType()), ('byte', ByteType()), ('short', ShortType()),
('int', IntegerType()), ('long', LongType()), ('float', FloatType()),
('double', DoubleType()), ('decim', DecimalType(10, 3)),
('array', ArrayType(IntegerType())), ('bool', BooleanType()), ('str', StringType())
('array', ArrayType(IntegerType())), ('bool', BooleanType()), ('str', StringType()),
('bin', BinaryType())
]

# TODO: Add BinaryType to variables above once minimum pyarrow version is 0.10.0
if LooseVersion(pa.__version__) >= LooseVersion("0.10.0"):
values.append(bytearray([0x01, 0x02]))
output_fields.append(('bin', BinaryType()))

output_schema = StructType([StructField(*x) for x in output_fields])
df = self.spark.createDataFrame([values], schema=output_schema)

Expand All @@ -95,6 +91,7 @@ def test_supported_types(self):
bool=False if pdf.bool else True,
str=pdf.str + 'there',
array=pdf.array,
bin=pdf.bin
),
output_schema,
PandasUDFType.GROUPED_MAP
Expand All @@ -112,6 +109,7 @@ def test_supported_types(self):
bool=False if pdf.bool else True,
str=pdf.str + 'there',
array=pdf.array,
bin=pdf.bin
),
output_schema,
PandasUDFType.GROUPED_MAP
Expand All @@ -130,6 +128,7 @@ def test_supported_types(self):
bool=False if pdf.bool else True,
str=pdf.str + 'there',
array=pdf.array,
bin=pdf.bin
),
output_schema,
PandasUDFType.GROUPED_MAP
Expand Down Expand Up @@ -291,10 +290,6 @@ def test_unsupported_types(self):
StructField('struct', StructType([StructField('l', LongType())])),
]

# TODO: Remove this if-statement once minimum pyarrow version is 0.10.0
if LooseVersion(pa.__version__) < LooseVersion("0.10.0"):
unsupported_types.append(StructField('bin', BinaryType()))

for unsupported_type in unsupported_types:
schema = StructType([StructField('id', LongType(), True), unsupported_type])
with QuietTest(self.sc):
Expand Down Expand Up @@ -466,13 +461,8 @@ def invalid_positional_types(pdf):
with QuietTest(self.sc):
with self.assertRaisesRegexp(Exception, "KeyError: 'id'"):
grouped_df.apply(column_name_typo).collect()
if LooseVersion(pa.__version__) < LooseVersion("0.11.0"):
# TODO: see ARROW-1949. Remove when the minimum PyArrow version becomes 0.11.0.
with self.assertRaisesRegexp(Exception, "No cast implemented"):
grouped_df.apply(invalid_positional_types).collect()
else:
with self.assertRaisesRegexp(Exception, "an integer is required"):
grouped_df.apply(invalid_positional_types).collect()
with self.assertRaisesRegexp(Exception, "an integer is required"):
grouped_df.apply(invalid_positional_types).collect()

def test_positional_assignment_conf(self):
with self.sql_conf({
Expand Down
Loading

0 comments on commit 87dc661

Please sign in to comment.