Skip to content

Commit

Permalink
unique_all and unique_inverse inverse indices shape fixed
Browse files Browse the repository at this point in the history
Were previously returning a 1D array of indices rather than an array with the same shape as input `x`
  • Loading branch information
ndgrigorian committed Dec 20, 2023
1 parent a231d56 commit 60c6ad6
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 23 deletions.
36 changes: 13 additions & 23 deletions dpctl/tensor/_set_functions_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,19 +293,14 @@ def unique_inverse(x):
exec_q = array_api_dev.sycl_queue
x_usm_type = x.usm_type
ind_dt = default_device_index_type(exec_q)
if x.ndim == 0:
return UniqueInverseResult(
dpt.reshape(x, (1,), order="C", copy=True),
dpt.zeros_like(x, ind_dt, usm_type=x_usm_type, sycl_queue=exec_q),
)
elif x.ndim == 1:
if x.ndim == 1:
fx = x
else:
fx = dpt.reshape(x, (x.size,), order="C")
sorting_ids = dpt.empty_like(fx, dtype=ind_dt, order="C")
unsorting_ids = dpt.empty_like(sorting_ids, dtype=ind_dt, order="C")
if fx.size == 0:
return UniqueInverseResult(fx, unsorting_ids)
return UniqueInverseResult(fx, dpt.reshape(unsorting_ids, x.shape))
host_tasks = []
if fx.flags.c_contiguous:
ht_ev, sort_ev = _argsort_ascending(
Expand Down Expand Up @@ -366,7 +361,7 @@ def unique_inverse(x):
)
if n_uniques == fx.size:
dpctl.SyclEvent.wait_for(host_tasks)
return UniqueInverseResult(s, unsorting_ids)
return UniqueInverseResult(s, dpt.reshape(unsorting_ids, x.shape))
unique_vals = dpt.empty(
n_uniques, dtype=x.dtype, usm_type=x_usm_type, sycl_queue=exec_q
)
Expand Down Expand Up @@ -422,7 +417,9 @@ def unique_inverse(x):
pos = pos_next
host_tasks.append(ht_ev)
dpctl.SyclEvent.wait_for(host_tasks)
return UniqueInverseResult(unique_vals, inv[unsorting_ids])
return UniqueInverseResult(
unique_vals, dpt.reshape(inv[unsorting_ids], x.shape)
)


def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult:
Expand Down Expand Up @@ -462,17 +459,7 @@ def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult:
exec_q = array_api_dev.sycl_queue
x_usm_type = x.usm_type
ind_dt = default_device_index_type(exec_q)
if x.ndim == 0:
uv = dpt.reshape(x, (1,), order="C", copy=True)
return UniqueAllResult(
uv,
dpt.zeros_like(uv, ind_dt, usm_type=x_usm_type, sycl_queue=exec_q),
dpt.zeros_like(x, ind_dt, usm_type=x_usm_type, sycl_queue=exec_q),
dpt.ones_like(
uv, dtype=ind_dt, usm_type=x_usm_type, sycl_queue=exec_q
),
)
elif x.ndim == 1:
if x.ndim == 1:
fx = x
else:
fx = dpt.reshape(x, (x.size,), order="C")
Expand All @@ -482,7 +469,10 @@ def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult:
# original array contains no data
# so it can be safely returned as values
return UniqueAllResult(
fx, sorting_ids, unsorting_ids, dpt.empty_like(fx, dtype=ind_dt)
fx,
sorting_ids,
dpt.reshape(unsorting_ids, x.shape),
dpt.empty_like(fx, dtype=ind_dt),
)
host_tasks = []
if fx.flags.c_contiguous:
Expand Down Expand Up @@ -550,7 +540,7 @@ def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult:
return UniqueAllResult(
s,
sorting_ids,
unsorting_ids,
dpt.reshape(unsorting_ids, x.shape),
_counts,
)
unique_vals = dpt.empty(
Expand Down Expand Up @@ -611,6 +601,6 @@ def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult:
return UniqueAllResult(
unique_vals,
sorting_ids[cum_unique_counts[:-1]],
inv[unsorting_ids],
dpt.reshape(inv[unsorting_ids], x.shape),
_counts,
)
2 changes: 2 additions & 0 deletions dpctl/tests/test_usm_ndarray_unique.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def test_unique_inverse(dtype):
uv, inv = dpt.unique_inverse(inp)
assert dpt.all(uv == dpt.arange(2, dtype=dtype))
assert dpt.all(inp == uv[inv])
assert inp.shape == inv.shape


@pytest.mark.parametrize(
Expand Down Expand Up @@ -151,6 +152,7 @@ def test_unique_all(dtype):
assert dpt.all(uv == dpt.arange(2, dtype=dtype))
assert dpt.all(uv == inp[ind])
assert dpt.all(inp == uv[inv])
assert inp.shape == inv.shape
assert dpt.all(uv_counts == dpt.full(2, n, dtype=uv_counts.dtype))


Expand Down

0 comments on commit 60c6ad6

Please sign in to comment.