Skip to content

Commit 7387990

Browse files
authored
Add proper handling of NaN values in dpnp.unique implementation with axis not None (#1989)
* Align handling of NaN values for axis not None with numpy * Add missing parametrize for test_sycl_queue.py::test_unique * Remove obsolete test * Applied black formating * Add a test with NaNs and axis not None * For complex dtype the result may vary
1 parent 028e44c commit 7387990

File tree

5 files changed

+74
-20
lines changed

5 files changed

+74
-20
lines changed

dpnp/dpnp_iface_manipulation.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def _unique_build_sort_indices(a, index_sh):
188188
189189
"""
190190

191-
is_complex = dpnp.iscomplexobj(a)
191+
is_inexact = dpnp.issubdtype(a, dpnp.inexact)
192192
if dpnp.issubdtype(a.dtype, numpy.unsignedinteger):
193193
ar_cmp = a.astype(dpnp.intp)
194194
elif dpnp.issubdtype(a.dtype, dpnp.bool):
@@ -200,8 +200,27 @@ def compare_axis_elems(idx1, idx2):
200200
comp = dpnp.trim_zeros(ar_cmp[idx1] - ar_cmp[idx2], "f")
201201
if comp.shape[0] > 0:
202202
diff = comp[0]
203-
if is_complex and dpnp.isnan(diff):
204-
return True
203+
if is_inexact and dpnp.isnan(diff):
204+
isnan1 = dpnp.isnan(ar_cmp[idx1])
205+
if not isnan1.any(): # no NaN in ar_cmp[idx1]
206+
return True # ar_cmp[idx1] goes to left
207+
208+
isnan2 = dpnp.isnan(ar_cmp[idx2])
209+
if not isnan2.any(): # no NaN in ar_cmp[idx2]
210+
return False # ar_cmp[idx1] goes to right
211+
212+
# for complex all NaNs are considered equivalent
213+
if (isnan1 & isnan2).all(): # NaNs at the same places
214+
return False # ar_cmp[idx1] goes to right
215+
216+
xor_nan_idx = dpnp.where(isnan1 ^ isnan2)[0]
217+
if xor_nan_idx.size == 0:
218+
return False
219+
220+
if dpnp.isnan(ar_cmp[idx2][xor_nan_idx[0]]):
221+
# first NaN in XOR mask is from ar_cmp[idx2]
222+
return True # ar_cmp[idx1] goes to left
223+
return False
205224
return diff < 0
206225
return False
207226

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ exclude-protected = ["_create_from_usm_ndarray"]
1818
max-args = 11
1919
max-locals = 30
2020
max-branches = 15
21-
max-returns = 7
21+
max-returns = 8
2222

2323
[tool.pylint.format]
2424
max-line-length = 80

tests/test_manipulation.py

Lines changed: 45 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .helper import (
1010
get_all_dtypes,
1111
get_complex_dtypes,
12+
get_float_complex_dtypes,
1213
get_float_dtypes,
1314
get_integer_dtypes,
1415
has_support_aspect64,
@@ -88,21 +89,6 @@ def test_result_type_only_arrays():
8889
assert dpnp.result_type(*X) == numpy.result_type(*X_np)
8990

9091

91-
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
92-
@pytest.mark.parametrize(
93-
"array",
94-
[[1, 2, 3], [1, 2, 2, 1, 2, 4], [2, 2, 2, 2], []],
95-
ids=["[1, 2, 3]", "[1, 2, 2, 1, 2, 4]", "[2, 2, 2, 2]", "[]"],
96-
)
97-
def test_unique(array):
98-
np_a = numpy.array(array)
99-
dpnp_a = dpnp.array(array)
100-
101-
expected = numpy.unique(np_a)
102-
result = dpnp.unique(dpnp_a)
103-
assert_array_equal(result, expected)
104-
105-
10692
class TestRepeat:
10793
@pytest.mark.parametrize(
10894
"data",
@@ -748,3 +734,47 @@ def test_equal_nan(self, eq_nan_kwd):
748734
result = dpnp.unique(ia, **eq_nan_kwd)
749735
expected = numpy.unique(a, **eq_nan_kwd)
750736
assert_array_equal(result, expected)
737+
738+
@pytest.mark.parametrize("dt", get_float_complex_dtypes())
739+
@pytest.mark.parametrize(
740+
"axis_kwd",
741+
[
742+
{},
743+
{"axis": 0},
744+
{"axis": 1},
745+
],
746+
)
747+
@pytest.mark.parametrize(
748+
"return_kwds",
749+
[
750+
{},
751+
{
752+
"return_index": True,
753+
"return_inverse": True,
754+
"return_counts": True,
755+
},
756+
],
757+
)
758+
@pytest.mark.parametrize(
759+
"row", [[2, 3, 4], [2, numpy.nan, 4], [numpy.nan, 3, 4]]
760+
)
761+
def test_2d_axis_nans(self, dt, axis_kwd, return_kwds, row):
762+
a = numpy.array(
763+
[
764+
[1, 0, 0],
765+
[1, 0, 0],
766+
[numpy.nan, numpy.nan, numpy.nan],
767+
row,
768+
[1, 0, 1],
769+
[numpy.nan, numpy.nan, numpy.nan],
770+
]
771+
).astype(dt)
772+
ia = dpnp.array(a)
773+
774+
result = dpnp.unique(ia, **axis_kwd, **return_kwds)
775+
expected = numpy.unique(a, **axis_kwd, **return_kwds)
776+
if len(return_kwds) == 0:
777+
assert_array_equal(result, expected)
778+
else:
779+
for iv, v in zip(result, expected):
780+
assert_array_equal(iv, v)

tests/test_sycl_queue.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2393,6 +2393,11 @@ def test_astype(device_x, device_y):
23932393

23942394

23952395
@pytest.mark.parametrize("axis", [None, 0, -1])
2396+
@pytest.mark.parametrize(
2397+
"device",
2398+
valid_devices,
2399+
ids=[device.filter_string for device in valid_devices],
2400+
)
23962401
def test_unique(axis, device):
23972402
a = numpy.array([[1, 1], [2, 3]])
23982403
ia = dpnp.array(a, device=device)

tests/third_party/cupy/manipulation_tests/test_add_remove.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ def test_unique_equal_nan(self, xp, dtype, equal_nan):
300300
[[2, xp.nan, 2], [xp.nan, 1, xp.nan], [xp.nan, 1, xp.nan]],
301301
dtype=dtype,
302302
)
303-
return xp.unique(a, axis=0, equal_nan=equal_nan)
303+
return xp.unique(a, axis=1, equal_nan=equal_nan)
304304

305305

306306
@testing.parameterize(*testing.product({"trim": ["fb", "f", "b"]}))

0 commit comments

Comments
 (0)