Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Dmitry Chigarev <dmitry.chigarev@intel.com>
  • Loading branch information
AndreyPavlenko and dchigarev committed Mar 7, 2024
1 parent 8013bb2 commit 1dda05f
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 113 deletions.
166 changes: 81 additions & 85 deletions modin/core/dataframe/pandas/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4659,145 +4659,141 @@ def case_when(self, caselist):
or (remote_fn := getattr(cls, "_CASE_WHEN_FN", None)) is None
):

def case_when(df, caselist): # pragma: no cover
def case_when(df, name, caselist): # pragma: no cover
caselist = [
tuple(
data.iloc[:, 0] if isinstance(data, pandas.DataFrame) else data
data.squeeze(1) if isinstance(data, pandas.DataFrame) else data
for data in case_tuple
)
for case_tuple in caselist
]
series = df.iloc[:, 0]
return pandas.DataFrame({series.name: series.case_when(caselist)})
return pandas.DataFrame({name: df.squeeze(1).case_when(caselist)})

if single_wrap:
cls._CASE_WHEN_FN = remote_fn = wrapper_put(case_when)
else:
remote_fn = case_when

parts_len = len(self._partitions)
parts = None
lengths = None
name = self.columns[0]
use_map = single_wrap
is_trivial_idx = None

def copartition(df, fill_value):
nonlocal is_trivial_idx
copartition = False
if is_trivial_idx is None:
is_trivial_idx = is_trivial_index(self.index)
if (
is_trivial_idx != is_trivial_index(df.index)
or (not is_trivial_idx and any(self.index != df.index))
or (parts_len < len(df._partitions))
# Lists of modin frames: first for conditions, second for replacements
modin_lists = [[], []]
# Fill values for conditions and replacements respectively
fill_values = [True, None]
new_caselist = []
for case_tuple in caselist:
new_case = []
for data, modin_list, fill_value in zip(
case_tuple, modin_lists, fill_values
):
copartition = True
else:
nonlocal parts, lengths
if parts is None:
parts = [p[0] for p in self._partitions]
lengths = self._get_lengths(parts, Axis.ROW_WISE)

df_parts = [p[0] for p in df._partitions]
df_part_lengths = df._get_lengths(df_parts, Axis.ROW_WISE)
if any(lengths[i] != df_part_lengths[i] for i in range(len(df_parts))):
copartition = True
if copartition:
if isinstance(data, cls):
modin_list.append(data)
elif callable(data):
if single_wrap:
data = wrapper_put(data)
elif isinstance(data, pandas.Series):
use_map = False
if is_trivial_idx is None:
self_idx = self.index
length = len(self_idx)
is_trivial_idx = is_trivial_index(self_idx)
if is_trivial_idx and is_trivial_index(data.index):
data = data[:length]
diff = length - len(data)
if diff > 0:
data = pandas.concat(
[data, pandas.Series([fill_value] * diff)]
)
else:
data = data.reindex(self_idx, fill_value=fill_value)
elif use_map and is_list_like(data):
use_map = False
new_case.append(data)
new_caselist.append(tuple(new_case))

if modin_lists[0] or modin_lists[1]:
# Copartition modin frames
use_map = False
columns = self.columns
column_widths = [1]
for modin_list, fill_value in zip(modin_lists, fill_values):
_, list_of_right_parts, joined_index, row_lengths = self._copartition(
Axis.ROW_WISE.value,
df,
modin_list,
how="left",
sort=False,
fill_value=fill_value,
)
df = self.__constructor__(
list_of_right_parts[0],
joined_index,
df.columns,
row_lengths,
df.column_widths,
modin_list.clear()
modin_list.extend(
self.__constructor__(
part,
joined_index,
columns,
row_lengths,
column_widths,
)
for part in list_of_right_parts
)
return df

use_map = single_wrap
new_caselist = []
for condition, replacement in caselist:
if callable(condition):
if single_wrap:
condition = wrapper_put(condition)
else:
use_map = False
if isinstance(condition, cls):
condition = copartition(condition, True)
if callable(replacement):
if single_wrap:
replacement = wrapper_put(replacement)
elif use_map and is_list_like(replacement):
use_map = False
if isinstance(replacement, cls):
replacement = copartition(replacement, None)
new_caselist.append((condition, replacement))
# Replace modin frames with copartitioned
caselist = new_caselist
new_caselist = []
for i in range(2):
modin_lists[i] = iter(modin_lists[i])
for case_tuple in caselist:
new_case = tuple(
next(modin_list) if isinstance(data, cls) else data
for data, modin_list in zip(case_tuple, modin_lists)
)
new_caselist.append(new_case)

# If all the conditions are callable and the replacements are either
# callable or scalar, use map().
if use_map:
return self.map(func=remote_fn, func_args=[new_caselist], lazy=True)
return self.map(func=remote_fn, func_args=[name, new_caselist], lazy=True)

# Get the chunk of data corresponding the the specified partition
def map_data(
part_offset,
part_idx,
part_len,
data,
data_offset,
fill_value,
):
if isinstance(data, cls):
if part_offset < len(data._partitions):
return data._partitions[part_offset][0]._data
else:
return [fill_value] * part_len

return data._partitions[part_idx][0]._data
if isinstance(data, pandas.Series):
nonlocal is_trivial_idx
if is_trivial_idx is None:
is_trivial_idx = is_trivial_index(self.index)
if is_trivial_idx and is_trivial_index(data.index):
data = data[data_offset : data_offset + part_len]
diff = part_len - len(data)
if diff > 0:
data = pandas.concat((data, pandas.Series([fill_value] * diff)))
return data
return data.reindex(
self.index[data_offset : data_offset + part_len],
fill_value=fill_value,
)

return data[data_offset : data_offset + part_len]
# As mentioned above, this is required for Dask
if not single_wrap and callable(data):
return wrapper_put(data)

return (
data[data_offset : data_offset + part_len]
if is_list_like(data)
else data
)

if parts is None:
parts = [p[0] for p in self._partitions]
lengths = self._get_lengths(parts, Axis.ROW_WISE)
parts = [p[0] for p in self._partitions]
lengths = self._get_lengths(parts, Axis.ROW_WISE)
new_parts = []
data_offset = 0
for i in range(0, parts_len):
part = parts[i]
part_len = lengths[i]

# Split the data and apply the remote function to each partition
# with the corresponding chunk of data
for i, part, part_len in zip(range(0, len(parts)), parts, lengths):
cases = [
tuple(
map_data(i, part_len, d[0], data_offset, d[1])
for d in zip(c, (True, None))
map_data(i, part_len, data, data_offset, fill_value)
for data, fill_value in zip(c, (True, None))
)
for c in new_caselist
]
new_parts.append(
part.add_to_apply_calls(
remote_fn if single_wrap else wrapper_put(remote_fn),
name,
cases,
length=part_len,
width=1,
Expand Down
9 changes: 4 additions & 5 deletions modin/core/dataframe/pandas/partitioning/partition_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,9 @@ def preprocess_func(cls, map_func):
`map_func` if the `apply` method of the `PandasDataframePartition` object
you are using does not require any modification to a given function.
"""
if cls._execution_wrapper.is_future(map_func):
return map_func # Has already been preprocessed

old_value = PersistentPickle.get()
# When performing a function with Modin objects, it is more profitable to
# do the conversion to pandas once on the main process than several times
Expand Down Expand Up @@ -657,11 +660,7 @@ def lazy_map_partitions(
NumPy array
An array of partitions
"""
preprocessed_map_func = (
map_func
if cls._execution_wrapper.is_future(map_func)
else cls.preprocess_func(map_func)
)
preprocessed_map_func = cls.preprocess_func(map_func)
return np.array(
[
[
Expand Down
3 changes: 2 additions & 1 deletion modin/core/storage_formats/base/query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6706,7 +6706,8 @@ def case_when(self, caselist): # noqa: PR01, RT01, D200
"""
Replace values where the conditions are True.
"""
qc_type = type(self)
# A workaround for https://github.com/modin-project/modin/issues/7041
qc_type = BaseQueryCompiler
caselist = [
tuple(
data.to_pandas().squeeze(axis=1) if isinstance(data, qc_type) else data
Expand Down
20 changes: 9 additions & 11 deletions modin/core/storage_formats/pandas/query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4531,14 +4531,12 @@ def compare(self, other, **kwargs):
)

def case_when(self, caselist):
if impl := getattr(self._modin_frame, "case_when", None):
qc_type = type(self)
caselist = [
tuple(
data._modin_frame if isinstance(data, qc_type) else data
for data in case_tuple
)
for case_tuple in caselist
]
return self.__constructor__(impl(caselist))
return super().case_when(caselist)
qc_type = BaseQueryCompiler
caselist = [
tuple(
data._modin_frame if isinstance(data, qc_type) else data
for data in case_tuple
)
for case_tuple in caselist
]
return self.__constructor__(self._modin_frame.case_when(caselist))
26 changes: 15 additions & 11 deletions modin/pandas/test/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import datetime
import itertools
import json
import math
import unittest.mock as mock

import matplotlib
Expand All @@ -30,11 +29,12 @@
from pandas.errors import SpecificationError

import modin.pandas as pd
from modin.config import MinPartitionSize, NPartitions, StorageFormat
from modin.config import NPartitions, StorageFormat
from modin.pandas.io import to_pandas
from modin.test.test_utils import warns_that_defaulting_to_pandas
from modin.utils import get_current_execution, try_cast_to_pandas

from ...test.storage_formats.pandas.test_internals import construct_modin_df_by_scheme
from .utils import (
RAND_HIGH,
RAND_LOW,
Expand Down Expand Up @@ -4550,15 +4550,19 @@ def permutations(values):
)
def test_case_when(base, caselist):
pd_result = base.case_when(caselist)
nparts = NPartitions.get()
part_size = MinPartitionSize.get()
new_nparts = max(1, min(math.ceil(len(base) / part_size), part_size)) + 1
NPartitions.put(new_nparts)
MinPartitionSize.put(math.ceil(len(base) / new_nparts))
base_repart = pd.Series(base)
NPartitions.put(nparts)
MinPartitionSize.put(part_size)
for df in (pd.Series(base), base_repart):
# 'base' and serieses from 'caselist' must have equal lengths, however in this test we want
# to verify that 'case_when' works correctly even if partitioning of 'base' and 'caselist' isn't equal
modin_base = pd.Series(base)
modin_base_repart = construct_modin_df_by_scheme(
base.to_frame(),
partitioning_scheme={"row_lengths": [14, 14, 12], "column_widths": [1]},
).squeeze(axis=1)
modin_base_repart.name = base.name
assert (
modin_base._query_compiler._modin_frame._partitions.shape
!= modin_base_repart._query_compiler._modin_frame._partitions.shape
)
for df in (modin_base, modin_base_repart):
df_equals(pd_result, df.case_when(caselist))
if any(
isinstance(data, pandas.Series)
Expand Down

0 comments on commit 1dda05f

Please sign in to comment.