diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index c3dc1a46fd3c1..d1d0b8b8fe5d9 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -23,6 +23,7 @@ if sys.version >= '3': basestring = unicode = str + xrange = range else: from itertools import imap as map @@ -416,6 +417,50 @@ def _createFromLocal(self, data, schema): data = [schema.toInternal(row) for row in data] return self._sc.parallelize(data), schema + def _get_numpy_record_dtypes(self, rec): + """ + Used when converting a pandas.DataFrame to Spark using to_records(), this will correct + the dtypes of records so they can be properly loaded into Spark. + :param rec: a numpy record to check dtypes + :return corrected dtypes for a numpy.record or None if no correction needed + """ + import numpy as np + cur_dtypes = rec.dtype + col_names = cur_dtypes.names + record_type_list = [] + has_rec_fix = False + for i in xrange(len(cur_dtypes)): + curr_type = cur_dtypes[i] + # If type is a datetime64 timestamp, convert to microseconds + # NOTE: if dtype is datetime[ns] then np.record.tolist() will output values as longs, + # conversion from [us] or lower will lead to py datetime objects, see SPARK-22417 + if curr_type == np.dtype('datetime64[ns]'): + curr_type = 'datetime64[us]' + has_rec_fix = True + record_type_list.append((str(col_names[i]), curr_type)) + return record_type_list if has_rec_fix else None + + def _convert_from_pandas(self, pdf, schema): + """ + Convert a pandas.DataFrame to list of records that can be used to make a DataFrame + :return tuple of list of records and schema + """ + # If no schema supplied by user then get the names of columns only + if schema is None: + schema = [str(x) for x in pdf.columns] + + # Convert pandas.DataFrame to list of numpy records + np_records = pdf.to_records(index=False) + + # Check if any columns need to be fixed for Spark to infer properly + if len(np_records) > 0: + record_type_list = self._get_numpy_record_dtypes(np_records[0]) + if record_type_list is not None: + return [r.astype(record_type_list).tolist() for r in np_records], schema + + # Convert list of numpy records to python lists + return [r.tolist() for r in np_records], schema + @since(2.0) @ignore_unicode_prefix def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True): @@ -512,9 +557,7 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr except Exception: has_pandas = False if has_pandas and isinstance(data, pandas.DataFrame): - if schema is None: - schema = [str(x) for x in data.columns] - data = [r.tolist() for r in data.to_records(index=False)] + data, schema = self._convert_from_pandas(data, schema) if isinstance(schema, StructType): verify_func = _make_type_verifier(schema) if verifySchema else lambda _: True diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 483f39aeef66a..eb0d4e29a5978 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2592,6 +2592,21 @@ def test_create_dataframe_from_array_of_long(self): df = self.spark.createDataFrame(data) self.assertEqual(df.first(), Row(longarray=[-9223372036854775808, 0, 9223372036854775807])) + @unittest.skipIf(not _have_pandas, "Pandas not installed") + def test_create_dataframe_from_pandas_with_timestamp(self): + import pandas as pd + from datetime import datetime + pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)], + "d": [pd.Timestamp.now().date()]}) + # test types are inferred correctly without specifying schema + df = self.spark.createDataFrame(pdf) + self.assertTrue(isinstance(df.schema['ts'].dataType, TimestampType)) + self.assertTrue(isinstance(df.schema['d'].dataType, DateType)) + # test with schema will accept pdf as input + df = self.spark.createDataFrame(pdf, schema="d date, ts timestamp") + self.assertTrue(isinstance(df.schema['ts'].dataType, TimestampType)) + self.assertTrue(isinstance(df.schema['d'].dataType, DateType)) + class HiveSparkSubmitTests(SparkSubmitTests):