-
Notifications
You must be signed in to change notification settings - Fork 85
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Add percentile function to tensorlib #817
Merged
Merged
Changes from 32 commits
Commits
Show all changes
35 commits
Select commit
Hold shift + click to select a range
8f22095
Add percentile to NumPy backend
matthewfeickert 4bded71
Add percentile to TensorFlow backend
matthewfeickert f0262f9
Add percentile to JAX backend
matthewfeickert a4d6223
Add dump attempt of percentile in PyTorch
matthewfeickert e63a2ac
Add temporary NumPy standin for PyTorch
matthewfeickert c00bd1d
Correct PyTorch returns in docstrings
matthewfeickert 99db70d
Add TODO for native PyTorch adoption
matthewfeickert a214bb3
Reduce number of calls
matthewfeickert 97b5583
Force consistent return structure
matthewfeickert 7b4c9aa
Add tests for percentile
matthewfeickert 7c20839
Add tests for all percentile interpolation schemes
matthewfeickert bc74618
Correct PyTorch docstring
matthewfeickert 359a00f
Mark interpolation as fail or TensorFlow
matthewfeickert f1db2e6
Add min and max test to percentile
matthewfeickert a6ecf72
Test all failure cases for percentile interpolation
matthewfeickert 724bd33
Improve interpolation method section of docstrings
matthewfeickert 5d73911
Adopt the 'q' notation for consistency with backend library docs
matthewfeickert 09ec33b
Add TODOs with relevant GitHub Issues for backends
matthewfeickert f85f4ae
Update JAX to include all interpolation methods
matthewfeickert 9dc2380
Test if TensorFlow Probability has merged bugfix
matthewfeickert 95fb385
Update to torch v1.7+ to use native quantile
matthewfeickert 3e15519
JAX now supports interpolation
matthewfeickert cb2f44d
Note issue in PR
matthewfeickert 09cc629
Use pytest.approx given JAX has floating point issues
matthewfeickert 0500676
Skip pytorch interpolation
matthewfeickert 2bddd5a
Split out JAX tests
matthewfeickert 8310b7b
Fixup tensorflow doctest
matthewfeickert 10016fe
fixup jax percentile docstring
matthewfeickert 19da1c4
Remove TODO as implimented by Matthew
matthewfeickert e4adcfd
Note outstanding problems with PyTorch
matthewfeickert c6c01f6
Clarify problem is with linear interpolation
matthewfeickert 22fc3aa
Fixup docstrings
matthewfeickert 2768e20
Use jnp.float64 example that Jake Vanderplas demoed
matthewfeickert 9b85741
Use np.float64 in tests
matthewfeickert eee1d30
Add Isssue 1693 as comments
matthewfeickert File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -370,6 +370,61 @@ def test_boolean_mask(backend): | |
) | ||
|
||
|
||
@pytest.mark.skip_jax | ||
def test_percentile(backend): | ||
tb = pyhf.tensorlib | ||
a = tb.astensor([[10, 7, 4], [3, 2, 1]]) | ||
assert tb.tolist(tb.percentile(a, 0)) == 1 | ||
|
||
assert tb.tolist(tb.percentile(a, 50)) == 3.5 | ||
assert tb.tolist(tb.percentile(a, 100)) == 10 | ||
assert tb.tolist(tb.percentile(a, 50, axis=1)) == [7.0, 2.0] | ||
|
||
|
||
# FIXME: PyTorch doesn't yet support interpolation schemes other than "linear" | ||
# c.f. https://github.com/pytorch/pytorch/pull/59397 | ||
@pytest.mark.skip_pytorch | ||
@pytest.mark.skip_pytorch64 | ||
@pytest.mark.skip_jax | ||
def test_percentile_interpolation(backend): | ||
tb = pyhf.tensorlib | ||
a = tb.astensor([[10, 7, 4], [3, 2, 1]]) | ||
|
||
assert tb.tolist(tb.percentile(a, 50, interpolation="linear")) == 3.5 | ||
assert tb.tolist(tb.percentile(a, 50, interpolation="nearest")) == 3.0 | ||
assert tb.tolist(tb.percentile(a, 50, interpolation="lower")) == 3.0 | ||
assert tb.tolist(tb.percentile(a, 50, interpolation="midpoint")) == 3.5 | ||
assert tb.tolist(tb.percentile(a, 50, interpolation="higher")) == 4.0 | ||
|
||
|
||
@pytest.mark.only_jax | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hopefully these tests are tolerable at the moment, though obviously not ideal. |
||
def test_percentile_jax(backend): | ||
tb = pyhf.tensorlib | ||
a = tb.astensor([[10, 7, 4], [3, 2, 1]]) | ||
assert tb.tolist(tb.percentile(a, 0)) == 1 | ||
|
||
# FIXME: JAX has floating point issues with "linear" interpolation method | ||
assert pytest.approx(tb.tolist(tb.percentile(a, 50)), rel=1e-6) == 3.5 | ||
assert pytest.approx(tb.tolist(tb.percentile(a, 100)), rel=1e-6) == 10 | ||
assert tb.tolist(tb.percentile(a, 50, axis=1)) == [7.0, 2.0] | ||
|
||
|
||
@pytest.mark.only_jax | ||
def test_percentile_interpolation_jax(backend): | ||
tb = pyhf.tensorlib | ||
a = tb.astensor([[10, 7, 4], [3, 2, 1]]) | ||
|
||
# FIXME: JAX has floating point issues with "linear" interpolation method | ||
assert ( | ||
pytest.approx(tb.tolist(tb.percentile(a, 50, interpolation="linear")), rel=1e-6) | ||
== 3.5 | ||
) | ||
assert tb.tolist(tb.percentile(a, 50, interpolation="nearest")) == 3.0 | ||
assert tb.tolist(tb.percentile(a, 50, interpolation="lower")) == 3.0 | ||
assert tb.tolist(tb.percentile(a, 50, interpolation="midpoint")) == 3.5 | ||
assert tb.tolist(tb.percentile(a, 50, interpolation="higher")) == 4.0 | ||
|
||
|
||
def test_tensor_tile(backend): | ||
a = [[1], [2], [3]] | ||
tb = pyhf.tensorlib | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Skipping PyTorch as interpolation kwargs aren't implemented yet as of
torch
v1.10.0
. c.f. #815 (comment)Skipping JAX so as to have separate test below given floating point issues with
"linear"
option.