From ed9855c97dbb2d704aa5b17a83b19f7b421ee9bd Mon Sep 17 00:00:00 2001 From: Jake Vanderplas Date: Wed, 6 May 2020 16:58:09 -0700 Subject: [PATCH] Add implementation of np.searchsorted (#2938) --- docs/jax.numpy.rst | 1 + jax/numpy/lax_numpy.py | 35 +++++++++++++++++++++++++++++++++++ tests/lax_numpy_test.py | 20 ++++++++++++++++++++ 3 files changed, 56 insertions(+) diff --git a/docs/jax.numpy.rst b/docs/jax.numpy.rst index 60fa05b26397..4609f085cef0 100644 --- a/docs/jax.numpy.rst +++ b/docs/jax.numpy.rst @@ -223,6 +223,7 @@ Not every function in NumPy is implemented; contributions are welcome! rot90 round row_stack + searchsorted select sign signbit diff --git a/jax/numpy/lax_numpy.py b/jax/numpy/lax_numpy.py index 8d208706a0d5..946bb26cb373 100644 --- a/jax/numpy/lax_numpy.py +++ b/jax/numpy/lax_numpy.py @@ -29,6 +29,7 @@ import collections from collections.abc import Sequence import itertools +import operator import os import re import string @@ -40,6 +41,7 @@ import opt_einsum from jax import jit, device_put, custom_jvp +from .vectorize import vectorize from ._util import _wraps from .. import core from .. import dtypes @@ -3753,6 +3755,39 @@ def _quantile(a, q, axis, interpolation, keepdims): return lax.convert_element_type(result, a.dtype) +@partial(jit, static_argnums=2) +@partial(vectorize, excluded={0, 2}) +def _searchsorted(a, v, side): + op = operator.le if side == 'left' else operator.lt + + def cond_fun(state): + start, stop = state + return stop - start > 1 + + def body_fun(state): + start, stop = state + mid = (start + stop) // 2 + go_left = op(v, a[mid]) + return where(go_left, start, mid), where(go_left, mid, stop) + + result = lax.while_loop(cond_fun, body_fun, (0, a.shape[0])) + return where(op(v, a[0]), 0, result[1]) + + +@_wraps(onp.searchsorted) +def searchsorted(a, v, side='left', sorter=None): + assert side in ['left', 'right'] + if sorter is not None: + raise NotImplementedError("sorter is not implemented") + a = asarray(a) + v = asarray(v) + if ndim(a) != 1: + raise ValueError("a should be 1-dimensional") + if size(a) == 0: + return zeros_like(v, dtype=int) + return _searchsorted(a, v, side) + + @_wraps(onp.percentile) def percentile(a, q, axis=None, out=None, overwrite_input=False, interpolation="linear", keepdims=False): diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 551b32319f10..7eb2a0c506fe 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -1411,6 +1411,26 @@ def onp_fun(arg): self._CheckAgainstNumpy(onp_fun, jnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_a={}_v={}_side={}".format( + jtu.format_shape_dtype_string(ashape, dtype), + jtu.format_shape_dtype_string(vshape, dtype), + side), "ashape": ashape, "vshape": vshape, "side": side, + "dtype": dtype, "rng_factory": rng_factory} + for ashape in [(20,)] + for vshape in [(), (5,), (5, 5)] + for side in ['left', 'right'] + for dtype in default_dtypes + for rng_factory in [jtu.rand_default] + )) + def testSearchsorted(self, ashape, vshape, side, dtype, rng_factory): + rng = rng_factory(self.rng()) + args_maker = lambda: [jnp.sort(rng(ashape, dtype)), rng(vshape, dtype)] + onp_fun = lambda a, v: onp.searchsorted(a, v, side=side) + jnp_fun = lambda a, v: jnp.searchsorted(a, v, side=side) + self._CheckAgainstNumpy(onp_fun, jnp_fun, args_maker, check_dtypes=True) + self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_axis={}".format( jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes), axis),