Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 38 additions & 10 deletions python/pyspark/sql/pandas/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
xrange = range
else:
from itertools import izip as zip
from collections import Counter

from pyspark import since
from pyspark.rdd import _load_from_socket
Expand Down Expand Up @@ -131,26 +132,53 @@ def toPandas(self):

# Below is toPandas without Arrow optimization.
pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns)
column_counter = Counter(self.columns)

dtype = [None] * len(self.schema)
for fieldIdx, field in enumerate(self.schema):
# For duplicate column name, we use `iloc` to access it.
if column_counter[field.name] > 1:
pandas_col = pdf.iloc[:, fieldIdx]
else:
pandas_col = pdf[field.name]

dtype = {}
for field in self.schema:
pandas_type = PandasConversionMixin._to_corrected_pandas_type(field.dataType)
# SPARK-21766: if an integer field is nullable and has null values, it can be
# inferred by pandas as float column. Once we convert the column with NaN back
# to integer type e.g., np.int16, we will hit exception. So we use the inferred
# float type, not the corrected type from the schema in this case.
if pandas_type is not None and \
not(isinstance(field.dataType, IntegralType) and field.nullable and
pdf[field.name].isnull().any()):
dtype[field.name] = pandas_type
pandas_col.isnull().any()):
dtype[fieldIdx] = pandas_type
# Ensure we fall back to nullable numpy types, even when whole column is null:
if isinstance(field.dataType, IntegralType) and pdf[field.name].isnull().any():
dtype[field.name] = np.float64
if isinstance(field.dataType, BooleanType) and pdf[field.name].isnull().any():
dtype[field.name] = np.object
if isinstance(field.dataType, IntegralType) and pandas_col.isnull().any():
dtype[fieldIdx] = np.float64
if isinstance(field.dataType, BooleanType) and pandas_col.isnull().any():
dtype[fieldIdx] = np.object

df = pd.DataFrame()
for index, t in enumerate(dtype):
column_name = self.schema[index].name

# For duplicate column name, we use `iloc` to access it.
if column_counter[column_name] > 1:
series = pdf.iloc[:, index]
else:
series = pdf[column_name]

if t is not None:
series = series.astype(t, copy=False)

# `insert` API makes copy of data, we only do it for Series of duplicate column names.
# `pdf.iloc[:, index] = pdf.iloc[:, index]...` doesn't always work because `iloc` could
# return a view or a copy depending by context.
if column_counter[column_name] > 1:
df.insert(index, column_name, series, allow_duplicates=True)
else:
df[column_name] = series

for f, t in dtype.items():
pdf[f] = pdf[f].astype(t, copy=False)
pdf = df

if timezone is None:
return pdf
Expand Down
18 changes: 18 additions & 0 deletions python/pyspark/sql/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,24 @@ def test_to_pandas(self):
self.assertEquals(types[4], np.object) # datetime.date
self.assertEquals(types[5], 'datetime64[ns]')

@unittest.skipIf(not have_pandas, pandas_requirement_message)
def test_to_pandas_on_cross_join(self):
import numpy as np

sql = """
select t1.*, t2.* from (
select explode(sequence(1, 3)) v
) t1 left join (
select explode(sequence(1, 3)) v
) t2
"""
with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
df = self.spark.sql(sql)
pdf = df.toPandas()
types = pdf.dtypes
self.assertEquals(types.iloc[0], np.int32)
self.assertEquals(types.iloc[1], np.int32)

@unittest.skipIf(have_pandas, "Required Pandas was found.")
def test_to_pandas_required_pandas_not_found(self):
with QuietTest(self.sc):
Expand Down