Skip to content

Commit

Permalink
Add new nanargmin/nanargmax tests (#2223)
Browse files Browse the repository at this point in the history
The PR proposes to extend third party tests with new nanargmin/nanargmax
tests added recently.
  • Loading branch information
antonwolfy authored Dec 10, 2024
1 parent 48babd0 commit b990bac
Showing 1 changed file with 34 additions and 2 deletions.
36 changes: 34 additions & 2 deletions dpnp/tests/third_party/cupy/sorting_tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,22 @@ def test_nanargmin_zero_size_axis1(self, xp, dtype):
a = testing.shaped_random((0, 1), xp, dtype)
return xp.nanargmin(a, axis=1)

@testing.for_all_dtypes(no_complex=True)
@testing.numpy_cupy_allclose()
def test_nanargmin_out_float_dtype(self, xp, dtype):
a = xp.array([[0.0]])
b = xp.empty((1), dtype="int64")
xp.nanargmin(a, axis=1, out=b)
return b

@testing.for_all_dtypes(no_complex=True)
@testing.numpy_cupy_array_equal()
def test_nanargmin_out_int_dtype(self, xp, dtype):
a = xp.array([1, 0])
b = xp.empty((), dtype="int64")
xp.nanargmin(a, out=b)
return b


class TestNanArgMax:

Expand Down Expand Up @@ -623,6 +639,22 @@ def test_nanargmax_zero_size_axis1(self, xp, dtype):
a = testing.shaped_random((0, 1), xp, dtype)
return xp.nanargmax(a, axis=1)

@testing.for_all_dtypes(no_complex=True)
@testing.numpy_cupy_allclose()
def test_nanargmax_out_float_dtype(self, xp, dtype):
a = xp.array([[0.0]])
b = xp.empty((1), dtype="int64")
xp.nanargmax(a, axis=1, out=b)
return b

@testing.for_all_dtypes(no_complex=True)
@testing.numpy_cupy_array_equal()
def test_nanargmax_out_int_dtype(self, xp, dtype):
a = xp.array([0, 1])
b = xp.empty((), dtype="int64")
xp.nanargmax(a, out=b)
return b


@testing.parameterize(
*testing.product(
Expand Down Expand Up @@ -771,7 +803,7 @@ def test_invalid_sorter(self):

def test_nonint_sorter(self):
for xp in (numpy, cupy):
x = testing.shaped_arange((12,), xp, xp.float32)
x = testing.shaped_arange((12,), xp, xp.float64)
bins = xp.array([10, 4, 2, 1, 8])
sorter = xp.array([], dtype=xp.float32)
with pytest.raises((TypeError, ValueError)):
Expand Down Expand Up @@ -865,7 +897,7 @@ def test_invalid_sorter(self):

def test_nonint_sorter(self):
for xp in (numpy, cupy):
x = testing.shaped_arange((12,), xp, xp.float32)
x = testing.shaped_arange((12,), xp, xp.float64)
bins = xp.array([10, 4, 2, 1, 8])
sorter = xp.array([], dtype=xp.float32)
with pytest.raises((TypeError, ValueError)):
Expand Down

0 comments on commit b990bac

Please sign in to comment.