From 9b3c301ffd9b636e657c97f202f07e35cc37e7c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Torsten=20W=C3=B6rtwein?= <6618166+twoertwein@users.noreply.github.com> Date: Tue, 6 Feb 2024 23:55:25 -0500 Subject: [PATCH] TYP: misc return types (#57285) --- pandas/_typing.py | 8 ++++++- pandas/core/_numba/extensions.py | 14 +++++++---- pandas/core/arrays/_mixins.py | 2 +- pandas/core/arrays/arrow/array.py | 4 ++-- pandas/core/dtypes/dtypes.py | 2 +- pandas/core/dtypes/generic.py | 2 +- pandas/core/groupby/generic.py | 2 +- pandas/core/indexing.py | 2 +- pandas/core/interchange/from_dataframe.py | 2 +- pandas/core/internals/blocks.py | 2 +- pandas/core/internals/managers.py | 3 ++- pandas/core/ops/array_ops.py | 2 +- pandas/core/ops/mask_ops.py | 11 ++++++--- pandas/core/resample.py | 29 ++++++++++++++++------- pandas/core/reshape/pivot.py | 2 +- pandas/core/reshape/tile.py | 2 +- pandas/core/strings/base.py | 11 +++++---- pandas/core/tools/datetimes.py | 4 +++- 18 files changed, 67 insertions(+), 37 deletions(-) diff --git a/pandas/_typing.py b/pandas/_typing.py index 8cc53335f6ce9..00135bffbd435 100644 --- a/pandas/_typing.py +++ b/pandas/_typing.py @@ -46,7 +46,12 @@ from pandas.core.dtypes.dtypes import ExtensionDtype - from pandas import Interval + from pandas import ( + DatetimeIndex, + Interval, + PeriodIndex, + TimedeltaIndex, + ) from pandas.arrays import ( DatetimeArray, TimedeltaArray, @@ -190,6 +195,7 @@ def __reversed__(self) -> Iterator[_T_co]: NDFrameT = TypeVar("NDFrameT", bound="NDFrame") IndexT = TypeVar("IndexT", bound="Index") +FreqIndexT = TypeVar("FreqIndexT", "DatetimeIndex", "PeriodIndex", "TimedeltaIndex") NumpyIndexT = TypeVar("NumpyIndexT", np.ndarray, "Index") AxisInt = int diff --git a/pandas/core/_numba/extensions.py b/pandas/core/_numba/extensions.py index ee09c9380fb0f..e6f0427de2a3a 100644 --- a/pandas/core/_numba/extensions.py +++ b/pandas/core/_numba/extensions.py @@ -12,6 +12,7 @@ from contextlib import contextmanager import operator +from typing import TYPE_CHECKING import numba from numba import types @@ -40,6 +41,9 @@ from pandas.core.internals import SingleBlockManager from pandas.core.series import Series +if TYPE_CHECKING: + from pandas._typing import Self + # Helper function to hack around fact that Index casts numpy string dtype to object # @@ -84,7 +88,7 @@ def key(self): def as_array(self): return types.Array(self.dtype, 1, self.layout) - def copy(self, dtype=None, ndim: int = 1, layout=None): + def copy(self, dtype=None, ndim: int = 1, layout=None) -> Self: assert ndim == 1 if dtype is None: dtype = self.dtype @@ -114,7 +118,7 @@ def key(self): def as_array(self): return self.values - def copy(self, dtype=None, ndim: int = 1, layout: str = "C"): + def copy(self, dtype=None, ndim: int = 1, layout: str = "C") -> Self: assert ndim == 1 assert layout == "C" if dtype is None: @@ -123,7 +127,7 @@ def copy(self, dtype=None, ndim: int = 1, layout: str = "C"): @typeof_impl.register(Index) -def typeof_index(val, c): +def typeof_index(val, c) -> IndexType: """ This will assume that only strings are in object dtype index. @@ -136,7 +140,7 @@ def typeof_index(val, c): @typeof_impl.register(Series) -def typeof_series(val, c): +def typeof_series(val, c) -> SeriesType: index = typeof_impl(val.index, c) arrty = typeof_impl(val.values, c) namety = typeof_impl(val.name, c) @@ -532,7 +536,7 @@ def key(self): @typeof_impl.register(_iLocIndexer) -def typeof_iloc(val, c): +def typeof_iloc(val, c) -> IlocType: objtype = typeof_impl(val.obj, c) return IlocType(objtype) diff --git a/pandas/core/arrays/_mixins.py b/pandas/core/arrays/_mixins.py index adbb971f777b4..83d2b6f1ca84f 100644 --- a/pandas/core/arrays/_mixins.py +++ b/pandas/core/arrays/_mixins.py @@ -250,7 +250,7 @@ def searchsorted( return self._ndarray.searchsorted(npvalue, side=side, sorter=sorter) @doc(ExtensionArray.shift) - def shift(self, periods: int = 1, fill_value=None): + def shift(self, periods: int = 1, fill_value=None) -> Self: # NB: shift is always along axis=0 axis = 0 fill_value = self._validate_scalar(fill_value) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 32044d1fc233a..d12250a863ac9 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -204,7 +204,7 @@ def floordiv_compat( from pandas.core.arrays.timedeltas import TimedeltaArray -def get_unit_from_pa_dtype(pa_dtype): +def get_unit_from_pa_dtype(pa_dtype) -> str: # https://github.com/pandas-dev/pandas/pull/50998#discussion_r1100344804 if pa_version_under11p0: unit = str(pa_dtype).split("[", 1)[-1][:-1] @@ -1966,7 +1966,7 @@ def _rank( na_option: str = "keep", ascending: bool = True, pct: bool = False, - ): + ) -> Self: """ See Series.rank.__doc__. """ diff --git a/pandas/core/dtypes/dtypes.py b/pandas/core/dtypes/dtypes.py index a6a5f142faf1c..5afb77b89c8d5 100644 --- a/pandas/core/dtypes/dtypes.py +++ b/pandas/core/dtypes/dtypes.py @@ -2337,7 +2337,7 @@ def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None: except NotImplementedError: return None - def __from_arrow__(self, array: pa.Array | pa.ChunkedArray): + def __from_arrow__(self, array: pa.Array | pa.ChunkedArray) -> ArrowExtensionArray: """ Construct IntegerArray/FloatingArray from pyarrow Array/ChunkedArray. """ diff --git a/pandas/core/dtypes/generic.py b/pandas/core/dtypes/generic.py index 9718ad600cb80..8abde2ab7010f 100644 --- a/pandas/core/dtypes/generic.py +++ b/pandas/core/dtypes/generic.py @@ -33,7 +33,7 @@ # define abstract base classes to enable isinstance type checking on our # objects -def create_pandas_abc_type(name, attr, comp): +def create_pandas_abc_type(name, attr, comp) -> type: def _check(inst) -> bool: return getattr(inst, attr, "_typ") in comp diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index db626752e9eff..448d052ed9531 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -1788,7 +1788,7 @@ def _choose_path(self, fast_path: Callable, slow_path: Callable, group: DataFram return path, res - def filter(self, func, dropna: bool = True, *args, **kwargs): + def filter(self, func, dropna: bool = True, *args, **kwargs) -> DataFrame: """ Filter elements from groups that don't satisfy a criterion. diff --git a/pandas/core/indexing.py b/pandas/core/indexing.py index ab06dd3ea5af0..91e9d6fd602a6 100644 --- a/pandas/core/indexing.py +++ b/pandas/core/indexing.py @@ -1750,7 +1750,7 @@ def _get_slice_axis(self, slice_obj: slice, axis: AxisInt): labels._validate_positional_slice(slice_obj) return self.obj._slice(slice_obj, axis=axis) - def _convert_to_indexer(self, key, axis: AxisInt): + def _convert_to_indexer(self, key: T, axis: AxisInt) -> T: """ Much simpler as we only have to deal with our valid types. """ diff --git a/pandas/core/interchange/from_dataframe.py b/pandas/core/interchange/from_dataframe.py index 390f5e0d0d5ae..ba2d275e88b32 100644 --- a/pandas/core/interchange/from_dataframe.py +++ b/pandas/core/interchange/from_dataframe.py @@ -75,7 +75,7 @@ def from_dataframe(df, allow_copy: bool = True) -> pd.DataFrame: ) -def _from_dataframe(df: DataFrameXchg, allow_copy: bool = True): +def _from_dataframe(df: DataFrameXchg, allow_copy: bool = True) -> pd.DataFrame: """ Build a ``pd.DataFrame`` from the DataFrame interchange object. diff --git a/pandas/core/internals/blocks.py b/pandas/core/internals/blocks.py index 910de45d3e89f..b4b1de5730833 100644 --- a/pandas/core/internals/blocks.py +++ b/pandas/core/internals/blocks.py @@ -2537,7 +2537,7 @@ def get_block_type(dtype: DtypeObj) -> type[Block]: def new_block_2d( values: ArrayLike, placement: BlockPlacement, refs: BlockValuesRefs | None = None -): +) -> Block: # new_block specialized to case with # ndim=2 # isinstance(placement, BlockPlacement) diff --git a/pandas/core/internals/managers.py b/pandas/core/internals/managers.py index cda5575a2b04e..1d3772d224d89 100644 --- a/pandas/core/internals/managers.py +++ b/pandas/core/internals/managers.py @@ -10,6 +10,7 @@ Any, Callable, Literal, + NoReturn, cast, final, ) @@ -2349,7 +2350,7 @@ def raise_construction_error( block_shape: Shape, axes: list[Index], e: ValueError | None = None, -): +) -> NoReturn: """raise a helpful message about our construction""" passed = tuple(map(int, [tot_items] + list(block_shape))) # Correcting the user facing error message during dataframe construction diff --git a/pandas/core/ops/array_ops.py b/pandas/core/ops/array_ops.py index 8ccd7c84cb05c..034a231f04488 100644 --- a/pandas/core/ops/array_ops.py +++ b/pandas/core/ops/array_ops.py @@ -130,7 +130,7 @@ def comp_method_OBJECT_ARRAY(op, x, y): return result.reshape(x.shape) -def _masked_arith_op(x: np.ndarray, y, op): +def _masked_arith_op(x: np.ndarray, y, op) -> np.ndarray: """ If the given arithmetic operation fails, attempt it again on only the non-null elements of the input array(s). diff --git a/pandas/core/ops/mask_ops.py b/pandas/core/ops/mask_ops.py index adc1f63c568bf..e5d0626ad9119 100644 --- a/pandas/core/ops/mask_ops.py +++ b/pandas/core/ops/mask_ops.py @@ -3,6 +3,8 @@ """ from __future__ import annotations +from typing import TYPE_CHECKING + import numpy as np from pandas._libs import ( @@ -10,13 +12,16 @@ missing as libmissing, ) +if TYPE_CHECKING: + from pandas._typing import npt + def kleene_or( left: bool | np.ndarray | libmissing.NAType, right: bool | np.ndarray | libmissing.NAType, left_mask: np.ndarray | None, right_mask: np.ndarray | None, -): +) -> tuple[npt.NDArray[np.bool_], npt.NDArray[np.bool_]]: """ Boolean ``or`` using Kleene logic. @@ -78,7 +83,7 @@ def kleene_xor( right: bool | np.ndarray | libmissing.NAType, left_mask: np.ndarray | None, right_mask: np.ndarray | None, -): +) -> tuple[npt.NDArray[np.bool_], npt.NDArray[np.bool_]]: """ Boolean ``xor`` using Kleene logic. @@ -131,7 +136,7 @@ def kleene_and( right: bool | libmissing.NAType | np.ndarray, left_mask: np.ndarray | None, right_mask: np.ndarray | None, -): +) -> tuple[npt.NDArray[np.bool_], npt.NDArray[np.bool_]]: """ Boolean ``and`` using Kleene logic. diff --git a/pandas/core/resample.py b/pandas/core/resample.py index 24c9b10c2b3a6..2a36c0f1ef549 100644 --- a/pandas/core/resample.py +++ b/pandas/core/resample.py @@ -101,6 +101,7 @@ AnyArrayLike, Axis, Concatenate, + FreqIndexT, Frequency, IndexLabel, InterpolateOptions, @@ -1690,7 +1691,7 @@ class DatetimeIndexResampler(Resampler): ax: DatetimeIndex @property - def _resampler_for_grouping(self): + def _resampler_for_grouping(self) -> type[DatetimeIndexResamplerGroupby]: return DatetimeIndexResamplerGroupby def _get_binner_for_time(self): @@ -2483,17 +2484,28 @@ def _set_grouper( return obj, ax, indexer +@overload def _take_new_index( - obj: NDFrameT, + obj: DataFrame, indexer: npt.NDArray[np.intp], new_index: Index +) -> DataFrame: + ... + + +@overload +def _take_new_index( + obj: Series, indexer: npt.NDArray[np.intp], new_index: Index +) -> Series: + ... + + +def _take_new_index( + obj: DataFrame | Series, indexer: npt.NDArray[np.intp], new_index: Index, -) -> NDFrameT: +) -> DataFrame | Series: if isinstance(obj, ABCSeries): new_values = algos.take_nd(obj._values, indexer) - # error: Incompatible return value type (got "Series", expected "NDFrameT") - return obj._constructor( # type: ignore[return-value] - new_values, index=new_index, name=obj.name - ) + return obj._constructor(new_values, index=new_index, name=obj.name) elif isinstance(obj, ABCDataFrame): new_mgr = obj._mgr.reindex_indexer(new_axis=new_index, indexer=indexer, axis=1) return obj._constructor_from_mgr(new_mgr, axes=new_mgr.axes) @@ -2788,7 +2800,7 @@ def asfreq( return new_obj -def _asfreq_compat(index: DatetimeIndex | PeriodIndex | TimedeltaIndex, freq): +def _asfreq_compat(index: FreqIndexT, freq) -> FreqIndexT: """ Helper to mimic asfreq on (empty) DatetimeIndex and TimedeltaIndex. @@ -2806,7 +2818,6 @@ def _asfreq_compat(index: DatetimeIndex | PeriodIndex | TimedeltaIndex, freq): raise ValueError( "Can only set arbitrary freq for empty DatetimeIndex or TimedeltaIndex" ) - new_index: Index if isinstance(index, PeriodIndex): new_index = index.asfreq(freq=freq) elif isinstance(index, DatetimeIndex): diff --git a/pandas/core/reshape/pivot.py b/pandas/core/reshape/pivot.py index 3abc1408584a0..927f2305045ae 100644 --- a/pandas/core/reshape/pivot.py +++ b/pandas/core/reshape/pivot.py @@ -830,7 +830,7 @@ def _normalize( return table -def _get_names(arrs, names, prefix: str = "row"): +def _get_names(arrs, names, prefix: str = "row") -> list: if names is None: names = [] for i, arr in enumerate(arrs): diff --git a/pandas/core/reshape/tile.py b/pandas/core/reshape/tile.py index ecbac32366028..4aecc9794384a 100644 --- a/pandas/core/reshape/tile.py +++ b/pandas/core/reshape/tile.py @@ -548,7 +548,7 @@ def _format_labels( precision: int, right: bool = True, include_lowest: bool = False, -): +) -> IntervalIndex: """based on the dtype, return our labels""" closed: IntervalLeftRight = "right" if right else "left" diff --git a/pandas/core/strings/base.py b/pandas/core/strings/base.py index 96b0352666b41..c1f94abff428a 100644 --- a/pandas/core/strings/base.py +++ b/pandas/core/strings/base.py @@ -13,9 +13,10 @@ from collections.abc import Sequence import re - from pandas._typing import Scalar - - from pandas import Series + from pandas._typing import ( + Scalar, + Self, + ) class BaseStringArrayMethods(abc.ABC): @@ -240,11 +241,11 @@ def _str_rstrip(self, to_strip=None): pass @abc.abstractmethod - def _str_removeprefix(self, prefix: str) -> Series: + def _str_removeprefix(self, prefix: str) -> Self: pass @abc.abstractmethod - def _str_removesuffix(self, suffix: str) -> Series: + def _str_removesuffix(self, suffix: str) -> Self: pass @abc.abstractmethod diff --git a/pandas/core/tools/datetimes.py b/pandas/core/tools/datetimes.py index e937d5c399820..8e0a96e508516 100644 --- a/pandas/core/tools/datetimes.py +++ b/pandas/core/tools/datetimes.py @@ -1134,7 +1134,9 @@ def to_datetime( } -def _assemble_from_unit_mappings(arg, errors: DateTimeErrorChoices, utc: bool): +def _assemble_from_unit_mappings( + arg, errors: DateTimeErrorChoices, utc: bool +) -> Series: """ assemble the unit specified fields from the arg (DataFrame) Return a Series for actual parsing