Skip to content

Commit ce26cf0

Browse files
antonwolfyvtavana
andauthored
Update dpnp.take implementation to get rid of limitations for input arguments (#1909)
* Remove limitations from dpnp.take implementation * Add more test to cover specail cases and increase code coverage * Applied pre-commit hook * Corrected test_over_index * Update docsctrings with resolving typos * Use dpnp.reshape() to change shape and create dpnp array from usm_ndarray result * Remove data syncronization from dpnp.get_result_array() * Update dpnp/dpnp_iface_indexing.py Co-authored-by: vtavana <120411540+vtavana@users.noreply.github.com> * Applied review comments --------- Co-authored-by: vtavana <120411540+vtavana@users.noreply.github.com>
1 parent eacaa5d commit ce26cf0

File tree

5 files changed

+857
-95
lines changed

5 files changed

+857
-95
lines changed

dpnp/dpnp_array.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1399,7 +1399,7 @@ def swapaxes(self, axis1, axis2):
13991399

14001400
return dpnp.swapaxes(self, axis1=axis1, axis2=axis2)
14011401

1402-
def take(self, indices, /, *, axis=None, out=None, mode="wrap"):
1402+
def take(self, indices, axis=None, out=None, mode="wrap"):
14031403
"""
14041404
Take elements from an array along an axis.
14051405

dpnp/dpnp_iface_indexing.py

Lines changed: 81 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
3838
"""
3939

40+
import operator
41+
4042
import dpctl.tensor as dpt
4143
import numpy
4244
from numpy.core.numeric import normalize_axis_index
@@ -1247,41 +1249,55 @@ def select(condlist, choicelist, default=0):
12471249

12481250

12491251
# pylint: disable=redefined-outer-name
1250-
def take(x, indices, /, *, axis=None, out=None, mode="wrap"):
1252+
def take(a, indices, /, *, axis=None, out=None, mode="wrap"):
12511253
"""
12521254
Take elements from an array along an axis.
12531255
1256+
When `axis` is not ``None``, this function does the same thing as "fancy"
1257+
indexing (indexing arrays using arrays); however, it can be easier to use
1258+
if you need elements along a given axis. A call such as
1259+
``dpnp.take(a, indices, axis=3)`` is equivalent to
1260+
``a[:, :, :, indices, ...]``.
1261+
12541262
For full documentation refer to :obj:`numpy.take`.
12551263
1264+
Parameters
1265+
----------
1266+
a : {dpnp.ndarray, usm_ndarray}, (Ni..., M, Nk...)
1267+
The source array.
1268+
indices : {array_like, scalars}, (Nj...)
1269+
The indices of the values to extract.
1270+
Also allow scalars for `indices`.
1271+
axis : {None, int, bool, 0-d array of integer dtype}, optional
1272+
The axis over which to select values. By default, the flattened
1273+
input array is used.
1274+
Default: ``None``.
1275+
out : {None, dpnp.ndarray, usm_ndarray}, optional (Ni..., Nj..., Nk...)
1276+
If provided, the result will be placed in this array. It should
1277+
be of the appropriate shape and dtype.
1278+
Default: ``None``.
1279+
mode : {"wrap", "clip"}, optional
1280+
Specifies how out-of-bounds indices will be handled. Possible values
1281+
are:
1282+
1283+
- ``"wrap"``: clamps indices to (``-n <= i < n``), then wraps
1284+
negative indices.
1285+
- ``"clip"``: clips indices to (``0 <= i < n``).
1286+
1287+
Default: ``"wrap"``.
1288+
12561289
Returns
12571290
-------
1258-
out : dpnp.ndarray
1259-
An array with shape x.shape[:axis] + indices.shape + x.shape[axis + 1:]
1260-
filled with elements from `x`.
1261-
1262-
Limitations
1263-
-----------
1264-
Parameters `x` and `indices` are supported either as :class:`dpnp.ndarray`
1265-
or :class:`dpctl.tensor.usm_ndarray`.
1266-
Parameter `indices` is supported as 1-D array of integer data type.
1267-
Parameter `out` is supported only with default value.
1268-
Parameter `mode` is supported with ``wrap``, the default, and ``clip``
1269-
values.
1270-
Providing parameter `axis` is optional when `x` is a 1-D array.
1271-
Otherwise the function will be executed sequentially on CPU.
1291+
out : dpnp.ndarray, (Ni..., Nj..., Nk...)
1292+
The returned array has the same type as `a`.
12721293
12731294
See Also
12741295
--------
12751296
:obj:`dpnp.compress` : Take elements using a boolean mask.
1297+
:obj:`dpnp.ndarray.take` : Equivalent method.
12761298
:obj:`dpnp.take_along_axis` : Take elements by matching the array and
12771299
the index arrays.
12781300
1279-
Notes
1280-
-----
1281-
How out-of-bounds indices will be handled.
1282-
"wrap" - clamps indices to (-n <= i < n), then wraps negative indices.
1283-
"clip" - clips indices to (0 <= i < n)
1284-
12851301
Examples
12861302
--------
12871303
>>> import dpnp as np
@@ -1302,29 +1318,53 @@ def take(x, indices, /, *, axis=None, out=None, mode="wrap"):
13021318
>>> np.take(x, indices, mode="clip")
13031319
array([4, 4, 4, 8, 8])
13041320
1321+
If `indices` is not one dimensional, the output also has these dimensions.
1322+
1323+
>>> np.take(x, [[0, 1], [2, 3]])
1324+
array([[4, 3],
1325+
[5, 7]])
1326+
13051327
"""
13061328

1307-
if dpnp.is_supported_array_type(x) and dpnp.is_supported_array_type(
1308-
indices
1309-
):
1310-
if indices.ndim != 1 or not dpnp.issubdtype(
1311-
indices.dtype, dpnp.integer
1312-
):
1313-
pass
1314-
elif axis is None and x.ndim > 1:
1315-
pass
1316-
elif out is not None:
1317-
pass
1318-
elif mode not in ("clip", "wrap"):
1319-
pass
1320-
else:
1321-
dpt_array = dpnp.get_usm_ndarray(x)
1322-
dpt_indices = dpnp.get_usm_ndarray(indices)
1323-
return dpnp_array._create_from_usm_ndarray(
1324-
dpt.take(dpt_array, dpt_indices, axis=axis, mode=mode)
1325-
)
1329+
if mode not in ("wrap", "clip"):
1330+
raise ValueError(f"`mode` must be 'wrap' or 'clip', but got `{mode}`.")
1331+
1332+
usm_a = dpnp.get_usm_ndarray(a)
1333+
if not dpnp.is_supported_array_type(indices):
1334+
usm_ind = dpt.asarray(
1335+
indices, usm_type=a.usm_type, sycl_queue=a.sycl_queue
1336+
)
1337+
else:
1338+
usm_ind = dpnp.get_usm_ndarray(indices)
1339+
1340+
a_ndim = a.ndim
1341+
if axis is None:
1342+
res_shape = usm_ind.shape
1343+
1344+
if a_ndim > 1:
1345+
# dpt.take requires flattened input array
1346+
usm_a = dpt.reshape(usm_a, -1)
1347+
elif a_ndim == 0:
1348+
axis = normalize_axis_index(operator.index(axis), 1)
1349+
res_shape = usm_ind.shape
1350+
else:
1351+
axis = normalize_axis_index(operator.index(axis), a_ndim)
1352+
a_sh = a.shape
1353+
res_shape = a_sh[:axis] + usm_ind.shape + a_sh[axis + 1 :]
1354+
1355+
if usm_ind.ndim != 1:
1356+
# dpt.take supports only 1-D array of indices
1357+
usm_ind = dpt.reshape(usm_ind, -1)
1358+
1359+
if not dpnp.issubdtype(usm_ind.dtype, dpnp.integer):
1360+
# dpt.take supports only integer dtype for array of indices
1361+
usm_ind = dpt.astype(usm_ind, dpnp.intp, copy=False, casting="safe")
1362+
1363+
usm_res = dpt.take(usm_a, usm_ind, axis=axis, mode=mode)
13261364

1327-
return call_origin(numpy.take, x, indices, axis, out, mode)
1365+
# need to reshape the result if shape of indices array was changed
1366+
result = dpnp.reshape(usm_res, res_shape)
1367+
return dpnp.get_result_array(result, out)
13281368

13291369

13301370
def take_along_axis(a, indices, axis):

tests/test_indexing.py

Lines changed: 84 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,90 @@ def test_broadcast(self, arr_dt, idx_dt):
535535
assert_array_equal(np_a, dp_a)
536536

537537

538+
class TestTake:
539+
@pytest.mark.parametrize("a_dt", get_all_dtypes(no_none=True))
540+
@pytest.mark.parametrize("ind_dt", get_all_dtypes(no_none=True))
541+
@pytest.mark.parametrize(
542+
"indices", [[-2, 2], [-5, 4]], ids=["[-2, 2]", "[-5, 4]"]
543+
)
544+
@pytest.mark.parametrize("mode", ["clip", "wrap"])
545+
def test_1d(self, a_dt, ind_dt, indices, mode):
546+
a = numpy.array([-2, -1, 0, 1, 2], dtype=a_dt)
547+
ind = numpy.array(indices, dtype=ind_dt)
548+
ia, iind = dpnp.array(a), dpnp.array(ind)
549+
550+
if numpy.can_cast(ind_dt, numpy.intp, casting="safe"):
551+
result = dpnp.take(ia, iind, mode=mode)
552+
expected = numpy.take(a, ind, mode=mode)
553+
assert_array_equal(result, expected)
554+
else:
555+
assert_raises(TypeError, ia.take, iind, mode=mode)
556+
assert_raises(TypeError, a.take, ind, mode=mode)
557+
558+
@pytest.mark.parametrize("a_dt", get_all_dtypes(no_none=True))
559+
@pytest.mark.parametrize("ind_dt", get_integer_dtypes())
560+
@pytest.mark.parametrize(
561+
"indices", [[-1, 0], [-3, 2]], ids=["[-1, 0]", "[-3, 2]"]
562+
)
563+
@pytest.mark.parametrize("mode", ["clip", "wrap"])
564+
@pytest.mark.parametrize("axis", [0, 1], ids=["0", "1"])
565+
def test_2d(self, a_dt, ind_dt, indices, mode, axis):
566+
a = numpy.array([[-1, 0, 1], [-2, -3, -4], [2, 3, 4]], dtype=a_dt)
567+
ind = numpy.array(indices, dtype=ind_dt)
568+
ia, iind = dpnp.array(a), dpnp.array(ind)
569+
570+
result = ia.take(iind, axis=axis, mode=mode)
571+
expected = a.take(ind, axis=axis, mode=mode)
572+
assert_array_equal(result, expected)
573+
574+
@pytest.mark.parametrize("a_dt", get_all_dtypes(no_none=True))
575+
@pytest.mark.parametrize("indices", [[-5, 5]], ids=["[-5, 5]"])
576+
@pytest.mark.parametrize("mode", ["clip", "wrap"])
577+
def test_over_index(self, a_dt, indices, mode):
578+
a = dpnp.array([-2, -1, 0, 1, 2], dtype=a_dt)
579+
ind = dpnp.array(indices, dtype=numpy.intp)
580+
581+
result = dpnp.take(a, ind, mode=mode)
582+
expected = dpnp.array([-2, 2], dtype=a.dtype)
583+
assert_array_equal(result, expected)
584+
585+
@pytest.mark.parametrize("xp", [numpy, dpnp])
586+
@pytest.mark.parametrize("indices", [[0], [1]], ids=["[0]", "[1]"])
587+
@pytest.mark.parametrize("mode", ["clip", "wrap"])
588+
def test_index_error(self, xp, indices, mode):
589+
# take from a 0-length dimension
590+
a = xp.empty((2, 3, 0, 4))
591+
assert_raises(IndexError, a.take, indices, axis=2, mode=mode)
592+
593+
def test_bool_axis(self):
594+
a = numpy.array([[[1]]])
595+
ia = dpnp.array(a)
596+
597+
result = ia.take([0], axis=False)
598+
expected = a.take([0], axis=0) # numpy raises an error for bool axis
599+
assert_array_equal(result, expected)
600+
601+
def test_axis_as_array(self):
602+
a = numpy.array([[[1]]])
603+
ia = dpnp.array(a)
604+
605+
result = ia.take([0], axis=ia)
606+
expected = a.take(
607+
[0], axis=1
608+
) # numpy raises an error for axis as array
609+
assert_array_equal(result, expected)
610+
611+
def test_mode_raise(self):
612+
a = dpnp.array([[1, 2], [3, 4]])
613+
assert_raises(ValueError, a.take, [-1, 4], mode="raise")
614+
615+
@pytest.mark.parametrize("xp", [numpy, dpnp])
616+
def test_unicode_mode(self, xp):
617+
a = xp.arange(10)
618+
k = b"\xc3\xa4".decode("UTF8")
619+
assert_raises(ValueError, a.take, 5, mode=k)
620+
621+
538622
class TestTakeAlongAxis:
539623
@pytest.mark.parametrize(
540624
"func, argfunc, kwargs",
@@ -964,54 +1048,6 @@ def test_select():
9641048
assert_array_equal(expected, result)
9651049

9661050

967-
@pytest.mark.parametrize("array_type", get_all_dtypes())
968-
@pytest.mark.parametrize(
969-
"indices_type", [numpy.int32, numpy.int64], ids=["int32", "int64"]
970-
)
971-
@pytest.mark.parametrize(
972-
"indices", [[-2, 2], [-5, 4]], ids=["[-2, 2]", "[-5, 4]"]
973-
)
974-
@pytest.mark.parametrize("mode", ["clip", "wrap"], ids=["clip", "wrap"])
975-
def test_take_1d(indices, array_type, indices_type, mode):
976-
a = numpy.array([-2, -1, 0, 1, 2], dtype=array_type)
977-
ind = numpy.array(indices, dtype=indices_type)
978-
ia = dpnp.array(a)
979-
iind = dpnp.array(ind)
980-
expected = numpy.take(a, ind, mode=mode)
981-
result = dpnp.take(ia, iind, mode=mode)
982-
assert_array_equal(expected, result)
983-
984-
985-
@pytest.mark.parametrize("array_type", get_all_dtypes())
986-
@pytest.mark.parametrize(
987-
"indices_type", [numpy.int32, numpy.int64], ids=["int32", "int64"]
988-
)
989-
@pytest.mark.parametrize(
990-
"indices", [[-1, 0], [-3, 2]], ids=["[-1, 0]", "[-3, 2]"]
991-
)
992-
@pytest.mark.parametrize("mode", ["clip", "wrap"], ids=["clip", "wrap"])
993-
@pytest.mark.parametrize("axis", [0, 1], ids=["0", "1"])
994-
def test_take_2d(indices, array_type, indices_type, axis, mode):
995-
a = numpy.array([[-1, 0, 1], [-2, -3, -4], [2, 3, 4]], dtype=array_type)
996-
ind = numpy.array(indices, dtype=indices_type)
997-
ia = dpnp.array(a)
998-
iind = dpnp.array(ind)
999-
expected = numpy.take(a, ind, axis=axis, mode=mode)
1000-
result = dpnp.take(ia, iind, axis=axis, mode=mode)
1001-
assert_array_equal(expected, result)
1002-
1003-
1004-
@pytest.mark.parametrize("array_type", get_all_dtypes())
1005-
@pytest.mark.parametrize("indices", [[-5, 5]], ids=["[-5, 5]"])
1006-
@pytest.mark.parametrize("mode", ["clip", "wrap"], ids=["clip", "wrap"])
1007-
def test_take_over_index(indices, array_type, mode):
1008-
a = dpnp.array([-2, -1, 0, 1, 2], dtype=array_type)
1009-
ind = dpnp.array(indices, dtype=dpnp.int64)
1010-
expected = dpnp.array([-2, 2], dtype=a.dtype)
1011-
result = dpnp.take(a, ind, mode=mode)
1012-
assert_array_equal(expected, result)
1013-
1014-
10151051
@pytest.mark.parametrize(
10161052
"m", [None, 0, 1, 2, 3, 4], ids=["None", "0", "1", "2", "3", "4"]
10171053
)

0 commit comments

Comments
 (0)