From b9b020660ea7e29e4c0fbcc3668cfda4b1691a1b Mon Sep 17 00:00:00 2001
From: jorenham <jhammudoglu@gmail.com>
Date: Sat, 22 Mar 2025 20:27:40 +0100
Subject: [PATCH 01/24] TYP: annotate `_internal.get_xp` (and curse at
 `ParamSpec` for being so useless)

---
 array_api_compat/_internal.py | 23 +++++++++++++++++------
 1 file changed, 17 insertions(+), 6 deletions(-)

diff --git a/array_api_compat/_internal.py b/array_api_compat/_internal.py
index 170a1ff9..7a973567 100644
--- a/array_api_compat/_internal.py
+++ b/array_api_compat/_internal.py
@@ -4,8 +4,19 @@
 
 from functools import wraps
 from inspect import signature
+from typing import TYPE_CHECKING
 
-def get_xp(xp):
+__all__ = ["get_xp"]
+
+if TYPE_CHECKING:
+    from collections.abc import Callable
+    from types import ModuleType
+    from typing import TypeVar
+
+    _T = TypeVar("_T")
+
+
+def get_xp(xp: "ModuleType") -> "Callable[[Callable[..., _T]], Callable[..., _T]]":
     """
     Decorator to automatically replace xp with the corresponding array module.
 
@@ -22,14 +33,14 @@ def func(x, /, xp, kwarg=None):
 
     """
 
-    def inner(f):
+    def inner(f: "Callable[..., _T]", /) -> "Callable[..., _T]":
         @wraps(f)
-        def wrapped_f(*args, **kwargs):
+        def wrapped_f(*args: object, **kwargs: object) -> object:
             return f(*args, xp=xp, **kwargs)
 
         sig = signature(f)
         new_sig = sig.replace(
-            parameters=[sig.parameters[i] for i in sig.parameters if i != "xp"]
+            parameters=[par for i, par in sig.parameters.items() if i != "xp"]
         )
 
         if wrapped_f.__doc__ is None:
@@ -40,7 +51,7 @@ def wrapped_f(*args, **kwargs):
 specification for more details.
 
 """
-        wrapped_f.__signature__ = new_sig
-        return wrapped_f
+        wrapped_f.__signature__ = new_sig  # pyright: ignore[reportAttributeAccessIssue]
+        return wrapped_f  # pyright: ignore[reportReturnType]
 
     return inner

From 6a17007a32cf38f9328453b27ff6a59eeca85f53 Mon Sep 17 00:00:00 2001
From: jorenham <jhammudoglu@gmail.com>
Date: Sat, 22 Mar 2025 21:56:31 +0100
Subject: [PATCH 02/24] TYP: fix (or ignore) typing errors in `common._helpers`
 (and curse at cupy)

---
 array_api_compat/common/_helpers.py | 282 ++++++++++++++++++----------
 array_api_compat/common/_typing.py  |  14 +-
 2 files changed, 198 insertions(+), 98 deletions(-)

diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py
index 67c619b8..9f016f05 100644
--- a/array_api_compat/common/_helpers.py
+++ b/array_api_compat/common/_helpers.py
@@ -5,33 +5,75 @@
 that are in __all__ are intended as additional helper functions for use by end
 users of the compat library.
 """
+
 from __future__ import annotations
 
-import sys
-import math
 import inspect
+import math
+import sys
 import warnings
-from typing import Optional, Union, Any
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Literal,
+    SupportsIndex,
+    TypeAlias,
+    TypeGuard,
+    cast,
+    overload,
+)
+
+from ._typing import Array, Device, HasShape, Namespace, SupportsArrayNamespace
+
+if TYPE_CHECKING:
+    from collections.abc import Collection
+
+    import dask.array as da
+    import jax
+    import ndonnx as ndx
+    import numpy as np
+    import numpy.typing as npt
+    import sparse  # pyright: ignore[reportMissingTypeStubs]
+    import torch
+    from typing_extensions import TypeIs, TypeVar
 
-from ._typing import Array, Device, Namespace
+    _SizeT = TypeVar("_SizeT", bound=int | None)
 
+    _ZeroGradientArray: TypeAlias = npt.NDArray[np.void]
+    _CupyArray: TypeAlias = Any  # cupy has no py.typed
 
-def _is_jax_zero_gradient_array(x: object) -> bool:
+    _ArrayApiObj: TypeAlias = (
+        npt.NDArray[Any]
+        | da.Array
+        | jax.Array
+        | ndx.Array
+        | sparse.SparseArray
+        | torch.Tensor
+        | SupportsArrayNamespace[Any]
+        | _CupyArray
+    )
+
+
+def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]:
     """Return True if `x` is a zero-gradient array.
 
     These arrays are a design quirk of Jax that may one day be removed.
     See https://github.com/google/jax/issues/20620.
     """
-    if 'numpy' not in sys.modules or 'jax' not in sys.modules:
+    if "numpy" not in sys.modules or "jax" not in sys.modules:
         return False
 
-    import numpy as np
     import jax
+    import numpy as np
 
-    return isinstance(x, np.ndarray) and x.dtype == jax.float0
+    jax_float0 = cast("np.dtype[np.void]", jax.float0)
+    return (
+        isinstance(x, np.ndarray)
+        and cast("npt.NDArray[np.void]", x).dtype == jax_float0
+    )
 
 
-def is_numpy_array(x: object) -> bool:
+def is_numpy_array(x: object) -> TypeGuard[npt.NDArray[Any]]:
     """
     Return True if `x` is a NumPy array.
 
@@ -53,14 +95,14 @@ def is_numpy_array(x: object) -> bool:
     is_pydata_sparse_array
     """
     # Avoid importing NumPy if it isn't already
-    if 'numpy' not in sys.modules:
+    if "numpy" not in sys.modules:
         return False
 
     import numpy as np
 
     # TODO: Should we reject ndarray subclasses?
     return (isinstance(x, (np.ndarray, np.generic))
-            and not _is_jax_zero_gradient_array(x))
+            and not _is_jax_zero_gradient_array(x))  # pyright: ignore[reportUnknownArgumentType]  # fmt: skip
 
 
 def is_cupy_array(x: object) -> bool:
@@ -85,16 +127,16 @@ def is_cupy_array(x: object) -> bool:
     is_pydata_sparse_array
     """
     # Avoid importing CuPy if it isn't already
-    if 'cupy' not in sys.modules:
+    if "cupy" not in sys.modules:
         return False
 
-    import cupy as cp
+    import cupy as cp  # pyright: ignore[reportMissingTypeStubs]
 
     # TODO: Should we reject ndarray subclasses?
-    return isinstance(x, cp.ndarray)
+    return isinstance(x, cp.ndarray)  # pyright: ignore[reportUnknownMemberType]
 
 
-def is_torch_array(x: object) -> bool:
+def is_torch_array(x: object) -> TypeIs[torch.Tensor]:
     """
     Return True if `x` is a PyTorch tensor.
 
@@ -113,7 +155,7 @@ def is_torch_array(x: object) -> bool:
     is_pydata_sparse_array
     """
     # Avoid importing torch if it isn't already
-    if 'torch' not in sys.modules:
+    if "torch" not in sys.modules:
         return False
 
     import torch
@@ -122,7 +164,7 @@ def is_torch_array(x: object) -> bool:
     return isinstance(x, torch.Tensor)
 
 
-def is_ndonnx_array(x: object) -> bool:
+def is_ndonnx_array(x: object) -> TypeIs[ndx.Array]:
     """
     Return True if `x` is a ndonnx Array.
 
@@ -142,7 +184,7 @@ def is_ndonnx_array(x: object) -> bool:
     is_pydata_sparse_array
     """
     # Avoid importing torch if it isn't already
-    if 'ndonnx' not in sys.modules:
+    if "ndonnx" not in sys.modules:
         return False
 
     import ndonnx as ndx
@@ -150,7 +192,7 @@ def is_ndonnx_array(x: object) -> bool:
     return isinstance(x, ndx.Array)
 
 
-def is_dask_array(x: object) -> bool:
+def is_dask_array(x: object) -> TypeIs[da.Array]:
     """
     Return True if `x` is a dask.array Array.
 
@@ -170,7 +212,7 @@ def is_dask_array(x: object) -> bool:
     is_pydata_sparse_array
     """
     # Avoid importing dask if it isn't already
-    if 'dask.array' not in sys.modules:
+    if "dask.array" not in sys.modules:
         return False
 
     import dask.array
@@ -178,7 +220,7 @@ def is_dask_array(x: object) -> bool:
     return isinstance(x, dask.array.Array)
 
 
-def is_jax_array(x: object) -> bool:
+def is_jax_array(x: object) -> TypeIs[jax.Array]:
     """
     Return True if `x` is a JAX array.
 
@@ -199,7 +241,7 @@ def is_jax_array(x: object) -> bool:
     is_pydata_sparse_array
     """
     # Avoid importing jax if it isn't already
-    if 'jax' not in sys.modules:
+    if "jax" not in sys.modules:
         return False
 
     import jax
@@ -207,7 +249,7 @@ def is_jax_array(x: object) -> bool:
     return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x)
 
 
-def is_pydata_sparse_array(x) -> bool:
+def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]:
     """
     Return True if `x` is an array from the `sparse` package.
 
@@ -228,16 +270,16 @@ def is_pydata_sparse_array(x) -> bool:
     is_jax_array
     """
     # Avoid importing jax if it isn't already
-    if 'sparse' not in sys.modules:
+    if "sparse" not in sys.modules:
         return False
 
-    import sparse
+    import sparse  # pyright: ignore[reportMissingTypeStubs]
 
     # TODO: Account for other backends.
     return isinstance(x, sparse.SparseArray)
 
 
-def is_array_api_obj(x: object) -> bool:
+def is_array_api_obj(x: object) -> TypeIs[_ArrayApiObj]:  # pyright: ignore[reportUnknownParameterType]
     """
     Return True if `x` is an array API compatible array object.
 
@@ -252,18 +294,20 @@ def is_array_api_obj(x: object) -> bool:
     is_dask_array
     is_jax_array
     """
-    return is_numpy_array(x) \
-        or is_cupy_array(x) \
-        or is_torch_array(x) \
-        or is_dask_array(x) \
-        or is_jax_array(x) \
-        or is_pydata_sparse_array(x) \
-        or hasattr(x, '__array_namespace__')
+    return (
+        is_numpy_array(x)
+        or is_cupy_array(x)
+        or is_torch_array(x)
+        or is_dask_array(x)
+        or is_jax_array(x)
+        or is_pydata_sparse_array(x)
+        or hasattr(x, "__array_namespace__")
+    )
 
 
 def _compat_module_name() -> str:
-    assert __name__.endswith('.common._helpers')
-    return __name__.removesuffix('.common._helpers')
+    assert __name__.endswith(".common._helpers")
+    return __name__.removesuffix(".common._helpers")
 
 
 def is_numpy_namespace(xp: Namespace) -> bool:
@@ -284,7 +328,7 @@ def is_numpy_namespace(xp: Namespace) -> bool:
     is_pydata_sparse_namespace
     is_array_api_strict_namespace
     """
-    return xp.__name__ in {'numpy', _compat_module_name() + '.numpy'}
+    return xp.__name__ in {"numpy", _compat_module_name() + ".numpy"}
 
 
 def is_cupy_namespace(xp: Namespace) -> bool:
@@ -305,7 +349,7 @@ def is_cupy_namespace(xp: Namespace) -> bool:
     is_pydata_sparse_namespace
     is_array_api_strict_namespace
     """
-    return xp.__name__ in {'cupy', _compat_module_name() + '.cupy'}
+    return xp.__name__ in {"cupy", _compat_module_name() + ".cupy"}
 
 
 def is_torch_namespace(xp: Namespace) -> bool:
@@ -326,7 +370,7 @@ def is_torch_namespace(xp: Namespace) -> bool:
     is_pydata_sparse_namespace
     is_array_api_strict_namespace
     """
-    return xp.__name__ in {'torch', _compat_module_name() + '.torch'}
+    return xp.__name__ in {"torch", _compat_module_name() + ".torch"}
 
 
 def is_ndonnx_namespace(xp: Namespace) -> bool:
@@ -345,7 +389,7 @@ def is_ndonnx_namespace(xp: Namespace) -> bool:
     is_pydata_sparse_namespace
     is_array_api_strict_namespace
     """
-    return xp.__name__ == 'ndonnx'
+    return xp.__name__ == "ndonnx"
 
 
 def is_dask_namespace(xp: Namespace) -> bool:
@@ -366,7 +410,7 @@ def is_dask_namespace(xp: Namespace) -> bool:
     is_pydata_sparse_namespace
     is_array_api_strict_namespace
     """
-    return xp.__name__ in {'dask.array', _compat_module_name() + '.dask.array'}
+    return xp.__name__ in {"dask.array", _compat_module_name() + ".dask.array"}
 
 
 def is_jax_namespace(xp: Namespace) -> bool:
@@ -388,7 +432,7 @@ def is_jax_namespace(xp: Namespace) -> bool:
     is_pydata_sparse_namespace
     is_array_api_strict_namespace
     """
-    return xp.__name__ in {'jax.numpy', 'jax.experimental.array_api'}
+    return xp.__name__ in {"jax.numpy", "jax.experimental.array_api"}
 
 
 def is_pydata_sparse_namespace(xp: Namespace) -> bool:
@@ -407,7 +451,7 @@ def is_pydata_sparse_namespace(xp: Namespace) -> bool:
     is_jax_namespace
     is_array_api_strict_namespace
     """
-    return xp.__name__ == 'sparse'
+    return xp.__name__ == "sparse"
 
 
 def is_array_api_strict_namespace(xp: Namespace) -> bool:
@@ -426,21 +470,29 @@ def is_array_api_strict_namespace(xp: Namespace) -> bool:
     is_jax_namespace
     is_pydata_sparse_namespace
     """
-    return xp.__name__ == 'array_api_strict'
+    return xp.__name__ == "array_api_strict"
 
 
-def _check_api_version(api_version: str) -> None:
-    if api_version in ['2021.12', '2022.12', '2023.12']:
-        warnings.warn(f"The {api_version} version of the array API specification was requested but the returned namespace is actually version 2024.12")
-    elif api_version is not None and api_version not in ['2021.12', '2022.12',
-                                                         '2023.12', '2024.12']:
-        raise ValueError("Only the 2024.12 version of the array API specification is currently supported")
+def _check_api_version(api_version: str | None) -> None:
+    if api_version in ["2021.12", "2022.12", "2023.12"]:
+        warnings.warn(
+            f"The {api_version} version of the array API specification was requested but the returned namespace is actually version 2024.12"
+        )
+    elif api_version is not None and api_version not in [
+        "2021.12",
+        "2022.12",
+        "2023.12",
+        "2024.12",
+    ]:
+        raise ValueError(
+            "Only the 2024.12 version of the array API specification is currently supported"
+        )
 
 
 def array_namespace(
-    *xs: Union[Array, bool, int, float, complex, None],
-    api_version: Optional[str] = None,
-    use_compat: Optional[bool] = None,
+    *xs: Array | complex | None,
+    api_version: str | None = None,
+    use_compat: bool | None = None,
 ) -> Namespace:
     """
     Get the array API compatible namespace for the arrays `xs`.
@@ -510,11 +562,13 @@ def your_function(x, y):
 
     _use_compat = use_compat in [None, True]
 
-    namespaces = set()
+    namespaces: set[Namespace] = set()
     for x in xs:
         if is_numpy_array(x):
-            from .. import numpy as numpy_namespace
             import numpy as np
+
+            from .. import numpy as numpy_namespace
+
             if use_compat is True:
                 _check_api_version(api_version)
                 namespaces.add(numpy_namespace)
@@ -528,25 +582,31 @@ def your_function(x, y):
             if _use_compat:
                 _check_api_version(api_version)
                 from .. import cupy as cupy_namespace
+
                 namespaces.add(cupy_namespace)
             else:
-                import cupy as cp
+                import cupy as cp  # pyright: ignore[reportMissingTypeStubs]
+
                 namespaces.add(cp)
         elif is_torch_array(x):
             if _use_compat:
                 _check_api_version(api_version)
                 from .. import torch as torch_namespace
+
                 namespaces.add(torch_namespace)
             else:
                 import torch
+
                 namespaces.add(torch)
         elif is_dask_array(x):
             if _use_compat:
                 _check_api_version(api_version)
                 from ..dask import array as dask_namespace
+
                 namespaces.add(dask_namespace)
             else:
                 import dask.array as da
+
                 namespaces.add(da)
         elif is_jax_array(x):
             if use_compat is True:
@@ -558,23 +618,27 @@ def your_function(x, y):
                 # JAX v0.4.32 and newer implements the array API directly in jax.numpy.
                 # For older JAX versions, it is available via jax.experimental.array_api.
                 import jax.numpy
+
                 if hasattr(jax.numpy, "__array_api_version__"):
                     jnp = jax.numpy
                 else:
-                    import jax.experimental.array_api as jnp
+                    import jax.experimental.array_api as jnp  # pyright: ignore[reportMissingImports]
             namespaces.add(jnp)
         elif is_pydata_sparse_array(x):
             if use_compat is True:
                 _check_api_version(api_version)
                 raise ValueError("`sparse` does not have an array-api-compat wrapper")
             else:
-                import sparse
+                import sparse  # pyright: ignore[reportMissingTypeStubs]
             # `sparse` is already an array namespace. We do not have a wrapper
             # submodule for it.
             namespaces.add(sparse)
