Skip to content

Commit

Permalink
Merge pull request #69 from asmeurer/rm-__array__
Browse files Browse the repository at this point in the history
Remove __array__
  • Loading branch information
asmeurer authored Nov 2, 2024
2 parents ef71d85 + 5485345 commit b9e64e5
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 34 deletions.
51 changes: 29 additions & 22 deletions array_api_strict/_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def __hash__(self):

_default = object()

_allow_array = False

class Array:
"""
n-d array object for the array API namespace.
Expand Down Expand Up @@ -145,30 +147,35 @@ def __repr__(self: Array, /) -> str:
mid = np.array2string(self._array, separator=', ', prefix=prefix, suffix=suffix)
return prefix + mid + suffix

# This function is not required by the spec, but we implement it here for
# convenience so that np.asarray(array_api_strict.Array) will work.
# Disallow __array__, meaning calling `np.func()` on an array_api_strict
# array will give an error. If we don't explicitly disallow it, NumPy
# defaults to creating an object dtype array, which would lead to
# confusing error messages at best and surprising bugs at worst.
#
# The alternative of course is to just support __array__, which is what we
# used to do. But this isn't actually supported by the standard, so it can
# lead to code assuming np.asarray(other_array) would always work in the
# standard.
def __array__(self, dtype: None | np.dtype[Any] = None, copy: None | bool = None) -> npt.NDArray[Any]:
"""
Warning: this method is NOT part of the array API spec. Implementers
of other libraries need not include it, and users should not assume it
will be present in other implementations.
"""
if self._device != CPU_DEVICE:
raise RuntimeError(f"Can not convert array on the '{self._device}' device to a Numpy array.")
# copy keyword is new in 2.0.0; for older versions don't use it
# retry without that keyword.
if np.__version__[0] < '2':
return np.asarray(self._array, dtype=dtype)
elif np.__version__.startswith('2.0.0-dev0'):
# Handle dev version for which we can't know based on version
# number whether or not the copy keyword is supported.
try:
return np.asarray(self._array, dtype=dtype, copy=copy)
except TypeError:
# We have to allow this to be internally enabled as there's no other
# easy way to parse a list of Array objects in asarray().
if _allow_array:
if self._device != CPU_DEVICE:
raise RuntimeError(f"Can not convert array on the '{self._device}' device to a Numpy array.")
# copy keyword is new in 2.0.0; for older versions don't use it
# retry without that keyword.
if np.__version__[0] < '2':
return np.asarray(self._array, dtype=dtype)
else:
return np.asarray(self._array, dtype=dtype, copy=copy)
elif np.__version__.startswith('2.0.0-dev0'):
# Handle dev version for which we can't know based on version
# number whether or not the copy keyword is supported.
try:
return np.asarray(self._array, dtype=dtype, copy=copy)
except TypeError:
return np.asarray(self._array, dtype=dtype)
else:
return np.asarray(self._array, dtype=dtype, copy=copy)
raise ValueError("Conversion from an array_api_strict array to a NumPy ndarray is not supported")

# These are various helper functions to make the array behavior match the
# spec in places where it either deviates from or is more strict than
Expand Down
18 changes: 16 additions & 2 deletions array_api_strict/_creation_functions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations


from contextlib import contextmanager
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

if TYPE_CHECKING:
Expand All @@ -16,6 +16,19 @@

import numpy as np

@contextmanager
def allow_array():
"""
Temporarily enable Array.__array__. This is needed for np.array to parse
list of lists of Array objects.
"""
from . import _array_object
original_value = _array_object._allow_array
try:
_array_object._allow_array = True
yield
finally:
_array_object._allow_array = original_value

def _check_valid_dtype(dtype):
# Note: Only spelling dtypes as the dtype objects is supported.
Expand Down Expand Up @@ -99,7 +112,8 @@ def asarray(
# Give a better error message in this case. NumPy would convert this
# to an object array. TODO: This won't handle large integers in lists.
raise OverflowError("Integer out of bounds for array dtypes")
res = np.array(obj, dtype=_np_dtype, copy=copy)
with allow_array():
res = np.array(obj, dtype=_np_dtype, copy=copy)
return Array._new(res, device=device)


Expand Down
2 changes: 2 additions & 0 deletions array_api_strict/_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,8 @@ def pinv(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> Array:
# default tolerance by max(M, N).
if rtol is None:
rtol = max(x.shape[-2:]) * finfo(x.dtype).eps
if isinstance(rtol, Array):
rtol = rtol._array
return Array._new(np.linalg.pinv(x._array, rcond=rtol), device=x.device)

@requires_extension('linalg')
Expand Down
14 changes: 5 additions & 9 deletions array_api_strict/tests/test_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,23 +350,19 @@ def test_array_properties():
assert isinstance(b.mT, Array)
assert b.mT.shape == (3, 2)

def test___array__():
a = ones((2, 3), dtype=int16)
assert np.asarray(a) is a._array
b = np.asarray(a, dtype=np.float64)
assert np.all(np.equal(b, np.ones((2, 3), dtype=np.float64)))
assert b.dtype == np.float64

def test_array_conversion():
# Check that arrays on the CPU device can be converted to NumPy
# but arrays on other devices can't
# but arrays on other devices can't. Note this is testing the logic in
# __array__, which is only used in asarray when converting lists of
# arrays.
a = ones((2, 3))
np.asarray(a)
asarray([a])

for device in ("device1", "device2"):
a = ones((2, 3), device=array_api_strict.Device(device))
with pytest.raises(RuntimeError, match="Can not convert array"):
np.asarray(a)
asarray([a])

def test_allow_newaxis():
a = ones(5)
Expand Down
15 changes: 14 additions & 1 deletion array_api_strict/tests/test_creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
zeros,
zeros_like,
)
from .._dtypes import float32, float64
from .._dtypes import int16, float32, float64
from .._array_object import Array, CPU_DEVICE, Device
from .._flags import set_array_api_strict_flags

Expand Down Expand Up @@ -97,6 +97,19 @@ def test_asarray_copy():
a[0] = 0
assert all(b[0] == 0)

def test_asarray_list_of_lists():
a = asarray(1, dtype=int16)
b = asarray([1], dtype=int16)
res = asarray([a, a])
assert res.shape == (2,)
assert res.dtype == int16
assert all(res == asarray([1, 1]))

res = asarray([b, b])
assert res.shape == (2, 1)
assert res.dtype == int16
assert all(res == asarray([[1], [1]]))


def test_asarray_device_inference():
assert asarray([1, 2, 3]).device == CPU_DEVICE
Expand Down

0 comments on commit b9e64e5

Please sign in to comment.