From 42e2a8b80a5e4fba8fdd200b043795ae6baeb068 Mon Sep 17 00:00:00 2001 From: Axel Donath Date: Thu, 14 Dec 2023 14:40:24 -0500 Subject: [PATCH 01/37] Add ruff to ci setup --- .github/workflows/ruff.yml | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 .github/workflows/ruff.yml diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml new file mode 100644 index 00000000..c9f44886 --- /dev/null +++ b/.github/workflows/ruff.yml @@ -0,0 +1,19 @@ +name: CI +on: [push, pull_request] +jobs: + check-ruff: + runs-on: ubuntu-latest + continue-on-error: true + steps: + - uses: actions/checkout@v3 + - name: Install Python + uses: actions/setup-python@v4 + with: + python-version: "3.11" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install ruff + # Update output format to enable automatic inline annotations. + - name: Run Ruff + run: ruff check --output-format=github . From 4d0ccb97489411fe30b657b58d1ab8db445cc66c Mon Sep 17 00:00:00 2001 From: Axel Donath Date: Fri, 26 Jan 2024 14:52:10 -0500 Subject: [PATCH 02/37] Fix ruff errors in common/ --- array_api_compat/common/__init__.py | 18 +++++++++++++++++- array_api_compat/common/_aliases.py | 3 ++- array_api_compat/common/_helpers.py | 9 +++++---- array_api_compat/common/_linalg.py | 2 +- 4 files changed, 25 insertions(+), 7 deletions(-) diff --git a/array_api_compat/common/__init__.py b/array_api_compat/common/__init__.py index ce3f44dd..2bd64c1f 100644 --- a/array_api_compat/common/__init__.py +++ b/array_api_compat/common/__init__.py @@ -1 +1,17 @@ -from ._helpers import * +from ._helpers import ( + array_namespace, + device, + get_namespace, + is_array_api_obj, + size, + to_device, +) + +__all__ = [ + "is_array_api_obj", + "array_namespace", + "get_namespace", + "device", + "to_device", + "size", +] diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 7713213e..fc8189d3 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -6,7 +6,8 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Optional, Sequence, Tuple, Union, List + import numpy as np + from typing import Optional, Sequence, Tuple, Union from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol from typing import NamedTuple diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 82bf47c1..89a74e1b 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -7,6 +7,9 @@ """ from __future__ import annotations +from typing import Optional, Union, Any +from ._typing import Array, Device + import sys import math @@ -142,7 +145,7 @@ def _check_device(xp, device): # wrapping or subclassing them. These helper functions can be used instead of # the wrapper functions for libraries that need to support both NumPy/CuPy and # other libraries that use devices. -def device(x: "Array", /) -> "Device": +def device(x: Array, /) -> Device: """ Hardware device the array data resides on. @@ -204,7 +207,7 @@ def _torch_to_device(x, device, /, stream=None): raise NotImplementedError return x.to(device) -def to_device(x: "Array", device: "Device", /, *, stream: "Optional[Union[int, Any]]" = None) -> "Array": +def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] = None) -> Array: """ Copy the array from the device on which it currently resides to the specified ``device``. @@ -252,5 +255,3 @@ def size(x): if None in x.shape: return None return math.prod(x.shape) - -__all__ = ['is_array_api_obj', 'array_namespace', 'get_namespace', 'device', 'to_device', 'size'] diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index 9f0c993f..3b17417d 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, NamedTuple if TYPE_CHECKING: - from typing import Literal, Optional, Sequence, Tuple, Union + from typing import Literal, Optional, Tuple, Union from ._typing import ndarray import numpy as np From fbe6bd43921b74e9744189768b7bff20ac2b6d18 Mon Sep 17 00:00:00 2001 From: Axel Donath Date: Fri, 26 Jan 2024 15:06:25 -0500 Subject: [PATCH 03/37] Fix ruff errors in cupy/ --- array_api_compat/cupy/__init__.py | 77 ++++++++++++++++++++++++++++--- array_api_compat/cupy/_aliases.py | 8 +--- 2 files changed, 72 insertions(+), 13 deletions(-) diff --git a/array_api_compat/cupy/__init__.py b/array_api_compat/cupy/__init__.py index ec113f9d..44ea227d 100644 --- a/array_api_compat/cupy/__init__.py +++ b/array_api_compat/cupy/__init__.py @@ -1,16 +1,79 @@ -from cupy import * +from cupy import * # noqa: F401, F403 # from cupy import * doesn't overwrite these builtin names from cupy import abs, max, min, round +from ..common._helpers import ( + array_namespace, + device, + get_namespace, + is_array_api_obj, + size, + to_device, +) + # These imports may overwrite names from the import * above. -from ._aliases import * +from ._aliases import ( + acos, + acosh, + asarray, + asarray_cupy, + asin, + asinh, + atan, + atan2, + atanh, + bitwise_invert, + bitwise_left_shift, + bitwise_right_shift, + bool, + concat, + pow, +) +from .linalg import matrix_transpose, vecdot -# See the comment in the numpy __init__.py -__import__(__package__ + '.linalg') +__all__ = [] -from .linalg import matrix_transpose, vecdot +__all__ += [ + "abs", + "max", + "min", + "round", +] -from ..common._helpers import * +__all__ += [ + "is_array_api_obj", + "array_namespace", + "get_namespace", + "device", + "to_device", + "size", +] + +__all__ += [ + "acos", + "acosh", + "asarray", + "asarray_cupy", + "asin", + "asinh", + "atan", + "atan2", + "atanh", + "bitwise_invert", + "bitwise_left_shift", + "bitwise_right_shift", + "bool", + "concat", + "pow", +] + +__all__ += [ + "matrix_transpose", + "vecdot", +] + +# See the comment in the numpy __init__.py +__import__(__package__ + ".linalg") -__array_api_version__ = '2022.12' +__array_api_version__ = "2022.12" diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index d1d3cda9..eaab450b 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -2,15 +2,15 @@ from functools import partial +import cupy as cp + from ..common import _aliases from .._internal import get_xp asarray = asarray_cupy = partial(_aliases._asarray, namespace='cupy') asarray.__doc__ = _aliases._asarray.__doc__ -del partial -import cupy as cp bool = cp.bool_ # Basic renames @@ -73,7 +73,3 @@ else: isdtype = get_xp(cp)(_aliases.isdtype) -__all__ = _aliases.__all__ + ['asarray', 'asarray_cupy', 'bool', 'acos', - 'acosh', 'asin', 'asinh', 'atan', 'atan2', - 'atanh', 'bitwise_left_shift', 'bitwise_invert', - 'bitwise_right_shift', 'concat', 'pow'] From c7c27be1fc8a5b16182b3e0205ceaf0d131ebb39 Mon Sep 17 00:00:00 2001 From: Axel Donath Date: Fri, 26 Jan 2024 15:10:30 -0500 Subject: [PATCH 04/37] Fix ruff errors in numpy/ --- array_api_compat/numpy/__init__.py | 79 +++++++++++++++++++++++++++--- array_api_compat/numpy/_aliases.py | 9 +--- 2 files changed, 73 insertions(+), 15 deletions(-) diff --git a/array_api_compat/numpy/__init__.py b/array_api_compat/numpy/__init__.py index 4a49f2f1..85a4784b 100644 --- a/array_api_compat/numpy/__init__.py +++ b/array_api_compat/numpy/__init__.py @@ -1,10 +1,77 @@ -from numpy import * +from numpy import * # noqa: F401, F403 # from numpy import * doesn't overwrite these builtin names from numpy import abs, max, min, round +from ..common._helpers import ( + array_namespace, + device, + get_namespace, + is_array_api_obj, + size, + to_device, +) + # These imports may overwrite names from the import * above. -from ._aliases import * +from ._aliases import ( + acos, + acosh, + asarray, + asarray_numpy, + asin, + asinh, + atan, + atan2, + atanh, + bitwise_invert, + bitwise_left_shift, + bitwise_right_shift, + bool, + concat, + pow, +) +from .linalg import matrix_transpose, vecdot + +__all__ = [] + +__all__ += [ + "abs", + "max", + "min", + "round", +] + +__all__ += [ + "is_array_api_obj", + "array_namespace", + "get_namespace", + "device", + "to_device", + "size", +] + +__all__ += [ + "acos", + "acosh", + "asarray", + "asarray_numpy", + "asin", + "asinh", + "atan", + "atan2", + "atanh", + "bitwise_invert", + "bitwise_left_shift", + "bitwise_right_shift", + "bool", + "concat", + "pow", +] + +__all__ += [ + "matrix_transpose", + "vecdot", +] # Don't know why, but we have to do an absolute import to import linalg. If we # instead do @@ -13,10 +80,6 @@ # # It doesn't overwrite np.linalg from above. The import is generated # dynamically so that the library can be vendored. -__import__(__package__ + '.linalg') - -from .linalg import matrix_transpose, vecdot - -from ..common._helpers import * +__import__(__package__ + ".linalg") -__array_api_version__ = '2022.12' +__array_api_version__ = "2022.12" diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index e7d4a1be..d44fcbd7 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -2,15 +2,15 @@ from functools import partial +import numpy as np + from ..common import _aliases from .._internal import get_xp asarray = asarray_numpy = partial(_aliases._asarray, namespace='numpy') asarray.__doc__ = _aliases._asarray.__doc__ -del partial -import numpy as np bool = np.bool_ # Basic renames @@ -72,8 +72,3 @@ isdtype = np.isdtype else: isdtype = get_xp(np)(_aliases.isdtype) - -__all__ = _aliases.__all__ + ['asarray', 'asarray_numpy', 'bool', 'acos', - 'acosh', 'asin', 'asinh', 'atan', 'atan2', - 'atanh', 'bitwise_left_shift', 'bitwise_invert', - 'bitwise_right_shift', 'concat', 'pow'] From d3d57b9e6d75035a92df651c2aad18665eeb6ed6 Mon Sep 17 00:00:00 2001 From: Axel Donath Date: Fri, 26 Jan 2024 15:18:44 -0500 Subject: [PATCH 05/37] Fix ruff errors in torch/ --- array_api_compat/torch/__init__.py | 180 ++++++++++++++++++++++++++--- 1 file changed, 167 insertions(+), 13 deletions(-) diff --git a/array_api_compat/torch/__init__.py b/array_api_compat/torch/__init__.py index ae53ec52..64aa52cb 100644 --- a/array_api_compat/torch/__init__.py +++ b/array_api_compat/torch/__init__.py @@ -1,22 +1,176 @@ -from torch import * - # Several names are not included in the above import * import torch +from torch import * # noqa: F401, F403 + for n in dir(torch): - if (n.startswith('_') - or n.endswith('_') - or 'cuda' in n - or 'cpu' in n - or 'backward' in n): + if ( + n.startswith("_") + or n.endswith("_") + or "cuda" in n + or "cpu" in n + or "backward" in n + ): continue - exec(n + ' = torch.' + n) + exec(n + " = torch." + n) + + +from ..common._helpers import ( + array_namespace, + device, + get_namespace, + is_array_api_obj, + size, + to_device, +) # These imports may overwrite names from the import * above. -from ._aliases import * +from ._aliases import ( + add, + all, + any, + arange, + astype, + atan2, + bitwise_and, + bitwise_invert, + bitwise_left_shift, + bitwise_or, + bitwise_right_shift, + bitwise_xor, + broadcast_arrays, + broadcast_to, + can_cast, + concat, + divide, + empty, + equal, + expand_dims, + eye, + flip, + floor_divide, + full, + greater, + greater_equal, + isdtype, + less, + less_equal, + linspace, + logaddexp, + matmul, + matrix_transpose, + max, + mean, + min, + multiply, + newaxis, + nonzero, + not_equal, + ones, + permute_dims, + pow, + prod, + remainder, + reshape, + result_type, + roll, + sort, + squeeze, + std, + subtract, + sum, + take, + tensordot, + tril, + triu, + unique_all, + unique_counts, + unique_inverse, + unique_values, + var, + vecdot, + where, + zeros, +) -# See the comment in the numpy __init__.py -__import__(__package__ + '.linalg') +__all__ = [ + "is_array_api_obj", + "array_namespace", + "get_namespace", + "device", + "to_device", + "size", +] + +__all__ += [ + "result_type", + "can_cast", + "permute_dims", + "bitwise_invert", + "newaxis", + "add", + "atan2", + "bitwise_and", + "bitwise_left_shift", + "bitwise_or", + "bitwise_right_shift", + "bitwise_xor", + "divide", + "equal", + "floor_divide", + "greater", + "greater_equal", + "less", + "less_equal", + "logaddexp", + "multiply", + "not_equal", + "pow", + "remainder", + "subtract", + "max", + "min", + "sort", + "prod", + "sum", + "any", + "all", + "mean", + "std", + "var", + "concat", + "squeeze", + "broadcast_to", + "flip", + "roll", + "nonzero", + "where", + "reshape", + "arange", + "eye", + "linspace", + "full", + "ones", + "zeros", + "empty", + "tril", + "triu", + "expand_dims", + "astype", + "broadcast_arrays", + "unique_all", + "unique_counts", + "unique_inverse", + "unique_values", + "matmul", + "matrix_transpose", + "vecdot", + "tensordot", + "isdtype", + "take", +] -from ..common._helpers import * -__array_api_version__ = '2022.12' +# See the comment in the numpy __init__.py +__import__(__package__ + ".linalg") + +__array_api_version__ = "2022.12" From 0d437cdd07adb1acafc937cbd08eb002ca1e84d7 Mon Sep 17 00:00:00 2001 From: Axel Donath Date: Fri, 26 Jan 2024 15:33:24 -0500 Subject: [PATCH 06/37] Fix ruff errors in tests/ --- tests/test_common.py | 2 +- tests/test_isdtype.py | 2 +- tests/test_vendoring.py | 8 +++----- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/tests/test_common.py b/tests/test_common.py index f98a717a..5537ade1 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -1,5 +1,5 @@ from ._helpers import import_ -from array_api_compat import to_device, device +from array_api_compat import to_device import pytest import numpy as np diff --git a/tests/test_isdtype.py b/tests/test_isdtype.py index 77e7ce72..1cda6089 100644 --- a/tests/test_isdtype.py +++ b/tests/test_isdtype.py @@ -61,7 +61,7 @@ def isdtype_(dtype_, kind): res = dtype_categories[kind](dtype_) else: res = dtype_ == kind - assert type(res) is bool + assert isinstance(res, bool) return res @pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"]) diff --git a/tests/test_vendoring.py b/tests/test_vendoring.py index 873b233a..93f42057 100644 --- a/tests/test_vendoring.py +++ b/tests/test_vendoring.py @@ -1,4 +1,5 @@ -from pytest import skip +import pytest + def test_vendoring_numpy(): from vendor_test import uses_numpy @@ -6,10 +7,7 @@ def test_vendoring_numpy(): def test_vendoring_cupy(): - try: - import cupy - except ImportError: - skip("CuPy is not installed") + pytest.importorskip("cupy") from vendor_test import uses_cupy uses_cupy._test_cupy() From afccf2987e8bd65c9803f32bc2d104ce7ed58a8a Mon Sep 17 00:00:00 2001 From: Axel Donath Date: Fri, 26 Jan 2024 15:35:37 -0500 Subject: [PATCH 07/37] Fix ruff errors in array_api_compat/__init__.py --- array_api_compat/__init__.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/array_api_compat/__init__.py b/array_api_compat/__init__.py index 28ffc7e7..20e7eda2 100644 --- a/array_api_compat/__init__.py +++ b/array_api_compat/__init__.py @@ -17,6 +17,20 @@ this implementation for the default when working with NumPy arrays. """ -__version__ = '1.4.1' +__version__ = "1.4.1" -from .common import * +from .common import ( + array_namespace, + get_namespace, + is_array_api_obj, + size, + to_device, +) + +__all__ = [ + "array_namespace", + "get_namespace", + "is_array_api_obj", + "size", + "to_device", +] From 2395ea04df04319179e4e55752b567ebd0aa87d1 Mon Sep 17 00:00:00 2001 From: Axel Donath Date: Fri, 26 Jan 2024 16:01:52 -0500 Subject: [PATCH 08/37] Implement _get_all_public_members --- array_api_compat/_internal.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/array_api_compat/_internal.py b/array_api_compat/_internal.py index 553c0356..af178eaf 100644 --- a/array_api_compat/_internal.py +++ b/array_api_compat/_internal.py @@ -5,6 +5,7 @@ from functools import wraps from inspect import signature + def get_xp(xp): """ Decorator to automatically replace xp with the corresponding array module. @@ -41,3 +42,16 @@ def wrapped_f(*args, **kwargs): return wrapped_f return inner + + +def get_all_public_members(module, filter_=None): + """Get all public members of a module.""" + try: + return getattr(module, '__all__') + except AttributeError: + pass + + if filter_ is None: + filter_ = lambda name: name.startswith('_') # noqa: E731 + + return map(dir(module), filter_) \ No newline at end of file From 645cef2dbcb82380bd381df52dd15836b64a3265 Mon Sep 17 00:00:00 2001 From: Axel Donath Date: Fri, 26 Jan 2024 16:02:24 -0500 Subject: [PATCH 09/37] Move linalg aliases to _aliases --- array_api_compat/cupy/_aliases.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index eaab450b..71ffadce 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -5,6 +5,7 @@ import cupy as cp from ..common import _aliases +from ..common import _linalg from .._internal import get_xp @@ -73,3 +74,28 @@ else: isdtype = get_xp(cp)(_aliases.isdtype) + +cross = get_xp(cp)(_linalg.cross) +outer = get_xp(cp)(_linalg.outer) +EighResult = _linalg.EighResult +QRResult = _linalg.QRResult +SlogdetResult = _linalg.SlogdetResult +SVDResult = _linalg.SVDResult +eigh = get_xp(cp)(_linalg.eigh) +qr = get_xp(cp)(_linalg.qr) +slogdet = get_xp(cp)(_linalg.slogdet) +svd = get_xp(cp)(_linalg.svd) +cholesky = get_xp(cp)(_linalg.cholesky) +matrix_rank = get_xp(cp)(_linalg.matrix_rank) +pinv = get_xp(cp)(_linalg.pinv) +matrix_norm = get_xp(cp)(_linalg.matrix_norm) +svdvals = get_xp(cp)(_linalg.svdvals) +diagonal = get_xp(cp)(_linalg.diagonal) +trace = get_xp(cp)(_linalg.trace) + +# These functions are completely new here. If the library already has them +# (i.e., numpy 2.0), use the library version instead of our wrapper. +if hasattr(cp.linalg, 'vector_norm'): + vector_norm = cp.linalg.vector_norm +else: + vector_norm = get_xp(cp)(_linalg.vector_norm) \ No newline at end of file From 5a6f411e7e89aea2f5c38c658ab37dfa48284dd1 Mon Sep 17 00:00:00 2001 From: Axel Donath Date: Fri, 26 Jan 2024 16:02:58 -0500 Subject: [PATCH 10/37] Fix ruff errors in cupy/linalg --- array_api_compat/cupy/linalg.py | 96 +++++++++++++++++++-------------- 1 file changed, 55 insertions(+), 41 deletions(-) diff --git a/array_api_compat/cupy/linalg.py b/array_api_compat/cupy/linalg.py index 84752e1a..d8b877fd 100644 --- a/array_api_compat/cupy/linalg.py +++ b/array_api_compat/cupy/linalg.py @@ -1,47 +1,61 @@ -from cupy.linalg import * -# cupy.linalg doesn't have __all__. If it is added, replace this with -# -# from cupy.linalg import __all__ as linalg_all -_n = {} -exec('from cupy.linalg import *', _n) -del _n['__builtins__'] -linalg_all = list(_n) -del _n +import cupy as cp +from .._internal import _get_all_public_members -from ..common import _linalg -from .._internal import get_xp -from ._aliases import (matmul, matrix_transpose, tensordot, vecdot) +_cupy_linalg_all = _get_all_public_members(cp.linalg) -import cupy as cp +for name in _cupy_linalg_all: + globals()[name] = getattr(cp.linalg, name) -cross = get_xp(cp)(_linalg.cross) -outer = get_xp(cp)(_linalg.outer) -EighResult = _linalg.EighResult -QRResult = _linalg.QRResult -SlogdetResult = _linalg.SlogdetResult -SVDResult = _linalg.SVDResult -eigh = get_xp(cp)(_linalg.eigh) -qr = get_xp(cp)(_linalg.qr) -slogdet = get_xp(cp)(_linalg.slogdet) -svd = get_xp(cp)(_linalg.svd) -cholesky = get_xp(cp)(_linalg.cholesky) -matrix_rank = get_xp(cp)(_linalg.matrix_rank) -pinv = get_xp(cp)(_linalg.pinv) -matrix_norm = get_xp(cp)(_linalg.matrix_norm) -svdvals = get_xp(cp)(_linalg.svdvals) -diagonal = get_xp(cp)(_linalg.diagonal) -trace = get_xp(cp)(_linalg.trace) +from ._aliases import ( # noqa: E402 + EighResult, + QRResult, + SlogdetResult, + SVDResult, + cholesky, + cross, + diagonal, + eigh, + matmul, + matrix_norm, + matrix_rank, + matrix_transpose, + outer, + pinv, + qr, + slogdet, + svd, + svdvals, + tensordot, + trace, + vecdot, + vector_norm, +) -# These functions are completely new here. If the library already has them -# (i.e., numpy 2.0), use the library version instead of our wrapper. -if hasattr(cp.linalg, 'vector_norm'): - vector_norm = cp.linalg.vector_norm -else: - vector_norm = get_xp(cp)(_linalg.vector_norm) +__all__ = [] -__all__ = linalg_all + _linalg.__all__ +__all__ += _cupy_linalg_all -del get_xp -del cp -del linalg_all -del _linalg +__all__ += [ + "cross", + "matmul", + "outer", + "tensordot", + "EighResult", + "QRResult", + "SlogdetResult", + "SVDResult", + "eigh", + "qr", + "slogdet", + "svd", + "cholesky", + "matrix_rank", + "pinv", + "matrix_norm", + "matrix_transpose", + "svdvals", + "vecdot", + "vector_norm", + "diagonal", + "trace", +] From 0ff0836080334497bed933be9146effa0c22b142 Mon Sep 17 00:00:00 2001 From: Axel Donath Date: Fri, 26 Jan 2024 16:13:00 -0500 Subject: [PATCH 11/37] Move linalg aliases to numpy/_aliases --- array_api_compat/numpy/_aliases.py | 36 +++++++++++++++++++++++++----- 1 file changed, 31 insertions(+), 5 deletions(-) diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index d44fcbd7..66c543c7 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -4,11 +4,11 @@ import numpy as np -from ..common import _aliases - from .._internal import get_xp +from ..common import _aliases +from ..common import _linalg -asarray = asarray_numpy = partial(_aliases._asarray, namespace='numpy') +asarray = asarray_numpy = partial(_aliases._asarray, namespace="numpy") asarray.__doc__ = _aliases._asarray.__doc__ bool = np.bool_ @@ -64,11 +64,37 @@ # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. -if hasattr(np, 'vecdot'): +if hasattr(np, "vecdot"): vecdot = np.vecdot else: vecdot = get_xp(np)(_aliases.vecdot) -if hasattr(np, 'isdtype'): +if hasattr(np, "isdtype"): isdtype = np.isdtype else: isdtype = get_xp(np)(_aliases.isdtype) + + +cross = get_xp(np)(_linalg.cross) +outer = get_xp(np)(_linalg.outer) +EighResult = _linalg.EighResult +QRResult = _linalg.QRResult +SlogdetResult = _linalg.SlogdetResult +SVDResult = _linalg.SVDResult +eigh = get_xp(np)(_linalg.eigh) +qr = get_xp(np)(_linalg.qr) +slogdet = get_xp(np)(_linalg.slogdet) +svd = get_xp(np)(_linalg.svd) +cholesky = get_xp(np)(_linalg.cholesky) +matrix_rank = get_xp(np)(_linalg.matrix_rank) +pinv = get_xp(np)(_linalg.pinv) +matrix_norm = get_xp(np)(_linalg.matrix_norm) +svdvals = get_xp(np)(_linalg.svdvals) +diagonal = get_xp(np)(_linalg.diagonal) +trace = get_xp(np)(_linalg.trace) + +# These functions are completely new here. If the library already has them +# (i.e., numpy 2.0), use the library version instead of our wrapper. +if hasattr(np.linalg, "vector_norm"): + vector_norm = np.linalg.vector_norm +else: + vector_norm = get_xp(np)(_linalg.vector_norm) From f4d78c7c4d09b70e18547f3e2f5b4d9b6fb12d6b Mon Sep 17 00:00:00 2001 From: Axel Donath Date: Fri, 26 Jan 2024 16:13:17 -0500 Subject: [PATCH 12/37] Fix ruff errors in numpy/linalg --- array_api_compat/numpy/linalg.py | 103 +++++++++++++++++++------------ 1 file changed, 63 insertions(+), 40 deletions(-) diff --git a/array_api_compat/numpy/linalg.py b/array_api_compat/numpy/linalg.py index 39997df8..0da1899e 100644 --- a/array_api_compat/numpy/linalg.py +++ b/array_api_compat/numpy/linalg.py @@ -1,40 +1,63 @@ -from numpy.linalg import * -from numpy.linalg import __all__ as linalg_all - -from ..common import _linalg -from .._internal import get_xp -from ._aliases import (matmul, matrix_transpose, tensordot, vecdot) - -import numpy as np - -cross = get_xp(np)(_linalg.cross) -outer = get_xp(np)(_linalg.outer) -EighResult = _linalg.EighResult -QRResult = _linalg.QRResult -SlogdetResult = _linalg.SlogdetResult -SVDResult = _linalg.SVDResult -eigh = get_xp(np)(_linalg.eigh) -qr = get_xp(np)(_linalg.qr) -slogdet = get_xp(np)(_linalg.slogdet) -svd = get_xp(np)(_linalg.svd) -cholesky = get_xp(np)(_linalg.cholesky) -matrix_rank = get_xp(np)(_linalg.matrix_rank) -pinv = get_xp(np)(_linalg.pinv) -matrix_norm = get_xp(np)(_linalg.matrix_norm) -svdvals = get_xp(np)(_linalg.svdvals) -diagonal = get_xp(np)(_linalg.diagonal) -trace = get_xp(np)(_linalg.trace) - -# These functions are completely new here. If the library already has them -# (i.e., numpy 2.0), use the library version instead of our wrapper. -if hasattr(np.linalg, 'vector_norm'): - vector_norm = np.linalg.vector_norm -else: - vector_norm = get_xp(np)(_linalg.vector_norm) - -__all__ = linalg_all + _linalg.__all__ - -del get_xp -del np -del linalg_all -del _linalg +import numpy as _np + +from .._internal import _get_all_public_members + +_numpy_linalg_all = _get_all_public_members(_np.linalg) + +for _name in _numpy_linalg_all: + globals()[_name] = getattr(_np.linalg, _name) + + +from ._aliases import ( # noqa: E402 + EighResult, + QRResult, + SlogdetResult, + SVDResult, + cholesky, + cross, + diagonal, + eigh, + matmul, + matrix_norm, + matrix_rank, + matrix_transpose, + outer, + pinv, + qr, + slogdet, + svd, + svdvals, + tensordot, + trace, + vecdot, + vector_norm, +) + +__all__ = [] + +__all__ += _numpy_linalg_all + +__all__ += [ + "EighResult", + "QRResult", + "SlogdetResult", + "SVDResult", + "cholesky", + "cross", + "diagonal", + "eigh", + "matmul", + "matrix_norm", + "matrix_rank", + "matrix_transpose", + "outer", + "pinv", + "qr", + "slogdet", + "svd", + "svdvals", + "tensordot", + "trace", + "vecdot", + "vector_norm", +] From 31bbbfa323d38d8d5c2cff97f63bee8516c0daee Mon Sep 17 00:00:00 2001 From: Axel Donath Date: Fri, 26 Jan 2024 16:14:43 -0500 Subject: [PATCH 13/37] Hide helper variables in cupy/linalg.py --- array_api_compat/cupy/linalg.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/array_api_compat/cupy/linalg.py b/array_api_compat/cupy/linalg.py index d8b877fd..31fe58b3 100644 --- a/array_api_compat/cupy/linalg.py +++ b/array_api_compat/cupy/linalg.py @@ -1,12 +1,13 @@ -import cupy as cp +import cupy as _cp + from .._internal import _get_all_public_members -_cupy_linalg_all = _get_all_public_members(cp.linalg) +_cupy_linalg_all = _get_all_public_members(_cp.linalg) -for name in _cupy_linalg_all: - globals()[name] = getattr(cp.linalg, name) +for _name in _cupy_linalg_all: + globals()[_name] = getattr(_cp.linalg, _name) -from ._aliases import ( # noqa: E402 +from ._aliases import ( # noqa: E402 EighResult, QRResult, SlogdetResult, From 5ecc7b50dfac3a62a9d03643c780b5f0633aad06 Mon Sep 17 00:00:00 2001 From: Axel Donath Date: Fri, 26 Jan 2024 16:26:20 -0500 Subject: [PATCH 14/37] Move linalg aliases to torch/_aliases --- array_api_compat/torch/_aliases.py | 71 +++++++++++++++++++++--------- 1 file changed, 50 insertions(+), 21 deletions(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 929d31aa..23cd5219 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -1,22 +1,24 @@ from __future__ import annotations +from builtins import all as builtin_all +from builtins import any as builtin_any from functools import wraps -from builtins import all as builtin_all, any as builtin_any - -from ..common._aliases import (UniqueAllResult, UniqueCountsResult, - UniqueInverseResult, - matrix_transpose as _aliases_matrix_transpose, - vecdot as _aliases_vecdot) -from .._internal import get_xp +from typing import TYPE_CHECKING import torch -from typing import TYPE_CHECKING +from .._internal import get_xp +from ..common._aliases import UniqueAllResult, UniqueCountsResult, UniqueInverseResult +from ..common._aliases import matrix_transpose as _aliases_matrix_transpose +from ..common._aliases import vecdot as _aliases_vecdot + if TYPE_CHECKING: from typing import List, Optional, Sequence, Tuple, Union - from ..common._typing import Device + from torch import dtype as Dtype + from ..common._typing import Device + array = torch.Tensor _int_dtypes = { @@ -693,15 +695,42 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) - axis = 0 return torch.index_select(x, axis, indices, **kwargs) -__all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'newaxis', - 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or', - 'bitwise_right_shift', 'bitwise_xor', 'divide', 'equal', - 'floor_divide', 'greater', 'greater_equal', 'less', 'less_equal', - 'logaddexp', 'multiply', 'not_equal', 'pow', 'remainder', - 'subtract', 'max', 'min', 'sort', 'prod', 'sum', 'any', 'all', - 'mean', 'std', 'var', 'concat', 'squeeze', 'broadcast_to', 'flip', 'roll', - 'nonzero', 'where', 'reshape', 'arange', 'eye', 'linspace', 'full', - 'ones', 'zeros', 'empty', 'tril', 'triu', 'expand_dims', 'astype', - 'broadcast_arrays', 'unique_all', 'unique_counts', - 'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose', - 'vecdot', 'tensordot', 'isdtype', 'take'] + + +# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the +# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743 +def cross(x1: array, x2: array, /, *, axis: int = -1) -> array: + x1, x2 = _fix_promotion(x1, x2, only_scalar=False) + return torch.linalg.cross(x1, x2, dim=axis) + +def vecdot_linalg(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array: + from ._aliases import isdtype + + x1, x2 = _fix_promotion(x1, x2, only_scalar=False) + + # torch.linalg.vecdot doesn't support integer dtypes + if isdtype(x1.dtype, 'integral') or isdtype(x2.dtype, 'integral'): + if kwargs: + raise RuntimeError("vecdot kwargs not supported for integral dtypes") + ndim = max(x1.ndim, x2.ndim) + x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape) + x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape) + if x1_shape[axis] != x2_shape[axis]: + raise ValueError("x1 and x2 must have the same size along the given axis") + + x1_, x2_ = torch.broadcast_tensors(x1, x2) + x1_ = torch.moveaxis(x1_, axis, -1) + x2_ = torch.moveaxis(x2_, axis, -1) + + res = x1_[..., None, :] @ x2_[..., None] + return res[..., 0, 0] + return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs) + +def solve(x1: array, x2: array, /, **kwargs) -> array: + x1, x2 = _fix_promotion(x1, x2, only_scalar=False) + return torch.linalg.solve(x1, x2, **kwargs) + +# torch.trace doesn't support the offset argument and doesn't support stacking +def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> array: + # Use our wrapped sum to make sure it does upcasting correctly + return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype) \ No newline at end of file From 8e4e9ca4b89614b6be68881157a5415960d22e04 Mon Sep 17 00:00:00 2001 From: Axel Donath Date: Fri, 26 Jan 2024 16:26:46 -0500 Subject: [PATCH 15/37] Fix ruff errors in torch/linalg --- array_api_compat/torch/linalg.py | 86 +++++++++++--------------------- 1 file changed, 29 insertions(+), 57 deletions(-) diff --git a/array_api_compat/torch/linalg.py b/array_api_compat/torch/linalg.py index 52667391..3f9811ed 100644 --- a/array_api_compat/torch/linalg.py +++ b/array_api_compat/torch/linalg.py @@ -1,62 +1,34 @@ -from __future__ import annotations +import torch as _torch -from typing import TYPE_CHECKING -if TYPE_CHECKING: - import torch - array = torch.Tensor - from torch import dtype as Dtype - from typing import Optional +from .._internal import _get_all_public_members -from torch.linalg import * +_torch_linalg_all = _get_all_public_members(_torch.linalg) -# torch.linalg doesn't define __all__ -# from torch.linalg import __all__ as linalg_all -from torch import linalg as torch_linalg -linalg_all = [i for i in dir(torch_linalg) if not i.startswith('_')] +for _name in _torch_linalg_all: + globals()[_name] = getattr(_torch.linalg, _name) # outer is implemented in torch but aren't in the linalg namespace -from torch import outer -from ._aliases import _fix_promotion, matrix_transpose, tensordot, sum - -# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the -# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743 -def cross(x1: array, x2: array, /, *, axis: int = -1) -> array: - x1, x2 = _fix_promotion(x1, x2, only_scalar=False) - return torch_linalg.cross(x1, x2, dim=axis) - -def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array: - from ._aliases import isdtype - - x1, x2 = _fix_promotion(x1, x2, only_scalar=False) - - # torch.linalg.vecdot doesn't support integer dtypes - if isdtype(x1.dtype, 'integral') or isdtype(x2.dtype, 'integral'): - if kwargs: - raise RuntimeError("vecdot kwargs not supported for integral dtypes") - ndim = max(x1.ndim, x2.ndim) - x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape) - x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape) - if x1_shape[axis] != x2_shape[axis]: - raise ValueError("x1 and x2 must have the same size along the given axis") - - x1_, x2_ = torch.broadcast_tensors(x1, x2) - x1_ = torch.moveaxis(x1_, axis, -1) - x2_ = torch.moveaxis(x2_, axis, -1) - - res = x1_[..., None, :] @ x2_[..., None] - return res[..., 0, 0] - return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs) - -def solve(x1: array, x2: array, /, **kwargs) -> array: - x1, x2 = _fix_promotion(x1, x2, only_scalar=False) - return torch.linalg.solve(x1, x2, **kwargs) - -# torch.trace doesn't support the offset argument and doesn't support stacking -def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> array: - # Use our wrapped sum to make sure it does upcasting correctly - return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype) - -__all__ = linalg_all + ['outer', 'trace', 'matrix_transpose', 'tensordot', - 'vecdot', 'solve'] - -del linalg_all +outer = _torch.outer + +from ._aliases import ( # noqa: E402 + matrix_transpose, + solve, + sum, + tensordot, + trace, + vecdot_linalg as vecdot, +) + +__all__ = [] + +__all__ += _torch_linalg_all + +__all__ += [ + "matrix_transpose", + "solve", + "sum", + "outer", + "trace", + "tensordot", + "vecdot", +] From 5c66efc12d3f217878c78f80081f1f099e3a48ac Mon Sep 17 00:00:00 2001 From: Axel Donath Date: Fri, 26 Jan 2024 16:29:57 -0500 Subject: [PATCH 16/37] Fix final ruff errors in array_api_compat/torch/__init__.py --- array_api_compat/torch/__init__.py | 35 ++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 11 deletions(-) diff --git a/array_api_compat/torch/__init__.py b/array_api_compat/torch/__init__.py index 64aa52cb..76a4ff6f 100644 --- a/array_api_compat/torch/__init__.py +++ b/array_api_compat/torch/__init__.py @@ -2,19 +2,28 @@ import torch from torch import * # noqa: F401, F403 -for n in dir(torch): +from .._internal import _get_all_public_members + + +def filter_(name): if ( - n.startswith("_") - or n.endswith("_") - or "cuda" in n - or "cpu" in n - or "backward" in n + name.startswith("_") + or name.endswith("_") + or "cuda" in name + or "cpu" in name + or "backward" in name ): - continue - exec(n + " = torch." + n) + return False + return True + + +_torch_all = _get_all_public_members(torch, filter_=filter_) +for _name in _torch_all: + globals()[_name] = getattr(torch, _name) -from ..common._helpers import ( + +from ..common._helpers import ( # noqa: E402 array_namespace, device, get_namespace, @@ -24,7 +33,7 @@ ) # These imports may overwrite names from the import * above. -from ._aliases import ( +from ._aliases import ( # noqa: E402 add, all, any, @@ -92,7 +101,11 @@ zeros, ) -__all__ = [ +__all__ = [] + +__all__ += _torch_all + +__all__ += [ "is_array_api_obj", "array_namespace", "get_namespace", From bca606dbfb95867752635bf3cd9fda8aa24c4b24 Mon Sep 17 00:00:00 2001 From: Axel Donath Date: Fri, 26 Jan 2024 16:35:27 -0500 Subject: [PATCH 17/37] Expose public members from numpy an cupy in __all__ respectively --- array_api_compat/cupy/__init__.py | 4 ++++ array_api_compat/numpy/__init__.py | 3 +++ 2 files changed, 7 insertions(+) diff --git a/array_api_compat/cupy/__init__.py b/array_api_compat/cupy/__init__.py index 44ea227d..d7f27303 100644 --- a/array_api_compat/cupy/__init__.py +++ b/array_api_compat/cupy/__init__.py @@ -1,8 +1,10 @@ +import cupy as _cp from cupy import * # noqa: F401, F403 # from cupy import * doesn't overwrite these builtin names from cupy import abs, max, min, round +from .._internal import _get_all_public_members from ..common._helpers import ( array_namespace, device, @@ -34,6 +36,8 @@ __all__ = [] +__all__ += _get_all_public_members(_cp) + __all__ += [ "abs", "max", diff --git a/array_api_compat/numpy/__init__.py b/array_api_compat/numpy/__init__.py index 85a4784b..7afd6321 100644 --- a/array_api_compat/numpy/__init__.py +++ b/array_api_compat/numpy/__init__.py @@ -1,4 +1,5 @@ from numpy import * # noqa: F401, F403 +from numpy import __all__ as _numpy_all # from numpy import * doesn't overwrite these builtin names from numpy import abs, max, min, round @@ -34,6 +35,8 @@ __all__ = [] +__all__ += _numpy_all + __all__ += [ "abs", "max", From 890c4974541768005ce3af20581ccc3d7b546395 Mon Sep 17 00:00:00 2001 From: Axel Donath Date: Fri, 26 Jan 2024 16:38:15 -0500 Subject: [PATCH 18/37] Clean up --- array_api_compat/__init__.py | 2 +- array_api_compat/_internal.py | 4 ++-- array_api_compat/common/_helpers.py | 7 +++++-- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/array_api_compat/__init__.py b/array_api_compat/__init__.py index 20e7eda2..61e828a9 100644 --- a/array_api_compat/__init__.py +++ b/array_api_compat/__init__.py @@ -17,7 +17,7 @@ this implementation for the default when working with NumPy arrays. """ -__version__ = "1.4.1" +__version__ = '1.4.1' from .common import ( array_namespace, diff --git a/array_api_compat/_internal.py b/array_api_compat/_internal.py index af178eaf..98eeb23a 100644 --- a/array_api_compat/_internal.py +++ b/array_api_compat/_internal.py @@ -44,7 +44,7 @@ def wrapped_f(*args, **kwargs): return inner -def get_all_public_members(module, filter_=None): +def _get_all_public_members(module, filter_=None): """Get all public members of a module.""" try: return getattr(module, '__all__') @@ -54,4 +54,4 @@ def get_all_public_members(module, filter_=None): if filter_ is None: filter_ = lambda name: name.startswith('_') # noqa: E731 - return map(dir(module), filter_) \ No newline at end of file + return map(filter_, dir(module)) \ No newline at end of file diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 89a74e1b..ac866551 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -7,8 +7,11 @@ """ from __future__ import annotations -from typing import Optional, Union, Any -from ._typing import Array, Device +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Optional, Union, Any + from ._typing import Array, Device import sys import math From b2f9557454455211c58a0f53b5d4dd97f4bfb867 Mon Sep 17 00:00:00 2001 From: Axel Donath Date: Fri, 26 Jan 2024 16:39:27 -0500 Subject: [PATCH 19/37] Add importorskip torch --- tests/test_vendoring.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_vendoring.py b/tests/test_vendoring.py index 93f42057..e3b38748 100644 --- a/tests/test_vendoring.py +++ b/tests/test_vendoring.py @@ -13,6 +13,8 @@ def test_vendoring_cupy(): uses_cupy._test_cupy() def test_vendoring_torch(): + pytest.importorskip("torch") + from vendor_test import uses_torch uses_torch._test_torch() From 52ef9ee81a59a219f7322cff47de53bd544d97dd Mon Sep 17 00:00:00 2001 From: Axel Donath Date: Fri, 26 Jan 2024 16:49:23 -0500 Subject: [PATCH 20/37] Use importorskip --- tests/test_array_namespace.py | 24 ++++++++++++------------ tests/test_common.py | 3 +-- tests/test_isdtype.py | 8 ++++---- 3 files changed, 17 insertions(+), 18 deletions(-) diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py index 0becfc3d..99c67630 100644 --- a/tests/test_array_namespace.py +++ b/tests/test_array_namespace.py @@ -1,19 +1,18 @@ +import pytest + import array_api_compat from array_api_compat import array_namespace -from ._helpers import import_ - -import pytest @pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"]) -@pytest.mark.parametrize("api_version", [None, '2021.12']) +@pytest.mark.parametrize("api_version", [None, "2021.12"]) def test_array_namespace(library, api_version): - lib = import_(library) + lib = pytest.importorskip(library) array = lib.asarray([1.0, 2.0, 3.0]) namespace = array_api_compat.array_namespace(array, api_version=api_version) - if 'array_api' in library: + if "array_api" in library: assert namespace == lib else: if library == "dask.array": @@ -23,21 +22,22 @@ def test_array_namespace(library, api_version): def test_array_namespace_errors(): + np = pytest.importorskip("numpy") + pytest.raises(TypeError, lambda: array_namespace([1])) pytest.raises(TypeError, lambda: array_namespace()) - import numpy as np x = np.asarray([1, 2]) - pytest.raises(TypeError, lambda: array_namespace((x, x))) pytest.raises(TypeError, lambda: array_namespace(x, (x, x))) - import torch - y = torch.asarray([1, 2]) - pytest.raises(TypeError, lambda: array_namespace(x, y)) +def test_array_namespace_errors_torch(): + torch = pytest.importorskip("torch") - pytest.raises(ValueError, lambda: array_namespace(x, api_version='2022.12')) + y = torch.asarray([1, 2]) + pytest.raises(TypeError, lambda: array_namespace(x, y)) + pytest.raises(ValueError, lambda: array_namespace(x, api_version="2022.12")) def test_get_namespace(): diff --git a/tests/test_common.py b/tests/test_common.py index 5537ade1..8c3a6ce2 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -1,4 +1,3 @@ -from ._helpers import import_ from array_api_compat import to_device import pytest @@ -11,7 +10,7 @@ def test_to_device_host(library): # for DtoH transfers; ensure that we support a portable # shim for common array libs # see: https://github.com/scipy/scipy/issues/18286#issuecomment-1527552919 - xp = import_('array_api_compat.' + library) + xp = pytest.importorskip('array_api_compat.' + library) expected = np.array([1, 2, 3]) x = xp.asarray([1, 2, 3]) x = to_device(x, "cpu") diff --git a/tests/test_isdtype.py b/tests/test_isdtype.py index 1cda6089..7d7241f0 100644 --- a/tests/test_isdtype.py +++ b/tests/test_isdtype.py @@ -3,10 +3,10 @@ non-spec dtypes """ -from ._helpers import import_ - import pytest +from ._helpers import import_ + # Check the known dtypes by their string names def _spec_dtypes(library): @@ -66,7 +66,7 @@ def isdtype_(dtype_, kind): @pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"]) def test_isdtype_spec_dtypes(library): - xp = import_('array_api_compat.' + library) + xp = pytest.importorskip('array_api_compat.' + library) isdtype = xp.isdtype @@ -101,7 +101,7 @@ def test_isdtype_spec_dtypes(library): @pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"]) @pytest.mark.parametrize("dtype_", additional_dtypes) def test_isdtype_additional_dtypes(library, dtype_): - xp = import_('array_api_compat.' + library) + xp = pytest.importorskip('array_api_compat.' + library) isdtype = xp.isdtype From b0692306a906457eaab2b15097ed6f0016788db1 Mon Sep 17 00:00:00 2001 From: Axel Donath Date: Fri, 26 Jan 2024 16:49:48 -0500 Subject: [PATCH 21/37] Add missing isdtype --- array_api_compat/numpy/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/array_api_compat/numpy/__init__.py b/array_api_compat/numpy/__init__.py index 7afd6321..f63acb48 100644 --- a/array_api_compat/numpy/__init__.py +++ b/array_api_compat/numpy/__init__.py @@ -30,6 +30,7 @@ bool, concat, pow, + isdtype, ) from .linalg import matrix_transpose, vecdot @@ -69,6 +70,7 @@ "bool", "concat", "pow", + "isdtype", ] __all__ += [ From ff510154050663201c98440cd924f472a4a730be Mon Sep 17 00:00:00 2001 From: Axel Donath Date: Fri, 26 Jan 2024 17:05:17 -0500 Subject: [PATCH 22/37] Fix tests --- array_api_compat/_internal.py | 33 ++++++++++++++++++++++-------- array_api_compat/torch/__init__.py | 12 +++++------ tests/test_array_namespace.py | 2 ++ tests/test_isdtype.py | 2 -- 4 files changed, 32 insertions(+), 17 deletions(-) diff --git a/array_api_compat/_internal.py b/array_api_compat/_internal.py index 98eeb23a..eafca7bd 100644 --- a/array_api_compat/_internal.py +++ b/array_api_compat/_internal.py @@ -44,14 +44,29 @@ def wrapped_f(*args, **kwargs): return inner -def _get_all_public_members(module, filter_=None): - """Get all public members of a module.""" - try: - return getattr(module, '__all__') - except AttributeError: - pass +def _get_all_public_members(module, exclude=None, extend_all=False): + """Get all public members of a module. - if filter_ is None: - filter_ = lambda name: name.startswith('_') # noqa: E731 + Parameters + ---------- + module : module + The module to get members from. + exclude : callable, optional + A callable that takes a name and returns True if the name should be + excluded from the list of members. + extend_all : bool, optional + If True, extend the module's __all__ attribute with the members of the + module derive from dir(module) + """ + members = getattr(module, '__all__', []) + + if members and not extend_all: + return members + + if exclude is None: + exclude = lambda name: name.startswith('_') # noqa: E731 + + members += [_ for _ in dir(module) if not exclude(_)] - return map(filter_, dir(module)) \ No newline at end of file + # remove duplicates + return list(set(members)) \ No newline at end of file diff --git a/array_api_compat/torch/__init__.py b/array_api_compat/torch/__init__.py index 76a4ff6f..bc96eee5 100644 --- a/array_api_compat/torch/__init__.py +++ b/array_api_compat/torch/__init__.py @@ -1,11 +1,11 @@ # Several names are not included in the above import * -import torch +import torch as _torch from torch import * # noqa: F401, F403 from .._internal import _get_all_public_members -def filter_(name): +def exlcude(name): if ( name.startswith("_") or name.endswith("_") @@ -13,14 +13,14 @@ def filter_(name): or "cpu" in name or "backward" in name ): - return False - return True + return True + return False -_torch_all = _get_all_public_members(torch, filter_=filter_) +_torch_all = _get_all_public_members(_torch, exclude=exlcude, extend_all=True) for _name in _torch_all: - globals()[_name] = getattr(torch, _name) + globals()[_name] = getattr(_torch, _name) from ..common._helpers import ( # noqa: E402 diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py index 99c67630..25c76a32 100644 --- a/tests/test_array_namespace.py +++ b/tests/test_array_namespace.py @@ -34,8 +34,10 @@ def test_array_namespace_errors(): def test_array_namespace_errors_torch(): torch = pytest.importorskip("torch") + np = pytest.importorskip("numpy") y = torch.asarray([1, 2]) + x = np.asarray([1, 2]) pytest.raises(TypeError, lambda: array_namespace(x, y)) pytest.raises(ValueError, lambda: array_namespace(x, api_version="2022.12")) diff --git a/tests/test_isdtype.py b/tests/test_isdtype.py index 7d7241f0..1b0f6e19 100644 --- a/tests/test_isdtype.py +++ b/tests/test_isdtype.py @@ -5,8 +5,6 @@ import pytest -from ._helpers import import_ - # Check the known dtypes by their string names def _spec_dtypes(library): From b0a323da04bc6e1829b0de6a040c01f62b171eb5 Mon Sep 17 00:00:00 2001 From: Axel Donath Date: Fri, 26 Jan 2024 22:01:35 -0500 Subject: [PATCH 23/37] Rename import_ to import_or_skip_cupy --- tests/_helpers.py | 5 +++-- tests/test_array_namespace.py | 15 +++++++-------- tests/test_common.py | 12 ++++++++---- tests/test_isdtype.py | 8 +++++--- tests/test_vendoring.py | 6 ++++-- 5 files changed, 27 insertions(+), 19 deletions(-) diff --git a/tests/_helpers.py b/tests/_helpers.py index 4066d07a..69952118 100644 --- a/tests/_helpers.py +++ b/tests/_helpers.py @@ -2,7 +2,8 @@ import pytest -def import_(library): - if 'cupy' in library: + +def import_or_skip_cupy(library): + if "cupy" in library: return pytest.importorskip(library) return import_module(library) diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py index 25c76a32..2c596d70 100644 --- a/tests/test_array_namespace.py +++ b/tests/test_array_namespace.py @@ -1,19 +1,23 @@ +import numpy as np import pytest +import torch import array_api_compat from array_api_compat import array_namespace +from ._helpers import import_or_skip_cupy + @pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"]) @pytest.mark.parametrize("api_version", [None, "2021.12"]) def test_array_namespace(library, api_version): - lib = pytest.importorskip(library) + xp = import_or_skip_cupy(library) - array = lib.asarray([1.0, 2.0, 3.0]) + array = xp.asarray([1.0, 2.0, 3.0]) namespace = array_api_compat.array_namespace(array, api_version=api_version) if "array_api" in library: - assert namespace == lib + assert namespace == xp else: if library == "dask.array": assert namespace == array_api_compat.dask.array @@ -22,8 +26,6 @@ def test_array_namespace(library, api_version): def test_array_namespace_errors(): - np = pytest.importorskip("numpy") - pytest.raises(TypeError, lambda: array_namespace([1])) pytest.raises(TypeError, lambda: array_namespace()) @@ -33,9 +35,6 @@ def test_array_namespace_errors(): def test_array_namespace_errors_torch(): - torch = pytest.importorskip("torch") - np = pytest.importorskip("numpy") - y = torch.asarray([1, 2]) x = np.asarray([1, 2]) pytest.raises(TypeError, lambda: array_namespace(x, y)) diff --git a/tests/test_common.py b/tests/test_common.py index 8c3a6ce2..bfaf58d2 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -1,16 +1,20 @@ -from array_api_compat import to_device - -import pytest import numpy as np +import pytest from numpy.testing import assert_allclose +from array_api_compat import to_device + +from ._helpers import import_or_skip_cupy + + @pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"]) def test_to_device_host(library): # different libraries have different semantics # for DtoH transfers; ensure that we support a portable # shim for common array libs # see: https://github.com/scipy/scipy/issues/18286#issuecomment-1527552919 - xp = pytest.importorskip('array_api_compat.' + library) + xp = import_or_skip_cupy("array_api_compat." + library) + expected = np.array([1, 2, 3]) x = xp.asarray([1, 2, 3]) x = to_device(x, "cpu") diff --git a/tests/test_isdtype.py b/tests/test_isdtype.py index 1b0f6e19..c27334da 100644 --- a/tests/test_isdtype.py +++ b/tests/test_isdtype.py @@ -5,6 +5,8 @@ import pytest +from ._helpers import import_or_skip_cupy + # Check the known dtypes by their string names def _spec_dtypes(library): @@ -59,12 +61,12 @@ def isdtype_(dtype_, kind): res = dtype_categories[kind](dtype_) else: res = dtype_ == kind - assert isinstance(res, bool) + assert type(res) is bool # noqa: E721 return res @pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"]) def test_isdtype_spec_dtypes(library): - xp = pytest.importorskip('array_api_compat.' + library) + xp = import_or_skip_cupy('array_api_compat.' + library) isdtype = xp.isdtype @@ -99,7 +101,7 @@ def test_isdtype_spec_dtypes(library): @pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"]) @pytest.mark.parametrize("dtype_", additional_dtypes) def test_isdtype_additional_dtypes(library, dtype_): - xp = pytest.importorskip('array_api_compat.' + library) + xp = import_or_skip_cupy('array_api_compat.' + library) isdtype = xp.isdtype diff --git a/tests/test_vendoring.py b/tests/test_vendoring.py index e3b38748..66fc6984 100644 --- a/tests/test_vendoring.py +++ b/tests/test_vendoring.py @@ -3,6 +3,7 @@ def test_vendoring_numpy(): from vendor_test import uses_numpy + uses_numpy._test_numpy() @@ -10,12 +11,13 @@ def test_vendoring_cupy(): pytest.importorskip("cupy") from vendor_test import uses_cupy + uses_cupy._test_cupy() + def test_vendoring_torch(): - pytest.importorskip("torch") - from vendor_test import uses_torch + uses_torch._test_torch() def test_vendoring_dask(): From 0ec2d8934b9f3c9cb1a38474b6f97e5a9bd86d92 Mon Sep 17 00:00:00 2001 From: Axel Donath Date: Fri, 26 Jan 2024 22:16:53 -0500 Subject: [PATCH 24/37] Add missing imports and sort __all__ --- array_api_compat/cupy/__init__.py | 70 ++++++++++++++++++++++++++++ array_api_compat/numpy/__init__.py | 72 ++++++++++++++++++++++++++++- array_api_compat/torch/__init__.py | 74 +++++++++++++++--------------- 3 files changed, 177 insertions(+), 39 deletions(-) diff --git a/array_api_compat/cupy/__init__.py b/array_api_compat/cupy/__init__.py index d7f27303..29c1a70f 100644 --- a/array_api_compat/cupy/__init__.py +++ b/array_api_compat/cupy/__init__.py @@ -16,12 +16,18 @@ # These imports may overwrite names from the import * above. from ._aliases import ( + UniqueAllResult, + UniqueCountsResult, + UniqueInverseResult, acos, acosh, + arange, + argsort, asarray, asarray_cupy, asin, asinh, + astype, atan, atan2, atanh, @@ -29,8 +35,37 @@ bitwise_left_shift, bitwise_right_shift, bool, + ceil, concat, + empty, + empty_like, + eye, + floor, + full, + full_like, + isdtype, + linspace, + matmul, + matrix_transpose, + nonzero, + ones, + ones_like, + permute_dims, pow, + prod, + reshape, + sort, + std, + sum, + tensordot, + trunc, + unique_all, + unique_counts, + unique_inverse, + unique_values, + var, + zeros, + zeros_like, ) from .linalg import matrix_transpose, vecdot @@ -55,12 +90,18 @@ ] __all__ += [ + "UniqueAllResult", + "UniqueCountsResult", + "UniqueInverseResult", "acos", "acosh", + "arange", + "argsort", "asarray", "asarray_cupy", "asin", "asinh", + "astype", "atan", "atan2", "atanh", @@ -68,8 +109,37 @@ "bitwise_left_shift", "bitwise_right_shift", "bool", + "ceil", "concat", + "empty", + "empty_like", + "eye", + "floor", + "full", + "full_like", + "isdtype", + "linspace", + "matmul", + "matrix_transpose", + "nonzero", + "ones", + "ones_like", + "permute_dims", "pow", + "prod", + "reshape", + "sort", + "std", + "sum", + "tensordot", + "trunc", + "unique_all", + "unique_counts", + "unique_inverse", + "unique_values", + "var", + "zeros", + "zeros_like", ] __all__ += [ diff --git a/array_api_compat/numpy/__init__.py b/array_api_compat/numpy/__init__.py index f63acb48..aca17387 100644 --- a/array_api_compat/numpy/__init__.py +++ b/array_api_compat/numpy/__init__.py @@ -15,12 +15,18 @@ # These imports may overwrite names from the import * above. from ._aliases import ( + UniqueAllResult, + UniqueCountsResult, + UniqueInverseResult, acos, acosh, + arange, + argsort, asarray, asarray_numpy, asin, asinh, + astype, atan, atan2, atanh, @@ -28,9 +34,37 @@ bitwise_left_shift, bitwise_right_shift, bool, + ceil, concat, - pow, + empty, + empty_like, + eye, + floor, + full, + full_like, isdtype, + linspace, + matmul, + matrix_transpose, + nonzero, + ones, + ones_like, + permute_dims, + pow, + prod, + reshape, + sort, + std, + sum, + tensordot, + trunc, + unique_all, + unique_counts, + unique_inverse, + unique_values, + var, + zeros, + zeros_like, ) from .linalg import matrix_transpose, vecdot @@ -55,12 +89,18 @@ ] __all__ += [ + "UniqueAllResult", + "UniqueCountsResult", + "UniqueInverseResult", "acos", "acosh", + "arange", + "argsort", "asarray", "asarray_numpy", "asin", "asinh", + "astype", "atan", "atan2", "atanh", @@ -68,9 +108,37 @@ "bitwise_left_shift", "bitwise_right_shift", "bool", + "ceil", "concat", - "pow", + "empty", + "empty_like", + "eye", + "floor", + "full", + "full_like", "isdtype", + "linspace", + "matmul", + "matrix_transpose", + "nonzero", + "ones", + "ones_like", + "permute_dims", + "pow", + "prod", + "reshape", + "sort", + "std", + "sum", + "tensordot", + "trunc", + "unique_all", + "unique_counts", + "unique_inverse", + "unique_values", + "var", + "zeros", + "zeros_like", ] __all__ += [ diff --git a/array_api_compat/torch/__init__.py b/array_api_compat/torch/__init__.py index bc96eee5..674227a1 100644 --- a/array_api_compat/torch/__init__.py +++ b/array_api_compat/torch/__init__.py @@ -115,71 +115,71 @@ def exlcude(name): ] __all__ += [ - "result_type", - "can_cast", - "permute_dims", - "bitwise_invert", - "newaxis", "add", + "all", + "any", + "arange", + "astype", "atan2", "bitwise_and", + "bitwise_invert", "bitwise_left_shift", "bitwise_or", "bitwise_right_shift", "bitwise_xor", + "broadcast_arrays", + "broadcast_to", + "can_cast", + "concat", "divide", + "empty", "equal", + "expand_dims", + "eye", + "flip", "floor_divide", + "full", "greater", "greater_equal", + "isdtype", "less", "less_equal", + "linspace", "logaddexp", + "matmul", + "matrix_transpose", + "max", + "mean", + "min", "multiply", + "newaxis", + "nonzero", "not_equal", + "ones", + "permute_dims", "pow", + "prod", "remainder", - "subtract", - "max", - "min", + "reshape", + "result_type", + "roll", "sort", - "prod", - "sum", - "any", - "all", - "mean", - "std", - "var", - "concat", "squeeze", - "broadcast_to", - "flip", - "roll", - "nonzero", - "where", - "reshape", - "arange", - "eye", - "linspace", - "full", - "ones", - "zeros", - "empty", + "std", + "subtract", + "sum", + "take", + "tensordot", "tril", "triu", - "expand_dims", - "astype", - "broadcast_arrays", "unique_all", "unique_counts", "unique_inverse", "unique_values", - "matmul", - "matrix_transpose", + "var", "vecdot", - "tensordot", - "isdtype", - "take", + "where", + "zeros", ] From 2baa4da3a4d1481cc862b2c75956f46b6856660c Mon Sep 17 00:00:00 2001 From: Axel Donath Date: Fri, 26 Jan 2024 22:20:14 -0500 Subject: [PATCH 25/37] More cleanup --- array_api_compat/common/_linalg.py | 10 ++-------- array_api_compat/common/_typing.py | 3 +++ 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index 3b17417d..0708b76a 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -11,7 +11,7 @@ else: from numpy.core.numeric import normalize_axis_tuple -from ._aliases import matmul, matrix_transpose, tensordot, vecdot, isdtype +from ._aliases import matrix_transpose, isdtype from .._internal import get_xp # These are in the main NumPy namespace but not in numpy.linalg @@ -149,10 +149,4 @@ def trace(x: ndarray, /, xp, *, offset: int = 0, dtype=None, **kwargs) -> ndarra dtype = xp.float64 elif x.dtype == xp.complex64: dtype = xp.complex128 - return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs)) - -__all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult', - 'QRResult', 'SlogdetResult', 'SVDResult', 'eigh', 'qr', 'slogdet', - 'svd', 'cholesky', 'matrix_rank', 'pinv', 'matrix_norm', - 'matrix_transpose', 'svdvals', 'vecdot', 'vector_norm', 'diagonal', - 'trace'] + return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs)) \ No newline at end of file diff --git a/array_api_compat/common/_typing.py b/array_api_compat/common/_typing.py index 3f178060..f20d085c 100644 --- a/array_api_compat/common/_typing.py +++ b/array_api_compat/common/_typing.py @@ -18,3 +18,6 @@ def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ... def __len__(self, /) -> int: ... SupportsBufferProtocol = Any + +Array = Any +Device = Any \ No newline at end of file From efd745c34789e9726f7942130ea7e32aebe0e8a2 Mon Sep 17 00:00:00 2001 From: Axel Donath Date: Fri, 26 Jan 2024 22:31:55 -0500 Subject: [PATCH 26/37] Remove redefinitions --- array_api_compat/cupy/__init__.py | 2 +- array_api_compat/numpy/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/array_api_compat/cupy/__init__.py b/array_api_compat/cupy/__init__.py index 29c1a70f..cab2e8cf 100644 --- a/array_api_compat/cupy/__init__.py +++ b/array_api_compat/cupy/__init__.py @@ -64,10 +64,10 @@ unique_inverse, unique_values, var, + vecdot, zeros, zeros_like, ) -from .linalg import matrix_transpose, vecdot __all__ = [] diff --git a/array_api_compat/numpy/__init__.py b/array_api_compat/numpy/__init__.py index aca17387..7c863417 100644 --- a/array_api_compat/numpy/__init__.py +++ b/array_api_compat/numpy/__init__.py @@ -63,10 +63,10 @@ unique_inverse, unique_values, var, + vecdot, zeros, zeros_like, ) -from .linalg import matrix_transpose, vecdot __all__ = [] From 49f2b7a73f4f74e71bb9fd3cc8fca3d469eb5b79 Mon Sep 17 00:00:00 2001 From: Axel Donath Date: Mon, 29 Jan 2024 14:58:28 -0500 Subject: [PATCH 27/37] Add ruff select F822 option [skip ci] --- .github/workflows/ruff.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml index c9f44886..2ce2bbdc 100644 --- a/.github/workflows/ruff.yml +++ b/.github/workflows/ruff.yml @@ -16,4 +16,4 @@ jobs: pip install ruff # Update output format to enable automatic inline annotations. - name: Run Ruff - run: ruff check --output-format=github . + run: ruff check --output-format=github --select F822 . From a748bfaefaf4201f5f266c1ced4c2e5f28a9ef94 Mon Sep 17 00:00:00 2001 From: Axel Donath Date: Mon, 29 Jan 2024 15:28:10 -0500 Subject: [PATCH 28/37] Add PLC0414 error code as well --- .github/workflows/ruff.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml index 2ce2bbdc..1f0846f8 100644 --- a/.github/workflows/ruff.yml +++ b/.github/workflows/ruff.yml @@ -16,4 +16,4 @@ jobs: pip install ruff # Update output format to enable automatic inline annotations. - name: Run Ruff - run: ruff check --output-format=github --select F822 . + run: ruff check --output-format=github --select F822,PLC0414 . From 6b4e92cb21bce5602ee9516aea50aa088a552132 Mon Sep 17 00:00:00 2001 From: Axel Donath Date: Wed, 31 Jan 2024 20:39:28 -0500 Subject: [PATCH 29/37] Avoid in place modification of __all__ in _get_all_public_members --- array_api_compat/_internal.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/array_api_compat/_internal.py b/array_api_compat/_internal.py index eafca7bd..97a2a01a 100644 --- a/array_api_compat/_internal.py +++ b/array_api_compat/_internal.py @@ -22,13 +22,16 @@ def func(x, /, xp, kwarg=None): arguments. """ + def inner(f): @wraps(f) def wrapped_f(*args, **kwargs): return f(*args, xp=xp, **kwargs) sig = signature(f) - new_sig = sig.replace(parameters=[sig.parameters[i] for i in sig.parameters if i != 'xp']) + new_sig = sig.replace( + parameters=[sig.parameters[i] for i in sig.parameters if i != "xp"] + ) if wrapped_f.__doc__ is None: wrapped_f.__doc__ = f"""\ @@ -46,7 +49,7 @@ def wrapped_f(*args, **kwargs): def _get_all_public_members(module, exclude=None, extend_all=False): """Get all public members of a module. - + Parameters ---------- module : module @@ -58,15 +61,15 @@ def _get_all_public_members(module, exclude=None, extend_all=False): If True, extend the module's __all__ attribute with the members of the module derive from dir(module) """ - members = getattr(module, '__all__', []) + members = getattr(module, "__all__", []) if members and not extend_all: return members if exclude is None: - exclude = lambda name: name.startswith('_') # noqa: E731 + exclude = lambda name: name.startswith("_") # noqa: E731 - members += [_ for _ in dir(module) if not exclude(_)] + members = members + [_ for _ in dir(module) if not exclude(_)] # remove duplicates - return list(set(members)) \ No newline at end of file + return list(set(members)) From a92f6408e518ccc082bc9dfacdc4fb9e542cd78a Mon Sep 17 00:00:00 2001 From: Axel Donath Date: Fri, 2 Feb 2024 11:18:06 -0500 Subject: [PATCH 30/37] Add sort check for __all__ --- .github/workflows/ruff.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml index 1f0846f8..9a09ffd7 100644 --- a/.github/workflows/ruff.yml +++ b/.github/workflows/ruff.yml @@ -16,4 +16,4 @@ jobs: pip install ruff # Update output format to enable automatic inline annotations. - name: Run Ruff - run: ruff check --output-format=github --select F822,PLC0414 . + run: ruff check --output-format=github --select F822,PLC0414,RUF022 --preview . From 9b1110bf3af756eb52ccf221a989e473b4ba1e94 Mon Sep 17 00:00:00 2001 From: Axel Donath Date: Fri, 2 Feb 2024 11:21:22 -0500 Subject: [PATCH 31/37] Sort __all__ lists --- array_api_compat/common/__init__.py | 6 +++--- array_api_compat/common/_aliases.py | 8 -------- array_api_compat/cupy/__init__.py | 6 +++--- array_api_compat/cupy/_typing.py | 2 +- array_api_compat/cupy/linalg.py | 24 ++++++++++++------------ array_api_compat/numpy/__init__.py | 6 +++--- array_api_compat/numpy/_aliases.py | 3 +-- array_api_compat/numpy/_typing.py | 2 +- array_api_compat/numpy/linalg.py | 2 +- array_api_compat/torch/__init__.py | 6 +++--- array_api_compat/torch/linalg.py | 4 ++-- 11 files changed, 30 insertions(+), 39 deletions(-) diff --git a/array_api_compat/common/__init__.py b/array_api_compat/common/__init__.py index 2bd64c1f..b941a31e 100644 --- a/array_api_compat/common/__init__.py +++ b/array_api_compat/common/__init__.py @@ -8,10 +8,10 @@ ) __all__ = [ - "is_array_api_obj", "array_namespace", - "get_namespace", "device", - "to_device", + "get_namespace", + "is_array_api_obj", "size", + "to_device", ] diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index fc8189d3..b58fb0ca 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -545,11 +545,3 @@ def isdtype( # more strict here to match the type annotation? Note that the # numpy.array_api implementation will be very strict. return dtype == kind - -__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like', - 'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like', - 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', - 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', - 'astype', 'std', 'var', 'permute_dims', 'reshape', 'argsort', - 'sort', 'nonzero', 'sum', 'prod', 'ceil', 'floor', 'trunc', - 'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype'] diff --git a/array_api_compat/cupy/__init__.py b/array_api_compat/cupy/__init__.py index cab2e8cf..b5eb5eea 100644 --- a/array_api_compat/cupy/__init__.py +++ b/array_api_compat/cupy/__init__.py @@ -81,12 +81,12 @@ ] __all__ += [ - "is_array_api_obj", "array_namespace", - "get_namespace", "device", - "to_device", + "get_namespace", + "is_array_api_obj", "size", + "to_device", ] __all__ += [ diff --git a/array_api_compat/cupy/_typing.py b/array_api_compat/cupy/_typing.py index f3d9aab6..b867ca94 100644 --- a/array_api_compat/cupy/_typing.py +++ b/array_api_compat/cupy/_typing.py @@ -1,9 +1,9 @@ from __future__ import annotations __all__ = [ - "ndarray", "Device", "Dtype", + "ndarray", ] import sys diff --git a/array_api_compat/cupy/linalg.py b/array_api_compat/cupy/linalg.py index 31fe58b3..cef74183 100644 --- a/array_api_compat/cupy/linalg.py +++ b/array_api_compat/cupy/linalg.py @@ -37,26 +37,26 @@ __all__ += _cupy_linalg_all __all__ += [ - "cross", - "matmul", - "outer", - "tensordot", "EighResult", "QRResult", - "SlogdetResult", "SVDResult", + "SlogdetResult", + "cholesky", + "cross", + "diagonal", "eigh", + "matmul", + "matrix_norm", + "matrix_rank", + "matrix_transpose", + "outer", + "pinv", "qr", "slogdet", "svd", - "cholesky", - "matrix_rank", - "pinv", - "matrix_norm", - "matrix_transpose", "svdvals", + "tensordot", + "trace", "vecdot", "vector_norm", - "diagonal", - "trace", ] diff --git a/array_api_compat/numpy/__init__.py b/array_api_compat/numpy/__init__.py index 7c863417..8ee9f711 100644 --- a/array_api_compat/numpy/__init__.py +++ b/array_api_compat/numpy/__init__.py @@ -80,12 +80,12 @@ ] __all__ += [ - "is_array_api_obj", "array_namespace", - "get_namespace", "device", - "to_device", + "get_namespace", + "is_array_api_obj", "size", + "to_device", ] __all__ += [ diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 66c543c7..ee1c1557 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -5,8 +5,7 @@ import numpy as np from .._internal import get_xp -from ..common import _aliases -from ..common import _linalg +from ..common import _aliases, _linalg asarray = asarray_numpy = partial(_aliases._asarray, namespace="numpy") asarray.__doc__ = _aliases._asarray.__doc__ diff --git a/array_api_compat/numpy/_typing.py b/array_api_compat/numpy/_typing.py index c5ebb5ab..53585f70 100644 --- a/array_api_compat/numpy/_typing.py +++ b/array_api_compat/numpy/_typing.py @@ -1,9 +1,9 @@ from __future__ import annotations __all__ = [ - "ndarray", "Device", "Dtype", + "ndarray", ] import sys diff --git a/array_api_compat/numpy/linalg.py b/array_api_compat/numpy/linalg.py index 0da1899e..fbe37cd7 100644 --- a/array_api_compat/numpy/linalg.py +++ b/array_api_compat/numpy/linalg.py @@ -40,8 +40,8 @@ __all__ += [ "EighResult", "QRResult", - "SlogdetResult", "SVDResult", + "SlogdetResult", "cholesky", "cross", "diagonal", diff --git a/array_api_compat/torch/__init__.py b/array_api_compat/torch/__init__.py index 674227a1..6492839f 100644 --- a/array_api_compat/torch/__init__.py +++ b/array_api_compat/torch/__init__.py @@ -106,12 +106,12 @@ def exlcude(name): __all__ += _torch_all __all__ += [ - "is_array_api_obj", "array_namespace", - "get_namespace", "device", - "to_device", + "get_namespace", + "is_array_api_obj", "size", + "to_device", ] __all__ += [ diff --git a/array_api_compat/torch/linalg.py b/array_api_compat/torch/linalg.py index 3f9811ed..160f074b 100644 --- a/array_api_compat/torch/linalg.py +++ b/array_api_compat/torch/linalg.py @@ -25,10 +25,10 @@ __all__ += [ "matrix_transpose", + "outer", "solve", "sum", - "outer", - "trace", "tensordot", + "trace", "vecdot", ] From 68c788fe2d11137ae3ef5b5c236ba948e66f3a3a Mon Sep 17 00:00:00 2001 From: Axel Donath Date: Fri, 2 Feb 2024 11:24:21 -0500 Subject: [PATCH 32/37] Use * import for array_api_compat/__init__.py --- array_api_compat/__init__.py | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/array_api_compat/__init__.py b/array_api_compat/__init__.py index 61e828a9..29e7be04 100644 --- a/array_api_compat/__init__.py +++ b/array_api_compat/__init__.py @@ -19,18 +19,4 @@ """ __version__ = '1.4.1' -from .common import ( - array_namespace, - get_namespace, - is_array_api_obj, - size, - to_device, -) - -__all__ = [ - "array_namespace", - "get_namespace", - "is_array_api_obj", - "size", - "to_device", -] +from .common import * # noqa: F401, F403 From 1720fb6e0f2e9039f0e50e6266a747b4e4a5e898 Mon Sep 17 00:00:00 2001 From: Axel Donath Date: Wed, 7 Feb 2024 08:49:22 -0500 Subject: [PATCH 33/37] Update array_api_compat/_internal.py Co-authored-by: Aaron Meurer --- array_api_compat/_internal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_compat/_internal.py b/array_api_compat/_internal.py index 97a2a01a..9073dd52 100644 --- a/array_api_compat/_internal.py +++ b/array_api_compat/_internal.py @@ -59,7 +59,7 @@ def _get_all_public_members(module, exclude=None, extend_all=False): excluded from the list of members. extend_all : bool, optional If True, extend the module's __all__ attribute with the members of the - module derive from dir(module) + module derived from dir(module). To be used for libraries that do not have a complete __all__ list. """ members = getattr(module, "__all__", []) From 5cd47dffcb3c14e3356eb078240884ceadf4ba34 Mon Sep 17 00:00:00 2001 From: Axel Donath Date: Wed, 7 Feb 2024 15:45:16 -0500 Subject: [PATCH 34/37] Adapt dask --- array_api_compat/dask/array/__init__.py | 186 +++++++++++++++++++++++- array_api_compat/dask/array/_aliases.py | 59 +------- 2 files changed, 183 insertions(+), 62 deletions(-) diff --git a/array_api_compat/dask/array/__init__.py b/array_api_compat/dask/array/__init__.py index a7c0b22e..237cc270 100644 --- a/array_api_compat/dask/array/__init__.py +++ b/array_api_compat/dask/array/__init__.py @@ -1,8 +1,186 @@ -from dask.array import * +from dask.array import * # noqa: F401, F403 +from dask.array import __all__ as _dask_array_all + +from dask.array import ( + # Element wise aliases + arccos as acos, + arccosh as acosh, + arcsin as asin, + arcsinh as asinh, + arctan as atan, + arctan2 as atan2, + arctanh as atanh, + # Other + concatenate as concat, + invert as bitwise_invert, + left_shift as bitwise_left_shift, + power as pow, + right_shift as bitwise_right_shift, + bool_ as bool, +) # These imports may overwrite names from the import * above. -from ._aliases import * +from numpy import ( + can_cast, + complex64, + complex128, + e, + finfo, + float32, + float64, + iinfo, + inf, + int8, + int16, + int32, + int64, + nan, + newaxis, + pi, + result_type, + uint8, + uint16, + uint32, + uint64, +) + +from ..common._helpers import ( + array_namespace, + device, + get_namespace, + is_array_api_obj, + size, + to_device, +) +from ._aliases import ( + UniqueAllResult, + UniqueCountsResult, + UniqueInverseResult, + arange, + asarray, + astype, + ceil, + empty, + empty_like, + eye, + floor, + full, + full_like, + isdtype, + linspace, + matmul, + matrix_transpose, + nonzero, + ones, + ones_like, + permute_dims, + prod, + reshape, + std, + sum, + tensordot, + trunc, + unique_all, + unique_counts, + unique_inverse, + unique_values, + var, + vecdot, + zeros, + zeros_like, +) + +__all__ = [] + +__all__ += _dask_array_all + +__all__ += [ + "can_cast", + "complex64", + "complex128", + "e", + "finfo", + "float32", + "float64", + "iinfo", + "inf", + "int8", + "int16", + "int32", + "int64", + "nan", + "newaxis", + "pi", + "result_type", + "uint8", + "uint16", + "uint32", + "uint64", +] + +__all__ += [ + "array_namespace", + "device", + "get_namespace", + "is_array_api_obj", + "size", + "to_device", +] + +# 'sort', 'argsort' are unsupported by dask.array + +__all__ += [ + "UniqueAllResult", + "UniqueCountsResult", + "UniqueInverseResult", + "acos", + "acosh", + "arange", + "asarray", + "asin", + "asinh", + "astype", + "atan", + "atan2", + "atanh", + "bitwise_invert", + "bitwise_left_shift", + "bitwise_right_shift", + "bool", + "ceil", + "concat", + "empty", + "empty_like", + "eye", + "floor", + "full", + "full_like", + "isdtype", + "linspace", + "matmul", + "matrix_transpose", + "nonzero", + "ones", + "ones_like", + "permute_dims", + "pow", + "prod", + "reshape", + "std", + "sum", + "tensordot", + "trunc", + "unique_all", + "unique_counts", + "unique_inverse", + "unique_values", + "var", + "vecdot", + "zeros", + "zeros_like", +] + -__array_api_version__ = '2022.12' +__array_api_version__ = "2022.12" -__import__(__package__ + '.linalg') +__import__(__package__ + ".linalg") diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index ef9ea356..9397633e 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -1,4 +1,5 @@ from __future__ import annotations +from functools import partial from ...common import _aliases from ...common._helpers import _check_device @@ -6,32 +7,6 @@ from ..._internal import get_xp import numpy as np -from numpy import ( - # Constants - e, - inf, - nan, - pi, - newaxis, - # Dtypes - bool_ as bool, - float32, - float64, - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, - complex64, - complex128, - iinfo, - finfo, - can_cast, - result_type, -) from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -75,7 +50,6 @@ def dask_arange( arange = get_xp(da)(dask_arange) eye = get_xp(da)(_aliases.eye) -from functools import partial asarray = partial(_aliases._asarray, namespace='dask.array') asarray.__doc__ = _aliases._asarray.__doc__ @@ -112,34 +86,3 @@ def dask_arange( matmul = get_xp(np)(_aliases.matmul) tensordot = get_xp(np)(_aliases.tensordot) -from dask.array import ( - # Element wise aliases - arccos as acos, - arccosh as acosh, - arcsin as asin, - arcsinh as asinh, - arctan as atan, - arctan2 as atan2, - arctanh as atanh, - left_shift as bitwise_left_shift, - right_shift as bitwise_right_shift, - invert as bitwise_invert, - power as pow, - # Other - concatenate as concat, -) - -# exclude these from all since -_da_unsupported = ['sort', 'argsort'] - -common_aliases = [alias for alias in _aliases.__all__ if alias not in _da_unsupported] - -__all__ = common_aliases + ['asarray', 'bool', 'acos', - 'acosh', 'asin', 'asinh', 'atan', 'atan2', - 'atanh', 'bitwise_left_shift', 'bitwise_invert', - 'bitwise_right_shift', 'concat', 'pow', - 'e', 'inf', 'nan', 'pi', 'newaxis', 'float32', 'float64', 'int8', - 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64', - 'complex64', 'complex128', 'iinfo', 'finfo', 'can_cast', 'result_type'] - -del da, partial, common_aliases, _da_unsupported, From c5d55ae674a1109e78f9bb93feb1af7dc993f573 Mon Sep 17 00:00:00 2001 From: Axel Donath Date: Wed, 7 Feb 2024 15:56:06 -0500 Subject: [PATCH 35/37] Fix ruff errors for dask/array/linalg --- array_api_compat/dask/array/_aliases.py | 40 +++++++--- array_api_compat/dask/array/linalg.py | 98 +++++++++++++------------ 2 files changed, 81 insertions(+), 57 deletions(-) diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index 9397633e..14b27070 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -1,17 +1,18 @@ from __future__ import annotations + from functools import partial +from typing import TYPE_CHECKING -from ...common import _aliases -from ...common._helpers import _check_device +import numpy as np from ..._internal import get_xp +from ...common import _aliases, _linalg +from ...common._helpers import _check_device -import numpy as np - -from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Optional, Union - from ...common._typing import ndarray, Device, Dtype + from typing import Optional, Tuple, Union + + from ...common._typing import Device, Dtype, ndarray import dask.array as da @@ -24,6 +25,7 @@ # not pass stop/step as keyword arguments, which will cause # an error with dask + # TODO: delete the xp stuff, it shouldn't be necessary def dask_arange( start: Union[int, float], @@ -34,7 +36,7 @@ def dask_arange( xp, dtype: Optional[Dtype] = None, device: Optional[Device] = None, - **kwargs + **kwargs, ) -> ndarray: _check_device(xp, device) args = [start] @@ -47,10 +49,11 @@ def dask_arange( args.append(step) return xp.arange(*args, dtype=dtype, **kwargs) + arange = get_xp(da)(dask_arange) eye = get_xp(da)(_aliases.eye) -asarray = partial(_aliases._asarray, namespace='dask.array') +asarray = partial(_aliases._asarray, namespace="dask.array") asarray.__doc__ = _aliases._asarray.__doc__ linspace = get_xp(da)(_aliases.linspace) @@ -86,3 +89,22 @@ def dask_arange( matmul = get_xp(np)(_aliases.matmul) tensordot = get_xp(np)(_aliases.tensordot) + +EighResult = _linalg.EighResult +QRResult = _linalg.QRResult +SlogdetResult = _linalg.SlogdetResult +SVDResult = _linalg.SVDResult +qr = get_xp(da)(_linalg.qr) +cholesky = get_xp(da)(_linalg.cholesky) +matrix_rank = get_xp(da)(_linalg.matrix_rank) +matrix_norm = get_xp(da)(_linalg.matrix_norm) + + +def svdvals(x: ndarray) -> Union[ndarray, Tuple[ndarray, ...]]: + # TODO: can't avoid computing U or V for dask + _, s, _ = da.linalg.svd(x) + return s + + +vector_norm = get_xp(da)(_linalg.vector_norm) +diagonal = get_xp(da)(_linalg.diagonal) diff --git a/array_api_compat/dask/array/linalg.py b/array_api_compat/dask/array/linalg.py index c8aa7c9f..321a4a5f 100644 --- a/array_api_compat/dask/array/linalg.py +++ b/array_api_compat/dask/array/linalg.py @@ -1,48 +1,50 @@ -from __future__ import annotations - -from dask.array.linalg import * -from ...common import _linalg -from ..._internal import get_xp -from dask.array import matmul, tensordot, trace, outer -from ._aliases import matrix_transpose, vecdot - -import dask.array as da - -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from typing import Union, Tuple - from ...common._typing import ndarray - -# cupy.linalg doesn't have __all__. If it is added, replace this with -# -# from cupy.linalg import __all__ as linalg_all -_n = {} -exec('from dask.array.linalg import *', _n) -del _n['__builtins__'] -linalg_all = list(_n) -del _n - -EighResult = _linalg.EighResult -QRResult = _linalg.QRResult -SlogdetResult = _linalg.SlogdetResult -SVDResult = _linalg.SVDResult -qr = get_xp(da)(_linalg.qr) -cholesky = get_xp(da)(_linalg.cholesky) -matrix_rank = get_xp(da)(_linalg.matrix_rank) -matrix_norm = get_xp(da)(_linalg.matrix_norm) - -def svdvals(x: ndarray) -> Union[ndarray, Tuple[ndarray, ...]]: - # TODO: can't avoid computing U or V for dask - _, s, _ = svd(x) - return s - -vector_norm = get_xp(da)(_linalg.vector_norm) -diagonal = get_xp(da)(_linalg.diagonal) - -__all__ = linalg_all + ["EighResult", "QRResult", "SlogdetResult", - "SVDResult", "qr", "cholesky", "matrix_rank", "matrix_norm", - "svdvals", "vector_norm", "diagonal"] - -del get_xp -del da -del _linalg +import dask.array as _da +from dask.array import ( + matmul, + outer, + tensordot, + trace, +) +from dask.array.linalg import * # noqa: F401, F403 + +from .._internal import _get_all_public_members +from ._aliases import ( + EighResult, + QRResult, + SlogdetResult, + SVDResult, + cholesky, + diagonal, + matrix_norm, + matrix_rank, + matrix_transpose, + qr, + svdvals, + vecdot, + vector_norm, +) + +__all__ = [ + "matmul", + "outer", + "tensordot", + "trace", +] + +__all__ += _get_all_public_members(_da.linalg) + +__all__ += [ + "EighResult", + "QRResult", + "SlogdetResult", + "SVDResult", + "qr", + "cholesky", + "matrix_rank", + "matrix_norm", + "matrix_transpose", + "vecdot", + "svdvals", + "vector_norm", + "diagonal", +] From 49851b563b46ddb57498e8c7f2a732b3dda1d6d8 Mon Sep 17 00:00:00 2001 From: Axel Donath Date: Wed, 7 Feb 2024 15:59:45 -0500 Subject: [PATCH 36/37] Fix __all__ order in dask linalg --- array_api_compat/dask/array/linalg.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/array_api_compat/dask/array/linalg.py b/array_api_compat/dask/array/linalg.py index 321a4a5f..cc9ac880 100644 --- a/array_api_compat/dask/array/linalg.py +++ b/array_api_compat/dask/array/linalg.py @@ -36,15 +36,15 @@ __all__ += [ "EighResult", "QRResult", - "SlogdetResult", "SVDResult", - "qr", + "SlogdetResult", "cholesky", - "matrix_rank", + "diagonal", "matrix_norm", + "matrix_rank", "matrix_transpose", - "vecdot", + "qr", "svdvals", + "vecdot", "vector_norm", - "diagonal", ] From 2db3d6ab64db47fdf66f2a07b2285ad6ab8a685e Mon Sep 17 00:00:00 2001 From: Axel Donath Date: Wed, 7 Feb 2024 16:12:40 -0500 Subject: [PATCH 37/37] Fix import of __all__ in dask/array/__init__.py --- array_api_compat/dask/array/__init__.py | 32 +++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/array_api_compat/dask/array/__init__.py b/array_api_compat/dask/array/__init__.py index 237cc270..d6b5e94e 100644 --- a/array_api_compat/dask/array/__init__.py +++ b/array_api_compat/dask/array/__init__.py @@ -1,22 +1,45 @@ +import dask.array as _da from dask.array import * # noqa: F401, F403 -from dask.array import __all__ as _dask_array_all - from dask.array import ( # Element wise aliases arccos as acos, +) +from dask.array import ( arccosh as acosh, +) +from dask.array import ( arcsin as asin, +) +from dask.array import ( arcsinh as asinh, +) +from dask.array import ( arctan as atan, +) +from dask.array import ( arctan2 as atan2, +) +from dask.array import ( arctanh as atanh, +) +from dask.array import ( + bool_ as bool, +) +from dask.array import ( # Other concatenate as concat, +) +from dask.array import ( invert as bitwise_invert, +) +from dask.array import ( left_shift as bitwise_left_shift, +) +from dask.array import ( power as pow, +) +from dask.array import ( right_shift as bitwise_right_shift, - bool_ as bool, ) # These imports may overwrite names from the import * above. @@ -52,6 +75,7 @@ size, to_device, ) +from ..internal import _get_all_public_members from ._aliases import ( UniqueAllResult, UniqueCountsResult, @@ -92,7 +116,7 @@ __all__ = [] -__all__ += _dask_array_all +__all__ += _get_all_public_members(_da) __all__ += [ "can_cast",