diff --git a/python/pyspark/sql/connect/conversion.py b/python/pyspark/sql/connect/conversion.py index d803f37c5b9f..f689c439f5f6 100644 --- a/python/pyspark/sql/connect/conversion.py +++ b/python/pyspark/sql/connect/conversion.py @@ -322,7 +322,7 @@ def convert_other(value: Any) -> Any: return lambda value: value @staticmethod - def convert(data: Sequence[Any], schema: StructType) -> "pa.Table": + def convert(data: Sequence[Any], schema: StructType, verifySchema: bool = False) -> "pa.Table": assert isinstance(data, list) and len(data) > 0 assert schema is not None and isinstance(schema, StructType) @@ -372,8 +372,8 @@ def convert(data: Sequence[Any], schema: StructType) -> "pa.Table": ] ) ) - - return pa.Table.from_arrays(pylist, schema=pa_schema) + table = pa.Table.from_arrays(pylist, schema=pa_schema) + return table.cast(pa_schema, safe=verifySchema) class ArrowTableToRowsConversion: diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 83b0496a8427..e7292bf8804f 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -50,6 +50,7 @@ ) import urllib +from pyspark._globals import _NoValue, _NoValueType from pyspark.sql.connect.dataframe import DataFrame from pyspark.sql.dataframe import DataFrame as ParentDataFrame from pyspark.sql.connect.logging import logger @@ -449,7 +450,7 @@ def createDataFrame( data: Union["pd.DataFrame", "np.ndarray", "pa.Table", Iterable[Any]], schema: Optional[Union[AtomicType, StructType, str, List[str], Tuple[str, ...]]] = None, samplingRatio: Optional[float] = None, - verifySchema: Optional[bool] = None, + verifySchema: Union[_NoValueType, bool] = _NoValue, ) -> "ParentDataFrame": assert data is not None if isinstance(data, DataFrame): @@ -461,9 +462,6 @@ def createDataFrame( if samplingRatio is not None: warnings.warn("'samplingRatio' is ignored. It is not supported with Spark Connect.") - if verifySchema is not None: - warnings.warn("'verifySchema' is ignored. It is not supported with Spark Connect.") - _schema: Optional[Union[AtomicType, StructType]] = None _cols: Optional[List[str]] = None _num_cols: Optional[int] = None @@ -576,7 +574,10 @@ def createDataFrame( "spark.sql.session.timeZone", "spark.sql.execution.pandas.convertToArrowArraySafely" ) - ser = ArrowStreamPandasSerializer(cast(str, timezone), safecheck == "true") + if verifySchema is _NoValue: + verifySchema = safecheck == "true" + + ser = ArrowStreamPandasSerializer(cast(str, timezone), verifySchema) _table = pa.Table.from_batches( [ @@ -596,6 +597,9 @@ def createDataFrame( ).cast(arrow_schema) elif isinstance(data, pa.Table): + if verifySchema is _NoValue: + verifySchema = False + prefer_timestamp_ntz = is_timestamp_ntz_preferred() (timezone,) = self._client.get_configs("spark.sql.session.timeZone") @@ -613,7 +617,10 @@ def createDataFrame( _table = ( _check_arrow_table_timestamps_localize(data, schema, True, timezone) - .cast(to_arrow_schema(schema, error_on_duplicated_field_names_in_struct=True)) + .cast( + to_arrow_schema(schema, error_on_duplicated_field_names_in_struct=True), + safe=verifySchema, + ) .rename_columns(schema.names) ) @@ -652,6 +659,12 @@ def createDataFrame( # The _table should already have the proper column names. _cols = None + if verifySchema is not _NoValue: + warnings.warn( + "'verifySchema' is ignored. It is not supported" + " with np.ndarray input on Spark Connect." + ) + else: _data = list(data) @@ -683,12 +696,15 @@ def createDataFrame( errorClass="CANNOT_DETERMINE_TYPE", messageParameters={} ) + if verifySchema is _NoValue: + verifySchema = True + from pyspark.sql.connect.conversion import LocalDataToArrowConversion # Spark Connect will try its best to build the Arrow table with the # inferred schema in the client side, and then rename the columns and # cast the datatypes in the server side. - _table = LocalDataToArrowConversion.convert(_data, _schema) + _table = LocalDataToArrowConversion.convert(_data, _schema, cast(bool, verifySchema)) # TODO: Beside the validation on number of columns, we should also check # whether the Arrow Schema is compatible with the user provided Schema. diff --git a/python/pyspark/sql/tests/connect/test_parity_arrow.py b/python/pyspark/sql/tests/connect/test_parity_arrow.py index d47a367a5460..99d03ad1a440 100644 --- a/python/pyspark/sql/tests/connect/test_parity_arrow.py +++ b/python/pyspark/sql/tests/connect/test_parity_arrow.py @@ -137,9 +137,8 @@ def test_toPandas_udt(self): def test_create_dataframe_namedtuples(self): self.check_create_dataframe_namedtuples(True) - @unittest.skip("Spark Connect does not support verifySchema.") def test_createDataFrame_verifySchema(self): - super().test_createDataFrame_verifySchema() + self.check_createDataFrame_verifySchema(True) if __name__ == "__main__": diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index 19d0db989431..99149d1a23d3 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -533,6 +533,11 @@ def test_createDataFrame_arrow_pandas(self): self.assertEqual(df_arrow.collect(), df_pandas.collect()) def test_createDataFrame_verifySchema(self): + for arrow_enabled in [True, False]: + with self.subTest(arrow_enabled=arrow_enabled): + self.check_createDataFrame_verifySchema(arrow_enabled) + + def check_createDataFrame_verifySchema(self, arrow_enabled): data = {"id": [1, 2, 3], "value": [100000000000, 200000000000, 300000000000]} # data.value should fail schema validation when verifySchema is True schema = StructType( @@ -547,29 +552,32 @@ def test_createDataFrame_verifySchema(self): table = pa.table(data) df = self.spark.createDataFrame(table, schema=schema) self.assertEqual(df.collect(), expected) - with self.assertRaises(Exception): self.spark.createDataFrame(table, schema=schema, verifySchema=True) - # pandas DataFrame with Arrow optimization - pdf = pd.DataFrame(data) - df = self.spark.createDataFrame(pdf, schema=schema) - # verifySchema defaults to `spark.sql.execution.pandas.convertToArrowArraySafely`, - # which is false by default - self.assertEqual(df.collect(), expected) - with self.assertRaises(Exception): - with self.sql_conf({"spark.sql.execution.pandas.convertToArrowArraySafely": True}): + if arrow_enabled: + with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": True}): + # pandas DataFrame with Arrow optimization + pdf = pd.DataFrame(data) df = self.spark.createDataFrame(pdf, schema=schema) - with self.assertRaises(Exception): - df = self.spark.createDataFrame(pdf, schema=schema, verifySchema=True) - - # pandas DataFrame without Arrow optimization - with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}): - pdf = pd.DataFrame(data) - with self.assertRaises(Exception): - df = self.spark.createDataFrame(pdf, schema=schema) # verifySchema defaults to True - df = self.spark.createDataFrame(pdf, schema=schema, verifySchema=False) - self.assertEqual(df.collect(), expected) + # verifySchema defaults to `spark.sql.execution.pandas.convertToArrowArraySafely`, + # which is false by default + self.assertEqual(df.collect(), expected) + with self.assertRaises(Exception): + with self.sql_conf( + {"spark.sql.execution.pandas.convertToArrowArraySafely": True} + ): + df = self.spark.createDataFrame(pdf, schema=schema) + with self.assertRaises(Exception): + df = self.spark.createDataFrame(pdf, schema=schema, verifySchema=True) + else: + # pandas DataFrame without Arrow optimization + with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}): + pdf = pd.DataFrame(data) + with self.assertRaises(Exception): + self.spark.createDataFrame(pdf, schema=schema) # verifySchema defaults to True + df = self.spark.createDataFrame(pdf, schema=schema, verifySchema=False) + self.assertEqual(df.collect(), expected) def _createDataFrame_toggle(self, data, schema=None): with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}):