Skip to content

Commit

Permalink
TYP: misc return types (#57285)
Browse files Browse the repository at this point in the history
  • Loading branch information
twoertwein authored Feb 7, 2024
1 parent 99e3afe commit 9b3c301
Show file tree
Hide file tree
Showing 18 changed files with 67 additions and 37 deletions.
8 changes: 7 additions & 1 deletion pandas/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,12 @@

from pandas.core.dtypes.dtypes import ExtensionDtype

from pandas import Interval
from pandas import (
DatetimeIndex,
Interval,
PeriodIndex,
TimedeltaIndex,
)
from pandas.arrays import (
DatetimeArray,
TimedeltaArray,
Expand Down Expand Up @@ -190,6 +195,7 @@ def __reversed__(self) -> Iterator[_T_co]:
NDFrameT = TypeVar("NDFrameT", bound="NDFrame")

IndexT = TypeVar("IndexT", bound="Index")
FreqIndexT = TypeVar("FreqIndexT", "DatetimeIndex", "PeriodIndex", "TimedeltaIndex")
NumpyIndexT = TypeVar("NumpyIndexT", np.ndarray, "Index")

AxisInt = int
Expand Down
14 changes: 9 additions & 5 deletions pandas/core/_numba/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from contextlib import contextmanager
import operator
from typing import TYPE_CHECKING

import numba
from numba import types
Expand Down Expand Up @@ -40,6 +41,9 @@
from pandas.core.internals import SingleBlockManager
from pandas.core.series import Series

if TYPE_CHECKING:
from pandas._typing import Self


# Helper function to hack around fact that Index casts numpy string dtype to object
#
Expand Down Expand Up @@ -84,7 +88,7 @@ def key(self):
def as_array(self):
return types.Array(self.dtype, 1, self.layout)

def copy(self, dtype=None, ndim: int = 1, layout=None):
def copy(self, dtype=None, ndim: int = 1, layout=None) -> Self:
assert ndim == 1
if dtype is None:
dtype = self.dtype
Expand Down Expand Up @@ -114,7 +118,7 @@ def key(self):
def as_array(self):
return self.values

def copy(self, dtype=None, ndim: int = 1, layout: str = "C"):
def copy(self, dtype=None, ndim: int = 1, layout: str = "C") -> Self:
assert ndim == 1
assert layout == "C"
if dtype is None:
Expand All @@ -123,7 +127,7 @@ def copy(self, dtype=None, ndim: int = 1, layout: str = "C"):


@typeof_impl.register(Index)
def typeof_index(val, c):
def typeof_index(val, c) -> IndexType:
"""
This will assume that only strings are in object dtype
index.
Expand All @@ -136,7 +140,7 @@ def typeof_index(val, c):


@typeof_impl.register(Series)
def typeof_series(val, c):
def typeof_series(val, c) -> SeriesType:
index = typeof_impl(val.index, c)
arrty = typeof_impl(val.values, c)
namety = typeof_impl(val.name, c)
Expand Down Expand Up @@ -532,7 +536,7 @@ def key(self):


@typeof_impl.register(_iLocIndexer)
def typeof_iloc(val, c):
def typeof_iloc(val, c) -> IlocType:
objtype = typeof_impl(val.obj, c)
return IlocType(objtype)

Expand Down
2 changes: 1 addition & 1 deletion pandas/core/arrays/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def searchsorted(
return self._ndarray.searchsorted(npvalue, side=side, sorter=sorter)

@doc(ExtensionArray.shift)
def shift(self, periods: int = 1, fill_value=None):
def shift(self, periods: int = 1, fill_value=None) -> Self:
# NB: shift is always along axis=0
axis = 0
fill_value = self._validate_scalar(fill_value)
Expand Down
4 changes: 2 additions & 2 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def floordiv_compat(
from pandas.core.arrays.timedeltas import TimedeltaArray


def get_unit_from_pa_dtype(pa_dtype):
def get_unit_from_pa_dtype(pa_dtype) -> str:
# https://github.com/pandas-dev/pandas/pull/50998#discussion_r1100344804
if pa_version_under11p0:
unit = str(pa_dtype).split("[", 1)[-1][:-1]
Expand Down Expand Up @@ -1966,7 +1966,7 @@ def _rank(
na_option: str = "keep",
ascending: bool = True,
pct: bool = False,
):
) -> Self:
"""
See Series.rank.__doc__.
"""
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/dtypes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2337,7 +2337,7 @@ def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None:
except NotImplementedError:
return None

def __from_arrow__(self, array: pa.Array | pa.ChunkedArray):
def __from_arrow__(self, array: pa.Array | pa.ChunkedArray) -> ArrowExtensionArray:
"""
Construct IntegerArray/FloatingArray from pyarrow Array/ChunkedArray.
"""
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/dtypes/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

# define abstract base classes to enable isinstance type checking on our
# objects
def create_pandas_abc_type(name, attr, comp):
def create_pandas_abc_type(name, attr, comp) -> type:
def _check(inst) -> bool:
return getattr(inst, attr, "_typ") in comp

Expand Down
2 changes: 1 addition & 1 deletion pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1788,7 +1788,7 @@ def _choose_path(self, fast_path: Callable, slow_path: Callable, group: DataFram

return path, res

def filter(self, func, dropna: bool = True, *args, **kwargs):
def filter(self, func, dropna: bool = True, *args, **kwargs) -> DataFrame:
"""
Filter elements from groups that don't satisfy a criterion.
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1750,7 +1750,7 @@ def _get_slice_axis(self, slice_obj: slice, axis: AxisInt):
labels._validate_positional_slice(slice_obj)
return self.obj._slice(slice_obj, axis=axis)

def _convert_to_indexer(self, key, axis: AxisInt):
def _convert_to_indexer(self, key: T, axis: AxisInt) -> T:
"""
Much simpler as we only have to deal with our valid types.
"""
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/interchange/from_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def from_dataframe(df, allow_copy: bool = True) -> pd.DataFrame:
)


def _from_dataframe(df: DataFrameXchg, allow_copy: bool = True):
def _from_dataframe(df: DataFrameXchg, allow_copy: bool = True) -> pd.DataFrame:
"""
Build a ``pd.DataFrame`` from the DataFrame interchange object.
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/internals/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2537,7 +2537,7 @@ def get_block_type(dtype: DtypeObj) -> type[Block]:

def new_block_2d(
values: ArrayLike, placement: BlockPlacement, refs: BlockValuesRefs | None = None
):
) -> Block:
# new_block specialized to case with
# ndim=2
# isinstance(placement, BlockPlacement)
Expand Down
3 changes: 2 additions & 1 deletion pandas/core/internals/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Any,
Callable,
Literal,
NoReturn,
cast,
final,
)
Expand Down Expand Up @@ -2349,7 +2350,7 @@ def raise_construction_error(
block_shape: Shape,
axes: list[Index],
e: ValueError | None = None,
):
) -> NoReturn:
"""raise a helpful message about our construction"""
passed = tuple(map(int, [tot_items] + list(block_shape)))
# Correcting the user facing error message during dataframe construction
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/ops/array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def comp_method_OBJECT_ARRAY(op, x, y):
return result.reshape(x.shape)


def _masked_arith_op(x: np.ndarray, y, op):
def _masked_arith_op(x: np.ndarray, y, op) -> np.ndarray:
"""
If the given arithmetic operation fails, attempt it again on
only the non-null elements of the input array(s).
Expand Down
11 changes: 8 additions & 3 deletions pandas/core/ops/mask_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,25 @@
"""
from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np

from pandas._libs import (
lib,
missing as libmissing,
)

if TYPE_CHECKING:
from pandas._typing import npt


def kleene_or(
left: bool | np.ndarray | libmissing.NAType,
right: bool | np.ndarray | libmissing.NAType,
left_mask: np.ndarray | None,
right_mask: np.ndarray | None,
):
) -> tuple[npt.NDArray[np.bool_], npt.NDArray[np.bool_]]:
"""
Boolean ``or`` using Kleene logic.
Expand Down Expand Up @@ -78,7 +83,7 @@ def kleene_xor(
right: bool | np.ndarray | libmissing.NAType,
left_mask: np.ndarray | None,
right_mask: np.ndarray | None,
):
) -> tuple[npt.NDArray[np.bool_], npt.NDArray[np.bool_]]:
"""
Boolean ``xor`` using Kleene logic.
Expand Down Expand Up @@ -131,7 +136,7 @@ def kleene_and(
right: bool | libmissing.NAType | np.ndarray,
left_mask: np.ndarray | None,
right_mask: np.ndarray | None,
):
) -> tuple[npt.NDArray[np.bool_], npt.NDArray[np.bool_]]:
"""
Boolean ``and`` using Kleene logic.
Expand Down
29 changes: 20 additions & 9 deletions pandas/core/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
AnyArrayLike,
Axis,
Concatenate,
FreqIndexT,
Frequency,
IndexLabel,
InterpolateOptions,
Expand Down Expand Up @@ -1690,7 +1691,7 @@ class DatetimeIndexResampler(Resampler):
ax: DatetimeIndex

