Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add percentile function to tensorlib #817

Merged
merged 35 commits into from
Nov 11, 2021
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
8f22095
Add percentile to NumPy backend
matthewfeickert Apr 3, 2020
4bded71
Add percentile to TensorFlow backend
matthewfeickert Apr 3, 2020
f0262f9
Add percentile to JAX backend
matthewfeickert Apr 3, 2020
a4d6223
Add dump attempt of percentile in PyTorch
matthewfeickert Apr 3, 2020
e63a2ac
Add temporary NumPy standin for PyTorch
matthewfeickert Apr 3, 2020
c00bd1d
Correct PyTorch returns in docstrings
matthewfeickert Apr 3, 2020
99db70d
Add TODO for native PyTorch adoption
matthewfeickert Apr 3, 2020
a214bb3
Reduce number of calls
matthewfeickert Apr 3, 2020
97b5583
Force consistent return structure
matthewfeickert Apr 4, 2020
7b4c9aa
Add tests for percentile
matthewfeickert Apr 4, 2020
7c20839
Add tests for all percentile interpolation schemes
matthewfeickert Apr 5, 2020
bc74618
Correct PyTorch docstring
matthewfeickert Apr 5, 2020
359a00f
Mark interpolation as fail or TensorFlow
matthewfeickert Apr 5, 2020
f1db2e6
Add min and max test to percentile
matthewfeickert Apr 5, 2020
a6ecf72
Test all failure cases for percentile interpolation
matthewfeickert Apr 5, 2020
724bd33
Improve interpolation method section of docstrings
matthewfeickert Apr 6, 2020
5d73911
Adopt the 'q' notation for consistency with backend library docs
matthewfeickert Apr 6, 2020
09ec33b
Add TODOs with relevant GitHub Issues for backends
matthewfeickert Apr 6, 2020
f85f4ae
Update JAX to include all interpolation methods
matthewfeickert Apr 17, 2020
9dc2380
Test if TensorFlow Probability has merged bugfix
matthewfeickert Jul 17, 2020
95fb385
Update to torch v1.7+ to use native quantile
matthewfeickert Oct 27, 2020
3e15519
JAX now supports interpolation
matthewfeickert Oct 27, 2020
cb2f44d
Note issue in PR
matthewfeickert Oct 27, 2020
09cc629
Use pytest.approx given JAX has floating point issues
matthewfeickert Nov 11, 2021
0500676
Skip pytorch interpolation
matthewfeickert Nov 11, 2021
2bddd5a
Split out JAX tests
matthewfeickert Nov 11, 2021
8310b7b
Fixup tensorflow doctest
matthewfeickert Nov 11, 2021
10016fe
fixup jax percentile docstring
matthewfeickert Nov 11, 2021
19da1c4
Remove TODO as implimented by Matthew
matthewfeickert Nov 11, 2021
e4adcfd
Note outstanding problems with PyTorch
matthewfeickert Nov 11, 2021
c6c01f6
Clarify problem is with linear interpolation
matthewfeickert Nov 11, 2021
22fc3aa
Fixup docstrings
matthewfeickert Nov 11, 2021
2768e20
Use jnp.float64 example that Jake Vanderplas demoed
matthewfeickert Nov 11, 2021
9b85741
Use np.float64 in tests
matthewfeickert Nov 11, 2021
eee1d30
Add Isssue 1693 as comments
matthewfeickert Nov 11, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions src/pyhf/tensor/jax_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,44 @@ def log(self, tensor_in):
def exp(self, tensor_in):
return jnp.exp(tensor_in)

def percentile(self, tensor_in, q, axis=None, interpolation="linear"):
r"""
Compute the :math:`q`-th percentile of the tensor along the specified axis.

Example:

>>> import pyhf
>>> pyhf.set_backend("jax")
>>> a = pyhf.tensorlib.astensor([[10, 7, 4], [3, 2, 1]])
>>> pyhf.tensorlib.percentile(a, 50)
DeviceArray(3.499999..., dtype=float64)
>>> pyhf.tensorlib.percentile(a, 50, axis=1)
DeviceArray([7., 2.], dtype=float64)

Args:
tensor_in (`tensor`): The tensor containing the data
q (:obj:`float` or `tensor`): The :math:`q`-th percentile to compute
axis (`number` or `tensor`): The dimensions along which to compute
interpolation (:obj:`str`): The interpolation method to use when the
desired percentile lies between two data points ``i < j``:

- ``'linear'``: ``i + (j - i) * fraction``, where ``fraction`` is the
fractional part of the index surrounded by ``i`` and ``j``.

- ``'lower'``: ``i``.

- ``'higher'``: ``j``.

- ``'midpoint'``: ``(i + j) / 2``.

- ``'nearest'``: ``i`` or ``j``, whichever is nearest.

Returns:
JAX ndarray: The value of the :math:`q`-th percentile of the tensor along the specified axis.

"""
return jnp.percentile(tensor_in, q, axis=axis, interpolation=interpolation)

