From fa6d3f26ff905e9ce7e746131a8c2e6336e0fd51 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 11 Jan 2024 12:15:52 -0800 Subject: [PATCH] jnp.unique: make return_inverse shape match NumPy 2.0 --- CHANGELOG.md | 3 +++ jax/_src/numpy/setops.py | 21 +++++++++++++++------ jax/experimental/sparse/bcoo.py | 3 +++ tests/lax_numpy_test.py | 30 +++++++++++++++++++++--------- 4 files changed, 42 insertions(+), 15 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 71303d206b03..9ec2c140c99e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -45,6 +45,9 @@ Remember to align the itemized text with the first line of an item within a list `from jax.experimental import export`. The old way of importing will continue to work for a deprecation period of 3 months. * Added {func}`jax.scipy.stats.sem`. + * {func}`jax.numpy.unique` with `return_inverse = True` returns inverse indices + reshaped to the dimension of the input, following a similar change to + {func}`numpy.unique` in NumPy 2.0. * Deprecations & Removals * A number of previously deprecated functions have been removed, following a diff --git a/jax/_src/numpy/setops.py b/jax/_src/numpy/setops.py index 32d0ac06e59c..4caf608396b2 100644 --- a/jax/_src/numpy/setops.py +++ b/jax/_src/numpy/setops.py @@ -35,6 +35,7 @@ from jax._src.numpy.reductions import any, cumsum from jax._src.numpy.ufuncs import isnan from jax._src.numpy.util import check_arraylike, _wraps +from jax._src.util import canonicalize_axis from jax._src.typing import Array, ArrayLike @@ -256,6 +257,8 @@ def _unique(ar: Array, axis: int, return_index: bool = False, return_inverse: bo """ Find the unique elements of an array along a particular axis. """ + axis = canonicalize_axis(axis, ar.ndim) + if ar.shape[axis] == 0 and size and fill_value is None: raise ValueError( "jnp.unique: for zero-sized input with nonzero size argument, fill_value must be specified") @@ -289,6 +292,8 @@ def _unique(ar: Array, axis: int, return_index: bool = False, return_inverse: bo inv_idx = inv_idx.at[perm].set(imask) else: inv_idx = zeros(ar.shape[axis], dtype=int) + if ar.ndim > 1: + inv_idx = lax.expand_dims(inv_idx, [i for i in range(ar.ndim) if i != axis],) ret += (inv_idx,) if return_counts: if aux.size: @@ -332,12 +337,18 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal size = core.concrete_or_error(operator.index, size, "The error arose for the size argument of jnp.unique(). " + UNIQUE_SIZE_HINT) arr = asarray(ar) + arr_shape = arr.shape if axis is None: - axis = 0 + axis_int: int = 0 arr = arr.flatten() - axis_int: int = core.concrete_or_error(operator.index, axis, "axis argument of jnp.unique()") - return _unique(arr, axis_int, return_index, return_inverse, - return_counts, equal_nan=equal_nan, size=size, fill_value=fill_value) + else: + axis_int = canonicalize_axis(axis, arr.ndim) + result = _unique(arr, axis_int, return_index, return_inverse, return_counts, + equal_nan=equal_nan, size=size, fill_value=fill_value) + if return_inverse and axis is None: + idx = 2 if return_index else 1 + result = (*result[:idx], result[idx].reshape(arr_shape), *result[idx + 1:]) + return result class _UniqueAllResult(NamedTuple): @@ -362,7 +373,6 @@ def unique_all(x: ArrayLike, /) -> _UniqueAllResult: check_arraylike("unique_all", x) values, indices, inverse_indices, counts = unique( x, return_index=True, return_inverse=True, return_counts=True, equal_nan=False) - inverse_indices = inverse_indices.reshape(np.shape(x)) return _UniqueAllResult(values=values, indices=indices, inverse_indices=inverse_indices, counts=counts) @@ -377,7 +387,6 @@ def unique_counts(x: ArrayLike, /) -> _UniqueCountsResult: def unique_inverse(x: ArrayLike, /) -> _UniqueInverseResult: check_arraylike("unique_inverse", x) values, inverse_indices = unique(x, return_inverse=True, equal_nan=False) - inverse_indices = inverse_indices.reshape(np.shape(x)) return _UniqueInverseResult(values=values, inverse_indices=inverse_indices) diff --git a/jax/experimental/sparse/bcoo.py b/jax/experimental/sparse/bcoo.py index 9f639af0ce23..b6572b01cadf 100644 --- a/jax/experimental/sparse/bcoo.py +++ b/jax/experimental/sparse/bcoo.py @@ -1453,6 +1453,9 @@ def _unique_indices_unbatched(indices, *, shape, return_inverse=False, # TODO: check if `indices_sorted` is True. out = _unique(indices, axis=0, return_inverse=return_inverse, return_index=return_index, return_true_size=return_true_size, size=props.nse, fill_value=fill_value) + if return_inverse: + idx = 2 if return_index else 1 + out = (*out[:idx], out[idx].ravel(), *out[idx + 1:]) if return_true_size: nse = out[-1] nse = nse - (indices == fill_value).any().astype(nse.dtype) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 29db81e3d24f..b4303442faee 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -87,6 +87,23 @@ # uint64 is problematic because with any uint type it promotes to float: int_dtypes_no_uint64 = [d for d in int_dtypes + unsigned_dtypes if d != np.uint64] +def np_unique_backport(ar, return_index=False, return_inverse=False, return_counts=False, + axis=None, **kwds): + # Wrapper for np.unique, handling the change to inverse_indices in numpy 2.0 + result = np.unique(ar, return_index=return_index, return_inverse=return_inverse, + return_counts=return_counts, axis=axis, **kwds) + if jtu.numpy_version() >= (2, 0, 0) or np.ndim(ar) == 1 or not return_inverse: + return result + + idx = 2 if return_index else 1 + inverse_indices = result[idx] + if axis is None: + inverse_indices = inverse_indices.reshape(np.shape(ar)) + else: + inverse_indices = np.expand_dims(inverse_indices, [i for i in range(np.ndim(ar)) if i != axis]) + return (*result[:idx], inverse_indices, *result[idx + 1:]) + + def _indexer_with_default_outputs(indexer, use_defaults=True): """Like jtu.with_jax_dtype_defaults, but for __getitem__ APIs""" class Indexer: @@ -1818,7 +1835,7 @@ def testUnique(self, shape, dtype, axis, return_index, return_inverse, return_co args_maker = lambda: [rng(shape, dtype)] extra_args = (return_index, return_inverse, return_counts) use_defaults = (False, *(True for arg in extra_args if arg)) if any(extra_args) else False - np_fun = jtu.with_jax_dtype_defaults(lambda x: np.unique(x, *extra_args, axis=axis), use_defaults) + np_fun = jtu.with_jax_dtype_defaults(lambda x: np_unique_backport(x, *extra_args, axis=axis), use_defaults) jnp_fun = lambda x: jnp.unique(x, *extra_args, axis=axis) self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) @@ -1827,10 +1844,7 @@ def testUniqueAll(self, shape, dtype): rng = jtu.rand_some_equal(self.rng()) args_maker = lambda: [rng(shape, dtype)] if jtu.numpy_version() < (2, 0, 0): - def np_fun(x): - values, indices, inverse_indices, counts = np.unique( - x, return_index=True, return_inverse=True, return_counts=True) - return values, indices, inverse_indices.reshape(np.shape(x)), counts + np_fun = partial(np_unique_backport, return_index=True, return_inverse=True, return_counts=True) else: np_fun = np.unique_all self._CheckAgainstNumpy(jnp.unique_all, np_fun, args_maker) @@ -1850,9 +1864,7 @@ def testUniqueInverse(self, shape, dtype): rng = jtu.rand_some_equal(self.rng()) args_maker = lambda: [rng(shape, dtype)] if jtu.numpy_version() < (2, 0, 0): - def np_fun(x): - values, inverse_indices = np.unique(x, return_inverse=True) - return values, inverse_indices.reshape(np.shape(x)) + np_fun = partial(np_unique_backport, return_inverse=True) else: np_fun = np.unique_inverse self._CheckAgainstNumpy(jnp.unique_inverse, np_fun, args_maker) @@ -1888,7 +1900,7 @@ def testUniqueSize(self, shape, dtype, axis, size, fill_value): @partial(jtu.with_jax_dtype_defaults, use_defaults=(False, True, True, True)) def np_fun(x, fill_value=fill_value): - u, ind, inv, counts = np.unique(x, **kwds) + u, ind, inv, counts = np_unique_backport(x, **kwds) axis = kwds['axis'] if axis is None: x = x.ravel()