From 885a1c49f1e78fb34978f52ce907bbcbe83ba863 Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Thu, 11 Nov 2021 09:50:22 -0800 Subject: [PATCH] ENH: implement EA._putmask (#44387) --- pandas/core/arrays/_mixins.py | 2 +- pandas/core/arrays/base.py | 27 +++++++++++++++++++++++++++ pandas/core/arrays/interval.py | 7 ++++--- pandas/core/indexes/base.py | 6 ++---- pandas/core/internals/blocks.py | 8 +++----- 5 files changed, 37 insertions(+), 13 deletions(-) diff --git a/pandas/core/arrays/_mixins.py b/pandas/core/arrays/_mixins.py index 8deeb44f65188..674379f6d65f8 100644 --- a/pandas/core/arrays/_mixins.py +++ b/pandas/core/arrays/_mixins.py @@ -310,7 +310,7 @@ def _wrap_reduction_result(self, axis: int | None, result): # ------------------------------------------------------------------------ # __array_function__ methods - def putmask(self, mask: np.ndarray, value) -> None: + def _putmask(self, mask: npt.NDArray[np.bool_], value) -> None: """ Analogue to np.putmask(self, mask, value) diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index 70841197761a9..a64aef64ab49f 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -1409,6 +1409,33 @@ def insert(self: ExtensionArrayT, loc: int, item) -> ExtensionArrayT: return type(self)._concat_same_type([self[:loc], item_arr, self[loc:]]) + def _putmask(self, mask: npt.NDArray[np.bool_], value) -> None: + """ + Analogue to np.putmask(self, mask, value) + + Parameters + ---------- + mask : np.ndarray[bool] + value : scalar or listlike + If listlike, must be arraylike with same length as self. + + Returns + ------- + None + + Notes + ----- + Unlike np.putmask, we do not repeat listlike values with mismatched length. + 'value' should either be a scalar or an arraylike with the same length + as self. + """ + if is_list_like(value): + val = value[mask] + else: + val = value + + self[mask] = val + def _where( self: ExtensionArrayT, mask: npt.NDArray[np.bool_], value ) -> ExtensionArrayT: diff --git a/pandas/core/arrays/interval.py b/pandas/core/arrays/interval.py index d5718d59bf8b0..01bf5ec0633b5 100644 --- a/pandas/core/arrays/interval.py +++ b/pandas/core/arrays/interval.py @@ -36,6 +36,7 @@ PositionalIndexer, ScalarIndexer, SequenceIndexer, + npt, ) from pandas.compat.numpy import function as nv from pandas.util._decorators import Appender @@ -1482,15 +1483,15 @@ def to_tuples(self, na_tuple=True) -> np.ndarray: # --------------------------------------------------------------------- - def putmask(self, mask: np.ndarray, value) -> None: + def _putmask(self, mask: npt.NDArray[np.bool_], value) -> None: value_left, value_right = self._validate_setitem_value(value) if isinstance(self._left, np.ndarray): np.putmask(self._left, mask, value_left) np.putmask(self._right, mask, value_right) else: - self._left.putmask(mask, value_left) - self._right.putmask(mask, value_right) + self._left._putmask(mask, value_left) + self._right._putmask(mask, value_right) def insert(self: IntervalArrayT, loc: int, item: Interval) -> IntervalArrayT: """ diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index ba7dde7d2a4d8..2514702b036dd 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -4444,8 +4444,7 @@ def _join_non_unique( if isinstance(join_array, np.ndarray): np.putmask(join_array, mask, right) else: - # error: "ExtensionArray" has no attribute "putmask" - join_array.putmask(mask, right) # type: ignore[attr-defined] + join_array._putmask(mask, right) join_index = self._wrap_joined_index(join_array, other) @@ -5051,8 +5050,7 @@ def putmask(self, mask, value) -> Index: else: # Note: we use the original value here, not converted, as # _validate_fill_value is not idempotent - # error: "ExtensionArray" has no attribute "putmask" - values.putmask(mask, value) # type: ignore[attr-defined] + values._putmask(mask, value) return self._shallow_copy(values) diff --git a/pandas/core/internals/blocks.py b/pandas/core/internals/blocks.py index 2589015e0f0b1..66a40b962e183 100644 --- a/pandas/core/internals/blocks.py +++ b/pandas/core/internals/blocks.py @@ -1415,15 +1415,13 @@ def putmask(self, mask, new) -> list[Block]: new_values = self.values - if isinstance(new, (np.ndarray, ExtensionArray)) and len(new) == len(mask): - new = new[mask] - if mask.ndim == new_values.ndim + 1: # TODO(EA2D): unnecessary with 2D EAs mask = mask.reshape(new_values.shape) try: - new_values[mask] = new + # Caller is responsible for ensuring matching lengths + new_values._putmask(mask, new) except TypeError: if not is_interval_dtype(self.dtype): # Discussion about what we want to support in the general @@ -1704,7 +1702,7 @@ def putmask(self, mask, new) -> list[Block]: return self.coerce_to_target_dtype(new).putmask(mask, new) arr = self.values - arr.T.putmask(mask, new) + arr.T._putmask(mask, new) return [self] def where(self, other, cond) -> list[Block]: