diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 34caff7..9416e38 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -66,6 +66,8 @@ def __hash__(self): _default = object() +_allow_array = False + class Array: """ n-d array object for the array API namespace. @@ -145,30 +147,35 @@ def __repr__(self: Array, /) -> str: mid = np.array2string(self._array, separator=', ', prefix=prefix, suffix=suffix) return prefix + mid + suffix - # This function is not required by the spec, but we implement it here for - # convenience so that np.asarray(array_api_strict.Array) will work. + # Disallow __array__, meaning calling `np.func()` on an array_api_strict + # array will give an error. If we don't explicitly disallow it, NumPy + # defaults to creating an object dtype array, which would lead to + # confusing error messages at best and surprising bugs at worst. + # + # The alternative of course is to just support __array__, which is what we + # used to do. But this isn't actually supported by the standard, so it can + # lead to code assuming np.asarray(other_array) would always work in the + # standard. def __array__(self, dtype: None | np.dtype[Any] = None, copy: None | bool = None) -> npt.NDArray[Any]: - """ - Warning: this method is NOT part of the array API spec. Implementers - of other libraries need not include it, and users should not assume it - will be present in other implementations. - - """ - if self._device != CPU_DEVICE: - raise RuntimeError(f"Can not convert array on the '{self._device}' device to a Numpy array.") - # copy keyword is new in 2.0.0; for older versions don't use it - # retry without that keyword. - if np.__version__[0] < '2': - return np.asarray(self._array, dtype=dtype) - elif np.__version__.startswith('2.0.0-dev0'): - # Handle dev version for which we can't know based on version - # number whether or not the copy keyword is supported. - try: - return np.asarray(self._array, dtype=dtype, copy=copy) - except TypeError: + # We have to allow this to be internally enabled as there's no other + # easy way to parse a list of Array objects in asarray(). + if _allow_array: + if self._device != CPU_DEVICE: + raise RuntimeError(f"Can not convert array on the '{self._device}' device to a Numpy array.") + # copy keyword is new in 2.0.0; for older versions don't use it + # retry without that keyword. + if np.__version__[0] < '2': return np.asarray(self._array, dtype=dtype) - else: - return np.asarray(self._array, dtype=dtype, copy=copy) + elif np.__version__.startswith('2.0.0-dev0'): + # Handle dev version for which we can't know based on version + # number whether or not the copy keyword is supported. + try: + return np.asarray(self._array, dtype=dtype, copy=copy) + except TypeError: + return np.asarray(self._array, dtype=dtype) + else: + return np.asarray(self._array, dtype=dtype, copy=copy) + raise ValueError("Conversion from an array_api_strict array to a NumPy ndarray is not supported") # These are various helper functions to make the array behavior match the # spec in places where it either deviates from or is more strict than diff --git a/array_api_strict/_creation_functions.py b/array_api_strict/_creation_functions.py index 7924a85..d6d3efa 100644 --- a/array_api_strict/_creation_functions.py +++ b/array_api_strict/_creation_functions.py @@ -1,6 +1,6 @@ from __future__ import annotations - +from contextlib import contextmanager from typing import TYPE_CHECKING, List, Optional, Tuple, Union if TYPE_CHECKING: @@ -16,6 +16,19 @@ import numpy as np +@contextmanager +def allow_array(): + """ + Temporarily enable Array.__array__. This is needed for np.array to parse + list of lists of Array objects. + """ + from . import _array_object + original_value = _array_object._allow_array + try: + _array_object._allow_array = True + yield + finally: + _array_object._allow_array = original_value def _check_valid_dtype(dtype): # Note: Only spelling dtypes as the dtype objects is supported. @@ -99,7 +112,8 @@ def asarray( # Give a better error message in this case. NumPy would convert this # to an object array. TODO: This won't handle large integers in lists. raise OverflowError("Integer out of bounds for array dtypes") - res = np.array(obj, dtype=_np_dtype, copy=copy) + with allow_array(): + res = np.array(obj, dtype=_np_dtype, copy=copy) return Array._new(res, device=device) diff --git a/array_api_strict/_linalg.py b/array_api_strict/_linalg.py index d341277..7d379a0 100644 --- a/array_api_strict/_linalg.py +++ b/array_api_strict/_linalg.py @@ -267,6 +267,8 @@ def pinv(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> Array: # default tolerance by max(M, N). if rtol is None: rtol = max(x.shape[-2:]) * finfo(x.dtype).eps + if isinstance(rtol, Array): + rtol = rtol._array return Array._new(np.linalg.pinv(x._array, rcond=rtol), device=x.device) @requires_extension('linalg') diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index a9ea26d..c7781d7 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -350,23 +350,19 @@ def test_array_properties(): assert isinstance(b.mT, Array) assert b.mT.shape == (3, 2) -def test___array__(): - a = ones((2, 3), dtype=int16) - assert np.asarray(a) is a._array - b = np.asarray(a, dtype=np.float64) - assert np.all(np.equal(b, np.ones((2, 3), dtype=np.float64))) - assert b.dtype == np.float64 def test_array_conversion(): # Check that arrays on the CPU device can be converted to NumPy - # but arrays on other devices can't + # but arrays on other devices can't. Note this is testing the logic in + # __array__, which is only used in asarray when converting lists of + # arrays. a = ones((2, 3)) - np.asarray(a) + asarray([a]) for device in ("device1", "device2"): a = ones((2, 3), device=array_api_strict.Device(device)) with pytest.raises(RuntimeError, match="Can not convert array"): - np.asarray(a) + asarray([a]) def test_allow_newaxis(): a = ones(5) diff --git a/array_api_strict/tests/test_creation_functions.py b/array_api_strict/tests/test_creation_functions.py index 71fd76b..c93a08a 100644 --- a/array_api_strict/tests/test_creation_functions.py +++ b/array_api_strict/tests/test_creation_functions.py @@ -22,7 +22,7 @@ zeros, zeros_like, ) -from .._dtypes import float32, float64 +from .._dtypes import int16, float32, float64 from .._array_object import Array, CPU_DEVICE, Device from .._flags import set_array_api_strict_flags @@ -97,6 +97,19 @@ def test_asarray_copy(): a[0] = 0 assert all(b[0] == 0) +def test_asarray_list_of_lists(): + a = asarray(1, dtype=int16) + b = asarray([1], dtype=int16) + res = asarray([a, a]) + assert res.shape == (2,) + assert res.dtype == int16 + assert all(res == asarray([1, 1])) + + res = asarray([b, b]) + assert res.shape == (2, 1) + assert res.dtype == int16 + assert all(res == asarray([[1], [1]])) + def test_asarray_device_inference(): assert asarray([1, 2, 3]).device == CPU_DEVICE