From ce26cf0c3db037fb1662686ddda0b18484989c5f Mon Sep 17 00:00:00 2001 From: Anton <100830759+antonwolfy@users.noreply.github.com> Date: Wed, 10 Jul 2024 13:19:30 +0200 Subject: [PATCH] Update `dpnp.take` implementation to get rid of limitations for input arguments (#1909) * Remove limitations from dpnp.take implementation * Add more test to cover specail cases and increase code coverage * Applied pre-commit hook * Corrected test_over_index * Update docsctrings with resolving typos * Use dpnp.reshape() to change shape and create dpnp array from usm_ndarray result * Remove data syncronization from dpnp.get_result_array() * Update dpnp/dpnp_iface_indexing.py Co-authored-by: vtavana <120411540+vtavana@users.noreply.github.com> * Applied review comments --------- Co-authored-by: vtavana <120411540+vtavana@users.noreply.github.com> --- dpnp/dpnp_array.py | 2 +- dpnp/dpnp_iface_indexing.py | 122 ++-- tests/test_indexing.py | 132 ++-- .../cupy/core_tests/test_ndarray.py | 690 ++++++++++++++++++ .../cupy/indexing_tests/test_indexing.py | 6 +- 5 files changed, 857 insertions(+), 95 deletions(-) create mode 100644 tests/third_party/cupy/core_tests/test_ndarray.py diff --git a/dpnp/dpnp_array.py b/dpnp/dpnp_array.py index 60adb091ed7..1738e1dfcaf 100644 --- a/dpnp/dpnp_array.py +++ b/dpnp/dpnp_array.py @@ -1399,7 +1399,7 @@ def swapaxes(self, axis1, axis2): return dpnp.swapaxes(self, axis1=axis1, axis2=axis2) - def take(self, indices, /, *, axis=None, out=None, mode="wrap"): + def take(self, indices, axis=None, out=None, mode="wrap"): """ Take elements from an array along an axis. diff --git a/dpnp/dpnp_iface_indexing.py b/dpnp/dpnp_iface_indexing.py index 20a046c82c1..e0afa9427d4 100644 --- a/dpnp/dpnp_iface_indexing.py +++ b/dpnp/dpnp_iface_indexing.py @@ -37,6 +37,8 @@ """ +import operator + import dpctl.tensor as dpt import numpy from numpy.core.numeric import normalize_axis_index @@ -1247,41 +1249,55 @@ def select(condlist, choicelist, default=0): # pylint: disable=redefined-outer-name -def take(x, indices, /, *, axis=None, out=None, mode="wrap"): +def take(a, indices, /, *, axis=None, out=None, mode="wrap"): """ Take elements from an array along an axis. + When `axis` is not ``None``, this function does the same thing as "fancy" + indexing (indexing arrays using arrays); however, it can be easier to use + if you need elements along a given axis. A call such as + ``dpnp.take(a, indices, axis=3)`` is equivalent to + ``a[:, :, :, indices, ...]``. + For full documentation refer to :obj:`numpy.take`. + Parameters + ---------- + a : {dpnp.ndarray, usm_ndarray}, (Ni..., M, Nk...) + The source array. + indices : {array_like, scalars}, (Nj...) + The indices of the values to extract. + Also allow scalars for `indices`. + axis : {None, int, bool, 0-d array of integer dtype}, optional + The axis over which to select values. By default, the flattened + input array is used. + Default: ``None``. + out : {None, dpnp.ndarray, usm_ndarray}, optional (Ni..., Nj..., Nk...) + If provided, the result will be placed in this array. It should + be of the appropriate shape and dtype. + Default: ``None``. + mode : {"wrap", "clip"}, optional + Specifies how out-of-bounds indices will be handled. Possible values + are: + + - ``"wrap"``: clamps indices to (``-n <= i < n``), then wraps + negative indices. + - ``"clip"``: clips indices to (``0 <= i < n``). + + Default: ``"wrap"``. + Returns ------- - out : dpnp.ndarray - An array with shape x.shape[:axis] + indices.shape + x.shape[axis + 1:] - filled with elements from `x`. - - Limitations - ----------- - Parameters `x` and `indices` are supported either as :class:`dpnp.ndarray` - or :class:`dpctl.tensor.usm_ndarray`. - Parameter `indices` is supported as 1-D array of integer data type. - Parameter `out` is supported only with default value. - Parameter `mode` is supported with ``wrap``, the default, and ``clip`` - values. - Providing parameter `axis` is optional when `x` is a 1-D array. - Otherwise the function will be executed sequentially on CPU. + out : dpnp.ndarray, (Ni..., Nj..., Nk...) + The returned array has the same type as `a`. See Also -------- :obj:`dpnp.compress` : Take elements using a boolean mask. + :obj:`dpnp.ndarray.take` : Equivalent method. :obj:`dpnp.take_along_axis` : Take elements by matching the array and the index arrays. - Notes - ----- - How out-of-bounds indices will be handled. - "wrap" - clamps indices to (-n <= i < n), then wraps negative indices. - "clip" - clips indices to (0 <= i < n) - Examples -------- >>> import dpnp as np @@ -1302,29 +1318,53 @@ def take(x, indices, /, *, axis=None, out=None, mode="wrap"): >>> np.take(x, indices, mode="clip") array([4, 4, 4, 8, 8]) + If `indices` is not one dimensional, the output also has these dimensions. + + >>> np.take(x, [[0, 1], [2, 3]]) + array([[4, 3], + [5, 7]]) + """ - if dpnp.is_supported_array_type(x) and dpnp.is_supported_array_type( - indices - ): - if indices.ndim != 1 or not dpnp.issubdtype( - indices.dtype, dpnp.integer - ): - pass - elif axis is None and x.ndim > 1: - pass - elif out is not None: - pass - elif mode not in ("clip", "wrap"): - pass - else: - dpt_array = dpnp.get_usm_ndarray(x) - dpt_indices = dpnp.get_usm_ndarray(indices) - return dpnp_array._create_from_usm_ndarray( - dpt.take(dpt_array, dpt_indices, axis=axis, mode=mode) - ) + if mode not in ("wrap", "clip"): + raise ValueError(f"`mode` must be 'wrap' or 'clip', but got `{mode}`.") + + usm_a = dpnp.get_usm_ndarray(a) + if not dpnp.is_supported_array_type(indices): + usm_ind = dpt.asarray( + indices, usm_type=a.usm_type, sycl_queue=a.sycl_queue + ) + else: + usm_ind = dpnp.get_usm_ndarray(indices) + + a_ndim = a.ndim + if axis is None: + res_shape = usm_ind.shape + + if a_ndim > 1: + # dpt.take requires flattened input array + usm_a = dpt.reshape(usm_a, -1) + elif a_ndim == 0: + axis = normalize_axis_index(operator.index(axis), 1) + res_shape = usm_ind.shape + else: + axis = normalize_axis_index(operator.index(axis), a_ndim) + a_sh = a.shape + res_shape = a_sh[:axis] + usm_ind.shape + a_sh[axis + 1 :] + + if usm_ind.ndim != 1: + # dpt.take supports only 1-D array of indices + usm_ind = dpt.reshape(usm_ind, -1) + + if not dpnp.issubdtype(usm_ind.dtype, dpnp.integer): + # dpt.take supports only integer dtype for array of indices + usm_ind = dpt.astype(usm_ind, dpnp.intp, copy=False, casting="safe") + + usm_res = dpt.take(usm_a, usm_ind, axis=axis, mode=mode) - return call_origin(numpy.take, x, indices, axis, out, mode) + # need to reshape the result if shape of indices array was changed + result = dpnp.reshape(usm_res, res_shape) + return dpnp.get_result_array(result, out) def take_along_axis(a, indices, axis): diff --git a/tests/test_indexing.py b/tests/test_indexing.py index 8b54bc482ce..c8e6e37d0da 100644 --- a/tests/test_indexing.py +++ b/tests/test_indexing.py @@ -535,6 +535,90 @@ def test_broadcast(self, arr_dt, idx_dt): assert_array_equal(np_a, dp_a) +class TestTake: + @pytest.mark.parametrize("a_dt", get_all_dtypes(no_none=True)) + @pytest.mark.parametrize("ind_dt", get_all_dtypes(no_none=True)) + @pytest.mark.parametrize( + "indices", [[-2, 2], [-5, 4]], ids=["[-2, 2]", "[-5, 4]"] + ) + @pytest.mark.parametrize("mode", ["clip", "wrap"]) + def test_1d(self, a_dt, ind_dt, indices, mode): + a = numpy.array([-2, -1, 0, 1, 2], dtype=a_dt) + ind = numpy.array(indices, dtype=ind_dt) + ia, iind = dpnp.array(a), dpnp.array(ind) + + if numpy.can_cast(ind_dt, numpy.intp, casting="safe"): + result = dpnp.take(ia, iind, mode=mode) + expected = numpy.take(a, ind, mode=mode) + assert_array_equal(result, expected) + else: + assert_raises(TypeError, ia.take, iind, mode=mode) + assert_raises(TypeError, a.take, ind, mode=mode) + + @pytest.mark.parametrize("a_dt", get_all_dtypes(no_none=True)) + @pytest.mark.parametrize("ind_dt", get_integer_dtypes()) + @pytest.mark.parametrize( + "indices", [[-1, 0], [-3, 2]], ids=["[-1, 0]", "[-3, 2]"] + ) + @pytest.mark.parametrize("mode", ["clip", "wrap"]) + @pytest.mark.parametrize("axis", [0, 1], ids=["0", "1"]) + def test_2d(self, a_dt, ind_dt, indices, mode, axis): + a = numpy.array([[-1, 0, 1], [-2, -3, -4], [2, 3, 4]], dtype=a_dt) + ind = numpy.array(indices, dtype=ind_dt) + ia, iind = dpnp.array(a), dpnp.array(ind) + + result = ia.take(iind, axis=axis, mode=mode) + expected = a.take(ind, axis=axis, mode=mode) + assert_array_equal(result, expected) + + @pytest.mark.parametrize("a_dt", get_all_dtypes(no_none=True)) + @pytest.mark.parametrize("indices", [[-5, 5]], ids=["[-5, 5]"]) + @pytest.mark.parametrize("mode", ["clip", "wrap"]) + def test_over_index(self, a_dt, indices, mode): + a = dpnp.array([-2, -1, 0, 1, 2], dtype=a_dt) + ind = dpnp.array(indices, dtype=numpy.intp) + + result = dpnp.take(a, ind, mode=mode) + expected = dpnp.array([-2, 2], dtype=a.dtype) + assert_array_equal(result, expected) + + @pytest.mark.parametrize("xp", [numpy, dpnp]) + @pytest.mark.parametrize("indices", [[0], [1]], ids=["[0]", "[1]"]) + @pytest.mark.parametrize("mode", ["clip", "wrap"]) + def test_index_error(self, xp, indices, mode): + # take from a 0-length dimension + a = xp.empty((2, 3, 0, 4)) + assert_raises(IndexError, a.take, indices, axis=2, mode=mode) + + def test_bool_axis(self): + a = numpy.array([[[1]]]) + ia = dpnp.array(a) + + result = ia.take([0], axis=False) + expected = a.take([0], axis=0) # numpy raises an error for bool axis + assert_array_equal(result, expected) + + def test_axis_as_array(self): + a = numpy.array([[[1]]]) + ia = dpnp.array(a) + + result = ia.take([0], axis=ia) + expected = a.take( + [0], axis=1 + ) # numpy raises an error for axis as array + assert_array_equal(result, expected) + + def test_mode_raise(self): + a = dpnp.array([[1, 2], [3, 4]]) + assert_raises(ValueError, a.take, [-1, 4], mode="raise") + + @pytest.mark.parametrize("xp", [numpy, dpnp]) + def test_unicode_mode(self, xp): + a = xp.arange(10) + k = b"\xc3\xa4".decode("UTF8") + assert_raises(ValueError, a.take, 5, mode=k) + + class TestTakeAlongAxis: @pytest.mark.parametrize( "func, argfunc, kwargs", @@ -964,54 +1048,6 @@ def test_select(): assert_array_equal(expected, result) -@pytest.mark.parametrize("array_type", get_all_dtypes()) -@pytest.mark.parametrize( - "indices_type", [numpy.int32, numpy.int64], ids=["int32", "int64"] -) -@pytest.mark.parametrize( - "indices", [[-2, 2], [-5, 4]], ids=["[-2, 2]", "[-5, 4]"] -) -@pytest.mark.parametrize("mode", ["clip", "wrap"], ids=["clip", "wrap"]) -def test_take_1d(indices, array_type, indices_type, mode): - a = numpy.array([-2, -1, 0, 1, 2], dtype=array_type) - ind = numpy.array(indices, dtype=indices_type) - ia = dpnp.array(a) - iind = dpnp.array(ind) - expected = numpy.take(a, ind, mode=mode) - result = dpnp.take(ia, iind, mode=mode) - assert_array_equal(expected, result) - - -@pytest.mark.parametrize("array_type", get_all_dtypes()) -@pytest.mark.parametrize( - "indices_type", [numpy.int32, numpy.int64], ids=["int32", "int64"] -) -@pytest.mark.parametrize( - "indices", [[-1, 0], [-3, 2]], ids=["[-1, 0]", "[-3, 2]"] -) -@pytest.mark.parametrize("mode", ["clip", "wrap"], ids=["clip", "wrap"]) -@pytest.mark.parametrize("axis", [0, 1], ids=["0", "1"]) -def test_take_2d(indices, array_type, indices_type, axis, mode): - a = numpy.array([[-1, 0, 1], [-2, -3, -4], [2, 3, 4]], dtype=array_type) - ind = numpy.array(indices, dtype=indices_type) - ia = dpnp.array(a) - iind = dpnp.array(ind) - expected = numpy.take(a, ind, axis=axis, mode=mode) - result = dpnp.take(ia, iind, axis=axis, mode=mode) - assert_array_equal(expected, result) - - -@pytest.mark.parametrize("array_type", get_all_dtypes()) -@pytest.mark.parametrize("indices", [[-5, 5]], ids=["[-5, 5]"]) -@pytest.mark.parametrize("mode", ["clip", "wrap"], ids=["clip", "wrap"]) -def test_take_over_index(indices, array_type, mode): - a = dpnp.array([-2, -1, 0, 1, 2], dtype=array_type) - ind = dpnp.array(indices, dtype=dpnp.int64) - expected = dpnp.array([-2, 2], dtype=a.dtype) - result = dpnp.take(a, ind, mode=mode) - assert_array_equal(expected, result) - - @pytest.mark.parametrize( "m", [None, 0, 1, 2, 3, 4], ids=["None", "0", "1", "2", "3", "4"] ) diff --git a/tests/third_party/cupy/core_tests/test_ndarray.py b/tests/third_party/cupy/core_tests/test_ndarray.py new file mode 100644 index 00000000000..4d39767daa6 --- /dev/null +++ b/tests/third_party/cupy/core_tests/test_ndarray.py @@ -0,0 +1,690 @@ +import copy +import unittest + +import dpctl +import numpy +import pytest + +import dpnp as cupy +from tests.third_party.cupy import testing + + +def get_array_module(*args): + for arg in args: + if isinstance(arg, cupy.ndarray): + return cupy + return numpy + + +def wrap_take(array, *args, **kwargs): + if get_array_module(array) == numpy: + kwargs["mode"] = "wrap" + + return array.take(*args, **kwargs) + + +class TestNdarrayInit(unittest.TestCase): + @pytest.mark.skip("passing 'None' into shape arguments is not supported") + def test_shape_none(self): + with testing.assert_warns(DeprecationWarning): + a = cupy.ndarray(None) + assert a.shape == () + + def test_shape_int(self): + a = cupy.ndarray(3) + assert a.shape == (3,) + + def test_shape_not_integer(self): + for xp in (numpy, cupy): + with pytest.raises(TypeError): + xp.ndarray(1.0) + with pytest.raises(TypeError): + xp.ndarray((1.0,)) + + @pytest.mark.skip("passing buffer as dpnp array is not supported") + def test_shape_int_with_strides(self): + dummy = cupy.ndarray(3) + a = cupy.ndarray(3, strides=(0,), buffer=dummy) + assert a.shape == (3,) + assert a.strides == (0,) + + @pytest.mark.skip("passing buffer as dpnp array is not supported") + def test_memptr(self): + a = cupy.arange(6).astype(numpy.float32).reshape((2, 3)) + memptr = a + + b = cupy.ndarray((2, 3), numpy.float32, buffer=memptr) + testing.assert_array_equal(a, b) + + b += 1 + testing.assert_array_equal(a, b) + + @pytest.mark.skip("self-overlapping strides are not supported") + def test_memptr_with_strides(self): + buf = cupy.ndarray(20, numpy.uint8) + memptr = buf + + # self-overlapping strides + a = cupy.ndarray((2, 3), numpy.float32, buffer=memptr, strides=(8, 4)) + assert a.strides == (8, 4) + + a[:] = 1 + a[0, 2] = 4 + assert float(a[1, 0]) == 4 + + @pytest.mark.skip("no exception raised by dpctl") + def test_strides_without_memptr(self): + for xp in (numpy, cupy): + with pytest.raises(ValueError): + xp.ndarray((2, 3), numpy.float32, strides=(20, 4)) + + @pytest.mark.skip("passing buffer as dpnp array is not supported") + def test_strides_is_given_and_order_is_ignored(self): + buf = cupy.ndarray(20, numpy.uint8) + a = cupy.ndarray((2, 3), numpy.float32, buf, strides=(2, 1), order="C") + assert a.strides == (2, 1) + + @pytest.mark.skip("dpctl-1724 issue") + @testing.with_requires("numpy>=1.19") + def test_strides_is_given_but_order_is_invalid(self): + for xp in (numpy, cupy): + with pytest.raises(ValueError): + xp.ndarray((2, 3), numpy.float32, strides=(2, 1), order="!") + + def test_order(self): + shape = (2, 3, 4) + a = cupy.ndarray(shape, order="F") + a_cpu = numpy.ndarray(shape, order="F", dtype=a.dtype) + assert all( + i * a.itemsize == j for i, j in zip(a.strides, a_cpu.strides) + ) + assert a.flags.f_contiguous + assert not a.flags.c_contiguous + + @pytest.mark.skip("passing 'None' into order arguments is not supported") + def test_order_none(self): + shape = (2, 3, 4) + a = cupy.ndarray(shape, order=None) + a_cpu = numpy.ndarray(shape, order=None, dtype=a.dtype) + assert a.flags.c_contiguous == a_cpu.flags.c_contiguous + assert a.flags.f_contiguous == a_cpu.flags.f_contiguous + assert all( + i * a.itemsize == j for i, j in zip(a.strides, a_cpu.strides) + ) + + @pytest.mark.skip("__slots__ is not supported") + def test_slots(self): + # Test for #7883. + a = cupy.ndarray((2, 3)) + with pytest.raises(AttributeError): + a.custom_attr = 100 + + class UserNdarray(cupy.ndarray): + pass + + b = UserNdarray((2, 3)) + b.custom_attr = 100 + + +@testing.parameterize( + *testing.product( + { + "shape": [(), (1,), (1, 2), (1, 2, 3)], + "order": ["C", "F"], + "dtype": [ + numpy.uint8, # itemsize=1 + numpy.uint16, # itemsize=2 + ], + } + ) +) +@pytest.mark.skip("strides may vary") +class TestNdarrayInitStrides(unittest.TestCase): + # Check the strides given shape, itemsize and order. + @testing.numpy_cupy_equal() + def test_strides(self, xp): + arr = xp.ndarray(self.shape, dtype=self.dtype, order=self.order) + return (arr.strides, arr.flags.c_contiguous, arr.flags.f_contiguous) + + +class TestNdarrayInitRaise(unittest.TestCase): + def test_unsupported_type(self): + arr = numpy.ndarray((2, 3), dtype=object) + with pytest.raises(TypeError): + cupy.array(arr) + + @pytest.mark.skip("no ndim limit") + def test_excessive_ndim(self): + for xp in (numpy, cupy): + with pytest.raises(ValueError): + xp.ndarray(shape=[1 for i in range(33)], dtype=xp.int8) + + +@testing.parameterize( + *testing.product( + { + "shape": [(), (0,), (1,), (0, 0, 2), (2, 3)], + } + ) +) +@pytest.mark.skip("deepcopy() is not supported") +class TestNdarrayDeepCopy(unittest.TestCase): + def _check_deepcopy(self, arr, arr2): + assert arr.data is not arr2.data + assert arr.shape == arr2.shape + assert arr.size == arr2.size + assert arr.dtype == arr2.dtype + assert arr.strides == arr2.strides + testing.assert_array_equal(arr, arr2) + + def test_deepcopy(self): + arr = _core.ndarray(self.shape) + arr2 = copy.deepcopy(arr) + self._check_deepcopy(arr, arr2) + + @testing.multi_gpu(2) + def test_deepcopy_multi_device(self): + arr = _core.ndarray(self.shape) + with cuda.Device(1): + arr2 = copy.deepcopy(arr) + self._check_deepcopy(arr, arr2) + assert arr2.device == arr.device + + +_test_copy_multi_device_with_stream_src = r""" +extern "C" __global__ +void wait_and_write(long long *x) { + clock_t start = clock(); + clock_t now; + for (;;) { + now = clock(); + clock_t cycles = now > start ? now - start : now + (0xffffffff - start); + if (cycles >= 1000000000) { + break; + } + } + x[0] = 1; + x[1] = now; // in case the compiler optimizing away the entire loop +} +""" + + +@pytest.mark.skip() +class TestNdarrayCopy: + @testing.multi_gpu(2) + @testing.for_orders("CFA") + def test_copy_multi_device_non_contiguous(self, order): + arr = cupy.ndarray((20,))[::2] + dev1 = dpctl.SyclDevice() + arr2 = arr.copy(order, device=dev1) + assert arr2.device == dev1 + testing.assert_array_equal(arr, arr2) + + @testing.multi_gpu(2) + def test_copy_multi_device_non_contiguous_K(self): + arr = _core.ndarray((20,))[::2] + with cuda.Device(1): + with pytest.raises(NotImplementedError): + arr.copy("K") + + # See cupy/cupy#5004 + @testing.multi_gpu(2) + def test_copy_multi_device_with_stream(self): + # Kernel that takes long enough then finally writes values. + src = _test_copy_multi_device_with_stream_src + if runtime.is_hip and driver.get_build_version() >= 5_00_00000: + src = "#include \n" + src + kern = cupy.RawKernel(src, "wait_and_write") + + # Allocates a memory and launches the kernel on a device with its + # stream. + with cuda.Device(0): + # Keep this stream alive over the D2D copy below for HIP + with cuda.Stream() as s1: + a = cupy.zeros((2,), dtype=numpy.uint64) + kern((1,), (1,), a) + + # D2D copy to another device with another stream should get the + # original values of the memory before the kernel on the first device + # finally makes the write. + with cuda.Device(1): + with cuda.Stream(): + b = a.copy() + testing.assert_array_equal( + b, numpy.array([0, 0], dtype=numpy.uint64) + ) + + +@pytest.mark.skip() +class TestNdarrayShape(unittest.TestCase): + @testing.numpy_cupy_array_equal() + def test_shape_set(self, xp): + arr = xp.ndarray((2, 3)) + arr.shape = (3, 2) + return xp.array(arr.shape) + + @testing.numpy_cupy_array_equal() + def test_shape_set_infer(self, xp): + arr = xp.ndarray((2, 3)) + arr.shape = (3, -1) + return xp.array(arr.shape) + + @testing.numpy_cupy_array_equal() + def test_shape_set_int(self, xp): + arr = xp.ndarray((2, 3)) + arr.shape = 6 + return xp.array(arr.shape) + + def test_shape_need_copy(self): + # from cupy/cupy#5470 + for xp in (numpy, cupy): + arr = xp.ndarray((2, 3), order="F") + with pytest.raises(AttributeError) as e: + arr.shape = (3, 2) + assert "incompatible shape" in str(e.value).lower() + + +@pytest.mark.skip("CUDA interface is not supported") +class TestNdarrayCudaInterface(unittest.TestCase): + def test_cuda_array_interface(self): + arr = cupy.zeros(shape=(2, 3), dtype=cupy.float64) + iface = arr.__cuda_array_interface__ + assert iface["version"] == 3 + assert set(iface.keys()) == set( + [ + "shape", + "typestr", + "data", + "version", + "descr", + "stream", + "strides", + ] + ) + assert iface["shape"] == (2, 3) + assert iface["typestr"] == "