From 38551c6d323fa610a6b18612ba5793e4d4ac2a87 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 15 Oct 2024 14:34:58 -0600 Subject: [PATCH 1/4] Remove __array__ This makes it raise an exception, since it isn't supported by the standard (if we just leave it unimplemented, then np.asarray() returns an object dtype array, which is not good). There is one issue here from the test suite, which is that this breaks the logic in asarray() for converting lists of array_api_strict 0-D arrays. I'm not yet sure what to do about that. Fixes #67. --- array_api_strict/_array_object.py | 31 +++++++-------------- array_api_strict/tests/test_array_object.py | 7 ----- 2 files changed, 10 insertions(+), 28 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index d8ed018..e7b53d9 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -125,28 +125,17 @@ def __repr__(self: Array, /) -> str: mid = np.array2string(self._array, separator=', ', prefix=prefix, suffix=suffix) return prefix + mid + suffix - # This function is not required by the spec, but we implement it here for - # convenience so that np.asarray(array_api_strict.Array) will work. + # Disallow __array__, meaning calling `np.func()` on an array_api_strict + # array will give an error. If we don't explicitly disallow it, NumPy + # defaults to creating an object dtype array, which would lead to + # confusing error messages at best and surprising bugs at worst. + # + # The alternative of course is to just support __array__, which is what we + # used to do. But this isn't actually supported by the standard, so it can + # lead to code assuming np.asarray(other_array) would always work in the + # standard. def __array__(self, dtype: None | np.dtype[Any] = None, copy: None | bool = None) -> npt.NDArray[Any]: - """ - Warning: this method is NOT part of the array API spec. Implementers - of other libraries need not include it, and users should not assume it - will be present in other implementations. - - """ - # copy keyword is new in 2.0.0; for older versions don't use it - # retry without that keyword. - if np.__version__[0] < '2': - return np.asarray(self._array, dtype=dtype) - elif np.__version__.startswith('2.0.0-dev0'): - # Handle dev version for which we can't know based on version - # number whether or not the copy keyword is supported. - try: - return np.asarray(self._array, dtype=dtype, copy=copy) - except TypeError: - return np.asarray(self._array, dtype=dtype) - else: - return np.asarray(self._array, dtype=dtype, copy=copy) + raise ValueError("Conversion from an array_api_strict array to a NumPy ndarray is not supported") # These are various helper functions to make the array behavior match the # spec in places where it either deviates from or is more strict than diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index dad6696..dd3f6c2 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -342,13 +342,6 @@ def test_array_properties(): assert isinstance(b.mT, Array) assert b.mT.shape == (3, 2) -def test___array__(): - a = ones((2, 3), dtype=int16) - assert np.asarray(a) is a._array - b = np.asarray(a, dtype=np.float64) - assert np.all(np.equal(b, np.ones((2, 3), dtype=np.float64))) - assert b.dtype == np.float64 - def test_allow_newaxis(): a = ones(5) indexed_a = a[None, :] From 6c3b7d6c3baf8092ac5040d41a917f806a7aa30d Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 28 Oct 2024 12:32:10 -0600 Subject: [PATCH 2/4] Temporarily enable __array__ in asarray so that parsing list of lists of Array can work --- array_api_strict/_array_object.py | 18 ++++++++++++++++++ array_api_strict/_creation_functions.py | 18 ++++++++++++++++-- .../tests/test_creation_functions.py | 15 ++++++++++++++- 3 files changed, 48 insertions(+), 3 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index e7b53d9..f72b68e 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -53,6 +53,8 @@ def __repr__(self): _default = object() +_allow_array = False + class Array: """ n-d array object for the array API namespace. @@ -135,6 +137,22 @@ def __repr__(self: Array, /) -> str: # lead to code assuming np.asarray(other_array) would always work in the # standard. def __array__(self, dtype: None | np.dtype[Any] = None, copy: None | bool = None) -> npt.NDArray[Any]: + # We have to allow this to be internally enabled as there's no other + # easy way to parse a list of Array objects in asarray(). + if _allow_array: + # copy keyword is new in 2.0.0; for older versions don't use it + # retry without that keyword. + if np.__version__[0] < '2': + return np.asarray(self._array, dtype=dtype) + elif np.__version__.startswith('2.0.0-dev0'): + # Handle dev version for which we can't know based on version + # number whether or not the copy keyword is supported. + try: + return np.asarray(self._array, dtype=dtype, copy=copy) + except TypeError: + return np.asarray(self._array, dtype=dtype) + else: + return np.asarray(self._array, dtype=dtype, copy=copy) raise ValueError("Conversion from an array_api_strict array to a NumPy ndarray is not supported") # These are various helper functions to make the array behavior match the diff --git a/array_api_strict/_creation_functions.py b/array_api_strict/_creation_functions.py index 67ba67c..52f9389 100644 --- a/array_api_strict/_creation_functions.py +++ b/array_api_strict/_creation_functions.py @@ -1,6 +1,6 @@ from __future__ import annotations - +from contextlib import contextmanager from typing import TYPE_CHECKING, List, Optional, Tuple, Union if TYPE_CHECKING: @@ -16,6 +16,19 @@ import numpy as np +@contextmanager +def allow_array(): + """ + Temporarily enable Array.__array__. This is needed for np.array to parse + list of lists of Array objects. + """ + from . import _array_object + original_value = _array_object._allow_array + try: + _array_object._allow_array = True + yield + finally: + _array_object._allow_array = original_value def _check_valid_dtype(dtype): # Note: Only spelling dtypes as the dtype objects is supported. @@ -94,7 +107,8 @@ def asarray( # Give a better error message in this case. NumPy would convert this # to an object array. TODO: This won't handle large integers in lists. raise OverflowError("Integer out of bounds for array dtypes") - res = np.array(obj, dtype=_np_dtype, copy=copy) + with allow_array(): + res = np.array(obj, dtype=_np_dtype, copy=copy) return Array._new(res) diff --git a/array_api_strict/tests/test_creation_functions.py b/array_api_strict/tests/test_creation_functions.py index 819afad..bb486b1 100644 --- a/array_api_strict/tests/test_creation_functions.py +++ b/array_api_strict/tests/test_creation_functions.py @@ -22,7 +22,7 @@ zeros, zeros_like, ) -from .._dtypes import float32, float64 +from .._dtypes import int16, float32, float64 from .._array_object import Array, CPU_DEVICE from .._flags import set_array_api_strict_flags @@ -97,6 +97,19 @@ def test_asarray_copy(): a[0] = 0 assert all(b[0] == 0) +def test_asarray_list_of_lists(): + a = asarray(1, dtype=int16) + b = asarray([1], dtype=int16) + res = asarray([a, a]) + assert res.shape == (2,) + assert res.dtype == int16 + assert all(res == asarray([1, 1])) + + res = asarray([b, b]) + assert res.shape == (2, 1) + assert res.dtype == int16 + assert all(res == asarray([[1], [1]])) + def test_arange_errors(): arange(1, device=CPU_DEVICE) # Doesn't error assert_raises(ValueError, lambda: arange(1, device="cpu")) From bb2816745191892e2114bf3cb3c929ecee9c826f Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 31 Oct 2024 13:51:48 -0600 Subject: [PATCH 3/4] Fix incorrect merge conflict resolution --- array_api_strict/_array_object.py | 4 ++-- array_api_strict/tests/test_array_object.py | 9 ++++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 8e12ef4..9416e38 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -160,8 +160,8 @@ def __array__(self, dtype: None | np.dtype[Any] = None, copy: None | bool = None # We have to allow this to be internally enabled as there's no other # easy way to parse a list of Array objects in asarray(). if _allow_array: - if self._device != CPU_DEVICE: - raise RuntimeError(f"Can not convert array on the '{self._device}' device to a Numpy array.") + if self._device != CPU_DEVICE: + raise RuntimeError(f"Can not convert array on the '{self._device}' device to a Numpy array.") # copy keyword is new in 2.0.0; for older versions don't use it # retry without that keyword. if np.__version__[0] < '2': diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index c59aa54..c7781d7 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -353,14 +353,17 @@ def test_array_properties(): def test_array_conversion(): # Check that arrays on the CPU device can be converted to NumPy - # but arrays on other devices can't + # but arrays on other devices can't. Note this is testing the logic in + # __array__, which is only used in asarray when converting lists of + # arrays. a = ones((2, 3)) - np.asarray(a) + asarray([a]) for device in ("device1", "device2"): a = ones((2, 3), device=array_api_strict.Device(device)) with pytest.raises(RuntimeError, match="Can not convert array"): - np.asarray(a) + asarray([a]) + def test_allow_newaxis(): a = ones(5) indexed_a = a[None, :] From d630ee5f5a7c4474838c1782d694fe1733008d68 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 31 Oct 2024 14:57:10 -0600 Subject: [PATCH 4/4] Fix the pinv function, which was implicitly using __array__ --- array_api_strict/_linalg.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/array_api_strict/_linalg.py b/array_api_strict/_linalg.py index d341277..7d379a0 100644 --- a/array_api_strict/_linalg.py +++ b/array_api_strict/_linalg.py @@ -267,6 +267,8 @@ def pinv(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> Array: # default tolerance by max(M, N). if rtol is None: rtol = max(x.shape[-2:]) * finfo(x.dtype).eps + if isinstance(rtol, Array): + rtol = rtol._array return Array._new(np.linalg.pinv(x._array, rcond=rtol), device=x.device) @requires_extension('linalg')