Skip to content

Commit

Permalink
Update JAX to include all interpolation methods
Browse files Browse the repository at this point in the history
JAX v0.1.63 should include the fix from Jake Vanderplas
  • Loading branch information
matthewfeickert committed Apr 22, 2020
1 parent e885822 commit 75f591f
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 16 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
extras_require = {
'tensorflow': ['tensorflow~=2.0', 'tensorflow-probability~=0.8'],
'torch': ['torch~=1.2'],
'jax': ['jax~=0.1,>0.1.51', 'jaxlib~=0.1,>0.1.33'],
'jax': ['jax~=0.1,>=0.1.63', 'jaxlib~=0.1,>=0.1.44'],
'xmlio': ['uproot'],
'minuit': ['iminuit'],
}
Expand Down
1 change: 0 additions & 1 deletion src/pyhf/tensor/jax_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,6 @@ def percentile(self, tensor_in, q, axis=None, interpolation="linear"):
JAX ndarray: The value of the :math:`q`-th percentile of the tensor along the specified axis.
"""
# TODO: https://github.com/google/jax/issues/2607
return np.percentile(tensor_in, q, axis=axis, interpolation=interpolation)

def stack(self, sequence, axis=0):
Expand Down
18 changes: 4 additions & 14 deletions tests/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,26 +240,16 @@ def test_percentile(backend):
def test_percentile_interpolation(backend):
tb = pyhf.tensorlib
a = tb.astensor([[10, 7, 4], [3, 2, 1]])

assert tb.tolist(tb.percentile(a, 50, interpolation="linear")) == 3.5
# TODO: Unify this with NumPy through TFP team fixing difference
if tb.name == "tensorflow":
assert tb.tolist(tb.percentile(a, 50, interpolation="nearest")) == 4.0
# TODO: Unify this with NumPy thorugh JAX team implimenting
elif tb.name == "jax":
with pytest.raises(NotImplementedError):
assert tb.tolist(tb.percentile(a, 50, interpolation="nearest")) == 3.0
else:
assert tb.tolist(tb.percentile(a, 50, interpolation="nearest")) == 3.0
# TODO: Unify this with NumPy thorugh JAX team implimenting
if tb.name == "jax":
with pytest.raises(NotImplementedError):
assert tb.tolist(tb.percentile(a, 50, interpolation="lower")) == 3.0
assert tb.tolist(tb.percentile(a, 50, interpolation="midpoint")) == 3.5
assert tb.tolist(tb.percentile(a, 50, interpolation="higher")) == 4.0
else:
assert tb.tolist(tb.percentile(a, 50, interpolation="lower")) == 3.0
assert tb.tolist(tb.percentile(a, 50, interpolation="midpoint")) == 3.5
assert tb.tolist(tb.percentile(a, 50, interpolation="higher")) == 4.0
assert tb.tolist(tb.percentile(a, 50, interpolation="lower")) == 3.0
assert tb.tolist(tb.percentile(a, 50, interpolation="midpoint")) == 3.5
assert tb.tolist(tb.percentile(a, 50, interpolation="higher")) == 4.0


def test_tensor_tile(backend):
Expand Down

0 comments on commit 75f591f

Please sign in to comment.