Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Declare Array API 2023.12 support #651

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions .github/workflows/array-api-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
uses: actions/checkout@v3
with:
repository: data-apis/array-api-tests
ref: 'db95e67b29235249e5776ca2b6bb4e77117e0690' # Latest commit as of 2024-08-08
ref: '1572b129c6682211abfe139e112592226c361a6c' # Latest commit as of 2024-12-04
path: array-api-tests
submodules: "true"
- name: Set up Python ${{ matrix.python-version }}
Expand Down Expand Up @@ -90,8 +90,7 @@ jobs:
array_api_tests/test_has_names.py

# signatures of items not implemented
array_api_tests/test_signatures.py::test_func_signature[std]
array_api_tests/test_signatures.py::test_func_signature[var]
array_api_tests/test_signatures.py::test_func_signature[cumulative_sum]
array_api_tests/test_signatures.py::test_func_signature[unique_all]
array_api_tests/test_signatures.py::test_func_signature[unique_counts]
array_api_tests/test_signatures.py::test_func_signature[unique_inverse]
Expand All @@ -110,13 +109,15 @@ jobs:
array_api_tests/test_linalg.py::test_vecdot
# (getitem with negative step size is not implemented)
array_api_tests/test_array_object.py::test_getitem
# test_searchsorted depends on sort which is not implemented
array_api_tests/test_searching_functions.py::test_searchsorted

# not implemented
array_api_tests/test_array_object.py::test_setitem
array_api_tests/test_array_object.py::test_setitem_masking
array_api_tests/test_manipulation_functions.py::test_repeat
array_api_tests/test_sorting_functions.py
array_api_tests/test_statistical_functions.py::test_std
array_api_tests/test_statistical_functions.py::test_var
array_api_tests/test_statistical_functions.py::test_cumulative_sum

# finfo(float32).eps returns float32 but should return float
array_api_tests/test_data_type_functions.py::test_finfo[float32]
Expand All @@ -126,6 +127,9 @@ jobs:
# https://github.com/numpy/numpy/issues/18881
array_api_tests/test_creation_functions.py::test_linspace

# https://github.com/numpy/numpy/issues/20870
#array_api_tests/test_data_type_functions.py::test_can_cast

EOF

pytest -v -rxXfEA --hypothesis-max-examples=2 --disable-data-dependent-shapes --disable-extension linalg --hypothesis-disable-deadline
6 changes: 2 additions & 4 deletions api_status.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
## Array API Coverage Implementation Status