-        elif hasattr(x, '__array_namespace__'):
+        elif hasattr(x, "__array_namespace__"):
             if use_compat is True:
-                raise ValueError("The given array does not have an array-api-compat wrapper")
+                raise ValueError(
+                    "The given array does not have an array-api-compat wrapper"
+                )
+            x = cast("SupportsArrayNamespace[Any]", x)
             namespaces.add(x.__array_namespace__(api_version=api_version))
         elif isinstance(x, (bool, int, float, complex, type(None))):
             continue
@@ -588,15 +652,16 @@ def your_function(x, y):
     if len(namespaces) != 1:
         raise TypeError(f"Multiple namespaces for array inputs: {namespaces}")
 
-    xp, = namespaces
+    (xp,) = namespaces
 
     return xp
 
+
 # backwards compatibility alias
 get_namespace = array_namespace
 
 
-def _check_device(bare_xp, device):
+def _check_device(bare_xp: Namespace, device: Device) -> None:  # pyright: ignore[reportUnusedFunction]
     """
     Validate dummy device on device-less array backends.
 
@@ -609,11 +674,11 @@ def _check_device(bare_xp, device):
 
     https://github.com/data-apis/array-api-compat/pull/293
     """
-    if bare_xp is sys.modules.get('numpy'):
+    if bare_xp is sys.modules.get("numpy"):
         if device not in ("cpu", None):
             raise ValueError(f"Unsupported device for NumPy: {device!r}")
 
-    elif bare_xp is sys.modules.get('dask.array'):
+    elif bare_xp is sys.modules.get("dask.array"):
         if device not in ("cpu", _DASK_DEVICE, None):
             raise ValueError(f"Unsupported device for Dask: {device!r}")
 
@@ -622,18 +687,20 @@ def _check_device(bare_xp, device):
 # when the array backend is not the CPU.
 # (since it is not easy to tell which device a dask array is on)
 class _dask_device:
-    def __repr__(self):
+    def __repr__(self) -> Literal["DASK_DEVICE"]:
         return "DASK_DEVICE"
 
+
 _DASK_DEVICE = _dask_device()
 
+
 # device() is not on numpy.ndarray or dask.array and to_device() is not on numpy.ndarray
 # or cupy.ndarray. They are not included in array objects of this library
 # because this library just reuses the respective ndarray classes without
 # wrapping or subclassing them. These helper functions can be used instead of
 # the wrapper functions for libraries that need to support both NumPy/CuPy and
 # other libraries that use devices.
-def device(x: Array, /) -> Device:
+def device(x: _ArrayApiObj, /) -> Device:
     """
     Hardware device the array data resides on.
 
@@ -669,7 +736,7 @@ def device(x: Array, /) -> Device:
         return "cpu"
     elif is_dask_array(x):
         # Peek at the metadata of the Dask array to determine type
-        if is_numpy_array(x._meta):
+        if is_numpy_array(x._meta):  # pyright: ignore
             # Must be on CPU since backed by numpy
             return "cpu"
         return _DASK_DEVICE
@@ -679,7 +746,7 @@ def device(x: Array, /) -> Device:
         #       Return None in this case. Note that this workaround breaks
         #       the standard and will result in new arrays being created on the
         #       default device instead of the same device as the input array(s).
-        x_device = getattr(x, 'device', None)
+        x_device = getattr(x, "device", None)
         # Older JAX releases had .device() as a method, which has been replaced
         # with a property in accordance with the standard.
         if inspect.ismethod(x_device):
@@ -688,27 +755,34 @@ def device(x: Array, /) -> Device:
             return x_device
     elif is_pydata_sparse_array(x):
         # `sparse` will gain `.device`, so check for this first.
-        x_device = getattr(x, 'device', None)
+        x_device = getattr(x, "device", None)
         if x_device is not None:
             return x_device
         # Everything but DOK has this attr.
         try:
-            inner = x.data
+            inner = x.data  # pyright: ignore
         except AttributeError:
             return "cpu"
         # Return the device of the constituent array
-        return device(inner)
-    return x.device
+        return device(inner)  # pyright: ignore
+    return x.device  # pyright: ignore
+
 
 # Prevent shadowing, used below
 _device = device
 
+
 # Based on cupy.array_api.Array.to_device
-def _cupy_to_device(x, device, /, stream=None):
-    import cupy as cp
-    from cupy.cuda import Device as _Device
-    from cupy.cuda import stream as stream_module
-    from cupy_backends.cuda.api import runtime
+def _cupy_to_device(
+    x: _CupyArray,
+    device: Device,
+    /,
+    stream: int | Any | None = None,
+) -> _CupyArray:
+    import cupy as cp  # pyright: ignore[reportMissingTypeStubs]
+    from cupy.cuda import Device as _Device  # pyright: ignore
+    from cupy.cuda import stream as stream_module  # pyright: ignore
+    from cupy_backends.cuda.api import runtime  # pyright: ignore
 
     if device == x.device:
         return x
@@ -721,33 +795,40 @@ def _cupy_to_device(x, device, /, stream=None):
         raise ValueError(f"Unsupported device {device!r}")
     else:
         # see cupy/cupy#5985 for the reason how we handle device/stream here
-        prev_device = runtime.getDevice()
-        prev_stream: stream_module.Stream = None
+        prev_device: Any = runtime.getDevice()  # pyright: ignore[reportUnknownMemberType]
+        prev_stream = None
         if stream is not None:
-            prev_stream = stream_module.get_current_stream()
+            prev_stream: Any = stream_module.get_current_stream()  # pyright: ignore
             # stream can be an int as specified in __dlpack__, or a CuPy stream
             if isinstance(stream, int):
-                stream = cp.cuda.ExternalStream(stream)
-            elif isinstance(stream, cp.cuda.Stream):
+                stream = cp.cuda.ExternalStream(stream)  # pyright: ignore
+            elif isinstance(stream, cp.cuda.Stream):  # pyright: ignore[reportUnknownMemberType]
                 pass
             else:
-                raise ValueError('the input stream is not recognized')
-            stream.use()
+                raise ValueError("the input stream is not recognized")
+            stream.use()  # pyright: ignore[reportUnknownMemberType]
         try:
-            runtime.setDevice(device.id)
+            runtime.setDevice(device.id)  # pyright: ignore[reportUnknownMemberType]
             arr = x.copy()
         finally:
-            runtime.setDevice(prev_device)
+            runtime.setDevice(prev_device)  # pyright: ignore[reportUnknownMemberType]
             if stream is not None:
                 prev_stream.use()
         return arr
 
-def _torch_to_device(x, device, /, stream=None):
+
+def _torch_to_device(
+    x: torch.Tensor,
+    device: torch.device | str | int,
+    /,
+    stream: None = None,
+) -> torch.Tensor:
     if stream is not None:
         raise NotImplementedError
     return x.to(device)
 
-def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] = None) -> Array:
+
+def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) -> Array:
     """
     Copy the array from the device on which it currently resides to the specified ``device``.
 
@@ -767,7 +848,7 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
         a ``device`` object (see the `Device Support <https://data-apis.org/array-api/latest/design_topics/device_support.html>`__
         section of the array API specification).
 
-    stream: Optional[Union[int, Any]]
+    stream: int | Any | None
         stream object to use during copy. In addition to the types supported
         in ``array.__dlpack__``, implementations may choose to support any
         library-specific stream object with the caveat that any code using
@@ -799,25 +880,26 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
     if is_numpy_array(x):
         if stream is not None:
             raise ValueError("The stream argument to to_device() is not supported")
-        if device == 'cpu':
+        if device == "cpu":
             return x
         raise ValueError(f"Unsupported device {device!r}")
     elif is_cupy_array(x):
         # cupy does not yet have to_device
         return _cupy_to_device(x, device, stream=stream)
     elif is_torch_array(x):
-        return _torch_to_device(x, device, stream=stream)
+        return _torch_to_device(x, device, stream=stream)  # pyright: ignore[reportArgumentType]
     elif is_dask_array(x):
         if stream is not None:
             raise ValueError("The stream argument to to_device() is not supported")
         # TODO: What if our array is on the GPU already?
-        if device == 'cpu':
+        if device == "cpu":
             return x
         raise ValueError(f"Unsupported device {device!r}")
     elif is_jax_array(x):
         if not hasattr(x, "__array_namespace__"):
             # In JAX v0.4.31 and older, this import adds to_device method to x...
-            import jax.experimental.array_api # noqa: F401
+            import jax.experimental.array_api  # noqa: F401  # pyright: ignore
+
             # ... but only on eager JAX. It won't work inside jax.jit.
             if not hasattr(x, "to_device"):
                 return x
@@ -826,10 +908,16 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
         # Perform trivial check to return the same array if
         # device is same instead of err-ing.
         return x
