Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Floating point deviation in jax.numpy.percentile with linear interpolation between v0.2.20 and v0.2.21 #8513

Closed
3 tasks done
matthewfeickert opened this issue Nov 11, 2021 · 3 comments · Fixed by #8520
Closed
3 tasks done
Assignees
Labels
bug Something isn't working

Comments

@matthewfeickert
Copy link

matthewfeickert commented Nov 11, 2021

Hi. There is some (very minor) deviations in the output of jax.numpy.percentile between jax v0.2.20 and v0.2.21 in the case that linear interpolation is used (the default). Interestingly, it is really in jax.numpy.percentile and not in jax.numpy.quantile as can be shown in the included example (for convenience this Issue also exists as a GitHub Gist).

Minimal failing example

# example.py
import jax
import jax.numpy as jnp
import numpy as np
from jax.config import config

config.update("jax_enable_x64", True)

if __name__ == "__main__":

    # percentile interpolation options:
    # This optional parameter specifies the interpolation method to use when the desired percentile lies between two data points i < j:
    # * ’linear’: i + (j - i) * fraction, where fraction is the fractional part of the index surrounded by i and j.
    # * ’lower’: i.
    # * ’higher’: j.
    # * ’nearest’: i or j, whichever is nearest.
    # * ’midpoint’: (i + j) / 2.

    input = [[10, 7, 4], [3, 2, 1]]
    print(f"input list: {input}")
    print(f"input list ravel: {np.asarray(input).ravel()}")
    # [10  7  4  3  2  1]

    print(f"\nNumPy v{np.__version__}")
    print(f"JAX v{jax.__version__}\n")
    numpy_array = np.asarray(input)
    print(f"{numpy_array=}")
    jax_array = jnp.asarray(input, dtype="float")
    print(f"{jax_array=}")

    print("\n# Checking quantile\n")
    assert np.quantile(numpy_array, 0) == 1.0
    assert np.quantile(numpy_array, 0.50) == 3.5
    assert np.quantile(numpy_array, 1) == 10
    assert np.quantile(numpy_array, 0.50, axis=1).tolist() == [7.0, 2.0]

    assert np.quantile(numpy_array, 0.50, interpolation="linear") == 3.5
    assert np.quantile(numpy_array, 0.50, interpolation="nearest") == 3.0
    assert np.quantile(numpy_array, 0.50, interpolation="lower") == 3.0
    assert np.quantile(numpy_array, 0.50, interpolation="midpoint") == 3.5
    assert np.quantile(numpy_array, 0.50, interpolation="higher") == 4.0

    assert jnp.quantile(jax_array, 0) == 1.0
    assert jnp.quantile(jax_array, 0.50) == 3.5
    assert jnp.quantile(jax_array, 1) == 10
    assert jnp.quantile(jax_array, 0.50, axis=1).tolist() == [7.0, 2.0]

    assert jnp.quantile(jax_array, 0.50, interpolation="linear") == 3.5
    assert jnp.quantile(jax_array, 0.50, interpolation="nearest") == 3.0
    assert jnp.quantile(jax_array, 0.50, interpolation="lower") == 3.0
    assert jnp.quantile(jax_array, 0.50, interpolation="midpoint") == 3.5
    assert jnp.quantile(jax_array, 0.50, interpolation="higher") == 4.0

    print("# Checking percentile")
    assert np.percentile(numpy_array, 0) == 1.0
    assert np.percentile(numpy_array, 50) == 3.5
    assert np.percentile(numpy_array, 100) == 10
    assert np.percentile(numpy_array, 50, axis=1).tolist() == [7.0, 2.0]

    assert np.percentile(numpy_array, 50, interpolation="linear") == 3.5
    assert np.percentile(numpy_array, 50, interpolation="nearest") == 3.0
    assert np.percentile(numpy_array, 50, interpolation="lower") == 3.0
    assert np.percentile(numpy_array, 50, interpolation="midpoint") == 3.5
    assert np.percentile(numpy_array, 50, interpolation="higher") == 4.0

    # default interpolation method is "linear"
    assert jnp.percentile(jax_array, 0) == 1.0
    assert jnp.percentile(jax_array, 50) == 3.5  # 3.499999761581421
    assert jnp.percentile(jax_array, 100) == 10  # 9.999998092651367
    assert jnp.percentile(jax_array, 50, axis=1).tolist() == [7.0, 2.0]

    assert jnp.percentile(jax_array, 50, interpolation="linear") == 3.5  # 3.499999761581421
    assert jnp.percentile(jax_array, 50, interpolation="nearest") == 3.0
    assert jnp.percentile(jax_array, 50, interpolation="lower") == 3.0
    assert jnp.percentile(jax_array, 50, interpolation="midpoint") == 3.5
    assert jnp.percentile(jax_array, 50, interpolation="higher") == 4.0