Cubed supports version [2022.12](https://data-apis.org/array-api/2022.12/index.html) of the Python array API standard, with a few exceptions noted below. The [Fourier transform functions](https://data-apis.org/array-api/2022.12/extensions/fourier_transform_functions.html) are *not* supported.

Support for version [2023.12](https://data-apis.org/array-api/2023.12/index.html) is tracked in Cubed issue [#438](https://github.com/cubed-dev/cubed/issues/438).
Cubed supports version [2023.12](https://data-apis.org/array-api/2023.12/index.html) of the Python array API standard, with a few exceptions noted below. The [Fourier transform functions](https://data-apis.org/array-api/2023.12/extensions/fourier_transform_functions.html) are *not* supported.

This table shows which parts of the the [Array API](https://data-apis.org/array-api/latest/API_specification/index.html) have been implemented in Cubed, and which ones are missing. The version column shows the version when the feature was added to the standard, for version 2022.12 or later.

Expand Down Expand Up @@ -61,7 +59,7 @@ This table shows which parts of the the [Array API](https://data-apis.org/array-
| | `expand_dims` | :white_check_mark: | | |
| | `flip` | :white_check_mark: | | |
| | `permute_dims` | :white_check_mark: | | |
| | `repeat` | :white_check_mark: | | |
| | `repeat` | :white_check_mark: | 2023.12 | |
| | `reshape` | :white_check_mark: | | Partial implementation |
| | `roll` | :white_check_mark: | | |
| | `squeeze` | :white_check_mark: | | |
Expand Down
2 changes: 1 addition & 1 deletion cubed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@

# Array API

__array_api_version__ = "2022.12"
__array_api_version__ = "2023.12"

from .array_api.inspection import __array_namespace_info__

Expand Down
2 changes: 1 addition & 1 deletion cubed/array_api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__all__ = []

__array_api_version__ = "2022.12"
__array_api_version__ = "2023.12"

from .inspection import __array_namespace_info__

Expand Down
6 changes: 5 additions & 1 deletion cubed/array_api/array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,11 @@ def __abs__(self, /):
return elemwise(nxp.abs, self, dtype=dtype)

def __array_namespace__(self, /, *, api_version=None):
if api_version is not None and api_version not in ("2021.12", "2022.12"):
if api_version is not None and api_version not in (
"2021.12",
"2022.12",
"2023.12",
):
raise ValueError(f"Unrecognized array API version: {api_version!r}")
import cubed.array_api as array_api

Expand Down
2 changes: 1 addition & 1 deletion cubed/array_api/data_type_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from cubed.core import CoreArray, map_blocks


def astype(x, dtype, /, *, copy=True):
def astype(x, dtype, /, *, copy=True, device=None):
if not copy and dtype == x.dtype:
return x
return map_blocks(_astype, x, dtype=dtype, astype_dtype=dtype)
Expand Down
9 changes: 7 additions & 2 deletions cubed/array_api/manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,9 @@ def permute_dims(x, /, axes):


def repeat(x, repeats, /, *, axis=0):
if not isinstance(repeats, int):
raise ValueError("repeat only supports integral values for `repeats`")

if axis is None:
x = flatten(x)
axis = 0
Expand Down Expand Up @@ -599,8 +602,10 @@ def unstack(x, /, *, axis=0):

n_arrays = x.shape[axis]

if n_arrays == 1:
return (x,)
if n_arrays == 0:
return ()
elif n_arrays == 1:
return (squeeze(x, axis=axis),)

shape = x.shape[:axis] + x.shape[axis + 1 :]
dtype = x.dtype
Expand Down
12 changes: 0 additions & 12 deletions cubed/array_api/statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@
_real_numeric_dtypes,
_signed_integer_dtypes,
_unsigned_integer_dtypes,
complex64,
complex128,
float32,
float64,
int64,
uint64,
)
Expand Down Expand Up @@ -128,10 +124,6 @@ def prod(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None):
dtype = int64
elif x.dtype in _unsigned_integer_dtypes:
dtype = uint64
elif x.dtype == float32:
dtype = float64
elif x.dtype == complex64:
dtype = complex128
else:
dtype = x.dtype
extra_func_kwargs = dict(dtype=dtype)
Expand Down Expand Up @@ -169,10 +161,6 @@ def sum(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None):
dtype = int64
elif x.dtype in _unsigned_integer_dtypes:
dtype = uint64
elif x.dtype == float32:
dtype = float64
elif x.dtype == complex64:
dtype = complex128
else:
dtype = x.dtype
extra_func_kwargs = dict(dtype=dtype)
Expand Down
9 changes: 7 additions & 2 deletions cubed/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,10 +722,15 @@ def test_unstack(spec, executor, chunks):
assert_array_equal(cu, np.full((4, 6), 3))


def test_unstack_noop(spec):
def test_unstack_zero_arrays(spec):
a = xp.full((0, 4, 6), 1, chunks=(1, 2, 3), spec=spec)
assert xp.unstack(a) == ()


def test_unstack_single_array(spec):
a = xp.full((1, 4, 6), 1, chunks=(1, 2, 3), spec=spec)
(b,) = xp.unstack(a)
assert a is b
assert_array_equal(b.compute(), np.full((4, 6), 1))


# Searching functions
Expand Down
4 changes: 1 addition & 3 deletions docs/array-api.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# Python Array API

Cubed implements version 2022.12 of the [Python Array API standard](https://data-apis.org/array-api/2022.12/index.html) in `cubed.array_api`, with a few exceptions listed on the [coverage status](https://github.com/cubed-dev/cubed/blob/main/api_status.md) page. The [Fourier transform functions](https://data-apis.org/array-api/2022.12/extensions/fourier_transform_functions.html) are *not* supported.

Support for version [2023.12](https://data-apis.org/array-api/2023.12/index.html) is tracked in Cubed issue [#438](https://github.com/cubed-dev/cubed/issues/438).
Cubed implements version 2023.12 of the [Python Array API standard](https://data-apis.org/array-api/2023.12/index.html) in `cubed.array_api`, with a few exceptions listed on the [coverage status](https://github.com/cubed-dev/cubed/blob/main/api_status.md) page. The [Fourier transform functions](https://data-apis.org/array-api/2023.12/extensions/fourier_transform_functions.html) are *not* supported.

## Differences between Cubed and the standard

Expand Down
Loading