-    return x.to_device(device, stream=stream)
+    return x.to_device(device, stream=stream)  # pyright: ignore
 
 
-def size(x: Array) -> int | None:
+@overload
+def size(x: HasShape[Collection[SupportsIndex]]) -> int: ...
+@overload
+def size(x: HasShape[Collection[None]]) -> None: ...
+@overload
+def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: ...
+def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None:
     """
     Return the total number of elements of x.
 
@@ -844,7 +932,7 @@ def size(x: Array) -> int | None:
     # Lazy API compliant arrays, such as ndonnx, can contain None in their shape
     if None in x.shape:
         return None
-    out = math.prod(x.shape)
+    out = math.prod(cast("Collection[SupportsIndex]", x.shape))
     # dask.array.Array.shape can contain NaN
     return None if math.isnan(out) else out
 
@@ -907,7 +995,7 @@ def is_lazy_array(x: object) -> bool:
     # on __bool__ (dask is one such example, which however is special-cased above).
 
     # Select a single point of the array
-    s = size(x)
+    s = size(cast("HasShape[Collection[SupportsIndex | None]]", x))
     if s is None:
         return True
     xp = array_namespace(x)
@@ -952,4 +1040,4 @@ def is_lazy_array(x: object) -> bool:
     "to_device",
 ]
 
-_all_ignore = ['sys', 'math', 'inspect', 'warnings']
+_all_ignore = ["sys", "math", "inspect", "warnings"]
diff --git a/array_api_compat/common/_typing.py b/array_api_compat/common/_typing.py
index 4c3b356b..a38d083b 100644
--- a/array_api_compat/common/_typing.py
+++ b/array_api_compat/common/_typing.py
@@ -1,11 +1,14 @@
 from __future__ import annotations
+
 from types import ModuleType as Namespace
-from typing import Any, TypeVar, Protocol
+from typing import Any, Protocol, TypeVar
 
 __all__ = [
     "Array",
+    "SupportsArrayNamespace",
     "DType",
     "Device",
+    "HasShape",
     "Namespace",
     "NestedSequence",
     "SupportsBufferProtocol",
@@ -18,6 +21,15 @@ def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ...
     def __len__(self, /) -> int: ...
 
 
+class SupportsArrayNamespace(Protocol[_T_co]):
+    def __array_namespace__(self, /, *, api_version: str | None) -> _T_co: ...
+
+
+class HasShape(Protocol[_T_co]):
+    @property
+    def shape(self, /) -> _T_co: ...
+
+
 SupportsBufferProtocol = Any
 Array = Any
 Device = Any

From 3b134b0c5d835f816b05729de4485ec2575520df Mon Sep 17 00:00:00 2001
From: jorenham <jhammudoglu@gmail.com>
Date: Sat, 22 Mar 2025 22:00:36 +0100
Subject: [PATCH 03/24] TYP: fix typing errors in `common._fft`

---
 array_api_compat/common/_fft.py | 68 ++++++++++++++++-----------------
 1 file changed, 34 insertions(+), 34 deletions(-)

diff --git a/array_api_compat/common/_fft.py b/array_api_compat/common/_fft.py
index bd2a4e1a..49f948ce 100644
--- a/array_api_compat/common/_fft.py
+++ b/array_api_compat/common/_fft.py
@@ -1,9 +1,9 @@
-from __future__ import annotations
-
 from collections.abc import Sequence
-from typing import Union, Optional, Literal
+from typing import Literal, TypeAlias
+
+from ._typing import Array, Device, DType, Namespace
 
-from ._typing import Device, Array, DType, Namespace
+_Norm: TypeAlias = Literal["backward", "ortho", "forward"]
 
 # Note: NumPy fft functions improperly upcast float32 and complex64 to
 # complex128, which is why we require wrapping them all here.
@@ -13,9 +13,9 @@ def fft(
     /,
     xp: Namespace,
     *,
-    n: Optional[int] = None,
+    n: int | None = None,
     axis: int = -1,
-    norm: Literal["backward", "ortho", "forward"] = "backward",
+    norm: _Norm = "backward",
 ) -> Array:
     res = xp.fft.fft(x, n=n, axis=axis, norm=norm)
     if x.dtype in [xp.float32, xp.complex64]:
@@ -27,9 +27,9 @@ def ifft(
     /,
     xp: Namespace,
     *,
-    n: Optional[int] = None,
+    n: int | None = None,
     axis: int = -1,
-    norm: Literal["backward", "ortho", "forward"] = "backward",
+    norm: _Norm = "backward",
 ) -> Array:
     res = xp.fft.ifft(x, n=n, axis=axis, norm=norm)
     if x.dtype in [xp.float32, xp.complex64]:
@@ -41,9 +41,9 @@ def fftn(
     /,
     xp: Namespace,
     *,
-    s: Sequence[int] = None,
-    axes: Sequence[int] = None,
-    norm: Literal["backward", "ortho", "forward"] = "backward",
+    s: Sequence[int] | None = None,
+    axes: Sequence[int] | None = None,
+    norm: _Norm = "backward",
 ) -> Array:
     res = xp.fft.fftn(x, s=s, axes=axes, norm=norm)
     if x.dtype in [xp.float32, xp.complex64]:
@@ -55,9 +55,9 @@ def ifftn(
     /,
     xp: Namespace,
     *,
-    s: Sequence[int] = None,
-    axes: Sequence[int] = None,
-    norm: Literal["backward", "ortho", "forward"] = "backward",
+    s: Sequence[int] | None = None,
+    axes: Sequence[int] | None = None,
+    norm: _Norm = "backward",
 ) -> Array:
     res = xp.fft.ifftn(x, s=s, axes=axes, norm=norm)
     if x.dtype in [xp.float32, xp.complex64]:
@@ -69,9 +69,9 @@ def rfft(
     /,
     xp: Namespace,
     *,
-    n: Optional[int] = None,
+    n: int | None = None,
     axis: int = -1,
-    norm: Literal["backward", "ortho", "forward"] = "backward",
+    norm: _Norm = "backward",
 ) -> Array:
     res = xp.fft.rfft(x, n=n, axis=axis, norm=norm)
     if x.dtype == xp.float32:
@@ -83,9 +83,9 @@ def irfft(
     /,
     xp: Namespace,
     *,
-    n: Optional[int] = None,
+    n: int | None = None,
     axis: int = -1,
-    norm: Literal["backward", "ortho", "forward"] = "backward",
+    norm: _Norm = "backward",
 ) -> Array:
     res = xp.fft.irfft(x, n=n, axis=axis, norm=norm)
     if x.dtype == xp.complex64:
@@ -97,9 +97,9 @@ def rfftn(
     /,
     xp: Namespace,
     *,
-    s: Sequence[int] = None,
-    axes: Sequence[int] = None,
-    norm: Literal["backward", "ortho", "forward"] = "backward",
+    s: Sequence[int] | None = None,
+    axes: Sequence[int] | None = None,
+    norm: _Norm = "backward",
 ) -> Array:
     res = xp.fft.rfftn(x, s=s, axes=axes, norm=norm)
     if x.dtype == xp.float32:
@@ -111,9 +111,9 @@ def irfftn(
     /,
     xp: Namespace,
     *,
-    s: Sequence[int] = None,
-    axes: Sequence[int] = None,
-    norm: Literal["backward", "ortho", "forward"] = "backward",
+    s: Sequence[int] | None = None,
+    axes: Sequence[int] | None = None,
+    norm: _Norm = "backward",
 ) -> Array:
     res = xp.fft.irfftn(x, s=s, axes=axes, norm=norm)
     if x.dtype == xp.complex64:
@@ -125,9 +125,9 @@ def hfft(
     /,
     xp: Namespace,
     *,
-    n: Optional[int] = None,
+    n: int | None = None,
     axis: int = -1,
-    norm: Literal["backward", "ortho", "forward"] = "backward",
+    norm: _Norm = "backward",
 ) -> Array:
     res = xp.fft.hfft(x, n=n, axis=axis, norm=norm)
     if x.dtype in [xp.float32, xp.complex64]:
@@ -139,9 +139,9 @@ def ihfft(
     /,
     xp: Namespace,
     *,
-    n: Optional[int] = None,
+    n: int | None = None,
     axis: int = -1,
-    norm: Literal["backward", "ortho", "forward"] = "backward",
+    norm: _Norm = "backward",
 ) -> Array:
     res = xp.fft.ihfft(x, n=n, axis=axis, norm=norm)
     if x.dtype in [xp.float32, xp.complex64]:
@@ -154,8 +154,8 @@ def fftfreq(
     xp: Namespace,
     *,
     d: float = 1.0,
-    dtype: Optional[DType] = None,
-    device: Optional[Device] = None,
+    dtype: DType | None = None,
+    device: Device | None = None,
 ) -> Array:
     if device not in ["cpu", None]:
         raise ValueError(f"Unsupported device {device!r}")
@@ -170,8 +170,8 @@ def rfftfreq(
     xp: Namespace,
     *,
     d: float = 1.0,
-    dtype: Optional[DType] = None,
-    device: Optional[Device] = None,
+    dtype: DType | None = None,
+    device: Device | None = None,
 ) -> Array:
     if device not in ["cpu", None]:
         raise ValueError(f"Unsupported device {device!r}")
@@ -181,12 +181,12 @@ def rfftfreq(
     return res
 
 def fftshift(
-    x: Array, /, xp: Namespace, *, axes: Union[int, Sequence[int]] = None
+    x: Array, /, xp: Namespace, *, axes: int | Sequence[int] | None = None
 ) -> Array:
     return xp.fft.fftshift(x, axes=axes)
 
 def ifftshift(
-    x: Array, /, xp: Namespace, *, axes: Union[int, Sequence[int]] = None
+    x: Array, /, xp: Namespace, *, axes: int | Sequence[int] | None = None
 ) -> Array:
     return xp.fft.ifftshift(x, axes=axes)
 

From 344ac1ece797de7880edffadd6af68ba91d7cd11 Mon Sep 17 00:00:00 2001
From: jorenham <jhammudoglu@gmail.com>
Date: Sat, 22 Mar 2025 22:11:33 +0100
Subject: [PATCH 04/24] TYP: fix typing errors in `common._aliases`

---
 array_api_compat/common/_aliases.py | 327 ++++++++++++++++++----------
 1 file changed, 206 insertions(+), 121 deletions(-)

diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py
index 351b5bd6..cd4b6af8 100644
--- a/array_api_compat/common/_aliases.py
+++ b/array_api_compat/common/_aliases.py
@@ -5,158 +5,169 @@
 from __future__ import annotations
 
 import inspect
-from typing import Any, NamedTuple, Optional, Sequence, Tuple, Union
+from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Sequence, cast
 
+from ._helpers import _check_device, array_namespace
+from ._helpers import device as _get_device
+from ._helpers import is_cupy_namespace as _is_cupy_namespace
 from ._typing import Array, Device, DType, Namespace
-from ._helpers import (
-    array_namespace,
-    _check_device,
-    device as _get_device,
-    is_cupy_namespace as _is_cupy_namespace
-)
 
+if TYPE_CHECKING:
+    from typing_extensions import TypeIs
 
 # These functions are modified from the NumPy versions.
 
 # Creation functions add the device keyword (which does nothing for NumPy and Dask)
 
+
 def arange(
-    start: Union[int, float],
+    start: float,
     /,
-    stop: Optional[Union[int, float]] = None,
-    step: Union[int, float] = 1,
+    stop: float | None = None,
+    step: float = 1,
     *,
     xp: Namespace,
-    dtype: Optional[DType] = None,
-    device: Optional[Device] = None,
-    **kwargs,
+    dtype: DType | None = None,
+    device: Device | None = None,
+    **kwargs: object,
 ) -> Array:
     _check_device(xp, device)
     return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs)
 
+
 def empty(
-    shape: Union[int, Tuple[int, ...]],
+    shape: int | tuple[int, ...],
     xp: Namespace,
     *,
-    dtype: Optional[DType] = None,
-    device: Optional[Device] = None,
-    **kwargs,
+    dtype: DType | None = None,
+    device: Device | None = None,
+    **kwargs: object,
 ) -> Array:
     _check_device(xp, device)
     return xp.empty(shape, dtype=dtype, **kwargs)
 
+
 def empty_like(
     x: Array,
     /,
     xp: Namespace,
     *,
-    dtype: Optional[DType] = None, 
-    device: Optional[Device] = None,
-    **kwargs,
+    dtype: DType | None = None,
+    device: Device | None = None,
+    **kwargs: object,
 ) -> Array:
     _check_device(xp, device)
     return xp.empty_like(x, dtype=dtype, **kwargs)
 
+
 def eye(
     n_rows: int,
-    n_cols: Optional[int] = None,
+    n_cols: int | None = None,
     /,
     *,
     xp: Namespace,
     k: int = 0,
-    dtype: Optional[DType] = None,
-    device: Optional[Device] = None,
-    **kwargs,
+    dtype: DType | None = None,
+    device: Device | None = None,
+    **kwargs: object,
 ) -> Array:
     _check_device(xp, device)
     return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs)
 
+
 def full(
-    shape: Union[int, Tuple[int, ...]],
-    fill_value: bool | int | float | complex,
+    shape: int | tuple[int, ...],
+    fill_value: complex,
     xp: Namespace,
     *,
-    dtype: Optional[DType] = None,
-    device: Optional[Device] = None,
-    **kwargs,
+    dtype: DType | None = None,
+    device: Device | None = None,
+    **kwargs: object,
 ) -> Array:
     _check_device(xp, device)
     return xp.full(shape, fill_value, dtype=dtype, **kwargs)
 
+
 def full_like(
     x: Array,
     /,
-    fill_value: bool | int | float | complex,
+    fill_value: complex,
     *,
     xp: Namespace,
-    dtype: Optional[DType] = None,
-    device: Optional[Device] = None,
-    **kwargs,
+    dtype: DType | None = None,
+    device: Device | None = None,
+    **kwargs: object,
 ) -> Array:
     _check_device(xp, device)
     return xp.full_like(x, fill_value, dtype=dtype, **kwargs)
 
+
 def linspace(
-    start: Union[int, float],
-    stop: Union[int, float],
+    start: float,
+    stop: float,
     /,
     num: int,
     *,
     xp: Namespace,
-    dtype: Optional[DType] = None,
-    device: Optional[Device] = None,
+    dtype: DType | None = None,
+    device: Device | None = None,
     endpoint: bool = True,
-    **kwargs,
+    **kwargs: object,
 ) -> Array:
     _check_device(xp, device)
     return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs)
 
+
 def ones(
-    shape: Union[int, Tuple[int, ...]],
+    shape: int | tuple[int, ...],
     xp: Namespace,
     *,
-    dtype: Optional[DType] = None,
-    device: Optional[Device] = None,
-    **kwargs,
+    dtype: DType | None = None,
+    device: Device | None = None,
+    **kwargs: object,
 ) -> Array:
     _check_device(xp, device)
     return xp.ones(shape, dtype=dtype, **kwargs)
 
+
 def ones_like(
     x: Array,
     /,
     xp: Namespace,
     *,
-    dtype: Optional[DType] = None,
-    device: Optional[Device] = None,
-    **kwargs,
+    dtype: DType | None = None,
+    device: Device | None = None,
+    **kwargs: object,
 ) -> Array:
     _check_device(xp, device)
     return xp.ones_like(x, dtype=dtype, **kwargs)
 
+
 def zeros(
-    shape: Union[int, Tuple[int, ...]],
+    shape: int | tuple[int, ...],
     xp: Namespace,
     *,
-    dtype: Optional[DType] = None,
-    device: Optional[Device] = None,
-    **kwargs,
+    dtype: DType | None = None,
+    device: Device | None = None,
+    **kwargs: object,
 ) -> Array:
     _check_device(xp, device)
     return xp.zeros(shape, dtype=dtype, **kwargs)
 
+
 def zeros_like(
     x: Array,
     /,
     xp: Namespace,
     *,
-    dtype: Optional[DType] = None,
-    device: Optional[Device] = None,
-    **kwargs,
+    dtype: DType | None = None,
+    device: Device | None = None,
+    **kwargs: object,
 ) -> Array:
     _check_device(xp, device)
     return xp.zeros_like(x, dtype=dtype, **kwargs)
 
+
 # np.unique() is split into four functions in the array API:
 # unique_all, unique_counts, unique_inverse, and unique_values (this is done
 # to remove polymorphic return types).
@@ -164,6 +175,7 @@ def zeros_like(
 # The functions here return namedtuples (np.unique() returns a normal
 # tuple).
 
+
 # Note that these named tuples aren't actually part of the standard namespace,
 # but I don't see any issue with exporting the names here regardless.
 class UniqueAllResult(NamedTuple):
@@ -188,10 +200,11 @@ def _unique_kwargs(xp: Namespace) -> dict[str, bool]:
     # trying to parse version numbers, just check if equal_nan is in the
     # signature.
     s = inspect.signature(xp.unique)
-    if 'equal_nan' in s.parameters:
-        return {'equal_nan': False}
+    if "equal_nan" in s.parameters:
+        return {"equal_nan": False}
     return {}
 
+
 def unique_all(x: Array, /, xp: Namespace) -> UniqueAllResult:
     kwargs = _unique_kwargs(xp)
     values, indices, inverse_indices, counts = xp.unique(
@@ -215,11 +228,7 @@ def unique_all(x: Array, /, xp: Namespace) -> UniqueAllResult:
 def unique_counts(x: Array, /, xp: Namespace) -> UniqueCountsResult:
     kwargs = _unique_kwargs(xp)
     res = xp.unique(
-        x,
-        return_counts=True,
-        return_index=False,
-        return_inverse=False,
-        **kwargs
+        x, return_counts=True, return_index=False, return_inverse=False, **kwargs
     )
 
     return UniqueCountsResult(*res)
@@ -250,51 +259,58 @@ def unique_values(x: Array, /, xp: Namespace) -> Array:
         **kwargs,
     )
 
+
 # These functions have different keyword argument names
 
+
 def std(
     x: Array,
     /,
     xp: Namespace,
     *,
-    axis: Optional[Union[int, Tuple[int, ...]]] = None,
-    correction: Union[int, float] = 0.0,  # correction instead of ddof
+    axis: int | tuple[int, ...] | None = None,
+    correction: float = 0.0,  # correction instead of ddof
     keepdims: bool = False,
-    **kwargs,
+    **kwargs: object,
 ) -> Array:
     return xp.std(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs)
 
+
 def var(
     x: Array,
     /,
     xp: Namespace,
     *,
-    axis: Optional[Union[int, Tuple[int, ...]]] = None,
-    correction: Union[int, float] = 0.0,  # correction instead of ddof
+    axis: int | tuple[int, ...] | None = None,
+    correction: float = 0.0,  # correction instead of ddof
     keepdims: bool = False,
-    **kwargs,
+    **kwargs: object,
 ) -> Array:
     return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs)
 
+
 # cumulative_sum is renamed from cumsum, and adds the include_initial keyword
 # argument
 
+
 def cumulative_sum(
     x: Array,
     /,
     xp: Namespace,
     *,
-    axis: Optional[int] = None,
-    dtype: Optional[DType] = None,
+    axis: int | None = None,
+    dtype: DType | None = None,
     include_initial: bool = False,
-    **kwargs,
+    **kwargs: object,
 ) -> Array:
     wrapped_xp = array_namespace(x)
 
     # TODO: The standard is not clear about what should happen when x.ndim == 0.
     if axis is None:
         if x.ndim > 1:
-            raise ValueError("axis must be specified in cumulative_sum for more than one dimension")
+            raise ValueError(
+                "axis must be specified in cumulative_sum for more than one dimension"
+            )
         axis = 0
 
     res = xp.cumsum(x, axis=axis, dtype=dtype, **kwargs)
@@ -304,7 +320,12 @@ def cumulative_sum(
         initial_shape = list(x.shape)
         initial_shape[axis] = 1
         res = xp.concatenate(
-            [wrapped_xp.zeros(shape=initial_shape, dtype=res.dtype, device=_get_device(res)), res],
+            [
+                wrapped_xp.zeros(
+                    shape=initial_shape, dtype=res.dtype, device=_get_device(res)
+                ),
+                res,
+            ],
             axis=axis,
         )
     return res
@@ -315,16 +336,18 @@ def cumulative_prod(
     /,
     xp: Namespace,
     *,
-    axis: Optional[int] = None,
-    dtype: Optional[DType] = None,
+    axis: int | None = None,
+    dtype: DType | None = None,
     include_initial: bool = False,
-    **kwargs,
+    **kwargs: object,
 ) -> Array:
     wrapped_xp = array_namespace(x)
 
     if axis is None:
         if x.ndim > 1:
-            raise ValueError("axis must be specified in cumulative_prod for more than one dimension")
+            raise ValueError(
+                "axis must be specified in cumulative_prod for more than one dimension"
+            )
         axis = 0
 
     res = xp.cumprod(x, axis=axis, dtype=dtype, **kwargs)
@@ -334,24 +357,30 @@ def cumulative_prod(
         initial_shape = list(x.shape)
         initial_shape[axis] = 1
         res = xp.concatenate(
-            [wrapped_xp.ones(shape=initial_shape, dtype=res.dtype, device=_get_device(res)), res],
+            [
+                wrapped_xp.ones(
+                    shape=initial_shape, dtype=res.dtype, device=_get_device(res)
+                ),
+                res,
+            ],
             axis=axis,
         )
     return res
 
+
 # The min and max argument names in clip are different and not optional in numpy, and type
 # promotion behavior is different.
 def clip(
     x: Array,
     /,
-    min: Optional[Union[int, float, Array]] = None,
-    max: Optional[Union[int, float, Array]] = None,
+    min: float | Array | None = None,
+    max: float | Array | None = None,
     *,
     xp: Namespace,
     # TODO: np.clip has other ufunc kwargs
-    out: Optional[Array] = None,
+    out: Array | None = None,
 ) -> Array:
-    def _isscalar(a):
+    def _isscalar(a: object) -> TypeIs[int | float | None]:
         return isinstance(a, (int, float, type(None)))
 
     min_shape = () if _isscalar(min) else min.shape
@@ -378,7 +407,6 @@ def _isscalar(a):
     # but an answer of 0 might be preferred. See
     # https://github.com/numpy/numpy/issues/24976 for more discussion on this issue.
 
-
     # At least handle the case of Python integers correctly (see
     # https://github.com/numpy/numpy/pull/26892).
     if wrapped_xp.isdtype(x.dtype, "integral"):
@@ -390,6 +418,7 @@ def _isscalar(a):
     dev = _get_device(x)
     if out is None:
         out = wrapped_xp.empty(result_shape, dtype=x.dtype, device=dev)
+    assert out is not None  # workaround for a type-narrowing issue in pyright
     out[()] = x
 
     if min is not None:
@@ -407,19 +436,21 @@ def _isscalar(a):
     # Return a scalar for 0-D
     return out[()]
 
+
 # Unlike transpose(), the axes argument to permute_dims() is required.
-def permute_dims(x: Array, /, axes: Tuple[int, ...], xp: Namespace) -> Array:
+def permute_dims(x: Array, /, axes: tuple[int, ...], xp: Namespace) -> Array:
     return xp.transpose(x, axes)
 
+
 # np.reshape calls the keyword argument 'newshape' instead of 'shape'
 def reshape(
     x: Array,
     /,
-    shape: Tuple[int, ...],
+    shape: tuple[int, ...],
     xp: Namespace,
     *,
     copy: Optional[bool] = None,
-    **kwargs,
+    **kwargs: object,
 ) -> Array:
     if copy is True:
         x = x.copy()
@@ -429,6 +460,7 @@ def reshape(
         return y
     return xp.reshape(x, shape, **kwargs)
 
+
 # The descending keyword is new in sort and argsort, and 'kind' replaced with
 # 'stable'
 def argsort(
@@ -439,13 +471,13 @@ def argsort(
     axis: int = -1,
     descending: bool = False,
     stable: bool = True,
-    **kwargs,
+    **kwargs: object,
 ) -> Array:
     # Note: this keyword argument is different, and the default is different.
     # We set it in kwargs like this because numpy.sort uses kind='quicksort'
     # as the default whereas cupy.sort uses kind=None.
     if stable:
-        kwargs['kind'] = "stable"
+        kwargs["kind"] = "stable"
     if not descending:
         res = xp.argsort(x, axis=axis, **kwargs)
     else:
@@ -462,6 +494,7 @@ def argsort(
         res = max_i - res
     return res
 
+
 def sort(
     x: Array,
     /,
@@ -470,68 +503,78 @@ def sort(
     axis: int = -1,
     descending: bool = False,
     stable: bool = True,
-    **kwargs,
+    **kwargs: object,
 ) -> Array:
     # Note: this keyword argument is different, and the default is different.
     # We set it in kwargs like this because numpy.sort uses kind='quicksort'
     # as the default whereas cupy.sort uses kind=None.
     if stable:
-        kwargs['kind'] = "stable"
+        kwargs["kind"] = "stable"
     res = xp.sort(x, axis=axis, **kwargs)
     if descending:
         res = xp.flip(res, axis=axis)
     return res
 
+
 # nonzero should error for zero-dimensional arrays
-def nonzero(x: Array, /, xp: Namespace, **kwargs) -> Tuple[Array, ...]:
+def nonzero(x: Array, /, xp: Namespace, **kwargs: object) -> tuple[Array, ...]:
     if x.ndim == 0:
         raise ValueError("nonzero() does not support zero-dimensional arrays")
     return xp.nonzero(x, **kwargs)
 
+
 # ceil, floor, and trunc return integers for integer inputs
 
-def ceil(x: Array, /, xp: Namespace, **kwargs) -> Array:
+
+def ceil(x: Array, /, xp: Namespace, **kwargs: object) -> Array:
     if xp.issubdtype(x.dtype, xp.integer):
         return x
     return xp.ceil(x, **kwargs)
 
-def floor(x: Array, /, xp: Namespace, **kwargs) -> Array:
+
+def floor(x: Array, /, xp: Namespace, **kwargs: object) -> Array:
     if xp.issubdtype(x.dtype, xp.integer):
         return x
     return xp.floor(x, **kwargs)
 
-def trunc(x: Array, /, xp: Namespace, **kwargs) -> Array:
+
+def trunc(x: Array, /, xp: Namespace, **kwargs: object) -> Array:
     if xp.issubdtype(x.dtype, xp.integer):
         return x
     return xp.trunc(x, **kwargs)
 
+
 # linear algebra functions
 
-def matmul(x1: Array, x2: Array, /, xp: Namespace, **kwargs) -> Array:
+
+def matmul(x1: Array, x2: Array, /, xp: Namespace, **kwargs: object) -> Array:
     return xp.matmul(x1, x2, **kwargs)
 
+
 # Unlike transpose, matrix_transpose only transposes the last two axes.
 def matrix_transpose(x: Array, /, xp: Namespace) -> Array:
     if x.ndim < 2:
         raise ValueError("x must be at least 2-dimensional for matrix_transpose")
     return xp.swapaxes(x, -1, -2)
 
+
 def tensordot(
     x1: Array,
     x2: Array,
     /,
     xp: Namespace,
     *,
-    axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2,
-    **kwargs,
+    axes: int | tuple[Sequence[int], Sequence[int]] = 2,
+    **kwargs: object,
 ) -> Array:
     return xp.tensordot(x1, x2, axes=axes, **kwargs)
 
+
 def vecdot(x1: Array, x2: Array, /, xp: Namespace, *, axis: int = -1) -> Array:
     if x1.shape[axis] != x2.shape[axis]:
         raise ValueError("x1 and x2 must have the same size along the given axis")
 
-    if hasattr(xp, 'broadcast_tensors'):
+    if hasattr(xp, "broadcast_tensors"):
         _broadcast = xp.broadcast_tensors
     else:
         _broadcast = xp.broadcast_arrays
@@ -543,14 +586,16 @@ def vecdot(x1: Array, x2: Array, /, xp: Namespace, *, axis: int = -1) -> Array:
     res = xp.conj(x1_[..., None, :]) @ x2_[..., None]
     return res[..., 0, 0]
 
+
 # isdtype is a new function in the 2022.12 array API specification.
 
+
 def isdtype(
     dtype: DType,
-    kind: Union[DType, str, Tuple[Union[DType, str], ...]],
+    kind: DType | str | tuple[DType | str, ...],
     xp: Namespace,
     *,
-    _tuple: bool = True, # Disallow nested tuples
+    _tuple: bool = True,  # Disallow nested tuples
 ) -> bool:
     """
     Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``.
