Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: implement EA._putmask #44387

Merged
merged 2 commits into from
Nov 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pandas/core/arrays/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def _wrap_reduction_result(self, axis: int | None, result):
# ------------------------------------------------------------------------
# __array_function__ methods

def putmask(self, mask: np.ndarray, value) -> None:
def _putmask(self, mask: npt.NDArray[np.bool_], value) -> None:
"""
Analogue to np.putmask(self, mask, value)

Expand Down
27 changes: 27 additions & 0 deletions pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1409,6 +1409,33 @@ def insert(self: ExtensionArrayT, loc: int, item) -> ExtensionArrayT:

return type(self)._concat_same_type([self[:loc], item_arr, self[loc:]])

def _putmask(self, mask: npt.NDArray[np.bool_], value) -> None:
"""
Analogue to np.putmask(self, mask, value)

Parameters
----------
mask : np.ndarray[bool]
value : scalar or listlike
If listlike, must be arraylike with same length as self.

Returns
-------
None

Notes
-----
Unlike np.putmask, we do not repeat listlike values with mismatched length.
'value' should either be a scalar or an arraylike with the same length
as self.
"""
if is_list_like(value):
val = value[mask]
else:
val = value

self[mask] = val

def _where(
self: ExtensionArrayT, mask: npt.NDArray[np.bool_], value
) -> ExtensionArrayT:
Expand Down
7 changes: 4 additions & 3 deletions pandas/core/arrays/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
PositionalIndexer,
ScalarIndexer,
SequenceIndexer,
npt,
)
from pandas.compat.numpy import function as nv
from pandas.util._decorators import Appender
Expand Down Expand Up @@ -1482,15 +1483,15 @@ def to_tuples(self, na_tuple=True) -> np.ndarray:

# ---------------------------------------------------------------------

def putmask(self, mask: np.ndarray, value) -> None:
def _putmask(self, mask: npt.NDArray[np.bool_], value) -> None:
value_left, value_right = self._validate_setitem_value(value)

if isinstance(self._left, np.ndarray):
np.putmask(self._left, mask, value_left)
np.putmask(self._right, mask, value_right)
else:
self._left.putmask(mask, value_left)
self._right.putmask(mask, value_right)
self._left._putmask(mask, value_left)
self._right._putmask(mask, value_right)

def insert(self: IntervalArrayT, loc: int, item: Interval) -> IntervalArrayT:
"""
Expand Down
6 changes: 2 additions & 4 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4444,8 +4444,7 @@ def _join_non_unique(
if isinstance(join_array, np.ndarray):
np.putmask(join_array, mask, right)
else:
# error: "ExtensionArray" has no attribute "putmask"
join_array.putmask(mask, right) # type: ignore[attr-defined]
join_array._putmask(mask, right)

join_index = self._wrap_joined_index(join_array, other)

Expand Down Expand Up @@ -5051,8 +5050,7 @@ def putmask(self, mask, value) -> Index:
else:
# Note: we use the original value here, not converted, as
# _validate_fill_value is not idempotent
# error: "ExtensionArray" has no attribute "putmask"
values.putmask(mask, value) # type: ignore[attr-defined]
values._putmask(mask, value)

return self._shallow_copy(values)

Expand Down
8 changes: 3 additions & 5 deletions pandas/core/internals/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1415,15 +1415,13 @@ def putmask(self, mask, new) -> list[Block]:

new_values = self.values

if isinstance(new, (np.ndarray, ExtensionArray)) and len(new) == len(mask):
new = new[mask]

if mask.ndim == new_values.ndim + 1:
# TODO(EA2D): unnecessary with 2D EAs
mask = mask.reshape(new_values.shape)

try:
new_values[mask] = new
# Caller is responsible for ensuring matching lengths
new_values._putmask(mask, new)
except TypeError:
if not is_interval_dtype(self.dtype):
# Discussion about what we want to support in the general
Expand Down Expand Up @@ -1704,7 +1702,7 @@ def putmask(self, mask, new) -> list[Block]:
return self.coerce_to_target_dtype(new).putmask(mask, new)

arr = self.values
arr.T.putmask(mask, new)
arr.T._putmask(mask, new)
return [self]

def where(self, other, cond) -> list[Block]:
Expand Down