From a211003a50f0cfa88500e4915411b5d8fe831ecb Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Thu, 14 Nov 2024 17:38:46 +0800 Subject: [PATCH 1/4] verifySchema on Connect --- python/pyspark/sql/connect/conversion.py | 6 ++--- python/pyspark/sql/connect/session.py | 28 ++++++++++++++++++------ 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/python/pyspark/sql/connect/conversion.py b/python/pyspark/sql/connect/conversion.py index d803f37c5b9f..f689c439f5f6 100644 --- a/python/pyspark/sql/connect/conversion.py +++ b/python/pyspark/sql/connect/conversion.py @@ -322,7 +322,7 @@ def convert_other(value: Any) -> Any: return lambda value: value @staticmethod - def convert(data: Sequence[Any], schema: StructType) -> "pa.Table": + def convert(data: Sequence[Any], schema: StructType, verifySchema: bool = False) -> "pa.Table": assert isinstance(data, list) and len(data) > 0 assert schema is not None and isinstance(schema, StructType) @@ -372,8 +372,8 @@ def convert(data: Sequence[Any], schema: StructType) -> "pa.Table": ] ) ) - - return pa.Table.from_arrays(pylist, schema=pa_schema) + table = pa.Table.from_arrays(pylist, schema=pa_schema) + return table.cast(pa_schema, safe=verifySchema) class ArrowTableToRowsConversion: diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 83b0496a8427..eb7214353943 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -50,6 +50,7 @@ ) import urllib +from pyspark._globals import _NoValue, _NoValueType from pyspark.sql.connect.dataframe import DataFrame from pyspark.sql.dataframe import DataFrame as ParentDataFrame from pyspark.sql.connect.logging import logger @@ -449,7 +450,7 @@ def createDataFrame( data: Union["pd.DataFrame", "np.ndarray", "pa.Table", Iterable[Any]], schema: Optional[Union[AtomicType, StructType, str, List[str], Tuple[str, ...]]] = None, samplingRatio: Optional[float] = None, - verifySchema: Optional[bool] = None, + verifySchema: Union[_NoValueType, bool] = _NoValue, ) -> "ParentDataFrame": assert data is not None if isinstance(data, DataFrame): @@ -461,9 +462,6 @@ def createDataFrame( if samplingRatio is not None: warnings.warn("'samplingRatio' is ignored. It is not supported with Spark Connect.") - if verifySchema is not None: - warnings.warn("'verifySchema' is ignored. It is not supported with Spark Connect.") - _schema: Optional[Union[AtomicType, StructType]] = None _cols: Optional[List[str]] = None _num_cols: Optional[int] = None @@ -576,7 +574,10 @@ def createDataFrame( "spark.sql.session.timeZone", "spark.sql.execution.pandas.convertToArrowArraySafely" ) - ser = ArrowStreamPandasSerializer(cast(str, timezone), safecheck == "true") + if verifySchema is _NoValue: + verifySchema = safecheck == "true" + + ser = ArrowStreamPandasSerializer(cast(str, timezone), verifySchema) _table = pa.Table.from_batches( [ @@ -596,6 +597,9 @@ def createDataFrame( ).cast(arrow_schema) elif isinstance(data, pa.Table): + if verifySchema is _NoValue: + verifySchema = False + prefer_timestamp_ntz = is_timestamp_ntz_preferred() (timezone,) = self._client.get_configs("spark.sql.session.timeZone") @@ -613,7 +617,10 @@ def createDataFrame( _table = ( _check_arrow_table_timestamps_localize(data, schema, True, timezone) - .cast(to_arrow_schema(schema, error_on_duplicated_field_names_in_struct=True)) + .cast( + to_arrow_schema(schema, error_on_duplicated_field_names_in_struct=True), + safe=verifySchema, + ) .rename_columns(schema.names) ) @@ -648,10 +655,14 @@ def createDataFrame( _table = pa.Table.from_arrays( [pa.array(data[::, i]) for i in range(0, data.shape[1])], _cols ) + _table.cast() # The _table should already have the proper column names. _cols = None + if verifySchema is _NoValue: + verifySchema = True + else: _data = list(data) @@ -683,12 +694,15 @@ def createDataFrame( errorClass="CANNOT_DETERMINE_TYPE", messageParameters={} ) + if verifySchema is _NoValue: + verifySchema = True + from pyspark.sql.connect.conversion import LocalDataToArrowConversion # Spark Connect will try its best to build the Arrow table with the # inferred schema in the client side, and then rename the columns and # cast the datatypes in the server side. - _table = LocalDataToArrowConversion.convert(_data, _schema) + _table = LocalDataToArrowConversion.convert(_data, _schema, verifySchema) # TODO: Beside the validation on number of columns, we should also check # whether the Arrow Schema is compatible with the user provided Schema. From d524d1f2bef2710ab7d32f5959d202a6aa9efaef Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Fri, 15 Nov 2024 09:23:26 +0800 Subject: [PATCH 2/4] fix; test --- python/pyspark/sql/connect/session.py | 10 ++++++---- python/pyspark/sql/tests/connect/test_parity_arrow.py | 1 - 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index eb7214353943..e7292bf8804f 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -655,13 +655,15 @@ def createDataFrame( _table = pa.Table.from_arrays( [pa.array(data[::, i]) for i in range(0, data.shape[1])], _cols ) - _table.cast() # The _table should already have the proper column names. _cols = None - if verifySchema is _NoValue: - verifySchema = True + if verifySchema is not _NoValue: + warnings.warn( + "'verifySchema' is ignored. It is not supported" + " with np.ndarray input on Spark Connect." + ) else: _data = list(data) @@ -702,7 +704,7 @@ def createDataFrame( # Spark Connect will try its best to build the Arrow table with the # inferred schema in the client side, and then rename the columns and # cast the datatypes in the server side. - _table = LocalDataToArrowConversion.convert(_data, _schema, verifySchema) + _table = LocalDataToArrowConversion.convert(_data, _schema, cast(bool, verifySchema)) # TODO: Beside the validation on number of columns, we should also check # whether the Arrow Schema is compatible with the user provided Schema. diff --git a/python/pyspark/sql/tests/connect/test_parity_arrow.py b/python/pyspark/sql/tests/connect/test_parity_arrow.py index d47a367a5460..d92cb908523f 100644 --- a/python/pyspark/sql/tests/connect/test_parity_arrow.py +++ b/python/pyspark/sql/tests/connect/test_parity_arrow.py @@ -137,7 +137,6 @@ def test_toPandas_udt(self): def test_create_dataframe_namedtuples(self): self.check_create_dataframe_namedtuples(True) - @unittest.skip("Spark Connect does not support verifySchema.") def test_createDataFrame_verifySchema(self): super().test_createDataFrame_verifySchema() From 5949de380a2304e2bfa78941a8b3169c39ecb2f3 Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Fri, 15 Nov 2024 11:50:28 +0800 Subject: [PATCH 3/4] tests --- .../sql/tests/connect/test_parity_arrow.py | 2 +- python/pyspark/sql/tests/test_arrow.py | 44 +++++++++++-------- 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/python/pyspark/sql/tests/connect/test_parity_arrow.py b/python/pyspark/sql/tests/connect/test_parity_arrow.py index d92cb908523f..99d03ad1a440 100644 --- a/python/pyspark/sql/tests/connect/test_parity_arrow.py +++ b/python/pyspark/sql/tests/connect/test_parity_arrow.py @@ -138,7 +138,7 @@ def test_create_dataframe_namedtuples(self): self.check_create_dataframe_namedtuples(True) def test_createDataFrame_verifySchema(self): - super().test_createDataFrame_verifySchema() + self.check_createDataFrame_verifySchema(True) if __name__ == "__main__": diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index 19d0db989431..b83eadb4394e 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -533,6 +533,11 @@ def test_createDataFrame_arrow_pandas(self): self.assertEqual(df_arrow.collect(), df_pandas.collect()) def test_createDataFrame_verifySchema(self): + for arrow_enabled in [True, False]: + with self.subTest(arrow_enabled=arrow_enabled): + self.check_createDataFrame_verifySchema(arrow_enabled) + + def check_createDataFrame_verifySchema(self, arrow_enabled): data = {"id": [1, 2, 3], "value": [100000000000, 200000000000, 300000000000]} # data.value should fail schema validation when verifySchema is True schema = StructType( @@ -547,29 +552,30 @@ def test_createDataFrame_verifySchema(self): table = pa.table(data) df = self.spark.createDataFrame(table, schema=schema) self.assertEqual(df.collect(), expected) - with self.assertRaises(Exception): self.spark.createDataFrame(table, schema=schema, verifySchema=True) - # pandas DataFrame with Arrow optimization - pdf = pd.DataFrame(data) - df = self.spark.createDataFrame(pdf, schema=schema) - # verifySchema defaults to `spark.sql.execution.pandas.convertToArrowArraySafely`, - # which is false by default - self.assertEqual(df.collect(), expected) - with self.assertRaises(Exception): - with self.sql_conf({"spark.sql.execution.pandas.convertToArrowArraySafely": True}): + if arrow_enabled: + with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": True}): + # pandas DataFrame with Arrow optimization + pdf = pd.DataFrame(data) df = self.spark.createDataFrame(pdf, schema=schema) - with self.assertRaises(Exception): - df = self.spark.createDataFrame(pdf, schema=schema, verifySchema=True) - - # pandas DataFrame without Arrow optimization - with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}): - pdf = pd.DataFrame(data) - with self.assertRaises(Exception): - df = self.spark.createDataFrame(pdf, schema=schema) # verifySchema defaults to True - df = self.spark.createDataFrame(pdf, schema=schema, verifySchema=False) - self.assertEqual(df.collect(), expected) + # verifySchema defaults to `spark.sql.execution.pandas.convertToArrowArraySafely`, + # which is false by default + self.assertEqual(df.collect(), expected) + with self.assertRaises(Exception): + with self.sql_conf({"spark.sql.execution.pandas.convertToArrowArraySafely": True}): + df = self.spark.createDataFrame(pdf, schema=schema) + with self.assertRaises(Exception): + df = self.spark.createDataFrame(pdf, schema=schema, verifySchema=True) + else: + # pandas DataFrame without Arrow optimization + with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}): + pdf = pd.DataFrame(data) + with self.assertRaises(Exception): + self.spark.createDataFrame(pdf, schema=schema) # verifySchema defaults to True + df = self.spark.createDataFrame(pdf, schema=schema, verifySchema=False) + self.assertEqual(df.collect(), expected) def _createDataFrame_toggle(self, data, schema=None): with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}): From 48f5e0c9a0bba29f1e7e1d6553081c1d75877de1 Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Fri, 15 Nov 2024 14:25:08 +0800 Subject: [PATCH 4/4] fmt --- python/pyspark/sql/tests/test_arrow.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index b83eadb4394e..99149d1a23d3 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -564,7 +564,9 @@ def check_createDataFrame_verifySchema(self, arrow_enabled): # which is false by default self.assertEqual(df.collect(), expected) with self.assertRaises(Exception): - with self.sql_conf({"spark.sql.execution.pandas.convertToArrowArraySafely": True}): + with self.sql_conf( + {"spark.sql.execution.pandas.convertToArrowArraySafely": True} + ): df = self.spark.createDataFrame(pdf, schema=schema) with self.assertRaises(Exception): df = self.spark.createDataFrame(pdf, schema=schema, verifySchema=True)