Skip to content

Commit

Permalink
Implement dpnp.select (#1977)
Browse files Browse the repository at this point in the history
* Implement dpnp.select

* Applied review comments

* Added test for select function

* removed old test for select function

---------

Co-authored-by: vtavana <120411540+vtavana@users.noreply.github.com>
  • Loading branch information
npolina4 and vtavana authored Aug 29, 2024
1 parent 838dbda commit f8378b0
Show file tree
Hide file tree
Showing 8 changed files with 214 additions and 95 deletions.
18 changes: 0 additions & 18 deletions dpnp/dpnp_algo/dpnp_algo_indexing.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ and the rest of the library
__all__ += [
"dpnp_choose",
"dpnp_putmask",
"dpnp_select",
]

ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_dpnp_choose_t)(c_dpctl.DPCTLSyclQueueRef,
Expand Down Expand Up @@ -102,20 +101,3 @@ cpdef dpnp_putmask(utils.dpnp_descriptor arr, utils.dpnp_descriptor mask, utils.
for i in range(arr.size):
if mask_flatiter[i]:
arr_flatiter[i] = values_flatiter[i % values_size]


cpdef utils.dpnp_descriptor dpnp_select(list condlist, list choicelist, default):
cdef size_t size_ = condlist[0].size
cdef utils.dpnp_descriptor res_array = utils_py.create_output_descriptor_py(condlist[0].shape, choicelist[0].dtype, None)

pass_val = {a: default for a in range(size_)}
for i in range(len(condlist)):
for j in range(size_):
if (condlist[i])[j]:
res_array.get_pyobj()[j] = (choicelist[i])[j]
pass_val.pop(j)

for ind, val in pass_val.items():
res_array.get_pyobj()[ind] = val

return res_array
132 changes: 107 additions & 25 deletions dpnp/dpnp_iface_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,11 @@
from .dpnp_algo import (
dpnp_choose,
dpnp_putmask,
dpnp_select,
)
from .dpnp_array import dpnp_array
from .dpnp_utils import (
call_origin,
use_origin_backend,
get_usm_allocations,
)

__all__ = [
Expand Down Expand Up @@ -528,7 +527,7 @@ def extract(condition, a):
:obj:`dpnp.put` : Replaces specified elements of an array with given values.
:obj:`dpnp.copyto` : Copies values from one array to another, broadcasting
as necessary.
:obj:`dpnp.compress` : eturn selected slices of an array along given axis.
:obj:`dpnp.compress` : Return selected slices of an array along given axis.
:obj:`dpnp.place` : Change elements of an array based on conditional and
input values.
Expand Down Expand Up @@ -1356,31 +1355,114 @@ def select(condlist, choicelist, default=0):
For full documentation refer to :obj:`numpy.select`.
Limitations
-----------
Arrays of input lists are supported as :obj:`dpnp.ndarray`.
Parameter `default` is supported only with default values.
Parameters
----------
condlist : list of bool dpnp.ndarray or usm_ndarray
The list of conditions which determine from which array in `choicelist`
the output elements are taken. When multiple conditions are satisfied,
the first one encountered in `condlist` is used.
choicelist : list of dpnp.ndarray or usm_ndarray
The list of arrays from which the output elements are taken. It has
to be of the same length as `condlist`.
default : {scalar, dpnp.ndarray, usm_ndarray}, optional
The element inserted in `output` when all conditions evaluate to
``False``. Default: ``0``.
Returns
-------
out : dpnp.ndarray
The output at position m is the m-th element of the array in
`choicelist` where the m-th element of the corresponding array in
`condlist` is ``True``.
See Also
--------
:obj:`dpnp.where` : Return elements from one of two arrays depending on
condition.
:obj:`dpnp.take` : Take elements from an array along an axis.
:obj:`dpnp.choose` : Construct an array from an index array and a set of
arrays to choose from.
:obj:`dpnp.compress` : Return selected slices of an array along given axis.
:obj:`dpnp.diag` : Extract a diagonal or construct a diagonal array.
:obj:`dpnp.diagonal` : Return specified diagonals.
Examples
--------
>>> import dpnp as np
Beginning with an array of integers from 0 to 5 (inclusive),
elements less than ``3`` are negated, elements greater than ``3``
are squared, and elements not meeting either of these conditions
(exactly ``3``) are replaced with a `default` value of ``42``.
>>> x = np.arange(6)
>>> condlist = [x<3, x>3]
>>> choicelist = [x, x**2]
>>> np.select(condlist, choicelist, 42)
array([ 0, 1, 2, 42, 16, 25])
When multiple conditions are satisfied, the first one encountered in
`condlist` is used.
>>> condlist = [x<=4, x>3]
>>> choicelist = [x, x**2]
>>> np.select(condlist, choicelist, 55)
array([ 0, 1, 2, 3, 4, 25])
"""

if not use_origin_backend():
if not isinstance(condlist, list):
pass
elif not isinstance(choicelist, list):
pass
elif len(condlist) != len(choicelist):
pass
else:
val = True
size_ = condlist[0].size
for cond, choice in zip(condlist, choicelist):
if cond.size != size_ or choice.size != size_:
val = False
if not val:
pass
else:
return dpnp_select(condlist, choicelist, default).get_pyobj()
if len(condlist) != len(choicelist):
raise ValueError(
"list of cases must be same length as list of conditions"
)

if len(condlist) == 0:
raise ValueError("select with an empty condition list is not possible")

dpnp.check_supported_arrays_type(*condlist)
dpnp.check_supported_arrays_type(*choicelist)

if not dpnp.isscalar(default) and not (
dpnp.is_supported_array_type(default) and default.ndim == 0
):
raise TypeError(
"A default value must be any of scalar or 0-d supported array type"
)

dtype = dpnp.result_type(*choicelist, default)

usm_type_alloc, sycl_queue_alloc = get_usm_allocations(
condlist + choicelist + [default]
)

for i, cond in enumerate(condlist):
if not dpnp.issubdtype(cond, dpnp.bool):
raise TypeError(
f"invalid entry {i} in condlist: should be boolean ndarray"
)

# Convert conditions to arrays and broadcast conditions and choices
# as the shape is needed for the result
condlist = dpnp.broadcast_arrays(*condlist)
choicelist = dpnp.broadcast_arrays(*choicelist)

result_shape = dpnp.broadcast_arrays(condlist[0], choicelist[0])[0].shape

result = dpnp.full(
result_shape,
default,
dtype=dtype,
usm_type=usm_type_alloc,
sycl_queue=sycl_queue_alloc,
)

# Do in reverse order since the first choice should take precedence.
choicelist = choicelist[::-1]
condlist = condlist[::-1]
for choice, cond in zip(choicelist, condlist):
dpnp.where(cond, choice, result, out=result)

return call_origin(numpy.select, condlist, choicelist, default)
return result


# pylint: disable=redefined-outer-name
Expand Down
13 changes: 0 additions & 13 deletions tests/skipped_tests.tbl
Original file line number Diff line number Diff line change
Expand Up @@ -90,19 +90,6 @@ tests/third_party/cupy/indexing_tests/test_generate.py::TestUnravelIndex::test_i
tests/third_party/cupy/indexing_tests/test_generate.py::TestUnravelIndex::test_invalid_index
tests/third_party/cupy/indexing_tests/test_generate.py::TestUnravelIndex::test_invalid_order

tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_1D_choicelist
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_choicelist_condlist_broadcast
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_complex
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_default
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_default_complex
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_default_scalar
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_empty_lists
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_length_error
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_odd_shaped_broadcastable
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_odd_shaped_broadcastable_complex
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_odd_shaped_non_broadcastable

tests/third_party/cupy/indexing_tests/test_insert.py::TestPutmaskDifferentDtypes::test_putmask_differnt_dtypes_raises
tests/third_party/cupy/indexing_tests/test_insert.py::TestPutmask::test_putmask_non_equal_shape_raises

Expand Down
14 changes: 0 additions & 14 deletions tests/skipped_tests_gpu.tbl
Original file line number Diff line number Diff line change
Expand Up @@ -143,20 +143,6 @@ tests/third_party/cupy/indexing_tests/test_generate.py::TestUnravelIndex::test_i
tests/third_party/cupy/indexing_tests/test_generate.py::TestUnravelIndex::test_invalid_index
tests/third_party/cupy/indexing_tests/test_generate.py::TestUnravelIndex::test_invalid_order

tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_1D_choicelist
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_choicelist_condlist_broadcast
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_complex
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_default
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_default_complex
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_default_scalar
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_empty_lists
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_length_error
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_odd_shaped_broadcastable
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_odd_shaped_broadcastable_complex
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_odd_shaped_non_broadcastable
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_type_error_condlist

tests/third_party/cupy/indexing_tests/test_insert.py::TestPutmaskDifferentDtypes::test_putmask_differnt_dtypes_raises
tests/third_party/cupy/indexing_tests/test_insert.py::TestPutmask::test_putmask_non_equal_shape_raises

Expand Down
108 changes: 85 additions & 23 deletions tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import dpnp

from .helper import get_all_dtypes, get_integer_dtypes
from .helper import get_all_dtypes, get_integer_dtypes, has_support_aspect64


def _add_keepdims(func):
Expand Down Expand Up @@ -949,28 +949,6 @@ def test_putmask3(arr, mask, vals):
assert_array_equal(a, ia)


def test_select():
cond_val1 = numpy.array(
[True, True, True, False, False, False, False, False, False, False]
)
cond_val2 = numpy.array(
[False, False, False, False, False, True, True, True, True, True]
)
icond_val1 = dpnp.array(cond_val1)
icond_val2 = dpnp.array(cond_val2)
condlist = [cond_val1, cond_val2]
icondlist = [icond_val1, icond_val2]
choice_val1 = numpy.full(10, -2)
choice_val2 = numpy.full(10, -1)
ichoice_val1 = dpnp.array(choice_val1)
ichoice_val2 = dpnp.array(choice_val2)
choicelist = [choice_val1, choice_val2]
ichoicelist = [ichoice_val1, ichoice_val2]
expected = numpy.select(condlist, choicelist)
result = dpnp.select(icondlist, ichoicelist)
assert_array_equal(expected, result)


@pytest.mark.parametrize(
"m", [None, 0, 1, 2, 3, 4], ids=["None", "0", "1", "2", "3", "4"]
)
Expand Down Expand Up @@ -1057,3 +1035,87 @@ def test_fill_diagonal_error():
arr = dpnp.ones((1, 2, 3))
with pytest.raises(ValueError):
dpnp.fill_diagonal(arr, 5)


class TestSelect:
choices_np = [
numpy.array([1, 2, 3]),
numpy.array([4, 5, 6]),
numpy.array([7, 8, 9]),
]
choices_dp = [
dpnp.array([1, 2, 3]),
dpnp.array([4, 5, 6]),
dpnp.array([7, 8, 9]),
]
conditions_np = [
numpy.array([False, False, False]),
numpy.array([False, True, False]),
numpy.array([False, False, True]),
]
conditions_dp = [
dpnp.array([False, False, False]),
dpnp.array([False, True, False]),
dpnp.array([False, False, True]),
]

def test_basic(self):
expected = numpy.select(self.conditions_np, self.choices_np, default=15)
result = dpnp.select(self.conditions_dp, self.choices_dp, default=15)
assert_array_equal(expected, result)

def test_broadcasting(self):
conditions_np = [numpy.array(True), numpy.array([False, True, False])]
conditions_dp = [dpnp.array(True), dpnp.array([False, True, False])]
choices_np = [numpy.array(1), numpy.arange(12).reshape(4, 3)]
choices_dp = [dpnp.array(1), dpnp.arange(12).reshape(4, 3)]
expected = numpy.select(conditions_np, choices_np)
result = dpnp.select(conditions_dp, choices_dp)
assert_array_equal(expected, result)

def test_return_dtype(self):
dtype = dpnp.complex128 if has_support_aspect64() else dpnp.complex64
assert_equal(
dpnp.select(self.conditions_dp, self.choices_dp, 1j).dtype, dtype
)

choices = [choice.astype(dpnp.int32) for choice in self.choices_dp]
assert_equal(dpnp.select(self.conditions_dp, choices).dtype, dpnp.int32)

def test_nan(self):
choice_np = numpy.array([1, 2, 3, numpy.nan, 5, 7])
choice_dp = dpnp.array([1, 2, 3, dpnp.nan, 5, 7])
condition_np = numpy.isnan(choice_np)
condition_dp = dpnp.isnan(choice_dp)
expected = numpy.select([condition_np], [choice_np])
result = dpnp.select([condition_dp], [choice_dp])
assert_array_equal(expected, result)

def test_many_arguments(self):
condition_np = [numpy.array([False])] * 100
condition_dp = [dpnp.array([False])] * 100
choice_np = [numpy.array([1])] * 100
choice_dp = [dpnp.array([1])] * 100
expected = numpy.select(condition_np, choice_np)
result = dpnp.select(condition_dp, choice_dp)
assert_array_equal(expected, result)

def test_deprecated_empty(self):
assert_raises(ValueError, dpnp.select, [], [], 3j)
assert_raises(ValueError, dpnp.select, [], [])

def test_non_bool_deprecation(self):
choices = self.choices_dp
conditions = self.conditions_dp[:]
conditions[0] = conditions[0].astype(dpnp.int64)
assert_raises(TypeError, dpnp.select, conditions, choices)

def test_error(self):
x0 = dpnp.array([True, False])
x1 = dpnp.array([1, 2])
with pytest.raises(ValueError):
dpnp.select([x0], [x1, x1])
with pytest.raises(TypeError):
dpnp.select([x0], [x1], default=x1)
with pytest.raises(TypeError):
dpnp.select([x1], [x1])
12 changes: 12 additions & 0 deletions tests/test_sycl_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -2461,6 +2461,18 @@ def test_astype(device_x, device_y):
assert_sycl_queue_equal(y.sycl_queue, sycl_queue)


@pytest.mark.parametrize(
"device",
valid_devices,
ids=[device.filter_string for device in valid_devices],
)
def test_select(device):
condlist = [dpnp.array([True, False], device=device)]
choicelist = [dpnp.array([1, 2], device=device)]
res = dpnp.select(condlist, choicelist)
assert_sycl_queue_equal(res.sycl_queue, condlist[0].sycl_queue)


@pytest.mark.parametrize("axis", [None, 0, -1])
@pytest.mark.parametrize(
"device",
Expand Down
9 changes: 9 additions & 0 deletions tests/test_usm_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -1435,6 +1435,15 @@ def test_histogram_bin_edges(usm_type_v, usm_type_w):
assert edges.usm_type == du.get_coerced_usm_type([usm_type_v, usm_type_w])


@pytest.mark.parametrize("usm_type_x", list_of_usm_types, ids=list_of_usm_types)
@pytest.mark.parametrize("usm_type_y", list_of_usm_types, ids=list_of_usm_types)
def test_select(usm_type_x, usm_type_y):
condlist = [dp.array([True, False], usm_type=usm_type_x)]
choicelist = [dp.array([1, 2], usm_type=usm_type_y)]
res = dp.select(condlist, choicelist)
assert res.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_y])


@pytest.mark.parametrize("axis", [None, 0, -1])
@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types)
def test_unique(axis, usm_type):
Expand Down
Loading

0 comments on commit f8378b0

Please sign in to comment.