def stack(self, sequence, axis=0):
return jnp.stack(sequence, axis=axis)

Expand Down
38 changes: 38 additions & 0 deletions src/pyhf/tensor/numpy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,44 @@ def log(self, tensor_in):
def exp(self, tensor_in):
return np.exp(tensor_in)

def percentile(self, tensor_in, q, axis=None, interpolation="linear"):
r"""
Compute the :math:`q`-th percentile of the tensor along the specified axis.

Example:

>>> import pyhf
>>> pyhf.set_backend("numpy")
>>> a = pyhf.tensorlib.astensor([[10, 7, 4], [3, 2, 1]])
>>> pyhf.tensorlib.percentile(a, 50)
3.5
>>> pyhf.tensorlib.percentile(a, 50, axis=1)
array([7., 2.])

Args:
tensor_in (`tensor`): The tensor containing the data
q (:obj:`float` or `tensor`): The :math:`q`-th percentile to compute
axis (`number` or `tensor`): The dimensions along which to compute
interpolation (:obj:`str`): The interpolation method to use when the
desired percentile lies between two data points ``i < j``:

- ``'linear'``: ``i + (j - i) * fraction``, where ``fraction`` is the
fractional part of the index surrounded by ``i`` and ``j``.

- ``'lower'``: ``i``.

- ``'higher'``: ``j``.

- ``'midpoint'``: ``(i + j) / 2``.

- ``'nearest'``: ``i`` or ``j``, whichever is nearest.

Returns:
NumPy ndarray: The value of the :math:`q`-th percentile of the tensor along the specified axis.

"""
return np.percentile(tensor_in, q, axis=axis, interpolation=interpolation)

def stack(self, sequence, axis=0):
return np.stack(sequence, axis=axis)

Expand Down
41 changes: 41 additions & 0 deletions src/pyhf/tensor/pytorch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,47 @@ def log(self, tensor_in):
def exp(self, tensor_in):
return torch.exp(tensor_in)

def percentile(self, tensor_in, q, axis=None, interpolation="linear"):
r"""
Compute the :math:`q`-th percentile of the tensor along the specified axis.

Example:

>>> import pyhf
>>> pyhf.set_backend("pytorch")
>>> a = pyhf.tensorlib.astensor([[10, 7, 4], [3, 2, 1]])
>>> pyhf.tensorlib.percentile(a, 50)
tensor(3.5000)
>>> pyhf.tensorlib.percentile(a, 50, axis=1)
tensor([7., 2.])

Args:
tensor_in (`tensor`): The tensor containing the data
q (:obj:`float` or `tensor`): The :math:`q`-th percentile to compute
axis (`number` or `tensor`): The dimensions along which to compute
interpolation (:obj:`str`): The interpolation method to use when the
desired percentile lies between two data points ``i < j``:

- ``'linear'``: ``i + (j - i) * fraction``, where ``fraction`` is the
fractional part of the index surrounded by ``i`` and ``j``.

- ``'lower'``: Not yet implemented in PyTorch.

- ``'higher'``: Not yet implemented in PyTorch.

- ``'midpoint'``: Not yet implemented in PyTorch.

- ``'nearest'``: Not yet implemented in PyTorch.

Returns:
PyTorch tensor: The value of the :math:`q`-th percentile of the tensor along the specified axis.

"""
# Interpolation options not yet supported
# c.f. https://github.com/pytorch/pytorch/pull/49267
# c.f. https://github.com/pytorch/pytorch/pull/59397
return torch.quantile(tensor_in, q / 100, dim=axis)

def stack(self, sequence, axis=0):
return torch.stack(sequence, dim=axis)

Expand Down
42 changes: 42 additions & 0 deletions src/pyhf/tensor/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,48 @@ def log(self, tensor_in):
def exp(self, tensor_in):
return tf.exp(tensor_in)

