Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jnp.unique: make return_inverse shape match NumPy 2.0 #19320

Merged
merged 1 commit into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ Remember to align the itemized text with the first line of an item within a list
`from jax.experimental import export`. The old way of importing will
continue to work for a deprecation period of 3 months.
* Added {func}`jax.scipy.stats.sem`.
* {func}`jax.numpy.unique` with `return_inverse = True` returns inverse indices
reshaped to the dimension of the input, following a similar change to
{func}`numpy.unique` in NumPy 2.0.

* Deprecations & Removals
* A number of previously deprecated functions have been removed, following a
Expand Down
21 changes: 15 additions & 6 deletions jax/_src/numpy/setops.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from jax._src.numpy.reductions import any, cumsum
from jax._src.numpy.ufuncs import isnan
from jax._src.numpy.util import check_arraylike, _wraps
from jax._src.util import canonicalize_axis
from jax._src.typing import Array, ArrayLike


Expand Down Expand Up @@ -256,6 +257,8 @@ def _unique(ar: Array, axis: int, return_index: bool = False, return_inverse: bo
"""
Find the unique elements of an array along a particular axis.
"""
axis = canonicalize_axis(axis, ar.ndim)

if ar.shape[axis] == 0 and size and fill_value is None:
raise ValueError(
"jnp.unique: for zero-sized input with nonzero size argument, fill_value must be specified")
Expand Down Expand Up @@ -289,6 +292,8 @@ def _unique(ar: Array, axis: int, return_index: bool = False, return_inverse: bo
inv_idx = inv_idx.at[perm].set(imask)
else:
inv_idx = zeros(ar.shape[axis], dtype=int)
if ar.ndim > 1:
inv_idx = lax.expand_dims(inv_idx, [i for i in range(ar.ndim) if i != axis],)
ret += (inv_idx,)
if return_counts:
if aux.size:
Expand Down Expand Up @@ -332,12 +337,18 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal
size = core.concrete_or_error(operator.index, size,
"The error arose for the size argument of jnp.unique(). " + UNIQUE_SIZE_HINT)
arr = asarray(ar)
arr_shape = arr.shape
if axis is None:
axis = 0
axis_int: int = 0
arr = arr.flatten()
axis_int: int = core.concrete_or_error(operator.index, axis, "axis argument of jnp.unique()")
return _unique(arr, axis_int, return_index, return_inverse,
return_counts, equal_nan=equal_nan, size=size, fill_value=fill_value)
else:
axis_int = canonicalize_axis(axis, arr.ndim)
result = _unique(arr, axis_int, return_index, return_inverse, return_counts,
equal_nan=equal_nan, size=size, fill_value=fill_value)
if return_inverse and axis is None:
idx = 2 if return_index else 1
result = (*result[:idx], result[idx].reshape(arr_shape), *result[idx + 1:])
return result


class _UniqueAllResult(NamedTuple):
Expand All @@ -362,7 +373,6 @@ def unique_all(x: ArrayLike, /) -> _UniqueAllResult:
check_arraylike("unique_all", x)
values, indices, inverse_indices, counts = unique(
x, return_index=True, return_inverse=True, return_counts=True, equal_nan=False)
inverse_indices = inverse_indices.reshape(np.shape(x))
return _UniqueAllResult(values=values, indices=indices, inverse_indices=inverse_indices, counts=counts)


Expand All @@ -377,7 +387,6 @@ def unique_counts(x: ArrayLike, /) -> _UniqueCountsResult:
def unique_inverse(x: ArrayLike, /) -> _UniqueInverseResult:
check_arraylike("unique_inverse", x)
values, inverse_indices = unique(x, return_inverse=True, equal_nan=False)
inverse_indices = inverse_indices.reshape(np.shape(x))
return _UniqueInverseResult(values=values, inverse_indices=inverse_indices)


Expand Down
3 changes: 3 additions & 0 deletions jax/experimental/sparse/bcoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1453,6 +1453,9 @@ def _unique_indices_unbatched(indices, *, shape, return_inverse=False,
# TODO: check if `indices_sorted` is True.
out = _unique(indices, axis=0, return_inverse=return_inverse, return_index=return_index,
return_true_size=return_true_size, size=props.nse, fill_value=fill_value)
if return_inverse:
idx = 2 if return_index else 1
out = (*out[:idx], out[idx].ravel(), *out[idx + 1:])
if return_true_size:
nse = out[-1]
nse = nse - (indices == fill_value).any().astype(nse.dtype)
Expand Down
30 changes: 21 additions & 9 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,23 @@
# uint64 is problematic because with any uint type it promotes to float:
int_dtypes_no_uint64 = [d for d in int_dtypes + unsigned_dtypes if d != np.uint64]

def np_unique_backport(ar, return_index=False, return_inverse=False, return_counts=False,
axis=None, **kwds):
# Wrapper for np.unique, handling the change to inverse_indices in numpy 2.0
result = np.unique(ar, return_index=return_index, return_inverse=return_inverse,
return_counts=return_counts, axis=axis, **kwds)
if jtu.numpy_version() >= (2, 0, 0) or np.ndim(ar) == 1 or not return_inverse:
return result

idx = 2 if return_index else 1
inverse_indices = result[idx]
if axis is None:
inverse_indices = inverse_indices.reshape(np.shape(ar))
else:
inverse_indices = np.expand_dims(inverse_indices, [i for i in range(np.ndim(ar)) if i != axis])
return (*result[:idx], inverse_indices, *result[idx + 1:])


def _indexer_with_default_outputs(indexer, use_defaults=True):
"""Like jtu.with_jax_dtype_defaults, but for __getitem__ APIs"""
class Indexer:
Expand Down Expand Up @@ -1818,7 +1835,7 @@ def testUnique(self, shape, dtype, axis, return_index, return_inverse, return_co
args_maker = lambda: [rng(shape, dtype)]
extra_args = (return_index, return_inverse, return_counts)
use_defaults = (False, *(True for arg in extra_args if arg)) if any(extra_args) else False
np_fun = jtu.with_jax_dtype_defaults(lambda x: np.unique(x, *extra_args, axis=axis), use_defaults)
np_fun = jtu.with_jax_dtype_defaults(lambda x: np_unique_backport(x, *extra_args, axis=axis), use_defaults)
jnp_fun = lambda x: jnp.unique(x, *extra_args, axis=axis)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)

Expand All @@ -1827,10 +1844,7 @@ def testUniqueAll(self, shape, dtype):
rng = jtu.rand_some_equal(self.rng())
args_maker = lambda: [rng(shape, dtype)]
if jtu.numpy_version() < (2, 0, 0):
def np_fun(x):
values, indices, inverse_indices, counts = np.unique(
x, return_index=True, return_inverse=True, return_counts=True)
return values, indices, inverse_indices.reshape(np.shape(x)), counts
np_fun = partial(np_unique_backport, return_index=True, return_inverse=True, return_counts=True)
else:
np_fun = np.unique_all
self._CheckAgainstNumpy(jnp.unique_all, np_fun, args_maker)
Expand All @@ -1850,9 +1864,7 @@ def testUniqueInverse(self, shape, dtype):
rng = jtu.rand_some_equal(self.rng())
args_maker = lambda: [rng(shape, dtype)]
if jtu.numpy_version() < (2, 0, 0):
def np_fun(x):
values, inverse_indices = np.unique(x, return_inverse=True)
return values, inverse_indices.reshape(np.shape(x))
np_fun = partial(np_unique_backport, return_inverse=True)
else:
np_fun = np.unique_inverse
self._CheckAgainstNumpy(jnp.unique_inverse, np_fun, args_maker)
Expand Down Expand Up @@ -1888,7 +1900,7 @@ def testUniqueSize(self, shape, dtype, axis, size, fill_value):

@partial(jtu.with_jax_dtype_defaults, use_defaults=(False, True, True, True))
def np_fun(x, fill_value=fill_value):
u, ind, inv, counts = np.unique(x, **kwds)
u, ind, inv, counts = np_unique_backport(x, **kwds)
axis = kwds['axis']
if axis is None:
x = x.ravel()
Expand Down