@@ -563,21 +608,24 @@ def isdtype(
     for more details
     """
     if isinstance(kind, tuple) and _tuple:
-        return any(isdtype(dtype, k, xp, _tuple=False) for k in kind)
+        return any(
+            isdtype(dtype, k, xp, _tuple=False)
+            for k in cast("tuple[DType | str, ...]", kind)
+        )
     elif isinstance(kind, str):
-        if kind == 'bool':
+        if kind == "bool":
             return dtype == xp.bool_
-        elif kind == 'signed integer':
+        elif kind == "signed integer":
             return xp.issubdtype(dtype, xp.signedinteger)
-        elif kind == 'unsigned integer':
+        elif kind == "unsigned integer":
             return xp.issubdtype(dtype, xp.unsignedinteger)
-        elif kind == 'integral':
+        elif kind == "integral":
             return xp.issubdtype(dtype, xp.integer)
-        elif kind == 'real floating':
+        elif kind == "real floating":
             return xp.issubdtype(dtype, xp.floating)
-        elif kind == 'complex floating':
+        elif kind == "complex floating":
             return xp.issubdtype(dtype, xp.complexfloating)
-        elif kind == 'numeric':
+        elif kind == "numeric":
             return xp.issubdtype(dtype, xp.number)
         else:
             raise ValueError(f"Unrecognized data type kind: {kind!r}")
@@ -588,24 +636,27 @@ def isdtype(
         # array_api_strict implementation will be very strict.
         return dtype == kind
 
+
 # unstack is a new function in the 2023.12 array API standard
-def unstack(x: Array, /, xp: Namespace, *, axis: int = 0) -> Tuple[Array, ...]:
+def unstack(x: Array, /, xp: Namespace, *, axis: int = 0) -> tuple[Array, ...]:
     if x.ndim == 0:
         raise ValueError("Input array must be at least 1-d.")
     return tuple(xp.moveaxis(x, axis, 0))
 
+
 # numpy 1.26 does not use the standard definition for sign on complex numbers
 
-def sign(x: Array, /, xp: Namespace, **kwargs) -> Array:
-    if isdtype(x.dtype, 'complex floating', xp=xp):
-        out = (x/xp.abs(x, **kwargs))[...]
+
+def sign(x: Array, /, xp: Namespace, **kwargs: object) -> Array:
+    if isdtype(x.dtype, "complex floating", xp=xp):
+        out = (x / xp.abs(x, **kwargs))[...]
         # sign(0) = 0 but the above formula would give nan
-        out[x == 0+0j] = 0+0j
+        out[x == 0 + 0j] = 0 + 0j
     else:
         out = xp.sign(x, **kwargs)
     # CuPy sign() does not propagate nans. See
     # https://github.com/data-apis/array-api-compat/issues/136
-    if _is_cupy_namespace(xp) and isdtype(x.dtype, 'real floating', xp=xp):
+    if _is_cupy_namespace(xp) and isdtype(x.dtype, "real floating", xp=xp):
         out[xp.isnan(x)] = xp.nan
     return out[()]
 
@@ -626,13 +677,47 @@ def iinfo(type_: DType | Array, /, xp: Namespace) -> Any:
         return xp.iinfo(type_.dtype)
 
 
-__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like',
-           'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
-           'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
-           'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
-           'std', 'var', 'cumulative_sum', 'cumulative_prod','clip', 'permute_dims',
-           'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc',
-           'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype',
-           'unstack', 'sign', 'finfo', 'iinfo']
-
-_all_ignore = ['inspect', 'array_namespace', 'NamedTuple']
+__all__ = [
+    "arange",
+    "empty",
+    "empty_like",
+    "eye",
+    "full",
+    "full_like",
+    "linspace",
+    "ones",
+    "ones_like",
+    "zeros",
+    "zeros_like",
+    "UniqueAllResult",
+    "UniqueCountsResult",
+    "UniqueInverseResult",
+    "unique_all",
+    "unique_counts",
+    "unique_inverse",
+    "unique_values",
+    "std",
+    "var",
+    "cumulative_sum",
+    "cumulative_prod",
+    "clip",
+    "permute_dims",
+    "reshape",
+    "argsort",
+    "sort",
+    "nonzero",
+    "ceil",
+    "floor",
+    "trunc",
+    "matmul",
+    "matrix_transpose",
+    "tensordot",
+    "vecdot",
+    "isdtype",
+    "unstack",
+    "sign",
+    "finfo",
+    "iinfo",
+]
+
+_all_ignore = ["inspect", "array_namespace", "NamedTuple"]

From cbec5f360e142c6a73c3509c2f864a31b8be510f Mon Sep 17 00:00:00 2001
From: jorenham <jhammudoglu@gmail.com>
Date: Sat, 22 Mar 2025 22:20:11 +0100
Subject: [PATCH 05/24] TYP: fix typing errors in `common._linalg`

---
 array_api_compat/common/_linalg.py | 106 +++++++++++++++++++++--------
 1 file changed, 78 insertions(+), 28 deletions(-)

diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py
index d1e7ebd8..5a6f070e 100644
--- a/array_api_compat/common/_linalg.py
+++ b/array_api_compat/common/_linalg.py
@@ -1,23 +1,33 @@
 from __future__ import annotations
 
 import math
-from typing import Literal, NamedTuple, Optional, Tuple, Union
+from typing import Literal, NamedTuple, cast
 
 import numpy as np
+
 if np.__version__[0] == "2":
     from numpy.lib.array_utils import normalize_axis_tuple
 else:
     from numpy.core.numeric import normalize_axis_tuple
 
-from ._aliases import matmul, matrix_transpose, tensordot, vecdot, isdtype
 from .._internal import get_xp
-from ._typing import Array, Namespace
+from ._aliases import isdtype, matmul, matrix_transpose, tensordot, vecdot
+from ._typing import Array, DType, Namespace
+
 
 # These are in the main NumPy namespace but not in numpy.linalg
-def cross(x1: Array, x2: Array, /, xp: Namespace, *, axis: int = -1, **kwargs) -> Array:
+def cross(
+    x1: Array,
+    x2: Array,
+    /,
+    xp: Namespace,
+    *,
+    axis: int = -1,
+    **kwargs: object,
+) -> Array:
     return xp.cross(x1, x2, axis=axis, **kwargs)
 
-def outer(x1: Array, x2: Array, /, xp: Namespace, **kwargs) -> Array:
+def outer(x1: Array, x2: Array, /, xp: Namespace, **kwargs: object) -> Array:
     return xp.outer(x1, x2, **kwargs)
 
 class EighResult(NamedTuple):
@@ -39,46 +49,66 @@ class SVDResult(NamedTuple):
 
 # These functions are the same as their NumPy counterparts except they return
 # a namedtuple.
-def eigh(x: Array, /, xp: Namespace, **kwargs) -> EighResult:
+def eigh(x: Array, /, xp: Namespace, **kwargs: object) -> EighResult:
     return EighResult(*xp.linalg.eigh(x, **kwargs))
 
-def qr(x: Array, /, xp: Namespace, *, mode: Literal['reduced', 'complete'] = 'reduced',
-       **kwargs) -> QRResult:
+def qr(
+    x: Array,
+    /,
+    xp: Namespace,
+    *,
+    mode: Literal["reduced", "complete"] = "reduced",
+    **kwargs: object,
+) -> QRResult:
     return QRResult(*xp.linalg.qr(x, mode=mode, **kwargs))
 
-def slogdet(x: Array, /, xp: Namespace, **kwargs) -> SlogdetResult:
+def slogdet(x: Array, /, xp: Namespace, **kwargs: object) -> SlogdetResult:
     return SlogdetResult(*xp.linalg.slogdet(x, **kwargs))
 
 def svd(
-    x: Array, /, xp: Namespace, *, full_matrices: bool = True, **kwargs
+    x: Array,
+    /,
+    xp: Namespace,
+    *,
+    full_matrices: bool = True,
+    **kwargs: object,
 ) -> SVDResult:
     return SVDResult(*xp.linalg.svd(x, full_matrices=full_matrices, **kwargs))
 
 # These functions have additional keyword arguments
 
 # The upper keyword argument is new from NumPy
-def cholesky(x: Array, /, xp: Namespace, *, upper: bool = False, **kwargs) -> Array:
+def cholesky(
+    x: Array,
+    /,
+    xp: Namespace,
+    *,
+    upper: bool = False,
+    **kwargs: object,
+) -> Array:
     L = xp.linalg.cholesky(x, **kwargs)
     if upper:
         U = get_xp(xp)(matrix_transpose)(L)
         if get_xp(xp)(isdtype)(U.dtype, 'complex floating'):
-            U = xp.conj(U)
+            U = xp.conj(U)  # pyright: ignore[reportConstantRedefinition]
         return U
     return L
 
 # The rtol keyword argument of matrix_rank() and pinv() is new from NumPy.
 # Note that it has a different semantic meaning from tol and rcond.
-def matrix_rank(x: Array,
-                /,
-                xp: Namespace,
-                *,
-                rtol: Optional[Union[float, Array]] = None,
-                **kwargs) -> Array:
+def matrix_rank(
+    x: Array,
+    /,
+    xp: Namespace,
+    *,
+    rtol: float | Array | None = None,
+    **kwargs: object,
+) -> Array:
     # this is different from xp.linalg.matrix_rank, which supports 1
     # dimensional arrays.
     if x.ndim < 2:
         raise xp.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional")
-    S = get_xp(xp)(svdvals)(x, **kwargs)
+    S: Array = get_xp(xp)(svdvals)(x, **kwargs)
     if rtol is None:
         tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * xp.finfo(S.dtype).eps
     else:
@@ -88,7 +118,12 @@ def matrix_rank(x: Array,
     return xp.count_nonzero(S > tol, axis=-1)
 
 def pinv(
-    x: Array, /, xp: Namespace, *, rtol: Optional[Union[float, Array]] = None, **kwargs
+    x: Array,
+    /,
+    xp: Namespace,
+    *,
+    rtol: float | Array | None = None,
+    **kwargs: object,
 ) -> Array:
     # this is different from xp.linalg.pinv, which does not multiply the
     # default tolerance by max(M, N).
@@ -104,13 +139,13 @@ def matrix_norm(
     xp: Namespace,
     *,
     keepdims: bool = False,
-    ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro',
+    ord: float | Literal["fro", "nuc"] | None = "fro",
 ) -> Array:
     return xp.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord)
 
 # svdvals is not in NumPy (but it is in SciPy). It is equivalent to
 # xp.linalg.svd(compute_uv=False).
-def svdvals(x: Array, /, xp: Namespace) -> Union[Array, Tuple[Array, ...]]:
+def svdvals(x: Array, /, xp: Namespace) -> Array | tuple[Array, ...]:
     return xp.linalg.svd(x, compute_uv=False)
 
 def vector_norm(
@@ -118,9 +153,9 @@ def vector_norm(
     /,
     xp: Namespace,
     *,
-    axis: Optional[Union[int, Tuple[int, ...]]] = None,
+    axis: int | tuple[int, ...] | None = None,
     keepdims: bool = False,
-    ord: Optional[Union[int, float]] = 2,
+    ord: float = 2,
 ) -> Array:
     # xp.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or
     # when axis=None and the input is 2-D, so to force a vector norm, we make
@@ -133,7 +168,10 @@ def vector_norm(
     elif isinstance(axis, tuple):
         # Note: The axis argument supports any number of axes, whereas
         # xp.linalg.norm() only supports a single axis for vector norm.
-        normalized_axis = normalize_axis_tuple(axis, x.ndim)
+        normalized_axis = cast(
+            "tuple[int, ...]",
+            normalize_axis_tuple(axis, x.ndim),  # pyright: ignore[reportCallIssue]
+        )
         rest = tuple(i for i in range(x.ndim) if i not in normalized_axis)
         newshape = axis + rest
         _x = xp.transpose(x, newshape).reshape(
@@ -149,7 +187,13 @@ def vector_norm(
         # We can't reuse xp.linalg.norm(keepdims) because of the reshape hacks
         # above to avoid matrix norm logic.
         shape = list(x.shape)
-        _axis = normalize_axis_tuple(range(x.ndim) if axis is None else axis, x.ndim)
+        _axis = cast(
+            "tuple[int, ...]",
+            normalize_axis_tuple(  # pyright: ignore[reportCallIssue]
+                range(x.ndim) if axis is None else axis,
+                x.ndim,
+            ),
+        )
         for i in _axis:
             shape[i] = 1
         res = xp.reshape(res, tuple(shape))
@@ -159,11 +203,17 @@ def vector_norm(
 # xp.diagonal and xp.trace operate on the first two axes whereas these
 # operates on the last two
 
-def diagonal(x: Array, /, xp: Namespace, *, offset: int = 0, **kwargs) -> Array:
+def diagonal(x: Array, /, xp: Namespace, *, offset: int = 0, **kwargs: object) -> Array:
     return xp.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs)
 
 def trace(
-    x: Array, /, xp: Namespace, *, offset: int = 0, dtype=None, **kwargs
+    x: Array,
+    /,
+    xp: Namespace,
+    *,
+    offset: int = 0,
+    dtype: DType | None = None,
+    **kwargs: object,
 ) -> Array:
     return xp.asarray(
         xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs)

From dc79e3f1731b966294142067416e9c76cf5bb607 Mon Sep 17 00:00:00 2001
From: jorenham <jhammudoglu@gmail.com>
Date: Sat, 22 Mar 2025 22:26:36 +0100
Subject: [PATCH 06/24] TYP: fix/ignore typing errors in `numpy.__init__`

---
 array_api_compat/numpy/__init__.py | 28 +++++++++++++++++++++-------
 1 file changed, 21 insertions(+), 7 deletions(-)

diff --git a/array_api_compat/numpy/__init__.py b/array_api_compat/numpy/__init__.py
index 6a5d9867..caaf67bb 100644
--- a/array_api_compat/numpy/__init__.py
+++ b/array_api_compat/numpy/__init__.py
@@ -1,10 +1,16 @@
-from numpy import * # noqa: F403
+# ruff: noqa: PLC0414
+from typing import Final
+
+from numpy import *  # noqa: F403  # pyright: ignore[reportWildcardImportFromLibrary]
 
 # from numpy import * doesn't overwrite these builtin names
-from numpy import abs, max, min, round # noqa: F401
+from numpy import abs as abs
+from numpy import max as max
+from numpy import min as min
+from numpy import round as round
 
 # These imports may overwrite names from the import * above.
-from ._aliases import * # noqa: F403
+from ._aliases import *  # noqa: F403
 
 # Don't know why, but we have to do an absolute import to import linalg. If we
 # instead do
@@ -13,9 +19,17 @@
 #
 # It doesn't overwrite np.linalg from above. The import is generated
 # dynamically so that the library can be vendored.
-__import__(__package__ + '.linalg')
-__import__(__package__ + '.fft')
+__import__(__package__ + ".linalg")
+
+__import__(__package__ + ".fft")
+
+from ..common._helpers import *  # noqa: F403
+from .linalg import matrix_transpose, vecdot  # noqa: F401
 
-from .linalg import matrix_transpose, vecdot # noqa: F401
+try:
+    # Used in asarray(). Not present in older versions.
+    from numpy import _CopyMode  # noqa: F401
+except ImportError:
+    pass
 
-__array_api_version__ = '2024.12'
+__array_api_version__ = "2024.12"

From 9643256ae1b7fd7af65b3255c992ce760dc6a4b5 Mon Sep 17 00:00:00 2001
From: jorenham <jhammudoglu@gmail.com>
Date: Sat, 22 Mar 2025 22:31:55 +0100
Subject: [PATCH 07/24] TYP: fix typing errors in `numpy._typing`

---
 array_api_compat/numpy/_typing.py | 21 +++++++--------------
 1 file changed, 7 insertions(+), 14 deletions(-)

diff --git a/array_api_compat/numpy/_typing.py b/array_api_compat/numpy/_typing.py
index a6c96924..1f7b247c 100644
--- a/array_api_compat/numpy/_typing.py
+++ b/array_api_compat/numpy/_typing.py
@@ -3,29 +3,22 @@
 __all__ = ["Array", "DType", "Device"]
 _all_ignore = ["np"]
 
-from typing import Literal, TYPE_CHECKING
+from typing import TYPE_CHECKING, Any, Literal, TypeAlias
 
 import numpy as np
-from numpy import ndarray as Array
 
-Device = Literal["cpu"]
+Device: TypeAlias = Literal["cpu"]
 if TYPE_CHECKING:
     # NumPy 1.x on Python 3.10 fails to parse np.dtype[]
-    DType = np.dtype[
-        np.intp
-        | np.int8
-        | np.int16
-        | np.int32
-        | np.int64
-        | np.uint8
-        | np.uint16
-        | np.uint32
-        | np.uint64
+    DType: TypeAlias = np.dtype[
+        np.bool_
+        | np.integer[Any]
         | np.float32
         | np.float64
         | np.complex64
         | np.complex128
-        | np.bool
     ]
+    Array: TypeAlias = np.ndarray[Any, DType]
 else:
     DType = np.dtype
+    Array = np.ndarray

From 18870dcf4fe7d345119cdc082550f679bd7542b3 Mon Sep 17 00:00:00 2001
From: jorenham <jhammudoglu@gmail.com>
Date: Sat, 22 Mar 2025 22:50:35 +0100
Subject: [PATCH 08/24] TYP: fix typing errors in `numpy._aliases`

---
 array_api_compat/numpy/_aliases.py | 81 +++++++++++++++++++-----------
 1 file changed, 52 insertions(+), 29 deletions(-)

diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py
index d1fd46a1..10d84fe7 100644
--- a/array_api_compat/numpy/_aliases.py
+++ b/array_api_compat/numpy/_aliases.py
@@ -1,6 +1,10 @@
+# pyright: reportPrivateUsage=false
 from __future__ import annotations
 
-from typing import Optional, Union
+from builtins import bool as py_bool
+from typing import TYPE_CHECKING, cast
+
+import numpy as np
 
 from .._internal import get_xp
 from ..common import _aliases, _helpers
@@ -8,7 +12,12 @@
 from ._info import __array_namespace_info__
 from ._typing import Array, Device, DType
 
-import numpy as np
+if TYPE_CHECKING:
+    from typing import Any, Literal, TypeAlias
+
+    from typing_extensions import Buffer, TypeIs
+
+    _Copy: TypeAlias = py_bool | Literal[2] | np._CopyMode
 
 bool = np.bool_
 
@@ -65,9 +74,9 @@
 iinfo = get_xp(np)(_aliases.iinfo)
 
 
-def _supports_buffer_protocol(obj):
+def _supports_buffer_protocol(obj: object) -> TypeIs[Buffer]:  # pyright: ignore[reportUnusedFunction]
     try:
-        memoryview(obj)
+        memoryview(obj)  # pyright: ignore[reportArgumentType]
     except TypeError:
         return False
     return True
@@ -78,18 +87,13 @@ def _supports_buffer_protocol(obj):
 # complicated enough that it's easier to define it separately for each module
 # rather than trying to combine everything into one function in common/
 def asarray(
-    obj: (
-        Array 
-        | bool | int | float | complex 
-        | NestedSequence[bool | int | float | complex] 
-        | SupportsBufferProtocol
-    ),
+    obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol,
     /,
     *,
-    dtype: Optional[DType] = None,
-    device: Optional[Device] = None,
-    copy: Optional[Union[bool, np._CopyMode]] = None,
-    **kwargs,
+    dtype: DType | None = None,
+    device: Device | None = None,
+    copy: _Copy | None = None,
+    **kwargs: Any,
 ) -> Array:
     """
     Array API compatibility wrapper for asarray().
@@ -106,7 +110,7 @@ def asarray(
     elif copy is True:
         copy = np._CopyMode.ALWAYS
 
-    return np.array(obj, copy=copy, dtype=dtype, **kwargs)
+    return np.array(obj, copy=copy, dtype=dtype, **kwargs)  # pyright: ignore
 
 
 def astype(
@@ -114,8 +118,8 @@ def astype(
     dtype: DType,
     /,
     *,
-    copy: bool = True,
-    device: Optional[Device] = None,
+    copy: py_bool = True,
+    device: Device | None = None,
 ) -> Array:
     _helpers._check_device(np, device)
     return x.astype(dtype=dtype, copy=copy)
@@ -123,8 +127,12 @@ def astype(
 
 # count_nonzero returns a python int for axis=None and keepdims=False
 # https://github.com/numpy/numpy/issues/17562
-def count_nonzero(x: Array, axis=None, keepdims=False) -> Array:
-    result = np.count_nonzero(x, axis=axis, keepdims=keepdims)
+def count_nonzero(
+    x: Array,
+    axis: int | tuple[int, ...] | None = None,
+    keepdims: py_bool = False,
+) -> Array:
+    result = cast("Any", np.count_nonzero(x, axis=axis, keepdims=keepdims))  # pyright: ignore
     if axis is None and not keepdims:
         return np.asarray(result)
     return result
@@ -132,25 +140,40 @@ def count_nonzero(x: Array, axis=None, keepdims=False) -> Array:
 
 # These functions are completely new here. If the library already has them
 # (i.e., numpy 2.0), use the library version instead of our wrapper.
-if hasattr(np, 'vecdot'):
+if hasattr(np, "vecdot"):
     vecdot = np.vecdot
 else:
     vecdot = get_xp(np)(_aliases.vecdot)
 
-if hasattr(np, 'isdtype'):
+if hasattr(np, "isdtype"):
     isdtype = np.isdtype
 else:
     isdtype = get_xp(np)(_aliases.isdtype)
 
-if hasattr(np, 'unstack'):
+if hasattr(np, "unstack"):
     unstack = np.unstack
 else:
     unstack = get_xp(np)(_aliases.unstack)
 
-__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'astype',
-                              'acos', 'acosh', 'asin', 'asinh', 'atan',
-                              'atan2', 'atanh', 'bitwise_left_shift',
-                              'bitwise_invert', 'bitwise_right_shift',
-                              'bool', 'concat', 'count_nonzero', 'pow']
-
-_all_ignore = ['np', 'get_xp']
+__all__ = [
+    "__array_namespace_info__",
+    "asarray",
+    "astype",
+    "acos",
+    "acosh",
+    "asin",
+    "asinh",
+    "atan",
+    "atan2",
+    "atanh",
+    "bitwise_left_shift",
+    "bitwise_invert",
+    "bitwise_right_shift",
+    "bool",
+    "concat",
+    "count_nonzero",
+    "pow",
+]
+__all__ += _aliases.__all__
+
+_all_ignore = ["np", "get_xp"]

From 1fb929bed064bbf04217566e3dff931fec49cd3e Mon Sep 17 00:00:00 2001
From: jorenham <jhammudoglu@gmail.com>
Date: Sat, 22 Mar 2025 22:55:40 +0100
Subject: [PATCH 09/24] TYP: fix typing errors in `numpy._info`

---
 array_api_compat/numpy/_info.py | 33 ++++++++++++++++++++++-----------
 1 file changed, 22 insertions(+), 11 deletions(-)

diff --git a/array_api_compat/numpy/_info.py b/array_api_compat/numpy/_info.py
index 365855b8..ec7c39f5 100644
--- a/array_api_compat/numpy/_info.py
+++ b/array_api_compat/numpy/_info.py
@@ -7,24 +7,26 @@
 more details.
 
 """
+from numpy import bool_ as bool
 from numpy import (
+    complex64,
+    complex128,
     dtype,
-    bool_ as bool,
-    intp,
+    float32,
+    float64,
     int8,
     int16,
     int32,
     int64,
+    intp,
     uint8,
     uint16,
     uint32,
     uint64,
-    float32,
-    float64,
-    complex64,
-    complex128,
 )
 
+from ._typing import Device, DType
+
 
 class __array_namespace_info__:
     """
@@ -131,7 +133,11 @@ def default_device(self):
         """
         return "cpu"
 
-    def default_dtypes(self, *, device=None):
+    def default_dtypes(
+        self,
+        *,
+        device: Device | None = None,
+    ) -> dict[str, dtype[intp | float64 | complex128]]:
         """
         The default data types used for new NumPy arrays.
 
@@ -183,7 +189,12 @@ def default_dtypes(self, *, device=None):
             "indexing": dtype(intp),
         }
 
-    def dtypes(self, *, device=None, kind=None):
+    def dtypes(
+        self,
+        *,
+        device: Device | None = None,
+        kind: str | tuple[str, ...] | None = None,
+    ) -> dict[str, DType]:
         """
         The array API data types supported by NumPy.
 
@@ -260,7 +271,7 @@ def dtypes(self, *, device=None, kind=None):
                 "complex128": dtype(complex128),
             }
         if kind == "bool":
-            return {"bool": bool}
+            return {"bool": dtype(bool)}
         if kind == "signed integer":
             return {
                 "int8": dtype(int8),
@@ -312,13 +323,13 @@ def dtypes(self, *, device=None, kind=None):
                 "complex128": dtype(complex128),
             }
         if isinstance(kind, tuple):
-            res = {}
+            res: dict[str, DType] = {}
             for k in kind:
                 res.update(self.dtypes(kind=k))
             return res
         raise ValueError(f"unsupported kind: {kind!r}")
 
-    def devices(self):
+    def devices(self) -> list[Device]:
         """
         The devices supported by NumPy.
 

From 014385fb0514015b0e40958c2a8ccee732e8e396 Mon Sep 17 00:00:00 2001
From: jorenham <jhammudoglu@gmail.com>
Date: Sat, 22 Mar 2025 23:01:54 +0100
Subject: [PATCH 10/24] TYP: fix typing errors in `numpy._fft`

---
 array_api_compat/numpy/fft.py | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/array_api_compat/numpy/fft.py b/array_api_compat/numpy/fft.py
index 28667594..5423bd01 100644
--- a/array_api_compat/numpy/fft.py
+++ b/array_api_compat/numpy/fft.py
@@ -1,10 +1,9 @@
-from numpy.fft import * # noqa: F403
+import numpy as np
 from numpy.fft import __all__ as fft_all
+from numpy.fft import fft2, ifft2, irfft2, rfft2
 
-from ..common import _fft
 from .._internal import get_xp
-
-import numpy as np
+from ..common import _fft
 
 fft = get_xp(np)(_fft.fft)
 ifft = get_xp(np)(_fft.ifft)
@@ -21,7 +20,8 @@
 fftshift = get_xp(np)(_fft.fftshift)
 ifftshift = get_xp(np)(_fft.ifftshift)
 
-__all__ = fft_all + _fft.__all__
+__all__ = ["rfft2", "irfft2", "fft2", "ifft2"]
+__all__ += _fft.__all__
 
 del get_xp
 del np

From ec72825c8e32de076a6f228105ccaf4bb25344a9 Mon Sep 17 00:00:00 2001
From: jorenham <jhammudoglu@gmail.com>
Date: Sat, 22 Mar 2025 23:18:53 +0100
Subject: [PATCH 11/24] TYP: it's a bad idea to import `TypeAlias` from
 `typing` on `python<3.10`

---
 array_api_compat/common/_fft.py     |  4 ++--
 array_api_compat/common/_helpers.py | 13 ++-----------
 array_api_compat/numpy/_aliases.py  |  4 ++--
 array_api_compat/numpy/_typing.py   |  6 ++++--
 4 files changed, 10 insertions(+), 17 deletions(-)

diff --git a/array_api_compat/common/_fft.py b/array_api_compat/common/_fft.py
index 49f948ce..fb95eafc 100644
--- a/array_api_compat/common/_fft.py
+++ b/array_api_compat/common/_fft.py
@@ -1,9 +1,9 @@
 from collections.abc import Sequence
-from typing import Literal, TypeAlias
+from typing import Literal
 
 from ._typing import Array, Device, DType, Namespace
 
-_Norm: TypeAlias = Literal["backward", "ortho", "forward"]
+_Norm = Literal["backward", "ortho", "forward"]
 
 # Note: NumPy fft functions improperly upcast float32 and complex64 to
 # complex128, which is why we require wrapping them all here.
diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py
index 9f016f05..76c9d92c 100644
--- a/array_api_compat/common/_helpers.py
+++ b/array_api_compat/common/_helpers.py
@@ -12,16 +12,7 @@
 import math
 import sys
 import warnings
-from typing import (
-    TYPE_CHECKING,
-    Any,
-    Literal,
-    SupportsIndex,
-    TypeAlias,
-    TypeGuard,
-    cast,
-    overload,
-)
+from typing import TYPE_CHECKING, Any, Literal, SupportsIndex, TypeGuard, cast, overload
 
 from ._typing import Array, Device, HasShape, Namespace, SupportsArrayNamespace
 
@@ -35,7 +26,7 @@
     import numpy.typing as npt
     import sparse  # pyright: ignore[reportMissingTypeStubs]
     import torch
-    from typing_extensions import TypeIs, TypeVar
+    from typing_extensions import TypeAlias, TypeIs, TypeVar
 
     _SizeT = TypeVar("_SizeT", bound=int | None)
 
diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py
index 10d84fe7..39771f4e 100644
--- a/array_api_compat/numpy/_aliases.py
+++ b/array_api_compat/numpy/_aliases.py
@@ -13,9 +13,9 @@
 from ._typing import Array, Device, DType
 
 if TYPE_CHECKING:
-    from typing import Any, Literal, TypeAlias
+    from typing import Any, Literal
 
-    from typing_extensions import Buffer, TypeIs
+    from typing_extensions import Buffer, TypeAlias, TypeIs
 
     _Copy: TypeAlias = py_bool | Literal[2] | np._CopyMode
 
diff --git a/array_api_compat/numpy/_typing.py b/array_api_compat/numpy/_typing.py
index 1f7b247c..d4a285b9 100644
--- a/array_api_compat/numpy/_typing.py
+++ b/array_api_compat/numpy/_typing.py
@@ -3,12 +3,14 @@
 __all__ = ["Array", "DType", "Device"]
 _all_ignore = ["np"]
 
-from typing import TYPE_CHECKING, Any, Literal, TypeAlias
+from typing import TYPE_CHECKING, Any, Literal
 
 import numpy as np
 
-Device: TypeAlias = Literal["cpu"]
+Device = Literal["cpu"]
 if TYPE_CHECKING:
+    from typing_extensions import TypeAlias
+
     # NumPy 1.x on Python 3.10 fails to parse np.dtype[]
     DType: TypeAlias = np.dtype[
         np.bool_

From ccd9bc6e3e6cc1fd12318c92df03c5d979916579 Mon Sep 17 00:00:00 2001
From: jorenham <jhammudoglu@gmail.com>
Date: Sat, 22 Mar 2025 23:20:54 +0100
Subject: [PATCH 12/24] TYP: it's also a bad idea to import `TypeGuard` from
 `typing` on `python<3.10`

---
 array_api_compat/common/_helpers.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py
index 76c9d92c..b14be849 100644
--- a/array_api_compat/common/_helpers.py
+++ b/array_api_compat/common/_helpers.py
@@ -12,7 +12,7 @@
 import math
 import sys
 import warnings
-from typing import TYPE_CHECKING, Any, Literal, SupportsIndex, TypeGuard, cast, overload
+from typing import TYPE_CHECKING, Any, Literal, SupportsIndex, cast, overload
 
 from ._typing import Array, Device, HasShape, Namespace, SupportsArrayNamespace
 
@@ -26,7 +26,7 @@
     import numpy.typing as npt
     import sparse  # pyright: ignore[reportMissingTypeStubs]
     import torch
-    from typing_extensions import TypeAlias, TypeIs, TypeVar
+    from typing_extensions import TypeAlias, TypeGuard, TypeIs, TypeVar
 
     _SizeT = TypeVar("_SizeT", bound=int | None)
 

From b8c78832fb9235f4fe1b674aa63364aa0ae2a87b Mon Sep 17 00:00:00 2001
From: jorenham <jhammudoglu@gmail.com>
Date: Sat, 22 Mar 2025 23:24:05 +0100
Subject: [PATCH 13/24] TYP: don't scare the prehistoric `dtype` from numpy
 1.21

---
 array_api_compat/numpy/_info.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/array_api_compat/numpy/_info.py b/array_api_compat/numpy/_info.py
index ec7c39f5..22570a7c 100644
--- a/array_api_compat/numpy/_info.py
+++ b/array_api_compat/numpy/_info.py
@@ -137,7 +137,7 @@ def default_dtypes(
         self,
         *,
         device: Device | None = None,
-    ) -> dict[str, dtype[intp | float64 | complex128]]:
+    ) -> dict[str, "dtype[intp | float64 | complex128]"]:
         """
         The default data types used for new NumPy arrays.
 

From ef066d1225887144a6310d78feecceb447f61a17 Mon Sep 17 00:00:00 2001
From: jorenham <jhammudoglu@gmail.com>
Date: Sat, 22 Mar 2025 23:25:50 +0100
Subject: [PATCH 14/24] TYP: dust off the DeLorean

---
 array_api_compat/numpy/_info.py | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/array_api_compat/numpy/_info.py b/array_api_compat/numpy/_info.py
index 22570a7c..ca019b2c 100644
--- a/array_api_compat/numpy/_info.py
+++ b/array_api_compat/numpy/_info.py
@@ -7,6 +7,8 @@
 more details.
 
 """
+from __future__ import annotations
+
 from numpy import bool_ as bool
 from numpy import (
     complex64,
@@ -137,7 +139,7 @@ def default_dtypes(
         self,
         *,
         device: Device | None = None,
-    ) -> dict[str, "dtype[intp | float64 | complex128]"]:
+    ) -> dict[str, dtype[intp | float64 | complex128]]:
         """
         The default data types used for new NumPy arrays.
 

From a522dbc9dbae73ac31f294c897cad5d59276dfda Mon Sep 17 00:00:00 2001
From: jorenham <jhammudoglu@gmail.com>
Date: Sat, 22 Mar 2025 23:31:48 +0100
Subject: [PATCH 15/24] TYP: figure out how to drive a DeLorean

---
 array_api_compat/common/_fft.py | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/array_api_compat/common/_fft.py b/array_api_compat/common/_fft.py
index fb95eafc..6fe834dc 100644
--- a/array_api_compat/common/_fft.py
+++ b/array_api_compat/common/_fft.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 from collections.abc import Sequence
 from typing import Literal
 

From bca9c0cedabd73c18cd471f3137dc2455fd9de13 Mon Sep 17 00:00:00 2001
From: jorenham <jhammudoglu@gmail.com>
Date: Tue, 15 Apr 2025 14:24:36 +0200
Subject: [PATCH 16/24] TYP: apply review suggestions

Co-authored-by: crusaderky <crusaderky@gmail.com>
---
 array_api_compat/_internal.py       | 15 ++++++---------
 array_api_compat/common/_aliases.py |  1 +
 array_api_compat/common/_helpers.py | 18 +++++++++++++++---
 array_api_compat/numpy/_aliases.py  |  8 +++-----
 4 files changed, 25 insertions(+), 17 deletions(-)

diff --git a/array_api_compat/_internal.py b/array_api_compat/_internal.py
index 7a973567..6df63097 100644
--- a/array_api_compat/_internal.py
+++ b/array_api_compat/_internal.py
@@ -2,21 +2,18 @@
 Internal helpers
 """
 
+from collections.abc import Callable
 from functools import wraps
 from inspect import signature
-from typing import TYPE_CHECKING
+from types import ModuleType
+from typing import TypeVar
 
 __all__ = ["get_xp"]
 
-if TYPE_CHECKING:
-    from collections.abc import Callable
-    from types import ModuleType
-    from typing import TypeVar
+_T = TypeVar("_T")
 
-    _T = TypeVar("_T")
 
-
-def get_xp(xp: "ModuleType") -> "Callable[[Callable[..., _T]], Callable[..., _T]]":
+def get_xp(xp: ModuleType) -> Callable[[Callable[..., _T]], Callable[..., _T]]:
     """
     Decorator to automatically replace xp with the corresponding array module.
 
@@ -33,7 +30,7 @@ def func(x, /, xp, kwarg=None):
 
     """
 
-    def inner(f: "Callable[..., _T]", /) -> "Callable[..., _T]":
+    def inner(f: Callable[..., _T], /) -> Callable[..., _T]:
         @wraps(f)
         def wrapped_f(*args: object, **kwargs: object) -> object:
             return f(*args, xp=xp, **kwargs)
diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py
index cd4b6af8..8b106d72 100644
--- a/array_api_compat/common/_aliases.py
+++ b/array_api_compat/common/_aliases.py
@@ -13,6 +13,7 @@
 from ._typing import Array, Device, DType, Namespace
 
 if TYPE_CHECKING:
+    # TODO: import from typing (requires Python >=3.13)
     from typing_extensions import TypeIs
 
 # These functions are modified from the NumPy versions.
diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py
index b14be849..551d09f7 100644
--- a/array_api_compat/common/_helpers.py
+++ b/array_api_compat/common/_helpers.py
@@ -12,12 +12,22 @@
 import math
 import sys
 import warnings
-from typing import TYPE_CHECKING, Any, Literal, SupportsIndex, cast, overload
+from collections.abc import Collection
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Literal,
+    SupportsIndex,
+    TypeAlias,
+    TypeGuard,
+    TypeVar,
+    cast,
+    overload,
+)
 
 from ._typing import Array, Device, HasShape, Namespace, SupportsArrayNamespace
 
 if TYPE_CHECKING:
-    from collections.abc import Collection
 
     import dask.array as da
     import jax
@@ -26,7 +36,9 @@
     import numpy.typing as npt
     import sparse  # pyright: ignore[reportMissingTypeStubs]
     import torch
-    from typing_extensions import TypeAlias, TypeGuard, TypeIs, TypeVar
+
+    # TODO: import from typing (requires Python >=3.13)
+    from typing_extensions import TypeIs, TypeVar
 
     _SizeT = TypeVar("_SizeT", bound=int | None)
 
diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py
index 39771f4e..41d02dd0 100644
--- a/array_api_compat/numpy/_aliases.py
+++ b/array_api_compat/numpy/_aliases.py
@@ -2,7 +2,7 @@
 from __future__ import annotations
 
 from builtins import bool as py_bool
-from typing import TYPE_CHECKING, cast
+from typing import TYPE_CHECKING, Any, Literal, TypeAlias, cast
 
 import numpy as np
 
@@ -13,11 +13,9 @@
 from ._typing import Array, Device, DType
 
 if TYPE_CHECKING:
-    from typing import Any, Literal
+    from typing_extensions import Buffer, TypeIs
 
-    from typing_extensions import Buffer, TypeAlias, TypeIs
-
-    _Copy: TypeAlias = py_bool | Literal[2] | np._CopyMode
+_Copy: TypeAlias = py_bool | Literal[2] | np._CopyMode
 
 bool = np.bool_
 

From 0dd925f79185d13aed000f98a00b272a3bd6a066 Mon Sep 17 00:00:00 2001
From: jorenham <jhammudoglu@gmail.com>
Date: Tue, 15 Apr 2025 14:25:10 +0200
Subject: [PATCH 17/24] TYP: sprinkle some `TypeAlias`es and `Final`s around

---
 array_api_compat/common/_fft.py    |  4 ++--
 array_api_compat/common/_typing.py | 10 +++++-----
 array_api_compat/numpy/__init__.py |  2 +-
 array_api_compat/numpy/_typing.py  | 10 +++++-----
 4 files changed, 13 insertions(+), 13 deletions(-)

diff --git a/array_api_compat/common/_fft.py b/array_api_compat/common/_fft.py
index 6fe834dc..5e75e85f 100644
--- a/array_api_compat/common/_fft.py
+++ b/array_api_compat/common/_fft.py
@@ -1,11 +1,11 @@
 from __future__ import annotations
 
 from collections.abc import Sequence
-from typing import Literal
+from typing import Literal, TypeAlias
 
 from ._typing import Array, Device, DType, Namespace
 
-_Norm = Literal["backward", "ortho", "forward"]
+_Norm: TypeAlias = Literal["backward", "ortho", "forward"]
 
 # Note: NumPy fft functions improperly upcast float32 and complex64 to
 # complex128, which is why we require wrapping them all here.
diff --git a/array_api_compat/common/_typing.py b/array_api_compat/common/_typing.py
index a38d083b..cf9a00d7 100644
--- a/array_api_compat/common/_typing.py
+++ b/array_api_compat/common/_typing.py
@@ -1,7 +1,7 @@
 from __future__ import annotations
 
 from types import ModuleType as Namespace
-from typing import Any, Protocol, TypeVar
+from typing import Any, Protocol, TypeAlias, TypeVar
 
 __all__ = [
     "Array",
@@ -30,7 +30,7 @@ class HasShape(Protocol[_T_co]):
     def shape(self, /) -> _T_co: ...
 
 
-SupportsBufferProtocol = Any
-Array = Any
-Device = Any
-DType = Any
+SupportsBufferProtocol: TypeAlias = Any
+Array: TypeAlias = Any
+Device: TypeAlias = Any
+DType: TypeAlias = Any
diff --git a/array_api_compat/numpy/__init__.py b/array_api_compat/numpy/__init__.py
index caaf67bb..f7b558ba 100644
--- a/array_api_compat/numpy/__init__.py
+++ b/array_api_compat/numpy/__init__.py
@@ -32,4 +32,4 @@
 except ImportError:
     pass
 
-__array_api_version__ = "2024.12"
+__array_api_version__: Final = "2024.12"
diff --git a/array_api_compat/numpy/_typing.py b/array_api_compat/numpy/_typing.py
index d4a285b9..0796f7d0 100644
--- a/array_api_compat/numpy/_typing.py
+++ b/array_api_compat/numpy/_typing.py
@@ -3,13 +3,13 @@
 __all__ = ["Array", "DType", "Device"]
 _all_ignore = ["np"]
 
-from typing import TYPE_CHECKING, Any, Literal
+from typing import TYPE_CHECKING, Any, Literal, TypeAlias
 
 import numpy as np
 
-Device = Literal["cpu"]
+Device: TypeAlias = Literal["cpu"]
+
 if TYPE_CHECKING:
-    from typing_extensions import TypeAlias
 
     # NumPy 1.x on Python 3.10 fails to parse np.dtype[]
     DType: TypeAlias = np.dtype[
@@ -22,5 +22,5 @@
     ]
     Array: TypeAlias = np.ndarray[Any, DType]
 else:
-    DType = np.dtype
-    Array = np.ndarray
+    DType: TypeAlias = np.dtype
+    Array: TypeAlias = np.ndarray

From 953d7c051be9aca5b269ae5e45eadd35ff8b9dea Mon Sep 17 00:00:00 2001
From: jorenham <jhammudoglu@gmail.com>
Date: Tue, 15 Apr 2025 14:41:59 +0200
Subject: [PATCH 18/24] TYP: `__dir__`

---
 array_api_compat/_internal.py       |  9 +++++++--
 array_api_compat/common/__init__.py |  2 +-
 array_api_compat/common/_aliases.py |  5 ++++-
 array_api_compat/common/_fft.py     |  3 +++
 array_api_compat/common/_helpers.py |  3 +++
 array_api_compat/common/_linalg.py  |  4 ++++
 array_api_compat/common/_typing.py  | 27 ++++++++++++++++-----------
 array_api_compat/numpy/_aliases.py  |  5 ++++-
 array_api_compat/numpy/_info.py     |  7 +++++++
 array_api_compat/numpy/_typing.py   | 10 +++++++---
 array_api_compat/numpy/fft.py       |  6 ++++++
 11 files changed, 62 insertions(+), 19 deletions(-)

diff --git a/array_api_compat/_internal.py b/array_api_compat/_internal.py
index 6df63097..cd8d939f 100644
--- a/array_api_compat/_internal.py
+++ b/array_api_compat/_internal.py
@@ -8,8 +8,6 @@
 from types import ModuleType
 from typing import TypeVar
 
-__all__ = ["get_xp"]
-
 _T = TypeVar("_T")
 
 
@@ -52,3 +50,10 @@ def wrapped_f(*args: object, **kwargs: object) -> object:
         return wrapped_f  # pyright: ignore[reportReturnType]
 
     return inner
+
+
+__all__ = ["get_xp"]
+
+
+def __dir__() -> list[str]:
+    return __all__
diff --git a/array_api_compat/common/__init__.py b/array_api_compat/common/__init__.py
index 91ab1c40..82360807 100644
--- a/array_api_compat/common/__init__.py
+++ b/array_api_compat/common/__init__.py
@@ -1 +1 @@
-from ._helpers import * # noqa: F403
+from ._helpers import *  # noqa: F403
diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py
index 8b106d72..7f3cf914 100644
--- a/array_api_compat/common/_aliases.py
+++ b/array_api_compat/common/_aliases.py
@@ -720,5 +720,8 @@ def iinfo(type_: DType | Array, /, xp: Namespace) -> Any:
     "finfo",
     "iinfo",
 ]
-
 _all_ignore = ["inspect", "array_namespace", "NamedTuple"]
+
+
+def __dir__() -> list[str]:
+    return __all__
diff --git a/array_api_compat/common/_fft.py b/array_api_compat/common/_fft.py
index 5e75e85f..18839d37 100644
--- a/array_api_compat/common/_fft.py
+++ b/array_api_compat/common/_fft.py
@@ -208,3 +208,6 @@ def ifftshift(
     "fftshift",
     "ifftshift",
 ]
+
+def __dir__() -> list[str]:
+    return __all__
diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py
index 551d09f7..0ce6edf0 100644
--- a/array_api_compat/common/_helpers.py
+++ b/array_api_compat/common/_helpers.py
@@ -1044,3 +1044,6 @@ def is_lazy_array(x: object) -> bool:
 ]
 
 _all_ignore = ["sys", "math", "inspect", "warnings"]
