From 5b082f81b7a0ce921490297b2dc90199544f468c Mon Sep 17 00:00:00 2001 From: Anton Volkov Date: Mon, 3 Apr 2023 06:51:49 -0500 Subject: [PATCH] Implement support of tuple key in __getitem__ and __setitem__ --- dpnp/dpnp_array.py | 24 +++++++++++++--- tests/test_indexing.py | 62 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 81 insertions(+), 5 deletions(-) diff --git a/dpnp/dpnp_array.py b/dpnp/dpnp_array.py index 70ba6f44580..5d28697594f 100644 --- a/dpnp/dpnp_array.py +++ b/dpnp/dpnp_array.py @@ -29,6 +29,23 @@ import dpnp + +def _get_unwrapped_index_key(key): + """ + Return a key where each nested instance of DPNP array is unwrapped into USM ndarray + for futher processing in DPCTL advanced indexing functions. + + """ + + if isinstance(key, tuple): + if any(isinstance(x, dpnp_array) for x in key): + # create a new tuple from the input key with unwrapped DPNP arrays + return tuple(x.get_array() if isinstance(x, dpnp_array) else x for x in key) + elif isinstance(key, dpnp_array): + return key.get_array() + return key + + class dpnp_array: """ Multi-dimensional array object. @@ -176,8 +193,7 @@ def __ge__(self, other): # '__getattribute__', def __getitem__(self, key): - if isinstance(key, dpnp_array): - key = key.get_array() + key = _get_unwrapped_index_key(key) item = self._array_obj.__getitem__(key) if not isinstance(item, dpt.usm_ndarray): @@ -337,8 +353,8 @@ def __rxor__(self, other): # '__setattr__', def __setitem__(self, key, val): - if isinstance(key, dpnp_array): - key = key.get_array() + key = _get_unwrapped_index_key(key) + if isinstance(val, dpnp_array): val = val.get_array() diff --git a/tests/test_indexing.py b/tests/test_indexing.py index 41128fd70e2..fb49d8c8749 100644 --- a/tests/test_indexing.py +++ b/tests/test_indexing.py @@ -6,10 +6,70 @@ import numpy from numpy.testing import ( - assert_array_equal + assert_, + assert_array_equal, + assert_equal ) +class TestIndexing: + def test_ellipsis_index(self): + a = dpnp.array([[1, 2, 3], + [4, 5, 6], + [7, 8, 9]]) + assert_(a[...] is not a) + assert_equal(a[...], a) + + # test that slicing with ellipsis doesn't skip an arbitrary number of dimensions + assert_equal(a[0, ...], a[0]) + assert_equal(a[0, ...], a[0,:]) + assert_equal(a[..., 0], a[:, 0]) + + # test that slicing with ellipsis always results in an array + assert_equal(a[0, ..., 1], dpnp.array(2)) + + # assignment with `(Ellipsis,)` on 0-d arrays + b = dpnp.array(1) + b[(Ellipsis,)] = 2 + assert_equal(b, 2) + + def test_boolean_indexing_list(self): + a = dpnp.array([1, 2, 3]) + b = dpnp.array([True, False, True]) + + assert_equal(a[b], [1, 3]) + assert_equal(a[None, b], [[1, 3]]) + + def test_indexing_array_weird_strides(self): + np_x = numpy.ones(10) + dp_x = dpnp.ones(10) + + np_ind = numpy.arange(10)[:, None, None, None] + np_ind = numpy.broadcast_to(np_ind, (10, 55, 4, 4)) + + dp_ind = dpnp.arange(10)[:, None, None, None] + dp_ind = dpnp.broadcast_to(dp_ind, (10, 55, 4, 4)) + + # single advanced index case + assert_array_equal(dp_x[dp_ind], np_x[np_ind]) + + np_x2 = numpy.ones((10, 2)) + dp_x2 = dpnp.ones((10, 2)) + + np_zind = numpy.zeros(4, dtype=np_ind.dtype) + dp_zind = dpnp.zeros(4, dtype=dp_ind.dtype) + + # higher dimensional advanced index + assert_array_equal(dp_x2[dp_ind, dp_zind], np_x2[np_ind, np_zind]) + + def test_indexing_array_negative_strides(self): + arr = dpnp.zeros((4, 4))[::-1, ::-1] + + slices = (slice(None), dpnp.array([0, 1, 2, 3])) + arr[slices] = 10 + assert_array_equal(arr, 10.) + + @pytest.mark.usefixtures("allow_fall_back_on_numpy") def test_choose(): a = numpy.r_[:4]