-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Floating point deviation in jax.numpy.percentile
with linear interpolation between v0.2.20
and v0.2.21
#8513
Comments
I think the operative change here is that import jax
import jax.numpy as jnp
import numpy as np
from jax.config import config
config.update("jax_enable_x64", True)
input = [[10, 7, 4], [3, 2, 1]]
numpy_array = np.asarray(input)
jax_array = jnp.asarray(input, dtype="float")
print("jax_version:", jax.__version__)
print("no jit:", jnp.percentile(jax_array, 50))
print("jit:", jax.jit(jnp.percentile)(jax_array, 50))
# jax_version: 0.2.20
# no jit: 3.5
# jit: 3.499999761581421 Why does JIT compiling cause this kind of inaccuracy? As part of compilation, XLA is free to re-arrange mathematical operations for efficiency, and sometimes this changes results slightly due to the imprecision inherent to floating point. On further exploration, it looks like passing |
Thanks very much for the example and explanation @jakevdp — this is already quite helpful!
That would be great if possible in the future. 👍 |
* Add percentile function to the tensor backends * Add tests for percentile and its interpolation methods - JAX requires additional dtype support with the 'linear' interpolation method c.f. jax-ml/jax#8513 - PyTorch has yet to implement interpolation method options - c.f. #1693
Fix in #8520 |
Hi. There is some (very minor) deviations in the output of
jax.numpy.percentile
betweenjax
v0.2.20
andv0.2.21
in the case that linear interpolation is used (the default). Interestingly, it is really injax.numpy.percentile
and not injax.numpy.quantile
as can be shown in the included example (for convenience this Issue also exists as a GitHub Gist).Minimal failing example
Notes
Comparing the code for
v0.2.20
https://github.com/google/jax/blob/a7b61c0e00d1b535df8a30a82edc0074884d5f4c/jax/_src/numpy/lax_numpy.py#L5905-L5912
and
v0.2.21
https://github.com/google/jax/blob/dbeb97d394740bfd122a46249c967139c10d3f11/jax/_src/numpy/lax_numpy.py#L6420-L6429
It seems (at first glance as I haven't dug into this yet) that the only relevant difference is the removal of
asarray(q)
in thetrue_divide
call in PR #7747 (though I would think given the point of that PR that nothing should have changed)This effect is quite minor, and probably poses no real significance in most cases, but it deviates from the docstring described behavior. Maybe the most obvious example is the extremes where the q-th percentile is 1 — which should return the element of the array object which is the maxima (in the example
10
) but instead returns the floating point approximation of that element (9.999998092651367
).Request
Would it be possible to revert to the
v0.2.20
behavior? This would be more consistent with both the docstring and NumPy.JAX Issues checklist
Please:
The text was updated successfully, but these errors were encountered: