From 1dda05fe942e1147ddde6efcb886e06fc20fd7fe Mon Sep 17 00:00:00 2001 From: Andrey Pavlenko Date: Thu, 7 Mar 2024 16:33:15 +0300 Subject: [PATCH] Apply suggestions from code review Co-authored-by: Dmitry Chigarev --- .../dataframe/pandas/dataframe/dataframe.py | 166 +++++++++--------- .../pandas/partitioning/partition_manager.py | 9 +- .../storage_formats/base/query_compiler.py | 3 +- .../storage_formats/pandas/query_compiler.py | 20 +-- modin/pandas/test/test_series.py | 26 +-- 5 files changed, 111 insertions(+), 113 deletions(-) diff --git a/modin/core/dataframe/pandas/dataframe/dataframe.py b/modin/core/dataframe/pandas/dataframe/dataframe.py index 171fe0adfa4..95dbb57bc05 100644 --- a/modin/core/dataframe/pandas/dataframe/dataframe.py +++ b/modin/core/dataframe/pandas/dataframe/dataframe.py @@ -4659,145 +4659,141 @@ def case_when(self, caselist): or (remote_fn := getattr(cls, "_CASE_WHEN_FN", None)) is None ): - def case_when(df, caselist): # pragma: no cover + def case_when(df, name, caselist): # pragma: no cover caselist = [ tuple( - data.iloc[:, 0] if isinstance(data, pandas.DataFrame) else data + data.squeeze(1) if isinstance(data, pandas.DataFrame) else data for data in case_tuple ) for case_tuple in caselist ] - series = df.iloc[:, 0] - return pandas.DataFrame({series.name: series.case_when(caselist)}) + return pandas.DataFrame({name: df.squeeze(1).case_when(caselist)}) if single_wrap: cls._CASE_WHEN_FN = remote_fn = wrapper_put(case_when) else: remote_fn = case_when - parts_len = len(self._partitions) - parts = None - lengths = None + name = self.columns[0] + use_map = single_wrap is_trivial_idx = None - - def copartition(df, fill_value): - nonlocal is_trivial_idx - copartition = False - if is_trivial_idx is None: - is_trivial_idx = is_trivial_index(self.index) - if ( - is_trivial_idx != is_trivial_index(df.index) - or (not is_trivial_idx and any(self.index != df.index)) - or (parts_len < len(df._partitions)) + # Lists of modin frames: first for conditions, second for replacements + modin_lists = [[], []] + # Fill values for conditions and replacements respectively + fill_values = [True, None] + new_caselist = [] + for case_tuple in caselist: + new_case = [] + for data, modin_list, fill_value in zip( + case_tuple, modin_lists, fill_values ): - copartition = True - else: - nonlocal parts, lengths - if parts is None: - parts = [p[0] for p in self._partitions] - lengths = self._get_lengths(parts, Axis.ROW_WISE) - - df_parts = [p[0] for p in df._partitions] - df_part_lengths = df._get_lengths(df_parts, Axis.ROW_WISE) - if any(lengths[i] != df_part_lengths[i] for i in range(len(df_parts))): - copartition = True - if copartition: + if isinstance(data, cls): + modin_list.append(data) + elif callable(data): + if single_wrap: + data = wrapper_put(data) + elif isinstance(data, pandas.Series): + use_map = False + if is_trivial_idx is None: + self_idx = self.index + length = len(self_idx) + is_trivial_idx = is_trivial_index(self_idx) + if is_trivial_idx and is_trivial_index(data.index): + data = data[:length] + diff = length - len(data) + if diff > 0: + data = pandas.concat( + [data, pandas.Series([fill_value] * diff)] + ) + else: + data = data.reindex(self_idx, fill_value=fill_value) + elif use_map and is_list_like(data): + use_map = False + new_case.append(data) + new_caselist.append(tuple(new_case)) + + if modin_lists[0] or modin_lists[1]: + # Copartition modin frames + use_map = False + columns = self.columns + column_widths = [1] + for modin_list, fill_value in zip(modin_lists, fill_values): _, list_of_right_parts, joined_index, row_lengths = self._copartition( Axis.ROW_WISE.value, - df, + modin_list, how="left", sort=False, fill_value=fill_value, ) - df = self.__constructor__( - list_of_right_parts[0], - joined_index, - df.columns, - row_lengths, - df.column_widths, + modin_list.clear() + modin_list.extend( + self.__constructor__( + part, + joined_index, + columns, + row_lengths, + column_widths, + ) + for part in list_of_right_parts ) - return df - use_map = single_wrap - new_caselist = [] - for condition, replacement in caselist: - if callable(condition): - if single_wrap: - condition = wrapper_put(condition) - else: - use_map = False - if isinstance(condition, cls): - condition = copartition(condition, True) - if callable(replacement): - if single_wrap: - replacement = wrapper_put(replacement) - elif use_map and is_list_like(replacement): - use_map = False - if isinstance(replacement, cls): - replacement = copartition(replacement, None) - new_caselist.append((condition, replacement)) + # Replace modin frames with copartitioned + caselist = new_caselist + new_caselist = [] + for i in range(2): + modin_lists[i] = iter(modin_lists[i]) + for case_tuple in caselist: + new_case = tuple( + next(modin_list) if isinstance(data, cls) else data + for data, modin_list in zip(case_tuple, modin_lists) + ) + new_caselist.append(new_case) # If all the conditions are callable and the replacements are either # callable or scalar, use map(). if use_map: - return self.map(func=remote_fn, func_args=[new_caselist], lazy=True) + return self.map(func=remote_fn, func_args=[name, new_caselist], lazy=True) + # Get the chunk of data corresponding the the specified partition def map_data( - part_offset, + part_idx, part_len, data, data_offset, fill_value, ): if isinstance(data, cls): - if part_offset < len(data._partitions): - return data._partitions[part_offset][0]._data - else: - return [fill_value] * part_len - + return data._partitions[part_idx][0]._data if isinstance(data, pandas.Series): - nonlocal is_trivial_idx - if is_trivial_idx is None: - is_trivial_idx = is_trivial_index(self.index) - if is_trivial_idx and is_trivial_index(data.index): - data = data[data_offset : data_offset + part_len] - diff = part_len - len(data) - if diff > 0: - data = pandas.concat((data, pandas.Series([fill_value] * diff))) - return data - return data.reindex( - self.index[data_offset : data_offset + part_len], - fill_value=fill_value, - ) - + return data[data_offset : data_offset + part_len] # As mentioned above, this is required for Dask if not single_wrap and callable(data): return wrapper_put(data) - return ( data[data_offset : data_offset + part_len] if is_list_like(data) else data ) - if parts is None: - parts = [p[0] for p in self._partitions] - lengths = self._get_lengths(parts, Axis.ROW_WISE) + parts = [p[0] for p in self._partitions] + lengths = self._get_lengths(parts, Axis.ROW_WISE) new_parts = [] data_offset = 0 - for i in range(0, parts_len): - part = parts[i] - part_len = lengths[i] + + # Split the data and apply the remote function to each partition + # with the corresponding chunk of data + for i, part, part_len in zip(range(0, len(parts)), parts, lengths): cases = [ tuple( - map_data(i, part_len, d[0], data_offset, d[1]) - for d in zip(c, (True, None)) + map_data(i, part_len, data, data_offset, fill_value) + for data, fill_value in zip(c, (True, None)) ) for c in new_caselist ] new_parts.append( part.add_to_apply_calls( remote_fn if single_wrap else wrapper_put(remote_fn), + name, cases, length=part_len, width=1, diff --git a/modin/core/dataframe/pandas/partitioning/partition_manager.py b/modin/core/dataframe/pandas/partitioning/partition_manager.py index 077917a7407..052aaea8d2b 100644 --- a/modin/core/dataframe/pandas/partitioning/partition_manager.py +++ b/modin/core/dataframe/pandas/partitioning/partition_manager.py @@ -158,6 +158,9 @@ def preprocess_func(cls, map_func): `map_func` if the `apply` method of the `PandasDataframePartition` object you are using does not require any modification to a given function. """ + if cls._execution_wrapper.is_future(map_func): + return map_func # Has already been preprocessed + old_value = PersistentPickle.get() # When performing a function with Modin objects, it is more profitable to # do the conversion to pandas once on the main process than several times @@ -657,11 +660,7 @@ def lazy_map_partitions( NumPy array An array of partitions """ - preprocessed_map_func = ( - map_func - if cls._execution_wrapper.is_future(map_func) - else cls.preprocess_func(map_func) - ) + preprocessed_map_func = cls.preprocess_func(map_func) return np.array( [ [ diff --git a/modin/core/storage_formats/base/query_compiler.py b/modin/core/storage_formats/base/query_compiler.py index 68a669ac183..34685aae03f 100644 --- a/modin/core/storage_formats/base/query_compiler.py +++ b/modin/core/storage_formats/base/query_compiler.py @@ -6706,7 +6706,8 @@ def case_when(self, caselist): # noqa: PR01, RT01, D200 """ Replace values where the conditions are True. """ - qc_type = type(self) + # A workaround for https://github.com/modin-project/modin/issues/7041 + qc_type = BaseQueryCompiler caselist = [ tuple( data.to_pandas().squeeze(axis=1) if isinstance(data, qc_type) else data diff --git a/modin/core/storage_formats/pandas/query_compiler.py b/modin/core/storage_formats/pandas/query_compiler.py index ac50d5c50f6..2cef8040ba5 100644 --- a/modin/core/storage_formats/pandas/query_compiler.py +++ b/modin/core/storage_formats/pandas/query_compiler.py @@ -4531,14 +4531,12 @@ def compare(self, other, **kwargs): ) def case_when(self, caselist): - if impl := getattr(self._modin_frame, "case_when", None): - qc_type = type(self) - caselist = [ - tuple( - data._modin_frame if isinstance(data, qc_type) else data - for data in case_tuple - ) - for case_tuple in caselist - ] - return self.__constructor__(impl(caselist)) - return super().case_when(caselist) + qc_type = BaseQueryCompiler + caselist = [ + tuple( + data._modin_frame if isinstance(data, qc_type) else data + for data in case_tuple + ) + for case_tuple in caselist + ] + return self.__constructor__(self._modin_frame.case_when(caselist)) diff --git a/modin/pandas/test/test_series.py b/modin/pandas/test/test_series.py index a8960bd33ac..c4179e25ff2 100644 --- a/modin/pandas/test/test_series.py +++ b/modin/pandas/test/test_series.py @@ -16,7 +16,6 @@ import datetime import itertools import json -import math import unittest.mock as mock import matplotlib @@ -30,11 +29,12 @@ from pandas.errors import SpecificationError import modin.pandas as pd -from modin.config import MinPartitionSize, NPartitions, StorageFormat +from modin.config import NPartitions, StorageFormat from modin.pandas.io import to_pandas from modin.test.test_utils import warns_that_defaulting_to_pandas from modin.utils import get_current_execution, try_cast_to_pandas +from ...test.storage_formats.pandas.test_internals import construct_modin_df_by_scheme from .utils import ( RAND_HIGH, RAND_LOW, @@ -4550,15 +4550,19 @@ def permutations(values): ) def test_case_when(base, caselist): pd_result = base.case_when(caselist) - nparts = NPartitions.get() - part_size = MinPartitionSize.get() - new_nparts = max(1, min(math.ceil(len(base) / part_size), part_size)) + 1 - NPartitions.put(new_nparts) - MinPartitionSize.put(math.ceil(len(base) / new_nparts)) - base_repart = pd.Series(base) - NPartitions.put(nparts) - MinPartitionSize.put(part_size) - for df in (pd.Series(base), base_repart): + # 'base' and serieses from 'caselist' must have equal lengths, however in this test we want + # to verify that 'case_when' works correctly even if partitioning of 'base' and 'caselist' isn't equal + modin_base = pd.Series(base) + modin_base_repart = construct_modin_df_by_scheme( + base.to_frame(), + partitioning_scheme={"row_lengths": [14, 14, 12], "column_widths": [1]}, + ).squeeze(axis=1) + modin_base_repart.name = base.name + assert ( + modin_base._query_compiler._modin_frame._partitions.shape + != modin_base_repart._query_compiler._modin_frame._partitions.shape + ) + for df in (modin_base, modin_base_repart): df_equals(pd_result, df.case_when(caselist)) if any( isinstance(data, pandas.Series)