diff --git a/dpctl/tests/test_usm_ndarray_sorting.py b/dpctl/tests/test_usm_ndarray_sorting.py index 69f2c62a04..e76ac667e1 100644 --- a/dpctl/tests/test_usm_ndarray_sorting.py +++ b/dpctl/tests/test_usm_ndarray_sorting.py @@ -14,6 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import itertools + +import numpy as np import pytest import dpctl.tensor as dpt @@ -211,3 +214,69 @@ def test_argsort_0d_array(): x = dpt.asarray(1, dtype="i4") assert dpt.argsort(x) == 0 + + +@pytest.mark.parametrize( + "dtype", + [ + "f2", + "f4", + "f8", + ], +) +def test_sort_real_fp_nan(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + x = dpt.asarray( + [-0.0, 0.1, dpt.nan, 0.0, -0.1, dpt.nan, 0.2, -0.3], dtype=dtype + ) + s = dpt.sort(x) + + expected = dpt.asarray( + [-0.3, -0.1, -0.0, 0.0, 0.1, 0.2, dpt.nan, dpt.nan], dtype=dtype + ) + + assert dpt.allclose(s, expected, equal_nan=True) + + s = dpt.sort(x, descending=True) + + expected = dpt.asarray( + [dpt.nan, dpt.nan, 0.2, 0.1, -0.0, 0.0, -0.1, -0.3], dtype=dtype + ) + + assert dpt.allclose(s, expected, equal_nan=True) + + +@pytest.mark.parametrize( + "dtype", + [ + "c8", + "c16", + ], +) +def test_sort_complex_fp_nan(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + rvs = [-0.0, 0.1, 0.0, 0.2, -0.3, dpt.nan] + ivs = [-0.0, 0.1, 0.0, 0.2, -0.3, dpt.nan] + + cv = [] + for rv in rvs: + for iv in ivs: + cv.append(complex(rv, iv)) + + inp = dpt.asarray(cv, dtype=dtype) + s = dpt.sort(inp) + + expected = np.sort(dpt.asnumpy(inp)) + + assert np.allclose(dpt.asnumpy(s), expected, equal_nan=True) + + for i, j in itertools.permutations(range(inp.shape[0]), 2): + r1 = dpt.asnumpy(dpt.sort(inp[dpt.asarray([i, j])])) + r2 = np.sort(dpt.asnumpy(inp[dpt.asarray([i, j])])) + assert np.array_equal( + r1.view(np.int64), r2.view(np.int64) + ), f"Failed for {i} and {j}"