Skip to content

Commit

Permalink
TYP: tighten Axis (#48612)
Browse files Browse the repository at this point in the history
* TYP: tighten Axis

* allow 'rows'
  • Loading branch information
twoertwein authored Sep 20, 2022
1 parent abdaea7 commit ba63562
Show file tree
Hide file tree
Showing 46 changed files with 385 additions and 259 deletions.
3 changes: 2 additions & 1 deletion pandas/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@

NumpyIndexT = TypeVar("NumpyIndexT", np.ndarray, "Index")

Axis = Union[str, int]
AxisInt = int
Axis = Union[AxisInt, Literal["index", "columns", "rows"]]
IndexLabel = Union[Hashable, Sequence[Hashable]]
Level = Hashable
Shape = Tuple[int, ...]
Expand Down
7 changes: 5 additions & 2 deletions pandas/compat/numpy/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@
is_bool,
is_integer,
)
from pandas._typing import Axis
from pandas._typing import (
Axis,
AxisInt,
)
from pandas.errors import UnsupportedFunctionCall
from pandas.util._validators import (
validate_args,
Expand Down Expand Up @@ -413,7 +416,7 @@ def validate_resampler_func(method: str, args, kwargs) -> None:
raise TypeError("too many arguments passed in")


def validate_minmax_axis(axis: int | None, ndim: int = 1) -> None:
def validate_minmax_axis(axis: AxisInt | None, ndim: int = 1) -> None:
"""
Ensure that the axis argument passed to min, max, argmin, or argmax is zero
or None, as otherwise it will be incorrectly ignored.
Expand Down
7 changes: 4 additions & 3 deletions pandas/core/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from pandas._typing import (
AnyArrayLike,
ArrayLike,
AxisInt,
DtypeObj,
IndexLabel,
TakeIndexer,
Expand Down Expand Up @@ -1105,7 +1106,7 @@ def mode(

def rank(
values: ArrayLike,
axis: int = 0,
axis: AxisInt = 0,
method: str = "average",
na_option: str = "keep",
ascending: bool = True,
Expand Down Expand Up @@ -1483,7 +1484,7 @@ def get_indexer(current_indexer, other_indexer):
def take(
arr,
indices: TakeIndexer,
axis: int = 0,
axis: AxisInt = 0,
allow_fill: bool = False,
fill_value=None,
):
Expand Down Expand Up @@ -1675,7 +1676,7 @@ def searchsorted(
_diff_special = {"float64", "float32", "int64", "int32", "int16", "int8"}


def diff(arr, n: int, axis: int = 0):
def diff(arr, n: int, axis: AxisInt = 0):
"""
difference of n between self,
analogous to s-s.shift(n)
Expand Down
3 changes: 2 additions & 1 deletion pandas/core/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
AggFuncTypeDict,
AggObjType,
Axis,
AxisInt,
NDFrameT,
npt,
)
Expand Down Expand Up @@ -104,7 +105,7 @@ def frame_apply(


class Apply(metaclass=abc.ABCMeta):
axis: int
axis: AxisInt

def __init__(
self,
Expand Down
21 changes: 12 additions & 9 deletions pandas/core/array_algos/masked_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
import numpy as np

from pandas._libs import missing as libmissing
from pandas._typing import npt
from pandas._typing import (
AxisInt,
npt,
)

from pandas.core.nanops import check_below_min_count

Expand All @@ -21,7 +24,7 @@ def _reductions(
*,
skipna: bool = True,
min_count: int = 0,
axis: int | None = None,
axis: AxisInt | None = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -62,7 +65,7 @@ def sum(
*,
skipna: bool = True,
min_count: int = 0,
axis: int | None = None,
axis: AxisInt | None = None,
):
return _reductions(
np.sum, values=values, mask=mask, skipna=skipna, min_count=min_count, axis=axis
Expand All @@ -75,7 +78,7 @@ def prod(
*,
skipna: bool = True,
min_count: int = 0,
axis: int | None = None,
axis: AxisInt | None = None,
):
return _reductions(
np.prod, values=values, mask=mask, skipna=skipna, min_count=min_count, axis=axis
Expand All @@ -88,7 +91,7 @@ def _minmax(
mask: npt.NDArray[np.bool_],
*,
skipna: bool = True,
axis: int | None = None,
axis: AxisInt | None = None,
):
"""
Reduction for 1D masked array.
Expand Down Expand Up @@ -125,7 +128,7 @@ def min(
mask: npt.NDArray[np.bool_],
*,
skipna: bool = True,
axis: int | None = None,
axis: AxisInt | None = None,
):
return _minmax(np.min, values=values, mask=mask, skipna=skipna, axis=axis)

Expand All @@ -135,7 +138,7 @@ def max(
mask: npt.NDArray[np.bool_],
*,
skipna: bool = True,
axis: int | None = None,
axis: AxisInt | None = None,
):
return _minmax(np.max, values=values, mask=mask, skipna=skipna, axis=axis)

Expand All @@ -145,7 +148,7 @@ def mean(
mask: npt.NDArray[np.bool_],
*,
skipna: bool = True,
axis: int | None = None,
axis: AxisInt | None = None,
):
if not values.size or mask.all():
return libmissing.NA
Expand All @@ -157,7 +160,7 @@ def var(
mask: npt.NDArray[np.bool_],
*,
skipna: bool = True,
axis: int | None = None,
axis: AxisInt | None = None,
ddof: int = 1,
):
if not values.size or mask.all():
Expand Down
19 changes: 12 additions & 7 deletions pandas/core/array_algos/take.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)
from pandas._typing import (
ArrayLike,
AxisInt,
npt,
)

Expand All @@ -36,7 +37,7 @@
def take_nd(
arr: np.ndarray,
indexer,
axis: int = ...,
axis: AxisInt = ...,
fill_value=...,
allow_fill: bool = ...,
) -> np.ndarray:
Expand All @@ -47,7 +48,7 @@ def take_nd(
def take_nd(
arr: ExtensionArray,
indexer,
axis: int = ...,
axis: AxisInt = ...,
fill_value=...,
allow_fill: bool = ...,
) -> ArrayLike:
Expand All @@ -57,7 +58,7 @@ def take_nd(
def take_nd(
arr: ArrayLike,
indexer,
axis: int = 0,
axis: AxisInt = 0,
fill_value=lib.no_default,
allow_fill: bool = True,
) -> ArrayLike:
Expand Down Expand Up @@ -120,7 +121,7 @@ def take_nd(
def _take_nd_ndarray(
arr: np.ndarray,
indexer: npt.NDArray[np.intp] | None,
axis: int,
axis: AxisInt,
fill_value,
allow_fill: bool,
) -> np.ndarray:
Expand Down Expand Up @@ -287,7 +288,7 @@ def take_2d_multi(

@functools.lru_cache(maxsize=128)
def _get_take_nd_function_cached(
ndim: int, arr_dtype: np.dtype, out_dtype: np.dtype, axis: int
ndim: int, arr_dtype: np.dtype, out_dtype: np.dtype, axis: AxisInt
):
"""
Part of _get_take_nd_function below that doesn't need `mask_info` and thus
Expand Down Expand Up @@ -324,7 +325,11 @@ def _get_take_nd_function_cached(


def _get_take_nd_function(
ndim: int, arr_dtype: np.dtype, out_dtype: np.dtype, axis: int = 0, mask_info=None
ndim: int,
arr_dtype: np.dtype,
out_dtype: np.dtype,
axis: AxisInt = 0,
mask_info=None,
):
"""
Get the appropriate "take" implementation for the given dimension, axis
Expand Down Expand Up @@ -503,7 +508,7 @@ def _take_nd_object(
arr: np.ndarray,
indexer: npt.NDArray[np.intp],
out: np.ndarray,
axis: int,
axis: AxisInt,
fill_value,
mask_info,
) -> None:
Expand Down
4 changes: 3 additions & 1 deletion pandas/core/array_algos/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@

import numpy as np

from pandas._typing import AxisInt

def shift(values: np.ndarray, periods: int, axis: int, fill_value) -> np.ndarray:

def shift(values: np.ndarray, periods: int, axis: AxisInt, fill_value) -> np.ndarray:
new_values = values

if periods == 0 or values.size == 0:
Expand Down
11 changes: 6 additions & 5 deletions pandas/core/arrays/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from pandas._libs.arrays import NDArrayBacked
from pandas._typing import (
ArrayLike,
AxisInt,
Dtype,
F,
PositionalIndexer2D,
Expand Down Expand Up @@ -157,7 +158,7 @@ def take(
*,
allow_fill: bool = False,
fill_value: Any = None,
axis: int = 0,
axis: AxisInt = 0,
) -> NDArrayBackedExtensionArrayT:
if allow_fill:
fill_value = self._validate_scalar(fill_value)
Expand Down Expand Up @@ -192,15 +193,15 @@ def _values_for_factorize(self):
return self._ndarray, self._internal_fill_value

# Signature of "argmin" incompatible with supertype "ExtensionArray"
def argmin(self, axis: int = 0, skipna: bool = True): # type: ignore[override]
def argmin(self, axis: AxisInt = 0, skipna: bool = True): # type: ignore[override]
# override base class by adding axis keyword
validate_bool_kwarg(skipna, "skipna")
if not skipna and self._hasna:
raise NotImplementedError
return nargminmax(self, "argmin", axis=axis)

# Signature of "argmax" incompatible with supertype "ExtensionArray"
def argmax(self, axis: int = 0, skipna: bool = True): # type: ignore[override]
def argmax(self, axis: AxisInt = 0, skipna: bool = True): # type: ignore[override]
# override base class by adding axis keyword
validate_bool_kwarg(skipna, "skipna")
if not skipna and self._hasna:
Expand All @@ -216,7 +217,7 @@ def unique(self: NDArrayBackedExtensionArrayT) -> NDArrayBackedExtensionArrayT:
def _concat_same_type(
cls: type[NDArrayBackedExtensionArrayT],
to_concat: Sequence[NDArrayBackedExtensionArrayT],
axis: int = 0,
axis: AxisInt = 0,
) -> NDArrayBackedExtensionArrayT:
dtypes = {str(x.dtype) for x in to_concat}
if len(dtypes) != 1:
Expand Down Expand Up @@ -351,7 +352,7 @@ def fillna(
# ------------------------------------------------------------------------
# Reductions

def _wrap_reduction_result(self, axis: int | None, result):
def _wrap_reduction_result(self, axis: AxisInt | None, result):
if axis is None or self.ndim == 1:
return self._box_func(result)
return self._from_backing_data(result)
Expand Down
5 changes: 3 additions & 2 deletions pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from pandas._typing import (
ArrayLike,
AstypeArg,
AxisInt,
Dtype,
FillnaOptions,
PositionalIndexer,
Expand Down Expand Up @@ -1137,7 +1138,7 @@ def factorize(
@Substitution(klass="ExtensionArray")
@Appender(_extension_array_shared_docs["repeat"])
def repeat(
self: ExtensionArrayT, repeats: int | Sequence[int], axis: int | None = None
self: ExtensionArrayT, repeats: int | Sequence[int], axis: AxisInt | None = None
) -> ExtensionArrayT:
nv.validate_repeat((), {"axis": axis})
ind = np.arange(len(self)).repeat(repeats)
Expand Down Expand Up @@ -1567,7 +1568,7 @@ def _fill_mask_inplace(
def _rank(
self,
*,
axis: int = 0,
axis: AxisInt = 0,
method: str = "average",
na_option: str = "keep",
ascending: bool = True,
Expand Down
5 changes: 3 additions & 2 deletions pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from pandas._typing import (
ArrayLike,
AstypeArg,
AxisInt,
Dtype,
NpDtype,
Ordered,
Expand Down Expand Up @@ -1988,7 +1989,7 @@ def sort_values(
def _rank(
self,
*,
axis: int = 0,
axis: AxisInt = 0,
method: str = "average",
na_option: str = "keep",
ascending: bool = True,
Expand Down Expand Up @@ -2464,7 +2465,7 @@ def equals(self, other: object) -> bool:

@classmethod
def _concat_same_type(
cls: type[CategoricalT], to_concat: Sequence[CategoricalT], axis: int = 0
cls: type[CategoricalT], to_concat: Sequence[CategoricalT], axis: AxisInt = 0
) -> CategoricalT:
from pandas.core.dtypes.concat import union_categoricals

Expand Down
Loading

0 comments on commit ba63562

Please sign in to comment.