From 447421ed70de3fd38c3e95e55fedd5ebcc72ab35 Mon Sep 17 00:00:00 2001 From: Tom White Date: Sun, 19 May 2024 15:23:07 +0100 Subject: [PATCH] Implement roll (#465) --- .github/workflows/array-api-tests.yml | 1 - api_status.md | 10 +++--- cubed/array_api/__init__.py | 2 ++ cubed/array_api/manipulation_functions.py | 41 +++++++++++++++++++++++ cubed/tests/test_array_api.py | 31 +++++++++++++++++ 5 files changed, 79 insertions(+), 6 deletions(-) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index 7aaf9fb1..813e667f 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -97,7 +97,6 @@ jobs: array_api_tests/test_array_object.py::test_setitem array_api_tests/test_array_object.py::test_setitem_masking array_api_tests/test_manipulation_functions.py::test_flip - array_api_tests/test_manipulation_functions.py::test_roll array_api_tests/test_sorting_functions.py array_api_tests/test_statistical_functions.py::test_std array_api_tests/test_statistical_functions.py::test_var diff --git a/api_status.md b/api_status.md index cf0ba4fd..80174a25 100644 --- a/api_status.md +++ b/api_status.md @@ -30,10 +30,10 @@ This table shows which parts of the the [Array API](https://data-apis.org/array- | | `zeros` | :white_check_mark: | | | | | `zeros_like` | :white_check_mark: | | | | Data Type Functions | `astype` | :white_check_mark: | | | -| | `can_cast` | :white_check_mark: | | Same as `numpy.array_api` | -| | `finfo` | :white_check_mark: | | Same as `numpy.array_api` | -| | `iinfo` | :white_check_mark: | | Same as `numpy.array_api` | -| | `result_type` | :white_check_mark: | | Same as `numpy.array_api` | +| | `can_cast` | :white_check_mark: | | | +| | `finfo` | :white_check_mark: | | | +| | `iinfo` | :white_check_mark: | | | +| | `result_type` | :white_check_mark: | | | | Data Types | `bool`, `int8`, ... | :white_check_mark: | | | | Elementwise Functions | `add` | :white_check_mark: | | Example of a binary function | | | `negative` | :white_check_mark: | | Example of a unary function | @@ -52,7 +52,7 @@ This table shows which parts of the the [Array API](https://data-apis.org/array- | | `flip` | :x: | 2 | Needs indexing with step=-1, [#114](https://github.com/cubed-dev/cubed/issues/114) | | | `permute_dims` | :white_check_mark: | | | | | `reshape` | :white_check_mark: | | Partial implementation | -| | `roll` | :x: | 2 | Use `concat` and `reshape`, [#115](https://github.com/cubed-dev/cubed/issues/115) | +| | `roll` | :white_check_mark: | | | | | `squeeze` | :white_check_mark: | | | | | `stack` | :white_check_mark: | | | | Searching Functions | `argmax` | :white_check_mark: | | | diff --git a/cubed/array_api/__init__.py b/cubed/array_api/__init__.py index 8ef442ef..dd709141 100644 --- a/cubed/array_api/__init__.py +++ b/cubed/array_api/__init__.py @@ -224,6 +224,7 @@ moveaxis, permute_dims, reshape, + roll, squeeze, stack, ) @@ -236,6 +237,7 @@ "moveaxis", "permute_dims", "reshape", + "roll", "squeeze", "stack", ] diff --git a/cubed/array_api/manipulation_functions.py b/cubed/array_api/manipulation_functions.py index 115dc1f3..aaceeba4 100644 --- a/cubed/array_api/manipulation_functions.py +++ b/cubed/array_api/manipulation_functions.py @@ -287,6 +287,47 @@ def _reshape_chunk(x, template): return nxp.reshape(x, template.shape) +def roll(x, /, shift, *, axis=None): + # based on dask roll + result = x + + if axis is None: + result = flatten(result) + + if not isinstance(shift, int): + raise TypeError("Expect `shift` to be an int when `axis` is None.") + + shift = (shift,) + axis = (0,) + else: + if not isinstance(shift, tuple): + shift = (shift,) + if not isinstance(axis, tuple): + axis = (axis,) + + if len(shift) != len(axis): + raise ValueError("Must have the same number of shifts as axes.") + + for i, s in zip(axis, shift): + shape = result.shape[i] + s = 0 if shape == 0 else -s % shape + + sl1 = result.ndim * [slice(None)] + sl2 = result.ndim * [slice(None)] + + sl1[i] = slice(s, None) + sl2[i] = slice(None, s) + + sl1 = tuple(sl1) + sl2 = tuple(sl2) + + # note we want the concatenated array to have the same chunking as input, + # not the chunking of result[sl1], which may be different + result = concat([result[sl1], result[sl2]], axis=i, chunks=result.chunks) + + return reshape(result, x.shape) + + def stack(arrays, /, *, axis=0): if not arrays: raise ValueError("Need array(s) to stack") diff --git a/cubed/tests/test_array_api.py b/cubed/tests/test_array_api.py index 8a62bd9c..e54c4843 100644 --- a/cubed/tests/test_array_api.py +++ b/cubed/tests/test_array_api.py @@ -507,6 +507,37 @@ def test_reshape_chunks_with_smaller_end_chunk(spec, executor): ) +def _maybe_len(a): + try: + return len(a) + except TypeError: + return 0 + + +@pytest.mark.parametrize( + "chunks, shift, axis", + [ + ((2, 6), 3, None), + ((2, 6), 3, 0), + ((2, 6), (3, 9), (0, 1)), + ((2, 6), (3, 9), None), + ((2, 6), (3, 9), 1), + ], +) +def test_roll(spec, executor, chunks, shift, axis): + x = np.arange(4 * 6).reshape((4, 6)) + a = cubed.from_array(x, chunks=chunks, spec=spec) + + if _maybe_len(shift) != _maybe_len(axis): + with pytest.raises(TypeError if axis is None else ValueError): + xp.roll(a, shift, axis=axis) + else: + assert_array_equal( + xp.roll(a, shift, axis=axis).compute(executor=executor), + np.roll(x, shift, axis), + ) + + def test_squeeze_1d(spec, executor): a = xp.asarray([[1, 2, 3]], chunks=(1, 2), spec=spec) b = xp.squeeze(a, 0)