Skip to content

Commit

Permalink
REF: simplify Block.where (and subtle alignment bug) (#44691)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored Dec 1, 2021
1 parent bd6eb7e commit ede6234
Showing 1 changed file with 30 additions and 28 deletions.
58 changes: 30 additions & 28 deletions pandas/core/internals/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1139,17 +1139,11 @@ def shift(self, periods: int, axis: int = 0, fill_value: Any = None) -> list[Blo
# convert integer to float if necessary. need to do a lot more than
# that, handle boolean etc also

# error: Value of type variable "NumpyArrayT" of "maybe_upcast" cannot be
# "Union[ndarray[Any, Any], ExtensionArray]"
new_values, fill_value = maybe_upcast(
self.values, fill_value # type: ignore[type-var]
)
values = cast(np.ndarray, self.values)

# error: Argument 1 to "shift" has incompatible type "Union[ndarray[Any, Any],
# ExtensionArray]"; expected "ndarray[Any, Any]"
new_values = shift(
new_values, periods, axis, fill_value # type: ignore[arg-type]
)
new_values, fill_value = maybe_upcast(values, fill_value)

new_values = shift(new_values, periods, axis, fill_value)

return [self.make_block(new_values)]

Expand All @@ -1171,7 +1165,8 @@ def where(self, other, cond) -> list[Block]:

transpose = self.ndim == 2

values = self.values
# EABlocks override where
values = cast(np.ndarray, self.values)
orig_other = other
if transpose:
values = values.T
Expand All @@ -1185,22 +1180,15 @@ def where(self, other, cond) -> list[Block]:
# TODO: avoid the downcasting at the end in this case?
# GH-39595: Always return a copy
result = values.copy()

elif not self._can_hold_element(other):
# we cannot coerce, return a compat dtype
block = self.coerce_to_target_dtype(other)
blocks = block.where(orig_other, cond)
return self._maybe_downcast(blocks, "infer")

else:
# see if we can operate on the entire block, or need item-by-item
# or if we are a single block (ndim == 1)
if not self._can_hold_element(other):
# we cannot coerce, return a compat dtype
block = self.coerce_to_target_dtype(other)
blocks = block.where(orig_other, cond)
return self._maybe_downcast(blocks, "infer")

# error: Argument 1 to "setitem_datetimelike_compat" has incompatible type
# "Union[ndarray, ExtensionArray]"; expected "ndarray"
# error: Argument 2 to "setitem_datetimelike_compat" has incompatible type
# "number[Any]"; expected "int"
alt = setitem_datetimelike_compat(
values, icond.sum(), other # type: ignore[arg-type]
)
alt = setitem_datetimelike_compat(values, icond.sum(), other)
if alt is not other:
if is_list_like(other) and len(other) < len(values):
# call np.where with other to get the appropriate ValueError
Expand All @@ -1215,6 +1203,19 @@ def where(self, other, cond) -> list[Block]:
else:
# By the time we get here, we should have all Series/Index
# args extracted to ndarray
if (
is_list_like(other)
and not isinstance(other, np.ndarray)
and len(other) == self.shape[-1]
):
# If we don't do this broadcasting here, then expressions.where
# will broadcast a 1D other to be row-like instead of
# column-like.
other = np.array(other).reshape(values.shape)
# If lengths don't match (or len(other)==1), we will raise
# inside expressions.where, see test_series_where

# Note: expressions.where may upcast.
result = expressions.where(~icond, values, other)

if self._can_hold_na or self.ndim == 1:
Expand All @@ -1233,7 +1234,6 @@ def where(self, other, cond) -> list[Block]:
result_blocks: list[Block] = []
for m in [mask, ~mask]:
if m.any():
result = cast(np.ndarray, result) # EABlock overrides where
taken = result.take(m.nonzero()[0], axis=axis)
r = maybe_downcast_numeric(taken, self.dtype)
nb = self.make_block(r.T, placement=self._mgr_locs[m])
Expand Down Expand Up @@ -1734,7 +1734,9 @@ def where(self, other, cond) -> list[Block]:
try:
res_values = arr.T._where(cond, other).T
except (ValueError, TypeError):
return Block.where(self, other, cond)
blk = self.coerce_to_target_dtype(other)
nbs = blk.where(other, cond)
return self._maybe_downcast(nbs, "infer")

nb = self.make_block_same_class(res_values)
return [nb]
Expand Down

0 comments on commit ede6234

Please sign in to comment.