diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml new file mode 100644 index 00000000..9a09ffd7 --- /dev/null +++ b/.github/workflows/ruff.yml @@ -0,0 +1,19 @@ +name: CI +on: [push, pull_request] +jobs: + check-ruff: + runs-on: ubuntu-latest + continue-on-error: true + steps: + - uses: actions/checkout@v3 + - name: Install Python + uses: actions/setup-python@v4 + with: + python-version: "3.11" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install ruff + # Update output format to enable automatic inline annotations. + - name: Run Ruff + run: ruff check --output-format=github --select F822,PLC0414,RUF022 --preview . diff --git a/array_api_compat/__init__.py b/array_api_compat/__init__.py index 28ffc7e7..29e7be04 100644 --- a/array_api_compat/__init__.py +++ b/array_api_compat/__init__.py @@ -19,4 +19,4 @@ """ __version__ = '1.4.1' -from .common import * +from .common import * # noqa: F401, F403 diff --git a/array_api_compat/_internal.py b/array_api_compat/_internal.py index 553c0356..9073dd52 100644 --- a/array_api_compat/_internal.py +++ b/array_api_compat/_internal.py @@ -5,6 +5,7 @@ from functools import wraps from inspect import signature + def get_xp(xp): """ Decorator to automatically replace xp with the corresponding array module. @@ -21,13 +22,16 @@ def func(x, /, xp, kwarg=None): arguments. """ + def inner(f): @wraps(f) def wrapped_f(*args, **kwargs): return f(*args, xp=xp, **kwargs) sig = signature(f) - new_sig = sig.replace(parameters=[sig.parameters[i] for i in sig.parameters if i != 'xp']) + new_sig = sig.replace( + parameters=[sig.parameters[i] for i in sig.parameters if i != "xp"] + ) if wrapped_f.__doc__ is None: wrapped_f.__doc__ = f"""\ @@ -41,3 +45,31 @@ def wrapped_f(*args, **kwargs): return wrapped_f return inner + + +def _get_all_public_members(module, exclude=None, extend_all=False): + """Get all public members of a module. + + Parameters + ---------- + module : module + The module to get members from. + exclude : callable, optional + A callable that takes a name and returns True if the name should be + excluded from the list of members. + extend_all : bool, optional + If True, extend the module's __all__ attribute with the members of the + module derived from dir(module). To be used for libraries that do not have a complete __all__ list. + """ + members = getattr(module, "__all__", []) + + if members and not extend_all: + return members + + if exclude is None: + exclude = lambda name: name.startswith("_") # noqa: E731 + + members = members + [_ for _ in dir(module) if not exclude(_)] + + # remove duplicates + return list(set(members)) diff --git a/array_api_compat/common/__init__.py b/array_api_compat/common/__init__.py index ce3f44dd..b941a31e 100644 --- a/array_api_compat/common/__init__.py +++ b/array_api_compat/common/__init__.py @@ -1 +1,17 @@ -from ._helpers import * +from ._helpers import ( + array_namespace, + device, + get_namespace, + is_array_api_obj, + size, + to_device, +) + +__all__ = [ + "array_namespace", + "device", + "get_namespace", + "is_array_api_obj", + "size", + "to_device", +] diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 7713213e..b58fb0ca 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -6,7 +6,8 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Optional, Sequence, Tuple, Union, List + import numpy as np + from typing import Optional, Sequence, Tuple, Union from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol from typing import NamedTuple @@ -544,11 +545,3 @@ def isdtype( # more strict here to match the type annotation? Note that the # numpy.array_api implementation will be very strict. return dtype == kind - -__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like', - 'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like', - 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', - 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', - 'astype', 'std', 'var', 'permute_dims', 'reshape', 'argsort', - 'sort', 'nonzero', 'sum', 'prod', 'ceil', 'floor', 'trunc', - 'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype'] diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 82bf47c1..ac866551 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -7,6 +7,12 @@ """ from __future__ import annotations +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Optional, Union, Any + from ._typing import Array, Device + import sys import math @@ -142,7 +148,7 @@ def _check_device(xp, device): # wrapping or subclassing them. These helper functions can be used instead of # the wrapper functions for libraries that need to support both NumPy/CuPy and # other libraries that use devices. -def device(x: "Array", /) -> "Device": +def device(x: Array, /) -> Device: """ Hardware device the array data resides on. @@ -204,7 +210,7 @@ def _torch_to_device(x, device, /, stream=None): raise NotImplementedError return x.to(device) -def to_device(x: "Array", device: "Device", /, *, stream: "Optional[Union[int, Any]]" = None) -> "Array": +def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] = None) -> Array: """ Copy the array from the device on which it currently resides to the specified ``device``. @@ -252,5 +258,3 @@ def size(x): if None in x.shape: return None return math.prod(x.shape) - -__all__ = ['is_array_api_obj', 'array_namespace', 'get_namespace', 'device', 'to_device', 'size'] diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index 9f0c993f..0708b76a 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, NamedTuple if TYPE_CHECKING: - from typing import Literal, Optional, Sequence, Tuple, Union + from typing import Literal, Optional, Tuple, Union from ._typing import ndarray import numpy as np @@ -11,7 +11,7 @@ else: from numpy.core.numeric import normalize_axis_tuple -from ._aliases import matmul, matrix_transpose, tensordot, vecdot, isdtype +from ._aliases import matrix_transpose, isdtype from .._internal import get_xp # These are in the main NumPy namespace but not in numpy.linalg @@ -149,10 +149,4 @@ def trace(x: ndarray, /, xp, *, offset: int = 0, dtype=None, **kwargs) -> ndarra dtype = xp.float64 elif x.dtype == xp.complex64: dtype = xp.complex128 - return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs)) - -__all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult', - 'QRResult', 'SlogdetResult', 'SVDResult', 'eigh', 'qr', 'slogdet', - 'svd', 'cholesky', 'matrix_rank', 'pinv', 'matrix_norm', - 'matrix_transpose', 'svdvals', 'vecdot', 'vector_norm', 'diagonal', - 'trace'] + return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs)) \ No newline at end of file diff --git a/array_api_compat/common/_typing.py b/array_api_compat/common/_typing.py index 3f178060..f20d085c 100644 --- a/array_api_compat/common/_typing.py +++ b/array_api_compat/common/_typing.py @@ -18,3 +18,6 @@ def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ... def __len__(self, /) -> int: ... SupportsBufferProtocol = Any + +Array = Any +Device = Any \ No newline at end of file diff --git a/array_api_compat/cupy/__init__.py b/array_api_compat/cupy/__init__.py index ec113f9d..b5eb5eea 100644 --- a/array_api_compat/cupy/__init__.py +++ b/array_api_compat/cupy/__init__.py @@ -1,16 +1,153 @@ -from cupy import * +import cupy as _cp +from cupy import * # noqa: F401, F403 # from cupy import * doesn't overwrite these builtin names from cupy import abs, max, min, round +from .._internal import _get_all_public_members +from ..common._helpers import ( + array_namespace, + device, + get_namespace, + is_array_api_obj, + size, + to_device, +) + # These imports may overwrite names from the import * above. -from ._aliases import * +from ._aliases import ( + UniqueAllResult, + UniqueCountsResult, + UniqueInverseResult, + acos, + acosh, + arange, + argsort, + asarray, + asarray_cupy, + asin, + asinh, + astype, + atan, + atan2, + atanh, + bitwise_invert, + bitwise_left_shift, + bitwise_right_shift, + bool, + ceil, + concat, + empty, + empty_like, + eye, + floor, + full, + full_like, + isdtype, + linspace, + matmul, + matrix_transpose, + nonzero, + ones, + ones_like, + permute_dims, + pow, + prod, + reshape, + sort, + std, + sum, + tensordot, + trunc, + unique_all, + unique_counts, + unique_inverse, + unique_values, + var, + vecdot, + zeros, + zeros_like, +) -# See the comment in the numpy __init__.py -__import__(__package__ + '.linalg') +__all__ = [] + +__all__ += _get_all_public_members(_cp) + +__all__ += [ + "abs", + "max", + "min", + "round", +] -from .linalg import matrix_transpose, vecdot +__all__ += [ + "array_namespace", + "device", + "get_namespace", + "is_array_api_obj", + "size", + "to_device", +] -from ..common._helpers import * +__all__ += [ + "UniqueAllResult", + "UniqueCountsResult", + "UniqueInverseResult", + "acos", + "acosh", + "arange", + "argsort", + "asarray", + "asarray_cupy", + "asin", + "asinh", + "astype", + "atan", + "atan2", + "atanh", + "bitwise_invert", + "bitwise_left_shift", + "bitwise_right_shift", + "bool", + "ceil", + "concat", + "empty", + "empty_like", + "eye", + "floor", + "full", + "full_like", + "isdtype", + "linspace", + "matmul", + "matrix_transpose", + "nonzero", + "ones", + "ones_like", + "permute_dims", + "pow", + "prod", + "reshape", + "sort", + "std", + "sum", + "tensordot", + "trunc", + "unique_all", + "unique_counts", + "unique_inverse", + "unique_values", + "var", + "zeros", + "zeros_like", +] + +__all__ += [ + "matrix_transpose", + "vecdot", +] + +# See the comment in the numpy __init__.py +__import__(__package__ + ".linalg") -__array_api_version__ = '2022.12' +__array_api_version__ = "2022.12" diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index d1d3cda9..71ffadce 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -2,15 +2,16 @@ from functools import partial +import cupy as cp + from ..common import _aliases +from ..common import _linalg from .._internal import get_xp asarray = asarray_cupy = partial(_aliases._asarray, namespace='cupy') asarray.__doc__ = _aliases._asarray.__doc__ -del partial -import cupy as cp bool = cp.bool_ # Basic renames @@ -73,7 +74,28 @@ else: isdtype = get_xp(cp)(_aliases.isdtype) -__all__ = _aliases.__all__ + ['asarray', 'asarray_cupy', 'bool', 'acos', - 'acosh', 'asin', 'asinh', 'atan', 'atan2', - 'atanh', 'bitwise_left_shift', 'bitwise_invert', - 'bitwise_right_shift', 'concat', 'pow'] + +cross = get_xp(cp)(_linalg.cross) +outer = get_xp(cp)(_linalg.outer) +EighResult = _linalg.EighResult +QRResult = _linalg.QRResult +SlogdetResult = _linalg.SlogdetResult +SVDResult = _linalg.SVDResult +eigh = get_xp(cp)(_linalg.eigh) +qr = get_xp(cp)(_linalg.qr) +slogdet = get_xp(cp)(_linalg.slogdet) +svd = get_xp(cp)(_linalg.svd) +cholesky = get_xp(cp)(_linalg.cholesky) +matrix_rank = get_xp(cp)(_linalg.matrix_rank) +pinv = get_xp(cp)(_linalg.pinv) +matrix_norm = get_xp(cp)(_linalg.matrix_norm) +svdvals = get_xp(cp)(_linalg.svdvals) +diagonal = get_xp(cp)(_linalg.diagonal) +trace = get_xp(cp)(_linalg.trace) + +# These functions are completely new here. If the library already has them +# (i.e., numpy 2.0), use the library version instead of our wrapper. +if hasattr(cp.linalg, 'vector_norm'): + vector_norm = cp.linalg.vector_norm +else: + vector_norm = get_xp(cp)(_linalg.vector_norm) \ No newline at end of file diff --git a/array_api_compat/cupy/_typing.py b/array_api_compat/cupy/_typing.py index f3d9aab6..b867ca94 100644 --- a/array_api_compat/cupy/_typing.py +++ b/array_api_compat/cupy/_typing.py @@ -1,9 +1,9 @@ from __future__ import annotations __all__ = [ - "ndarray", "Device", "Dtype", + "ndarray", ] import sys diff --git a/array_api_compat/cupy/linalg.py b/array_api_compat/cupy/linalg.py index 84752e1a..cef74183 100644 --- a/array_api_compat/cupy/linalg.py +++ b/array_api_compat/cupy/linalg.py @@ -1,47 +1,62 @@ -from cupy.linalg import * -# cupy.linalg doesn't have __all__. If it is added, replace this with -# -# from cupy.linalg import __all__ as linalg_all -_n = {} -exec('from cupy.linalg import *', _n) -del _n['__builtins__'] -linalg_all = list(_n) -del _n +import cupy as _cp -from ..common import _linalg -from .._internal import get_xp -from ._aliases import (matmul, matrix_transpose, tensordot, vecdot) +from .._internal import _get_all_public_members -import cupy as cp +_cupy_linalg_all = _get_all_public_members(_cp.linalg) -cross = get_xp(cp)(_linalg.cross) -outer = get_xp(cp)(_linalg.outer) -EighResult = _linalg.EighResult -QRResult = _linalg.QRResult -SlogdetResult = _linalg.SlogdetResult -SVDResult = _linalg.SVDResult -eigh = get_xp(cp)(_linalg.eigh) -qr = get_xp(cp)(_linalg.qr) -slogdet = get_xp(cp)(_linalg.slogdet) -svd = get_xp(cp)(_linalg.svd) -cholesky = get_xp(cp)(_linalg.cholesky) -matrix_rank = get_xp(cp)(_linalg.matrix_rank) -pinv = get_xp(cp)(_linalg.pinv) -matrix_norm = get_xp(cp)(_linalg.matrix_norm) -svdvals = get_xp(cp)(_linalg.svdvals) -diagonal = get_xp(cp)(_linalg.diagonal) -trace = get_xp(cp)(_linalg.trace) +for _name in _cupy_linalg_all: + globals()[_name] = getattr(_cp.linalg, _name) -# These functions are completely new here. If the library already has them -# (i.e., numpy 2.0), use the library version instead of our wrapper. -if hasattr(cp.linalg, 'vector_norm'): - vector_norm = cp.linalg.vector_norm -else: - vector_norm = get_xp(cp)(_linalg.vector_norm) +from ._aliases import ( # noqa: E402 + EighResult, + QRResult, + SlogdetResult, + SVDResult, + cholesky, + cross, + diagonal, + eigh, + matmul, + matrix_norm, + matrix_rank, + matrix_transpose, + outer, + pinv, + qr, + slogdet, + svd, + svdvals, + tensordot, + trace, + vecdot, + vector_norm, +) -__all__ = linalg_all + _linalg.__all__ +__all__ = [] -del get_xp -del cp -del linalg_all -del _linalg +__all__ += _cupy_linalg_all + +__all__ += [ + "EighResult", + "QRResult", + "SVDResult", + "SlogdetResult", + "cholesky", + "cross", + "diagonal", + "eigh", + "matmul", + "matrix_norm", + "matrix_rank", + "matrix_transpose", + "outer", + "pinv", + "qr", + "slogdet", + "svd", + "svdvals", + "tensordot", + "trace", + "vecdot", + "vector_norm", +] diff --git a/array_api_compat/dask/array/__init__.py b/array_api_compat/dask/array/__init__.py index a7c0b22e..d6b5e94e 100644 --- a/array_api_compat/dask/array/__init__.py +++ b/array_api_compat/dask/array/__init__.py @@ -1,8 +1,210 @@ -from dask.array import * +import dask.array as _da +from dask.array import * # noqa: F401, F403 +from dask.array import ( + # Element wise aliases + arccos as acos, +) +from dask.array import ( + arccosh as acosh, +) +from dask.array import ( + arcsin as asin, +) +from dask.array import ( + arcsinh as asinh, +) +from dask.array import ( + arctan as atan, +) +from dask.array import ( + arctan2 as atan2, +) +from dask.array import ( + arctanh as atanh, +) +from dask.array import ( + bool_ as bool, +) +from dask.array import ( + # Other + concatenate as concat, +) +from dask.array import ( + invert as bitwise_invert, +) +from dask.array import ( + left_shift as bitwise_left_shift, +) +from dask.array import ( + power as pow, +) +from dask.array import ( + right_shift as bitwise_right_shift, +) # These imports may overwrite names from the import * above. -from ._aliases import * +from numpy import ( + can_cast, + complex64, + complex128, + e, + finfo, + float32, + float64, + iinfo, + inf, + int8, + int16, + int32, + int64, + nan, + newaxis, + pi, + result_type, + uint8, + uint16, + uint32, + uint64, +) -__array_api_version__ = '2022.12' +from ..common._helpers import ( + array_namespace, + device, + get_namespace, + is_array_api_obj, + size, + to_device, +) +from ..internal import _get_all_public_members +from ._aliases import ( + UniqueAllResult, + UniqueCountsResult, + UniqueInverseResult, + arange, + asarray, + astype, + ceil, + empty, + empty_like, + eye, + floor, + full, + full_like, + isdtype, + linspace, + matmul, + matrix_transpose, + nonzero, + ones, + ones_like, + permute_dims, + prod, + reshape, + std, + sum, + tensordot, + trunc, + unique_all, + unique_counts, + unique_inverse, + unique_values, + var, + vecdot, + zeros, + zeros_like, +) -__import__(__package__ + '.linalg') +__all__ = [] + +__all__ += _get_all_public_members(_da) + +__all__ += [ + "can_cast", + "complex64", + "complex128", + "e", + "finfo", + "float32", + "float64", + "iinfo", + "inf", + "int8", + "int16", + "int32", + "int64", + "nan", + "newaxis", + "pi", + "result_type", + "uint8", + "uint16", + "uint32", + "uint64", +] + +__all__ += [ + "array_namespace", + "device", + "get_namespace", + "is_array_api_obj", + "size", + "to_device", +] + +# 'sort', 'argsort' are unsupported by dask.array + +__all__ += [ + "UniqueAllResult", + "UniqueCountsResult", + "UniqueInverseResult", + "acos", + "acosh", + "arange", + "asarray", + "asin", + "asinh", + "astype", + "atan", + "atan2", + "atanh", + "bitwise_invert", + "bitwise_left_shift", + "bitwise_right_shift", + "bool", + "ceil", + "concat", + "empty", + "empty_like", + "eye", + "floor", + "full", + "full_like", + "isdtype", + "linspace", + "matmul", + "matrix_transpose", + "nonzero", + "ones", + "ones_like", + "permute_dims", + "pow", + "prod", + "reshape", + "std", + "sum", + "tensordot", + "trunc", + "unique_all", + "unique_counts", + "unique_inverse", + "unique_values", + "var", + "vecdot", + "zeros", + "zeros_like", +] + + +__array_api_version__ = "2022.12" + +__import__(__package__ + ".linalg") diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index ef9ea356..14b27070 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -1,42 +1,18 @@ from __future__ import annotations -from ...common import _aliases -from ...common._helpers import _check_device - -from ..._internal import get_xp +from functools import partial +from typing import TYPE_CHECKING import numpy as np -from numpy import ( - # Constants - e, - inf, - nan, - pi, - newaxis, - # Dtypes - bool_ as bool, - float32, - float64, - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, - complex64, - complex128, - iinfo, - finfo, - can_cast, - result_type, -) -from typing import TYPE_CHECKING +from ..._internal import get_xp +from ...common import _aliases, _linalg +from ...common._helpers import _check_device + if TYPE_CHECKING: - from typing import Optional, Union - from ...common._typing import ndarray, Device, Dtype + from typing import Optional, Tuple, Union + + from ...common._typing import Device, Dtype, ndarray import dask.array as da @@ -49,6 +25,7 @@ # not pass stop/step as keyword arguments, which will cause # an error with dask + # TODO: delete the xp stuff, it shouldn't be necessary def dask_arange( start: Union[int, float], @@ -59,7 +36,7 @@ def dask_arange( xp, dtype: Optional[Dtype] = None, device: Optional[Device] = None, - **kwargs + **kwargs, ) -> ndarray: _check_device(xp, device) args = [start] @@ -72,11 +49,11 @@ def dask_arange( args.append(step) return xp.arange(*args, dtype=dtype, **kwargs) + arange = get_xp(da)(dask_arange) eye = get_xp(da)(_aliases.eye) -from functools import partial -asarray = partial(_aliases._asarray, namespace='dask.array') +asarray = partial(_aliases._asarray, namespace="dask.array") asarray.__doc__ = _aliases._asarray.__doc__ linspace = get_xp(da)(_aliases.linspace) @@ -112,34 +89,22 @@ def dask_arange( matmul = get_xp(np)(_aliases.matmul) tensordot = get_xp(np)(_aliases.tensordot) -from dask.array import ( - # Element wise aliases - arccos as acos, - arccosh as acosh, - arcsin as asin, - arcsinh as asinh, - arctan as atan, - arctan2 as atan2, - arctanh as atanh, - left_shift as bitwise_left_shift, - right_shift as bitwise_right_shift, - invert as bitwise_invert, - power as pow, - # Other - concatenate as concat, -) - -# exclude these from all since -_da_unsupported = ['sort', 'argsort'] - -common_aliases = [alias for alias in _aliases.__all__ if alias not in _da_unsupported] - -__all__ = common_aliases + ['asarray', 'bool', 'acos', - 'acosh', 'asin', 'asinh', 'atan', 'atan2', - 'atanh', 'bitwise_left_shift', 'bitwise_invert', - 'bitwise_right_shift', 'concat', 'pow', - 'e', 'inf', 'nan', 'pi', 'newaxis', 'float32', 'float64', 'int8', - 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64', - 'complex64', 'complex128', 'iinfo', 'finfo', 'can_cast', 'result_type'] - -del da, partial, common_aliases, _da_unsupported, + +EighResult = _linalg.EighResult +QRResult = _linalg.QRResult +SlogdetResult = _linalg.SlogdetResult +SVDResult = _linalg.SVDResult +qr = get_xp(da)(_linalg.qr) +cholesky = get_xp(da)(_linalg.cholesky) +matrix_rank = get_xp(da)(_linalg.matrix_rank) +matrix_norm = get_xp(da)(_linalg.matrix_norm) + + +def svdvals(x: ndarray) -> Union[ndarray, Tuple[ndarray, ...]]: + # TODO: can't avoid computing U or V for dask + _, s, _ = da.linalg.svd(x) + return s + + +vector_norm = get_xp(da)(_linalg.vector_norm) +diagonal = get_xp(da)(_linalg.diagonal) diff --git a/array_api_compat/dask/array/linalg.py b/array_api_compat/dask/array/linalg.py index c8aa7c9f..cc9ac880 100644 --- a/array_api_compat/dask/array/linalg.py +++ b/array_api_compat/dask/array/linalg.py @@ -1,48 +1,50 @@ -from __future__ import annotations - -from dask.array.linalg import * -from ...common import _linalg -from ..._internal import get_xp -from dask.array import matmul, tensordot, trace, outer -from ._aliases import matrix_transpose, vecdot - -import dask.array as da - -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from typing import Union, Tuple - from ...common._typing import ndarray - -# cupy.linalg doesn't have __all__. If it is added, replace this with -# -# from cupy.linalg import __all__ as linalg_all -_n = {} -exec('from dask.array.linalg import *', _n) -del _n['__builtins__'] -linalg_all = list(_n) -del _n - -EighResult = _linalg.EighResult -QRResult = _linalg.QRResult -SlogdetResult = _linalg.SlogdetResult -SVDResult = _linalg.SVDResult -qr = get_xp(da)(_linalg.qr) -cholesky = get_xp(da)(_linalg.cholesky) -matrix_rank = get_xp(da)(_linalg.matrix_rank) -matrix_norm = get_xp(da)(_linalg.matrix_norm) - -def svdvals(x: ndarray) -> Union[ndarray, Tuple[ndarray, ...]]: - # TODO: can't avoid computing U or V for dask - _, s, _ = svd(x) - return s - -vector_norm = get_xp(da)(_linalg.vector_norm) -diagonal = get_xp(da)(_linalg.diagonal) - -__all__ = linalg_all + ["EighResult", "QRResult", "SlogdetResult", - "SVDResult", "qr", "cholesky", "matrix_rank", "matrix_norm", - "svdvals", "vector_norm", "diagonal"] - -del get_xp -del da -del _linalg +import dask.array as _da +from dask.array import ( + matmul, + outer, + tensordot, + trace, +) +from dask.array.linalg import * # noqa: F401, F403 + +from .._internal import _get_all_public_members +from ._aliases import ( + EighResult, + QRResult, + SlogdetResult, + SVDResult, + cholesky, + diagonal, + matrix_norm, + matrix_rank, + matrix_transpose, + qr, + svdvals, + vecdot, + vector_norm, +) + +__all__ = [ + "matmul", + "outer", + "tensordot", + "trace", +] + +__all__ += _get_all_public_members(_da.linalg) + +__all__ += [ + "EighResult", + "QRResult", + "SVDResult", + "SlogdetResult", + "cholesky", + "diagonal", + "matrix_norm", + "matrix_rank", + "matrix_transpose", + "qr", + "svdvals", + "vecdot", + "vector_norm", +] diff --git a/array_api_compat/numpy/__init__.py b/array_api_compat/numpy/__init__.py index 4a49f2f1..8ee9f711 100644 --- a/array_api_compat/numpy/__init__.py +++ b/array_api_compat/numpy/__init__.py @@ -1,10 +1,150 @@ -from numpy import * +from numpy import * # noqa: F401, F403 +from numpy import __all__ as _numpy_all # from numpy import * doesn't overwrite these builtin names from numpy import abs, max, min, round +from ..common._helpers import ( + array_namespace, + device, + get_namespace, + is_array_api_obj, + size, + to_device, +) + # These imports may overwrite names from the import * above. -from ._aliases import * +from ._aliases import ( + UniqueAllResult, + UniqueCountsResult, + UniqueInverseResult, + acos, + acosh, + arange, + argsort, + asarray, + asarray_numpy, + asin, + asinh, + astype, + atan, + atan2, + atanh, + bitwise_invert, + bitwise_left_shift, + bitwise_right_shift, + bool, + ceil, + concat, + empty, + empty_like, + eye, + floor, + full, + full_like, + isdtype, + linspace, + matmul, + matrix_transpose, + nonzero, + ones, + ones_like, + permute_dims, + pow, + prod, + reshape, + sort, + std, + sum, + tensordot, + trunc, + unique_all, + unique_counts, + unique_inverse, + unique_values, + var, + vecdot, + zeros, + zeros_like, +) + +__all__ = [] + +__all__ += _numpy_all + +__all__ += [ + "abs", + "max", + "min", + "round", +] + +__all__ += [ + "array_namespace", + "device", + "get_namespace", + "is_array_api_obj", + "size", + "to_device", +] + +__all__ += [ + "UniqueAllResult", + "UniqueCountsResult", + "UniqueInverseResult", + "acos", + "acosh", + "arange", + "argsort", + "asarray", + "asarray_numpy", + "asin", + "asinh", + "astype", + "atan", + "atan2", + "atanh", + "bitwise_invert", + "bitwise_left_shift", + "bitwise_right_shift", + "bool", + "ceil", + "concat", + "empty", + "empty_like", + "eye", + "floor", + "full", + "full_like", + "isdtype", + "linspace", + "matmul", + "matrix_transpose", + "nonzero", + "ones", + "ones_like", + "permute_dims", + "pow", + "prod", + "reshape", + "sort", + "std", + "sum", + "tensordot", + "trunc", + "unique_all", + "unique_counts", + "unique_inverse", + "unique_values", + "var", + "zeros", + "zeros_like", +] + +__all__ += [ + "matrix_transpose", + "vecdot", +] # Don't know why, but we have to do an absolute import to import linalg. If we # instead do @@ -13,10 +153,6 @@ # # It doesn't overwrite np.linalg from above. The import is generated # dynamically so that the library can be vendored. -__import__(__package__ + '.linalg') - -from .linalg import matrix_transpose, vecdot - -from ..common._helpers import * +__import__(__package__ + ".linalg") -__array_api_version__ = '2022.12' +__array_api_version__ = "2022.12" diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index e7d4a1be..ee1c1557 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -2,15 +2,14 @@ from functools import partial -from ..common import _aliases +import numpy as np from .._internal import get_xp +from ..common import _aliases, _linalg -asarray = asarray_numpy = partial(_aliases._asarray, namespace='numpy') +asarray = asarray_numpy = partial(_aliases._asarray, namespace="numpy") asarray.__doc__ = _aliases._asarray.__doc__ -del partial -import numpy as np bool = np.bool_ # Basic renames @@ -64,16 +63,37 @@ # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. -if hasattr(np, 'vecdot'): +if hasattr(np, "vecdot"): vecdot = np.vecdot else: vecdot = get_xp(np)(_aliases.vecdot) -if hasattr(np, 'isdtype'): +if hasattr(np, "isdtype"): isdtype = np.isdtype else: isdtype = get_xp(np)(_aliases.isdtype) -__all__ = _aliases.__all__ + ['asarray', 'asarray_numpy', 'bool', 'acos', - 'acosh', 'asin', 'asinh', 'atan', 'atan2', - 'atanh', 'bitwise_left_shift', 'bitwise_invert', - 'bitwise_right_shift', 'concat', 'pow'] + +cross = get_xp(np)(_linalg.cross) +outer = get_xp(np)(_linalg.outer) +EighResult = _linalg.EighResult +QRResult = _linalg.QRResult +SlogdetResult = _linalg.SlogdetResult +SVDResult = _linalg.SVDResult +eigh = get_xp(np)(_linalg.eigh) +qr = get_xp(np)(_linalg.qr) +slogdet = get_xp(np)(_linalg.slogdet) +svd = get_xp(np)(_linalg.svd) +cholesky = get_xp(np)(_linalg.cholesky) +matrix_rank = get_xp(np)(_linalg.matrix_rank) +pinv = get_xp(np)(_linalg.pinv) +matrix_norm = get_xp(np)(_linalg.matrix_norm) +svdvals = get_xp(np)(_linalg.svdvals) +diagonal = get_xp(np)(_linalg.diagonal) +trace = get_xp(np)(_linalg.trace) + +# These functions are completely new here. If the library already has them +# (i.e., numpy 2.0), use the library version instead of our wrapper. +if hasattr(np.linalg, "vector_norm"): + vector_norm = np.linalg.vector_norm +else: + vector_norm = get_xp(np)(_linalg.vector_norm) diff --git a/array_api_compat/numpy/_typing.py b/array_api_compat/numpy/_typing.py index c5ebb5ab..53585f70 100644 --- a/array_api_compat/numpy/_typing.py +++ b/array_api_compat/numpy/_typing.py @@ -1,9 +1,9 @@ from __future__ import annotations __all__ = [ - "ndarray", "Device", "Dtype", + "ndarray", ] import sys diff --git a/array_api_compat/numpy/linalg.py b/array_api_compat/numpy/linalg.py index 39997df8..fbe37cd7 100644 --- a/array_api_compat/numpy/linalg.py +++ b/array_api_compat/numpy/linalg.py @@ -1,40 +1,63 @@ -from numpy.linalg import * -from numpy.linalg import __all__ as linalg_all - -from ..common import _linalg -from .._internal import get_xp -from ._aliases import (matmul, matrix_transpose, tensordot, vecdot) - -import numpy as np - -cross = get_xp(np)(_linalg.cross) -outer = get_xp(np)(_linalg.outer) -EighResult = _linalg.EighResult -QRResult = _linalg.QRResult -SlogdetResult = _linalg.SlogdetResult -SVDResult = _linalg.SVDResult -eigh = get_xp(np)(_linalg.eigh) -qr = get_xp(np)(_linalg.qr) -slogdet = get_xp(np)(_linalg.slogdet) -svd = get_xp(np)(_linalg.svd) -cholesky = get_xp(np)(_linalg.cholesky) -matrix_rank = get_xp(np)(_linalg.matrix_rank) -pinv = get_xp(np)(_linalg.pinv) -matrix_norm = get_xp(np)(_linalg.matrix_norm) -svdvals = get_xp(np)(_linalg.svdvals) -diagonal = get_xp(np)(_linalg.diagonal) -trace = get_xp(np)(_linalg.trace) - -# These functions are completely new here. If the library already has them -# (i.e., numpy 2.0), use the library version instead of our wrapper. -if hasattr(np.linalg, 'vector_norm'): - vector_norm = np.linalg.vector_norm -else: - vector_norm = get_xp(np)(_linalg.vector_norm) - -__all__ = linalg_all + _linalg.__all__ - -del get_xp -del np -del linalg_all -del _linalg +import numpy as _np + +from .._internal import _get_all_public_members + +_numpy_linalg_all = _get_all_public_members(_np.linalg) + +for _name in _numpy_linalg_all: + globals()[_name] = getattr(_np.linalg, _name) + + +from ._aliases import ( # noqa: E402 + EighResult, + QRResult, + SlogdetResult, + SVDResult, + cholesky, + cross, + diagonal, + eigh, + matmul, + matrix_norm, + matrix_rank, + matrix_transpose, + outer, + pinv, + qr, + slogdet, + svd, + svdvals, + tensordot, + trace, + vecdot, + vector_norm, +) + +__all__ = [] + +__all__ += _numpy_linalg_all + +__all__ += [ + "EighResult", + "QRResult", + "SVDResult", + "SlogdetResult", + "cholesky", + "cross", + "diagonal", + "eigh", + "matmul", + "matrix_norm", + "matrix_rank", + "matrix_transpose", + "outer", + "pinv", + "qr", + "slogdet", + "svd", + "svdvals", + "tensordot", + "trace", + "vecdot", + "vector_norm", +] diff --git a/array_api_compat/torch/__init__.py b/array_api_compat/torch/__init__.py index ae53ec52..6492839f 100644 --- a/array_api_compat/torch/__init__.py +++ b/array_api_compat/torch/__init__.py @@ -1,22 +1,189 @@ -from torch import * - # Several names are not included in the above import * -import torch -for n in dir(torch): - if (n.startswith('_') - or n.endswith('_') - or 'cuda' in n - or 'cpu' in n - or 'backward' in n): - continue - exec(n + ' = torch.' + n) +import torch as _torch +from torch import * # noqa: F401, F403 + +from .._internal import _get_all_public_members + + +def exlcude(name): + if ( + name.startswith("_") + or name.endswith("_") + or "cuda" in name + or "cpu" in name + or "backward" in name + ): + return True + return False + + +_torch_all = _get_all_public_members(_torch, exclude=exlcude, extend_all=True) + +for _name in _torch_all: + globals()[_name] = getattr(_torch, _name) + + +from ..common._helpers import ( # noqa: E402 + array_namespace, + device, + get_namespace, + is_array_api_obj, + size, + to_device, +) # These imports may overwrite names from the import * above. -from ._aliases import * +from ._aliases import ( # noqa: E402 + add, + all, + any, + arange, + astype, + atan2, + bitwise_and, + bitwise_invert, + bitwise_left_shift, + bitwise_or, + bitwise_right_shift, + bitwise_xor, + broadcast_arrays, + broadcast_to, + can_cast, + concat, + divide, + empty, + equal, + expand_dims, + eye, + flip, + floor_divide, + full, + greater, + greater_equal, + isdtype, + less, + less_equal, + linspace, + logaddexp, + matmul, + matrix_transpose, + max, + mean, + min, + multiply, + newaxis, + nonzero, + not_equal, + ones, + permute_dims, + pow, + prod, + remainder, + reshape, + result_type, + roll, + sort, + squeeze, + std, + subtract, + sum, + take, + tensordot, + tril, + triu, + unique_all, + unique_counts, + unique_inverse, + unique_values, + var, + vecdot, + where, + zeros, +) -# See the comment in the numpy __init__.py -__import__(__package__ + '.linalg') +__all__ = [] + +__all__ += _torch_all + +__all__ += [ + "array_namespace", + "device", + "get_namespace", + "is_array_api_obj", + "size", + "to_device", +] -from ..common._helpers import * +__all__ += [ + "add", + "all", + "any", + "arange", + "astype", + "atan2", + "bitwise_and", + "bitwise_invert", + "bitwise_left_shift", + "bitwise_or", + "bitwise_right_shift", + "bitwise_xor", + "broadcast_arrays", + "broadcast_to", + "can_cast", + "concat", + "divide", + "empty", + "equal", + "expand_dims", + "eye", + "flip", + "floor_divide", + "full", + "greater", + "greater_equal", + "isdtype", + "less", + "less_equal", + "linspace", + "logaddexp", + "matmul", + "matrix_transpose", + "max", + "mean", + "min", + "multiply", + "newaxis", + "nonzero", + "not_equal", + "ones", + "permute_dims", + "pow", + "prod", + "remainder", + "reshape", + "result_type", + "roll", + "sort", + "squeeze", + "std", + "subtract", + "sum", + "take", + "tensordot", + "tril", + "triu", + "unique_all", + "unique_counts", + "unique_inverse", + "unique_values", + "var", + "vecdot", + "where", + "zeros", +] + + +# See the comment in the numpy __init__.py +__import__(__package__ + ".linalg") -__array_api_version__ = '2022.12' +__array_api_version__ = "2022.12" diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 929d31aa..23cd5219 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -1,22 +1,24 @@ from __future__ import annotations +from builtins import all as builtin_all +from builtins import any as builtin_any from functools import wraps -from builtins import all as builtin_all, any as builtin_any - -from ..common._aliases import (UniqueAllResult, UniqueCountsResult, - UniqueInverseResult, - matrix_transpose as _aliases_matrix_transpose, - vecdot as _aliases_vecdot) -from .._internal import get_xp +from typing import TYPE_CHECKING import torch -from typing import TYPE_CHECKING +from .._internal import get_xp +from ..common._aliases import UniqueAllResult, UniqueCountsResult, UniqueInverseResult +from ..common._aliases import matrix_transpose as _aliases_matrix_transpose +from ..common._aliases import vecdot as _aliases_vecdot + if TYPE_CHECKING: from typing import List, Optional, Sequence, Tuple, Union - from ..common._typing import Device + from torch import dtype as Dtype + from ..common._typing import Device + array = torch.Tensor _int_dtypes = { @@ -693,15 +695,42 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) - axis = 0 return torch.index_select(x, axis, indices, **kwargs) -__all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'newaxis', - 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or', - 'bitwise_right_shift', 'bitwise_xor', 'divide', 'equal', - 'floor_divide', 'greater', 'greater_equal', 'less', 'less_equal', - 'logaddexp', 'multiply', 'not_equal', 'pow', 'remainder', - 'subtract', 'max', 'min', 'sort', 'prod', 'sum', 'any', 'all', - 'mean', 'std', 'var', 'concat', 'squeeze', 'broadcast_to', 'flip', 'roll', - 'nonzero', 'where', 'reshape', 'arange', 'eye', 'linspace', 'full', - 'ones', 'zeros', 'empty', 'tril', 'triu', 'expand_dims', 'astype', - 'broadcast_arrays', 'unique_all', 'unique_counts', - 'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose', - 'vecdot', 'tensordot', 'isdtype', 'take'] + + +# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the +# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743 +def cross(x1: array, x2: array, /, *, axis: int = -1) -> array: + x1, x2 = _fix_promotion(x1, x2, only_scalar=False) + return torch.linalg.cross(x1, x2, dim=axis) + +def vecdot_linalg(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array: + from ._aliases import isdtype + + x1, x2 = _fix_promotion(x1, x2, only_scalar=False) + + # torch.linalg.vecdot doesn't support integer dtypes + if isdtype(x1.dtype, 'integral') or isdtype(x2.dtype, 'integral'): + if kwargs: + raise RuntimeError("vecdot kwargs not supported for integral dtypes") + ndim = max(x1.ndim, x2.ndim) + x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape) + x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape) + if x1_shape[axis] != x2_shape[axis]: + raise ValueError("x1 and x2 must have the same size along the given axis") + + x1_, x2_ = torch.broadcast_tensors(x1, x2) + x1_ = torch.moveaxis(x1_, axis, -1) + x2_ = torch.moveaxis(x2_, axis, -1) + + res = x1_[..., None, :] @ x2_[..., None] + return res[..., 0, 0] + return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs) + +def solve(x1: array, x2: array, /, **kwargs) -> array: + x1, x2 = _fix_promotion(x1, x2, only_scalar=False) + return torch.linalg.solve(x1, x2, **kwargs) + +# torch.trace doesn't support the offset argument and doesn't support stacking +def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> array: + # Use our wrapped sum to make sure it does upcasting correctly + return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype) \ No newline at end of file diff --git a/array_api_compat/torch/linalg.py b/array_api_compat/torch/linalg.py index 52667391..160f074b 100644 --- a/array_api_compat/torch/linalg.py +++ b/array_api_compat/torch/linalg.py @@ -1,62 +1,34 @@ -from __future__ import annotations +import torch as _torch -from typing import TYPE_CHECKING -if TYPE_CHECKING: - import torch - array = torch.Tensor - from torch import dtype as Dtype - from typing import Optional +from .._internal import _get_all_public_members -from torch.linalg import * +_torch_linalg_all = _get_all_public_members(_torch.linalg) -# torch.linalg doesn't define __all__ -# from torch.linalg import __all__ as linalg_all -from torch import linalg as torch_linalg -linalg_all = [i for i in dir(torch_linalg) if not i.startswith('_')] +for _name in _torch_linalg_all: + globals()[_name] = getattr(_torch.linalg, _name) # outer is implemented in torch but aren't in the linalg namespace -from torch import outer -from ._aliases import _fix_promotion, matrix_transpose, tensordot, sum - -# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the -# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743 -def cross(x1: array, x2: array, /, *, axis: int = -1) -> array: - x1, x2 = _fix_promotion(x1, x2, only_scalar=False) - return torch_linalg.cross(x1, x2, dim=axis) - -def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array: - from ._aliases import isdtype - - x1, x2 = _fix_promotion(x1, x2, only_scalar=False) - - # torch.linalg.vecdot doesn't support integer dtypes - if isdtype(x1.dtype, 'integral') or isdtype(x2.dtype, 'integral'): - if kwargs: - raise RuntimeError("vecdot kwargs not supported for integral dtypes") - ndim = max(x1.ndim, x2.ndim) - x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape) - x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape) - if x1_shape[axis] != x2_shape[axis]: - raise ValueError("x1 and x2 must have the same size along the given axis") - - x1_, x2_ = torch.broadcast_tensors(x1, x2) - x1_ = torch.moveaxis(x1_, axis, -1) - x2_ = torch.moveaxis(x2_, axis, -1) - - res = x1_[..., None, :] @ x2_[..., None] - return res[..., 0, 0] - return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs) - -def solve(x1: array, x2: array, /, **kwargs) -> array: - x1, x2 = _fix_promotion(x1, x2, only_scalar=False) - return torch.linalg.solve(x1, x2, **kwargs) - -# torch.trace doesn't support the offset argument and doesn't support stacking -def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> array: - # Use our wrapped sum to make sure it does upcasting correctly - return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype) - -__all__ = linalg_all + ['outer', 'trace', 'matrix_transpose', 'tensordot', - 'vecdot', 'solve'] - -del linalg_all +outer = _torch.outer + +from ._aliases import ( # noqa: E402 + matrix_transpose, + solve, + sum, + tensordot, + trace, + vecdot_linalg as vecdot, +) + +__all__ = [] + +__all__ += _torch_linalg_all + +__all__ += [ + "matrix_transpose", + "outer", + "solve", + "sum", + "tensordot", + "trace", + "vecdot", +] diff --git a/tests/_helpers.py b/tests/_helpers.py index 4066d07a..69952118 100644 --- a/tests/_helpers.py +++ b/tests/_helpers.py @@ -2,7 +2,8 @@ import pytest -def import_(library): - if 'cupy' in library: + +def import_or_skip_cupy(library): + if "cupy" in library: return pytest.importorskip(library) return import_module(library) diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py index 0becfc3d..2c596d70 100644 --- a/tests/test_array_namespace.py +++ b/tests/test_array_namespace.py @@ -1,20 +1,23 @@ +import numpy as np +import pytest +import torch + import array_api_compat from array_api_compat import array_namespace -from ._helpers import import_ +from ._helpers import import_or_skip_cupy -import pytest @pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"]) -@pytest.mark.parametrize("api_version", [None, '2021.12']) +@pytest.mark.parametrize("api_version", [None, "2021.12"]) def test_array_namespace(library, api_version): - lib = import_(library) + xp = import_or_skip_cupy(library) - array = lib.asarray([1.0, 2.0, 3.0]) + array = xp.asarray([1.0, 2.0, 3.0]) namespace = array_api_compat.array_namespace(array, api_version=api_version) - if 'array_api' in library: - assert namespace == lib + if "array_api" in library: + assert namespace == xp else: if library == "dask.array": assert namespace == array_api_compat.dask.array @@ -26,18 +29,16 @@ def test_array_namespace_errors(): pytest.raises(TypeError, lambda: array_namespace([1])) pytest.raises(TypeError, lambda: array_namespace()) - import numpy as np x = np.asarray([1, 2]) - pytest.raises(TypeError, lambda: array_namespace((x, x))) pytest.raises(TypeError, lambda: array_namespace(x, (x, x))) - import torch - y = torch.asarray([1, 2]) +def test_array_namespace_errors_torch(): + y = torch.asarray([1, 2]) + x = np.asarray([1, 2]) pytest.raises(TypeError, lambda: array_namespace(x, y)) - - pytest.raises(ValueError, lambda: array_namespace(x, api_version='2022.12')) + pytest.raises(ValueError, lambda: array_namespace(x, api_version="2022.12")) def test_get_namespace(): diff --git a/tests/test_common.py b/tests/test_common.py index f98a717a..bfaf58d2 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -1,17 +1,20 @@ -from ._helpers import import_ -from array_api_compat import to_device, device - -import pytest import numpy as np +import pytest from numpy.testing import assert_allclose +from array_api_compat import to_device + +from ._helpers import import_or_skip_cupy + + @pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"]) def test_to_device_host(library): # different libraries have different semantics # for DtoH transfers; ensure that we support a portable # shim for common array libs # see: https://github.com/scipy/scipy/issues/18286#issuecomment-1527552919 - xp = import_('array_api_compat.' + library) + xp = import_or_skip_cupy("array_api_compat." + library) + expected = np.array([1, 2, 3]) x = xp.asarray([1, 2, 3]) x = to_device(x, "cpu") diff --git a/tests/test_isdtype.py b/tests/test_isdtype.py index 77e7ce72..c27334da 100644 --- a/tests/test_isdtype.py +++ b/tests/test_isdtype.py @@ -3,10 +3,10 @@ non-spec dtypes """ -from ._helpers import import_ - import pytest +from ._helpers import import_or_skip_cupy + # Check the known dtypes by their string names def _spec_dtypes(library): @@ -61,12 +61,12 @@ def isdtype_(dtype_, kind): res = dtype_categories[kind](dtype_) else: res = dtype_ == kind - assert type(res) is bool + assert type(res) is bool # noqa: E721 return res @pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"]) def test_isdtype_spec_dtypes(library): - xp = import_('array_api_compat.' + library) + xp = import_or_skip_cupy('array_api_compat.' + library) isdtype = xp.isdtype @@ -101,7 +101,7 @@ def test_isdtype_spec_dtypes(library): @pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"]) @pytest.mark.parametrize("dtype_", additional_dtypes) def test_isdtype_additional_dtypes(library, dtype_): - xp = import_('array_api_compat.' + library) + xp = import_or_skip_cupy('array_api_compat.' + library) isdtype = xp.isdtype diff --git a/tests/test_vendoring.py b/tests/test_vendoring.py index 873b233a..66fc6984 100644 --- a/tests/test_vendoring.py +++ b/tests/test_vendoring.py @@ -1,21 +1,23 @@ -from pytest import skip +import pytest + def test_vendoring_numpy(): from vendor_test import uses_numpy + uses_numpy._test_numpy() def test_vendoring_cupy(): - try: - import cupy - except ImportError: - skip("CuPy is not installed") + pytest.importorskip("cupy") from vendor_test import uses_cupy + uses_cupy._test_cupy() + def test_vendoring_torch(): from vendor_test import uses_torch + uses_torch._test_torch() def test_vendoring_dask():