Skip to content

Commit

Permalink
ENH: implement EA._putmask (pandas-dev#44387)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored and nickleus27 committed Nov 28, 2021
1 parent c4316b5 commit 885a1c4
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 13 deletions.
2 changes: 1 addition & 1 deletion pandas/core/arrays/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def _wrap_reduction_result(self, axis: int | None, result):
# ------------------------------------------------------------------------
# __array_function__ methods

def putmask(self, mask: np.ndarray, value) -> None:
def _putmask(self, mask: npt.NDArray[np.bool_], value) -> None:
"""
Analogue to np.putmask(self, mask, value)
Expand Down
27 changes: 27 additions & 0 deletions pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1409,6 +1409,33 @@ def insert(self: ExtensionArrayT, loc: int, item) -> ExtensionArrayT:

return type(self)._concat_same_type([self[:loc], item_arr, self[loc:]])

def _putmask(self, mask: npt.NDArray[np.bool_], value) -> None:
"""
Analogue to np.putmask(self, mask, value)
Parameters
----------
mask : np.ndarray[bool]
value : scalar or listlike
If listlike, must be arraylike with same length as self.
Returns
-------
None
Notes
-----
Unlike np.putmask, we do not repeat listlike values with mismatched length.
'value' should either be a scalar or an arraylike with the same length
as self.
"""
if is_list_like(value):
val = value[mask]
else:
val = value

self[mask] = val

def _where(
self: ExtensionArrayT, mask: npt.NDArray[np.bool_], value
) -> ExtensionArrayT:
Expand Down
7 changes: 4 additions & 3 deletions pandas/core/arrays/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
PositionalIndexer,
ScalarIndexer,
SequenceIndexer,
npt,
)
from pandas.compat.numpy import function as nv
from pandas.util._decorators import Appender
Expand Down Expand Up @@ -1482,15 +1483,15 @@ def to_tuples(self, na_tuple=True) -> np.ndarray:

# ---------------------------------------------------------------------

def putmask(self, mask: np.ndarray, value) -> None:
def _putmask(self, mask: npt.NDArray[np.bool_], value) -> None:
value_left, value_right = self._validate_setitem_value(value)

if isinstance(self._left, np.ndarray):
np.putmask(self._left, mask, value_left)
np.putmask(self._right, mask, value_right)
else:
self._left.putmask(mask, value_left)
self._right.putmask(mask, value_right)
self._left._putmask(mask, value_left)
self._right._putmask(mask, value_right)

def insert(self: IntervalArrayT, loc: int, item: Interval) -> IntervalArrayT:
"""
Expand Down
6 changes: 2 additions & 4 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4444,8 +4444,7 @@ def _join_non_unique(
if isinstance(join_array, np.ndarray):
np.putmask(join_array, mask, right)
else:
# error: "ExtensionArray" has no attribute "putmask"
join_array.putmask(mask, right) # type: ignore[attr-defined]
join_array._putmask(mask, right)

join_index = self._wrap_joined_index(join_array, other)

Expand Down Expand Up @@ -5051,8 +5050,7 @@ def putmask(self, mask, value) -> Index:
else:
# Note: we use the original value here, not converted, as
# _validate_fill_value is not idempotent
# error: "ExtensionArray" has no attribute "putmask"
values.putmask(mask, value) # type: ignore[attr-defined]
values._putmask(mask, value)

return self._shallow_copy(values)

Expand Down
8 changes: 3 additions & 5 deletions pandas/core/internals/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1415,15 +1415,13 @@ def putmask(self, mask, new) -> list[Block]:

new_values = self.values

if isinstance(new, (np.ndarray, ExtensionArray)) and len(new) == len(mask):
new = new[mask]

if mask.ndim == new_values.ndim + 1:
# TODO(EA2D): unnecessary with 2D EAs
mask = mask.reshape(new_values.shape)

try:
new_values[mask] = new
# Caller is responsible for ensuring matching lengths
new_values._putmask(mask, new)
except TypeError:
if not is_interval_dtype(self.dtype):
# Discussion about what we want to support in the general
Expand Down Expand Up @@ -1704,7 +1702,7 @@ def putmask(self, mask, new) -> list[Block]:
return self.coerce_to_target_dtype(new).putmask(mask, new)

arr = self.values
arr.T.putmask(mask, new)
arr.T._putmask(mask, new)
return [self]

def where(self, other, cond) -> list[Block]:
Expand Down

0 comments on commit 885a1c4

Please sign in to comment.