diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py index d6a1df6ba93b..91d2f96aee1d 100644 --- a/python/pyspark/sql/connect/client.py +++ b/python/pyspark/sql/connect/client.py @@ -545,7 +545,10 @@ def to_pandas(self, plan: pb2.Plan) -> "pd.DataFrame": req = self._execute_plan_request_with_metadata() req.plan.CopyFrom(plan) table, metrics = self._execute_and_fetch(req) + column_names = table.column_names + table = table.rename_columns([f"col_{i}" for i in range(len(column_names))]) pdf = table.to_pandas() + pdf.columns = column_names if len(metrics) > 0: pdf.attrs["metrics"] = metrics return pdf diff --git a/python/pyspark/sql/tests/connect/test_parity_dataframe.py b/python/pyspark/sql/tests/connect/test_parity_dataframe.py index 800fe4a22980..97c0f473ce8c 100644 --- a/python/pyspark/sql/tests/connect/test_parity_dataframe.py +++ b/python/pyspark/sql/tests/connect/test_parity_dataframe.py @@ -112,15 +112,11 @@ def test_to_pandas_for_array_of_struct(self): def test_to_pandas_from_null_dataframe(self): super().test_to_pandas_from_null_dataframe() - # TODO(SPARK-41834): Implement SparkSession.conf - @unittest.skip("Fails in Spark Connect, should enable.") def test_to_pandas_on_cross_join(self): - super().test_to_pandas_on_cross_join() + self.check_to_pandas_on_cross_join() - # TODO(SPARK-41834): Implement SparkSession.conf - @unittest.skip("Fails in Spark Connect, should enable.") def test_to_pandas_with_duplicated_column_names(self): - super().test_to_pandas_with_duplicated_column_names() + self.check_to_pandas_with_duplicated_column_names() # TODO(SPARK-42367): DataFrame.drop should handle duplicated columns properly @unittest.skip("Fails in Spark Connect, should enable.") diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index 610edc0926dd..e686fa9e929f 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -1129,19 +1129,27 @@ def test_to_pandas(self): @unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore def test_to_pandas_with_duplicated_column_names(self): + for arrow_enabled in [False, True]: + with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": arrow_enabled}): + self.check_to_pandas_with_duplicated_column_names() + + def check_to_pandas_with_duplicated_column_names(self): import numpy as np sql = "select 1 v, 1 v" - for arrowEnabled in [False, True]: - with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": arrowEnabled}): - df = self.spark.sql(sql) - pdf = df.toPandas() - types = pdf.dtypes - self.assertEqual(types.iloc[0], np.int32) - self.assertEqual(types.iloc[1], np.int32) + df = self.spark.sql(sql) + pdf = df.toPandas() + types = pdf.dtypes + self.assertEqual(types.iloc[0], np.int32) + self.assertEqual(types.iloc[1], np.int32) @unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore def test_to_pandas_on_cross_join(self): + for arrow_enabled in [False, True]: + with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": arrow_enabled}): + self.check_to_pandas_on_cross_join() + + def check_to_pandas_on_cross_join(self): import numpy as np sql = """ @@ -1151,18 +1159,12 @@ def test_to_pandas_on_cross_join(self): select explode(sequence(1, 3)) v ) t2 """ - for arrowEnabled in [False, True]: - with self.sql_conf( - { - "spark.sql.crossJoin.enabled": True, - "spark.sql.execution.arrow.pyspark.enabled": arrowEnabled, - } - ): - df = self.spark.sql(sql) - pdf = df.toPandas() - types = pdf.dtypes - self.assertEqual(types.iloc[0], np.int32) - self.assertEqual(types.iloc[1], np.int32) + with self.sql_conf({"spark.sql.crossJoin.enabled": True}): + df = self.spark.sql(sql) + pdf = df.toPandas() + types = pdf.dtypes + self.assertEqual(types.iloc[0], np.int32) + self.assertEqual(types.iloc[1], np.int32) @unittest.skipIf(have_pandas, "Required Pandas was found.") def test_to_pandas_required_pandas_not_found(self):