@property
def _resampler_for_grouping(self):
def _resampler_for_grouping(self) -> type[DatetimeIndexResamplerGroupby]:
return DatetimeIndexResamplerGroupby

def _get_binner_for_time(self):
Expand Down Expand Up @@ -2483,17 +2484,28 @@ def _set_grouper(
return obj, ax, indexer


@overload
def _take_new_index(
obj: NDFrameT,
obj: DataFrame, indexer: npt.NDArray[np.intp], new_index: Index
) -> DataFrame:
...


@overload
def _take_new_index(
obj: Series, indexer: npt.NDArray[np.intp], new_index: Index
) -> Series:
...


def _take_new_index(
obj: DataFrame | Series,
indexer: npt.NDArray[np.intp],
new_index: Index,
) -> NDFrameT:
) -> DataFrame | Series:
if isinstance(obj, ABCSeries):
new_values = algos.take_nd(obj._values, indexer)
# error: Incompatible return value type (got "Series", expected "NDFrameT")
return obj._constructor( # type: ignore[return-value]
new_values, index=new_index, name=obj.name
)
return obj._constructor(new_values, index=new_index, name=obj.name)
elif isinstance(obj, ABCDataFrame):
new_mgr = obj._mgr.reindex_indexer(new_axis=new_index, indexer=indexer, axis=1)
return obj._constructor_from_mgr(new_mgr, axes=new_mgr.axes)
Expand Down Expand Up @@ -2788,7 +2800,7 @@ def asfreq(
return new_obj


def _asfreq_compat(index: DatetimeIndex | PeriodIndex | TimedeltaIndex, freq):
def _asfreq_compat(index: FreqIndexT, freq) -> FreqIndexT:
"""
Helper to mimic asfreq on (empty) DatetimeIndex and TimedeltaIndex.
Expand All @@ -2806,7 +2818,6 @@ def _asfreq_compat(index: DatetimeIndex | PeriodIndex | TimedeltaIndex, freq):
raise ValueError(
"Can only set arbitrary freq for empty DatetimeIndex or TimedeltaIndex"
)
new_index: Index
if isinstance(index, PeriodIndex):
new_index = index.asfreq(freq=freq)
elif isinstance(index, DatetimeIndex):
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/reshape/pivot.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,7 +830,7 @@ def _normalize(
return table


def _get_names(arrs, names, prefix: str = "row"):
def _get_names(arrs, names, prefix: str = "row") -> list:
if names is None:
names = []
for i, arr in enumerate(arrs):
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/reshape/tile.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ def _format_labels(
precision: int,
right: bool = True,
include_lowest: bool = False,
):
) -> IntervalIndex:
"""based on the dtype, return our labels"""
closed: IntervalLeftRight = "right" if right else "left"

Expand Down
11 changes: 6 additions & 5 deletions pandas/core/strings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
from collections.abc import Sequence
import re

from pandas._typing import Scalar

from pandas import Series
from pandas._typing import (
Scalar,
Self,
)


class BaseStringArrayMethods(abc.ABC):
Expand Down Expand Up @@ -240,11 +241,11 @@ def _str_rstrip(self, to_strip=None):
pass

@abc.abstractmethod
def _str_removeprefix(self, prefix: str) -> Series:
def _str_removeprefix(self, prefix: str) -> Self:
pass

@abc.abstractmethod
def _str_removesuffix(self, suffix: str) -> Series:
def _str_removesuffix(self, suffix: str) -> Self:
pass

@abc.abstractmethod
Expand Down
4 changes: 3 additions & 1 deletion pandas/core/tools/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1134,7 +1134,9 @@ def to_datetime(
}


def _assemble_from_unit_mappings(arg, errors: DateTimeErrorChoices, utc: bool):
def _assemble_from_unit_mappings(
arg, errors: DateTimeErrorChoices, utc: bool
) -> Series:
"""
assemble the unit specified fields from the arg (DataFrame)
Return a Series for actual parsing
Expand Down

0 comments on commit 9b3c301

Please sign in to comment.