Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement support of tuple key in __getitem__ and __setitem__ #1362

Merged
merged 2 commits into from
Apr 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions dpnp/dpnp_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,23 @@

import dpnp


def _get_unwrapped_index_key(key):
"""
Return a key where each nested instance of DPNP array is unwrapped into USM ndarray
for futher processing in DPCTL advanced indexing functions.

"""

if isinstance(key, tuple):
if any(isinstance(x, dpnp_array) for x in key):
# create a new tuple from the input key with unwrapped DPNP arrays
return tuple(x.get_array() if isinstance(x, dpnp_array) else x for x in key)
elif isinstance(key, dpnp_array):
return key.get_array()
return key


class dpnp_array:
"""
Multi-dimensional array object.
Expand Down Expand Up @@ -176,8 +193,7 @@ def __ge__(self, other):
# '__getattribute__',

def __getitem__(self, key):
if isinstance(key, dpnp_array):
key = key.get_array()
key = _get_unwrapped_index_key(key)

item = self._array_obj.__getitem__(key)
if not isinstance(item, dpt.usm_ndarray):
Expand Down Expand Up @@ -340,8 +356,8 @@ def __rxor__(self, other):
# '__setattr__',

def __setitem__(self, key, val):
if isinstance(key, dpnp_array):
key = key.get_array()
key = _get_unwrapped_index_key(key)

if isinstance(val, dpnp_array):
val = val.get_array()

Expand Down
62 changes: 61 additions & 1 deletion tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,70 @@

import numpy
from numpy.testing import (
assert_array_equal
assert_,
assert_array_equal,
assert_equal
)


class TestIndexing:
def test_ellipsis_index(self):
a = dpnp.array([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
assert_(a[...] is not a)
assert_equal(a[...], a)

# test that slicing with ellipsis doesn't skip an arbitrary number of dimensions
assert_equal(a[0, ...], a[0])
assert_equal(a[0, ...], a[0,:])
assert_equal(a[..., 0], a[:, 0])

# test that slicing with ellipsis always results in an array
assert_equal(a[0, ..., 1], dpnp.array(2))

# assignment with `(Ellipsis,)` on 0-d arrays
b = dpnp.array(1)
b[(Ellipsis,)] = 2
assert_equal(b, 2)

def test_boolean_indexing_list(self):
a = dpnp.array([1, 2, 3])
b = dpnp.array([True, False, True])

assert_equal(a[b], [1, 3])
assert_equal(a[None, b], [[1, 3]])

def test_indexing_array_weird_strides(self):
np_x = numpy.ones(10)
dp_x = dpnp.ones(10)

np_ind = numpy.arange(10)[:, None, None, None]
np_ind = numpy.broadcast_to(np_ind, (10, 55, 4, 4))

dp_ind = dpnp.arange(10)[:, None, None, None]
dp_ind = dpnp.broadcast_to(dp_ind, (10, 55, 4, 4))

# single advanced index case
assert_array_equal(dp_x[dp_ind], np_x[np_ind])

np_x2 = numpy.ones((10, 2))
dp_x2 = dpnp.ones((10, 2))

np_zind = numpy.zeros(4, dtype=np_ind.dtype)
dp_zind = dpnp.zeros(4, dtype=dp_ind.dtype)

# higher dimensional advanced index
assert_array_equal(dp_x2[dp_ind, dp_zind], np_x2[np_ind, np_zind])

def test_indexing_array_negative_strides(self):
arr = dpnp.zeros((4, 4))[::-1, ::-1]

slices = (slice(None), dpnp.array([0, 1, 2, 3]))
arr[slices] = 10
assert_array_equal(arr, 10.)


@pytest.mark.usefixtures("allow_fall_back_on_numpy")
def test_choose():
a = numpy.r_[:4]
Expand Down