diff --git a/dpnp/dpnp_algo/dpnp_algo_indexing.pxi b/dpnp/dpnp_algo/dpnp_algo_indexing.pxi index d96ebcd816f..4eab7f3e94c 100644 --- a/dpnp/dpnp_algo/dpnp_algo_indexing.pxi +++ b/dpnp/dpnp_algo/dpnp_algo_indexing.pxi @@ -38,7 +38,6 @@ and the rest of the library __all__ += [ "dpnp_choose", "dpnp_putmask", - "dpnp_select", ] ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_dpnp_choose_t)(c_dpctl.DPCTLSyclQueueRef, @@ -102,20 +101,3 @@ cpdef dpnp_putmask(utils.dpnp_descriptor arr, utils.dpnp_descriptor mask, utils. for i in range(arr.size): if mask_flatiter[i]: arr_flatiter[i] = values_flatiter[i % values_size] - - -cpdef utils.dpnp_descriptor dpnp_select(list condlist, list choicelist, default): - cdef size_t size_ = condlist[0].size - cdef utils.dpnp_descriptor res_array = utils_py.create_output_descriptor_py(condlist[0].shape, choicelist[0].dtype, None) - - pass_val = {a: default for a in range(size_)} - for i in range(len(condlist)): - for j in range(size_): - if (condlist[i])[j]: - res_array.get_pyobj()[j] = (choicelist[i])[j] - pass_val.pop(j) - - for ind, val in pass_val.items(): - res_array.get_pyobj()[ind] = val - - return res_array diff --git a/dpnp/dpnp_iface_indexing.py b/dpnp/dpnp_iface_indexing.py index e8df6fec906..d76622d382e 100644 --- a/dpnp/dpnp_iface_indexing.py +++ b/dpnp/dpnp_iface_indexing.py @@ -49,12 +49,11 @@ from .dpnp_algo import ( dpnp_choose, dpnp_putmask, - dpnp_select, ) from .dpnp_array import dpnp_array from .dpnp_utils import ( call_origin, - use_origin_backend, + get_usm_allocations, ) __all__ = [ @@ -528,7 +527,7 @@ def extract(condition, a): :obj:`dpnp.put` : Replaces specified elements of an array with given values. :obj:`dpnp.copyto` : Copies values from one array to another, broadcasting as necessary. - :obj:`dpnp.compress` : eturn selected slices of an array along given axis. + :obj:`dpnp.compress` : Return selected slices of an array along given axis. :obj:`dpnp.place` : Change elements of an array based on conditional and input values. @@ -1356,31 +1355,114 @@ def select(condlist, choicelist, default=0): For full documentation refer to :obj:`numpy.select`. - Limitations - ----------- - Arrays of input lists are supported as :obj:`dpnp.ndarray`. - Parameter `default` is supported only with default values. + Parameters + ---------- + condlist : list of bool dpnp.ndarray or usm_ndarray + The list of conditions which determine from which array in `choicelist` + the output elements are taken. When multiple conditions are satisfied, + the first one encountered in `condlist` is used. + choicelist : list of dpnp.ndarray or usm_ndarray + The list of arrays from which the output elements are taken. It has + to be of the same length as `condlist`. + default : {scalar, dpnp.ndarray, usm_ndarray}, optional + The element inserted in `output` when all conditions evaluate to + ``False``. Default: ``0``. + + Returns + ------- + out : dpnp.ndarray + The output at position m is the m-th element of the array in + `choicelist` where the m-th element of the corresponding array in + `condlist` is ``True``. + + See Also + -------- + :obj:`dpnp.where` : Return elements from one of two arrays depending on + condition. + :obj:`dpnp.take` : Take elements from an array along an axis. + :obj:`dpnp.choose` : Construct an array from an index array and a set of + arrays to choose from. + :obj:`dpnp.compress` : Return selected slices of an array along given axis. + :obj:`dpnp.diag` : Extract a diagonal or construct a diagonal array. + :obj:`dpnp.diagonal` : Return specified diagonals. + + Examples + -------- + >>> import dpnp as np + + Beginning with an array of integers from 0 to 5 (inclusive), + elements less than ``3`` are negated, elements greater than ``3`` + are squared, and elements not meeting either of these conditions + (exactly ``3``) are replaced with a `default` value of ``42``. + + >>> x = np.arange(6) + >>> condlist = [x<3, x>3] + >>> choicelist = [x, x**2] + >>> np.select(condlist, choicelist, 42) + array([ 0, 1, 2, 42, 16, 25]) + + When multiple conditions are satisfied, the first one encountered in + `condlist` is used. + + >>> condlist = [x<=4, x>3] + >>> choicelist = [x, x**2] + >>> np.select(condlist, choicelist, 55) + array([ 0, 1, 2, 3, 4, 25]) + """ - if not use_origin_backend(): - if not isinstance(condlist, list): - pass - elif not isinstance(choicelist, list): - pass - elif len(condlist) != len(choicelist): - pass - else: - val = True - size_ = condlist[0].size - for cond, choice in zip(condlist, choicelist): - if cond.size != size_ or choice.size != size_: - val = False - if not val: - pass - else: - return dpnp_select(condlist, choicelist, default).get_pyobj() + if len(condlist) != len(choicelist): + raise ValueError( + "list of cases must be same length as list of conditions" + ) + + if len(condlist) == 0: + raise ValueError("select with an empty condition list is not possible") + + dpnp.check_supported_arrays_type(*condlist) + dpnp.check_supported_arrays_type(*choicelist) + + if not dpnp.isscalar(default) and not ( + dpnp.is_supported_array_type(default) and default.ndim == 0 + ): + raise TypeError( + "A default value must be any of scalar or 0-d supported array type" + ) + + dtype = dpnp.result_type(*choicelist, default) + + usm_type_alloc, sycl_queue_alloc = get_usm_allocations( + condlist + choicelist + [default] + ) + + for i, cond in enumerate(condlist): + if not dpnp.issubdtype(cond, dpnp.bool): + raise TypeError( + f"invalid entry {i} in condlist: should be boolean ndarray" + ) + + # Convert conditions to arrays and broadcast conditions and choices + # as the shape is needed for the result + condlist = dpnp.broadcast_arrays(*condlist) + choicelist = dpnp.broadcast_arrays(*choicelist) + + result_shape = dpnp.broadcast_arrays(condlist[0], choicelist[0])[0].shape + + result = dpnp.full( + result_shape, + default, + dtype=dtype, + usm_type=usm_type_alloc, + sycl_queue=sycl_queue_alloc, + ) + + # Do in reverse order since the first choice should take precedence. + choicelist = choicelist[::-1] + condlist = condlist[::-1] + for choice, cond in zip(choicelist, condlist): + dpnp.where(cond, choice, result, out=result) - return call_origin(numpy.select, condlist, choicelist, default) + return result # pylint: disable=redefined-outer-name diff --git a/tests/skipped_tests.tbl b/tests/skipped_tests.tbl index 02700e87c3f..3071f8f4813 100644 --- a/tests/skipped_tests.tbl +++ b/tests/skipped_tests.tbl @@ -90,19 +90,6 @@ tests/third_party/cupy/indexing_tests/test_generate.py::TestUnravelIndex::test_i tests/third_party/cupy/indexing_tests/test_generate.py::TestUnravelIndex::test_invalid_index tests/third_party/cupy/indexing_tests/test_generate.py::TestUnravelIndex::test_invalid_order -tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select -tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_1D_choicelist -tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_choicelist_condlist_broadcast -tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_complex -tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_default -tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_default_complex -tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_default_scalar -tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_empty_lists -tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_length_error -tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_odd_shaped_broadcastable -tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_odd_shaped_broadcastable_complex -tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_odd_shaped_non_broadcastable - tests/third_party/cupy/indexing_tests/test_insert.py::TestPutmaskDifferentDtypes::test_putmask_differnt_dtypes_raises tests/third_party/cupy/indexing_tests/test_insert.py::TestPutmask::test_putmask_non_equal_shape_raises diff --git a/tests/skipped_tests_gpu.tbl b/tests/skipped_tests_gpu.tbl index 455c2bc58a3..389b708f26b 100644 --- a/tests/skipped_tests_gpu.tbl +++ b/tests/skipped_tests_gpu.tbl @@ -143,20 +143,6 @@ tests/third_party/cupy/indexing_tests/test_generate.py::TestUnravelIndex::test_i tests/third_party/cupy/indexing_tests/test_generate.py::TestUnravelIndex::test_invalid_index tests/third_party/cupy/indexing_tests/test_generate.py::TestUnravelIndex::test_invalid_order -tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select -tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_1D_choicelist -tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_choicelist_condlist_broadcast -tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_complex -tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_default -tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_default_complex -tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_default_scalar -tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_empty_lists -tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_length_error -tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_odd_shaped_broadcastable -tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_odd_shaped_broadcastable_complex -tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_odd_shaped_non_broadcastable -tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_type_error_condlist - tests/third_party/cupy/indexing_tests/test_insert.py::TestPutmaskDifferentDtypes::test_putmask_differnt_dtypes_raises tests/third_party/cupy/indexing_tests/test_insert.py::TestPutmask::test_putmask_non_equal_shape_raises diff --git a/tests/test_indexing.py b/tests/test_indexing.py index bed48bce398..ffa3304e8f4 100644 --- a/tests/test_indexing.py +++ b/tests/test_indexing.py @@ -13,7 +13,7 @@ import dpnp -from .helper import get_all_dtypes, get_integer_dtypes +from .helper import get_all_dtypes, get_integer_dtypes, has_support_aspect64 def _add_keepdims(func): @@ -949,28 +949,6 @@ def test_putmask3(arr, mask, vals): assert_array_equal(a, ia) -def test_select(): - cond_val1 = numpy.array( - [True, True, True, False, False, False, False, False, False, False] - ) - cond_val2 = numpy.array( - [False, False, False, False, False, True, True, True, True, True] - ) - icond_val1 = dpnp.array(cond_val1) - icond_val2 = dpnp.array(cond_val2) - condlist = [cond_val1, cond_val2] - icondlist = [icond_val1, icond_val2] - choice_val1 = numpy.full(10, -2) - choice_val2 = numpy.full(10, -1) - ichoice_val1 = dpnp.array(choice_val1) - ichoice_val2 = dpnp.array(choice_val2) - choicelist = [choice_val1, choice_val2] - ichoicelist = [ichoice_val1, ichoice_val2] - expected = numpy.select(condlist, choicelist) - result = dpnp.select(icondlist, ichoicelist) - assert_array_equal(expected, result) - - @pytest.mark.parametrize( "m", [None, 0, 1, 2, 3, 4], ids=["None", "0", "1", "2", "3", "4"] ) @@ -1057,3 +1035,87 @@ def test_fill_diagonal_error(): arr = dpnp.ones((1, 2, 3)) with pytest.raises(ValueError): dpnp.fill_diagonal(arr, 5) + + +class TestSelect: + choices_np = [ + numpy.array([1, 2, 3]), + numpy.array([4, 5, 6]), + numpy.array([7, 8, 9]), + ] + choices_dp = [ + dpnp.array([1, 2, 3]), + dpnp.array([4, 5, 6]), + dpnp.array([7, 8, 9]), + ] + conditions_np = [ + numpy.array([False, False, False]), + numpy.array([False, True, False]), + numpy.array([False, False, True]), + ] + conditions_dp = [ + dpnp.array([False, False, False]), + dpnp.array([False, True, False]), + dpnp.array([False, False, True]), + ] + + def test_basic(self): + expected = numpy.select(self.conditions_np, self.choices_np, default=15) + result = dpnp.select(self.conditions_dp, self.choices_dp, default=15) + assert_array_equal(expected, result) + + def test_broadcasting(self): + conditions_np = [numpy.array(True), numpy.array([False, True, False])] + conditions_dp = [dpnp.array(True), dpnp.array([False, True, False])] + choices_np = [numpy.array(1), numpy.arange(12).reshape(4, 3)] + choices_dp = [dpnp.array(1), dpnp.arange(12).reshape(4, 3)] + expected = numpy.select(conditions_np, choices_np) + result = dpnp.select(conditions_dp, choices_dp) + assert_array_equal(expected, result) + + def test_return_dtype(self): + dtype = dpnp.complex128 if has_support_aspect64() else dpnp.complex64 + assert_equal( + dpnp.select(self.conditions_dp, self.choices_dp, 1j).dtype, dtype + ) + + choices = [choice.astype(dpnp.int32) for choice in self.choices_dp] + assert_equal(dpnp.select(self.conditions_dp, choices).dtype, dpnp.int32) + + def test_nan(self): + choice_np = numpy.array([1, 2, 3, numpy.nan, 5, 7]) + choice_dp = dpnp.array([1, 2, 3, dpnp.nan, 5, 7]) + condition_np = numpy.isnan(choice_np) + condition_dp = dpnp.isnan(choice_dp) + expected = numpy.select([condition_np], [choice_np]) + result = dpnp.select([condition_dp], [choice_dp]) + assert_array_equal(expected, result) + + def test_many_arguments(self): + condition_np = [numpy.array([False])] * 100 + condition_dp = [dpnp.array([False])] * 100 + choice_np = [numpy.array([1])] * 100 + choice_dp = [dpnp.array([1])] * 100 + expected = numpy.select(condition_np, choice_np) + result = dpnp.select(condition_dp, choice_dp) + assert_array_equal(expected, result) + + def test_deprecated_empty(self): + assert_raises(ValueError, dpnp.select, [], [], 3j) + assert_raises(ValueError, dpnp.select, [], []) + + def test_non_bool_deprecation(self): + choices = self.choices_dp + conditions = self.conditions_dp[:] + conditions[0] = conditions[0].astype(dpnp.int64) + assert_raises(TypeError, dpnp.select, conditions, choices) + + def test_error(self): + x0 = dpnp.array([True, False]) + x1 = dpnp.array([1, 2]) + with pytest.raises(ValueError): + dpnp.select([x0], [x1, x1]) + with pytest.raises(TypeError): + dpnp.select([x0], [x1], default=x1) + with pytest.raises(TypeError): + dpnp.select([x1], [x1]) diff --git a/tests/test_sycl_queue.py b/tests/test_sycl_queue.py index 7e9c35507c4..8ac3f8dc104 100644 --- a/tests/test_sycl_queue.py +++ b/tests/test_sycl_queue.py @@ -2461,6 +2461,18 @@ def test_astype(device_x, device_y): assert_sycl_queue_equal(y.sycl_queue, sycl_queue) +@pytest.mark.parametrize( + "device", + valid_devices, + ids=[device.filter_string for device in valid_devices], +) +def test_select(device): + condlist = [dpnp.array([True, False], device=device)] + choicelist = [dpnp.array([1, 2], device=device)] + res = dpnp.select(condlist, choicelist) + assert_sycl_queue_equal(res.sycl_queue, condlist[0].sycl_queue) + + @pytest.mark.parametrize("axis", [None, 0, -1]) @pytest.mark.parametrize( "device", diff --git a/tests/test_usm_type.py b/tests/test_usm_type.py index db9438d6a8c..c4e310882af 100644 --- a/tests/test_usm_type.py +++ b/tests/test_usm_type.py @@ -1435,6 +1435,15 @@ def test_histogram_bin_edges(usm_type_v, usm_type_w): assert edges.usm_type == du.get_coerced_usm_type([usm_type_v, usm_type_w]) +@pytest.mark.parametrize("usm_type_x", list_of_usm_types, ids=list_of_usm_types) +@pytest.mark.parametrize("usm_type_y", list_of_usm_types, ids=list_of_usm_types) +def test_select(usm_type_x, usm_type_y): + condlist = [dp.array([True, False], usm_type=usm_type_x)] + choicelist = [dp.array([1, 2], usm_type=usm_type_y)] + res = dp.select(condlist, choicelist) + assert res.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_y]) + + @pytest.mark.parametrize("axis", [None, 0, -1]) @pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types) def test_unique(axis, usm_type): diff --git a/tests/third_party/cupy/indexing_tests/test_indexing.py b/tests/third_party/cupy/indexing_tests/test_indexing.py index 94a9cc9a29e..ca2f9a9cc6c 100644 --- a/tests/third_party/cupy/indexing_tests/test_indexing.py +++ b/tests/third_party/cupy/indexing_tests/test_indexing.py @@ -355,10 +355,9 @@ def test_select_type_error_condlist(self, dtype): a = cupy.arange(10, dtype=dtype) condlist = [[3] * 10, [2] * 10] choicelist = [a, a**2] - with pytest.raises(AttributeError): + with pytest.raises(TypeError): cupy.select(condlist, choicelist) - @pytest.mark.usefixtures("allow_fall_back_on_numpy") @testing.for_all_dtypes(no_bool=True) def test_select_type_error_choicelist(self, dtype): a, b = list(range(10)), list(range(-10, 0))