Skip to content
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,7 @@ Indexing
- Bug in :meth:`Series.loc.__getitem__` with a non-unique :class:`MultiIndex` and an empty-list indexer (:issue:`13691`)
- Bug in indexing on a :class:`Series` or :class:`DataFrame` with a :class:`MultiIndex` with a level named "0" (:issue:`37194`)
- Bug in :meth:`Series.__getitem__` when using an unsigned integer array as an indexer giving incorrect results or segfaulting instead of raising ``KeyError`` (:issue:`37218`)
- Bug in :meth:`Index.where` incorrectly casting numeric values to strings (:issue:`37591`)

Missing
^^^^^^^
Expand Down
18 changes: 5 additions & 13 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
ensure_int64,
ensure_object,
ensure_platform_int,
is_bool,
is_bool_dtype,
is_categorical_dtype,
is_datetime64_any_dtype,
Expand Down Expand Up @@ -4079,23 +4078,16 @@ def where(self, cond, other=None):
if other is None:
other = self._na_value

dtype = self.dtype
values = self.values

if is_bool(other) or is_bool_dtype(other):

# bools force casting
values = values.astype(object)
dtype = None
try:
self._validate_fill_value(other)
except (ValueError, TypeError):
return self.astype(object).where(cond, other)

values = np.where(cond, values, other)

if self._is_numeric_dtype and np.any(isna(values)):
# We can't coerce to the numeric dtype of "self" (unless
# it's float) if there are NaN values in our output.
dtype = None

return Index(values, dtype=dtype, name=self.name)
return Index(values, name=self.name)

# construction helpers
@final
Expand Down
11 changes: 2 additions & 9 deletions pandas/core/indexes/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,16 +482,9 @@ def isin(self, values, level=None):

@Appender(Index.where.__doc__)
def where(self, cond, other=None):
values = self._data._ndarray
other = self._data._validate_setitem_value(other)

try:
other = self._data._validate_setitem_value(other)
except (TypeError, ValueError) as err:
# Includes tzawareness mismatch and IncompatibleFrequencyError
oth = getattr(other, "dtype", other)
raise TypeError(f"Where requires matching dtype, not {oth}") from err

result = np.where(cond, values, other)
result = np.where(cond, self._data._ndarray, other)
arr = self._data._from_backing_data(result)
return type(self)._simple_new(arr, name=self.name)

Expand Down
2 changes: 2 additions & 0 deletions pandas/core/indexes/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ def _validate_fill_value(self, value):
# force conversion to object
# so we don't lose the bools
raise TypeError
if isinstance(value, str):
raise TypeError

return value

Expand Down
13 changes: 13 additions & 0 deletions pandas/tests/indexes/base_class/test_where.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import numpy as np

from pandas import Index
import pandas._testing as tm


class TestWhere:
def test_where_intlike_str_doesnt_cast_ints(self):
idx = Index(range(3))
mask = np.array([True, False, True])
res = idx.where(mask, "2")
expected = Index([0, "2", 2])
tm.assert_index_equal(res, expected)
3 changes: 1 addition & 2 deletions pandas/tests/indexes/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,9 @@ def test_where_cast_str(self):
result = index.where(mask, [str(index[0])])
tm.assert_index_equal(result, expected)

msg = "Where requires matching dtype, not foo"
msg = "value should be a '.*', 'NaT', or array of those"
with pytest.raises(TypeError, match=msg):
index.where(mask, "foo")

msg = r"Where requires matching dtype, not \['foo'\]"
with pytest.raises(TypeError, match=msg):
index.where(mask, ["foo"])
16 changes: 9 additions & 7 deletions pandas/tests/indexes/datetimes/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,24 +177,26 @@ def test_where_invalid_dtypes(self):

i2 = Index([pd.NaT, pd.NaT] + dti[2:].tolist())

with pytest.raises(TypeError, match="Where requires matching dtype"):
msg = "value should be a 'Timestamp', 'NaT', or array of those. Got"
msg2 = "Cannot compare tz-naive and tz-aware datetime-like objects"
with pytest.raises(TypeError, match=msg2):
# passing tz-naive ndarray to tzaware DTI
dti.where(notna(i2), i2.values)

with pytest.raises(TypeError, match="Where requires matching dtype"):
with pytest.raises(TypeError, match=msg2):
# passing tz-aware DTI to tznaive DTI
dti.tz_localize(None).where(notna(i2), i2)

with pytest.raises(TypeError, match="Where requires matching dtype"):
with pytest.raises(TypeError, match=msg):
dti.where(notna(i2), i2.tz_localize(None).to_period("D"))

with pytest.raises(TypeError, match="Where requires matching dtype"):
with pytest.raises(TypeError, match=msg):
dti.where(notna(i2), i2.asi8.view("timedelta64[ns]"))

