Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add proper handling of NaN values in dpnp.unique implementation with axis not None #1989

Merged
merged 6 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions dpnp/dpnp_iface_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def _unique_build_sort_indices(a, index_sh):

"""

is_complex = dpnp.iscomplexobj(a)
is_inexact = dpnp.issubdtype(a, dpnp.inexact)
if dpnp.issubdtype(a.dtype, numpy.unsignedinteger):
ar_cmp = a.astype(dpnp.intp)
elif dpnp.issubdtype(a.dtype, dpnp.bool):
Expand All @@ -200,8 +200,27 @@ def compare_axis_elems(idx1, idx2):
comp = dpnp.trim_zeros(ar_cmp[idx1] - ar_cmp[idx2], "f")
if comp.shape[0] > 0:
diff = comp[0]
if is_complex and dpnp.isnan(diff):
return True
if is_inexact and dpnp.isnan(diff):
isnan1 = dpnp.isnan(ar_cmp[idx1])
if not isnan1.any(): # no NaN in ar_cmp[idx1]
return True # ar_cmp[idx1] goes to left

isnan2 = dpnp.isnan(ar_cmp[idx2])
if not isnan2.any(): # no NaN in ar_cmp[idx2]
return False # ar_cmp[idx1] goes to right

# for complex all NaNs are considered equivalent
if (isnan1 & isnan2).all(): # NaNs at the same places
return False # ar_cmp[idx1] goes to right

xor_nan_idx = dpnp.where(isnan1 ^ isnan2)[0]
if xor_nan_idx.size == 0:
return False

if dpnp.isnan(ar_cmp[idx2][xor_nan_idx[0]]):
# first NaN in XOR mask is from ar_cmp[idx2]
return True # ar_cmp[idx1] goes to left
return False
return diff < 0
return False

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ exclude-protected = ["_create_from_usm_ndarray"]
max-args = 11
max-locals = 30
max-branches = 15
max-returns = 7
max-returns = 8

[tool.pylint.format]
max-line-length = 80
Expand Down
60 changes: 45 additions & 15 deletions tests/test_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .helper import (
get_all_dtypes,
get_complex_dtypes,
get_float_complex_dtypes,
get_float_dtypes,
get_integer_dtypes,
has_support_aspect64,
Expand Down Expand Up @@ -88,21 +89,6 @@ def test_result_type_only_arrays():
assert dpnp.result_type(*X) == numpy.result_type(*X_np)


@pytest.mark.usefixtures("allow_fall_back_on_numpy")
@pytest.mark.parametrize(
"array",
[[1, 2, 3], [1, 2, 2, 1, 2, 4], [2, 2, 2, 2], []],
ids=["[1, 2, 3]", "[1, 2, 2, 1, 2, 4]", "[2, 2, 2, 2]", "[]"],
)
def test_unique(array):
np_a = numpy.array(array)
dpnp_a = dpnp.array(array)

expected = numpy.unique(np_a)
result = dpnp.unique(dpnp_a)
assert_array_equal(result, expected)


class TestRepeat:
@pytest.mark.parametrize(
"data",
Expand Down Expand Up @@ -748,3 +734,47 @@ def test_equal_nan(self, eq_nan_kwd):
result = dpnp.unique(ia, **eq_nan_kwd)
expected = numpy.unique(a, **eq_nan_kwd)
assert_array_equal(result, expected)

@pytest.mark.parametrize("dt", get_float_complex_dtypes())
@pytest.mark.parametrize(
"axis_kwd",
[
{},
{"axis": 0},
{"axis": 1},
],
)
@pytest.mark.parametrize(
"return_kwds",
[
{},
{
"return_index": True,
"return_inverse": True,
"return_counts": True,
},
],
)
@pytest.mark.parametrize(
"row", [[2, 3, 4], [2, numpy.nan, 4], [numpy.nan, 3, 4]]
)
def test_2d_axis_nans(self, dt, axis_kwd, return_kwds, row):
a = numpy.array(
[
[1, 0, 0],
[1, 0, 0],
[numpy.nan, numpy.nan, numpy.nan],
row,
[1, 0, 1],
[numpy.nan, numpy.nan, numpy.nan],
]
).astype(dt)
ia = dpnp.array(a)

result = dpnp.unique(ia, **axis_kwd, **return_kwds)
expected = numpy.unique(a, **axis_kwd, **return_kwds)
if len(return_kwds) == 0:
assert_array_equal(result, expected)
else:
for iv, v in zip(result, expected):
assert_array_equal(iv, v)
5 changes: 5 additions & 0 deletions tests/test_sycl_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -2393,6 +2393,11 @@ def test_astype(device_x, device_y):


@pytest.mark.parametrize("axis", [None, 0, -1])
@pytest.mark.parametrize(
"device",
valid_devices,
ids=[device.filter_string for device in valid_devices],
)
def test_unique(axis, device):
a = numpy.array([[1, 1], [2, 3]])
ia = dpnp.array(a, device=device)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def test_unique_equal_nan(self, xp, dtype, equal_nan):
[[2, xp.nan, 2], [xp.nan, 1, xp.nan], [xp.nan, 1, xp.nan]],
dtype=dtype,
)
return xp.unique(a, axis=0, equal_nan=equal_nan)
return xp.unique(a, axis=1, equal_nan=equal_nan)


@testing.parameterize(*testing.product({"trim": ["fb", "f", "b"]}))
Expand Down
Loading