diff --git a/pandas/core/indexing.py b/pandas/core/indexing.py index 1d4ea54ef0d70..9acafb89b856b 100755 --- a/pandas/core/indexing.py +++ b/pandas/core/indexing.py @@ -1,5 +1,5 @@ import textwrap -from typing import Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple, Union import warnings import numpy as np @@ -25,10 +25,15 @@ from pandas.core.dtypes.generic import ABCDataFrame, ABCSeries from pandas.core.dtypes.missing import _infer_fill_value, isna +from pandas._typing import Axis import pandas.core.common as com from pandas.core.index import Index, InvalidIndexError, MultiIndex from pandas.core.indexers import is_list_like_indexer, length_of_indexer +if TYPE_CHECKING: + from pandas.core.generic import NDFrame + from pandas import DataFrame, Series, DatetimeArray # noqa: F401 + # the supported indexers def get_indexers_list(): @@ -104,7 +109,7 @@ class _NDFrameIndexer(_NDFrameIndexerBase): _exception = Exception axis = None - def __call__(self, axis=None): + def __call__(self, axis: Optional[Axis] = None) -> "_NDFrameIndexer": # we need to return a copy of ourselves new_self = self.__class__(self.name, self.obj) @@ -193,7 +198,7 @@ def _get_setitem_indexer(self, key): raise raise IndexingError(key) - def __setitem__(self, key, value): + def __setitem__(self, key, value) -> None: if isinstance(key, tuple): key = tuple(com.apply_if_callable(x, self.obj) for x in key) else: @@ -260,7 +265,7 @@ def _convert_tuple(self, key, is_setter: bool = False): keyidx.append(idx) return tuple(keyidx) - def _convert_range(self, key: range, is_setter: bool = False): + def _convert_range(self, key: range, is_setter: bool = False) -> List[int]: """ convert a range argument """ return list(key) @@ -638,7 +643,9 @@ def _setitem_with_indexer_missing(self, indexer, value): self.obj._maybe_update_cacher(clear=True) return self.obj - def _align_series(self, indexer, ser: ABCSeries, multiindex_indexer: bool = False): + def _align_series( + self, indexer, ser: "Series", multiindex_indexer: bool = False + ) -> Union[np.ndarray, "DatetimeArray"]: """ Parameters ---------- @@ -730,7 +737,7 @@ def ravel(i): raise ValueError("Incompatible indexer with Series") - def _align_frame(self, indexer, df: ABCDataFrame): + def _align_frame(self, indexer, df: "DataFrame") -> np.ndarray: is_frame = self.obj.ndim == 2 if isinstance(indexer, tuple): @@ -860,7 +867,7 @@ def _handle_lowerdim_multi_index_axis0(self, tup: Tuple): axis = self.axis or 0 try: # fast path for series or for tup devoid of slices - return self._get_label(tup, axis=axis) + return self._get_label(tup, axis=axis) # type: ignore except TypeError: # slices are unhashable pass @@ -959,7 +966,7 @@ def _getitem_nested_tuple(self, tup: Tuple): # this is a series with a multi-index specified a tuple of # selectors axis = self.axis or 0 - return self._getitem_axis(tup, axis=axis) + return self._getitem_axis(tup, axis=axis) # type: ignore # handle the multi-axis by taking sections and reducing # this is iterative @@ -1324,12 +1331,12 @@ class _IXIndexer(_NDFrameIndexer): http://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#ix-indexer-is-deprecated""" # noqa: E501 ) - def __init__(self, name, obj): + def __init__(self, name: str, obj: "NDFrame"): warnings.warn(self._ix_deprecation_warning, FutureWarning, stacklevel=2) super().__init__(name, obj) @Appender(_NDFrameIndexer._validate_key.__doc__) - def _validate_key(self, key, axis: int): + def _validate_key(self, key, axis: int) -> bool: if isinstance(key, slice): return True @@ -1345,7 +1352,7 @@ def _validate_key(self, key, axis: int): return True - def _convert_for_reindex(self, key, axis: int): + def _convert_for_reindex(self, key, axis: int) -> Union[Index, np.ndarray]: """ Transform a list of keys into a new array ready to be used as axis of the object we return (e.g. including NaNs). @@ -1414,7 +1421,7 @@ def _getitem_scalar(self, key): def _getitem_axis(self, key, axis: int): raise NotImplementedError() - def _getbool_axis(self, key, axis: int): + def _getbool_axis(self, key, axis: int) -> "NDFrame": # caller is responsible for ensuring non-None axis labels = self.obj._get_axis(axis) key = check_bool_indexer(labels, key) @@ -1424,7 +1431,7 @@ def _getbool_axis(self, key, axis: int): except Exception as detail: raise self._exception(detail) - def _get_slice_axis(self, slice_obj: slice, axis: int): + def _get_slice_axis(self, slice_obj: slice, axis: int) -> "NDFrame": """ this is pretty simple as we just have to deal with labels """ # caller is responsible for ensuring non-None axis obj = self.obj @@ -1690,7 +1697,7 @@ class _LocIndexer(_LocationIndexer): _exception = KeyError @Appender(_NDFrameIndexer._validate_key.__doc__) - def _validate_key(self, key, axis: int): + def _validate_key(self, key, axis: int) -> None: # valid for a collection of labels (we check their presence later) # slice of labels (where start-end in labels) @@ -1706,7 +1713,7 @@ def _validate_key(self, key, axis: int): if not is_list_like_indexer(key): self._convert_scalar_indexer(key, axis) - def _is_scalar_access(self, key: Tuple): + def _is_scalar_access(self, key: Tuple) -> bool: # this is a shortcut accessor to both .loc and .iloc # that provide the equivalent access of .at and .iat # a) avoid getting things via sections and (to minimize dtype changes) @@ -1733,14 +1740,18 @@ def _getitem_scalar(self, key): values = self.obj._get_value(*key) return values - def _get_partial_string_timestamp_match_key(self, key, labels): + def _get_partial_string_timestamp_match_key(self, key, labels: Index): """Translate any partial string timestamp matches in key, returning the new key (GH 10331)""" if isinstance(labels, MultiIndex): if isinstance(key, str) and labels.levels[0].is_all_dates: # Convert key '2016-01-01' to # ('2016-01-01'[, slice(None, None, None)]+) - key = tuple([key] + [slice(None)] * (len(labels.levels) - 1)) + key = tuple( + # https://github.com/python/mypy/issues/5492 + [key] + + [slice(None)] * (len(labels.levels) - 1) # type: ignore + ) if isinstance(key, tuple): # Convert (..., '2016-01-01', ...) in tuple to @@ -1748,7 +1759,9 @@ def _get_partial_string_timestamp_match_key(self, key, labels): new_key = [] for i, component in enumerate(key): if isinstance(component, str) and labels.levels[i].is_all_dates: - new_key.append(slice(component, component, None)) + new_key.append( + slice(component, component, None) # type: ignore + ) else: new_key.append(component) key = tuple(new_key)