user@machine:~$ python --version
Python 3.9.6
user@machine:~$ python -m venv /tmp/venv && . /tmp/venv/bin/activate
(venv) user@machine:~$ python -m pip install --upgrade pip setuptools wheel
(venv) user@machine:~$ cat requirements_passing.txt
jax==0.2.20
jaxlib==0.1.69
(venv) user@machine:~$ python -m pip install -r requirements_passing.txt
(venv) user@machine:~$ python example.py
input list: [[10, 7, 4], [3, 2, 1]]
input list ravel: [10  7  4  3  2  1]

NumPy v1.21.4
JAX v0.2.20

numpy_array=array([[10,  7,  4],
       [ 3,  2,  1]])
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
jax_array=DeviceArray([[10.,  7.,  4.],
             [ 3.,  2.,  1.]], dtype=float64)

# Checking quantile

# Checking percentile
(venv) user@machine:~$ cat requirements_failing.txt
jax==0.2.21
jaxlib==0.1.69
(venv) user@machine:~$ python -m pip install -r requirements_failing.txt
(venv) user@machine:~$ python example.py
input list: [[10, 7, 4], [3, 2, 1]]
input list ravel: [10  7  4  3  2  1]

NumPy v1.21.4
JAX v0.2.21

numpy_array=array([[10,  7,  4],
       [ 3,  2,  1]])
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
jax_array=DeviceArray([[10.,  7.,  4.],
             [ 3.,  2.,  1.]], dtype=float64)

# Checking quantile

# Checking percentile
Traceback (most recent call last):
  File "/home/feickert/Code/debug/jax-percentile-drift/example.py", line 67, in <module>
    assert jnp.percentile(jax_array, 50) == 3.5  # 3.499999761581421
AssertionError

Notes

Comparing the code for v0.2.20

https://github.com/google/jax/blob/a7b61c0e00d1b535df8a30a82edc0074884d5f4c/jax/_src/numpy/lax_numpy.py#L5905-L5912

and v0.2.21

https://github.com/google/jax/blob/dbeb97d394740bfd122a46249c967139c10d3f11/jax/_src/numpy/lax_numpy.py#L6420-L6429

It seems (at first glance as I haven't dug into this yet) that the only relevant difference is the removal of asarray(q) in the true_divide call in PR #7747 (though I would think given the point of that PR that nothing should have changed)

-q = true_divide(asarray(q), float32(100.0))
+q = true_divide(q, float32(100.0))

This effect is quite minor, and probably poses no real significance in most cases, but it deviates from the docstring described behavior. Maybe the most obvious example is the extremes where the q-th percentile is 1 — which should return the element of the array object which is the maxima (in the example 10) but instead returns the floating point approximation of that element (9.999998092651367).

Request

Would it be possible to revert to the v0.2.20 behavior? This would be more consistent with both the docstring and NumPy.

JAX Issues checklist

Please:

  • Check for duplicate issues.
  • Provide a complete example of how to reproduce the bug, wrapped in triple backticks
  • If applicable, include full error messages/tracebacks.
@matthewfeickert matthewfeickert added the bug Something isn't working label Nov 11, 2021
@jakevdp jakevdp self-assigned this Nov 11, 2021
@jakevdp
Copy link
Collaborator

jakevdp commented Nov 11, 2021

I think the operative change here is that jnp.percentile is JIT-compiled by default starting in v0.2.21. Try running the following in version 0.2.20:

import jax
import jax.numpy as jnp
import numpy as np
from jax.config import config

config.update("jax_enable_x64", True)

input = [[10, 7, 4], [3, 2, 1]]
numpy_array = np.asarray(input)
jax_array = jnp.asarray(input, dtype="float")

print("jax_version:", jax.__version__)
print("no jit:", jnp.percentile(jax_array, 50))
print("jit:", jax.jit(jnp.percentile)(jax_array, 50))
# jax_version: 0.2.20
# no jit: 3.5
# jit: 3.499999761581421

Why does JIT compiling cause this kind of inaccuracy? As part of compilation, XLA is free to re-arrange mathematical operations for efficiency, and sometimes this changes results slightly due to the imprecision inherent to floating point.

On further exploration, it looks like passing np.float64(50) rather than 50 fixes the issue. I think we could probably address this by being a bit more careful about dtype promotion within the percentile implementation.

@matthewfeickert
Copy link
Author

Why does JIT compiling cause this kind of inaccuracy? As part of compilation, XLA is free to re-arrange mathematical operations for efficiency, and sometimes this changes results slightly due to the imprecision inherent to floating point.

Thanks very much for the example and explanation @jakevdp — this is already quite helpful!

I think we could probably address this by being a bit more careful about dtype promotion within the percentile implementation.

That would be great if possible in the future. 👍

matthewfeickert added a commit to scikit-hep/pyhf that referenced this issue Nov 11, 2021
* Add percentile function to the tensor backends
* Add tests for percentile and its interpolation methods
   - JAX requires additional dtype support with the 'linear' interpolation method
     c.f. jax-ml/jax#8513
   - PyTorch has yet to implement interpolation method options
   - c.f. #1693
@jakevdp
Copy link
Collaborator

jakevdp commented Nov 11, 2021

Fix in #8520

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants