Skip to content

Commit

Permalink
Array API tests fixes (#636)
Browse files Browse the repository at this point in the history
* Handle complex dtype in full

* Exclude test_getitem from array api tests since negative step sizes are not supported

* Fix broadcast_to when target shape has size 0
  • Loading branch information
tomwhite authored Dec 5, 2024
1 parent 2d7e6cc commit 83382e1
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 5 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/array-api-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ jobs:
# edge case failures (https://github.com/cubed-dev/cubed/issues/420)
array_api_tests/test_linalg.py::test_tensordot
array_api_tests/test_linalg.py::test_vecdot
# (getitem with negative step size is not implemented)
array_api_tests/test_array_object.py::test_getitem
# not implemented
array_api_tests/test_array_object.py::test_setitem
Expand Down
2 changes: 2 additions & 0 deletions cubed/array_api/creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ def full(
dtype = nxp.int64
elif isinstance(fill_value, float):
dtype = nxp.float64
elif isinstance(fill_value, complex):
dtype = nxp.complex128
else:
raise TypeError("Invalid input to full")
chunksize = to_chunksize(normalize_chunks(chunks, shape=shape, dtype=dtype))
Expand Down
10 changes: 6 additions & 4 deletions cubed/array_api/manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,16 @@ def broadcast_to(x, /, shape, *, chunks=None):
):
raise ValueError(f"cannot broadcast shape {x.shape} to shape {shape}")

# TODO: fix case where shape has a dimension of size zero

if chunks is None:
# New dimensions and broadcast dimensions have chunk size 1
# This behaviour differs from dask where it is the full dimension size
xchunks = normalize_chunks(x.chunks, x.shape, dtype=x.dtype)
chunks = tuple((1,) * s for s in shape[:ndim_new]) + tuple(
bd if old > 1 else ((1,) * new if new > 0 else (0,))

def chunklen(shapelen):
return (1,) * shapelen if shapelen > 0 else (0,)

chunks = tuple(chunklen(s) for s in shape[:ndim_new]) + tuple(
bd if old > 1 else chunklen(new)
for bd, old, new in zip(xchunks, x.shape, shape[ndim_new:])
)
else:
Expand Down
3 changes: 2 additions & 1 deletion cubed/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,8 @@ def test_broadcast_arrays(executor):
@pytest.mark.parametrize(
"shape, chunks, new_shape, new_chunks, new_chunks_expected",
[
# ((5, 1, 6), (3, 1, 3), (5, 0, 6), None, ((3, 2), (0,), (3, 3))), # fails
((), (), (0,), None, ((0,),)),
((5, 1, 6), (3, 1, 3), (5, 0, 6), None, ((3, 2), (0,), (3, 3))),
((5, 1, 6), (3, 1, 3), (5, 4, 6), None, ((3, 2), (1, 1, 1, 1), (3, 3))),
((5, 1, 6), (3, 1, 3), (2, 5, 1, 6), None, ((1, 1), (3, 2), (1,), (3, 3))),
((5, 1, 6), (3, 1, 3), (5, 3, 6), (3, 3, 3), ((3, 2), (3,), (3, 3))),
Expand Down

0 comments on commit 83382e1

Please sign in to comment.