-
Notifications
You must be signed in to change notification settings - Fork 85
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Configurable default backend #1646
Conversation
@kratsg I think we don't need all the Feel free to elaborate though. |
Can you add a test that shows that a function that uses both default backend and tensorlib with both set to jax becomes diffable? |
we need this to access Note this is just one step in the PR. First step is to keep all the API changes the same while migrating the core functionality of the code. The next step is to fix things up to not have I've already dropped globals. |
Yes, I can definitely add a test. I'm hoping @phinate can provide some suggestions on a way to ensure diffability. |
Can you make an Issue that lays out what you're planning to do? It is easier to review PRs when there is a context. |
|
I would try something like def func(x): jax.jacrev(func)(default_backend.astensor([1.,2.])) |
Yeah, this is a nice way to check there’s no backend foul play on the scale of individual ops. Probably not a planned use case, but I imagine this is also possible to test for the other backends. To ensure the most general case in terms of model construction, one could also do an example similar to the one posted in #742 that differentiates with respect to a latent param that controls the yields, but with a more general set of models that explores the different types of systematics, though probably out of scope for this PR! |
Forked and tried this, it's probably enough: import pyhf
import jax
import pytest
def test_diffable_backend():
pyhf.set_backend("jax", default=True)
def example_op(x):
y = pyhf.default_backend.astensor(x)
return 2*y
assert jax.jacrev(example_op)([1.]) == [2.]
def test_diffable_backend_failure():
pyhf.set_backend("numpy", default=True)
pyhf.set_backend("jax")
def example_op(x):
y = pyhf.default_backend.astensor(x)
return 2*y
with pytest.raises(Exception):
jax.jacrev(example_op)([1.]) Could maybe cover slightly more ground with a tensorlib op instead of * 2 like Lukas wrote, but this is the idea |
thanks @phinate ccan you verify that this works too?
|
sure -- guessing you meant |
Yes sorry :) |
may not need both, but this still runs without error: def test_diffable_backend():
pyhf.set_backend("jax", default=True)
def example_op(x):
y = pyhf.default_backend.astensor(x)
return 2 * y
assert jax.jacrev(example_op)([1.0]) == [2.0]
def example_op2(x):
y = pyhf.default_backend.power(x, 2)
z = pyhf.tensorlib.sum(y)
return z
assert jax.jacrev(example_op2)(pyhf.tensorlib.astensor([2.0, 3.0])).tolist() == [
4.0,
6.0,
]
def test_diffable_backend_failure():
pyhf.set_backend("numpy", default=True)
pyhf.set_backend("jax")
def example_op(x):
y = pyhf.default_backend.astensor(x)
return 2 * y
with pytest.raises(Exception):
jax.value_and_grad(example_op)(1.0)
def example_op2(x):
y = pyhf.default_backend.power(x, 2)
z = pyhf.tensorlib.sum(y)
return z
with pytest.raises(Exception):
jax.jacrev(example_op2)(pyhf.tensorlib.astensor([2.0, 3.0])) |
Great tests.. Would be nice to also test that you can jit through this if the two backwnds are jax |
nice idea! still works :) @pytest.mark.parametrize('jitted', (False, True))
def test_diffable_backend(jitted):
pyhf.set_backend("jax", default=True)
def example_op(x):
y = pyhf.default_backend.astensor(x)
return 2 * y
if jitted:
assert jax.jacrev(jax.jit(example_op))([1.0]) == [2.0]
else:
assert jax.jacrev(example_op)([1.0]) == [2.0]
def example_op2(x):
y = pyhf.default_backend.power(x, 2)
z = pyhf.tensorlib.sum(y)
return z
if jitted:
assert jax.jacrev(jax.jit(example_op2))(pyhf.tensorlib.astensor([2.0, 3.0])).tolist() == [
4.0,
6.0,
]
else:
assert jax.jacrev(example_op2)(pyhf.tensorlib.astensor([2.0, 3.0])).tolist() == [
4.0,
6.0,
]
def test_diffable_backend_failure():
pyhf.set_backend("numpy", default=True)
pyhf.set_backend("jax")
def example_op(x):
y = pyhf.default_backend.astensor(x)
return 2 * y
with pytest.raises(Exception):
jax.jacrev(example_op)([1.0])
def example_op2(x):
y = pyhf.default_backend.power(x, 2)
z = pyhf.tensorlib.sum(y)
return z
with pytest.raises(Exception):
jax.jacrev(example_op2)(pyhf.tensorlib.astensor([2.0, 3.0])) edit: parametrized over jitting, seems a little neater |
…y to support default_backend changing
Codecov Report
@@ Coverage Diff @@
## master #1646 +/- ##
==========================================
+ Coverage 98.04% 98.07% +0.02%
==========================================
Files 63 64 +1
Lines 4142 4199 +57
Branches 572 578 +6
==========================================
+ Hits 4061 4118 +57
Misses 48 48
Partials 33 33
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
I think we need a couple of changes and tests: when playing around with this branch, I found that pyhf.simplemodels.correlated_background([2],[3],[5],[5]) fails with error
This is because We should also try to flag this with some tests, was there not already tests to check basic functionality across multiple default/tensorlib combinations? Edit: of course, I should clarify that this is with pyhf.set_backend("jax", default=True) |
This is mainly an issue as of #1647, since that's where the concatenate calls were likely added, so i'm only seeing it from the rebase you did @kratsg |
just for clarity, here's the relevant page in the jax docs that goes over this |
Huh, we do have |
Actually, now I'm confused. We're loading in regular workspaces -- but those also won't work with non-numpy backends by default, right? Edit, doesn't work. This is just revealing an underlying feature we never actually supported
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @kratsg for this PR. This looks good overall to me, and I just have a few small questions and suggested changes.
I also agree with @lukasheinrich that using
default_backend = pyhf.default_backend
is a nicer pattern (and certainly made this easier to review) so thanks for adopting that. 👍
import pyhf | ||
from pyhf.tensor.manager import get_backend |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just so I understand, this is necessary to avoid
ImportError: cannot import name 'default_backend' from partially initialized module 'pyhf' (most likely due to a circular import)
because pyhf.defaul_backend
needs to be evaluated as an attr and so doesn't exist as a static attribute? Or am I just missing the obvious here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Global vs non-global. You don't want to evaluate default_backend at import time.
Co-authored-by: Matthew Feickert <matthew.feickert@cern.ch>
Co-authored-by: Matthew Feickert <matthew.feickert@cern.ch>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. 👍 Thanks!
Pull Request Description
Supersedes #1121. Should allow for default backend to be changeable.
Resolves #1004.
Checklist Before Requesting Reviewer
Before Merging
For the PR Assignees: