Skip to content

Commit

Permalink
Adding tests for sorting of FP arrays with NaNs
Browse files Browse the repository at this point in the history
  • Loading branch information
oleksandr-pavlyk committed Jan 9, 2024
1 parent a3d0d08 commit 5469832
Showing 1 changed file with 69 additions and 0 deletions.
69 changes: 69 additions & 0 deletions dpctl/tests/test_usm_ndarray_sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools

import numpy as np
import pytest

import dpctl.tensor as dpt
Expand Down Expand Up @@ -211,3 +214,69 @@ def test_argsort_0d_array():

x = dpt.asarray(1, dtype="i4")
assert dpt.argsort(x) == 0


@pytest.mark.parametrize(
"dtype",
[
"f2",
"f4",
"f8",
],
)
def test_sort_real_fp_nan(dtype):
q = get_queue_or_skip()
skip_if_dtype_not_supported(dtype, q)

x = dpt.asarray(
[-0.0, 0.1, dpt.nan, 0.0, -0.1, dpt.nan, 0.2, -0.3], dtype=dtype
)
s = dpt.sort(x)

expected = dpt.asarray(
[-0.3, -0.1, -0.0, 0.0, 0.1, 0.2, dpt.nan, dpt.nan], dtype=dtype
)

assert dpt.allclose(s, expected, equal_nan=True)

s = dpt.sort(x, descending=True)

expected = dpt.asarray(
[dpt.nan, dpt.nan, 0.2, 0.1, -0.0, 0.0, -0.1, -0.3], dtype=dtype
)

assert dpt.allclose(s, expected, equal_nan=True)


@pytest.mark.parametrize(
"dtype",
[
"c8",
"c16",
],
)
def test_sort_complex_fp_nan(dtype):
q = get_queue_or_skip()
skip_if_dtype_not_supported(dtype, q)

rvs = [-0.0, 0.1, 0.0, 0.2, -0.3, dpt.nan]
ivs = [-0.0, 0.1, 0.0, 0.2, -0.3, dpt.nan]

cv = []
for rv in rvs:
for iv in ivs:
cv.append(complex(rv, iv))

inp = dpt.asarray(cv, dtype=dtype)
s = dpt.sort(inp)

expected = np.sort(dpt.asnumpy(inp))

assert np.allclose(dpt.asnumpy(s), expected, equal_nan=True)

for i, j in itertools.permutations(range(inp.shape[0]), 2):
r1 = dpt.asnumpy(dpt.sort(inp[dpt.asarray([i, j])]))
r2 = np.sort(dpt.asnumpy(inp[dpt.asarray([i, j])]))
assert np.array_equal(
r1.view(np.int64), r2.view(np.int64)
), f"Failed for {i} and {j}"

0 comments on commit 5469832

Please sign in to comment.