Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bump pyarrow to 0.12.1 #649

Merged
merged 3 commits into from
May 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions docs/sql-pyspark-pandas-with-arrow.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ displayTitle: PySpark Usage Guide for Pandas with Apache Arrow
* Table of contents
{:toc}

## Apache Arrow in Spark
## Apache Arrow in PySpark

Apache Arrow is an in-memory columnar data format that is used in Spark to efficiently transfer
data between JVM and Python processes. This currently is most beneficial to Python users that
Expand All @@ -20,7 +20,7 @@ working with Arrow-enabled data.

If you install PySpark using pip, then PyArrow can be brought in as an extra dependency of the
SQL module with the command `pip install pyspark[sql]`. Otherwise, you must ensure that PyArrow
is installed and available on all cluster nodes. The current supported version is 0.8.0.
is installed and available on all cluster nodes. The current supported version is 0.12.1.
You can install using pip or conda from the conda-forge channel. See PyArrow
[installation](https://arrow.apache.org/docs/python/install.html) for details.

Expand Down Expand Up @@ -128,8 +128,7 @@ For detailed usage, please see [`pyspark.sql.functions.pandas_udf`](api/python/p
### Supported SQL Types

Currently, all Spark SQL data types are supported by Arrow-based conversion except `MapType`,
`ArrayType` of `TimestampType`, and nested `StructType`. `BinaryType` is supported only when
installed PyArrow is equal to or higher than 0.10.0.
`ArrayType` of `TimestampType`, and nested `StructType`.

### Setting Arrow Batch Size

Expand Down
53 changes: 17 additions & 36 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,10 +265,14 @@ def __init__(self, timezone, safecheck, assign_cols_by_name):
self._safecheck = safecheck
self._assign_cols_by_name = assign_cols_by_name

def arrow_to_pandas(self, arrow_column, data_type):
from pyspark.sql.types import _arrow_column_to_pandas, _check_series_localize_timestamps
def arrow_to_pandas(self, arrow_column):
from pyspark.sql.types import _check_series_localize_timestamps

# If the given column is a date type column, creates a series of datetime.date directly
# instead of creating datetime64[ns] as intermediate data to avoid overflow caused by
# datetime64[ns] type handling.
s = arrow_column.to_pandas(date_as_object=True)

s = _arrow_column_to_pandas(arrow_column, data_type)
s = _check_series_localize_timestamps(s, self._timezone)
return s

Expand All @@ -280,8 +284,6 @@ def _create_batch(self, series):
:param series: A single pandas.Series, list of Series, or list of (series, arrow_type)
:return: Arrow RecordBatch
"""
import decimal
from distutils.version import LooseVersion
import pandas as pd
import pyarrow as pa
from pyspark.sql.types import _check_series_convert_timestamps_internal
Expand All @@ -294,24 +296,10 @@ def _create_batch(self, series):
def create_array(s, t):
mask = s.isnull()
# Ensure timestamp series are in expected form for Spark internal representation
# TODO: maybe don't need None check anymore as of Arrow 0.9.1
if t is not None and pa.types.is_timestamp(t):
s = _check_series_convert_timestamps_internal(s.fillna(0), self._timezone)
# TODO: need cast after Arrow conversion, ns values cause error with pandas 0.19.2
return pa.Array.from_pandas(s, mask=mask).cast(t, safe=False)
elif t is not None and pa.types.is_string(t) and sys.version < '3':
# TODO: need decode before converting to Arrow in Python 2
# TODO: don't need as of Arrow 0.9.1
return pa.Array.from_pandas(s.apply(
lambda v: v.decode("utf-8") if isinstance(v, str) else v), mask=mask, type=t)
elif t is not None and pa.types.is_decimal(t) and \
LooseVersion("0.9.0") <= LooseVersion(pa.__version__) < LooseVersion("0.10.0"):
# TODO: see ARROW-2432. Remove when the minimum PyArrow version becomes 0.10.0.
return pa.Array.from_pandas(s.apply(
lambda v: decimal.Decimal('NaN') if v is None else v), mask=mask, type=t)
elif LooseVersion(pa.__version__) < LooseVersion("0.11.0"):
# TODO: see ARROW-1949. Remove when the minimum PyArrow version becomes 0.11.0.
return pa.Array.from_pandas(s, mask=mask, type=t)

try:
array = pa.Array.from_pandas(s, mask=mask, type=t, safe=self._safecheck)
Expand Down Expand Up @@ -345,12 +333,7 @@ def create_array(s, t):
for i, field in enumerate(t)]

struct_arrs, struct_names = zip(*arrs_names)

# TODO: from_arrays args switched for v0.9.0, remove when bump min pyarrow version
if LooseVersion(pa.__version__) < LooseVersion("0.9.0"):
arrs.append(pa.StructArray.from_arrays(struct_names, struct_arrs))
else:
arrs.append(pa.StructArray.from_arrays(struct_arrs, struct_names))
arrs.append(pa.StructArray.from_arrays(struct_arrs, struct_names))
else:
arrs.append(create_array(s, t))

Expand All @@ -370,10 +353,8 @@ def load_stream(self, stream):
"""
batches = super(ArrowStreamPandasSerializer, self).load_stream(stream)
import pyarrow as pa
from pyspark.sql.types import from_arrow_type
for batch in batches:
yield [self.arrow_to_pandas(c, from_arrow_type(c.type))
for c in pa.Table.from_batches([batch]).itercolumns()]
yield [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns()]

def __repr__(self):
return "ArrowStreamPandasSerializer"
Expand All @@ -389,17 +370,17 @@ def __init__(self, timezone, safecheck, assign_cols_by_name, df_for_struct=False
.__init__(timezone, safecheck, assign_cols_by_name)
self._df_for_struct = df_for_struct

def arrow_to_pandas(self, arrow_column, data_type):
from pyspark.sql.types import StructType, \
_arrow_column_to_pandas, _check_dataframe_localize_timestamps
def arrow_to_pandas(self, arrow_column):
import pyarrow.types as types

if self._df_for_struct and type(data_type) == StructType:
if self._df_for_struct and types.is_struct(arrow_column.type):
import pandas as pd
series = [_arrow_column_to_pandas(column, field.dataType).rename(field.name)
for column, field in zip(arrow_column.flatten(), data_type)]
s = _check_dataframe_localize_timestamps(pd.concat(series, axis=1), self._timezone)
series = [super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(column)
.rename(field.name)
for column, field in zip(arrow_column.flatten(), arrow_column.type)]
s = pd.concat(series, axis=1)
else:
s = super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(arrow_column, data_type)
s = super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(arrow_column)
return s

def dump_stream(self, iterator, stream):
Expand Down
8 changes: 5 additions & 3 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2107,13 +2107,15 @@ def toPandas(self):
# of PyArrow is found, if 'spark.sql.execution.arrow.enabled' is enabled.
if use_arrow:
try:
from pyspark.sql.types import _arrow_table_to_pandas, \
_check_dataframe_localize_timestamps
from pyspark.sql.types import _check_dataframe_localize_timestamps
import pyarrow
batches = self._collectAsArrow()
if len(batches) > 0:
table = pyarrow.Table.from_batches(batches)
pdf = _arrow_table_to_pandas(table, self.schema)
# Pandas DataFrame created from PyArrow uses datetime64[ns] for date type
# values, but we should use datetime.date to match the behavior with when
# Arrow optimization is disabled.
pdf = table.to_pandas(date_as_object=True)
return _check_dataframe_localize_timestamps(pdf, timezone)
else:
return pd.DataFrame.from_records([], columns=self.columns)
Expand Down
6 changes: 1 addition & 5 deletions python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,11 +545,7 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone):

# Create the Spark schema from list of names passed in with Arrow types
if isinstance(schema, (list, tuple)):
if LooseVersion(pa.__version__) < LooseVersion("0.12.0"):
temp_batch = pa.RecordBatch.from_pandas(pdf[0:100], preserve_index=False)
arrow_schema = temp_batch.schema
else:
arrow_schema = pa.Schema.from_pandas(pdf, preserve_index=False)
arrow_schema = pa.Schema.from_pandas(pdf, preserve_index=False)

# TODO(rshkv): Remove when we stop supporting Python 2 (#678)
if sys.version < '3' and LooseVersion(pa.__version__) >= LooseVersion("0.10.0"):
Expand Down
39 changes: 6 additions & 33 deletions python/pyspark/sql/tests/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ class ArrowTests(ReusedSQLTestCase):
def setUpClass(cls):
from datetime import date, datetime
from decimal import Decimal
from distutils.version import LooseVersion
super(ArrowTests, cls).setUpClass()
cls.warnings_lock = threading.Lock()

Expand All @@ -69,23 +68,16 @@ def setUpClass(cls):
StructField("5_double_t", DoubleType(), True),
StructField("6_decimal_t", DecimalType(38, 18), True),
StructField("7_date_t", DateType(), True),
StructField("8_timestamp_t", TimestampType(), True)])
StructField("8_timestamp_t", TimestampType(), True),
StructField("9_binary_t", BinaryType(), True)])
cls.data = [(u"a", 1, 10, 0.2, 2.0, Decimal("2.0"),
date(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)),
date(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1), bytearray(b"a")),
(u"b", 2, 20, 0.4, 4.0, Decimal("4.0"),
date(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)),
date(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2), bytearray(b"bb")),
(u"c", 3, 30, 0.8, 6.0, Decimal("6.0"),
date(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3)),
date(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3), bytearray(b"ccc")),
(u"d", 4, 40, 1.0, 8.0, Decimal("8.0"),
date(2262, 4, 12), datetime(2262, 3, 3, 3, 3, 3))]

# TODO: remove version check once minimum pyarrow version is 0.10.0
if LooseVersion("0.10.0") <= LooseVersion(pa.__version__):
cls.schema.add(StructField("9_binary_t", BinaryType(), True))
cls.data[0] = cls.data[0] + (bytearray(b"a"),)
cls.data[1] = cls.data[1] + (bytearray(b"bb"),)
cls.data[2] = cls.data[2] + (bytearray(b"ccc"),)
cls.data[3] = cls.data[3] + (bytearray(b"dddd"),)
date(2262, 4, 12), datetime(2262, 3, 3, 3, 3, 3), bytearray(b"dddd"))]

@classmethod
def tearDownClass(cls):
Expand Down Expand Up @@ -124,23 +116,13 @@ def test_toPandas_fallback_enabled(self):
assert_frame_equal(pdf, pd.DataFrame({u'map': [{u'a': 1}]}))

def test_toPandas_fallback_disabled(self):
from distutils.version import LooseVersion

schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
df = self.spark.createDataFrame([(None,)], schema=schema)
with QuietTest(self.sc):
with self.warnings_lock:
with self.assertRaisesRegexp(Exception, 'Unsupported type'):
df.toPandas()

# TODO: remove BinaryType check once minimum pyarrow version is 0.10.0
if LooseVersion(pa.__version__) < LooseVersion("0.10.0"):
schema = StructType([StructField("binary", BinaryType(), True)])
df = self.spark.createDataFrame([(None,)], schema=schema)
with QuietTest(self.sc):
with self.assertRaisesRegexp(Exception, 'Unsupported type.*BinaryType'):
df.toPandas()

def test_null_conversion(self):
df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] +
self.data)
Expand Down Expand Up @@ -391,20 +373,11 @@ def test_createDataFrame_fallback_enabled(self):
self.assertEqual(df.collect(), [Row(a={u'a': 1})])

def test_createDataFrame_fallback_disabled(self):
from distutils.version import LooseVersion

with QuietTest(self.sc):
with self.assertRaisesRegexp(TypeError, 'Unsupported type'):
self.spark.createDataFrame(
pd.DataFrame([[{u'a': 1}]]), "a: map<string, int>")

# TODO: remove BinaryType check once minimum pyarrow version is 0.10.0
if LooseVersion(pa.__version__) < LooseVersion("0.10.0"):
with QuietTest(self.sc):
with self.assertRaisesRegexp(TypeError, 'Unsupported type.*BinaryType'):
self.spark.createDataFrame(
pd.DataFrame([[{'a': b'aaa'}]]), "a: binary")

# Regression test for SPARK-23314
def test_timestamp_dst(self):
# Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am
Expand Down
48 changes: 17 additions & 31 deletions python/pyspark/sql/tests/test_pandas_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,62 +198,48 @@ def foofoo(x, y):
)

def test_pandas_udf_detect_unsafe_type_conversion(self):
from distutils.version import LooseVersion
import pandas as pd
import numpy as np
import pyarrow as pa

