From d87547c05c0ab874dfce8e6ddca4ee454926b664 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 9 Feb 2018 12:40:41 +0900 Subject: [PATCH 1/3] toPandas conversion cleanup --- python/pyspark/sql/dataframe.py | 74 ++++++++++++++++++--------------- python/pyspark/sql/session.py | 4 +- python/pyspark/sql/tests.py | 22 +++++++--- 3 files changed, 60 insertions(+), 40 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index faee870a2d2e..1f6bfd308ca2 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1941,12 +1941,24 @@ def toPandas(self): timezone = None if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true": + should_fall_back = False try: - from pyspark.sql.types import _check_dataframe_convert_date, \ - _check_dataframe_localize_timestamps + from pyspark.sql.types import to_arrow_schema from pyspark.sql.utils import require_minimum_pyarrow_version - import pyarrow require_minimum_pyarrow_version() + # Check if its schema is convertible in Arrow format. + to_arrow_schema(self.schema) + except Exception as e: + # Fallback to convert to Pandas DataFrame without arrow if raise some exception + should_fall_back = True + warnings.warn( + "Arrow will not be used in toPandas: %s" % _exception_message(e)) + + if not should_fall_back: + import pyarrow + from pyspark.sql.types import _check_dataframe_convert_date, \ + _check_dataframe_localize_timestamps + tables = self._collectAsArrow() if tables: table = pyarrow.concat_tables(tables) @@ -1955,38 +1967,34 @@ def toPandas(self): return _check_dataframe_localize_timestamps(pdf, timezone) else: return pd.DataFrame.from_records([], columns=self.columns) - except ImportError as e: - msg = "note: pyarrow must be installed and available on calling Python process " \ - "if using spark.sql.execution.arrow.enabled=true" - raise ImportError("%s\n%s" % (_exception_message(e), msg)) - else: - pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) - dtype = {} + pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) + + dtype = {} + for field in self.schema: + pandas_type = _to_corrected_pandas_type(field.dataType) + # SPARK-21766: if an integer field is nullable and has null values, it can be + # inferred by pandas as float column. Once we convert the column with NaN back + # to integer type e.g., np.int16, we will hit exception. So we use the inferred + # float type, not the corrected type from the schema in this case. + if pandas_type is not None and \ + not(isinstance(field.dataType, IntegralType) and field.nullable and + pdf[field.name].isnull().any()): + dtype[field.name] = pandas_type + + for f, t in dtype.items(): + pdf[f] = pdf[f].astype(t, copy=False) + + if timezone is None: + return pdf + else: + from pyspark.sql.types import _check_series_convert_timestamps_local_tz for field in self.schema: - pandas_type = _to_corrected_pandas_type(field.dataType) - # SPARK-21766: if an integer field is nullable and has null values, it can be - # inferred by pandas as float column. Once we convert the column with NaN back - # to integer type e.g., np.int16, we will hit exception. So we use the inferred - # float type, not the corrected type from the schema in this case. - if pandas_type is not None and \ - not(isinstance(field.dataType, IntegralType) and field.nullable and - pdf[field.name].isnull().any()): - dtype[field.name] = pandas_type - - for f, t in dtype.items(): - pdf[f] = pdf[f].astype(t, copy=False) - - if timezone is None: - return pdf - else: - from pyspark.sql.types import _check_series_convert_timestamps_local_tz - for field in self.schema: - # TODO: handle nested timestamps, such as ArrayType(TimestampType())? - if isinstance(field.dataType, TimestampType): - pdf[field.name] = \ - _check_series_convert_timestamps_local_tz(pdf[field.name], timezone) - return pdf + # TODO: handle nested timestamps, such as ArrayType(TimestampType())? + if isinstance(field.dataType, TimestampType): + pdf[field.name] = \ + _check_series_convert_timestamps_local_tz(pdf[field.name], timezone) + return pdf def _collectAsArrow(self): """ diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index b3af9b82953f..c608129c283b 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -37,6 +37,7 @@ _make_type_verifier, _infer_schema, _has_nulltype, _merge_type, _create_converter, \ _parse_datatype_string from pyspark.sql.utils import install_exception_handler +from pyspark.util import _exception_message __all__ = ["SparkSession"] @@ -666,8 +667,9 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr try: return self._create_from_pandas_with_arrow(data, schema, timezone) except Exception as e: - warnings.warn("Arrow will not be used in createDataFrame: %s" % str(e)) # Fallback to create DataFrame without arrow if raise some exception + warnings.warn( + "Arrow will not be used in createDataFrame: %s" % _exception_message(e)) data = self._convert_from_pandas(data, schema, timezone) if isinstance(schema, StructType): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 6ace16955000..33154f405000 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -32,6 +32,7 @@ import datetime import array import ctypes +import warnings import py4j try: @@ -48,12 +49,12 @@ else: import unittest +from pyspark.util import _exception_message _pandas_requirement_message = None try: from pyspark.sql.utils import require_minimum_pandas_version require_minimum_pandas_version() except ImportError as e: - from pyspark.util import _exception_message # If Pandas version requirement is not satisfied, skip related tests. _pandas_requirement_message = _exception_message(e) @@ -62,7 +63,6 @@ from pyspark.sql.utils import require_minimum_pyarrow_version require_minimum_pyarrow_version() except ImportError as e: - from pyspark.util import _exception_message # If Arrow version requirement is not satisfied, skip related tests. _pyarrow_requirement_message = _exception_message(e) @@ -3437,12 +3437,22 @@ def create_pandas_data_frame(self): data_dict["4_float_t"] = np.float32(data_dict["4_float_t"]) return pd.DataFrame(data=data_dict) - def test_unsupported_datatype(self): + def test_toPandas_fallback(self): + import pandas as pd + schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) - df = self.spark.createDataFrame([(None,)], schema=schema) + df = self.spark.createDataFrame([({u'a': 1},)], schema=schema) with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'Unsupported data type'): - df.toPandas() + with warnings.catch_warnings(record=True) as warns: + pdf = df.toPandas() + # Catch and check the last UserWarning. + user_warns = [ + warn.message for warn in warns if isinstance(warn.message, UserWarning)] + self.assertTrue(len(user_warns) > 0) + self.assertTrue( + "Arrow will not be used in toPandas" in _exception_message(user_warns[-1])) + + self.assertPandasEqual(pdf, pd.DataFrame({u'map': [{u'a': 1}]})) def test_null_conversion(self): df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] + From f46540e3d306b3d41ef335b71c8240f1cd2bd3f1 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 12 Feb 2018 16:39:24 +0900 Subject: [PATCH 2/3] Fix a nit while I am here :-) --- python/pyspark/sql/tests.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 33154f405000..ef3dd5731f2c 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -50,6 +50,7 @@ import unittest from pyspark.util import _exception_message + _pandas_requirement_message = None try: from pyspark.sql.utils import require_minimum_pandas_version From 42dec467df4c332cedf474623243db7c929881d7 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 13 Feb 2018 07:59:31 +0900 Subject: [PATCH 3/3] Fix a nit --- python/pyspark/sql/dataframe.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 1f6bfd308ca2..7a547a8c3911 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1941,7 +1941,7 @@ def toPandas(self): timezone = None if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true": - should_fall_back = False + should_fallback = False try: from pyspark.sql.types import to_arrow_schema from pyspark.sql.utils import require_minimum_pyarrow_version @@ -1950,11 +1950,11 @@ def toPandas(self): to_arrow_schema(self.schema) except Exception as e: # Fallback to convert to Pandas DataFrame without arrow if raise some exception - should_fall_back = True + should_fallback = True warnings.warn( "Arrow will not be used in toPandas: %s" % _exception_message(e)) - if not should_fall_back: + if not should_fallback: import pyarrow from pyspark.sql.types import _check_dataframe_convert_date, \ _check_dataframe_localize_timestamps