diff --git a/python/pyspark/sql/connect/conversion.py b/python/pyspark/sql/connect/conversion.py index 0a723f2977a6d..4dbdb5db21252 100644 --- a/python/pyspark/sql/connect/conversion.py +++ b/python/pyspark/sql/connect/conversion.py @@ -105,7 +105,9 @@ def convert_struct(value: Any) -> Any: if value is None: return None else: - assert isinstance(value, (tuple, dict)), f"{type(value)} {value}" + assert isinstance(value, (tuple, dict)) or hasattr( + value, "__dict__" + ), f"{type(value)} {value}" _dict = {} if isinstance(value, dict): @@ -116,6 +118,10 @@ def convert_struct(value: Any) -> Any: for k, v in value.asDict(recursive=False).items(): assert isinstance(k, str) _dict[k] = field_convs[k](v) + elif not isinstance(value, Row) and hasattr(value, "__dict__"): + for k, v in value.__dict__.items(): + assert isinstance(k, str) + _dict[k] = field_convs[k](v) else: i = 0 for v in value: @@ -253,6 +259,10 @@ def convert(data: Sequence[Any], schema: StructType) -> "pa.Table": elif isinstance(item, Row) and hasattr(item, "__fields__"): for col, value in item.asDict(recursive=False).items(): _dict[col] = column_convs[col](value) + elif not isinstance(item, Row) and hasattr(item, "__dict__"): + for col, value in item.__dict__.items(): + print(col, value) + _dict[col] = column_convs[col](value) else: i = 0 for value in item: diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index b93bbffc999f4..898baa45b03ce 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -294,7 +294,9 @@ def createDataFrame( # For dictionaries, we sort the schema in alphabetical order. _data = [dict(sorted(d.items())) for d in _data] - elif not isinstance(_data[0], (Row, tuple, list, dict)): + elif not isinstance(_data[0], (Row, tuple, list, dict)) and not hasattr( + _data[0], "__dict__" + ): # input data can be [1, 2, 3] # we need to convert it to [[1], [2], [3]] to be able to infer schema. _data = [[d] for d in _data] diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 9068d6f56354e..a9beb71545d04 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -37,6 +37,7 @@ ) from pyspark.testing.sqlutils import ( + MyObject, SQLTestUtils, PythonOnlyUDT, ExamplePoint, @@ -840,6 +841,22 @@ def test_nested_type_create_from_rows(self): self.assertEqual(cdf.schema, sdf.schema) self.assertEqual(cdf.collect(), sdf.collect()) + def test_create_df_from_objects(self): + data = [MyObject(1, "1"), MyObject(2, "2")] + + # +---+-----+ + # |key|value| + # +---+-----+ + # | 1| 1| + # | 2| 2| + # +---+-----+ + + cdf = self.connect.createDataFrame(data) + sdf = self.spark.createDataFrame(data) + + self.assertEqual(cdf.schema, sdf.schema) + self.assertEqual(cdf.collect(), sdf.collect()) + def test_simple_explain_string(self): df = self.connect.read.table(self.tbl_name).limit(10) result = df._explain_string() diff --git a/python/pyspark/sql/tests/connect/test_parity_types.py b/python/pyspark/sql/tests/connect/test_parity_types.py index 025e64f2bf069..e966986c15213 100644 --- a/python/pyspark/sql/tests/connect/test_parity_types.py +++ b/python/pyspark/sql/tests/connect/test_parity_types.py @@ -54,11 +54,6 @@ def test_cast_to_udt_with_udt(self): def test_complex_nested_udt_in_df(self): super().test_complex_nested_udt_in_df() - # TODO(SPARK-42020): createDataFrame with UDT - @unittest.skip("Fails in Spark Connect, should enable.") - def test_create_dataframe_from_objects(self): - super().test_create_dataframe_from_objects() - @unittest.skip("Spark Connect does not support RDD but the tests depend on them.") def test_create_dataframe_schema_mismatch(self): super().test_create_dataframe_schema_mismatch()