Skip to content

Commit

Permalink
Add implementation of np.searchsorted (jax-ml#2938)
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp authored and Jamie Townsend committed May 14, 2020
1 parent d0724e9 commit ed9855c
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/jax.numpy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ Not every function in NumPy is implemented; contributions are welcome!
rot90
round
row_stack
searchsorted
select
sign
signbit
Expand Down
35 changes: 35 additions & 0 deletions jax/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import collections
from collections.abc import Sequence
import itertools
import operator
import os
import re
import string
Expand All @@ -40,6 +41,7 @@
import opt_einsum

from jax import jit, device_put, custom_jvp
from .vectorize import vectorize
from ._util import _wraps
from .. import core
from .. import dtypes
Expand Down Expand Up @@ -3753,6 +3755,39 @@ def _quantile(a, q, axis, interpolation, keepdims):
return lax.convert_element_type(result, a.dtype)


@partial(jit, static_argnums=2)
@partial(vectorize, excluded={0, 2})
def _searchsorted(a, v, side):
op = operator.le if side == 'left' else operator.lt

def cond_fun(state):
start, stop = state
return stop - start > 1

def body_fun(state):
start, stop = state
mid = (start + stop) // 2
go_left = op(v, a[mid])
return where(go_left, start, mid), where(go_left, mid, stop)

result = lax.while_loop(cond_fun, body_fun, (0, a.shape[0]))
return where(op(v, a[0]), 0, result[1])


@_wraps(onp.searchsorted)
def searchsorted(a, v, side='left', sorter=None):
assert side in ['left', 'right']
if sorter is not None:
raise NotImplementedError("sorter is not implemented")
a = asarray(a)
v = asarray(v)
if ndim(a) != 1:
raise ValueError("a should be 1-dimensional")
if size(a) == 0:
return zeros_like(v, dtype=int)
return _searchsorted(a, v, side)


@_wraps(onp.percentile)
def percentile(a, q, axis=None, out=None, overwrite_input=False,
interpolation="linear", keepdims=False):
Expand Down
20 changes: 20 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1411,6 +1411,26 @@ def onp_fun(arg):
self._CheckAgainstNumpy(onp_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_a={}_v={}_side={}".format(
jtu.format_shape_dtype_string(ashape, dtype),
jtu.format_shape_dtype_string(vshape, dtype),
side), "ashape": ashape, "vshape": vshape, "side": side,
"dtype": dtype, "rng_factory": rng_factory}
for ashape in [(20,)]
for vshape in [(), (5,), (5, 5)]
for side in ['left', 'right']
for dtype in default_dtypes
for rng_factory in [jtu.rand_default]
))
def testSearchsorted(self, ashape, vshape, side, dtype, rng_factory):
rng = rng_factory(self.rng())
args_maker = lambda: [jnp.sort(rng(ashape, dtype)), rng(vshape, dtype)]
onp_fun = lambda a, v: onp.searchsorted(a, v, side=side)
jnp_fun = lambda a, v: jnp.searchsorted(a, v, side=side)
self._CheckAgainstNumpy(onp_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}".format(
jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes), axis),
Expand Down

0 comments on commit ed9855c

Please sign in to comment.