Skip to content

Commit

Permalink
Use single wrap for Dask
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreyPavlenko committed Mar 7, 2024
1 parent 4176f38 commit 8013bb2
Showing 1 changed file with 16 additions and 10 deletions.
26 changes: 16 additions & 10 deletions modin/core/dataframe/pandas/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4649,9 +4649,15 @@ def case_when(self, caselist):
-------
PandasDataframe
"""
# For Dask the callables must wrapped for each partition, otherwise
# the execution could fail with CancelledError.
single_wrap = Engine.get() != "Dask"
cls = type(self)
wrapper_put = self._partition_mgr_cls._execution_wrapper.put
if (remote_fn := getattr(cls, "_CASE_WHEN_FN", None)) is None:
if (
not single_wrap
or (remote_fn := getattr(cls, "_CASE_WHEN_FN", None)) is None
):

def case_when(df, caselist): # pragma: no cover
caselist = [
Expand All @@ -4664,7 +4670,10 @@ def case_when(df, caselist): # pragma: no cover
series = df.iloc[:, 0]
return pandas.DataFrame({series.name: series.case_when(caselist)})

cls._CASE_WHEN_FN = remote_fn = wrapper_put(case_when)
if single_wrap:
cls._CASE_WHEN_FN = remote_fn = wrapper_put(case_when)
else:
remote_fn = case_when

parts_len = len(self._partitions)
parts = None
Expand Down Expand Up @@ -4709,21 +4718,18 @@ def copartition(df, fill_value):
)
return df

# For Dask the callables are wrapped for each partition in the map_data() function.
# If the same callable is wrapped only once for all partitions, CancelledError is raised.
wrap_callable = Engine.get() != "Dask"
use_map = wrap_callable
use_map = single_wrap
new_caselist = []
for condition, replacement in caselist:
if callable(condition):
if wrap_callable:
if single_wrap:
condition = wrapper_put(condition)
else:
use_map = False
if isinstance(condition, cls):
condition = copartition(condition, True)
if callable(replacement):
if wrap_callable:
if single_wrap:
replacement = wrapper_put(replacement)
elif use_map and is_list_like(replacement):
use_map = False
Expand Down Expand Up @@ -4765,7 +4771,7 @@ def map_data(
)

# As mentioned above, this is required for Dask
if not wrap_callable and callable(data):
if not single_wrap and callable(data):
return wrapper_put(data)

return (
Expand All @@ -4791,7 +4797,7 @@ def map_data(
]
new_parts.append(
part.add_to_apply_calls(
remote_fn,
remote_fn if single_wrap else wrapper_put(remote_fn),
cases,
length=part_len,
width=1,
Expand Down

0 comments on commit 8013bb2

Please sign in to comment.