Skip to content

Commit 6fea291

Browse files
viiryaHyukjinKwon
authored andcommitted
[SPARK-31186][PYSPARK][SQL] toPandas should not fail on duplicate column names
### What changes were proposed in this pull request? When `toPandas` API works on duplicate column names produced from operators like join, we see the error like: ``` ValueError: The truth value of a Series is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all(). ``` This patch fixes the error in `toPandas` API. ### Why are the changes needed? To make `toPandas` work on dataframe with duplicate column names. ### Does this PR introduce any user-facing change? Yes. Previously calling `toPandas` API on a dataframe with duplicate column names will fail. After this patch, it will produce correct result. ### How was this patch tested? Unit test. Closes #28025 from viirya/SPARK-31186. Authored-by: Liang-Chi Hsieh <viirya@gmail.com> Signed-off-by: HyukjinKwon <gurwls223@apache.org> (cherry picked from commit 559d3e4) Signed-off-by: HyukjinKwon <gurwls223@apache.org>
1 parent 6f30ff4 commit 6fea291

File tree

2 files changed

+56
-10
lines changed

2 files changed

+56
-10
lines changed

python/pyspark/sql/pandas/conversion.py

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
xrange = range
2222
else:
2323
from itertools import izip as zip
24+
from collections import Counter
2425

2526
from pyspark import since
2627
from pyspark.rdd import _load_from_socket
@@ -131,26 +132,53 @@ def toPandas(self):
131132

132133
# Below is toPandas without Arrow optimization.
133134
pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns)
135+
column_counter = Counter(self.columns)
136+
137+
dtype = [None] * len(self.schema)
138+
for fieldIdx, field in enumerate(self.schema):
139+
# For duplicate column name, we use `iloc` to access it.
140+
if column_counter[field.name] > 1:
141+
pandas_col = pdf.iloc[:, fieldIdx]
142+
else:
143+
pandas_col = pdf[field.name]
134144

135-
dtype = {}
136-
for field in self.schema:
137145
pandas_type = PandasConversionMixin._to_corrected_pandas_type(field.dataType)
138146
# SPARK-21766: if an integer field is nullable and has null values, it can be
139147
# inferred by pandas as float column. Once we convert the column with NaN back
140148
# to integer type e.g., np.int16, we will hit exception. So we use the inferred
141149
# float type, not the corrected type from the schema in this case.
142150
if pandas_type is not None and \
143151
not(isinstance(field.dataType, IntegralType) and field.nullable and
144-
pdf[field.name].isnull().any()):
145-
dtype[field.name] = pandas_type
152+
pandas_col.isnull().any()):
153+
dtype[fieldIdx] = pandas_type
146154
# Ensure we fall back to nullable numpy types, even when whole column is null:
147-
if isinstance(field.dataType, IntegralType) and pdf[field.name].isnull().any():
148-
dtype[field.name] = np.float64
149-
if isinstance(field.dataType, BooleanType) and pdf[field.name].isnull().any():
150-
dtype[field.name] = np.object
155+
if isinstance(field.dataType, IntegralType) and pandas_col.isnull().any():
156+
dtype[fieldIdx] = np.float64
157+
if isinstance(field.dataType, BooleanType) and pandas_col.isnull().any():
158+
dtype[fieldIdx] = np.object
159+
160+
df = pd.DataFrame()
161+
for index, t in enumerate(dtype):
162+
column_name = self.schema[index].name
163+
164+
# For duplicate column name, we use `iloc` to access it.
165+
if column_counter[column_name] > 1:
166+
series = pdf.iloc[:, index]
167+
else:
168+
series = pdf[column_name]
169+
170+
if t is not None:
171+
series = series.astype(t, copy=False)
172+
173+
# `insert` API makes copy of data, we only do it for Series of duplicate column names.
174+
# `pdf.iloc[:, index] = pdf.iloc[:, index]...` doesn't always work because `iloc` could
175+
# return a view or a copy depending by context.
176+
if column_counter[column_name] > 1:
177+
df.insert(index, column_name, series, allow_duplicates=True)
178+
else:
179+
df[column_name] = series
151180

152-
for f, t in dtype.items():
153-
pdf[f] = pdf[f].astype(t, copy=False)
181+
pdf = df
154182

155183
if timezone is None:
156184
return pdf

python/pyspark/sql/tests/test_dataframe.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,24 @@ def test_to_pandas(self):
529529
self.assertEquals(types[4], np.object) # datetime.date
530530
self.assertEquals(types[5], 'datetime64[ns]')
531531

532+
@unittest.skipIf(not have_pandas, pandas_requirement_message)
533+
def test_to_pandas_on_cross_join(self):
534+
import numpy as np
535+
536+
sql = """
537+
select t1.*, t2.* from (
538+
select explode(sequence(1, 3)) v
539+
) t1 left join (
540+
select explode(sequence(1, 3)) v
541+
) t2
542+
"""
543+
with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
544+
df = self.spark.sql(sql)
545+
pdf = df.toPandas()
546+
types = pdf.dtypes
547+
self.assertEquals(types.iloc[0], np.int32)
548+
self.assertEquals(types.iloc[1], np.int32)
549+
532550
@unittest.skipIf(have_pandas, "Required Pandas was found.")
533551
def test_to_pandas_required_pandas_not_found(self):
534552
with QuietTest(self.sc):

0 commit comments

Comments
 (0)