Skip to content

Commit

Permalink
Update dpnp.take implementation to get rid of limitations for input…
Browse files Browse the repository at this point in the history
… arguments (#1909)

* Remove limitations from dpnp.take implementation

* Add more test to cover specail cases and increase code coverage

* Applied pre-commit hook

* Corrected test_over_index

* Update docsctrings with resolving typos

* Use dpnp.reshape() to change shape and create dpnp array from usm_ndarray result

* Remove data syncronization from dpnp.get_result_array()

* Update dpnp/dpnp_iface_indexing.py

Co-authored-by: vtavana <120411540+vtavana@users.noreply.github.com>

* Applied review comments

---------

Co-authored-by: vtavana <120411540+vtavana@users.noreply.github.com>
  • Loading branch information
antonwolfy and vtavana authored Jul 10, 2024
1 parent eacaa5d commit ce26cf0
Show file tree
Hide file tree
Showing 5 changed files with 857 additions and 95 deletions.
2 changes: 1 addition & 1 deletion dpnp/dpnp_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1399,7 +1399,7 @@ def swapaxes(self, axis1, axis2):

return dpnp.swapaxes(self, axis1=axis1, axis2=axis2)

def take(self, indices, /, *, axis=None, out=None, mode="wrap"):
def take(self, indices, axis=None, out=None, mode="wrap"):
"""
Take elements from an array along an axis.
Expand Down
122 changes: 81 additions & 41 deletions dpnp/dpnp_iface_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
"""

import operator

import dpctl.tensor as dpt
import numpy
from numpy.core.numeric import normalize_axis_index
Expand Down Expand Up @@ -1247,41 +1249,55 @@ def select(condlist, choicelist, default=0):


# pylint: disable=redefined-outer-name
def take(x, indices, /, *, axis=None, out=None, mode="wrap"):
def take(a, indices, /, *, axis=None, out=None, mode="wrap"):
"""
Take elements from an array along an axis.
When `axis` is not ``None``, this function does the same thing as "fancy"
indexing (indexing arrays using arrays); however, it can be easier to use
if you need elements along a given axis. A call such as
``dpnp.take(a, indices, axis=3)`` is equivalent to
``a[:, :, :, indices, ...]``.
For full documentation refer to :obj:`numpy.take`.
Parameters
----------
a : {dpnp.ndarray, usm_ndarray}, (Ni..., M, Nk...)
The source array.
indices : {array_like, scalars}, (Nj...)
The indices of the values to extract.
Also allow scalars for `indices`.
axis : {None, int, bool, 0-d array of integer dtype}, optional
The axis over which to select values. By default, the flattened
input array is used.
Default: ``None``.
out : {None, dpnp.ndarray, usm_ndarray}, optional (Ni..., Nj..., Nk...)
If provided, the result will be placed in this array. It should
be of the appropriate shape and dtype.
Default: ``None``.
mode : {"wrap", "clip"}, optional
Specifies how out-of-bounds indices will be handled. Possible values
are:
- ``"wrap"``: clamps indices to (``-n <= i < n``), then wraps
negative indices.
- ``"clip"``: clips indices to (``0 <= i < n``).
Default: ``"wrap"``.
Returns
-------
out : dpnp.ndarray
An array with shape x.shape[:axis] + indices.shape + x.shape[axis + 1:]
filled with elements from `x`.
Limitations
-----------
Parameters `x` and `indices` are supported either as :class:`dpnp.ndarray`
or :class:`dpctl.tensor.usm_ndarray`.
Parameter `indices` is supported as 1-D array of integer data type.
Parameter `out` is supported only with default value.
Parameter `mode` is supported with ``wrap``, the default, and ``clip``
values.
Providing parameter `axis` is optional when `x` is a 1-D array.
Otherwise the function will be executed sequentially on CPU.
out : dpnp.ndarray, (Ni..., Nj..., Nk...)
The returned array has the same type as `a`.
See Also
--------
:obj:`dpnp.compress` : Take elements using a boolean mask.
:obj:`dpnp.ndarray.take` : Equivalent method.
:obj:`dpnp.take_along_axis` : Take elements by matching the array and
the index arrays.
Notes
-----
How out-of-bounds indices will be handled.
"wrap" - clamps indices to (-n <= i < n), then wraps negative indices.
"clip" - clips indices to (0 <= i < n)
Examples
--------
>>> import dpnp as np
Expand All @@ -1302,29 +1318,53 @@ def take(x, indices, /, *, axis=None, out=None, mode="wrap"):
>>> np.take(x, indices, mode="clip")
array([4, 4, 4, 8, 8])
If `indices` is not one dimensional, the output also has these dimensions.
>>> np.take(x, [[0, 1], [2, 3]])
array([[4, 3],
[5, 7]])
"""

if dpnp.is_supported_array_type(x) and dpnp.is_supported_array_type(
indices
):
if indices.ndim != 1 or not dpnp.issubdtype(
indices.dtype, dpnp.integer
):
pass
elif axis is None and x.ndim > 1:
pass
elif out is not None:
pass
elif mode not in ("clip", "wrap"):
pass
else:
dpt_array = dpnp.get_usm_ndarray(x)
dpt_indices = dpnp.get_usm_ndarray(indices)
return dpnp_array._create_from_usm_ndarray(
dpt.take(dpt_array, dpt_indices, axis=axis, mode=mode)
)
if mode not in ("wrap", "clip"):
raise ValueError(f"`mode` must be 'wrap' or 'clip', but got `{mode}`.")

usm_a = dpnp.get_usm_ndarray(a)
if not dpnp.is_supported_array_type(indices):
usm_ind = dpt.asarray(
indices, usm_type=a.usm_type, sycl_queue=a.sycl_queue
)
else:
usm_ind = dpnp.get_usm_ndarray(indices)

a_ndim = a.ndim
if axis is None:
res_shape = usm_ind.shape

if a_ndim > 1:
# dpt.take requires flattened input array
usm_a = dpt.reshape(usm_a, -1)
elif a_ndim == 0:
axis = normalize_axis_index(operator.index(axis), 1)
res_shape = usm_ind.shape
else:
axis = normalize_axis_index(operator.index(axis), a_ndim)
a_sh = a.shape
res_shape = a_sh[:axis] + usm_ind.shape + a_sh[axis + 1 :]

if usm_ind.ndim != 1:
# dpt.take supports only 1-D array of indices
usm_ind = dpt.reshape(usm_ind, -1)

if not dpnp.issubdtype(usm_ind.dtype, dpnp.integer):
# dpt.take supports only integer dtype for array of indices
usm_ind = dpt.astype(usm_ind, dpnp.intp, copy=False, casting="safe")

usm_res = dpt.take(usm_a, usm_ind, axis=axis, mode=mode)

return call_origin(numpy.take, x, indices, axis, out, mode)
# need to reshape the result if shape of indices array was changed
result = dpnp.reshape(usm_res, res_shape)
return dpnp.get_result_array(result, out)


def take_along_axis(a, indices, axis):
Expand Down
132 changes: 84 additions & 48 deletions tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,90 @@ def test_broadcast(self, arr_dt, idx_dt):
assert_array_equal(np_a, dp_a)


class TestTake:
@pytest.mark.parametrize("a_dt", get_all_dtypes(no_none=True))
@pytest.mark.parametrize("ind_dt", get_all_dtypes(no_none=True))
@pytest.mark.parametrize(
"indices", [[-2, 2], [-5, 4]], ids=["[-2, 2]", "[-5, 4]"]
)
@pytest.mark.parametrize("mode", ["clip", "wrap"])
def test_1d(self, a_dt, ind_dt, indices, mode):
a = numpy.array([-2, -1, 0, 1, 2], dtype=a_dt)
ind = numpy.array(indices, dtype=ind_dt)
ia, iind = dpnp.array(a), dpnp.array(ind)

if numpy.can_cast(ind_dt, numpy.intp, casting="safe"):
result = dpnp.take(ia, iind, mode=mode)
expected = numpy.take(a, ind, mode=mode)
assert_array_equal(result, expected)
else:
assert_raises(TypeError, ia.take, iind, mode=mode)
assert_raises(TypeError, a.take, ind, mode=mode)

@pytest.mark.parametrize("a_dt", get_all_dtypes(no_none=True))
@pytest.mark.parametrize("ind_dt", get_integer_dtypes())
@pytest.mark.parametrize(
"indices", [[-1, 0], [-3, 2]], ids=["[-1, 0]", "[-3, 2]"]
)
@pytest.mark.parametrize("mode", ["clip", "wrap"])
@pytest.mark.parametrize("axis", [0, 1], ids=["0", "1"])
def test_2d(self, a_dt, ind_dt, indices, mode, axis):
a = numpy.array([[-1, 0, 1], [-2, -3, -4], [2, 3, 4]], dtype=a_dt)
ind = numpy.array(indices, dtype=ind_dt)
ia, iind = dpnp.array(a), dpnp.array(ind)

result = ia.take(iind, axis=axis, mode=mode)
expected = a.take(ind, axis=axis, mode=mode)
assert_array_equal(result, expected)

@pytest.mark.parametrize("a_dt", get_all_dtypes(no_none=True))
@pytest.mark.parametrize("indices", [[-5, 5]], ids=["[-5, 5]"])
@pytest.mark.parametrize("mode", ["clip", "wrap"])
def test_over_index(self, a_dt, indices, mode):
a = dpnp.array([-2, -1, 0, 1, 2], dtype=a_dt)
ind = dpnp.array(indices, dtype=numpy.intp)

result = dpnp.take(a, ind, mode=mode)
expected = dpnp.array([-2, 2], dtype=a.dtype)
assert_array_equal(result, expected)

@pytest.mark.parametrize("xp", [numpy, dpnp])
@pytest.mark.parametrize("indices", [[0], [1]], ids=["[0]", "[1]"])
@pytest.mark.parametrize("mode", ["clip", "wrap"])
def test_index_error(self, xp, indices, mode):
# take from a 0-length dimension
a = xp.empty((2, 3, 0, 4))
assert_raises(IndexError, a.take, indices, axis=2, mode=mode)

def test_bool_axis(self):
a = numpy.array([[[1]]])
ia = dpnp.array(a)

result = ia.take([0], axis=False)
expected = a.take([0], axis=0) # numpy raises an error for bool axis
assert_array_equal(result, expected)

def test_axis_as_array(self):
a = numpy.array([[[1]]])
ia = dpnp.array(a)

result = ia.take([0], axis=ia)
expected = a.take(
[0], axis=1
) # numpy raises an error for axis as array
assert_array_equal(result, expected)

def test_mode_raise(self):
a = dpnp.array([[1, 2], [3, 4]])
assert_raises(ValueError, a.take, [-1, 4], mode="raise")

@pytest.mark.parametrize("xp", [numpy, dpnp])
def test_unicode_mode(self, xp):
a = xp.arange(10)
k = b"\xc3\xa4".decode("UTF8")
assert_raises(ValueError, a.take, 5, mode=k)


class TestTakeAlongAxis:
@pytest.mark.parametrize(
"func, argfunc, kwargs",
Expand Down Expand Up @@ -964,54 +1048,6 @@ def test_select():
assert_array_equal(expected, result)


@pytest.mark.parametrize("array_type", get_all_dtypes())
@pytest.mark.parametrize(
"indices_type", [numpy.int32, numpy.int64], ids=["int32", "int64"]
)
@pytest.mark.parametrize(
"indices", [[-2, 2], [-5, 4]], ids=["[-2, 2]", "[-5, 4]"]
)
@pytest.mark.parametrize("mode", ["clip", "wrap"], ids=["clip", "wrap"])
def test_take_1d(indices, array_type, indices_type, mode):
a = numpy.array([-2, -1, 0, 1, 2], dtype=array_type)
ind = numpy.array(indices, dtype=indices_type)
ia = dpnp.array(a)
iind = dpnp.array(ind)
expected = numpy.take(a, ind, mode=mode)
result = dpnp.take(ia, iind, mode=mode)
assert_array_equal(expected, result)


@pytest.mark.parametrize("array_type", get_all_dtypes())
@pytest.mark.parametrize(
"indices_type", [numpy.int32, numpy.int64], ids=["int32", "int64"]
)
@pytest.mark.parametrize(
"indices", [[-1, 0], [-3, 2]], ids=["[-1, 0]", "[-3, 2]"]
)
@pytest.mark.parametrize("mode", ["clip", "wrap"], ids=["clip", "wrap"])
@pytest.mark.parametrize("axis", [0, 1], ids=["0", "1"])
def test_take_2d(indices, array_type, indices_type, axis, mode):
a = numpy.array([[-1, 0, 1], [-2, -3, -4], [2, 3, 4]], dtype=array_type)
ind = numpy.array(indices, dtype=indices_type)
ia = dpnp.array(a)
iind = dpnp.array(ind)
expected = numpy.take(a, ind, axis=axis, mode=mode)
result = dpnp.take(ia, iind, axis=axis, mode=mode)
assert_array_equal(expected, result)


@pytest.mark.parametrize("array_type", get_all_dtypes())
@pytest.mark.parametrize("indices", [[-5, 5]], ids=["[-5, 5]"])
@pytest.mark.parametrize("mode", ["clip", "wrap"], ids=["clip", "wrap"])
def test_take_over_index(indices, array_type, mode):
a = dpnp.array([-2, -1, 0, 1, 2], dtype=array_type)
ind = dpnp.array(indices, dtype=dpnp.int64)
expected = dpnp.array([-2, 2], dtype=a.dtype)
result = dpnp.take(a, ind, mode=mode)
assert_array_equal(expected, result)


@pytest.mark.parametrize(
"m", [None, 0, 1, 2, 3, 4], ids=["None", "0", "1", "2", "3", "4"]
)
Expand Down
Loading

0 comments on commit ce26cf0

Please sign in to comment.