diff --git a/python/pyspark/sql/pandas/conversion.py b/python/pyspark/sql/pandas/conversion.py index 172a4fc4b234..0c612bf4eae3 100644 --- a/python/pyspark/sql/pandas/conversion.py +++ b/python/pyspark/sql/pandas/conversion.py @@ -27,6 +27,7 @@ ) from warnings import warn +from pyspark._globals import _NoValue, _NoValueType from pyspark.errors.exceptions.captured import unwrap_spark_exception from pyspark.loose_version import LooseVersion from pyspark.util import _load_from_socket @@ -352,7 +353,7 @@ def createDataFrame( self, data: "PandasDataFrameLike", schema: Union[StructType, str], - verifySchema: bool = ..., + verifySchema: Union[_NoValueType, bool] = ..., ) -> "DataFrame": ... @@ -361,7 +362,7 @@ def createDataFrame( self, data: "pa.Table", schema: Union[StructType, str], - verifySchema: bool = ..., + verifySchema: Union[_NoValueType, bool] = ..., ) -> "DataFrame": ... @@ -370,7 +371,7 @@ def createDataFrame( # type: ignore[misc] data: Union["PandasDataFrameLike", "pa.Table"], schema: Optional[Union[StructType, List[str]]] = None, samplingRatio: Optional[float] = None, - verifySchema: bool = True, + verifySchema: Union[_NoValueType, bool] = _NoValue, ) -> "DataFrame": from pyspark.sql import SparkSession @@ -392,7 +393,7 @@ def createDataFrame( # type: ignore[misc] if schema is None: schema = data.schema.names - return self._create_from_arrow_table(data, schema, timezone) + return self._create_from_arrow_table(data, schema, timezone, verifySchema) # `data` is a PandasDataFrameLike object from pyspark.sql.pandas.utils import require_minimum_pandas_version @@ -405,7 +406,7 @@ def createDataFrame( # type: ignore[misc] if self._jconf.arrowPySparkEnabled() and len(data) > 0: try: - return self._create_from_pandas_with_arrow(data, schema, timezone) + return self._create_from_pandas_with_arrow(data, schema, timezone, verifySchema) except Exception as e: if self._jconf.arrowPySparkFallbackEnabled(): msg = ( @@ -624,7 +625,11 @@ def _get_numpy_record_dtype(self, rec: "np.recarray") -> Optional["np.dtype"]: return np.dtype(record_type_list) if has_rec_fix else None def _create_from_pandas_with_arrow( - self, pdf: "PandasDataFrameLike", schema: Union[StructType, List[str]], timezone: str + self, + pdf: "PandasDataFrameLike", + schema: Union[StructType, List[str]], + timezone: str, + verifySchema: Union[_NoValueType, bool], ) -> "DataFrame": """ Create a DataFrame from a given pandas.DataFrame by slicing it into partitions, converting @@ -657,6 +662,10 @@ def _create_from_pandas_with_arrow( ) import pyarrow as pa + if verifySchema is _NoValue: + # (With Arrow optimization) createDataFrame with `pandas.DataFrame` + verifySchema = self._jconf.arrowSafeTypeConversion() + infer_pandas_dict_as_map = ( str(self.conf.get("spark.sql.execution.pandas.inferPandasDictAsMap")).lower() == "true" ) @@ -725,8 +734,7 @@ def _create_from_pandas_with_arrow( jsparkSession = self._jsparkSession - safecheck = self._jconf.arrowSafeTypeConversion() - ser = ArrowStreamPandasSerializer(timezone, safecheck) + ser = ArrowStreamPandasSerializer(timezone, verifySchema) @no_type_check def reader_func(temp_filename): @@ -745,7 +753,11 @@ def create_iter_server(): return df def _create_from_arrow_table( - self, table: "pa.Table", schema: Union[StructType, List[str]], timezone: str + self, + table: "pa.Table", + schema: Union[StructType, List[str]], + timezone: str, + verifySchema: Union[_NoValueType, bool], ) -> "DataFrame": """ Create a DataFrame from a given pyarrow.Table by slicing it into partitions then @@ -767,6 +779,10 @@ def _create_from_arrow_table( require_minimum_pyarrow_version() + if verifySchema is _NoValue: + # createDataFrame with `pyarrow.Table` + verifySchema = False + prefer_timestamp_ntz = is_timestamp_ntz_preferred() # Create the Spark schema from list of names passed in with Arrow types @@ -786,7 +802,8 @@ def _create_from_arrow_table( schema = from_arrow_schema(table.schema, prefer_timestamp_ntz=prefer_timestamp_ntz) table = _check_arrow_table_timestamps_localize(table, schema, True, timezone).cast( - to_arrow_schema(schema, error_on_duplicated_field_names_in_struct=True) + to_arrow_schema(schema, error_on_duplicated_field_names_in_struct=True), + safe=verifySchema, ) # Chunk the Arrow Table into RecordBatches diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 4979ce712673..ef8750b6e72d 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -38,6 +38,7 @@ TYPE_CHECKING, ) +from pyspark._globals import _NoValue, _NoValueType from pyspark.conf import SparkConf from pyspark.util import is_remote_only from pyspark.sql.conf import RuntimeConfig @@ -1265,7 +1266,7 @@ def createDataFrame( data: Iterable["RowLike"], schema: Union[StructType, str], *, - verifySchema: bool = ..., + verifySchema: Union[_NoValueType, bool] = ..., ) -> DataFrame: ... @@ -1275,7 +1276,7 @@ def createDataFrame( data: "RDD[RowLike]", schema: Union[StructType, str], *, - verifySchema: bool = ..., + verifySchema: Union[_NoValueType, bool] = ..., ) -> DataFrame: ... @@ -1284,7 +1285,7 @@ def createDataFrame( self, data: "RDD[AtomicValue]", schema: Union[AtomicType, str], - verifySchema: bool = ..., + verifySchema: Union[_NoValueType, bool] = ..., ) -> DataFrame: ... @@ -1293,7 +1294,7 @@ def createDataFrame( self, data: Iterable["AtomicValue"], schema: Union[AtomicType, str], - verifySchema: bool = ..., + verifySchema: Union[_NoValueType, bool] = ..., ) -> DataFrame: ... @@ -1312,7 +1313,7 @@ def createDataFrame( self, data: "PandasDataFrameLike", schema: Union[StructType, str], - verifySchema: bool = ..., + verifySchema: Union[_NoValueType, bool] = ..., ) -> DataFrame: ... @@ -1321,7 +1322,7 @@ def createDataFrame( self, data: "pa.Table", schema: Union[StructType, str], - verifySchema: bool = ..., + verifySchema: Union[_NoValueType, bool] = ..., ) -> DataFrame: ... @@ -1330,7 +1331,7 @@ def createDataFrame( # type: ignore[misc] data: Union["RDD[Any]", Iterable[Any], "PandasDataFrameLike", "ArrayLike", "pa.Table"], schema: Optional[Union[AtomicType, StructType, str]] = None, samplingRatio: Optional[float] = None, - verifySchema: bool = True, + verifySchema: Union[_NoValueType, bool] = _NoValue, ) -> DataFrame: """ Creates a :class:`DataFrame` from an :class:`RDD`, a list, a :class:`pandas.DataFrame`, @@ -1374,11 +1375,14 @@ def createDataFrame( # type: ignore[misc] if ``samplingRatio`` is ``None``. This option is effective only when the input is :class:`RDD`. verifySchema : bool, optional - verify data types of every row against schema. Enabled by default. - When the input is :class:`pyarrow.Table` or when the input class is - :class:`pandas.DataFrame` and `spark.sql.execution.arrow.pyspark.enabled` is enabled, - this option is not effective. It follows Arrow type coercion. This option is not - supported with Spark Connect. + verify data types of every row against schema. + If not provided, createDataFrame with + - pyarrow.Table, verifySchema=False + - pandas.DataFrame with Arrow optimization, verifySchema defaults to + `spark.sql.execution.pandas.convertToArrowArraySafely` + - pandas.DataFrame without Arrow optimization, verifySchema=True + - regular Python instances, verifySchema=True + Arrow optimization is enabled/disabled via `spark.sql.execution.arrow.pyspark.enabled`. .. versionadded:: 2.1.0 @@ -1578,8 +1582,13 @@ def _create_dataframe( data: Union["RDD[Any]", Iterable[Any]], schema: Optional[Union[DataType, List[str]]], samplingRatio: Optional[float], - verifySchema: bool, + verifySchema: Union[_NoValueType, bool], ) -> DataFrame: + if verifySchema is _NoValue: + # createDataFrame with regular Python instances + # or (without Arrow optimization) createDataFrame with Pandas DataFrame + verifySchema = True + if isinstance(schema, StructType): verify_func = _make_type_verifier(schema) if verifySchema else lambda _: True diff --git a/python/pyspark/sql/tests/connect/test_parity_arrow.py b/python/pyspark/sql/tests/connect/test_parity_arrow.py index 885b3001b1db..d47a367a5460 100644 --- a/python/pyspark/sql/tests/connect/test_parity_arrow.py +++ b/python/pyspark/sql/tests/connect/test_parity_arrow.py @@ -137,6 +137,10 @@ def test_toPandas_udt(self): def test_create_dataframe_namedtuples(self): self.check_create_dataframe_namedtuples(True) + @unittest.skip("Spark Connect does not support verifySchema.") + def test_createDataFrame_verifySchema(self): + super().test_createDataFrame_verifySchema() + if __name__ == "__main__": from pyspark.sql.tests.connect.test_parity_arrow import * # noqa: F401 diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index b71bdb1eece2..19d0db989431 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -532,6 +532,45 @@ def test_createDataFrame_arrow_pandas(self): df_pandas = self.spark.createDataFrame(pdf) self.assertEqual(df_arrow.collect(), df_pandas.collect()) + def test_createDataFrame_verifySchema(self): + data = {"id": [1, 2, 3], "value": [100000000000, 200000000000, 300000000000]} + # data.value should fail schema validation when verifySchema is True + schema = StructType( + [StructField("id", IntegerType(), True), StructField("value", IntegerType(), True)] + ) + expected = [ + Row(id=1, value=1215752192), + Row(id=2, value=-1863462912), + Row(id=3, value=-647710720), + ] + # Arrow table + table = pa.table(data) + df = self.spark.createDataFrame(table, schema=schema) + self.assertEqual(df.collect(), expected) + + with self.assertRaises(Exception): + self.spark.createDataFrame(table, schema=schema, verifySchema=True) + + # pandas DataFrame with Arrow optimization + pdf = pd.DataFrame(data) + df = self.spark.createDataFrame(pdf, schema=schema) + # verifySchema defaults to `spark.sql.execution.pandas.convertToArrowArraySafely`, + # which is false by default + self.assertEqual(df.collect(), expected) + with self.assertRaises(Exception): + with self.sql_conf({"spark.sql.execution.pandas.convertToArrowArraySafely": True}): + df = self.spark.createDataFrame(pdf, schema=schema) + with self.assertRaises(Exception): + df = self.spark.createDataFrame(pdf, schema=schema, verifySchema=True) + + # pandas DataFrame without Arrow optimization + with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}): + pdf = pd.DataFrame(data) + with self.assertRaises(Exception): + df = self.spark.createDataFrame(pdf, schema=schema) # verifySchema defaults to True + df = self.spark.createDataFrame(pdf, schema=schema, verifySchema=False) + self.assertEqual(df.collect(), expected) + def _createDataFrame_toggle(self, data, schema=None): with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}): df_no_arrow = self.spark.createDataFrame(data, schema=schema) diff --git a/python/pyspark/sql/tests/typing/test_session.yml b/python/pyspark/sql/tests/typing/test_session.yml index d6eee82a7678..98587458efe8 100644 --- a/python/pyspark/sql/tests/typing/test_session.yml +++ b/python/pyspark/sql/tests/typing/test_session.yml @@ -17,6 +17,7 @@ - case: createDataFrameStructsValid main: | + from pyspark._globals import _NoValueType from pyspark.sql import SparkSession from pyspark.sql.types import StructType, StructField, StringType, IntegerType @@ -78,14 +79,14 @@ main:18: note: Possible overload variants: main:18: note: def [RowLike in (list[Any], tuple[Any, ...], Row)] createDataFrame(self, data: Iterable[RowLike], schema: Union[list[str], tuple[str, ...]] = ..., samplingRatio: Optional[float] = ...) -> DataFrame main:18: note: def [RowLike in (list[Any], tuple[Any, ...], Row)] createDataFrame(self, data: RDD[RowLike], schema: Union[list[str], tuple[str, ...]] = ..., samplingRatio: Optional[float] = ...) -> DataFrame - main:18: note: def [RowLike in (list[Any], tuple[Any, ...], Row)] createDataFrame(self, data: Iterable[RowLike], schema: Union[StructType, str], *, verifySchema: bool = ...) -> DataFrame - main:18: note: def [RowLike in (list[Any], tuple[Any, ...], Row)] createDataFrame(self, data: RDD[RowLike], schema: Union[StructType, str], *, verifySchema: bool = ...) -> DataFrame - main:18: note: def [AtomicValue in (datetime, date, Decimal, bool, str, int, float)] createDataFrame(self, data: RDD[AtomicValue], schema: Union[AtomicType, str], verifySchema: bool = ...) -> DataFrame - main:18: note: def [AtomicValue in (datetime, date, Decimal, bool, str, int, float)] createDataFrame(self, data: Iterable[AtomicValue], schema: Union[AtomicType, str], verifySchema: bool = ...) -> DataFrame + main:18: note: def [RowLike in (list[Any], tuple[Any, ...], Row)] createDataFrame(self, data: Iterable[RowLike], schema: Union[StructType, str], *, verifySchema: Union[_NoValueType, bool] = ...) -> DataFrame + main:18: note: def [RowLike in (list[Any], tuple[Any, ...], Row)] createDataFrame(self, data: RDD[RowLike], schema: Union[StructType, str], *, verifySchema: Union[_NoValueType, bool] = ...) -> DataFrame + main:18: note: def [AtomicValue in (datetime, date, Decimal, bool, str, int, float)] createDataFrame(self, data: RDD[AtomicValue], schema: Union[AtomicType, str], verifySchema: Union[_NoValueType, bool] = ...) -> DataFrame + main:18: note: def [AtomicValue in (datetime, date, Decimal, bool, str, int, float)] createDataFrame(self, data: Iterable[AtomicValue], schema: Union[AtomicType, str], verifySchema: Union[_NoValueType, bool] = ...) -> DataFrame main:18: note: def createDataFrame(self, data: DataFrame, samplingRatio: Optional[float] = ...) -> DataFrame main:18: note: def createDataFrame(self, data: Any, samplingRatio: Optional[float] = ...) -> DataFrame - main:18: note: def createDataFrame(self, data: DataFrame, schema: Union[StructType, str], verifySchema: bool = ...) -> DataFrame - main:18: note: def createDataFrame(self, data: Any, schema: Union[StructType, str], verifySchema: bool = ...) -> DataFrame + main:18: note: def createDataFrame(self, data: DataFrame, schema: Union[StructType, str], verifySchema: Union[_NoValueType, bool] = ...) -> DataFrame + main:18: note: def createDataFrame(self, data: Any, schema: Union[StructType, str], verifySchema: Union[_NoValueType, bool] = ...) -> DataFrame - case: createDataFrameFromEmptyRdd main: |