+
+def __dir__() -> list[str]:
+    return __all__
diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py
index 5a6f070e..7e002aed 100644
--- a/array_api_compat/common/_linalg.py
+++ b/array_api_compat/common/_linalg.py
@@ -226,3 +226,7 @@ def trace(
            'trace']
 
 _all_ignore = ['math', 'normalize_axis_tuple', 'get_xp', 'np', 'isdtype']
+
+
+def __dir__() -> list[str]:
+    return __all__
diff --git a/array_api_compat/common/_typing.py b/array_api_compat/common/_typing.py
index cf9a00d7..477b6a4e 100644
--- a/array_api_compat/common/_typing.py
+++ b/array_api_compat/common/_typing.py
@@ -3,17 +3,6 @@
 from types import ModuleType as Namespace
 from typing import Any, Protocol, TypeAlias, TypeVar
 
-__all__ = [
-    "Array",
-    "SupportsArrayNamespace",
-    "DType",
-    "Device",
-    "HasShape",
-    "Namespace",
-    "NestedSequence",
-    "SupportsBufferProtocol",
-]
-
 _T_co = TypeVar("_T_co", covariant=True)
 
 class NestedSequence(Protocol[_T_co]):
@@ -34,3 +23,19 @@ def shape(self, /) -> _T_co: ...
 Array: TypeAlias = Any
 Device: TypeAlias = Any
 DType: TypeAlias = Any
+
+
+__all__ = [
+    "Array",
+    "SupportsArrayNamespace",
+    "DType",
+    "Device",
+    "HasShape",
+    "Namespace",
+    "NestedSequence",
+    "SupportsBufferProtocol",
+]
+
+
+def __dir__() -> list[str]:
+    return __all__
diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py
index 41d02dd0..2c03041f 100644
--- a/array_api_compat/numpy/_aliases.py
+++ b/array_api_compat/numpy/_aliases.py
@@ -173,5 +173,8 @@ def count_nonzero(
     "pow",
 ]
 __all__ += _aliases.__all__
-
 _all_ignore = ["np", "get_xp"]
+
+
+def __dir__() -> list[str]:
+    return __all__
diff --git a/array_api_compat/numpy/_info.py b/array_api_compat/numpy/_info.py
index ca019b2c..f307f62c 100644
--- a/array_api_compat/numpy/_info.py
+++ b/array_api_compat/numpy/_info.py
@@ -357,3 +357,10 @@ def devices(self) -> list[Device]:
 
         """
         return ["cpu"]
+
+
+__all__ = ["__array_namespace_info__"]
+
+
+def __dir__() -> list[str]:
+    return __all__
diff --git a/array_api_compat/numpy/_typing.py b/array_api_compat/numpy/_typing.py
index 0796f7d0..e771c788 100644
--- a/array_api_compat/numpy/_typing.py
+++ b/array_api_compat/numpy/_typing.py
@@ -1,8 +1,5 @@
 from __future__ import annotations
 
-__all__ = ["Array", "DType", "Device"]
-_all_ignore = ["np"]
-
 from typing import TYPE_CHECKING, Any, Literal, TypeAlias
 
 import numpy as np
@@ -24,3 +21,10 @@
 else:
     DType: TypeAlias = np.dtype
     Array: TypeAlias = np.ndarray
+
+__all__ = ["Array", "DType", "Device"]
+_all_ignore = ["np"]
+
+
+def __dir__() -> list[str]:
+    return __all__
diff --git a/array_api_compat/numpy/fft.py b/array_api_compat/numpy/fft.py
index 5423bd01..06875f00 100644
--- a/array_api_compat/numpy/fft.py
+++ b/array_api_compat/numpy/fft.py
@@ -20,9 +20,15 @@
 fftshift = get_xp(np)(_fft.fftshift)
 ifftshift = get_xp(np)(_fft.ifftshift)
 
+
 __all__ = ["rfft2", "irfft2", "fft2", "ifft2"]
 __all__ += _fft.__all__
 
+
+def __dir__() -> list[str]:
+    return __all__
+
+
 del get_xp
 del np
 del fft_all

From c66b750684f85ed92b355b5a9b90ec74da2e4b39 Mon Sep 17 00:00:00 2001
From: jorenham <jhammudoglu@gmail.com>
Date: Tue, 15 Apr 2025 15:21:03 +0200
Subject: [PATCH 19/24] TYP: fix typing errors in `numpy.linalg`

---
 array_api_compat/numpy/linalg.py | 97 ++++++++++++++++++++++++--------
 1 file changed, 75 insertions(+), 22 deletions(-)

diff --git a/array_api_compat/numpy/linalg.py b/array_api_compat/numpy/linalg.py
index 8f01593b..2d3e731d 100644
--- a/array_api_compat/numpy/linalg.py
+++ b/array_api_compat/numpy/linalg.py
@@ -1,14 +1,35 @@
-from numpy.linalg import * # noqa: F403
-from numpy.linalg import __all__ as linalg_all
-import numpy as _np
+# pyright: reportAttributeAccessIssue=false
+# pyright: reportUnknownArgumentType=false
+# pyright: reportUnknownMemberType=false
+# pyright: reportUnknownVariableType=false
+
+from __future__ import annotations
+
+import numpy as np
+
+# intersection of `np.linalg.__all__` on numpy 1.22 and 2.2, minus `_linalg.__all__`
+from numpy.linalg import (
+    LinAlgError,
+    cond,
+    det,
+    eig,
+    eigvals,
+    eigvalsh,
+    inv,
+    lstsq,
+    matrix_power,
+    multi_dot,
+    norm,
+    tensorinv,
+    tensorsolve,
+)
 
-from ..common import _linalg
 from .._internal import get_xp
+from ..common import _linalg
 
 # These functions are in both the main and linalg namespaces
-from ._aliases import matmul, matrix_transpose, tensordot, vecdot # noqa: F401
-
-import numpy as np
+from ._aliases import matmul, matrix_transpose, tensordot, vecdot  # noqa: F401
+from ._typing import Array
 
 cross = get_xp(np)(_linalg.cross)
 outer = get_xp(np)(_linalg.outer)
@@ -38,19 +59,28 @@
 # To workaround this, the below is the code from np.linalg.solve except
 # only calling solve1 in the exactly 1D case.
 
+
 # This code is here instead of in common because it is numpy specific. Also
 # note that CuPy's solve() does not currently support broadcasting (see
 # https://github.com/cupy/cupy/blob/main/cupy/cublas.py#L43).
-def solve(x1: _np.ndarray, x2: _np.ndarray, /) -> _np.ndarray:
+def solve(x1: Array, x2: Array, /) -> Array:
     try:
         from numpy.linalg._linalg import (
-        _makearray, _assert_stacked_2d, _assert_stacked_square,
-        _commonType, isComplexType, _raise_linalgerror_singular
+            _assert_stacked_2d,
+            _assert_stacked_square,
+            _commonType,
+            _makearray,
+            _raise_linalgerror_singular,
+            isComplexType,
         )
     except ImportError:
         from numpy.linalg.linalg import (
-        _makearray, _assert_stacked_2d, _assert_stacked_square,
-        _commonType, isComplexType, _raise_linalgerror_singular
+            _assert_stacked_2d,
+            _assert_stacked_square,
+            _commonType,
+            _makearray,
+            _raise_linalgerror_singular,
+            isComplexType,
         )
     from numpy.linalg import _umath_linalg
 
@@ -61,6 +91,7 @@ def solve(x1: _np.ndarray, x2: _np.ndarray, /) -> _np.ndarray:
     t, result_t = _commonType(x1, x2)
 
     # This part is different from np.linalg.solve
+    gufunc: np.ufunc
     if x2.ndim == 1:
         gufunc = _umath_linalg.solve1
     else:
@@ -68,23 +99,45 @@ def solve(x1: _np.ndarray, x2: _np.ndarray, /) -> _np.ndarray:
 
     # This does nothing currently but is left in because it will be relevant
     # when complex dtype support is added to the spec in 2022.
-    signature = 'DD->D' if isComplexType(t) else 'dd->d'
-    with _np.errstate(call=_raise_linalgerror_singular, invalid='call',
-                      over='ignore', divide='ignore', under='ignore'):
-        r = gufunc(x1, x2, signature=signature)
+    signature = "DD->D" if isComplexType(t) else "dd->d"
+    with np.errstate(
+        call=_raise_linalgerror_singular,
+        invalid="call",
+        over="ignore",
+        divide="ignore",
+        under="ignore",
+    ):
+        r: Array = gufunc(x1, x2, signature=signature)
 
     return wrap(r.astype(result_t, copy=False))
 
+
 # These functions are completely new here. If the library already has them
 # (i.e., numpy 2.0), use the library version instead of our wrapper.
-if hasattr(np.linalg, 'vector_norm'):
+if hasattr(np.linalg, "vector_norm"):
     vector_norm = np.linalg.vector_norm
 else:
     vector_norm = get_xp(np)(_linalg.vector_norm)
 
-__all__ = linalg_all + _linalg.__all__ + ['solve']
 
-del get_xp
-del np
-del linalg_all
-del _linalg
+__all__ = [
+    "LinAlgError",
+    "cond",
+    "det",
+    "eig",
+    "eigvals",
+    "eigvalsh",
+    "inv",
+    "lstsq",
+    "matrix_power",
+    "multi_dot",
+    "norm",
+    "tensorinv",
+    "tensorsolve",
+]
+__all__ += _linalg.__all__
+__all__ += ["solve", "vector_norm"]
+
+
+def __dir__() -> list[str]:
+    return __all__

From 9acba462fcae27c9db4027440386c2cdacfef78a Mon Sep 17 00:00:00 2001
From: jorenham <jhammudoglu@gmail.com>
Date: Tue, 15 Apr 2025 15:45:24 +0200
Subject: [PATCH 20/24] TYP: add a `common._typing.Capabilities` typed dict
 type

---
 array_api_compat/common/_typing.py | 16 ++++++++++++++--
 1 file changed, 14 insertions(+), 2 deletions(-)

diff --git a/array_api_compat/common/_typing.py b/array_api_compat/common/_typing.py
index 477b6a4e..331e5a8f 100644
--- a/array_api_compat/common/_typing.py
+++ b/array_api_compat/common/_typing.py
@@ -1,10 +1,11 @@
 from __future__ import annotations
 
 from types import ModuleType as Namespace
-from typing import Any, Protocol, TypeAlias, TypeVar
+from typing import Any, Protocol, TypeAlias, TypedDict, TypeVar
 
 _T_co = TypeVar("_T_co", covariant=True)
 
+
 class NestedSequence(Protocol[_T_co]):
     def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ...
     def __len__(self, /) -> int: ...
@@ -19,6 +20,16 @@ class HasShape(Protocol[_T_co]):
     def shape(self, /) -> _T_co: ...
 
 
+Capabilities = TypedDict(
+    "Capabilities",
+    {
+        "boolean indexing": bool,
+        "data-dependent shapes": bool,
+        "max dimensions": int,
+    },
+)
+
+
 SupportsBufferProtocol: TypeAlias = Any
 Array: TypeAlias = Any
 Device: TypeAlias = Any
@@ -27,12 +38,13 @@ def shape(self, /) -> _T_co: ...
 
 __all__ = [
     "Array",
-    "SupportsArrayNamespace",
+    "Capabilities",
     "DType",
     "Device",
     "HasShape",
     "Namespace",
     "NestedSequence",
+    "SupportsArrayNamespace",
     "SupportsBufferProtocol",
 ]
 

From ba0b4e58069cf9cd5a31df8f8483d867fbf33ea1 Mon Sep 17 00:00:00 2001
From: jorenham <jhammudoglu@gmail.com>
Date: Tue, 15 Apr 2025 16:26:34 +0200
Subject: [PATCH 21/24] TYP: `__array_namespace_info__` helper types

---
 array_api_compat/common/_typing.py | 107 +++++++++++++++++++++++++++--
 1 file changed, 102 insertions(+), 5 deletions(-)

diff --git a/array_api_compat/common/_typing.py b/array_api_compat/common/_typing.py
index 331e5a8f..d7deade1 100644
--- a/array_api_compat/common/_typing.py
+++ b/array_api_compat/common/_typing.py
@@ -1,7 +1,22 @@
 from __future__ import annotations
 
+from collections.abc import Mapping
 from types import ModuleType as Namespace
-from typing import Any, Protocol, TypeAlias, TypedDict, TypeVar
+from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, TypedDict, TypeVar
+
+if TYPE_CHECKING:
+    from _typeshed import Incomplete
+
+    SupportsBufferProtocol: TypeAlias = Incomplete
+    Array: TypeAlias = Incomplete
+    Device: TypeAlias = Incomplete
+    DType: TypeAlias = Incomplete
+else:
+    SupportsBufferProtocol = object
+    Array = object
+    Device = object
+    DType = object
+
 
 _T_co = TypeVar("_T_co", covariant=True)
 
@@ -20,6 +35,7 @@ class HasShape(Protocol[_T_co]):
     def shape(self, /) -> _T_co: ...
 
 
+# Return type of `__array_namespace_info__.default_dtypes`
 Capabilities = TypedDict(
     "Capabilities",
     {
@@ -29,17 +45,98 @@ def shape(self, /) -> _T_co: ...
     },
 )
 
+# Return type of `__array_namespace_info__.default_dtypes`
+DefaultDTypes = TypedDict(
+    "DefaultDTypes",
+    {
+        "real floating": DType,
+        "complex floating": DType,
+        "integral": DType,
+        "indexing": DType,
+    },
+)
+
+
+_DTypeKind: TypeAlias = Literal[
+    "bool",
+    "signed integer",
+    "unsigned integer",
+    "integral",
+    "real floating",
+    "complex floating",
+    "numeric",
+]
+# Type of the `kind` parameter in `__array_namespace_info__.dtypes`
+DTypeKind: TypeAlias = _DTypeKind | tuple[_DTypeKind, ...]
+
+
+# `__array_namespace_info__.dtypes(kind="bool")`
+class DTypesBool(TypedDict):
+    bool: DType
+
+
+# `__array_namespace_info__.dtypes(kind="signed integer")`
+class DTypesSigned(TypedDict):
+    int8: DType
+    int16: DType
+    int32: DType
+    int64: DType
+
+
+# `__array_namespace_info__.dtypes(kind="unsigned integer")`
+class DTypesUnsigned(TypedDict):
+    uint8: DType
+    uint16: DType
+    uint32: DType
+    uint64: DType
+
+
+# `__array_namespace_info__.dtypes(kind="integral")`
+class DTypesIntegral(DTypesSigned, DTypesUnsigned):
+    pass
+
+
+# `__array_namespace_info__.dtypes(kind="real floating")`
+class DTypesReal(TypedDict):
+    float32: DType
+    float64: DType
+
+
+# `__array_namespace_info__.dtypes(kind="complex floating")`
+class DTypesComplex(TypedDict):
+    complex64: DType
+    complex128: DType
+
+
+# `__array_namespace_info__.dtypes(kind="numeric")`
+class DTypesNumeric(DTypesIntegral, DTypesReal, DTypesComplex):
+    pass
+
+
+# `__array_namespace_info__.dtypes(kind=None)` (default)
+class DTypesAll(DTypesBool, DTypesNumeric):
+    pass
+
 
-SupportsBufferProtocol: TypeAlias = Any
-Array: TypeAlias = Any
-Device: TypeAlias = Any
-DType: TypeAlias = Any
+# `__array_namespace_info__.dtypes(kind=?)` (fallback)
+DTypesAny: TypeAlias = Mapping[str, DType]
 
 
 __all__ = [
     "Array",
     "Capabilities",
     "DType",
+    "DTypeKind",
+    "DTypesAny",
+    "DTypesAll",
+    "DTypesBool",
+    "DTypesNumeric",
+    "DTypesIntegral",
+    "DTypesSigned",
+    "DTypesUnsigned",
+    "DTypesReal",
+    "DTypesComplex",
+    "DefaultDTypes",
     "Device",
     "HasShape",
     "Namespace",

From 4278dfba75c3b2a3063abfc35245acc9212c6ff5 Mon Sep 17 00:00:00 2001
From: jorenham <jhammudoglu@gmail.com>
Date: Tue, 15 Apr 2025 16:42:35 +0200
Subject: [PATCH 22/24] TYP: `dask.array` typing fixes and improvements

---
 array_api_compat/dask/array/__init__.py |   8 +-
 array_api_compat/dask/array/_aliases.py | 162 ++++++++++++++----------
 array_api_compat/dask/array/_info.py    |  96 +++++++++++---
 array_api_compat/dask/array/linalg.py   |  22 ++--
 4 files changed, 188 insertions(+), 100 deletions(-)

diff --git a/array_api_compat/dask/array/__init__.py b/array_api_compat/dask/array/__init__.py
index bb649306..1e47b960 100644
--- a/array_api_compat/dask/array/__init__.py
+++ b/array_api_compat/dask/array/__init__.py
@@ -1,9 +1,11 @@
-from dask.array import * # noqa: F403
+from typing import Final
+
+from dask.array import *  # noqa: F403
 
 # These imports may overwrite names from the import * above.
-from ._aliases import * # noqa: F403
+from ._aliases import *  # noqa: F403
 
-__array_api_version__ = '2024.12'
+__array_api_version__: Final = "2024.12"
 
 # See the comment in the numpy __init__.py
 __import__(__package__ + '.linalg')
diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py
index e7ddde78..9687a9cd 100644
--- a/array_api_compat/dask/array/_aliases.py
+++ b/array_api_compat/dask/array/_aliases.py
@@ -1,28 +1,38 @@
+# pyright: reportPrivateUsage=false
+# pyright: reportUnknownArgumentType=false
+# pyright: reportUnknownMemberType=false
+# pyright: reportUnknownVariableType=false
+
 from __future__ import annotations
 
-from typing import Callable, Optional, Union
+from builtins import bool as py_bool
+from collections.abc import Callable
+from typing import TYPE_CHECKING, Any
+
+if TYPE_CHECKING:
+    from typing_extensions import TypeIs
 
+import dask.array as da
 import numpy as np
+from numpy import bool_ as bool
 from numpy import (
-    # dtypes
-    bool_ as bool,
+    can_cast,
+    complex64,
+    complex128,
     float32,
     float64,
     int8,
     int16,
     int32,
     int64,
+    result_type,
     uint8,
     uint16,
     uint32,
     uint64,
-    complex64,
-    complex128,
-    can_cast,
-    result_type,
 )
-import dask.array as da
 
+from ..._internal import get_xp
 from ...common import _aliases, _helpers, array_namespace
 from ...common._typing import (
     Array,
@@ -31,7 +41,6 @@
     NestedSequence,
     SupportsBufferProtocol,
 )
-from ..._internal import get_xp
 from ._info import __array_namespace_info__
 
 isdtype = get_xp(np)(_aliases.isdtype)
@@ -44,8 +53,8 @@ def astype(
     dtype: DType,
     /,
     *,
-    copy: bool = True,
-    device: Optional[Device] = None,
+    copy: py_bool = True,
+    device: Device | None = None,
 ) -> Array:
     """
     Array API compatibility wrapper for astype().
@@ -69,14 +78,14 @@ def astype(
 # not pass stop/step as keyword arguments, which will cause
 # an error with dask
 def arange(
-    start: Union[int, float],
+    start: float,
     /,
-    stop: Optional[Union[int, float]] = None,
-    step: Union[int, float] = 1,
+    stop: float | None = None,
+    step: float = 1,
     *,
-    dtype: Optional[DType] = None,
-    device: Optional[Device] = None,
-    **kwargs,
+    dtype: DType | None = None,
+    device: Device | None = None,
+    **kwargs: object,
 ) -> Array:
     """
     Array API compatibility wrapper for arange().
@@ -87,7 +96,7 @@ def arange(
     # TODO: respect device keyword?
     _helpers._check_device(da, device)
 
-    args = [start]
+    args: list[Any] = [start]
     if stop is not None:
         args.append(stop)
     else:
@@ -137,18 +146,13 @@ def arange(
 
 # asarray also adds the copy keyword, which is not present in numpy 1.0.
 def asarray(
-    obj: (
-        Array 
-        | bool | int | float | complex 
-        | NestedSequence[bool | int | float | complex] 
-        | SupportsBufferProtocol
-    ),
+    obj: complex | NestedSequence[complex] | Array | SupportsBufferProtocol,
     /,
     *,
-    dtype: Optional[DType] = None,
-    device: Optional[Device] = None,
-    copy: Optional[bool] = None,
-    **kwargs,
+    dtype: DType | None = None,
+    device: Device | None = None,
+    copy: py_bool | None = None,
+    **kwargs: object,
 ) -> Array:
     """
     Array API compatibility wrapper for asarray().
@@ -164,7 +168,7 @@ def asarray(
             if copy is False:
                 raise ValueError("Unable to avoid copy when changing dtype")
             obj = obj.astype(dtype)
-        return obj.copy() if copy else obj
+        return obj.copy() if copy else obj  # pyright: ignore[reportAttributeAccessIssue]
 
     if copy is False:
         raise NotImplementedError(
@@ -177,22 +181,21 @@ def asarray(
     return da.from_array(obj)
 
 
-from dask.array import (
-    # Element wise aliases
-    arccos as acos,
-    arccosh as acosh,
-    arcsin as asin,
-    arcsinh as asinh,
-    arctan as atan,
-    arctan2 as atan2,
-    arctanh as atanh,
-    left_shift as bitwise_left_shift,
-    right_shift as bitwise_right_shift,
-    invert as bitwise_invert,
-    power as pow,
-    # Other
-    concatenate as concat,
-)
+# Element wise aliases
+from dask.array import arccos as acos
+from dask.array import arccosh as acosh
+from dask.array import arcsin as asin
+from dask.array import arcsinh as asinh
+from dask.array import arctan as atan
+from dask.array import arctan2 as atan2
+from dask.array import arctanh as atanh
+
+# Other
+from dask.array import concatenate as concat
+from dask.array import invert as bitwise_invert
+from dask.array import left_shift as bitwise_left_shift
+from dask.array import power as pow
+from dask.array import right_shift as bitwise_right_shift
 
 
 # dask.array.clip does not work unless all three arguments are provided.
@@ -202,8 +205,8 @@ def asarray(
 def clip(
     x: Array,
     /,
-    min: Optional[Union[int, float, Array]] = None,
-    max: Optional[Union[int, float, Array]] = None,
+    min: float | Array | None = None,
+    max: float | Array | None = None,
 ) -> Array:
     """
     Array API compatibility wrapper for clip().
@@ -212,8 +215,8 @@ def clip(
     specification for more details.
     """
 
-    def _isscalar(a):
-        return isinstance(a, (int, float, type(None)))
+    def _isscalar(a: float | Array | None, /) -> TypeIs[float | None]:
+        return a is None or isinstance(a, (int, float))
 
     min_shape = () if _isscalar(min) else min.shape
     max_shape = () if _isscalar(max) else max.shape
@@ -266,7 +269,12 @@ def _ensure_single_chunk(x: Array, axis: int) -> tuple[Array, Callable[[Array],
 
 
 def sort(
-    x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True
+    x: Array,
+    /,
+    *,
+    axis: int = -1,
+    descending: py_bool = False,
+    stable: py_bool = True,
 ) -> Array:
     """
     Array API compatibility layer around the lack of sort() in Dask.
@@ -296,7 +304,12 @@ def sort(
 
 
 def argsort(
-    x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True
+    x: Array,
+    /,
+    *,
+    axis: int = -1,
+    descending: py_bool = False,
+    stable: py_bool = True,
 ) -> Array:
     """
     Array API compatibility layer around the lack of argsort() in Dask.
@@ -330,25 +343,34 @@ def argsort(
 # dask.array.count_nonzero does not have keepdims
 def count_nonzero(
     x: Array,
-    axis=None,
-    keepdims=False
+    axis: int | None = None,
+    keepdims: py_bool = False,
 ) -> Array:
-   result = da.count_nonzero(x, axis)
-   if keepdims:
-       if axis is None:
-            return da.reshape(result, [1]*x.ndim)
-       return da.expand_dims(result, axis)
-   return result
-
-
+    result = da.count_nonzero(x, axis)
+    if keepdims:
+        if axis is None:
+            return da.reshape(result, [1] * x.ndim)
+        return da.expand_dims(result, axis)
+    return result
+
+
+__all__ = [
+    "__array_namespace_info__",
+    "count_nonzero",
+    "bool",
+    "int8", "int16", "int32", "int64",
+    "uint8", "uint16", "uint32", "uint64",
+    "float32", "float64",
+    "complex64", "complex128",
+    "asarray", "astype", "can_cast", "result_type",
+    "pow",
+    "concat",
+    "acos", "acosh", "asin", "asinh", "atan", "atan2", "atanh",
+    "bitwise_left_shift", "bitwise_right_shift", "bitwise_invert",
+]  # fmt: skip
+__all__ += _aliases.__all__
+_all_ignore = ["array_namespace", "get_xp", "da", "np"]
 
-__all__ = _aliases.__all__ + [
-                    '__array_namespace_info__', 'asarray', 'astype', 'acos',
-                    'acosh', 'asin', 'asinh', 'atan', 'atan2',
-                    'atanh', 'bitwise_left_shift', 'bitwise_invert',
-                    'bitwise_right_shift', 'concat', 'pow', 'can_cast',
-                    'result_type', 'bool', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64',
-                    'uint8', 'uint16', 'uint32', 'uint64', 'complex64', 'complex128',
-                    'can_cast', 'count_nonzero', 'result_type']
 
-_all_ignore = ["array_namespace", "get_xp", "da", "np"]
+def __dir__() -> list[str]:
+    return __all__
diff --git a/array_api_compat/dask/array/_info.py b/array_api_compat/dask/array/_info.py
index 614f43d9..9e4d736f 100644
--- a/array_api_compat/dask/array/_info.py
+++ b/array_api_compat/dask/array/_info.py
@@ -7,25 +7,51 @@
 more details.
 
 """
+
+# pyright: reportPrivateUsage=false
+
+from __future__ import annotations
+
+from typing import Literal as L
+from typing import TypeAlias, overload
+
+from numpy import bool_ as bool
 from numpy import (
+    complex64,
+    complex128,
     dtype,
-    bool_ as bool,
-    intp,
+    float32,
+    float64,
     int8,
     int16,
     int32,
     int64,
+    intp,
     uint8,
     uint16,
     uint32,
     uint64,
-    float32,
-    float64,
-    complex64,
-    complex128,
 )
 
-from ...common._helpers import _DASK_DEVICE
+from ...common._helpers import _DASK_DEVICE, _dask_device
+from ...common._typing import (
+    Capabilities,
+    DefaultDTypes,
+    DType,
+    DTypeKind,
+    DTypesAll,
+    DTypesAny,
+    DTypesBool,
+    DTypesComplex,
+    DTypesIntegral,
+    DTypesNumeric,
+    DTypesReal,
+    DTypesSigned,
+    DTypesUnsigned,
+)
+
+_Device: TypeAlias = L["cpu"] | _dask_device
+
 
 class __array_namespace_info__:
     """
@@ -59,9 +85,9 @@ class __array_namespace_info__:
 
     """
 
-    __module__ = 'dask.array'
+    __module__ = "dask.array"
 
-    def capabilities(self):
+    def capabilities(self) -> Capabilities:
         """
         Return a dictionary of array API library capabilities.
 
@@ -116,7 +142,7 @@ def capabilities(self):
             "max dimensions": 64,
         }
 
-    def default_device(self):
+    def default_device(self) -> L["cpu"]:
         """
         The default device used for new Dask arrays.
 
@@ -143,7 +169,7 @@ def default_device(self):
         """
         return "cpu"
 
-    def default_dtypes(self, *, device=None):
+    def default_dtypes(self, /, *, device: _Device | None = None) -> DefaultDTypes:
         """
         The default data types used for new Dask arrays.
 
@@ -184,8 +210,8 @@ def default_dtypes(self, *, device=None):
         """
         if device not in ["cpu", _DASK_DEVICE, None]:
             raise ValueError(
-                'Device not understood. Only "cpu" or _DASK_DEVICE is allowed, but received:'
-                f' {device}'
+                f'Device not understood. Only "cpu" or _DASK_DEVICE is allowed, '
+                f"but received: {device!r}"
             )
         return {
             "real floating": dtype(float64),
@@ -194,7 +220,41 @@ def default_dtypes(self, *, device=None):
             "indexing": dtype(intp),
         }
 
-    def dtypes(self, *, device=None, kind=None):
+    @overload
+    def dtypes(
+        self, /, *, device: _Device | None = None, kind: None = None
+    ) -> DTypesAll: ...
+    @overload
+    def dtypes(
+        self, /, *, device: _Device | None = None, kind: L["bool"]
+    ) -> DTypesBool: ...
+    @overload
+    def dtypes(
+        self, /, *, device: _Device | None = None, kind: L["signed integer"]
+    ) -> DTypesSigned: ...
+    @overload
+    def dtypes(
+        self, /, *, device: _Device | None = None, kind: L["unsigned integer"]
+    ) -> DTypesUnsigned: ...
+    @overload
+    def dtypes(
+        self, /, *, device: _Device | None = None, kind: L["integral"]
+    ) -> DTypesIntegral: ...
+    @overload
+    def dtypes(
+        self, /, *, device: _Device | None = None, kind: L["real floating"]
+    ) -> DTypesReal: ...
+    @overload
+    def dtypes(
+        self, /, *, device: _Device | None = None, kind: L["complex floating"]
+    ) -> DTypesComplex: ...
+    @overload
+    def dtypes(
+        self, /, *, device: _Device | None = None, kind: L["numeric"]
+    ) -> DTypesNumeric: ...
+    def dtypes(
+        self, /, *, device: _Device | None = None, kind: DTypeKind | None = None
+    ) -> DTypesAny:
         """
         The array API data types supported by Dask.
 
@@ -251,7 +311,7 @@ def dtypes(self, *, device=None, kind=None):
         if device not in ["cpu", _DASK_DEVICE, None]:
             raise ValueError(
                 'Device not understood. Only "cpu" or _DASK_DEVICE is allowed, but received:'
-                f' {device}'
+                f" {device}"
             )
         if kind is None:
             return {
@@ -321,14 +381,14 @@ def dtypes(self, *, device=None, kind=None):
                 "complex64": dtype(complex64),
                 "complex128": dtype(complex128),
             }
-        if isinstance(kind, tuple):
-            res = {}
+        if isinstance(kind, tuple):  # type: ignore[reportUnnecessaryIsinstanceCall]
+            res: dict[str, DType] = {}
             for k in kind:
                 res.update(self.dtypes(kind=k))
             return res
         raise ValueError(f"unsupported kind: {kind!r}")
 
-    def devices(self):
+    def devices(self) -> list[_Device]:
         """
         The devices supported by Dask.
 
diff --git a/array_api_compat/dask/array/linalg.py b/array_api_compat/dask/array/linalg.py
index bd53f0df..fd509a38 100644
--- a/array_api_compat/dask/array/linalg.py
+++ b/array_api_compat/dask/array/linalg.py
@@ -3,15 +3,16 @@
 from typing import Literal
 
 import dask.array as da
-# Exports
-from dask.array.linalg import * # noqa: F403
-from dask.array import outer
+
 # These functions are in both the main and linalg namespaces
-from dask.array import matmul, tensordot
+from dask.array import matmul, outer, tensordot
+
+# Exports
+from dask.array.linalg import *  # noqa: F403
 
 from ..._internal import get_xp
 from ...common import _linalg
-from ...common._typing import Array
+from ...common._typing import Array as _Array
 from ._aliases import matrix_transpose, vecdot
 
 # dask.array.linalg doesn't have __all__. If it is added, replace this with
@@ -32,8 +33,11 @@
 # supports the mode keyword on QR
 # https://github.com/dask/dask/issues/10388
 #qr = get_xp(da)(_linalg.qr)
-def qr(x: Array, mode: Literal['reduced', 'complete'] = 'reduced',
-       **kwargs) -> QRResult:
+def qr(
+    x: _Array,
+    mode: Literal["reduced", "complete"] = "reduced",
+    **kwargs: object,
+) -> QRResult:
     if mode != "reduced":
         raise ValueError("dask arrays only support using mode='reduced'")
     return QRResult(*da.linalg.qr(x, **kwargs))
@@ -46,12 +50,12 @@ def qr(x: Array, mode: Literal['reduced', 'complete'] = 'reduced',
 # Wrap the svd functions to not pass full_matrices to dask
 # when full_matrices=False (as that is the default behavior for dask),
 # and dask doesn't have the full_matrices keyword
-def svd(x: Array, full_matrices: bool = True, **kwargs) -> SVDResult:
+def svd(x: _Array, full_matrices: bool = True, **kwargs) -> SVDResult:
     if full_matrices:
         raise ValueError("full_matrics=True is not supported by dask.")
     return da.linalg.svd(x, coerce_signs=False, **kwargs)
 
-def svdvals(x: Array) -> Array:
+def svdvals(x: _Array) -> _Array:
     # TODO: can't avoid computing U or V for dask
     _, s, _ =  svd(x)
     return s

From 4f6ef6d46c43284d69c59cff66e9fae88d7498d7 Mon Sep 17 00:00:00 2001
From: Joren Hammudoglu <jhammudoglu@gmail.com>
Date: Thu, 17 Apr 2025 18:10:17 +0200
Subject: [PATCH 23/24] STY: give the `=` some breathing room

Co-authored-by: Lucas Colley <lucas.colley8@gmail.com>
---
 array_api_compat/common/_helpers.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py
index 0ce6edf0..8596c534 100644
--- a/array_api_compat/common/_helpers.py
+++ b/array_api_compat/common/_helpers.py
@@ -40,7 +40,7 @@
     # TODO: import from typing (requires Python >=3.13)
     from typing_extensions import TypeIs, TypeVar
 
-    _SizeT = TypeVar("_SizeT", bound=int | None)
+    _SizeT = TypeVar("_SizeT", bound = int | None)
 
     _ZeroGradientArray: TypeAlias = npt.NDArray[np.void]
     _CupyArray: TypeAlias = Any  # cupy has no py.typed

From d758d6fdfd539816a0c4542804aa877e0b9a2be2 Mon Sep 17 00:00:00 2001
From: jorenham <jhammudoglu@gmail.com>
Date: Thu, 17 Apr 2025 18:50:21 +0200
Subject: [PATCH 24/24] STY: apply review suggestions

Co-authored-by: lucascolley <lucas.colley8@gmail.com>
---
 array_api_compat/common/_aliases.py   |  2 +-
 array_api_compat/common/_helpers.py   | 13 ++++++-------
 array_api_compat/dask/array/linalg.py |  2 +-
 array_api_compat/numpy/_aliases.py    |  6 +++++-
 4 files changed, 13 insertions(+), 10 deletions(-)

diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py
index 7f3cf914..8ea9162a 100644
--- a/array_api_compat/common/_aliases.py
+++ b/array_api_compat/common/_aliases.py
@@ -652,7 +652,7 @@ def sign(x: Array, /, xp: Namespace, **kwargs: object) -> Array:
     if isdtype(x.dtype, "complex floating", xp=xp):
         out = (x / xp.abs(x, **kwargs))[...]
         # sign(0) = 0 but the above formula would give nan
-        out[x == 0 + 0j] = 0 + 0j
+        out[x == 0j] = 0j
     else:
         out = xp.sign(x, **kwargs)
     # CuPy sign() does not propagate nans. See
diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py
index 8596c534..db3e4cd7 100644
--- a/array_api_compat/common/_helpers.py
+++ b/array_api_compat/common/_helpers.py
@@ -16,6 +16,7 @@
 from typing import (
     TYPE_CHECKING,
     Any,
+    Final,
     Literal,
     SupportsIndex,
     TypeAlias,
@@ -56,6 +57,9 @@
         | _CupyArray
     )
 
+_API_VERSIONS_OLD: Final = frozenset({"2021.12", "2022.12", "2023.12"})
+_API_VERSIONS: Final = _API_VERSIONS_OLD | frozenset({"2024.12"})
+
 
 def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]:
     """Return True if `x` is a zero-gradient array.
@@ -477,16 +481,11 @@ def is_array_api_strict_namespace(xp: Namespace) -> bool:
 
 
 def _check_api_version(api_version: str | None) -> None:
-    if api_version in ["2021.12", "2022.12", "2023.12"]:
+    if api_version in _API_VERSIONS_OLD:
         warnings.warn(
             f"The {api_version} version of the array API specification was requested but the returned namespace is actually version 2024.12"
         )
-    elif api_version is not None and api_version not in [
-        "2021.12",
-        "2022.12",
-        "2023.12",
-        "2024.12",
-    ]:
+    elif api_version is not None and api_version not in _API_VERSIONS:
         raise ValueError(
             "Only the 2024.12 version of the array API specification is currently supported"
         )
diff --git a/array_api_compat/dask/array/linalg.py b/array_api_compat/dask/array/linalg.py
index fd509a38..0825386e 100644
--- a/array_api_compat/dask/array/linalg.py
+++ b/array_api_compat/dask/array/linalg.py
@@ -4,7 +4,7 @@
 
 import dask.array as da
 
-# These functions are in both the main and linalg namespaces
+# The `matmul` and `tensordot` functions are in both the main and linalg namespaces
 from dask.array import matmul, outer, tensordot
 
 # Exports
diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py
index 2c03041f..d8792611 100644
--- a/array_api_compat/numpy/_aliases.py
+++ b/array_api_compat/numpy/_aliases.py
@@ -15,6 +15,8 @@
 if TYPE_CHECKING:
     from typing_extensions import Buffer, TypeIs
 
+# The values of the `_CopyMode` enum can be either `False`, `True`, or `2`:
+# https://github.com/numpy/numpy/blob/5a8a6a79d9c2fff8f07dcab5d41e14f8508d673f/numpy/_globals.pyi#L7-L10
 _Copy: TypeAlias = py_bool | Literal[2] | np._CopyMode
 
 bool = np.bool_
@@ -130,7 +132,9 @@ def count_nonzero(
     axis: int | tuple[int, ...] | None = None,
     keepdims: py_bool = False,
 ) -> Array:
-    result = cast("Any", np.count_nonzero(x, axis=axis, keepdims=keepdims))  # pyright: ignore
+    # NOTE: this is currently incorrectly typed in numpy, but will be fixed in
+    # numpy 2.2.5 and 2.3.0: https://github.com/numpy/numpy/pull/28750
+    result = cast("Any", np.count_nonzero(x, axis=axis, keepdims=keepdims))  # pyright: ignore[reportArgumentType, reportCallIssue]
     if axis is None and not keepdims:
         return np.asarray(result)
     return result