Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions python/pyspark/sql/connect/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,10 @@ def to_pandas(self, plan: pb2.Plan) -> "pd.DataFrame":
req = self._execute_plan_request_with_metadata()
req.plan.CopyFrom(plan)
table, metrics = self._execute_and_fetch(req)
column_names = table.column_names
table = table.rename_columns([f"col_{i}" for i in range(len(column_names))])
pdf = table.to_pandas()
pdf.columns = column_names
if len(metrics) > 0:
pdf.attrs["metrics"] = metrics
return pdf
Expand Down
8 changes: 2 additions & 6 deletions python/pyspark/sql/tests/connect/test_parity_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,11 @@ def test_to_pandas_for_array_of_struct(self):
def test_to_pandas_from_null_dataframe(self):
super().test_to_pandas_from_null_dataframe()

# TODO(SPARK-41834): Implement SparkSession.conf
@unittest.skip("Fails in Spark Connect, should enable.")
def test_to_pandas_on_cross_join(self):
super().test_to_pandas_on_cross_join()
self.check_to_pandas_on_cross_join()

# TODO(SPARK-41834): Implement SparkSession.conf
@unittest.skip("Fails in Spark Connect, should enable.")
def test_to_pandas_with_duplicated_column_names(self):
super().test_to_pandas_with_duplicated_column_names()
self.check_to_pandas_with_duplicated_column_names()

# TODO(SPARK-42367): DataFrame.drop should handle duplicated columns properly
@unittest.skip("Fails in Spark Connect, should enable.")
Expand Down
40 changes: 21 additions & 19 deletions python/pyspark/sql/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1129,19 +1129,27 @@ def test_to_pandas(self):

@unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore
def test_to_pandas_with_duplicated_column_names(self):
for arrow_enabled in [False, True]:
with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": arrow_enabled}):
self.check_to_pandas_with_duplicated_column_names()

def check_to_pandas_with_duplicated_column_names(self):
import numpy as np

sql = "select 1 v, 1 v"
for arrowEnabled in [False, True]:
with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": arrowEnabled}):
df = self.spark.sql(sql)
pdf = df.toPandas()
types = pdf.dtypes
self.assertEqual(types.iloc[0], np.int32)
self.assertEqual(types.iloc[1], np.int32)
df = self.spark.sql(sql)
pdf = df.toPandas()
types = pdf.dtypes
self.assertEqual(types.iloc[0], np.int32)
self.assertEqual(types.iloc[1], np.int32)

@unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore
def test_to_pandas_on_cross_join(self):
for arrow_enabled in [False, True]:
with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": arrow_enabled}):
self.check_to_pandas_on_cross_join()

def check_to_pandas_on_cross_join(self):
import numpy as np

sql = """
Expand All @@ -1151,18 +1159,12 @@ def test_to_pandas_on_cross_join(self):
select explode(sequence(1, 3)) v
) t2
"""
for arrowEnabled in [False, True]:
with self.sql_conf(
{
"spark.sql.crossJoin.enabled": True,
"spark.sql.execution.arrow.pyspark.enabled": arrowEnabled,
}
):
df = self.spark.sql(sql)
pdf = df.toPandas()
types = pdf.dtypes
self.assertEqual(types.iloc[0], np.int32)
self.assertEqual(types.iloc[1], np.int32)
with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
df = self.spark.sql(sql)
pdf = df.toPandas()
types = pdf.dtypes
self.assertEqual(types.iloc[0], np.int32)
self.assertEqual(types.iloc[1], np.int32)

@unittest.skipIf(have_pandas, "Required Pandas was found.")
def test_to_pandas_required_pandas_not_found(self):
Expand Down