Skip to content

Commit 287b834

Browse files
kgrytemdhaber
andauthored
fix: address incorrect boundary conditions in searchsorted
PR-URL: #898 Closes: #861 Reviewed-by: Matt Haberland <mhaberla@calpoly.edu> Co-authored-by: Matt Haberland <mhaberla@calpoly.edu>
1 parent 0a425d1 commit 287b834

File tree

2 files changed

+12
-18
lines changed

2 files changed

+12
-18
lines changed

src/array_api_stubs/_2023_12/searching_functions.py

+6-9
Original file line numberDiff line numberDiff line change
@@ -107,17 +107,14 @@ def searchsorted(
107107
side: Literal['left', 'right']
108108
argument controlling which index is returned if a value lands exactly on an edge.
109109
110-
Let ``x`` be an array of rank ``N`` where ``v`` is an individual element given by ``v = x2[n,m,...,j]``.
110+
Let ``v`` be an element of ``x2`` given by ``v = x2[j]``, where ``j`` refers to a valid index (see :ref:`indexing`).
111111
112-
If ``side == 'left'``, then
112+
- If ``v`` is less than all elements in ``x1``, then ``out[j]`` must be ``0``.
113+
- If ``v`` is greater than all elements in ``x1``, then ``out[j]`` must be ``M``, where ``M`` is the number of elements in ``x1``.
114+
- Otherwise, each returned index ``i = out[j]`` must satisfy an index condition:
113115
114-
- each returned index ``i`` must satisfy the index condition ``x1[i-1] < v <= x1[i]``.
115-
- if no index satisfies the index condition, then the returned index for that element must be ``0``.
116-
117-
Otherwise, if ``side == 'right'``, then
118-
119-
- each returned index ``i`` must satisfy the index condition ``x1[i-1] <= v < x1[i]``.
120-
- if no index satisfies the index condition, then the returned index for that element must be ``N``, where ``N`` is the number of elements in ``x1``.
116+
- If ``side == 'left'``, then ``x1[i-1] < v <= x1[i]``.
117+
- If ``side == 'right'``, then ``x1[i-1] <= v < x1[i]``.
121118
122119
Default: ``'left'``.
123120
sorter: Optional[array]

src/array_api_stubs/_draft/searching_functions.py

+6-9
Original file line numberDiff line numberDiff line change
@@ -136,17 +136,14 @@ def searchsorted(
136136
side: Literal['left', 'right']
137137
argument controlling which index is returned if a value lands exactly on an edge.
138138
139-
Let ``x`` be an array of rank ``N`` where ``v`` is an individual element given by ``v = x2[n,m,...,j]``.
139+
Let ``v`` be an element of ``x2`` given by ``v = x2[j]``, where ``j`` refers to a valid index (see :ref:`indexing`).
140140
141-
If ``side == 'left'``, then
141+
- If ``v`` is less than all elements in ``x1``, then ``out[j]`` must be ``0``.
142+
- If ``v`` is greater than all elements in ``x1``, then ``out[j]`` must be ``M``, where ``M`` is the number of elements in ``x1``.
143+
- Otherwise, each returned index ``i = out[j]`` must satisfy an index condition:
142144
143-
- each returned index ``i`` must satisfy the index condition ``x1[i-1] < v <= x1[i]``.
144-
- if no index satisfies the index condition, then the returned index for that element must be ``0``.
145-
146-
Otherwise, if ``side == 'right'``, then
147-
148-
- each returned index ``i`` must satisfy the index condition ``x1[i-1] <= v < x1[i]``.
149-
- if no index satisfies the index condition, then the returned index for that element must be ``N``, where ``N`` is the number of elements in ``x1``.
145+
- If ``side == 'left'``, then ``x1[i-1] < v <= x1[i]``.
146+
- If ``side == 'right'``, then ``x1[i-1] <= v < x1[i]``.
150147
151148
Default: ``'left'``.
152149
sorter: Optional[array]

0 commit comments

Comments
 (0)