From b9b020660ea7e29e4c0fbcc3668cfda4b1691a1b Mon Sep 17 00:00:00 2001 From: jorenham <jhammudoglu@gmail.com> Date: Sat, 22 Mar 2025 20:27:40 +0100 Subject: [PATCH 01/24] TYP: annotate `_internal.get_xp` (and curse at `ParamSpec` for being so useless) --- array_api_compat/_internal.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/array_api_compat/_internal.py b/array_api_compat/_internal.py index 170a1ff9..7a973567 100644 --- a/array_api_compat/_internal.py +++ b/array_api_compat/_internal.py @@ -4,8 +4,19 @@ from functools import wraps from inspect import signature +from typing import TYPE_CHECKING -def get_xp(xp): +__all__ = ["get_xp"] + +if TYPE_CHECKING: + from collections.abc import Callable + from types import ModuleType + from typing import TypeVar + + _T = TypeVar("_T") + + +def get_xp(xp: "ModuleType") -> "Callable[[Callable[..., _T]], Callable[..., _T]]": """ Decorator to automatically replace xp with the corresponding array module. @@ -22,14 +33,14 @@ def func(x, /, xp, kwarg=None): """ - def inner(f): + def inner(f: "Callable[..., _T]", /) -> "Callable[..., _T]": @wraps(f) - def wrapped_f(*args, **kwargs): + def wrapped_f(*args: object, **kwargs: object) -> object: return f(*args, xp=xp, **kwargs) sig = signature(f) new_sig = sig.replace( - parameters=[sig.parameters[i] for i in sig.parameters if i != "xp"] + parameters=[par for i, par in sig.parameters.items() if i != "xp"] ) if wrapped_f.__doc__ is None: @@ -40,7 +51,7 @@ def wrapped_f(*args, **kwargs): specification for more details. """ - wrapped_f.__signature__ = new_sig - return wrapped_f + wrapped_f.__signature__ = new_sig # pyright: ignore[reportAttributeAccessIssue] + return wrapped_f # pyright: ignore[reportReturnType] return inner From 6a17007a32cf38f9328453b27ff6a59eeca85f53 Mon Sep 17 00:00:00 2001 From: jorenham <jhammudoglu@gmail.com> Date: Sat, 22 Mar 2025 21:56:31 +0100 Subject: [PATCH 02/24] TYP: fix (or ignore) typing errors in `common._helpers` (and curse at cupy) --- array_api_compat/common/_helpers.py | 282 ++++++++++++++++++---------- array_api_compat/common/_typing.py | 14 +- 2 files changed, 198 insertions(+), 98 deletions(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 67c619b8..9f016f05 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -5,33 +5,75 @@ that are in __all__ are intended as additional helper functions for use by end users of the compat library. """ + from __future__ import annotations -import sys -import math import inspect +import math +import sys import warnings -from typing import Optional, Union, Any +from typing import ( + TYPE_CHECKING, + Any, + Literal, + SupportsIndex, + TypeAlias, + TypeGuard, + cast, + overload, +) + +from ._typing import Array, Device, HasShape, Namespace, SupportsArrayNamespace + +if TYPE_CHECKING: + from collections.abc import Collection + + import dask.array as da + import jax + import ndonnx as ndx + import numpy as np + import numpy.typing as npt + import sparse # pyright: ignore[reportMissingTypeStubs] + import torch + from typing_extensions import TypeIs, TypeVar -from ._typing import Array, Device, Namespace + _SizeT = TypeVar("_SizeT", bound=int | None) + _ZeroGradientArray: TypeAlias = npt.NDArray[np.void] + _CupyArray: TypeAlias = Any # cupy has no py.typed -def _is_jax_zero_gradient_array(x: object) -> bool: + _ArrayApiObj: TypeAlias = ( + npt.NDArray[Any] + | da.Array + | jax.Array + | ndx.Array + | sparse.SparseArray + | torch.Tensor + | SupportsArrayNamespace[Any] + | _CupyArray + ) + + +def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]: """Return True if `x` is a zero-gradient array. These arrays are a design quirk of Jax that may one day be removed. See https://github.com/google/jax/issues/20620. """ - if 'numpy' not in sys.modules or 'jax' not in sys.modules: + if "numpy" not in sys.modules or "jax" not in sys.modules: return False - import numpy as np import jax + import numpy as np - return isinstance(x, np.ndarray) and x.dtype == jax.float0 + jax_float0 = cast("np.dtype[np.void]", jax.float0) + return ( + isinstance(x, np.ndarray) + and cast("npt.NDArray[np.void]", x).dtype == jax_float0 + ) -def is_numpy_array(x: object) -> bool: +def is_numpy_array(x: object) -> TypeGuard[npt.NDArray[Any]]: """ Return True if `x` is a NumPy array. @@ -53,14 +95,14 @@ def is_numpy_array(x: object) -> bool: is_pydata_sparse_array """ # Avoid importing NumPy if it isn't already - if 'numpy' not in sys.modules: + if "numpy" not in sys.modules: return False import numpy as np # TODO: Should we reject ndarray subclasses? return (isinstance(x, (np.ndarray, np.generic)) - and not _is_jax_zero_gradient_array(x)) + and not _is_jax_zero_gradient_array(x)) # pyright: ignore[reportUnknownArgumentType] # fmt: skip def is_cupy_array(x: object) -> bool: @@ -85,16 +127,16 @@ def is_cupy_array(x: object) -> bool: is_pydata_sparse_array """ # Avoid importing CuPy if it isn't already - if 'cupy' not in sys.modules: + if "cupy" not in sys.modules: return False - import cupy as cp + import cupy as cp # pyright: ignore[reportMissingTypeStubs] # TODO: Should we reject ndarray subclasses? - return isinstance(x, cp.ndarray) + return isinstance(x, cp.ndarray) # pyright: ignore[reportUnknownMemberType] -def is_torch_array(x: object) -> bool: +def is_torch_array(x: object) -> TypeIs[torch.Tensor]: """ Return True if `x` is a PyTorch tensor. @@ -113,7 +155,7 @@ def is_torch_array(x: object) -> bool: is_pydata_sparse_array """ # Avoid importing torch if it isn't already - if 'torch' not in sys.modules: + if "torch" not in sys.modules: return False import torch @@ -122,7 +164,7 @@ def is_torch_array(x: object) -> bool: return isinstance(x, torch.Tensor) -def is_ndonnx_array(x: object) -> bool: +def is_ndonnx_array(x: object) -> TypeIs[ndx.Array]: """ Return True if `x` is a ndonnx Array. @@ -142,7 +184,7 @@ def is_ndonnx_array(x: object) -> bool: is_pydata_sparse_array """ # Avoid importing torch if it isn't already - if 'ndonnx' not in sys.modules: + if "ndonnx" not in sys.modules: return False import ndonnx as ndx @@ -150,7 +192,7 @@ def is_ndonnx_array(x: object) -> bool: return isinstance(x, ndx.Array) -def is_dask_array(x: object) -> bool: +def is_dask_array(x: object) -> TypeIs[da.Array]: """ Return True if `x` is a dask.array Array. @@ -170,7 +212,7 @@ def is_dask_array(x: object) -> bool: is_pydata_sparse_array """ # Avoid importing dask if it isn't already - if 'dask.array' not in sys.modules: + if "dask.array" not in sys.modules: return False import dask.array @@ -178,7 +220,7 @@ def is_dask_array(x: object) -> bool: return isinstance(x, dask.array.Array) -def is_jax_array(x: object) -> bool: +def is_jax_array(x: object) -> TypeIs[jax.Array]: """ Return True if `x` is a JAX array. @@ -199,7 +241,7 @@ def is_jax_array(x: object) -> bool: is_pydata_sparse_array """ # Avoid importing jax if it isn't already - if 'jax' not in sys.modules: + if "jax" not in sys.modules: return False import jax @@ -207,7 +249,7 @@ def is_jax_array(x: object) -> bool: return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x) -def is_pydata_sparse_array(x) -> bool: +def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]: """ Return True if `x` is an array from the `sparse` package. @@ -228,16 +270,16 @@ def is_pydata_sparse_array(x) -> bool: is_jax_array """ # Avoid importing jax if it isn't already - if 'sparse' not in sys.modules: + if "sparse" not in sys.modules: return False - import sparse + import sparse # pyright: ignore[reportMissingTypeStubs] # TODO: Account for other backends. return isinstance(x, sparse.SparseArray) -def is_array_api_obj(x: object) -> bool: +def is_array_api_obj(x: object) -> TypeIs[_ArrayApiObj]: # pyright: ignore[reportUnknownParameterType] """ Return True if `x` is an array API compatible array object. @@ -252,18 +294,20 @@ def is_array_api_obj(x: object) -> bool: is_dask_array is_jax_array """ - return is_numpy_array(x) \ - or is_cupy_array(x) \ - or is_torch_array(x) \ - or is_dask_array(x) \ - or is_jax_array(x) \ - or is_pydata_sparse_array(x) \ - or hasattr(x, '__array_namespace__') + return ( + is_numpy_array(x) + or is_cupy_array(x) + or is_torch_array(x) + or is_dask_array(x) + or is_jax_array(x) + or is_pydata_sparse_array(x) + or hasattr(x, "__array_namespace__") + ) def _compat_module_name() -> str: - assert __name__.endswith('.common._helpers') - return __name__.removesuffix('.common._helpers') + assert __name__.endswith(".common._helpers") + return __name__.removesuffix(".common._helpers") def is_numpy_namespace(xp: Namespace) -> bool: @@ -284,7 +328,7 @@ def is_numpy_namespace(xp: Namespace) -> bool: is_pydata_sparse_namespace is_array_api_strict_namespace """ - return xp.__name__ in {'numpy', _compat_module_name() + '.numpy'} + return xp.__name__ in {"numpy", _compat_module_name() + ".numpy"} def is_cupy_namespace(xp: Namespace) -> bool: @@ -305,7 +349,7 @@ def is_cupy_namespace(xp: Namespace) -> bool: is_pydata_sparse_namespace is_array_api_strict_namespace """ - return xp.__name__ in {'cupy', _compat_module_name() + '.cupy'} + return xp.__name__ in {"cupy", _compat_module_name() + ".cupy"} def is_torch_namespace(xp: Namespace) -> bool: @@ -326,7 +370,7 @@ def is_torch_namespace(xp: Namespace) -> bool: is_pydata_sparse_namespace is_array_api_strict_namespace """ - return xp.__name__ in {'torch', _compat_module_name() + '.torch'} + return xp.__name__ in {"torch", _compat_module_name() + ".torch"} def is_ndonnx_namespace(xp: Namespace) -> bool: @@ -345,7 +389,7 @@ def is_ndonnx_namespace(xp: Namespace) -> bool: is_pydata_sparse_namespace is_array_api_strict_namespace """ - return xp.__name__ == 'ndonnx' + return xp.__name__ == "ndonnx" def is_dask_namespace(xp: Namespace) -> bool: @@ -366,7 +410,7 @@ def is_dask_namespace(xp: Namespace) -> bool: is_pydata_sparse_namespace is_array_api_strict_namespace """ - return xp.__name__ in {'dask.array', _compat_module_name() + '.dask.array'} + return xp.__name__ in {"dask.array", _compat_module_name() + ".dask.array"} def is_jax_namespace(xp: Namespace) -> bool: @@ -388,7 +432,7 @@ def is_jax_namespace(xp: Namespace) -> bool: is_pydata_sparse_namespace is_array_api_strict_namespace """ - return xp.__name__ in {'jax.numpy', 'jax.experimental.array_api'} + return xp.__name__ in {"jax.numpy", "jax.experimental.array_api"} def is_pydata_sparse_namespace(xp: Namespace) -> bool: @@ -407,7 +451,7 @@ def is_pydata_sparse_namespace(xp: Namespace) -> bool: is_jax_namespace is_array_api_strict_namespace """ - return xp.__name__ == 'sparse' + return xp.__name__ == "sparse" def is_array_api_strict_namespace(xp: Namespace) -> bool: @@ -426,21 +470,29 @@ def is_array_api_strict_namespace(xp: Namespace) -> bool: is_jax_namespace is_pydata_sparse_namespace """ - return xp.__name__ == 'array_api_strict' + return xp.__name__ == "array_api_strict" -def _check_api_version(api_version: str) -> None: - if api_version in ['2021.12', '2022.12', '2023.12']: - warnings.warn(f"The {api_version} version of the array API specification was requested but the returned namespace is actually version 2024.12") - elif api_version is not None and api_version not in ['2021.12', '2022.12', - '2023.12', '2024.12']: - raise ValueError("Only the 2024.12 version of the array API specification is currently supported") +def _check_api_version(api_version: str | None) -> None: + if api_version in ["2021.12", "2022.12", "2023.12"]: + warnings.warn( + f"The {api_version} version of the array API specification was requested but the returned namespace is actually version 2024.12" + ) + elif api_version is not None and api_version not in [ + "2021.12", + "2022.12", + "2023.12", + "2024.12", + ]: + raise ValueError( + "Only the 2024.12 version of the array API specification is currently supported" + ) def array_namespace( - *xs: Union[Array, bool, int, float, complex, None], - api_version: Optional[str] = None, - use_compat: Optional[bool] = None, + *xs: Array | complex | None, + api_version: str | None = None, + use_compat: bool | None = None, ) -> Namespace: """ Get the array API compatible namespace for the arrays `xs`. @@ -510,11 +562,13 @@ def your_function(x, y): _use_compat = use_compat in [None, True] - namespaces = set() + namespaces: set[Namespace] = set() for x in xs: if is_numpy_array(x): - from .. import numpy as numpy_namespace import numpy as np + + from .. import numpy as numpy_namespace + if use_compat is True: _check_api_version(api_version) namespaces.add(numpy_namespace) @@ -528,25 +582,31 @@ def your_function(x, y): if _use_compat: _check_api_version(api_version) from .. import cupy as cupy_namespace + namespaces.add(cupy_namespace) else: - import cupy as cp + import cupy as cp # pyright: ignore[reportMissingTypeStubs] + namespaces.add(cp) elif is_torch_array(x): if _use_compat: _check_api_version(api_version) from .. import torch as torch_namespace + namespaces.add(torch_namespace) else: import torch + namespaces.add(torch) elif is_dask_array(x): if _use_compat: _check_api_version(api_version) from ..dask import array as dask_namespace + namespaces.add(dask_namespace) else: import dask.array as da + namespaces.add(da) elif is_jax_array(x): if use_compat is True: @@ -558,23 +618,27 @@ def your_function(x, y): # JAX v0.4.32 and newer implements the array API directly in jax.numpy. # For older JAX versions, it is available via jax.experimental.array_api. import jax.numpy + if hasattr(jax.numpy, "__array_api_version__"): jnp = jax.numpy else: - import jax.experimental.array_api as jnp + import jax.experimental.array_api as jnp # pyright: ignore[reportMissingImports] namespaces.add(jnp) elif is_pydata_sparse_array(x): if use_compat is True: _check_api_version(api_version) raise ValueError("`sparse` does not have an array-api-compat wrapper") else: - import sparse + import sparse # pyright: ignore[reportMissingTypeStubs] # `sparse` is already an array namespace. We do not have a wrapper # submodule for it. namespaces.add(sparse) - elif hasattr(x, '__array_namespace__'): + elif hasattr(x, "__array_namespace__"): if use_compat is True: - raise ValueError("The given array does not have an array-api-compat wrapper") + raise ValueError( + "The given array does not have an array-api-compat wrapper" + ) + x = cast("SupportsArrayNamespace[Any]", x) namespaces.add(x.__array_namespace__(api_version=api_version)) elif isinstance(x, (bool, int, float, complex, type(None))): continue @@ -588,15 +652,16 @@ def your_function(x, y): if len(namespaces) != 1: raise TypeError(f"Multiple namespaces for array inputs: {namespaces}") - xp, = namespaces + (xp,) = namespaces return xp + # backwards compatibility alias get_namespace = array_namespace -def _check_device(bare_xp, device): +def _check_device(bare_xp: Namespace, device: Device) -> None: # pyright: ignore[reportUnusedFunction] """ Validate dummy device on device-less array backends. @@ -609,11 +674,11 @@ def _check_device(bare_xp, device): https://github.com/data-apis/array-api-compat/pull/293 """ - if bare_xp is sys.modules.get('numpy'): + if bare_xp is sys.modules.get("numpy"): if device not in ("cpu", None): raise ValueError(f"Unsupported device for NumPy: {device!r}") - elif bare_xp is sys.modules.get('dask.array'): + elif bare_xp is sys.modules.get("dask.array"): if device not in ("cpu", _DASK_DEVICE, None): raise ValueError(f"Unsupported device for Dask: {device!r}") @@ -622,18 +687,20 @@ def _check_device(bare_xp, device): # when the array backend is not the CPU. # (since it is not easy to tell which device a dask array is on) class _dask_device: - def __repr__(self): + def __repr__(self) -> Literal["DASK_DEVICE"]: return "DASK_DEVICE" + _DASK_DEVICE = _dask_device() + # device() is not on numpy.ndarray or dask.array and to_device() is not on numpy.ndarray # or cupy.ndarray. They are not included in array objects of this library # because this library just reuses the respective ndarray classes without # wrapping or subclassing them. These helper functions can be used instead of # the wrapper functions for libraries that need to support both NumPy/CuPy and # other libraries that use devices. -def device(x: Array, /) -> Device: +def device(x: _ArrayApiObj, /) -> Device: """ Hardware device the array data resides on. @@ -669,7 +736,7 @@ def device(x: Array, /) -> Device: return "cpu" elif is_dask_array(x): # Peek at the metadata of the Dask array to determine type - if is_numpy_array(x._meta): + if is_numpy_array(x._meta): # pyright: ignore # Must be on CPU since backed by numpy return "cpu" return _DASK_DEVICE @@ -679,7 +746,7 @@ def device(x: Array, /) -> Device: # Return None in this case. Note that this workaround breaks # the standard and will result in new arrays being created on the # default device instead of the same device as the input array(s). - x_device = getattr(x, 'device', None) + x_device = getattr(x, "device", None) # Older JAX releases had .device() as a method, which has been replaced # with a property in accordance with the standard. if inspect.ismethod(x_device): @@ -688,27 +755,34 @@ def device(x: Array, /) -> Device: return x_device elif is_pydata_sparse_array(x): # `sparse` will gain `.device`, so check for this first. - x_device = getattr(x, 'device', None) + x_device = getattr(x, "device", None) if x_device is not None: return x_device # Everything but DOK has this attr. try: - inner = x.data + inner = x.data # pyright: ignore except AttributeError: return "cpu" # Return the device of the constituent array - return device(inner) - return x.device + return device(inner) # pyright: ignore + return x.device # pyright: ignore + # Prevent shadowing, used below _device = device + # Based on cupy.array_api.Array.to_device -def _cupy_to_device(x, device, /, stream=None): - import cupy as cp - from cupy.cuda import Device as _Device - from cupy.cuda import stream as stream_module - from cupy_backends.cuda.api import runtime +def _cupy_to_device( + x: _CupyArray, + device: Device, + /, + stream: int | Any | None = None, +) -> _CupyArray: + import cupy as cp # pyright: ignore[reportMissingTypeStubs] + from cupy.cuda import Device as _Device # pyright: ignore + from cupy.cuda import stream as stream_module # pyright: ignore + from cupy_backends.cuda.api import runtime # pyright: ignore if device == x.device: return x @@ -721,33 +795,40 @@ def _cupy_to_device(x, device, /, stream=None): raise ValueError(f"Unsupported device {device!r}") else: # see cupy/cupy#5985 for the reason how we handle device/stream here - prev_device = runtime.getDevice() - prev_stream: stream_module.Stream = None + prev_device: Any = runtime.getDevice() # pyright: ignore[reportUnknownMemberType] + prev_stream = None if stream is not None: - prev_stream = stream_module.get_current_stream() + prev_stream: Any = stream_module.get_current_stream() # pyright: ignore # stream can be an int as specified in __dlpack__, or a CuPy stream if isinstance(stream, int): - stream = cp.cuda.ExternalStream(stream) - elif isinstance(stream, cp.cuda.Stream): + stream = cp.cuda.ExternalStream(stream) # pyright: ignore + elif isinstance(stream, cp.cuda.Stream): # pyright: ignore[reportUnknownMemberType] pass else: - raise ValueError('the input stream is not recognized') - stream.use() + raise ValueError("the input stream is not recognized") + stream.use() # pyright: ignore[reportUnknownMemberType] try: - runtime.setDevice(device.id) + runtime.setDevice(device.id) # pyright: ignore[reportUnknownMemberType] arr = x.copy() finally: - runtime.setDevice(prev_device) + runtime.setDevice(prev_device) # pyright: ignore[reportUnknownMemberType] if stream is not None: prev_stream.use() return arr -def _torch_to_device(x, device, /, stream=None): + +def _torch_to_device( + x: torch.Tensor, + device: torch.device | str | int, + /, + stream: None = None, +) -> torch.Tensor: if stream is not None: raise NotImplementedError return x.to(device) -def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] = None) -> Array: + +def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) -> Array: """ Copy the array from the device on which it currently resides to the specified ``device``. @@ -767,7 +848,7 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] a ``device`` object (see the `Device Support <https://data-apis.org/array-api/latest/design_topics/device_support.html>`__ section of the array API specification). - stream: Optional[Union[int, Any]] + stream: int | Any | None stream object to use during copy. In addition to the types supported in ``array.__dlpack__``, implementations may choose to support any library-specific stream object with the caveat that any code using @@ -799,25 +880,26 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] if is_numpy_array(x): if stream is not None: raise ValueError("The stream argument to to_device() is not supported") - if device == 'cpu': + if device == "cpu": return x raise ValueError(f"Unsupported device {device!r}") elif is_cupy_array(x): # cupy does not yet have to_device return _cupy_to_device(x, device, stream=stream) elif is_torch_array(x): - return _torch_to_device(x, device, stream=stream) + return _torch_to_device(x, device, stream=stream) # pyright: ignore[reportArgumentType] elif is_dask_array(x): if stream is not None: raise ValueError("The stream argument to to_device() is not supported") # TODO: What if our array is on the GPU already? - if device == 'cpu': + if device == "cpu": return x raise ValueError(f"Unsupported device {device!r}") elif is_jax_array(x): if not hasattr(x, "__array_namespace__"): # In JAX v0.4.31 and older, this import adds to_device method to x... - import jax.experimental.array_api # noqa: F401 + import jax.experimental.array_api # noqa: F401 # pyright: ignore + # ... but only on eager JAX. It won't work inside jax.jit. if not hasattr(x, "to_device"): return x @@ -826,10 +908,16 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] # Perform trivial check to return the same array if # device is same instead of err-ing. return x - return x.to_device(device, stream=stream) + return x.to_device(device, stream=stream) # pyright: ignore -def size(x: Array) -> int | None: +@overload +def size(x: HasShape[Collection[SupportsIndex]]) -> int: ... +@overload +def size(x: HasShape[Collection[None]]) -> None: ... +@overload +def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: ... +def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: """ Return the total number of elements of x. @@ -844,7 +932,7 @@ def size(x: Array) -> int | None: # Lazy API compliant arrays, such as ndonnx, can contain None in their shape if None in x.shape: return None - out = math.prod(x.shape) + out = math.prod(cast("Collection[SupportsIndex]", x.shape)) # dask.array.Array.shape can contain NaN return None if math.isnan(out) else out @@ -907,7 +995,7 @@ def is_lazy_array(x: object) -> bool: # on __bool__ (dask is one such example, which however is special-cased above). # Select a single point of the array - s = size(x) + s = size(cast("HasShape[Collection[SupportsIndex | None]]", x)) if s is None: return True xp = array_namespace(x) @@ -952,4 +1040,4 @@ def is_lazy_array(x: object) -> bool: "to_device", ] -_all_ignore = ['sys', 'math', 'inspect', 'warnings'] +_all_ignore = ["sys", "math", "inspect", "warnings"] diff --git a/array_api_compat/common/_typing.py b/array_api_compat/common/_typing.py index 4c3b356b..a38d083b 100644 --- a/array_api_compat/common/_typing.py +++ b/array_api_compat/common/_typing.py @@ -1,11 +1,14 @@ from __future__ import annotations + from types import ModuleType as Namespace -from typing import Any, TypeVar, Protocol +from typing import Any, Protocol, TypeVar __all__ = [ "Array", + "SupportsArrayNamespace", "DType", "Device", + "HasShape", "Namespace", "NestedSequence", "SupportsBufferProtocol", @@ -18,6 +21,15 @@ def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ... def __len__(self, /) -> int: ... +class SupportsArrayNamespace(Protocol[_T_co]): + def __array_namespace__(self, /, *, api_version: str | None) -> _T_co: ... + + +class HasShape(Protocol[_T_co]): + @property + def shape(self, /) -> _T_co: ... + + SupportsBufferProtocol = Any Array = Any Device = Any From 3b134b0c5d835f816b05729de4485ec2575520df Mon Sep 17 00:00:00 2001 From: jorenham <jhammudoglu@gmail.com> Date: Sat, 22 Mar 2025 22:00:36 +0100 Subject: [PATCH 03/24] TYP: fix typing errors in `common._fft` --- array_api_compat/common/_fft.py | 68 ++++++++++++++++----------------- 1 file changed, 34 insertions(+), 34 deletions(-) diff --git a/array_api_compat/common/_fft.py b/array_api_compat/common/_fft.py index bd2a4e1a..49f948ce 100644 --- a/array_api_compat/common/_fft.py +++ b/array_api_compat/common/_fft.py @@ -1,9 +1,9 @@ -from __future__ import annotations - from collections.abc import Sequence -from typing import Union, Optional, Literal +from typing import Literal, TypeAlias + +from ._typing import Array, Device, DType, Namespace -from ._typing import Device, Array, DType, Namespace +_Norm: TypeAlias = Literal["backward", "ortho", "forward"] # Note: NumPy fft functions improperly upcast float32 and complex64 to # complex128, which is why we require wrapping them all here. @@ -13,9 +13,9 @@ def fft( /, xp: Namespace, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, - norm: Literal["backward", "ortho", "forward"] = "backward", + norm: _Norm = "backward", ) -> Array: res = xp.fft.fft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: @@ -27,9 +27,9 @@ def ifft( /, xp: Namespace, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, - norm: Literal["backward", "ortho", "forward"] = "backward", + norm: _Norm = "backward", ) -> Array: res = xp.fft.ifft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: @@ -41,9 +41,9 @@ def fftn( /, xp: Namespace, *, - s: Sequence[int] = None, - axes: Sequence[int] = None, - norm: Literal["backward", "ortho", "forward"] = "backward", + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, + norm: _Norm = "backward", ) -> Array: res = xp.fft.fftn(x, s=s, axes=axes, norm=norm) if x.dtype in [xp.float32, xp.complex64]: @@ -55,9 +55,9 @@ def ifftn( /, xp: Namespace, *, - s: Sequence[int] = None, - axes: Sequence[int] = None, - norm: Literal["backward", "ortho", "forward"] = "backward", + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, + norm: _Norm = "backward", ) -> Array: res = xp.fft.ifftn(x, s=s, axes=axes, norm=norm) if x.dtype in [xp.float32, xp.complex64]: @@ -69,9 +69,9 @@ def rfft( /, xp: Namespace, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, - norm: Literal["backward", "ortho", "forward"] = "backward", + norm: _Norm = "backward", ) -> Array: res = xp.fft.rfft(x, n=n, axis=axis, norm=norm) if x.dtype == xp.float32: @@ -83,9 +83,9 @@ def irfft( /, xp: Namespace, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, - norm: Literal["backward", "ortho", "forward"] = "backward", + norm: _Norm = "backward", ) -> Array: res = xp.fft.irfft(x, n=n, axis=axis, norm=norm) if x.dtype == xp.complex64: @@ -97,9 +97,9 @@ def rfftn( /, xp: Namespace, *, - s: Sequence[int] = None, - axes: Sequence[int] = None, - norm: Literal["backward", "ortho", "forward"] = "backward", + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, + norm: _Norm = "backward", ) -> Array: res = xp.fft.rfftn(x, s=s, axes=axes, norm=norm) if x.dtype == xp.float32: @@ -111,9 +111,9 @@ def irfftn( /, xp: Namespace, *, - s: Sequence[int] = None, - axes: Sequence[int] = None, - norm: Literal["backward", "ortho", "forward"] = "backward", + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, + norm: _Norm = "backward", ) -> Array: res = xp.fft.irfftn(x, s=s, axes=axes, norm=norm) if x.dtype == xp.complex64: @@ -125,9 +125,9 @@ def hfft( /, xp: Namespace, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, - norm: Literal["backward", "ortho", "forward"] = "backward", + norm: _Norm = "backward", ) -> Array: res = xp.fft.hfft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: @@ -139,9 +139,9 @@ def ihfft( /, xp: Namespace, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, - norm: Literal["backward", "ortho", "forward"] = "backward", + norm: _Norm = "backward", ) -> Array: res = xp.fft.ihfft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: @@ -154,8 +154,8 @@ def fftfreq( xp: Namespace, *, d: float = 1.0, - dtype: Optional[DType] = None, - device: Optional[Device] = None, + dtype: DType | None = None, + device: Device | None = None, ) -> Array: if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") @@ -170,8 +170,8 @@ def rfftfreq( xp: Namespace, *, d: float = 1.0, - dtype: Optional[DType] = None, - device: Optional[Device] = None, + dtype: DType | None = None, + device: Device | None = None, ) -> Array: if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") @@ -181,12 +181,12 @@ def rfftfreq( return res def fftshift( - x: Array, /, xp: Namespace, *, axes: Union[int, Sequence[int]] = None + x: Array, /, xp: Namespace, *, axes: int | Sequence[int] | None = None ) -> Array: return xp.fft.fftshift(x, axes=axes) def ifftshift( - x: Array, /, xp: Namespace, *, axes: Union[int, Sequence[int]] = None + x: Array, /, xp: Namespace, *, axes: int | Sequence[int] | None = None ) -> Array: return xp.fft.ifftshift(x, axes=axes) From 344ac1ece797de7880edffadd6af68ba91d7cd11 Mon Sep 17 00:00:00 2001 From: jorenham <jhammudoglu@gmail.com> Date: Sat, 22 Mar 2025 22:11:33 +0100 Subject: [PATCH 04/24] TYP: fix typing errors in `common._aliases` --- array_api_compat/common/_aliases.py | 327 ++++++++++++++++++---------- 1 file changed, 206 insertions(+), 121 deletions(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 351b5bd6..cd4b6af8 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -5,158 +5,169 @@ from __future__ import annotations import inspect -from typing import Any, NamedTuple, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Sequence, cast +from ._helpers import _check_device, array_namespace +from ._helpers import device as _get_device +from ._helpers import is_cupy_namespace as _is_cupy_namespace from ._typing import Array, Device, DType, Namespace -from ._helpers import ( - array_namespace, - _check_device, - device as _get_device, - is_cupy_namespace as _is_cupy_namespace -) +if TYPE_CHECKING: + from typing_extensions import TypeIs # These functions are modified from the NumPy versions. # Creation functions add the device keyword (which does nothing for NumPy and Dask) + def arange( - start: Union[int, float], + start: float, /, - stop: Optional[Union[int, float]] = None, - step: Union[int, float] = 1, + stop: float | None = None, + step: float = 1, *, xp: Namespace, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs) + def empty( - shape: Union[int, Tuple[int, ...]], + shape: int | tuple[int, ...], xp: Namespace, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.empty(shape, dtype=dtype, **kwargs) + def empty_like( x: Array, /, xp: Namespace, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.empty_like(x, dtype=dtype, **kwargs) + def eye( n_rows: int, - n_cols: Optional[int] = None, + n_cols: int | None = None, /, *, xp: Namespace, k: int = 0, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs) + def full( - shape: Union[int, Tuple[int, ...]], - fill_value: bool | int | float | complex, + shape: int | tuple[int, ...], + fill_value: complex, xp: Namespace, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.full(shape, fill_value, dtype=dtype, **kwargs) + def full_like( x: Array, /, - fill_value: bool | int | float | complex, + fill_value: complex, *, xp: Namespace, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.full_like(x, fill_value, dtype=dtype, **kwargs) + def linspace( - start: Union[int, float], - stop: Union[int, float], + start: float, + stop: float, /, num: int, *, xp: Namespace, - dtype: Optional[DType] = None, - device: Optional[Device] = None, + dtype: DType | None = None, + device: Device | None = None, endpoint: bool = True, - **kwargs, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs) + def ones( - shape: Union[int, Tuple[int, ...]], + shape: int | tuple[int, ...], xp: Namespace, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.ones(shape, dtype=dtype, **kwargs) + def ones_like( x: Array, /, xp: Namespace, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.ones_like(x, dtype=dtype, **kwargs) + def zeros( - shape: Union[int, Tuple[int, ...]], + shape: int | tuple[int, ...], xp: Namespace, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.zeros(shape, dtype=dtype, **kwargs) + def zeros_like( x: Array, /, xp: Namespace, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.zeros_like(x, dtype=dtype, **kwargs) + # np.unique() is split into four functions in the array API: # unique_all, unique_counts, unique_inverse, and unique_values (this is done # to remove polymorphic return types). @@ -164,6 +175,7 @@ def zeros_like( # The functions here return namedtuples (np.unique() returns a normal # tuple). + # Note that these named tuples aren't actually part of the standard namespace, # but I don't see any issue with exporting the names here regardless. class UniqueAllResult(NamedTuple): @@ -188,10 +200,11 @@ def _unique_kwargs(xp: Namespace) -> dict[str, bool]: # trying to parse version numbers, just check if equal_nan is in the # signature. s = inspect.signature(xp.unique) - if 'equal_nan' in s.parameters: - return {'equal_nan': False} + if "equal_nan" in s.parameters: + return {"equal_nan": False} return {} + def unique_all(x: Array, /, xp: Namespace) -> UniqueAllResult: kwargs = _unique_kwargs(xp) values, indices, inverse_indices, counts = xp.unique( @@ -215,11 +228,7 @@ def unique_all(x: Array, /, xp: Namespace) -> UniqueAllResult: def unique_counts(x: Array, /, xp: Namespace) -> UniqueCountsResult: kwargs = _unique_kwargs(xp) res = xp.unique( - x, - return_counts=True, - return_index=False, - return_inverse=False, - **kwargs + x, return_counts=True, return_index=False, return_inverse=False, **kwargs ) return UniqueCountsResult(*res) @@ -250,51 +259,58 @@ def unique_values(x: Array, /, xp: Namespace) -> Array: **kwargs, ) + # These functions have different keyword argument names + def std( x: Array, /, xp: Namespace, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - correction: Union[int, float] = 0.0, # correction instead of ddof + axis: int | tuple[int, ...] | None = None, + correction: float = 0.0, # correction instead of ddof keepdims: bool = False, - **kwargs, + **kwargs: object, ) -> Array: return xp.std(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs) + def var( x: Array, /, xp: Namespace, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - correction: Union[int, float] = 0.0, # correction instead of ddof + axis: int | tuple[int, ...] | None = None, + correction: float = 0.0, # correction instead of ddof keepdims: bool = False, - **kwargs, + **kwargs: object, ) -> Array: return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs) + # cumulative_sum is renamed from cumsum, and adds the include_initial keyword # argument + def cumulative_sum( x: Array, /, xp: Namespace, *, - axis: Optional[int] = None, - dtype: Optional[DType] = None, + axis: int | None = None, + dtype: DType | None = None, include_initial: bool = False, - **kwargs, + **kwargs: object, ) -> Array: wrapped_xp = array_namespace(x) # TODO: The standard is not clear about what should happen when x.ndim == 0. if axis is None: if x.ndim > 1: - raise ValueError("axis must be specified in cumulative_sum for more than one dimension") + raise ValueError( + "axis must be specified in cumulative_sum for more than one dimension" + ) axis = 0 res = xp.cumsum(x, axis=axis, dtype=dtype, **kwargs) @@ -304,7 +320,12 @@ def cumulative_sum( initial_shape = list(x.shape) initial_shape[axis] = 1 res = xp.concatenate( - [wrapped_xp.zeros(shape=initial_shape, dtype=res.dtype, device=_get_device(res)), res], + [ + wrapped_xp.zeros( + shape=initial_shape, dtype=res.dtype, device=_get_device(res) + ), + res, + ], axis=axis, ) return res @@ -315,16 +336,18 @@ def cumulative_prod( /, xp: Namespace, *, - axis: Optional[int] = None, - dtype: Optional[DType] = None, + axis: int | None = None, + dtype: DType | None = None, include_initial: bool = False, - **kwargs, + **kwargs: object, ) -> Array: wrapped_xp = array_namespace(x) if axis is None: if x.ndim > 1: - raise ValueError("axis must be specified in cumulative_prod for more than one dimension") + raise ValueError( + "axis must be specified in cumulative_prod for more than one dimension" + ) axis = 0 res = xp.cumprod(x, axis=axis, dtype=dtype, **kwargs) @@ -334,24 +357,30 @@ def cumulative_prod( initial_shape = list(x.shape) initial_shape[axis] = 1 res = xp.concatenate( - [wrapped_xp.ones(shape=initial_shape, dtype=res.dtype, device=_get_device(res)), res], + [ + wrapped_xp.ones( + shape=initial_shape, dtype=res.dtype, device=_get_device(res) + ), + res, + ], axis=axis, ) return res + # The min and max argument names in clip are different and not optional in numpy, and type # promotion behavior is different. def clip( x: Array, /, - min: Optional[Union[int, float, Array]] = None, - max: Optional[Union[int, float, Array]] = None, + min: float | Array | None = None, + max: float | Array | None = None, *, xp: Namespace, # TODO: np.clip has other ufunc kwargs - out: Optional[Array] = None, + out: Array | None = None, ) -> Array: - def _isscalar(a): + def _isscalar(a: object) -> TypeIs[int | float | None]: return isinstance(a, (int, float, type(None))) min_shape = () if _isscalar(min) else min.shape @@ -378,7 +407,6 @@ def _isscalar(a): # but an answer of 0 might be preferred. See # https://github.com/numpy/numpy/issues/24976 for more discussion on this issue. - # At least handle the case of Python integers correctly (see # https://github.com/numpy/numpy/pull/26892). if wrapped_xp.isdtype(x.dtype, "integral"): @@ -390,6 +418,7 @@ def _isscalar(a): dev = _get_device(x) if out is None: out = wrapped_xp.empty(result_shape, dtype=x.dtype, device=dev) + assert out is not None # workaround for a type-narrowing issue in pyright out[()] = x if min is not None: @@ -407,19 +436,21 @@ def _isscalar(a): # Return a scalar for 0-D return out[()] + # Unlike transpose(), the axes argument to permute_dims() is required. -def permute_dims(x: Array, /, axes: Tuple[int, ...], xp: Namespace) -> Array: +def permute_dims(x: Array, /, axes: tuple[int, ...], xp: Namespace) -> Array: return xp.transpose(x, axes) + # np.reshape calls the keyword argument 'newshape' instead of 'shape' def reshape( x: Array, /, - shape: Tuple[int, ...], + shape: tuple[int, ...], xp: Namespace, *, copy: Optional[bool] = None, - **kwargs, + **kwargs: object, ) -> Array: if copy is True: x = x.copy() @@ -429,6 +460,7 @@ def reshape( return y return xp.reshape(x, shape, **kwargs) + # The descending keyword is new in sort and argsort, and 'kind' replaced with # 'stable' def argsort( @@ -439,13 +471,13 @@ def argsort( axis: int = -1, descending: bool = False, stable: bool = True, - **kwargs, + **kwargs: object, ) -> Array: # Note: this keyword argument is different, and the default is different. # We set it in kwargs like this because numpy.sort uses kind='quicksort' # as the default whereas cupy.sort uses kind=None. if stable: - kwargs['kind'] = "stable" + kwargs["kind"] = "stable" if not descending: res = xp.argsort(x, axis=axis, **kwargs) else: @@ -462,6 +494,7 @@ def argsort( res = max_i - res return res + def sort( x: Array, /, @@ -470,68 +503,78 @@ def sort( axis: int = -1, descending: bool = False, stable: bool = True, - **kwargs, + **kwargs: object, ) -> Array: # Note: this keyword argument is different, and the default is different. # We set it in kwargs like this because numpy.sort uses kind='quicksort' # as the default whereas cupy.sort uses kind=None. if stable: - kwargs['kind'] = "stable" + kwargs["kind"] = "stable" res = xp.sort(x, axis=axis, **kwargs) if descending: res = xp.flip(res, axis=axis) return res + # nonzero should error for zero-dimensional arrays -def nonzero(x: Array, /, xp: Namespace, **kwargs) -> Tuple[Array, ...]: +def nonzero(x: Array, /, xp: Namespace, **kwargs: object) -> tuple[Array, ...]: if x.ndim == 0: raise ValueError("nonzero() does not support zero-dimensional arrays") return xp.nonzero(x, **kwargs) + # ceil, floor, and trunc return integers for integer inputs -def ceil(x: Array, /, xp: Namespace, **kwargs) -> Array: + +def ceil(x: Array, /, xp: Namespace, **kwargs: object) -> Array: if xp.issubdtype(x.dtype, xp.integer): return x return xp.ceil(x, **kwargs) -def floor(x: Array, /, xp: Namespace, **kwargs) -> Array: + +def floor(x: Array, /, xp: Namespace, **kwargs: object) -> Array: if xp.issubdtype(x.dtype, xp.integer): return x return xp.floor(x, **kwargs) -def trunc(x: Array, /, xp: Namespace, **kwargs) -> Array: + +def trunc(x: Array, /, xp: Namespace, **kwargs: object) -> Array: if xp.issubdtype(x.dtype, xp.integer): return x return xp.trunc(x, **kwargs) + # linear algebra functions -def matmul(x1: Array, x2: Array, /, xp: Namespace, **kwargs) -> Array: + +def matmul(x1: Array, x2: Array, /, xp: Namespace, **kwargs: object) -> Array: return xp.matmul(x1, x2, **kwargs) + # Unlike transpose, matrix_transpose only transposes the last two axes. def matrix_transpose(x: Array, /, xp: Namespace) -> Array: if x.ndim < 2: raise ValueError("x must be at least 2-dimensional for matrix_transpose") return xp.swapaxes(x, -1, -2) + def tensordot( x1: Array, x2: Array, /, xp: Namespace, *, - axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, - **kwargs, + axes: int | tuple[Sequence[int], Sequence[int]] = 2, + **kwargs: object, ) -> Array: return xp.tensordot(x1, x2, axes=axes, **kwargs) + def vecdot(x1: Array, x2: Array, /, xp: Namespace, *, axis: int = -1) -> Array: if x1.shape[axis] != x2.shape[axis]: raise ValueError("x1 and x2 must have the same size along the given axis") - if hasattr(xp, 'broadcast_tensors'): + if hasattr(xp, "broadcast_tensors"): _broadcast = xp.broadcast_tensors else: _broadcast = xp.broadcast_arrays @@ -543,14 +586,16 @@ def vecdot(x1: Array, x2: Array, /, xp: Namespace, *, axis: int = -1) -> Array: res = xp.conj(x1_[..., None, :]) @ x2_[..., None] return res[..., 0, 0] + # isdtype is a new function in the 2022.12 array API specification. + def isdtype( dtype: DType, - kind: Union[DType, str, Tuple[Union[DType, str], ...]], + kind: DType | str | tuple[DType | str, ...], xp: Namespace, *, - _tuple: bool = True, # Disallow nested tuples + _tuple: bool = True, # Disallow nested tuples ) -> bool: """ Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``. @@ -563,21 +608,24 @@ def isdtype( for more details """ if isinstance(kind, tuple) and _tuple: - return any(isdtype(dtype, k, xp, _tuple=False) for k in kind) + return any( + isdtype(dtype, k, xp, _tuple=False) + for k in cast("tuple[DType | str, ...]", kind) + ) elif isinstance(kind, str): - if kind == 'bool': + if kind == "bool": return dtype == xp.bool_ - elif kind == 'signed integer': + elif kind == "signed integer": return xp.issubdtype(dtype, xp.signedinteger) - elif kind == 'unsigned integer': + elif kind == "unsigned integer": return xp.issubdtype(dtype, xp.unsignedinteger) - elif kind == 'integral': + elif kind == "integral": return xp.issubdtype(dtype, xp.integer) - elif kind == 'real floating': + elif kind == "real floating": return xp.issubdtype(dtype, xp.floating) - elif kind == 'complex floating': + elif kind == "complex floating": return xp.issubdtype(dtype, xp.complexfloating) - elif kind == 'numeric': + elif kind == "numeric": return xp.issubdtype(dtype, xp.number) else: raise ValueError(f"Unrecognized data type kind: {kind!r}") @@ -588,24 +636,27 @@ def isdtype( # array_api_strict implementation will be very strict. return dtype == kind + # unstack is a new function in the 2023.12 array API standard -def unstack(x: Array, /, xp: Namespace, *, axis: int = 0) -> Tuple[Array, ...]: +def unstack(x: Array, /, xp: Namespace, *, axis: int = 0) -> tuple[Array, ...]: if x.ndim == 0: raise ValueError("Input array must be at least 1-d.") return tuple(xp.moveaxis(x, axis, 0)) + # numpy 1.26 does not use the standard definition for sign on complex numbers -def sign(x: Array, /, xp: Namespace, **kwargs) -> Array: - if isdtype(x.dtype, 'complex floating', xp=xp): - out = (x/xp.abs(x, **kwargs))[...] + +def sign(x: Array, /, xp: Namespace, **kwargs: object) -> Array: + if isdtype(x.dtype, "complex floating", xp=xp): + out = (x / xp.abs(x, **kwargs))[...] # sign(0) = 0 but the above formula would give nan - out[x == 0+0j] = 0+0j + out[x == 0 + 0j] = 0 + 0j else: out = xp.sign(x, **kwargs) # CuPy sign() does not propagate nans. See # https://github.com/data-apis/array-api-compat/issues/136 - if _is_cupy_namespace(xp) and isdtype(x.dtype, 'real floating', xp=xp): + if _is_cupy_namespace(xp) and isdtype(x.dtype, "real floating", xp=xp): out[xp.isnan(x)] = xp.nan return out[()] @@ -626,13 +677,47 @@ def iinfo(type_: DType | Array, /, xp: Namespace) -> Any: return xp.iinfo(type_.dtype) -__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like', - 'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like', - 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', - 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', - 'std', 'var', 'cumulative_sum', 'cumulative_prod','clip', 'permute_dims', - 'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc', - 'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype', - 'unstack', 'sign', 'finfo', 'iinfo'] - -_all_ignore = ['inspect', 'array_namespace', 'NamedTuple'] +__all__ = [ + "arange", + "empty", + "empty_like", + "eye", + "full", + "full_like", + "linspace", + "ones", + "ones_like", + "zeros", + "zeros_like", + "UniqueAllResult", + "UniqueCountsResult", + "UniqueInverseResult", + "unique_all", + "unique_counts", + "unique_inverse", + "unique_values", + "std", + "var", + "cumulative_sum", + "cumulative_prod", + "clip", + "permute_dims", + "reshape", + "argsort", + "sort", + "nonzero", + "ceil", + "floor", + "trunc", + "matmul", + "matrix_transpose", + "tensordot", + "vecdot", + "isdtype", + "unstack", + "sign", + "finfo", + "iinfo", +] + +_all_ignore = ["inspect", "array_namespace", "NamedTuple"] From cbec5f360e142c6a73c3509c2f864a31b8be510f Mon Sep 17 00:00:00 2001 From: jorenham <jhammudoglu@gmail.com> Date: Sat, 22 Mar 2025 22:20:11 +0100 Subject: [PATCH 05/24] TYP: fix typing errors in `common._linalg` --- array_api_compat/common/_linalg.py | 106 +++++++++++++++++++++-------- 1 file changed, 78 insertions(+), 28 deletions(-) diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index d1e7ebd8..5a6f070e 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -1,23 +1,33 @@ from __future__ import annotations import math -from typing import Literal, NamedTuple, Optional, Tuple, Union +from typing import Literal, NamedTuple, cast import numpy as np + if np.__version__[0] == "2": from numpy.lib.array_utils import normalize_axis_tuple else: from numpy.core.numeric import normalize_axis_tuple -from ._aliases import matmul, matrix_transpose, tensordot, vecdot, isdtype from .._internal import get_xp -from ._typing import Array, Namespace +from ._aliases import isdtype, matmul, matrix_transpose, tensordot, vecdot +from ._typing import Array, DType, Namespace + # These are in the main NumPy namespace but not in numpy.linalg -def cross(x1: Array, x2: Array, /, xp: Namespace, *, axis: int = -1, **kwargs) -> Array: +def cross( + x1: Array, + x2: Array, + /, + xp: Namespace, + *, + axis: int = -1, + **kwargs: object, +) -> Array: return xp.cross(x1, x2, axis=axis, **kwargs) -def outer(x1: Array, x2: Array, /, xp: Namespace, **kwargs) -> Array: +def outer(x1: Array, x2: Array, /, xp: Namespace, **kwargs: object) -> Array: return xp.outer(x1, x2, **kwargs) class EighResult(NamedTuple): @@ -39,46 +49,66 @@ class SVDResult(NamedTuple): # These functions are the same as their NumPy counterparts except they return # a namedtuple. -def eigh(x: Array, /, xp: Namespace, **kwargs) -> EighResult: +def eigh(x: Array, /, xp: Namespace, **kwargs: object) -> EighResult: return EighResult(*xp.linalg.eigh(x, **kwargs)) -def qr(x: Array, /, xp: Namespace, *, mode: Literal['reduced', 'complete'] = 'reduced', - **kwargs) -> QRResult: +def qr( + x: Array, + /, + xp: Namespace, + *, + mode: Literal["reduced", "complete"] = "reduced", + **kwargs: object, +) -> QRResult: return QRResult(*xp.linalg.qr(x, mode=mode, **kwargs)) -def slogdet(x: Array, /, xp: Namespace, **kwargs) -> SlogdetResult: +def slogdet(x: Array, /, xp: Namespace, **kwargs: object) -> SlogdetResult: return SlogdetResult(*xp.linalg.slogdet(x, **kwargs)) def svd( - x: Array, /, xp: Namespace, *, full_matrices: bool = True, **kwargs + x: Array, + /, + xp: Namespace, + *, + full_matrices: bool = True, + **kwargs: object, ) -> SVDResult: return SVDResult(*xp.linalg.svd(x, full_matrices=full_matrices, **kwargs)) # These functions have additional keyword arguments # The upper keyword argument is new from NumPy -def cholesky(x: Array, /, xp: Namespace, *, upper: bool = False, **kwargs) -> Array: +def cholesky( + x: Array, + /, + xp: Namespace, + *, + upper: bool = False, + **kwargs: object, +) -> Array: L = xp.linalg.cholesky(x, **kwargs) if upper: U = get_xp(xp)(matrix_transpose)(L) if get_xp(xp)(isdtype)(U.dtype, 'complex floating'): - U = xp.conj(U) + U = xp.conj(U) # pyright: ignore[reportConstantRedefinition] return U return L # The rtol keyword argument of matrix_rank() and pinv() is new from NumPy. # Note that it has a different semantic meaning from tol and rcond. -def matrix_rank(x: Array, - /, - xp: Namespace, - *, - rtol: Optional[Union[float, Array]] = None, - **kwargs) -> Array: +def matrix_rank( + x: Array, + /, + xp: Namespace, + *, + rtol: float | Array | None = None, + **kwargs: object, +) -> Array: # this is different from xp.linalg.matrix_rank, which supports 1 # dimensional arrays. if x.ndim < 2: raise xp.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional") - S = get_xp(xp)(svdvals)(x, **kwargs) + S: Array = get_xp(xp)(svdvals)(x, **kwargs) if rtol is None: tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * xp.finfo(S.dtype).eps else: @@ -88,7 +118,12 @@ def matrix_rank(x: Array, return xp.count_nonzero(S > tol, axis=-1) def pinv( - x: Array, /, xp: Namespace, *, rtol: Optional[Union[float, Array]] = None, **kwargs + x: Array, + /, + xp: Namespace, + *, + rtol: float | Array | None = None, + **kwargs: object, ) -> Array: # this is different from xp.linalg.pinv, which does not multiply the # default tolerance by max(M, N). @@ -104,13 +139,13 @@ def matrix_norm( xp: Namespace, *, keepdims: bool = False, - ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro', + ord: float | Literal["fro", "nuc"] | None = "fro", ) -> Array: return xp.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord) # svdvals is not in NumPy (but it is in SciPy). It is equivalent to # xp.linalg.svd(compute_uv=False). -def svdvals(x: Array, /, xp: Namespace) -> Union[Array, Tuple[Array, ...]]: +def svdvals(x: Array, /, xp: Namespace) -> Array | tuple[Array, ...]: return xp.linalg.svd(x, compute_uv=False) def vector_norm( @@ -118,9 +153,9 @@ def vector_norm( /, xp: Namespace, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, + axis: int | tuple[int, ...] | None = None, keepdims: bool = False, - ord: Optional[Union[int, float]] = 2, + ord: float = 2, ) -> Array: # xp.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or # when axis=None and the input is 2-D, so to force a vector norm, we make @@ -133,7 +168,10 @@ def vector_norm( elif isinstance(axis, tuple): # Note: The axis argument supports any number of axes, whereas # xp.linalg.norm() only supports a single axis for vector norm. - normalized_axis = normalize_axis_tuple(axis, x.ndim) + normalized_axis = cast( + "tuple[int, ...]", + normalize_axis_tuple(axis, x.ndim), # pyright: ignore[reportCallIssue] + ) rest = tuple(i for i in range(x.ndim) if i not in normalized_axis) newshape = axis + rest _x = xp.transpose(x, newshape).reshape( @@ -149,7 +187,13 @@ def vector_norm( # We can't reuse xp.linalg.norm(keepdims) because of the reshape hacks # above to avoid matrix norm logic. shape = list(x.shape) - _axis = normalize_axis_tuple(range(x.ndim) if axis is None else axis, x.ndim) + _axis = cast( + "tuple[int, ...]", + normalize_axis_tuple( # pyright: ignore[reportCallIssue] + range(x.ndim) if axis is None else axis, + x.ndim, + ), + ) for i in _axis: shape[i] = 1 res = xp.reshape(res, tuple(shape)) @@ -159,11 +203,17 @@ def vector_norm( # xp.diagonal and xp.trace operate on the first two axes whereas these # operates on the last two -def diagonal(x: Array, /, xp: Namespace, *, offset: int = 0, **kwargs) -> Array: +def diagonal(x: Array, /, xp: Namespace, *, offset: int = 0, **kwargs: object) -> Array: return xp.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs) def trace( - x: Array, /, xp: Namespace, *, offset: int = 0, dtype=None, **kwargs + x: Array, + /, + xp: Namespace, + *, + offset: int = 0, + dtype: DType | None = None, + **kwargs: object, ) -> Array: return xp.asarray( xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs) From dc79e3f1731b966294142067416e9c76cf5bb607 Mon Sep 17 00:00:00 2001 From: jorenham <jhammudoglu@gmail.com> Date: Sat, 22 Mar 2025 22:26:36 +0100 Subject: [PATCH 06/24] TYP: fix/ignore typing errors in `numpy.__init__` --- array_api_compat/numpy/__init__.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/array_api_compat/numpy/__init__.py b/array_api_compat/numpy/__init__.py index 6a5d9867..caaf67bb 100644 --- a/array_api_compat/numpy/__init__.py +++ b/array_api_compat/numpy/__init__.py @@ -1,10 +1,16 @@ -from numpy import * # noqa: F403 +# ruff: noqa: PLC0414 +from typing import Final + +from numpy import * # noqa: F403 # pyright: ignore[reportWildcardImportFromLibrary] # from numpy import * doesn't overwrite these builtin names -from numpy import abs, max, min, round # noqa: F401 +from numpy import abs as abs +from numpy import max as max +from numpy import min as min +from numpy import round as round # These imports may overwrite names from the import * above. -from ._aliases import * # noqa: F403 +from ._aliases import * # noqa: F403 # Don't know why, but we have to do an absolute import to import linalg. If we # instead do @@ -13,9 +19,17 @@ # # It doesn't overwrite np.linalg from above. The import is generated # dynamically so that the library can be vendored. -__import__(__package__ + '.linalg') -__import__(__package__ + '.fft') +__import__(__package__ + ".linalg") + +__import__(__package__ + ".fft") + +from ..common._helpers import * # noqa: F403 +from .linalg import matrix_transpose, vecdot # noqa: F401 -from .linalg import matrix_transpose, vecdot # noqa: F401 +try: + # Used in asarray(). Not present in older versions. + from numpy import _CopyMode # noqa: F401 +except ImportError: + pass -__array_api_version__ = '2024.12' +__array_api_version__ = "2024.12" From 9643256ae1b7fd7af65b3255c992ce760dc6a4b5 Mon Sep 17 00:00:00 2001 From: jorenham <jhammudoglu@gmail.com> Date: Sat, 22 Mar 2025 22:31:55 +0100 Subject: [PATCH 07/24] TYP: fix typing errors in `numpy._typing` --- array_api_compat/numpy/_typing.py | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/array_api_compat/numpy/_typing.py b/array_api_compat/numpy/_typing.py index a6c96924..1f7b247c 100644 --- a/array_api_compat/numpy/_typing.py +++ b/array_api_compat/numpy/_typing.py @@ -3,29 +3,22 @@ __all__ = ["Array", "DType", "Device"] _all_ignore = ["np"] -from typing import Literal, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Literal, TypeAlias import numpy as np -from numpy import ndarray as Array -Device = Literal["cpu"] +Device: TypeAlias = Literal["cpu"] if TYPE_CHECKING: # NumPy 1.x on Python 3.10 fails to parse np.dtype[] - DType = np.dtype[ - np.intp - | np.int8 - | np.int16 - | np.int32 - | np.int64 - | np.uint8 - | np.uint16 - | np.uint32 - | np.uint64 + DType: TypeAlias = np.dtype[ + np.bool_ + | np.integer[Any] | np.float32 | np.float64 | np.complex64 | np.complex128 - | np.bool ] + Array: TypeAlias = np.ndarray[Any, DType] else: DType = np.dtype + Array = np.ndarray From 18870dcf4fe7d345119cdc082550f679bd7542b3 Mon Sep 17 00:00:00 2001 From: jorenham <jhammudoglu@gmail.com> Date: Sat, 22 Mar 2025 22:50:35 +0100 Subject: [PATCH 08/24] TYP: fix typing errors in `numpy._aliases` --- array_api_compat/numpy/_aliases.py | 81 +++++++++++++++++++----------- 1 file changed, 52 insertions(+), 29 deletions(-) diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index d1fd46a1..10d84fe7 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -1,6 +1,10 @@ +# pyright: reportPrivateUsage=false from __future__ import annotations -from typing import Optional, Union +from builtins import bool as py_bool +from typing import TYPE_CHECKING, cast + +import numpy as np from .._internal import get_xp from ..common import _aliases, _helpers @@ -8,7 +12,12 @@ from ._info import __array_namespace_info__ from ._typing import Array, Device, DType -import numpy as np +if TYPE_CHECKING: + from typing import Any, Literal, TypeAlias + + from typing_extensions import Buffer, TypeIs + + _Copy: TypeAlias = py_bool | Literal[2] | np._CopyMode bool = np.bool_ @@ -65,9 +74,9 @@ iinfo = get_xp(np)(_aliases.iinfo) -def _supports_buffer_protocol(obj): +def _supports_buffer_protocol(obj: object) -> TypeIs[Buffer]: # pyright: ignore[reportUnusedFunction] try: - memoryview(obj) + memoryview(obj) # pyright: ignore[reportArgumentType] except TypeError: return False return True @@ -78,18 +87,13 @@ def _supports_buffer_protocol(obj): # complicated enough that it's easier to define it separately for each module # rather than trying to combine everything into one function in common/ def asarray( - obj: ( - Array - | bool | int | float | complex - | NestedSequence[bool | int | float | complex] - | SupportsBufferProtocol - ), + obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol, /, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - copy: Optional[Union[bool, np._CopyMode]] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + copy: _Copy | None = None, + **kwargs: Any, ) -> Array: """ Array API compatibility wrapper for asarray(). @@ -106,7 +110,7 @@ def asarray( elif copy is True: copy = np._CopyMode.ALWAYS - return np.array(obj, copy=copy, dtype=dtype, **kwargs) + return np.array(obj, copy=copy, dtype=dtype, **kwargs) # pyright: ignore def astype( @@ -114,8 +118,8 @@ def astype( dtype: DType, /, *, - copy: bool = True, - device: Optional[Device] = None, + copy: py_bool = True, + device: Device | None = None, ) -> Array: _helpers._check_device(np, device) return x.astype(dtype=dtype, copy=copy) @@ -123,8 +127,12 @@ def astype( # count_nonzero returns a python int for axis=None and keepdims=False # https://github.com/numpy/numpy/issues/17562 -def count_nonzero(x: Array, axis=None, keepdims=False) -> Array: - result = np.count_nonzero(x, axis=axis, keepdims=keepdims) +def count_nonzero( + x: Array, + axis: int | tuple[int, ...] | None = None, + keepdims: py_bool = False, +) -> Array: + result = cast("Any", np.count_nonzero(x, axis=axis, keepdims=keepdims)) # pyright: ignore if axis is None and not keepdims: return np.asarray(result) return result @@ -132,25 +140,40 @@ def count_nonzero(x: Array, axis=None, keepdims=False) -> Array: # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. -if hasattr(np, 'vecdot'): +if hasattr(np, "vecdot"): vecdot = np.vecdot else: vecdot = get_xp(np)(_aliases.vecdot) -if hasattr(np, 'isdtype'): +if hasattr(np, "isdtype"): isdtype = np.isdtype else: isdtype = get_xp(np)(_aliases.isdtype) -if hasattr(np, 'unstack'): +if hasattr(np, "unstack"): unstack = np.unstack else: unstack = get_xp(np)(_aliases.unstack) -__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'astype', - 'acos', 'acosh', 'asin', 'asinh', 'atan', - 'atan2', 'atanh', 'bitwise_left_shift', - 'bitwise_invert', 'bitwise_right_shift', - 'bool', 'concat', 'count_nonzero', 'pow'] - -_all_ignore = ['np', 'get_xp'] +__all__ = [ + "__array_namespace_info__", + "asarray", + "astype", + "acos", + "acosh", + "asin", + "asinh", + "atan", + "atan2", + "atanh", + "bitwise_left_shift", + "bitwise_invert", + "bitwise_right_shift", + "bool", + "concat", + "count_nonzero", + "pow", +] +__all__ += _aliases.__all__ + +_all_ignore = ["np", "get_xp"] From 1fb929bed064bbf04217566e3dff931fec49cd3e Mon Sep 17 00:00:00 2001 From: jorenham <jhammudoglu@gmail.com> Date: Sat, 22 Mar 2025 22:55:40 +0100 Subject: [PATCH 09/24] TYP: fix typing errors in `numpy._info` --- array_api_compat/numpy/_info.py | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/array_api_compat/numpy/_info.py b/array_api_compat/numpy/_info.py index 365855b8..ec7c39f5 100644 --- a/array_api_compat/numpy/_info.py +++ b/array_api_compat/numpy/_info.py @@ -7,24 +7,26 @@ more details. """ +from numpy import bool_ as bool from numpy import ( + complex64, + complex128, dtype, - bool_ as bool, - intp, + float32, + float64, int8, int16, int32, int64, + intp, uint8, uint16, uint32, uint64, - float32, - float64, - complex64, - complex128, ) +from ._typing import Device, DType + class __array_namespace_info__: """ @@ -131,7 +133,11 @@ def default_device(self): """ return "cpu" - def default_dtypes(self, *, device=None): + def default_dtypes( + self, + *, + device: Device | None = None, + ) -> dict[str, dtype[intp | float64 | complex128]]: """ The default data types used for new NumPy arrays. @@ -183,7 +189,12 @@ def default_dtypes(self, *, device=None): "indexing": dtype(intp), } - def dtypes(self, *, device=None, kind=None): + def dtypes( + self, + *, + device: Device | None = None, + kind: str | tuple[str, ...] | None = None, + ) -> dict[str, DType]: """ The array API data types supported by NumPy. @@ -260,7 +271,7 @@ def dtypes(self, *, device=None, kind=None): "complex128": dtype(complex128), } if kind == "bool": - return {"bool": bool} + return {"bool": dtype(bool)} if kind == "signed integer": return { "int8": dtype(int8), @@ -312,13 +323,13 @@ def dtypes(self, *, device=None, kind=None): "complex128": dtype(complex128), } if isinstance(kind, tuple): - res = {} + res: dict[str, DType] = {} for k in kind: res.update(self.dtypes(kind=k)) return res raise ValueError(f"unsupported kind: {kind!r}") - def devices(self): + def devices(self) -> list[Device]: """ The devices supported by NumPy. From 014385fb0514015b0e40958c2a8ccee732e8e396 Mon Sep 17 00:00:00 2001 From: jorenham <jhammudoglu@gmail.com> Date: Sat, 22 Mar 2025 23:01:54 +0100 Subject: [PATCH 10/24] TYP: fix typing errors in `numpy._fft` --- array_api_compat/numpy/fft.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/array_api_compat/numpy/fft.py b/array_api_compat/numpy/fft.py index 28667594..5423bd01 100644 --- a/array_api_compat/numpy/fft.py +++ b/array_api_compat/numpy/fft.py @@ -1,10 +1,9 @@ -from numpy.fft import * # noqa: F403 +import numpy as np from numpy.fft import __all__ as fft_all +from numpy.fft import fft2, ifft2, irfft2, rfft2 -from ..common import _fft from .._internal import get_xp - -import numpy as np +from ..common import _fft fft = get_xp(np)(_fft.fft) ifft = get_xp(np)(_fft.ifft) @@ -21,7 +20,8 @@ fftshift = get_xp(np)(_fft.fftshift) ifftshift = get_xp(np)(_fft.ifftshift) -__all__ = fft_all + _fft.__all__ +__all__ = ["rfft2", "irfft2", "fft2", "ifft2"] +__all__ += _fft.__all__ del get_xp del np From ec72825c8e32de076a6f228105ccaf4bb25344a9 Mon Sep 17 00:00:00 2001 From: jorenham <jhammudoglu@gmail.com> Date: Sat, 22 Mar 2025 23:18:53 +0100 Subject: [PATCH 11/24] TYP: it's a bad idea to import `TypeAlias` from `typing` on `python<3.10` --- array_api_compat/common/_fft.py | 4 ++-- array_api_compat/common/_helpers.py | 13 ++----------- array_api_compat/numpy/_aliases.py | 4 ++-- array_api_compat/numpy/_typing.py | 6 ++++-- 4 files changed, 10 insertions(+), 17 deletions(-) diff --git a/array_api_compat/common/_fft.py b/array_api_compat/common/_fft.py index 49f948ce..fb95eafc 100644 --- a/array_api_compat/common/_fft.py +++ b/array_api_compat/common/_fft.py @@ -1,9 +1,9 @@ from collections.abc import Sequence -from typing import Literal, TypeAlias +from typing import Literal from ._typing import Array, Device, DType, Namespace -_Norm: TypeAlias = Literal["backward", "ortho", "forward"] +_Norm = Literal["backward", "ortho", "forward"] # Note: NumPy fft functions improperly upcast float32 and complex64 to # complex128, which is why we require wrapping them all here. diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 9f016f05..76c9d92c 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -12,16 +12,7 @@ import math import sys import warnings -from typing import ( - TYPE_CHECKING, - Any, - Literal, - SupportsIndex, - TypeAlias, - TypeGuard, - cast, - overload, -) +from typing import TYPE_CHECKING, Any, Literal, SupportsIndex, TypeGuard, cast, overload from ._typing import Array, Device, HasShape, Namespace, SupportsArrayNamespace @@ -35,7 +26,7 @@ import numpy.typing as npt import sparse # pyright: ignore[reportMissingTypeStubs] import torch - from typing_extensions import TypeIs, TypeVar + from typing_extensions import TypeAlias, TypeIs, TypeVar _SizeT = TypeVar("_SizeT", bound=int | None) diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 10d84fe7..39771f4e 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -13,9 +13,9 @@ from ._typing import Array, Device, DType if TYPE_CHECKING: - from typing import Any, Literal, TypeAlias + from typing import Any, Literal - from typing_extensions import Buffer, TypeIs + from typing_extensions import Buffer, TypeAlias, TypeIs _Copy: TypeAlias = py_bool | Literal[2] | np._CopyMode diff --git a/array_api_compat/numpy/_typing.py b/array_api_compat/numpy/_typing.py index 1f7b247c..d4a285b9 100644 --- a/array_api_compat/numpy/_typing.py +++ b/array_api_compat/numpy/_typing.py @@ -3,12 +3,14 @@ __all__ = ["Array", "DType", "Device"] _all_ignore = ["np"] -from typing import TYPE_CHECKING, Any, Literal, TypeAlias +from typing import TYPE_CHECKING, Any, Literal import numpy as np -Device: TypeAlias = Literal["cpu"] +Device = Literal["cpu"] if TYPE_CHECKING: + from typing_extensions import TypeAlias + # NumPy 1.x on Python 3.10 fails to parse np.dtype[] DType: TypeAlias = np.dtype[ np.bool_ From ccd9bc6e3e6cc1fd12318c92df03c5d979916579 Mon Sep 17 00:00:00 2001 From: jorenham <jhammudoglu@gmail.com> Date: Sat, 22 Mar 2025 23:20:54 +0100 Subject: [PATCH 12/24] TYP: it's also a bad idea to import `TypeGuard` from `typing` on `python<3.10` --- array_api_compat/common/_helpers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 76c9d92c..b14be849 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -12,7 +12,7 @@ import math import sys import warnings -from typing import TYPE_CHECKING, Any, Literal, SupportsIndex, TypeGuard, cast, overload +from typing import TYPE_CHECKING, Any, Literal, SupportsIndex, cast, overload from ._typing import Array, Device, HasShape, Namespace, SupportsArrayNamespace @@ -26,7 +26,7 @@ import numpy.typing as npt import sparse # pyright: ignore[reportMissingTypeStubs] import torch - from typing_extensions import TypeAlias, TypeIs, TypeVar + from typing_extensions import TypeAlias, TypeGuard, TypeIs, TypeVar _SizeT = TypeVar("_SizeT", bound=int | None) From b8c78832fb9235f4fe1b674aa63364aa0ae2a87b Mon Sep 17 00:00:00 2001 From: jorenham <jhammudoglu@gmail.com> Date: Sat, 22 Mar 2025 23:24:05 +0100 Subject: [PATCH 13/24] TYP: don't scare the prehistoric `dtype` from numpy 1.21 --- array_api_compat/numpy/_info.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_compat/numpy/_info.py b/array_api_compat/numpy/_info.py index ec7c39f5..22570a7c 100644 --- a/array_api_compat/numpy/_info.py +++ b/array_api_compat/numpy/_info.py @@ -137,7 +137,7 @@ def default_dtypes( self, *, device: Device | None = None, - ) -> dict[str, dtype[intp | float64 | complex128]]: + ) -> dict[str, "dtype[intp | float64 | complex128]"]: """ The default data types used for new NumPy arrays. From ef066d1225887144a6310d78feecceb447f61a17 Mon Sep 17 00:00:00 2001 From: jorenham <jhammudoglu@gmail.com> Date: Sat, 22 Mar 2025 23:25:50 +0100 Subject: [PATCH 14/24] TYP: dust off the DeLorean --- array_api_compat/numpy/_info.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/array_api_compat/numpy/_info.py b/array_api_compat/numpy/_info.py index 22570a7c..ca019b2c 100644 --- a/array_api_compat/numpy/_info.py +++ b/array_api_compat/numpy/_info.py @@ -7,6 +7,8 @@ more details. """ +from __future__ import annotations + from numpy import bool_ as bool from numpy import ( complex64, @@ -137,7 +139,7 @@ def default_dtypes( self, *, device: Device | None = None, - ) -> dict[str, "dtype[intp | float64 | complex128]"]: + ) -> dict[str, dtype[intp | float64 | complex128]]: """ The default data types used for new NumPy arrays. From a522dbc9dbae73ac31f294c897cad5d59276dfda Mon Sep 17 00:00:00 2001 From: jorenham <jhammudoglu@gmail.com> Date: Sat, 22 Mar 2025 23:31:48 +0100 Subject: [PATCH 15/24] TYP: figure out how to drive a DeLorean --- array_api_compat/common/_fft.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/array_api_compat/common/_fft.py b/array_api_compat/common/_fft.py index fb95eafc..6fe834dc 100644 --- a/array_api_compat/common/_fft.py +++ b/array_api_compat/common/_fft.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from collections.abc import Sequence from typing import Literal From bca9c0cedabd73c18cd471f3137dc2455fd9de13 Mon Sep 17 00:00:00 2001 From: jorenham <jhammudoglu@gmail.com> Date: Tue, 15 Apr 2025 14:24:36 +0200 Subject: [PATCH 16/24] TYP: apply review suggestions Co-authored-by: crusaderky <crusaderky@gmail.com> --- array_api_compat/_internal.py | 15 ++++++--------- array_api_compat/common/_aliases.py | 1 + array_api_compat/common/_helpers.py | 18 +++++++++++++++--- array_api_compat/numpy/_aliases.py | 8 +++----- 4 files changed, 25 insertions(+), 17 deletions(-) diff --git a/array_api_compat/_internal.py b/array_api_compat/_internal.py index 7a973567..6df63097 100644 --- a/array_api_compat/_internal.py +++ b/array_api_compat/_internal.py @@ -2,21 +2,18 @@ Internal helpers """ +from collections.abc import Callable from functools import wraps from inspect import signature -from typing import TYPE_CHECKING +from types import ModuleType +from typing import TypeVar __all__ = ["get_xp"] -if TYPE_CHECKING: - from collections.abc import Callable - from types import ModuleType - from typing import TypeVar +_T = TypeVar("_T") - _T = TypeVar("_T") - -def get_xp(xp: "ModuleType") -> "Callable[[Callable[..., _T]], Callable[..., _T]]": +def get_xp(xp: ModuleType) -> Callable[[Callable[..., _T]], Callable[..., _T]]: """ Decorator to automatically replace xp with the corresponding array module. @@ -33,7 +30,7 @@ def func(x, /, xp, kwarg=None): """ - def inner(f: "Callable[..., _T]", /) -> "Callable[..., _T]": + def inner(f: Callable[..., _T], /) -> Callable[..., _T]: @wraps(f) def wrapped_f(*args: object, **kwargs: object) -> object: return f(*args, xp=xp, **kwargs) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index cd4b6af8..8b106d72 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -13,6 +13,7 @@ from ._typing import Array, Device, DType, Namespace if TYPE_CHECKING: + # TODO: import from typing (requires Python >=3.13) from typing_extensions import TypeIs # These functions are modified from the NumPy versions. diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index b14be849..551d09f7 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -12,12 +12,22 @@ import math import sys import warnings -from typing import TYPE_CHECKING, Any, Literal, SupportsIndex, cast, overload +from collections.abc import Collection +from typing import ( + TYPE_CHECKING, + Any, + Literal, + SupportsIndex, + TypeAlias, + TypeGuard, + TypeVar, + cast, + overload, +) from ._typing import Array, Device, HasShape, Namespace, SupportsArrayNamespace if TYPE_CHECKING: - from collections.abc import Collection import dask.array as da import jax @@ -26,7 +36,9 @@ import numpy.typing as npt import sparse # pyright: ignore[reportMissingTypeStubs] import torch - from typing_extensions import TypeAlias, TypeGuard, TypeIs, TypeVar + + # TODO: import from typing (requires Python >=3.13) + from typing_extensions import TypeIs, TypeVar _SizeT = TypeVar("_SizeT", bound=int | None) diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 39771f4e..41d02dd0 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -2,7 +2,7 @@ from __future__ import annotations from builtins import bool as py_bool -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Any, Literal, TypeAlias, cast import numpy as np @@ -13,11 +13,9 @@ from ._typing import Array, Device, DType if TYPE_CHECKING: - from typing import Any, Literal + from typing_extensions import Buffer, TypeIs - from typing_extensions import Buffer, TypeAlias, TypeIs - - _Copy: TypeAlias = py_bool | Literal[2] | np._CopyMode +_Copy: TypeAlias = py_bool | Literal[2] | np._CopyMode bool = np.bool_ From 0dd925f79185d13aed000f98a00b272a3bd6a066 Mon Sep 17 00:00:00 2001 From: jorenham <jhammudoglu@gmail.com> Date: Tue, 15 Apr 2025 14:25:10 +0200 Subject: [PATCH 17/24] TYP: sprinkle some `TypeAlias`es and `Final`s around --- array_api_compat/common/_fft.py | 4 ++-- array_api_compat/common/_typing.py | 10 +++++----- array_api_compat/numpy/__init__.py | 2 +- array_api_compat/numpy/_typing.py | 10 +++++----- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/array_api_compat/common/_fft.py b/array_api_compat/common/_fft.py index 6fe834dc..5e75e85f 100644 --- a/array_api_compat/common/_fft.py +++ b/array_api_compat/common/_fft.py @@ -1,11 +1,11 @@ from __future__ import annotations from collections.abc import Sequence -from typing import Literal +from typing import Literal, TypeAlias from ._typing import Array, Device, DType, Namespace -_Norm = Literal["backward", "ortho", "forward"] +_Norm: TypeAlias = Literal["backward", "ortho", "forward"] # Note: NumPy fft functions improperly upcast float32 and complex64 to # complex128, which is why we require wrapping them all here. diff --git a/array_api_compat/common/_typing.py b/array_api_compat/common/_typing.py index a38d083b..cf9a00d7 100644 --- a/array_api_compat/common/_typing.py +++ b/array_api_compat/common/_typing.py @@ -1,7 +1,7 @@ from __future__ import annotations from types import ModuleType as Namespace -from typing import Any, Protocol, TypeVar +from typing import Any, Protocol, TypeAlias, TypeVar __all__ = [ "Array", @@ -30,7 +30,7 @@ class HasShape(Protocol[_T_co]): def shape(self, /) -> _T_co: ... -SupportsBufferProtocol = Any -Array = Any -Device = Any -DType = Any +SupportsBufferProtocol: TypeAlias = Any +Array: TypeAlias = Any +Device: TypeAlias = Any +DType: TypeAlias = Any diff --git a/array_api_compat/numpy/__init__.py b/array_api_compat/numpy/__init__.py index caaf67bb..f7b558ba 100644 --- a/array_api_compat/numpy/__init__.py +++ b/array_api_compat/numpy/__init__.py @@ -32,4 +32,4 @@ except ImportError: pass -__array_api_version__ = "2024.12" +__array_api_version__: Final = "2024.12" diff --git a/array_api_compat/numpy/_typing.py b/array_api_compat/numpy/_typing.py index d4a285b9..0796f7d0 100644 --- a/array_api_compat/numpy/_typing.py +++ b/array_api_compat/numpy/_typing.py @@ -3,13 +3,13 @@ __all__ = ["Array", "DType", "Device"] _all_ignore = ["np"] -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Literal, TypeAlias import numpy as np -Device = Literal["cpu"] +Device: TypeAlias = Literal["cpu"] + if TYPE_CHECKING: - from typing_extensions import TypeAlias # NumPy 1.x on Python 3.10 fails to parse np.dtype[] DType: TypeAlias = np.dtype[ @@ -22,5 +22,5 @@ ] Array: TypeAlias = np.ndarray[Any, DType] else: - DType = np.dtype - Array = np.ndarray + DType: TypeAlias = np.dtype + Array: TypeAlias = np.ndarray From 953d7c051be9aca5b269ae5e45eadd35ff8b9dea Mon Sep 17 00:00:00 2001 From: jorenham <jhammudoglu@gmail.com> Date: Tue, 15 Apr 2025 14:41:59 +0200 Subject: [PATCH 18/24] TYP: `__dir__` --- array_api_compat/_internal.py | 9 +++++++-- array_api_compat/common/__init__.py | 2 +- array_api_compat/common/_aliases.py | 5 ++++- array_api_compat/common/_fft.py | 3 +++ array_api_compat/common/_helpers.py | 3 +++ array_api_compat/common/_linalg.py | 4 ++++ array_api_compat/common/_typing.py | 27 ++++++++++++++++----------- array_api_compat/numpy/_aliases.py | 5 ++++- array_api_compat/numpy/_info.py | 7 +++++++ array_api_compat/numpy/_typing.py | 10 +++++++--- array_api_compat/numpy/fft.py | 6 ++++++ 11 files changed, 62 insertions(+), 19 deletions(-) diff --git a/array_api_compat/_internal.py b/array_api_compat/_internal.py index 6df63097..cd8d939f 100644 --- a/array_api_compat/_internal.py +++ b/array_api_compat/_internal.py @@ -8,8 +8,6 @@ from types import ModuleType from typing import TypeVar -__all__ = ["get_xp"] - _T = TypeVar("_T") @@ -52,3 +50,10 @@ def wrapped_f(*args: object, **kwargs: object) -> object: return wrapped_f # pyright: ignore[reportReturnType] return inner + + +__all__ = ["get_xp"] + + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/common/__init__.py b/array_api_compat/common/__init__.py index 91ab1c40..82360807 100644 --- a/array_api_compat/common/__init__.py +++ b/array_api_compat/common/__init__.py @@ -1 +1 @@ -from ._helpers import * # noqa: F403 +from ._helpers import * # noqa: F403 diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 8b106d72..7f3cf914 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -720,5 +720,8 @@ def iinfo(type_: DType | Array, /, xp: Namespace) -> Any: "finfo", "iinfo", ] - _all_ignore = ["inspect", "array_namespace", "NamedTuple"] + + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/common/_fft.py b/array_api_compat/common/_fft.py index 5e75e85f..18839d37 100644 --- a/array_api_compat/common/_fft.py +++ b/array_api_compat/common/_fft.py @@ -208,3 +208,6 @@ def ifftshift( "fftshift", "ifftshift", ] + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 551d09f7..0ce6edf0 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -1044,3 +1044,6 @@ def is_lazy_array(x: object) -> bool: ] _all_ignore = ["sys", "math", "inspect", "warnings"] + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index 5a6f070e..7e002aed 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -226,3 +226,7 @@ def trace( 'trace'] _all_ignore = ['math', 'normalize_axis_tuple', 'get_xp', 'np', 'isdtype'] + + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/common/_typing.py b/array_api_compat/common/_typing.py index cf9a00d7..477b6a4e 100644 --- a/array_api_compat/common/_typing.py +++ b/array_api_compat/common/_typing.py @@ -3,17 +3,6 @@ from types import ModuleType as Namespace from typing import Any, Protocol, TypeAlias, TypeVar -__all__ = [ - "Array", - "SupportsArrayNamespace", - "DType", - "Device", - "HasShape", - "Namespace", - "NestedSequence", - "SupportsBufferProtocol", -] - _T_co = TypeVar("_T_co", covariant=True) class NestedSequence(Protocol[_T_co]): @@ -34,3 +23,19 @@ def shape(self, /) -> _T_co: ... Array: TypeAlias = Any Device: TypeAlias = Any DType: TypeAlias = Any + + +__all__ = [ + "Array", + "SupportsArrayNamespace", + "DType", + "Device", + "HasShape", + "Namespace", + "NestedSequence", + "SupportsBufferProtocol", +] + + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 41d02dd0..2c03041f 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -173,5 +173,8 @@ def count_nonzero( "pow", ] __all__ += _aliases.__all__ - _all_ignore = ["np", "get_xp"] + + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/numpy/_info.py b/array_api_compat/numpy/_info.py index ca019b2c..f307f62c 100644 --- a/array_api_compat/numpy/_info.py +++ b/array_api_compat/numpy/_info.py @@ -357,3 +357,10 @@ def devices(self) -> list[Device]: """ return ["cpu"] + + +__all__ = ["__array_namespace_info__"] + + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/numpy/_typing.py b/array_api_compat/numpy/_typing.py index 0796f7d0..e771c788 100644 --- a/array_api_compat/numpy/_typing.py +++ b/array_api_compat/numpy/_typing.py @@ -1,8 +1,5 @@ from __future__ import annotations -__all__ = ["Array", "DType", "Device"] -_all_ignore = ["np"] - from typing import TYPE_CHECKING, Any, Literal, TypeAlias import numpy as np @@ -24,3 +21,10 @@ else: DType: TypeAlias = np.dtype Array: TypeAlias = np.ndarray + +__all__ = ["Array", "DType", "Device"] +_all_ignore = ["np"] + + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/numpy/fft.py b/array_api_compat/numpy/fft.py index 5423bd01..06875f00 100644 --- a/array_api_compat/numpy/fft.py +++ b/array_api_compat/numpy/fft.py @@ -20,9 +20,15 @@ fftshift = get_xp(np)(_fft.fftshift) ifftshift = get_xp(np)(_fft.ifftshift) + __all__ = ["rfft2", "irfft2", "fft2", "ifft2"] __all__ += _fft.__all__ + +def __dir__() -> list[str]: + return __all__ + + del get_xp del np del fft_all From c66b750684f85ed92b355b5a9b90ec74da2e4b39 Mon Sep 17 00:00:00 2001 From: jorenham <jhammudoglu@gmail.com> Date: Tue, 15 Apr 2025 15:21:03 +0200 Subject: [PATCH 19/24] TYP: fix typing errors in `numpy.linalg` --- array_api_compat/numpy/linalg.py | 97 ++++++++++++++++++++++++-------- 1 file changed, 75 insertions(+), 22 deletions(-) diff --git a/array_api_compat/numpy/linalg.py b/array_api_compat/numpy/linalg.py index 8f01593b..2d3e731d 100644 --- a/array_api_compat/numpy/linalg.py +++ b/array_api_compat/numpy/linalg.py @@ -1,14 +1,35 @@ -from numpy.linalg import * # noqa: F403 -from numpy.linalg import __all__ as linalg_all -import numpy as _np +# pyright: reportAttributeAccessIssue=false +# pyright: reportUnknownArgumentType=false +# pyright: reportUnknownMemberType=false +# pyright: reportUnknownVariableType=false + +from __future__ import annotations + +import numpy as np + +# intersection of `np.linalg.__all__` on numpy 1.22 and 2.2, minus `_linalg.__all__` +from numpy.linalg import ( + LinAlgError, + cond, + det, + eig, + eigvals, + eigvalsh, + inv, + lstsq, + matrix_power, + multi_dot, + norm, + tensorinv, + tensorsolve, +) -from ..common import _linalg from .._internal import get_xp +from ..common import _linalg # These functions are in both the main and linalg namespaces -from ._aliases import matmul, matrix_transpose, tensordot, vecdot # noqa: F401 - -import numpy as np +from ._aliases import matmul, matrix_transpose, tensordot, vecdot # noqa: F401 +from ._typing import Array cross = get_xp(np)(_linalg.cross) outer = get_xp(np)(_linalg.outer) @@ -38,19 +59,28 @@ # To workaround this, the below is the code from np.linalg.solve except # only calling solve1 in the exactly 1D case. + # This code is here instead of in common because it is numpy specific. Also # note that CuPy's solve() does not currently support broadcasting (see # https://github.com/cupy/cupy/blob/main/cupy/cublas.py#L43). -def solve(x1: _np.ndarray, x2: _np.ndarray, /) -> _np.ndarray: +def solve(x1: Array, x2: Array, /) -> Array: try: from numpy.linalg._linalg import ( - _makearray, _assert_stacked_2d, _assert_stacked_square, - _commonType, isComplexType, _raise_linalgerror_singular + _assert_stacked_2d, + _assert_stacked_square, + _commonType, + _makearray, + _raise_linalgerror_singular, + isComplexType, ) except ImportError: from numpy.linalg.linalg import ( - _makearray, _assert_stacked_2d, _assert_stacked_square, - _commonType, isComplexType, _raise_linalgerror_singular + _assert_stacked_2d, + _assert_stacked_square, + _commonType, + _makearray, + _raise_linalgerror_singular, + isComplexType, ) from numpy.linalg import _umath_linalg @@ -61,6 +91,7 @@ def solve(x1: _np.ndarray, x2: _np.ndarray, /) -> _np.ndarray: t, result_t = _commonType(x1, x2) # This part is different from np.linalg.solve + gufunc: np.ufunc if x2.ndim == 1: gufunc = _umath_linalg.solve1 else: @@ -68,23 +99,45 @@ def solve(x1: _np.ndarray, x2: _np.ndarray, /) -> _np.ndarray: # This does nothing currently but is left in because it will be relevant # when complex dtype support is added to the spec in 2022. - signature = 'DD->D' if isComplexType(t) else 'dd->d' - with _np.errstate(call=_raise_linalgerror_singular, invalid='call', - over='ignore', divide='ignore', under='ignore'): - r = gufunc(x1, x2, signature=signature) + signature = "DD->D" if isComplexType(t) else "dd->d" + with np.errstate( + call=_raise_linalgerror_singular, + invalid="call", + over="ignore", + divide="ignore", + under="ignore", + ): + r: Array = gufunc(x1, x2, signature=signature) return wrap(r.astype(result_t, copy=False)) + # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. -if hasattr(np.linalg, 'vector_norm'): +if hasattr(np.linalg, "vector_norm"): vector_norm = np.linalg.vector_norm else: vector_norm = get_xp(np)(_linalg.vector_norm) -__all__ = linalg_all + _linalg.__all__ + ['solve'] -del get_xp -del np -del linalg_all -del _linalg +__all__ = [ + "LinAlgError", + "cond", + "det", + "eig", + "eigvals", + "eigvalsh", + "inv", + "lstsq", + "matrix_power", + "multi_dot", + "norm", + "tensorinv", + "tensorsolve", +] +__all__ += _linalg.__all__ +__all__ += ["solve", "vector_norm"] + + +def __dir__() -> list[str]: + return __all__ From 9acba462fcae27c9db4027440386c2cdacfef78a Mon Sep 17 00:00:00 2001 From: jorenham <jhammudoglu@gmail.com> Date: Tue, 15 Apr 2025 15:45:24 +0200 Subject: [PATCH 20/24] TYP: add a `common._typing.Capabilities` typed dict type --- array_api_compat/common/_typing.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/array_api_compat/common/_typing.py b/array_api_compat/common/_typing.py index 477b6a4e..331e5a8f 100644 --- a/array_api_compat/common/_typing.py +++ b/array_api_compat/common/_typing.py @@ -1,10 +1,11 @@ from __future__ import annotations from types import ModuleType as Namespace -from typing import Any, Protocol, TypeAlias, TypeVar +from typing import Any, Protocol, TypeAlias, TypedDict, TypeVar _T_co = TypeVar("_T_co", covariant=True) + class NestedSequence(Protocol[_T_co]): def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ... def __len__(self, /) -> int: ... @@ -19,6 +20,16 @@ class HasShape(Protocol[_T_co]): def shape(self, /) -> _T_co: ... +Capabilities = TypedDict( + "Capabilities", + { + "boolean indexing": bool, + "data-dependent shapes": bool, + "max dimensions": int, + }, +) + + SupportsBufferProtocol: TypeAlias = Any Array: TypeAlias = Any Device: TypeAlias = Any @@ -27,12 +38,13 @@ def shape(self, /) -> _T_co: ... __all__ = [ "Array", - "SupportsArrayNamespace", + "Capabilities", "DType", "Device", "HasShape", "Namespace", "NestedSequence", + "SupportsArrayNamespace", "SupportsBufferProtocol", ] From ba0b4e58069cf9cd5a31df8f8483d867fbf33ea1 Mon Sep 17 00:00:00 2001 From: jorenham <jhammudoglu@gmail.com> Date: Tue, 15 Apr 2025 16:26:34 +0200 Subject: [PATCH 21/24] TYP: `__array_namespace_info__` helper types --- array_api_compat/common/_typing.py | 107 +++++++++++++++++++++++++++-- 1 file changed, 102 insertions(+), 5 deletions(-) diff --git a/array_api_compat/common/_typing.py b/array_api_compat/common/_typing.py index 331e5a8f..d7deade1 100644 --- a/array_api_compat/common/_typing.py +++ b/array_api_compat/common/_typing.py @@ -1,7 +1,22 @@ from __future__ import annotations +from collections.abc import Mapping from types import ModuleType as Namespace -from typing import Any, Protocol, TypeAlias, TypedDict, TypeVar +from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, TypedDict, TypeVar + +if TYPE_CHECKING: + from _typeshed import Incomplete + + SupportsBufferProtocol: TypeAlias = Incomplete + Array: TypeAlias = Incomplete + Device: TypeAlias = Incomplete + DType: TypeAlias = Incomplete +else: + SupportsBufferProtocol = object + Array = object + Device = object + DType = object + _T_co = TypeVar("_T_co", covariant=True) @@ -20,6 +35,7 @@ class HasShape(Protocol[_T_co]): def shape(self, /) -> _T_co: ... +# Return type of `__array_namespace_info__.default_dtypes` Capabilities = TypedDict( "Capabilities", { @@ -29,17 +45,98 @@ def shape(self, /) -> _T_co: ... }, ) +# Return type of `__array_namespace_info__.default_dtypes` +DefaultDTypes = TypedDict( + "DefaultDTypes", + { + "real floating": DType, + "complex floating": DType, + "integral": DType, + "indexing": DType, + }, +) + + +_DTypeKind: TypeAlias = Literal[ + "bool", + "signed integer", + "unsigned integer", + "integral", + "real floating", + "complex floating", + "numeric", +] +# Type of the `kind` parameter in `__array_namespace_info__.dtypes` +DTypeKind: TypeAlias = _DTypeKind | tuple[_DTypeKind, ...] + + +# `__array_namespace_info__.dtypes(kind="bool")` +class DTypesBool(TypedDict): + bool: DType + + +# `__array_namespace_info__.dtypes(kind="signed integer")` +class DTypesSigned(TypedDict): + int8: DType + int16: DType + int32: DType + int64: DType + + +# `__array_namespace_info__.dtypes(kind="unsigned integer")` +class DTypesUnsigned(TypedDict): + uint8: DType + uint16: DType + uint32: DType + uint64: DType + + +# `__array_namespace_info__.dtypes(kind="integral")` +class DTypesIntegral(DTypesSigned, DTypesUnsigned): + pass + + +# `__array_namespace_info__.dtypes(kind="real floating")` +class DTypesReal(TypedDict): + float32: DType + float64: DType + + +# `__array_namespace_info__.dtypes(kind="complex floating")` +class DTypesComplex(TypedDict): + complex64: DType + complex128: DType + + +# `__array_namespace_info__.dtypes(kind="numeric")` +class DTypesNumeric(DTypesIntegral, DTypesReal, DTypesComplex): + pass + + +# `__array_namespace_info__.dtypes(kind=None)` (default) +class DTypesAll(DTypesBool, DTypesNumeric): + pass + -SupportsBufferProtocol: TypeAlias = Any -Array: TypeAlias = Any -Device: TypeAlias = Any -DType: TypeAlias = Any +# `__array_namespace_info__.dtypes(kind=?)` (fallback) +DTypesAny: TypeAlias = Mapping[str, DType] __all__ = [ "Array", "Capabilities", "DType", + "DTypeKind", + "DTypesAny", + "DTypesAll", + "DTypesBool", + "DTypesNumeric", + "DTypesIntegral", + "DTypesSigned", + "DTypesUnsigned", + "DTypesReal", + "DTypesComplex", + "DefaultDTypes", "Device", "HasShape", "Namespace", From 4278dfba75c3b2a3063abfc35245acc9212c6ff5 Mon Sep 17 00:00:00 2001 From: jorenham <jhammudoglu@gmail.com> Date: Tue, 15 Apr 2025 16:42:35 +0200 Subject: [PATCH 22/24] TYP: `dask.array` typing fixes and improvements --- array_api_compat/dask/array/__init__.py | 8 +- array_api_compat/dask/array/_aliases.py | 162 ++++++++++++++---------- array_api_compat/dask/array/_info.py | 96 +++++++++++--- array_api_compat/dask/array/linalg.py | 22 ++-- 4 files changed, 188 insertions(+), 100 deletions(-) diff --git a/array_api_compat/dask/array/__init__.py b/array_api_compat/dask/array/__init__.py index bb649306..1e47b960 100644 --- a/array_api_compat/dask/array/__init__.py +++ b/array_api_compat/dask/array/__init__.py @@ -1,9 +1,11 @@ -from dask.array import * # noqa: F403 +from typing import Final + +from dask.array import * # noqa: F403 # These imports may overwrite names from the import * above. -from ._aliases import * # noqa: F403 +from ._aliases import * # noqa: F403 -__array_api_version__ = '2024.12' +__array_api_version__: Final = "2024.12" # See the comment in the numpy __init__.py __import__(__package__ + '.linalg') diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index e7ddde78..9687a9cd 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -1,28 +1,38 @@ +# pyright: reportPrivateUsage=false +# pyright: reportUnknownArgumentType=false +# pyright: reportUnknownMemberType=false +# pyright: reportUnknownVariableType=false + from __future__ import annotations -from typing import Callable, Optional, Union +from builtins import bool as py_bool +from collections.abc import Callable +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from typing_extensions import TypeIs +import dask.array as da import numpy as np +from numpy import bool_ as bool from numpy import ( - # dtypes - bool_ as bool, + can_cast, + complex64, + complex128, float32, float64, int8, int16, int32, int64, + result_type, uint8, uint16, uint32, uint64, - complex64, - complex128, - can_cast, - result_type, ) -import dask.array as da +from ..._internal import get_xp from ...common import _aliases, _helpers, array_namespace from ...common._typing import ( Array, @@ -31,7 +41,6 @@ NestedSequence, SupportsBufferProtocol, ) -from ..._internal import get_xp from ._info import __array_namespace_info__ isdtype = get_xp(np)(_aliases.isdtype) @@ -44,8 +53,8 @@ def astype( dtype: DType, /, *, - copy: bool = True, - device: Optional[Device] = None, + copy: py_bool = True, + device: Device | None = None, ) -> Array: """ Array API compatibility wrapper for astype(). @@ -69,14 +78,14 @@ def astype( # not pass stop/step as keyword arguments, which will cause # an error with dask def arange( - start: Union[int, float], + start: float, /, - stop: Optional[Union[int, float]] = None, - step: Union[int, float] = 1, + stop: float | None = None, + step: float = 1, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: """ Array API compatibility wrapper for arange(). @@ -87,7 +96,7 @@ def arange( # TODO: respect device keyword? _helpers._check_device(da, device) - args = [start] + args: list[Any] = [start] if stop is not None: args.append(stop) else: @@ -137,18 +146,13 @@ def arange( # asarray also adds the copy keyword, which is not present in numpy 1.0. def asarray( - obj: ( - Array - | bool | int | float | complex - | NestedSequence[bool | int | float | complex] - | SupportsBufferProtocol - ), + obj: complex | NestedSequence[complex] | Array | SupportsBufferProtocol, /, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - copy: Optional[bool] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + copy: py_bool | None = None, + **kwargs: object, ) -> Array: """ Array API compatibility wrapper for asarray(). @@ -164,7 +168,7 @@ def asarray( if copy is False: raise ValueError("Unable to avoid copy when changing dtype") obj = obj.astype(dtype) - return obj.copy() if copy else obj + return obj.copy() if copy else obj # pyright: ignore[reportAttributeAccessIssue] if copy is False: raise NotImplementedError( @@ -177,22 +181,21 @@ def asarray( return da.from_array(obj) -from dask.array import ( - # Element wise aliases - arccos as acos, - arccosh as acosh, - arcsin as asin, - arcsinh as asinh, - arctan as atan, - arctan2 as atan2, - arctanh as atanh, - left_shift as bitwise_left_shift, - right_shift as bitwise_right_shift, - invert as bitwise_invert, - power as pow, - # Other - concatenate as concat, -) +# Element wise aliases +from dask.array import arccos as acos +from dask.array import arccosh as acosh +from dask.array import arcsin as asin +from dask.array import arcsinh as asinh +from dask.array import arctan as atan +from dask.array import arctan2 as atan2 +from dask.array import arctanh as atanh + +# Other +from dask.array import concatenate as concat +from dask.array import invert as bitwise_invert +from dask.array import left_shift as bitwise_left_shift +from dask.array import power as pow +from dask.array import right_shift as bitwise_right_shift # dask.array.clip does not work unless all three arguments are provided. @@ -202,8 +205,8 @@ def asarray( def clip( x: Array, /, - min: Optional[Union[int, float, Array]] = None, - max: Optional[Union[int, float, Array]] = None, + min: float | Array | None = None, + max: float | Array | None = None, ) -> Array: """ Array API compatibility wrapper for clip(). @@ -212,8 +215,8 @@ def clip( specification for more details. """ - def _isscalar(a): - return isinstance(a, (int, float, type(None))) + def _isscalar(a: float | Array | None, /) -> TypeIs[float | None]: + return a is None or isinstance(a, (int, float)) min_shape = () if _isscalar(min) else min.shape max_shape = () if _isscalar(max) else max.shape @@ -266,7 +269,12 @@ def _ensure_single_chunk(x: Array, axis: int) -> tuple[Array, Callable[[Array], def sort( - x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True + x: Array, + /, + *, + axis: int = -1, + descending: py_bool = False, + stable: py_bool = True, ) -> Array: """ Array API compatibility layer around the lack of sort() in Dask. @@ -296,7 +304,12 @@ def sort( def argsort( - x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True + x: Array, + /, + *, + axis: int = -1, + descending: py_bool = False, + stable: py_bool = True, ) -> Array: """ Array API compatibility layer around the lack of argsort() in Dask. @@ -330,25 +343,34 @@ def argsort( # dask.array.count_nonzero does not have keepdims def count_nonzero( x: Array, - axis=None, - keepdims=False + axis: int | None = None, + keepdims: py_bool = False, ) -> Array: - result = da.count_nonzero(x, axis) - if keepdims: - if axis is None: - return da.reshape(result, [1]*x.ndim) - return da.expand_dims(result, axis) - return result - - + result = da.count_nonzero(x, axis) + if keepdims: + if axis is None: + return da.reshape(result, [1] * x.ndim) + return da.expand_dims(result, axis) + return result + + +__all__ = [ + "__array_namespace_info__", + "count_nonzero", + "bool", + "int8", "int16", "int32", "int64", + "uint8", "uint16", "uint32", "uint64", + "float32", "float64", + "complex64", "complex128", + "asarray", "astype", "can_cast", "result_type", + "pow", + "concat", + "acos", "acosh", "asin", "asinh", "atan", "atan2", "atanh", + "bitwise_left_shift", "bitwise_right_shift", "bitwise_invert", +] # fmt: skip +__all__ += _aliases.__all__ +_all_ignore = ["array_namespace", "get_xp", "da", "np"] -__all__ = _aliases.__all__ + [ - '__array_namespace_info__', 'asarray', 'astype', 'acos', - 'acosh', 'asin', 'asinh', 'atan', 'atan2', - 'atanh', 'bitwise_left_shift', 'bitwise_invert', - 'bitwise_right_shift', 'concat', 'pow', 'can_cast', - 'result_type', 'bool', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64', - 'uint8', 'uint16', 'uint32', 'uint64', 'complex64', 'complex128', - 'can_cast', 'count_nonzero', 'result_type'] -_all_ignore = ["array_namespace", "get_xp", "da", "np"] +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/dask/array/_info.py b/array_api_compat/dask/array/_info.py index 614f43d9..9e4d736f 100644 --- a/array_api_compat/dask/array/_info.py +++ b/array_api_compat/dask/array/_info.py @@ -7,25 +7,51 @@ more details. """ + +# pyright: reportPrivateUsage=false + +from __future__ import annotations + +from typing import Literal as L +from typing import TypeAlias, overload + +from numpy import bool_ as bool from numpy import ( + complex64, + complex128, dtype, - bool_ as bool, - intp, + float32, + float64, int8, int16, int32, int64, + intp, uint8, uint16, uint32, uint64, - float32, - float64, - complex64, - complex128, ) -from ...common._helpers import _DASK_DEVICE +from ...common._helpers import _DASK_DEVICE, _dask_device +from ...common._typing import ( + Capabilities, + DefaultDTypes, + DType, + DTypeKind, + DTypesAll, + DTypesAny, + DTypesBool, + DTypesComplex, + DTypesIntegral, + DTypesNumeric, + DTypesReal, + DTypesSigned, + DTypesUnsigned, +) + +_Device: TypeAlias = L["cpu"] | _dask_device + class __array_namespace_info__: """ @@ -59,9 +85,9 @@ class __array_namespace_info__: """ - __module__ = 'dask.array' + __module__ = "dask.array" - def capabilities(self): + def capabilities(self) -> Capabilities: """ Return a dictionary of array API library capabilities. @@ -116,7 +142,7 @@ def capabilities(self): "max dimensions": 64, } - def default_device(self): + def default_device(self) -> L["cpu"]: """ The default device used for new Dask arrays. @@ -143,7 +169,7 @@ def default_device(self): """ return "cpu" - def default_dtypes(self, *, device=None): + def default_dtypes(self, /, *, device: _Device | None = None) -> DefaultDTypes: """ The default data types used for new Dask arrays. @@ -184,8 +210,8 @@ def default_dtypes(self, *, device=None): """ if device not in ["cpu", _DASK_DEVICE, None]: raise ValueError( - 'Device not understood. Only "cpu" or _DASK_DEVICE is allowed, but received:' - f' {device}' + f'Device not understood. Only "cpu" or _DASK_DEVICE is allowed, ' + f"but received: {device!r}" ) return { "real floating": dtype(float64), @@ -194,7 +220,41 @@ def default_dtypes(self, *, device=None): "indexing": dtype(intp), } - def dtypes(self, *, device=None, kind=None): + @overload + def dtypes( + self, /, *, device: _Device | None = None, kind: None = None + ) -> DTypesAll: ... + @overload + def dtypes( + self, /, *, device: _Device | None = None, kind: L["bool"] + ) -> DTypesBool: ... + @overload + def dtypes( + self, /, *, device: _Device | None = None, kind: L["signed integer"] + ) -> DTypesSigned: ... + @overload + def dtypes( + self, /, *, device: _Device | None = None, kind: L["unsigned integer"] + ) -> DTypesUnsigned: ... + @overload + def dtypes( + self, /, *, device: _Device | None = None, kind: L["integral"] + ) -> DTypesIntegral: ... + @overload + def dtypes( + self, /, *, device: _Device | None = None, kind: L["real floating"] + ) -> DTypesReal: ... + @overload + def dtypes( + self, /, *, device: _Device | None = None, kind: L["complex floating"] + ) -> DTypesComplex: ... + @overload + def dtypes( + self, /, *, device: _Device | None = None, kind: L["numeric"] + ) -> DTypesNumeric: ... + def dtypes( + self, /, *, device: _Device | None = None, kind: DTypeKind | None = None + ) -> DTypesAny: """ The array API data types supported by Dask. @@ -251,7 +311,7 @@ def dtypes(self, *, device=None, kind=None): if device not in ["cpu", _DASK_DEVICE, None]: raise ValueError( 'Device not understood. Only "cpu" or _DASK_DEVICE is allowed, but received:' - f' {device}' + f" {device}" ) if kind is None: return { @@ -321,14 +381,14 @@ def dtypes(self, *, device=None, kind=None): "complex64": dtype(complex64), "complex128": dtype(complex128), } - if isinstance(kind, tuple): - res = {} + if isinstance(kind, tuple): # type: ignore[reportUnnecessaryIsinstanceCall] + res: dict[str, DType] = {} for k in kind: res.update(self.dtypes(kind=k)) return res raise ValueError(f"unsupported kind: {kind!r}") - def devices(self): + def devices(self) -> list[_Device]: """ The devices supported by Dask. diff --git a/array_api_compat/dask/array/linalg.py b/array_api_compat/dask/array/linalg.py index bd53f0df..fd509a38 100644 --- a/array_api_compat/dask/array/linalg.py +++ b/array_api_compat/dask/array/linalg.py @@ -3,15 +3,16 @@ from typing import Literal import dask.array as da -# Exports -from dask.array.linalg import * # noqa: F403 -from dask.array import outer + # These functions are in both the main and linalg namespaces -from dask.array import matmul, tensordot +from dask.array import matmul, outer, tensordot + +# Exports +from dask.array.linalg import * # noqa: F403 from ..._internal import get_xp from ...common import _linalg -from ...common._typing import Array +from ...common._typing import Array as _Array from ._aliases import matrix_transpose, vecdot # dask.array.linalg doesn't have __all__. If it is added, replace this with @@ -32,8 +33,11 @@ # supports the mode keyword on QR # https://github.com/dask/dask/issues/10388 #qr = get_xp(da)(_linalg.qr) -def qr(x: Array, mode: Literal['reduced', 'complete'] = 'reduced', - **kwargs) -> QRResult: +def qr( + x: _Array, + mode: Literal["reduced", "complete"] = "reduced", + **kwargs: object, +) -> QRResult: if mode != "reduced": raise ValueError("dask arrays only support using mode='reduced'") return QRResult(*da.linalg.qr(x, **kwargs)) @@ -46,12 +50,12 @@ def qr(x: Array, mode: Literal['reduced', 'complete'] = 'reduced', # Wrap the svd functions to not pass full_matrices to dask # when full_matrices=False (as that is the default behavior for dask), # and dask doesn't have the full_matrices keyword -def svd(x: Array, full_matrices: bool = True, **kwargs) -> SVDResult: +def svd(x: _Array, full_matrices: bool = True, **kwargs) -> SVDResult: if full_matrices: raise ValueError("full_matrics=True is not supported by dask.") return da.linalg.svd(x, coerce_signs=False, **kwargs) -def svdvals(x: Array) -> Array: +def svdvals(x: _Array) -> _Array: # TODO: can't avoid computing U or V for dask _, s, _ = svd(x) return s From 4f6ef6d46c43284d69c59cff66e9fae88d7498d7 Mon Sep 17 00:00:00 2001 From: Joren Hammudoglu <jhammudoglu@gmail.com> Date: Thu, 17 Apr 2025 18:10:17 +0200 Subject: [PATCH 23/24] STY: give the `=` some breathing room Co-authored-by: Lucas Colley <lucas.colley8@gmail.com> --- array_api_compat/common/_helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 0ce6edf0..8596c534 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -40,7 +40,7 @@ # TODO: import from typing (requires Python >=3.13) from typing_extensions import TypeIs, TypeVar - _SizeT = TypeVar("_SizeT", bound=int | None) + _SizeT = TypeVar("_SizeT", bound = int | None) _ZeroGradientArray: TypeAlias = npt.NDArray[np.void] _CupyArray: TypeAlias = Any # cupy has no py.typed From d758d6fdfd539816a0c4542804aa877e0b9a2be2 Mon Sep 17 00:00:00 2001 From: jorenham <jhammudoglu@gmail.com> Date: Thu, 17 Apr 2025 18:50:21 +0200 Subject: [PATCH 24/24] STY: apply review suggestions Co-authored-by: lucascolley <lucas.colley8@gmail.com> --- array_api_compat/common/_aliases.py | 2 +- array_api_compat/common/_helpers.py | 13 ++++++------- array_api_compat/dask/array/linalg.py | 2 +- array_api_compat/numpy/_aliases.py | 6 +++++- 4 files changed, 13 insertions(+), 10 deletions(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 7f3cf914..8ea9162a 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -652,7 +652,7 @@ def sign(x: Array, /, xp: Namespace, **kwargs: object) -> Array: if isdtype(x.dtype, "complex floating", xp=xp): out = (x / xp.abs(x, **kwargs))[...] # sign(0) = 0 but the above formula would give nan - out[x == 0 + 0j] = 0 + 0j + out[x == 0j] = 0j else: out = xp.sign(x, **kwargs) # CuPy sign() does not propagate nans. See diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 8596c534..db3e4cd7 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -16,6 +16,7 @@ from typing import ( TYPE_CHECKING, Any, + Final, Literal, SupportsIndex, TypeAlias, @@ -56,6 +57,9 @@ | _CupyArray ) +_API_VERSIONS_OLD: Final = frozenset({"2021.12", "2022.12", "2023.12"}) +_API_VERSIONS: Final = _API_VERSIONS_OLD | frozenset({"2024.12"}) + def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]: """Return True if `x` is a zero-gradient array. @@ -477,16 +481,11 @@ def is_array_api_strict_namespace(xp: Namespace) -> bool: def _check_api_version(api_version: str | None) -> None: - if api_version in ["2021.12", "2022.12", "2023.12"]: + if api_version in _API_VERSIONS_OLD: warnings.warn( f"The {api_version} version of the array API specification was requested but the returned namespace is actually version 2024.12" ) - elif api_version is not None and api_version not in [ - "2021.12", - "2022.12", - "2023.12", - "2024.12", - ]: + elif api_version is not None and api_version not in _API_VERSIONS: raise ValueError( "Only the 2024.12 version of the array API specification is currently supported" ) diff --git a/array_api_compat/dask/array/linalg.py b/array_api_compat/dask/array/linalg.py index fd509a38..0825386e 100644 --- a/array_api_compat/dask/array/linalg.py +++ b/array_api_compat/dask/array/linalg.py @@ -4,7 +4,7 @@ import dask.array as da -# These functions are in both the main and linalg namespaces +# The `matmul` and `tensordot` functions are in both the main and linalg namespaces from dask.array import matmul, outer, tensordot # Exports diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 2c03041f..d8792611 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -15,6 +15,8 @@ if TYPE_CHECKING: from typing_extensions import Buffer, TypeIs +# The values of the `_CopyMode` enum can be either `False`, `True`, or `2`: +# https://github.com/numpy/numpy/blob/5a8a6a79d9c2fff8f07dcab5d41e14f8508d673f/numpy/_globals.pyi#L7-L10 _Copy: TypeAlias = py_bool | Literal[2] | np._CopyMode bool = np.bool_ @@ -130,7 +132,9 @@ def count_nonzero( axis: int | tuple[int, ...] | None = None, keepdims: py_bool = False, ) -> Array: - result = cast("Any", np.count_nonzero(x, axis=axis, keepdims=keepdims)) # pyright: ignore + # NOTE: this is currently incorrectly typed in numpy, but will be fixed in + # numpy 2.2.5 and 2.3.0: https://github.com/numpy/numpy/pull/28750 + result = cast("Any", np.count_nonzero(x, axis=axis, keepdims=keepdims)) # pyright: ignore[reportArgumentType, reportCallIssue] if axis is None and not keepdims: return np.asarray(result) return result