Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions dpnp/dpnp_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,23 @@

import dpnp


def _get_unwrapped_index_key(key):
"""
Return a key where each nested instance of DPNP array is unwrapped into USM ndarray
for futher processing in DPCTL advanced indexing functions.

"""

if isinstance(key, tuple):
if any(isinstance(x, dpnp_array) for x in key):
# create a new tuple from the input key with unwrapped DPNP arrays
return tuple(x.get_array() if isinstance(x, dpnp_array) else x for x in key)
elif isinstance(key, dpnp_array):
return key.get_array()
return key


class dpnp_array:
"""
Multi-dimensional array object.
Expand Down Expand Up @@ -176,8 +193,7 @@ def __ge__(self, other):
# '__getattribute__',

def __getitem__(self, key):
if isinstance(key, dpnp_array):
key = key.get_array()
key = _get_unwrapped_index_key(key)

item = self._array_obj.__getitem__(key)
if not isinstance(item, dpt.usm_ndarray):
Expand Down Expand Up @@ -340,8 +356,8 @@ def __rxor__(self, other):
# '__setattr__',

def __setitem__(self, key, val):
if isinstance(key, dpnp_array):
key = key.get_array()
key = _get_unwrapped_index_key(key)

if isinstance(val, dpnp_array):
val = val.get_array()

Expand Down
62 changes: 61 additions & 1 deletion tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,70 @@

import numpy
from numpy.testing import (
assert_array_equal
assert_,
assert_array_equal,
assert_equal
)


class TestIndexing:
def test_ellipsis_index(self):
a = dpnp.array([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
assert_(a[...] is not a)
assert_equal(a[...], a)

# test that slicing with ellipsis doesn't skip an arbitrary number of dimensions
assert_equal(a[0, ...], a[0])
assert_equal(a[0, ...], a[0,:])
assert_equal(a[..., 0], a[:, 0])

# test that slicing with ellipsis always results in an array
assert_equal(a[0, ..., 1], dpnp.array(2))

# assignment with `(Ellipsis,)` on 0-d arrays
b = dpnp.array(1)
b[(Ellipsis,)] = 2
assert_equal(b, 2)

def test_boolean_indexing_list(self):
a = dpnp.array([1, 2, 3])
b = dpnp.array([True, False, True])

assert_equal(a[b], [1, 3])
assert_equal(a[None, b], [[1, 3]])

def test_indexing_array_weird_strides(self):
np_x = numpy.ones(10)
dp_x = dpnp.ones(10)

np_ind = numpy.arange(10)[:, None, None, None]
np_ind = numpy.broadcast_to(np_ind, (10, 55, 4, 4))

dp_ind = dpnp.arange(10)[:, None, None, None]
dp_ind = dpnp.broadcast_to(dp_ind, (10, 55, 4, 4))

# single advanced index case
assert_array_equal(dp_x[dp_ind], np_x[np_ind])

np_x2 = numpy.ones((10, 2))
dp_x2 = dpnp.ones((10, 2))

np_zind = numpy.zeros(4, dtype=np_ind.dtype)
dp_zind = dpnp.zeros(4, dtype=dp_ind.dtype)

# higher dimensional advanced index
assert_array_equal(dp_x2[dp_ind, dp_zind], np_x2[np_ind, np_zind])

def test_indexing_array_negative_strides(self):
arr = dpnp.zeros((4, 4))[::-1, ::-1]

slices = (slice(None), dpnp.array([0, 1, 2, 3]))
arr[slices] = 10
assert_array_equal(arr, 10.)


@pytest.mark.usefixtures("allow_fall_back_on_numpy")
def test_choose():
a = numpy.r_[:4]
Expand Down