values = [1.0] * 3
pdf = pd.DataFrame({'A': values})
df = self.spark.createDataFrame(pdf).repartition(1)

@pandas_udf(returnType="int")
def udf(column):
return pd.Series(np.linspace(0, 1, 3))
return pd.Series(np.linspace(0, 1, len(column)))

# Since 0.11.0, PyArrow supports the feature to raise an error for unsafe cast.
if LooseVersion(pa.__version__) >= LooseVersion("0.11.0"):
with self.sql_conf({
"spark.sql.execution.pandas.arrowSafeTypeConversion": True}):
with self.assertRaisesRegexp(Exception,
"Exception thrown when converting pandas.Series"):
df.select(['A']).withColumn('udf', udf('A')).collect()
with self.sql_conf({
"spark.sql.execution.pandas.arrowSafeTypeConversion": True}):
with self.assertRaisesRegexp(Exception,
"Exception thrown when converting pandas.Series"):
df.select(['A']).withColumn('udf', udf('A')).collect()

# Disabling Arrow safe type check.
with self.sql_conf({
"spark.sql.execution.pandas.arrowSafeTypeConversion": False}):
df.select(['A']).withColumn('udf', udf('A')).collect()

def test_pandas_udf_arrow_overflow(self):
from distutils.version import LooseVersion
import pandas as pd
import pyarrow as pa

df = self.spark.range(0, 1)

@pandas_udf(returnType="byte")
def udf(column):
return pd.Series([128])

# Arrow 0.11.0+ allows enabling or disabling safe type check.
if LooseVersion(pa.__version__) >= LooseVersion("0.11.0"):
# When enabling safe type check, Arrow 0.11.0+ disallows overflow cast.
with self.sql_conf({
"spark.sql.execution.pandas.arrowSafeTypeConversion": True}):
with self.assertRaisesRegexp(Exception,
"Exception thrown when converting pandas.Series"):
df.withColumn('udf', udf('id')).collect()

# Disabling safe type check, let Arrow do the cast anyway.
with self.sql_conf({"spark.sql.execution.pandas.arrowSafeTypeConversion": False}):
return pd.Series([128] * len(column))

# When enabling safe type check, Arrow 0.11.0+ disallows overflow cast.
with self.sql_conf({
"spark.sql.execution.pandas.arrowSafeTypeConversion": True}):
with self.assertRaisesRegexp(Exception,
"Exception thrown when converting pandas.Series"):
df.withColumn('udf', udf('id')).collect()
else:
# SQL config `arrowSafeTypeConversion` no matters for older Arrow.
# Overflow cast causes an error.
with self.sql_conf({"spark.sql.execution.pandas.arrowSafeTypeConversion": False}):
with self.assertRaisesRegexp(Exception,
"Integer value out of bounds"):
df.withColumn('udf', udf('id')).collect()

# Disabling safe type check, let Arrow do the cast anyway.
with self.sql_conf({"spark.sql.execution.pandas.arrowSafeTypeConversion": False}):
df.withColumn('udf', udf('id')).collect()


if __name__ == "__main__":
Expand Down
Loading