From ba70bb1040da21cb9a203f2acbf35d41d1d5b7fe Mon Sep 17 00:00:00 2001 From: Alex Dragan <35031007+aldragan0@users.noreply.github.com> Date: Mon, 13 Jul 2020 08:32:41 +0300 Subject: [PATCH] Implement np.intersect1d (#3726) * Implement np.intersect1d * Add jitable helper to function * Fix argsort failing tests * Fix linter errors --- docs/jax.numpy.rst | 1 + jax/numpy/__init__.py | 2 +- jax/numpy/lax_numpy.py | 51 +++++++++++++++++++++++++++++++++++++++++ tests/lax_numpy_test.py | 22 ++++++++++++++++++ 4 files changed, 75 insertions(+), 1 deletion(-) diff --git a/docs/jax.numpy.rst b/docs/jax.numpy.rst index fb662338ce5f..f44be9f7054f 100644 --- a/docs/jax.numpy.rst +++ b/docs/jax.numpy.rst @@ -153,6 +153,7 @@ Not every function in NumPy is implemented; contributions are welcome! iscomplex isfinite isin + intersect1d isinf isnan isneginf diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index adb920b09aae..9a5fdc6de884 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -37,7 +37,7 @@ fmod, frexp, full, full_like, function, gcd, geomspace, gradient, greater, greater_equal, hamming, hanning, heaviside, histogram, histogram_bin_edges, hsplit, hstack, hypot, identity, iinfo, imag, - indices, inexact, in1d, inf, inner, int16, int32, int64, int8, int_, integer, + indices, inexact, in1d, inf, inner, int16, int32, int64, int8, int_, integer, intersect1d, isclose, iscomplex, iscomplexobj, isfinite, isin, isinf, isnan, isneginf, isposinf, isreal, isrealobj, isscalar, issubdtype, issubsctype, iterable, ix_, kaiser, kron, lcm, ldexp, left_shift, less, less_equal, linspace, diff --git a/jax/numpy/lax_numpy.py b/jax/numpy/lax_numpy.py index 02b0d8b77206..660de682345e 100644 --- a/jax/numpy/lax_numpy.py +++ b/jax/numpy/lax_numpy.py @@ -1243,6 +1243,57 @@ def in1d(ar1, ar2, assume_unique=False, invert=False): else: return (ar1[:, None] == ar2).any(-1) +@partial(jit, static_argnums=2) +def _intersect1d_sorted_mask(ar1, ar2, return_indices=False): + """ + Helper function for intersect1d which is jit-able + """ + ar = concatenate((ar1, ar2)) + + if return_indices: + indices = argsort(ar) + aux = ar[indices] + else: + aux = sort(ar) + + mask = aux[1:] == aux[:-1] + if return_indices: + return aux, mask, indices + else: + return aux, mask + +@_wraps(np.intersect1d) +def intersect1d(ar1, ar2, assume_unique=False, return_indices=False): + + if not assume_unique: + if return_indices: + ar1, ind1 = unique(ar1, return_index=True) + ar2, ind2 = unique(ar2, return_index=True) + else: + ar1 = unique(ar1) + ar2 = unique(ar2) + else: + ar1 = ravel(ar1) + ar2 = ravel(ar2) + + if return_indices: + aux, mask, aux_sort_indices = _intersect1d_sorted_mask(ar1, ar2, return_indices) + else: + aux, mask = _intersect1d_sorted_mask(ar1, ar2, return_indices) + + int1d = aux[:-1][mask] + + if return_indices: + ar1_indices = aux_sort_indices[:-1][mask] + ar2_indices = aux_sort_indices[1:][mask] - ar1.size + if not assume_unique: + ar1_indices = ind1[ar1_indices] + ar2_indices = ind2[ar2_indices] + + return int1d, ar1_indices, ar2_indices + else: + return int1d + @_wraps(np.isin, lax_description=""" In the JAX version, the `assume_unique` argument is not referenced. diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 7b6ca4202062..8246fe43fafc 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -957,6 +957,28 @@ def testIn1d(self, element_shape, test_shape, dtype, invert): self._CompileAndCheck(jnp_fun, args_maker) + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_{}_assume_unique={}_return_indices={}".format( + jtu.format_shape_dtype_string(shape1, dtype1), + jtu.format_shape_dtype_string(shape2, dtype2), + assume_unique, + return_indices), + "shape1": shape1, "dtype1": dtype1, "shape2": shape2, "dtype2": dtype2, + "assume_unique": assume_unique, "return_indices": return_indices} + for dtype1 in [s for s in default_dtypes if s != jnp.bfloat16] + for dtype2 in [s for s in default_dtypes if s != jnp.bfloat16] + for shape1 in all_shapes + for shape2 in all_shapes + for assume_unique in [False, True] + for return_indices in [False, True])) + def testIntersect1d(self, shape1, dtype1, shape2, dtype2, assume_unique, return_indices): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] + jnp_fun = lambda ar1, ar2: jnp.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices) + np_fun = lambda ar1, ar2: np.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) + + @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_{}".format( jtu.format_shape_dtype_string(lhs_shape, lhs_dtype),