def percentile(self, tensor_in, q, axis=None, interpolation="linear"):
r"""
Compute the :math:`q`-th percentile of the tensor along the specified axis.

Example:

>>> import pyhf
>>> pyhf.set_backend("tensorflow")
>>> a = pyhf.tensorlib.astensor([[10, 7, 4], [3, 2, 1]])
>>> t = pyhf.tensorlib.percentile(a, 50)
>>> print(t)
tf.Tensor(3.5, shape=(), dtype=float64)
>>> t = pyhf.tensorlib.percentile(a, 50, axis=1)
>>> print(t)
tf.Tensor([7. 2.], shape=(2,), dtype=float64)

Args:
tensor_in (`tensor`): The tensor containing the data
q (:obj:`float` or `tensor`): The :math:`q`-th percentile to compute
axis (`number` or `tensor`): The dimensions along which to compute
interpolation (:obj:`str`): The interpolation method to use when the
desired percentile lies between two data points ``i < j``:

- ``'linear'``: ``i + (j - i) * fraction``, where ``fraction`` is the
fractional part of the index surrounded by ``i`` and ``j``.

- ``'lower'``: ``i``.

- ``'higher'``: ``j``.

- ``'midpoint'``: ``(i + j) / 2``.

- ``'nearest'``: ``i`` or ``j``, whichever is nearest.

Returns:
TensorFlow Tensor: The value of the :math:`q`-th percentile of the tensor along the specified axis.

"""
return tfp.stats.percentile(
tensor_in, q, axis=axis, interpolation=interpolation
)

def stack(self, sequence, axis=0):
return tf.stack(sequence, axis=axis)

Expand Down
55 changes: 55 additions & 0 deletions tests/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,61 @@ def test_boolean_mask(backend):
)


@pytest.mark.skip_jax
def test_percentile(backend):
tb = pyhf.tensorlib
a = tb.astensor([[10, 7, 4], [3, 2, 1]])
assert tb.tolist(tb.percentile(a, 0)) == 1

assert tb.tolist(tb.percentile(a, 50)) == 3.5
assert tb.tolist(tb.percentile(a, 100)) == 10
assert tb.tolist(tb.percentile(a, 50, axis=1)) == [7.0, 2.0]


# FIXME: PyTorch doesn't yet support interpolation schemes other than "linear"
# c.f. https://github.com/pytorch/pytorch/pull/59397
@pytest.mark.skip_pytorch
@pytest.mark.skip_pytorch64
@pytest.mark.skip_jax
Comment on lines +387 to +389
Copy link
Member Author

@matthewfeickert matthewfeickert Nov 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Skipping PyTorch as interpolation kwargs aren't implemented yet as of torch v1.10.0. c.f. #815 (comment)

Skipping JAX so as to have separate test below given floating point issues with "linear" option.

def test_percentile_interpolation(backend):
tb = pyhf.tensorlib
a = tb.astensor([[10, 7, 4], [3, 2, 1]])

assert tb.tolist(tb.percentile(a, 50, interpolation="linear")) == 3.5
assert tb.tolist(tb.percentile(a, 50, interpolation="nearest")) == 3.0
assert tb.tolist(tb.percentile(a, 50, interpolation="lower")) == 3.0
assert tb.tolist(tb.percentile(a, 50, interpolation="midpoint")) == 3.5
assert tb.tolist(tb.percentile(a, 50, interpolation="higher")) == 4.0


@pytest.mark.only_jax
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hopefully these tests are tolerable at the moment, though obviously not ideal.

def test_percentile_jax(backend):
tb = pyhf.tensorlib
a = tb.astensor([[10, 7, 4], [3, 2, 1]])
assert tb.tolist(tb.percentile(a, 0)) == 1

# FIXME: JAX has floating point issues with "linear" interpolation method
assert pytest.approx(tb.tolist(tb.percentile(a, 50)), rel=1e-6) == 3.5
assert pytest.approx(tb.tolist(tb.percentile(a, 100)), rel=1e-6) == 10
assert tb.tolist(tb.percentile(a, 50, axis=1)) == [7.0, 2.0]


@pytest.mark.only_jax
def test_percentile_interpolation_jax(backend):
tb = pyhf.tensorlib
a = tb.astensor([[10, 7, 4], [3, 2, 1]])

# FIXME: JAX has floating point issues with "linear" interpolation method
assert (
pytest.approx(tb.tolist(tb.percentile(a, 50, interpolation="linear")), rel=1e-6)
== 3.5
)
assert tb.tolist(tb.percentile(a, 50, interpolation="nearest")) == 3.0
assert tb.tolist(tb.percentile(a, 50, interpolation="lower")) == 3.0
assert tb.tolist(tb.percentile(a, 50, interpolation="midpoint")) == 3.5
assert tb.tolist(tb.percentile(a, 50, interpolation="higher")) == 4.0


def test_tensor_tile(backend):
a = [[1], [2], [3]]
tb = pyhf.tensorlib
Expand Down