diff --git a/src/numpy-stubs/lib/_index_tricks_impl.pyi b/src/numpy-stubs/lib/_index_tricks_impl.pyi index 16e52e3b..a757bd26 100644 --- a/src/numpy-stubs/lib/_index_tricks_impl.pyi +++ b/src/numpy-stubs/lib/_index_tricks_impl.pyi @@ -1,12 +1,13 @@ from _typeshed import Incomplete from collections.abc import Sequence -from typing import Any, Final, Generic, Literal, SupportsIndex, overload +from typing import Any, Final, Generic, Literal as L, SupportsIndex, final, overload from typing_extensions import TypeVar import numpy as np +from _numtype import Array, Is, Matrix, Sequence_nd, _ToArray1_nd from numpy import ndenumerate, ndindex # noqa: ICN003 from numpy._core.multiarray import ravel_multi_index, unravel_index -from numpy._typing import ArrayLike, DTypeLike, NDArray, _FiniteNestedSequence, _NestedSequence, _Shape, _SupportsDType +from numpy._typing import ArrayLike, DTypeLike, _DTypeLike, _SupportsDType as _HasDType __all__ = [ "c_", @@ -27,90 +28,133 @@ __all__ = [ _T = TypeVar("_T") _DTypeT = TypeVar("_DTypeT", bound=np.dtype[Any]) -_BoolT_co = TypeVar("_BoolT_co", bound=bool, default=bool, covariant=True) +_ScalarT = TypeVar("_ScalarT", bound=np.generic) _TupleT = TypeVar("_TupleT", bound=tuple[object, ...]) -_ArrayT = TypeVar("_ArrayT", bound=NDArray[Any]) +_ArrayT = TypeVar("_ArrayT", bound=Array) -@overload -def ix_(*args: _FiniteNestedSequence[_SupportsDType[_DTypeT]]) -> tuple[np.ndarray[_Shape, _DTypeT], ...]: ... -@overload -def ix_(*args: str | _NestedSequence[str]) -> tuple[NDArray[np.str_], ...]: ... -@overload -def ix_(*args: bytes | _NestedSequence[bytes]) -> tuple[NDArray[np.bytes_], ...]: ... -@overload -def ix_(*args: bool | _NestedSequence[bool]) -> tuple[NDArray[np.bool], ...]: ... -@overload -def ix_(*args: int | _NestedSequence[int]) -> tuple[NDArray[np.int_], ...]: ... -@overload -def ix_(*args: float | _NestedSequence[float]) -> tuple[NDArray[np.float64], ...]: ... -@overload -def ix_(*args: complex | _NestedSequence[complex]) -> tuple[NDArray[np.complex128], ...]: ... +_BoolT_co = TypeVar("_BoolT_co", bound=bool, default=bool, covariant=True) + +_AxisT_co = TypeVar("_AxisT_co", bound=int, default=L[0], covariant=True) +_MatrixT_co = TypeVar("_MatrixT_co", bound=bool, default=L[False], covariant=True) +_NDMinT_co = TypeVar("_NDMinT_co", bound=int, default=L[1], covariant=True) +_Trans1DT_co = TypeVar("_Trans1DT_co", bound=int, default=L[-1], covariant=True) + +### class nd_grid(Generic[_BoolT_co]): sparse: _BoolT_co def __init__(self, sparse: _BoolT_co = ...) -> None: ... @overload - def __getitem__(self: nd_grid[Literal[False]], key: slice | Sequence[slice]) -> NDArray[Any]: ... + def __getitem__(self: nd_grid[L[False]], key: slice | Sequence[slice]) -> Array: ... @overload - def __getitem__(self: nd_grid[Literal[True]], key: slice | Sequence[slice]) -> tuple[NDArray[Any], ...]: ... + def __getitem__(self: nd_grid[L[True]], key: slice | Sequence[slice]) -> tuple[Array, ...]: ... -class MGridClass(nd_grid[Literal[False]]): +@final +class MGridClass(nd_grid[L[False]]): def __init__(self) -> None: ... -mgrid: Final[MGridClass] = ... - -class OGridClass(nd_grid[Literal[True]]): +@final +class OGridClass(nd_grid[L[True]]): def __init__(self) -> None: ... -ogrid: Final[OGridClass] = ... +class AxisConcatenator(Generic[_AxisT_co, _MatrixT_co, _NDMinT_co, _Trans1DT_co]): + __slots__ = "axis", "matrix", "ndmin", "trans1d" + + axis: _AxisT_co + matrix: _MatrixT_co + trans1d: _Trans1DT_co + ndmin: _NDMinT_co + + # + def __init__( + self, + /, + axis: _AxisT_co = ..., + matrix: _MatrixT_co = ..., + ndmin: _NDMinT_co = ..., + trans1d: _Trans1DT_co = ..., + ) -> None: ... + + # TODO(jorenham): annotate this + def __getitem__(self, key: Incomplete, /) -> Incomplete: ... + def __len__(self, /) -> L[0]: ... -class AxisConcatenator: - axis: int - matrix: bool - ndmin: int - trans1d: int - def __init__(self, axis: int = ..., matrix: bool = ..., ndmin: int = ..., trans1d: int = ...) -> None: ... + # @staticmethod @overload - def concatenate(*a: ArrayLike, axis: SupportsIndex = ..., out: None = ...) -> NDArray[Any]: ... + def concatenate(*a: _ToArray1_nd[_ScalarT], axis: SupportsIndex | None = 0, out: None = None) -> Array[_ScalarT]: ... @staticmethod @overload - def concatenate(*a: ArrayLike, axis: SupportsIndex = ..., out: _ArrayT) -> _ArrayT: ... + def concatenate(*a: ArrayLike, axis: SupportsIndex | None = 0, out: _ArrayT) -> _ArrayT: ... @staticmethod - def makemat(data: ArrayLike, dtype: DTypeLike = ..., copy: bool = ...) -> np.matrix[Any, Any]: ... - def __getitem__(self, key: Incomplete, /) -> Incomplete: ... - -class RClass(AxisConcatenator): - axis: Literal[0] - matrix: Literal[False] - ndmin: Literal[1] - trans1d: Literal[-1] - def __init__(self) -> None: ... + @overload + def concatenate(*a: ArrayLike, axis: SupportsIndex | None = 0, out: None = None) -> Array: ... -r_: Final[RClass] = ... + # + @staticmethod + @overload + def makemat(data: _ToArray1_nd[_ScalarT], dtype: None = None, copy: bool = True) -> Matrix[_ScalarT]: ... + @staticmethod + @overload + def makemat(data: ArrayLike, dtype: _DTypeLike[_ScalarT], copy: bool = True) -> Matrix[_ScalarT]: ... + @staticmethod + @overload + def makemat(data: ArrayLike, dtype: DTypeLike | None = None, copy: bool = True) -> Matrix: ... -class CClass(AxisConcatenator): - axis: Literal[-1] - matrix: Literal[False] - ndmin: Literal[2] - trans1d: Literal[0] - def __init__(self) -> None: ... +@final +class RClass(AxisConcatenator[L[0], L[False], L[1], L[-1]]): + def __init__(self, /) -> None: ... -c_: Final[CClass] = ... +@final +class CClass(AxisConcatenator[L[-1], L[False], L[2], L[0]]): + def __init__(self, /) -> None: ... class IndexExpression(Generic[_BoolT_co]): maketuple: _BoolT_co - def __init__(self, maketuple: _BoolT_co) -> None: ... + def __init__(self, /, maketuple: _BoolT_co) -> None: ... + # + @overload + def __getitem__(self, item: _TupleT, /) -> _TupleT: ... # type: ignore[overload-overlap] @overload - def __getitem__(self, item: _TupleT) -> _TupleT: ... # type: ignore[misc] + def __getitem__(self: IndexExpression[L[False]], item: _T, /) -> _T: ... @overload - def __getitem__(self: IndexExpression[Literal[True]], item: _T) -> tuple[_T]: ... + def __getitem__(self: IndexExpression[L[True]], item: _T, /) -> tuple[_T]: ... @overload - def __getitem__(self: IndexExpression[Literal[False]], item: _T) -> _T: ... + def __getitem__(self, item: _T, /) -> _T | tuple[_T]: ... -index_exp: Final[IndexExpression[Literal[True]]] = ... -s_: Final[IndexExpression[Literal[False]]] = ... +@overload +def ix_(*args: _HasDType[_DTypeT] | Sequence_nd[_HasDType[_DTypeT]]) -> tuple[np.ndarray[tuple[int, ...], _DTypeT], ...]: ... +@overload +def ix_(*args: str | Sequence_nd[str]) -> tuple[Array[np.str_], ...]: ... +@overload +def ix_(*args: bytes | Sequence_nd[bytes]) -> tuple[Array[np.bytes_], ...]: ... +@overload +def ix_(*args: bool | Sequence_nd[bool]) -> tuple[Array[np.bool], ...]: ... +@overload +def ix_(*args: Is[int] | Sequence_nd[Is[int]]) -> tuple[Array[np.intp], ...]: ... +@overload +def ix_(*args: Is[float] | Sequence_nd[Is[float]]) -> tuple[Array[np.float64], ...]: ... +@overload +def ix_(*args: Is[complex] | Sequence_nd[Is[complex]]) -> tuple[Array[np.complex128], ...]: ... +@overload +def ix_(*args: int | Sequence_nd[int]) -> tuple[Array[np.intp | np.bool], ...]: ... # type: ignore[overload-cannot-match] # pyright: ignore[reportOverlappingOverload] +@overload +def ix_(*args: float | Sequence_nd[float]) -> tuple[Array[np.float64 | np.intp | np.bool], ...]: ... # type: ignore[overload-cannot-match] +@overload +def ix_(*args: complex | Sequence_nd[complex]) -> tuple[Array[np.complex128 | np.float64 | np.intp | np.bool], ...]: ... # type: ignore[overload-cannot-match] + +# +def fill_diagonal(a: Array, val: Any, wrap: bool = ...) -> None: ... +def diag_indices(n: int, ndim: int = ...) -> tuple[Array[np.int_], ...]: ... +def diag_indices_from(arr: ArrayLike) -> tuple[Array[np.int_], ...]: ... + +### + +mgrid: Final[MGridClass] = ... +ogrid: Final[OGridClass] = ... + +r_: Final[RClass] = ... +c_: Final[CClass] = ... -def fill_diagonal(a: NDArray[Any], val: Any, wrap: bool = ...) -> None: ... -def diag_indices(n: int, ndim: int = ...) -> tuple[NDArray[np.int_], ...]: ... -def diag_indices_from(arr: ArrayLike) -> tuple[NDArray[np.int_], ...]: ... +index_exp: Final[IndexExpression[L[True]]] = ... +s_: Final[IndexExpression[L[False]]] = ... diff --git a/test/static/accept/index_tricks.pyi b/test/static/accept/index_tricks.pyi index 681a386a..27668dcd 100644 --- a/test/static/accept/index_tricks.pyi +++ b/test/static/accept/index_tricks.pyi @@ -68,7 +68,7 @@ assert_type(np.s_[i:i, None:i], tuple[slice[int, int, None], slice[None, int, No assert_type(np.s_[i, i:i, ..., [i, i, i]], tuple[int, slice[int, int, None], EllipsisType, list[int]]) assert_type(np.ix_(AR_LIKE_b), tuple[npt.NDArray[np.bool], ...]) -assert_type(np.ix_(AR_LIKE_i, AR_LIKE_f), tuple[npt.NDArray[np.float64], ...]) +assert_type(np.ix_(AR_LIKE_i, AR_LIKE_f), tuple[npt.NDArray[np.float64 | np.intp | np.bool], ...]) assert_type(np.ix_(AR_i8), tuple[npt.NDArray[np.int64], ...]) assert_type(np.fill_diagonal(AR_i8, 5), None)