Skip to content

Commit

Permalink
Merge pull request #24717 from jakevdp:fix-rankdata
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 693416741
  • Loading branch information
Google-ML-Automation committed Nov 5, 2024
2 parents c1af808 + 5f90f63 commit 939b41f
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions jax/_src/scipy/stats/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,13 +198,12 @@ def rankdata(
return jnp.apply_along_axis(rankdata, axis, a, method)

arr = jnp.ravel(a)
sorter = jnp.argsort(arr)
arr, sorter = jax.lax.sort_key_val(arr, jnp.arange(len(arr)))
inv = invert_permutation(sorter)

if method == "ordinal":
return inv + 1
arr = arr[sorter]
obs = jnp.insert(arr[1:] != arr[:-1], 0, True)
obs = jnp.concatenate([jnp.array([True]), arr[1:] != arr[:-1]])
dense = obs.cumsum()[inv]
if method == "dense":
return dense
Expand Down

0 comments on commit 939b41f

Please sign in to comment.