Replies: 2 comments 4 replies
-
Hi @yyang97 , yes, you can vmap over the levels, e.g.: import jax
from ott.tools import soft_sort
key = jax.random.PRNGKey(0)
x_test = jax.random.normal(key, shape=(100,))
levels = jnp.array([0.2, 0.8])
jax.vmap(lambda level: soft_sort.quantile(x_test, level=level))(levels)
# Array([[-0.7245259],
# [ 0.93384 ]], dtype=float32) |
Beta Was this translation helpful? Give feedback.
1 reply
-
another approach would be to have If you need this for an application, can you raise an issue? I will try to implement this shortly. |
Beta Was this translation helpful? Give feedback.
3 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
It is good to calculate the 0.2 quantile value of x_test.
However, how can I simultaneously calculate 0.2 and 0.8 quantile values of x_test?
like
However, it reports some errors:
ValueError: All input arrays must have the same shape.
Beta Was this translation helpful? Give feedback.
All reactions