with pytest.raises(TypeError, match="Where requires matching dtype"):
with pytest.raises(TypeError, match=msg):
dti.where(notna(i2), i2.asi8)

with pytest.raises(TypeError, match="Where requires matching dtype"):
with pytest.raises(TypeError, match=msg):
# non-matching scalar
dti.where(notna(i2), pd.Timedelta(days=4))

Expand All @@ -203,7 +205,7 @@ def test_where_mismatched_nat(self, tz_aware_fixture):
dti = pd.date_range("2013-01-01", periods=3, tz=tz)
cond = np.array([True, False, True])

msg = "Where requires matching dtype"
msg = "value should be a 'Timestamp', 'NaT', or array of those. Got"
with pytest.raises(TypeError, match=msg):
# wrong-dtyped NaT
dti.where(cond, np.timedelta64("NaT", "ns"))
Expand Down
11 changes: 6 additions & 5 deletions pandas/tests/indexes/period/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,24 +545,25 @@ def test_where_invalid_dtypes(self):

i2 = PeriodIndex([NaT, NaT] + pi[2:].tolist(), freq="D")

with pytest.raises(TypeError, match="Where requires matching dtype"):
msg = "value should be a 'Period', 'NaT', or array of those"
with pytest.raises(TypeError, match=msg):
pi.where(notna(i2), i2.asi8)

with pytest.raises(TypeError, match="Where requires matching dtype"):
with pytest.raises(TypeError, match=msg):
pi.where(notna(i2), i2.asi8.view("timedelta64[ns]"))

with pytest.raises(TypeError, match="Where requires matching dtype"):
with pytest.raises(TypeError, match=msg):
pi.where(notna(i2), i2.to_timestamp("S"))

with pytest.raises(TypeError, match="Where requires matching dtype"):
with pytest.raises(TypeError, match=msg):
# non-matching scalar
pi.where(notna(i2), Timedelta(days=4))

def test_where_mismatched_nat(self):
pi = period_range("20130101", periods=5, freq="D")
cond = np.array([True, False, True, True, False])

msg = "Where requires matching dtype"
msg = "value should be a 'Period', 'NaT', or array of those"
with pytest.raises(TypeError, match=msg):
# wrong-dtyped NaT
pi.where(cond, np.timedelta64("NaT", "ns"))
Expand Down
11 changes: 6 additions & 5 deletions pandas/tests/indexes/timedeltas/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,24 +150,25 @@ def test_where_invalid_dtypes(self):

i2 = Index([pd.NaT, pd.NaT] + tdi[2:].tolist())

with pytest.raises(TypeError, match="Where requires matching dtype"):
msg = "value should be a 'Timedelta', 'NaT', or array of those"
with pytest.raises(TypeError, match=msg):
tdi.where(notna(i2), i2.asi8)

with pytest.raises(TypeError, match="Where requires matching dtype"):
with pytest.raises(TypeError, match=msg):
tdi.where(notna(i2), i2 + pd.Timestamp.now())

with pytest.raises(TypeError, match="Where requires matching dtype"):
with pytest.raises(TypeError, match=msg):
tdi.where(notna(i2), (i2 + pd.Timestamp.now()).to_period("D"))

with pytest.raises(TypeError, match="Where requires matching dtype"):
with pytest.raises(TypeError, match=msg):
# non-matching scalar
tdi.where(notna(i2), pd.Timestamp.now())

def test_where_mismatched_nat(self):
tdi = timedelta_range("1 day", periods=3, freq="D", name="idx")
cond = np.array([True, False, False])

msg = "Where requires matching dtype"
msg = "value should be a 'Timedelta', 'NaT', or array of those"
with pytest.raises(TypeError, match=msg):
# wrong-dtyped NaT
tdi.where(cond, np.datetime64("NaT", "ns"))
Expand Down
7 changes: 4 additions & 3 deletions pandas/tests/indexing/test_coercion.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,7 @@ def test_where_index_timedelta64(self, value):
result = tdi.where(cond, value)
tm.assert_index_equal(result, expected)

msg = "Where requires matching dtype"
msg = "value should be a 'Timedelta', 'NaT', or array of thos"
with pytest.raises(TypeError, match=msg):
# wrong-dtyped NaT
tdi.where(cond, np.datetime64("NaT", "ns"))
Expand All @@ -804,11 +804,12 @@ def test_where_index_period(self):
tm.assert_index_equal(result, expected)

# Passing a mismatched scalar
msg = "Where requires matching dtype"
msg = "value should be a 'Period', 'NaT', or array of those"
with pytest.raises(TypeError, match=msg):
pi.where(cond, pd.Timedelta(days=4))

with pytest.raises(TypeError, match=msg):
msg = r"Input has different freq=D from PeriodArray\(freq=Q-DEC\)"
with pytest.raises(ValueError, match=msg):
pi.where(cond, pd.Period("2020-04-21", "D"))


Expand Down