diff --git a/pandas-stubs/core/indexes/base.pyi b/pandas-stubs/core/indexes/base.pyi index 25c975eff..bb51faef9 100644 --- a/pandas-stubs/core/indexes/base.pyi +++ b/pandas-stubs/core/indexes/base.pyi @@ -40,6 +40,7 @@ from pandas import ( Series, TimedeltaIndex, ) +from pandas.core.arrays.boolean import BooleanArray from pandas.core.base import ( ElementOpsMixin, IndexOpsMixin, @@ -457,7 +458,18 @@ class Index(IndexOpsMixin[S1], ElementOpsMixin[S1]): @property def values(self) -> np_1darray: ... def memory_usage(self, deep: bool = False): ... - def where(self, cond, other: Scalar | ArrayLike | None = None): ... + @overload + def where( + self, + cond: Sequence[bool] | np_ndarray_bool | BooleanArray | IndexOpsMixin[bool], + other: S1 | Series[S1] | Self, + ) -> Self: ... + @overload + def where( + self, + cond: Sequence[bool] | np_ndarray_bool | BooleanArray | IndexOpsMixin[bool], + other: Scalar | AnyArrayLike | None = None, + ) -> Index: ... def __contains__(self, key) -> bool: ... @final def __setitem__(self, key, value) -> None: ... diff --git a/pandas-stubs/core/indexes/range.pyi b/pandas-stubs/core/indexes/range.pyi index c55e7f844..e96204c36 100644 --- a/pandas-stubs/core/indexes/range.pyi +++ b/pandas-stubs/core/indexes/range.pyi @@ -9,6 +9,8 @@ from typing import ( ) import numpy as np +from pandas.core.arrays.boolean import BooleanArray +from pandas.core.base import IndexOpsMixin from pandas.core.indexes.base import ( Index, _IndexSubclassBase, @@ -16,11 +18,14 @@ from pandas.core.indexes.base import ( from typing_extensions import Self from pandas._typing import ( + AnyArrayLike, Dtype, HashableT, MaskType, + Scalar, np_1darray, np_ndarray_anyint, + np_ndarray_bool, ) class RangeIndex(_IndexSubclassBase[int, np.int64]): @@ -82,3 +87,8 @@ class RangeIndex(_IndexSubclassBase[int, np.int64]): def __getitem__( # pyright: ignore[reportIncompatibleMethodOverride] self, idx: int ) -> int: ... + def where( # type: ignore[override] + self, + cond: Sequence[bool] | np_ndarray_bool | BooleanArray | IndexOpsMixin[bool], + other: Scalar | AnyArrayLike | None = None, + ) -> Index: ... diff --git a/tests/indexes/test_indexes.py b/tests/indexes/test_indexes.py index 045409474..6c25d5e13 100644 --- a/tests/indexes/test_indexes.py +++ b/tests/indexes/test_indexes.py @@ -19,6 +19,7 @@ from pandas.core.arrays.timedeltas import TimedeltaArray from pandas.core.indexes.base import Index from pandas.core.indexes.category import CategoricalIndex +from pandas.core.indexes.datetimes import DatetimeIndex from typing_extensions import ( Never, assert_type, @@ -1541,3 +1542,39 @@ def test_multiindex_swaplevel() -> None: """Test that MultiIndex.swaplevel returns MultiIndex""" mi = pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=["let", "num"]) check(assert_type(mi.swaplevel(0, 1), "pd.MultiIndex"), pd.MultiIndex) + + +def test_index_where() -> None: + """Test Index.where with multiple types of other GH1419.""" + idx = pd.Index(range(48)) + mask = np.ones(48, dtype=bool) + val_idx = idx.where(mask, idx) + check(assert_type(val_idx, "pd.Index[int]"), pd.Index, int) + + val_sr = idx.where(mask, (idx).to_series()) + check(assert_type(val_sr, "pd.Index[int]"), pd.Index, int) + + +def test_datetimeindex_where() -> None: + """Test DatetimeIndex.where with multiple types of other GH1419.""" + datetime_index = pd.date_range(start="2025-01-01", freq="h", periods=48) + mask = np.ones(48, dtype=bool) + val_idx = datetime_index.where(mask, datetime_index - pd.Timedelta(days=1)) + check(assert_type(val_idx, DatetimeIndex), DatetimeIndex) + + val_sr = datetime_index.where( + mask, (datetime_index - pd.Timedelta(days=1)).to_series() + ) + check(assert_type(val_sr, DatetimeIndex), DatetimeIndex) + + val_idx_scalar = datetime_index.where(mask, pd.Index([0, 1])) + check(assert_type(val_idx_scalar, pd.Index), pd.Index) + + val_sr_scalar = datetime_index.where(mask, pd.Series([0, 1])) + check(assert_type(val_sr_scalar, pd.Index), pd.Index) + + val_scalar = datetime_index.where(mask, 1) + check(assert_type(val_scalar, pd.Index), pd.Index) + + val_range = pd.RangeIndex(2).where(pd.Series([True, False]), 3) + check(assert_type(val_range, pd.Index), pd.RangeIndex)