From cca6757b36fbe8a73a81570625f5efa6e24bd8c6 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 2 Nov 2017 16:03:00 -0700 Subject: [PATCH 1/4] added fix for pandas timestamp to convert to microseconds for createDataFrame --- python/pyspark/sql/session.py | 34 +++++++++++++++++++++++++++++++++- python/pyspark/sql/tests.py | 10 ++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index c3dc1a46fd3c1..a7d5265d41c16 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -512,9 +512,41 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr except Exception: has_pandas = False if has_pandas and isinstance(data, pandas.DataFrame): + import numpy as np + + # Convert pandas.DataFrame to list of numpy records + np_records = data.to_records(index=False) + + # Check if any columns need to be fixed for Spark to infer properly + record_type_list = None + if schema is None and len(np_records) > 0: + cur_dtypes = np_records[0].dtype + col_names = cur_dtypes.names + record_type_list = [] + has_rec_fix = False + for i in xrange(len(cur_dtypes)): + curr_type = cur_dtypes[i] + # If type is a datetime64 timestamp, convert to microseconds + # NOTE: if dtype is M8[ns] then np.record.tolist() will output values as longs, + # this conversion will lead to an output of py datetime objects, see SPARK-22417 + if curr_type == np.dtype('M8[ns]'): + curr_type = 'M8[us]' + has_rec_fix = True + record_type_list.append((str(col_names[i]), curr_type)) + if not has_rec_fix: + record_type_list = None + + # If no schema supplied by user then get the names of columns only if schema is None: schema = [str(x) for x in data.columns] - data = [r.tolist() for r in data.to_records(index=False)] + + # Convert list of numpy records to python lists + if record_type_list is not None: + def fix_rec(rec): + return rec.astype(record_type_list) + data = [fix_rec(r).tolist() for r in np_records] + else: + data = [r.tolist() for r in np_records] if isinstance(schema, StructType): verify_func = _make_type_verifier(schema) if verifySchema else lambda _: True diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 483f39aeef66a..a94506b23ece6 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2592,6 +2592,16 @@ def test_create_dataframe_from_array_of_long(self): df = self.spark.createDataFrame(data) self.assertEqual(df.first(), Row(longarray=[-9223372036854775808, 0, 9223372036854775807])) + @unittest.skipIf(not _have_pandas, "Pandas not installed") + def test_create_dataframe_from_pandas_with_timestamp(self): + import pandas as pd + from datetime import datetime + pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)], + "d": [pd.Timestamp.now().date()]}) + df = self.spark.createDataFrame(pdf) + self.assertTrue(isinstance(df.schema['ts'].dataType, TimestampType)) + self.assertTrue(isinstance(df.schema['d'].dataType, DateType)) + class HiveSparkSubmitTests(SparkSubmitTests): From 839bb5076ceeb297299e8e4c408cf437979b2bb9 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 2 Nov 2017 16:13:20 -0700 Subject: [PATCH 2/4] did not need a func to apply fix --- python/pyspark/sql/session.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index a7d5265d41c16..8a81966c25665 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -542,9 +542,7 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr # Convert list of numpy records to python lists if record_type_list is not None: - def fix_rec(rec): - return rec.astype(record_type_list) - data = [fix_rec(r).tolist() for r in np_records] + data = [r.astype(record_type_list).tolist() for r in np_records] else: data = [r.tolist() for r in np_records] From 3944b5ca3f2ac588ee4a997ca55c10ed1456c7cf Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 2 Nov 2017 21:10:14 -0700 Subject: [PATCH 3/4] moved conversion to internal methods --- python/pyspark/sql/session.py | 79 ++++++++++++++++++++--------------- 1 file changed, 46 insertions(+), 33 deletions(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 8a81966c25665..518fb8869b1cb 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -23,6 +23,7 @@ if sys.version >= '3': basestring = unicode = str + xrange = range else: from itertools import imap as map @@ -416,6 +417,50 @@ def _createFromLocal(self, data, schema): data = [schema.toInternal(row) for row in data] return self._sc.parallelize(data), schema + def _getNumpyRecordDtypes(self, rec): + """ + Used when converting a pandas.DataFrame to Spark using to_records(), this will correct + the dtypes of records so they can be properly loaded into Spark. + :param rec: a numpy record to check dtypes + :return corrected dtypes for a numpy.record or None if no correction needed + """ + import numpy as np + cur_dtypes = rec.dtype + col_names = cur_dtypes.names + record_type_list = [] + has_rec_fix = False + for i in xrange(len(cur_dtypes)): + curr_type = cur_dtypes[i] + # If type is a datetime64 timestamp, convert to microseconds + # NOTE: if dtype is M8[ns] then np.record.tolist() will output values as longs, + # this conversion will lead to an output of py datetime objects, see SPARK-22417 + if curr_type == np.dtype('M8[ns]'): + curr_type = 'M8[us]' + has_rec_fix = True + record_type_list.append((str(col_names[i]), curr_type)) + return record_type_list if has_rec_fix else None + + def _convertFromPandas(self, pdf, schema): + """ + Convert a pandas.DataFrame to list of records that can be used to make a DataFrame + :return tuple of list of records and schema + """ + # Convert pandas.DataFrame to list of numpy records + np_records = pdf.to_records(index=False) + + # If no schema supplied by user then get the names of columns only + if schema is None: + schema = [str(x) for x in pdf.columns] + + # Check if any columns need to be fixed for Spark to infer properly + if len(np_records) > 0: + record_type_list = self._getNumpyRecordDtypes(np_records[0]) + if record_type_list is not None: + return [r.astype(record_type_list).tolist() for r in np_records], schema + + # Convert list of numpy records to python lists + return [r.tolist() for r in np_records], schema + @since(2.0) @ignore_unicode_prefix def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True): @@ -512,39 +557,7 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr except Exception: has_pandas = False if has_pandas and isinstance(data, pandas.DataFrame): - import numpy as np - - # Convert pandas.DataFrame to list of numpy records - np_records = data.to_records(index=False) - - # Check if any columns need to be fixed for Spark to infer properly - record_type_list = None - if schema is None and len(np_records) > 0: - cur_dtypes = np_records[0].dtype - col_names = cur_dtypes.names - record_type_list = [] - has_rec_fix = False - for i in xrange(len(cur_dtypes)): - curr_type = cur_dtypes[i] - # If type is a datetime64 timestamp, convert to microseconds - # NOTE: if dtype is M8[ns] then np.record.tolist() will output values as longs, - # this conversion will lead to an output of py datetime objects, see SPARK-22417 - if curr_type == np.dtype('M8[ns]'): - curr_type = 'M8[us]' - has_rec_fix = True - record_type_list.append((str(col_names[i]), curr_type)) - if not has_rec_fix: - record_type_list = None - - # If no schema supplied by user then get the names of columns only - if schema is None: - schema = [str(x) for x in data.columns] - - # Convert list of numpy records to python lists - if record_type_list is not None: - data = [r.astype(record_type_list).tolist() for r in np_records] - else: - data = [r.tolist() for r in np_records] + data, schema = self._convertFromPandas(data, schema) if isinstance(schema, StructType): verify_func = _make_type_verifier(schema) if verifySchema else lambda _: True From 355edfce1d36acf97a190799896f9d48349e2726 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 6 Nov 2017 13:00:38 -0800 Subject: [PATCH 4/4] use datetime64 instead of M8, minor fixes --- python/pyspark/sql/session.py | 30 +++++++++++++++--------------- python/pyspark/sql/tests.py | 5 +++++ 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 518fb8869b1cb..d1d0b8b8fe5d9 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -417,7 +417,7 @@ def _createFromLocal(self, data, schema): data = [schema.toInternal(row) for row in data] return self._sc.parallelize(data), schema - def _getNumpyRecordDtypes(self, rec): + def _get_numpy_record_dtypes(self, rec): """ Used when converting a pandas.DataFrame to Spark using to_records(), this will correct the dtypes of records so they can be properly loaded into Spark. @@ -432,31 +432,31 @@ def _getNumpyRecordDtypes(self, rec): for i in xrange(len(cur_dtypes)): curr_type = cur_dtypes[i] # If type is a datetime64 timestamp, convert to microseconds - # NOTE: if dtype is M8[ns] then np.record.tolist() will output values as longs, - # this conversion will lead to an output of py datetime objects, see SPARK-22417 - if curr_type == np.dtype('M8[ns]'): - curr_type = 'M8[us]' + # NOTE: if dtype is datetime[ns] then np.record.tolist() will output values as longs, + # conversion from [us] or lower will lead to py datetime objects, see SPARK-22417 + if curr_type == np.dtype('datetime64[ns]'): + curr_type = 'datetime64[us]' has_rec_fix = True record_type_list.append((str(col_names[i]), curr_type)) return record_type_list if has_rec_fix else None - def _convertFromPandas(self, pdf, schema): + def _convert_from_pandas(self, pdf, schema): """ Convert a pandas.DataFrame to list of records that can be used to make a DataFrame :return tuple of list of records and schema """ - # Convert pandas.DataFrame to list of numpy records - np_records = pdf.to_records(index=False) - # If no schema supplied by user then get the names of columns only if schema is None: schema = [str(x) for x in pdf.columns] - # Check if any columns need to be fixed for Spark to infer properly - if len(np_records) > 0: - record_type_list = self._getNumpyRecordDtypes(np_records[0]) - if record_type_list is not None: - return [r.astype(record_type_list).tolist() for r in np_records], schema + # Convert pandas.DataFrame to list of numpy records + np_records = pdf.to_records(index=False) + + # Check if any columns need to be fixed for Spark to infer properly + if len(np_records) > 0: + record_type_list = self._get_numpy_record_dtypes(np_records[0]) + if record_type_list is not None: + return [r.astype(record_type_list).tolist() for r in np_records], schema # Convert list of numpy records to python lists return [r.tolist() for r in np_records], schema @@ -557,7 +557,7 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr except Exception: has_pandas = False if has_pandas and isinstance(data, pandas.DataFrame): - data, schema = self._convertFromPandas(data, schema) + data, schema = self._convert_from_pandas(data, schema) if isinstance(schema, StructType): verify_func = _make_type_verifier(schema) if verifySchema else lambda _: True diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index a94506b23ece6..eb0d4e29a5978 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2598,9 +2598,14 @@ def test_create_dataframe_from_pandas_with_timestamp(self): from datetime import datetime pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)], "d": [pd.Timestamp.now().date()]}) + # test types are inferred correctly without specifying schema df = self.spark.createDataFrame(pdf) self.assertTrue(isinstance(df.schema['ts'].dataType, TimestampType)) self.assertTrue(isinstance(df.schema['d'].dataType, DateType)) + # test with schema will accept pdf as input + df = self.spark.createDataFrame(pdf, schema="d date, ts timestamp") + self.assertTrue(isinstance(df.schema['ts'].dataType, TimestampType)) + self.assertTrue(isinstance(df.schema['d'].dataType, DateType)) class HiveSparkSubmitTests(SparkSubmitTests):