Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,13 +347,23 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
int_to_decimal_coercion_enabled : bool
If True, applies additional coercions in Python before converting to Arrow
This has performance penalties.
pandas_backend : str
(Experimental) Back-end data type applied to the pandas DataFrame or Series.
Supported options are: numpy and pyarrow.
"""

def __init__(self, timezone, safecheck, int_to_decimal_coercion_enabled):
def __init__(
self,
timezone,
safecheck,
int_to_decimal_coercion_enabled,
pandas_backend,
):
super().__init__()
self._timezone = timezone
self._safecheck = safecheck
self._int_to_decimal_coercion_enabled = int_to_decimal_coercion_enabled
self._pandas_backend = pandas_backend

def arrow_to_pandas(
self, arrow_column, idx, struct_in_pandas="dict", ndarray_as_list=False, spark_type=None
Expand All @@ -366,6 +376,17 @@ def arrow_to_pandas(
"date_as_object": True,
"coerce_temporal_nanoseconds": True,
}

if self._pandas_backend == "pyarrow":
import pandas as pd

pandas_options.update(
{
"types_mapper": pd.ArrowDtype,
"zero_copy_only": True, # Raise an ArrowException if copy the underlying data
}
)

s = arrow_column.to_pandas(**pandas_options)

converter = _create_converter_to_pandas(
Expand All @@ -375,6 +396,7 @@ def arrow_to_pandas(
struct_in_pandas=struct_in_pandas,
error_on_duplicated_field_names=True,
ndarray_as_list=ndarray_as_list,
pandas_backend=self._pandas_backend,
)
return converter(s)

Expand Down Expand Up @@ -411,6 +433,7 @@ def _create_array(self, series, arrow_type, spark_type=None, arrow_cast=False):
timezone=self._timezone,
error_on_duplicated_field_names=False,
int_to_decimal_coercion_enabled=self._int_to_decimal_coercion_enabled,
pandas_backend=self._pandas_backend,
)
series = conv(series)

Expand Down Expand Up @@ -531,8 +554,9 @@ def __init__(
arrow_cast=False,
input_types=None,
int_to_decimal_coercion_enabled=False,
pandas_backend="numpy",
):
super().__init__(timezone, safecheck, int_to_decimal_coercion_enabled)
super().__init__(timezone, safecheck, int_to_decimal_coercion_enabled, pandas_backend)
self._assign_cols_by_name = assign_cols_by_name
self._df_for_struct = df_for_struct
self._struct_in_pandas = struct_in_pandas
Expand Down Expand Up @@ -1143,6 +1167,7 @@ def __init__(
safecheck,
assign_cols_by_name,
int_to_decimal_coercion_enabled,
pandas_backend,
):
super().__init__(
timezone=timezone,
Expand All @@ -1154,6 +1179,7 @@ def __init__(
arrow_cast=True,
input_types=None,
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
pandas_backend=pandas_backend,
)

def load_stream(self, stream):
Expand Down Expand Up @@ -1197,6 +1223,7 @@ def __init__(
safecheck,
assign_cols_by_name,
int_to_decimal_coercion_enabled,
pandas_backend,
):
super().__init__(
timezone=timezone,
Expand All @@ -1208,6 +1235,7 @@ def __init__(
arrow_cast=True,
input_types=None,
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
pandas_backend=pandas_backend,
)

def load_stream(self, stream):
Expand Down
27 changes: 27 additions & 0 deletions python/pyspark/sql/pandas/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,6 +867,7 @@ def _create_converter_to_pandas(
error_on_duplicated_field_names: bool = True,
timestamp_utc_localized: bool = True,
ndarray_as_list: bool = False,
pandas_backend: str = "numpy",
) -> Callable[["pd.Series"], "pd.Series"]:
"""
Create a converter of pandas Series that is created from Spark's Python objects,
Expand Down Expand Up @@ -895,6 +896,9 @@ def _create_converter_to_pandas(
whereas the ones from `df.collect()` are localized to the local timezone.
ndarray_as_list : bool, optional
Whether `np.ndarray` is converted to a list or not (default ``False``).
pandas_backend : str, (default ``numpy``)
(Experimental) Back-end data type applied to the pandas DataFrame or Series.
Supported options are: numpy and pyarrow.

Returns
-------
Expand All @@ -903,6 +907,15 @@ def _create_converter_to_pandas(
import numpy as np
import pandas as pd

if pandas_backend == "pyarrow":

def convert_pyarrow(pser: pd.Series) -> pd.Series:
# Assert that the result series is PyArrow-backed
assert isinstance(pser.dtype, pd.ArrowDtype), pser.dtype
return pser

return convert_pyarrow

pandas_type = _to_corrected_pandas_type(data_type)

if pandas_type is not None:
Expand Down Expand Up @@ -1227,6 +1240,7 @@ def _create_converter_from_pandas(
error_on_duplicated_field_names: bool = True,
ignore_unexpected_complex_type_values: bool = False,
int_to_decimal_coercion_enabled: bool = False,
pandas_backend: str = "numpy",
) -> Callable[["pd.Series"], "pd.Series"]:
"""
Create a converter of pandas Series to create Spark DataFrame with Arrow optimization.
Expand All @@ -1251,13 +1265,26 @@ def _create_converter_from_pandas(
and raise an AssertionError when the given value is not the expected type.
If ``True``, just ignore and return the give value.
(default ``False``)
pandas_backend : str, (default ``numpy``)
(Experimental) Back-end data type applied to the pandas DataFrame or Series.
Supported options are: numpy and pyarrow.

Returns
-------
The converter of `pandas.Series`
"""
import pandas as pd

if pandas_backend == "pyarrow":
arrow_type = to_arrow_type(data_type, error_on_duplicated_field_names)

def convert_pyarrow(pser: pd.Series) -> pd.Series:
arrow_pser = pser.astype(pd.ArrowDtype(arrow_type))
assert isinstance(arrow_pser.dtype, pd.ArrowDtype), arrow_pser.dtype
return arrow_pser

return convert_pyarrow

if isinstance(data_type, TimestampType):
assert timezone is not None

Expand Down
8 changes: 8 additions & 0 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,10 @@ def binary_as_bytes(self) -> bool:
def safecheck(self) -> bool:
return self.get("spark.sql.execution.pandas.convertToArrowArraySafely", "false") == "true"

@property
def pandas_backend(self) -> str:
return self.get("spark.sql.execution.pandas.backend", "numpy")

@property
def int_to_decimal_coercion_enabled(self) -> bool:
return (
Expand Down Expand Up @@ -2760,6 +2764,7 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf):
runner_conf.safecheck,
runner_conf.assign_cols_by_name,
runner_conf.int_to_decimal_coercion_enabled,
pandas_backend=runner_conf.pandas_backend,
)
elif (
eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF
Expand All @@ -2770,6 +2775,7 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf):
runner_conf.safecheck,
runner_conf.assign_cols_by_name,
runner_conf.int_to_decimal_coercion_enabled,
pandas_backend=runner_conf.pandas_backend,
)
elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF:
ser = CogroupArrowUDFSerializer(runner_conf.assign_cols_by_name)
Expand All @@ -2780,6 +2786,7 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf):
runner_conf.assign_cols_by_name,
int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled,
arrow_cast=True,
pandas_backend=runner_conf.pandas_backend,
)
elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE:
ser = ApplyInPandasWithStateSerializer(
Expand Down Expand Up @@ -2869,6 +2876,7 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf):
True,
input_types,
int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled,
pandas_backend=runner_conf.pandas_backend,
)
else:
batch_size = int(os.environ.get("PYTHON_UDF_BATCH_SIZE", "100"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4246,6 +4246,15 @@ object SQLConf {
.checkValues(Set("legacy", "row", "dict"))
.createWithDefaultString("legacy")

val PANDAS_EXECUTION_BACKEND =
buildConf("spark.sql.execution.pandas.backend")
.doc("(Experimental) The backend of Pandas Series/DataFrame in Python execution. " +
"This configuration applies to Pandas UDFs.")
.version("4.2.0")
.stringConf
.checkValues(Set("numpy", "pyarrow"))
.createWithDefaultString("numpy")

val PYSPARK_HIDE_TRACEBACK =
buildConf("spark.sql.execution.pyspark.udf.hideTraceback.enabled")
.doc(
Expand Down Expand Up @@ -7627,6 +7636,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf {

def pandasStructHandlingMode: String = getConf(PANDAS_STRUCT_HANDLING_MODE)

def pandasExecutionBackend: String = getConf(PANDAS_EXECUTION_BACKEND)

def pysparkHideTraceback: Boolean = getConf(PYSPARK_HIDE_TRACEBACK)

def pysparkSimplifiedTraceback: Boolean = getConf(PYSPARK_SIMPLIFIED_TRACEBACK)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ object ArrowPythonRunner {
val timeZoneConf = Seq(SQLConf.SESSION_LOCAL_TIMEZONE.key -> conf.sessionLocalTimeZone)
val pandasColsByName = Seq(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME.key ->
conf.pandasGroupedMapAssignColumnsByName.toString)
val pandasBackend = Seq(
SQLConf.PANDAS_EXECUTION_BACKEND.key -> conf.pandasExecutionBackend)
val arrowSafeTypeCheck = Seq(SQLConf.PANDAS_ARROW_SAFE_TYPE_CONVERSION.key ->
conf.arrowSafeTypeConversion.toString)
val arrowAyncParallelism = conf.pythonUDFArrowConcurrencyLevel.map(v =>
Expand All @@ -170,7 +172,8 @@ object ArrowPythonRunner {
val binaryAsBytes = Seq(
SQLConf.PYSPARK_BINARY_AS_BYTES.key ->
conf.pysparkBinaryAsBytes.toString)
Map(timeZoneConf ++ pandasColsByName ++ arrowSafeTypeCheck ++
Map(timeZoneConf ++ pandasColsByName ++
pandasBackend ++ arrowSafeTypeCheck ++
arrowAyncParallelism ++ useLargeVarTypes ++
intToDecimalCoercion ++ binaryAsBytes ++
legacyPandasConversion ++ legacyPandasConversionUDF: _*)
Expand Down