Skip to content

Commit

Permalink
Add searchsorted (#647)
Browse files Browse the repository at this point in the history
* Add `searchsorted`

* Use `where` rather than in-place modification for JAX support
  • Loading branch information
tomwhite authored Dec 18, 2024
1 parent 5993f59 commit fc2201e
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 5 deletions.
4 changes: 2 additions & 2 deletions cubed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,9 +323,9 @@
"unstack",
]

from .array_api.searching_functions import argmax, argmin, where
from .array_api.searching_functions import argmax, argmin, searchsorted, where

__all__ += ["argmax", "argmin", "where"]
__all__ += ["argmax", "argmin", "searchsorted", "where"]

from .array_api.statistical_functions import max, mean, min, prod, std, sum, var

Expand Down
4 changes: 2 additions & 2 deletions cubed/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,9 @@
"unstack",
]

from .searching_functions import argmax, argmin, where
from .searching_functions import argmax, argmin, searchsorted, where

__all__ += ["argmax", "argmin", "where"]
__all__ += ["argmax", "argmin", "searchsorted", "where"]

from .statistical_functions import max, mean, min, prod, std, sum, var

Expand Down
52 changes: 51 additions & 1 deletion cubed/array_api/searching_functions.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from cubed.array_api.creation_functions import asarray, zeros_like
from cubed.array_api.data_type_functions import result_type
from cubed.array_api.dtypes import _real_numeric_dtypes
from cubed.array_api.manipulation_functions import reshape
from cubed.array_api.statistical_functions import max
from cubed.backend_array_api import namespace as nxp
from cubed.core.ops import arg_reduction, elemwise
from cubed.core.ops import arg_reduction, blockwise, elemwise


def argmax(x, /, *, axis=None, keepdims=False, split_every=None):
Expand Down Expand Up @@ -37,6 +39,54 @@ def argmin(x, /, *, axis=None, keepdims=False, split_every=None):
)


def searchsorted(x1, x2, /, *, side="left", sorter=None):
if x1.ndim != 1:
raise ValueError("Input array x1 must be one dimensional")

if sorter is not None:
raise NotImplementedError(
"searchsorted with a sorter argument is not supported"
)

# call nxp.searchsorted for each pair of blocks in x1 and v
out = blockwise(
_searchsorted,
list(range(x2.ndim + 1)),
x1,
[0],
x2,
list(range(1, x2.ndim + 1)),
dtype=nxp.int64, # TODO: index dtype
adjust_chunks={0: 1}, # one row for each block in x1
side=side,
)

# add offsets to take account of the position of each block within the array x1
x1_chunk_sizes = nxp.asarray((0, *x1.chunks[0]))
x1_chunk_offsets = nxp.cumulative_sum(x1_chunk_sizes)[:-1]
x1_chunk_offsets = x1_chunk_offsets[(Ellipsis,) + x2.ndim * (nxp.newaxis,)]
x1_offsets = asarray(x1_chunk_offsets, chunks=1)
out = where(out < 0, out, out + x1_offsets)

# combine the results from each block (of a)
out = max(out, axis=0)

# fix up any -1 values
# TODO: use general_blockwise which has block_id to avoid this
out = where(out >= 0, out, zeros_like(out))

return out


def _searchsorted(x, y, side):
res = nxp.searchsorted(x, y, side=side)
# 0 is only correct for the first block of a, but blockwise doesn't have a way
# of telling which block is being operated on (unlike map_blocks),
# so set all 0 values to a special value and set back at the end of searchsorted
res = nxp.where(res == 0, -1, res)
return res[nxp.newaxis, :]


def where(condition, x1, x2, /):
dtype = result_type(x1, x2)
return elemwise(nxp.where, condition, x1, x2, dtype=dtype)
31 changes: 31 additions & 0 deletions cubed/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,37 @@ def test_argmin_axis_0(spec):
)


@pytest.mark.parametrize(
"x1, x1_chunks, x2, x2_chunks",
[
[[], 1, [], 1],
[[0], 1, [0], 1],
[[-10, 0, 10, 20, 30], 3, [11, 30], 2],
[[-10, 0, 10, 20, 30], 3, [11, 30, -20, 1, -10, 10, 37, 11], 5],
[[-10, 0, 10, 20, 30], 3, [[11, 30, -20, 1, -10, 10, 37, 11]], 5],
[[-10, 0, 10, 20, 30], 3, [[7, 0], [-10, 10], [11, -1], [15, 15]], (2, 2)],
],
)
@pytest.mark.parametrize("side", ["left", "right"])
def test_searchsorted(x1, x1_chunks, x2, x2_chunks, side):
x1 = np.array(x1)
x2 = np.array(x2)

x1d = xp.asarray(x1, chunks=x1_chunks)
x2d = xp.asarray(x2, chunks=x2_chunks)

out = xp.searchsorted(x1d, x2d, side=side)

assert out.shape == x2d.shape
assert out.chunks == x2d.chunks
assert_array_equal(out.compute(), np.searchsorted(x1, x2, side=side))


def test_searchsorted_sorter_not_implemented():
with pytest.raises(NotImplementedError):
xp.searchsorted(xp.asarray([1, 0]), xp.asarray([1]), sorter=xp.asarray([1, 0]))


# Statistical functions


Expand Down

0 comments on commit fc2201e

Please sign in to comment.