From 6e317aab4e56445729061f7b9b4ca0734a051849 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 14 Apr 2020 17:08:20 -0700 Subject: [PATCH] Backport SPARK-31186. --- python/pyspark/sql/dataframe.py | 40 ++++++++++++++++++++++++++++----- python/pyspark/sql/tests.py | 18 +++++++++++++++ 2 files changed, 52 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index a2651d2781c2..b58d976abbeb 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -27,6 +27,7 @@ from itertools import imap as map from cgi import escape as html_escape +from collections import Counter import warnings from pyspark import copy_func, since, _NoValue @@ -2148,9 +2149,16 @@ def toPandas(self): # Below is toPandas without Arrow optimization. pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) + column_counter = Counter(self.columns) + + dtype = [None] * len(self.schema) + for fieldIdx, field in enumerate(self.schema): + # For duplicate column name, we use `iloc` to access it. + if column_counter[field.name] > 1: + pandas_col = pdf.iloc[:, fieldIdx] + else: + pandas_col = pdf[field.name] - dtype = {} - for field in self.schema: pandas_type = _to_corrected_pandas_type(field.dataType) # SPARK-21766: if an integer field is nullable and has null values, it can be # inferred by pandas as float column. Once we convert the column with NaN back @@ -2158,11 +2166,31 @@ def toPandas(self): # float type, not the corrected type from the schema in this case. if pandas_type is not None and \ not(isinstance(field.dataType, IntegralType) and field.nullable and - pdf[field.name].isnull().any()): - dtype[field.name] = pandas_type + pandas_col.isnull().any()): + dtype[fieldIdx] = pandas_type + + df = pd.DataFrame() + for index, t in enumerate(dtype): + column_name = self.schema[index].name + + # For duplicate column name, we use `iloc` to access it. + if column_counter[column_name] > 1: + series = pdf.iloc[:, index] + else: + series = pdf[column_name] + + if t is not None: + series = series.astype(t, copy=False) + + # `insert` API makes copy of data, we only do it for Series of duplicate column names. + # `pdf.iloc[:, index] = pdf.iloc[:, index]...` doesn't always work because `iloc` could + # return a view or a copy depending by context. + if column_counter[column_name] > 1: + df.insert(index, column_name, series, allow_duplicates=True) + else: + df[column_name] = series - for f, t in dtype.items(): - pdf[f] = pdf[f].astype(t, copy=False) + pdf = df if timezone is None: return pdf diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 02842673b277..d359e005fc9d 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3296,6 +3296,24 @@ def test_to_pandas(self): self.assertEquals(types[4], np.object) # datetime.date self.assertEquals(types[5], 'datetime64[ns]') + @unittest.skipIf(not _have_pandas, _pandas_requirement_message) + def test_to_pandas_on_cross_join(self): + import numpy as np + + sql = """ + select t1.*, t2.* from ( + select explode(sequence(1, 3)) v + ) t1 left join ( + select explode(sequence(1, 3)) v + ) t2 + """ + with self.sql_conf({"spark.sql.crossJoin.enabled": True}): + df = self.spark.sql(sql) + pdf = df.toPandas() + types = pdf.dtypes + self.assertEquals(types.iloc[0], np.int32) + self.assertEquals(types.iloc[1], np.int32) + @unittest.skipIf(_have_pandas, "Required Pandas was found.") def test_to_pandas_required_pandas_not_found(self): with QuietTest(self.sc):