Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Configurable default backend #1646

Merged
merged 31 commits into from
Oct 22, 2021

Conversation

kratsg
Copy link
Contributor

@kratsg kratsg commented Oct 15, 2021

Pull Request Description

Supersedes #1121. Should allow for default backend to be changeable.

Resolves #1004.

Checklist Before Requesting Reviewer

  • Tests are passing
  • "WIP" removed from the title of the pull request
  • Selected an Assignee for the PR to be responsible for the log summary

Before Merging

For the PR Assignees:

  • Summarize commit messages into a comprehensive review of the PR
* Add conditional logic in `set_backend` and `get_backend` based on
a new kwarg `default=False` that will set or get the default backend
instead of the tensorlib backend. Particularly useful for differentiable
model construction, which needs to be propagated through tensors
instead of numpy arrays.
* Migrate relevant code in src/pyhf/__init__.py into a submodule
instead and clean things up more
* Add tests to ensure autodifferentiability and jit with the
changing of default backends

@kratsg kratsg added API Changes the public API feat/enhancement New feature or request refactor A code change that neither fixes a bug nor adds a feature labels Oct 15, 2021
@kratsg kratsg self-assigned this Oct 15, 2021
@matthewfeickert
Copy link
Member

@kratsg I think we don't need all the import pyhf everywhere. I know that you said you're doing this so avoid circular imports, but I think this can be avoided.

Feel free to elaborate though.

@matthewfeickert matthewfeickert marked this pull request as draft October 15, 2021 16:02
@lukasheinrich
Copy link
Contributor

Can you add a test that shows that a function that uses both default backend and tensorlib with both set to jax becomes diffable?

@lukasheinrich lukasheinrich marked this pull request as ready for review October 15, 2021 16:04
@kratsg
Copy link
Contributor Author

kratsg commented Oct 15, 2021

@kratsg I think we don't need all the import pyhf everywhere. I know that you said you're doing this so avoid circular imports, but I think this can be avoided.

we need this to access pyhf.default_backend. I cannot from pyhf import default_backend anymore fyi.

Note this is just one step in the PR. First step is to keep all the API changes the same while migrating the core functionality of the code. The next step is to fix things up to not have import pyhf. Then next will be to drop events, and only use one backend instance everywhere.

I've already dropped globals.

@kratsg
Copy link
Contributor Author

kratsg commented Oct 15, 2021

Can you add a test that shows that a function that uses both default backend and tensorlib with both set to jax becomes diffable?

Yes, I can definitely add a test. I'm hoping @phinate can provide some suggestions on a way to ensure diffability.

@matthewfeickert
Copy link
Member

Note this is just one step in the PR. First step is to keep all the API changes the same while migrating the core functionality of the code. The next step is to fix things up to not have import pyhf. Then next will be to drop events, and only use one backend instance everywhere.

Can you make an Issue that lays out what you're planning to do? It is easier to review PRs when there is a context.

@kratsg
Copy link
Contributor Author

kratsg commented Oct 15, 2021

Can you make an Issue that lays out what you're planning to do? It is easier to review PRs when there is a context.

#1648

@kratsg kratsg changed the title feat: Change default backend feat: Configurable default backend Oct 15, 2021
@lukasheinrich
Copy link
Contributor

I would try something like

def func(x):
y = default_backend.power(x, 2)
z = tensorlib.power(y, 2)
return z

jax.jacrev(func)(default_backend.astensor([1.,2.]))

@phinate
Copy link
Contributor

phinate commented Oct 15, 2021

I would try something like

def func(x): y = default_backend.power(x, 2) z = tensorlib.power(y, 2) return z

jax.jacrev(func)(default_backend.astensor([1.,2.]))

Yeah, this is a nice way to check there’s no backend foul play on the scale of individual ops. Probably not a planned use case, but I imagine this is also possible to test for the other backends.

To ensure the most general case in terms of model construction, one could also do an example similar to the one posted in #742 that differentiates with respect to a latent param that controls the yields, but with a more general set of models that explores the different types of systematics, though probably out of scope for this PR!

@kratsg kratsg mentioned this pull request Oct 15, 2021
4 tasks
@phinate
Copy link
Contributor

phinate commented Oct 16, 2021

Forked and tried this, it's probably enough:

import pyhf
import jax
import pytest

def test_diffable_backend():
    pyhf.set_backend("jax", default=True)

    def example_op(x):
        y = pyhf.default_backend.astensor(x)
        return 2*y
    assert jax.jacrev(example_op)([1.]) == [2.] 

def test_diffable_backend_failure():
    pyhf.set_backend("numpy", default=True)
    pyhf.set_backend("jax")

    def example_op(x):
        y = pyhf.default_backend.astensor(x)
        return 2*y

    with pytest.raises(Exception):
        jax.jacrev(example_op)([1.])

Could maybe cover slightly more ground with a tensorlib op instead of * 2 like Lukas wrote, but this is the idea

@lukasheinrich
Copy link
Contributor

thanks @phinate ccan you verify that this works too?

    def example_op(x):
        y = pyhf.default_backend.power(x,2)
        z = pyhf.tensorlib.sum(x)
        return z

@phinate
Copy link
Contributor

phinate commented Oct 16, 2021

thanks @phinate ccan you verify that this works too?

    def example_op(x):
        y = pyhf.default_backend.power(x,2)
        z = pyhf.tensorlib.sum(x)
        return z

sure -- guessing you meant z = pyhf.tensorlib.sum(y)?

@lukasheinrich
Copy link
Contributor

Yes sorry :)

@phinate
Copy link
Contributor

phinate commented Oct 16, 2021

may not need both, but this still runs without error:

def test_diffable_backend():
    pyhf.set_backend("jax", default=True)

    def example_op(x):
        y = pyhf.default_backend.astensor(x)
        return 2 * y

    assert jax.jacrev(example_op)([1.0]) == [2.0]

    def example_op2(x):
        y = pyhf.default_backend.power(x, 2)
        z = pyhf.tensorlib.sum(y)
        return z

    assert jax.jacrev(example_op2)(pyhf.tensorlib.astensor([2.0, 3.0])).tolist() == [
        4.0,
        6.0,
    ]


def test_diffable_backend_failure():
    pyhf.set_backend("numpy", default=True)
    pyhf.set_backend("jax")

    def example_op(x):
        y = pyhf.default_backend.astensor(x)
        return 2 * y

    with pytest.raises(Exception):
        jax.value_and_grad(example_op)(1.0)

    def example_op2(x):
        y = pyhf.default_backend.power(x, 2)
        z = pyhf.tensorlib.sum(y)
        return z

    with pytest.raises(Exception):
        jax.jacrev(example_op2)(pyhf.tensorlib.astensor([2.0, 3.0]))

@lukasheinrich
Copy link
Contributor

Great tests.. Would be nice to also test that you can jit through this if the two backwnds are jax

@phinate
Copy link
Contributor

phinate commented Oct 16, 2021

Great tests.. Would be nice to also test that you can jit through this if the two backwnds are jax

nice idea! still works :)

@pytest.mark.parametrize('jitted', (False, True))
def test_diffable_backend(jitted):
    pyhf.set_backend("jax", default=True)

    def example_op(x):
        y = pyhf.default_backend.astensor(x)
        return 2 * y

    if jitted: 
        assert jax.jacrev(jax.jit(example_op))([1.0]) == [2.0]
    else:
        assert jax.jacrev(example_op)([1.0]) == [2.0]

    def example_op2(x):
        y = pyhf.default_backend.power(x, 2)
        z = pyhf.tensorlib.sum(y)
        return z

    if jitted:
        assert jax.jacrev(jax.jit(example_op2))(pyhf.tensorlib.astensor([2.0, 3.0])).tolist() == [
            4.0,
            6.0,
        ]
    else:
        assert jax.jacrev(example_op2)(pyhf.tensorlib.astensor([2.0, 3.0])).tolist() == [
            4.0,
            6.0,
        ]


def test_diffable_backend_failure():
    pyhf.set_backend("numpy", default=True)
    pyhf.set_backend("jax")

    def example_op(x):
        y = pyhf.default_backend.astensor(x)
        return 2 * y

    with pytest.raises(Exception):
        jax.jacrev(example_op)([1.0])

    def example_op2(x):
        y = pyhf.default_backend.power(x, 2)
        z = pyhf.tensorlib.sum(y)
        return z

    with pytest.raises(Exception):
        jax.jacrev(example_op2)(pyhf.tensorlib.astensor([2.0, 3.0]))

edit: parametrized over jitting, seems a little neater

@lukasheinrich
Copy link
Contributor

thanks @phinate - @kratsg these would be very good tests to add

@codecov
Copy link

codecov bot commented Oct 19, 2021

Codecov Report

Merging #1646 (972435c) into master (c910f86) will increase coverage by 0.02%.
The diff coverage is 99.27%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #1646      +/-   ##
==========================================
+ Coverage   98.04%   98.07%   +0.02%     
==========================================
  Files          63       64       +1     
  Lines        4142     4199      +57     
  Branches      572      578       +6     
==========================================
+ Hits         4061     4118      +57     
  Misses         48       48              
  Partials       33       33              
Flag Coverage Δ
contrib 25.24% <49.27%> (+0.35%) ⬆️
doctest 61.15% <70.28%> (+0.17%) ⬆️
unittests 96.38% <99.27%> (+0.04%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
src/pyhf/tensor/manager.py 98.55% <98.55%> (ø)
src/pyhf/__init__.py 100.00% <100.00%> (+1.35%) ⬆️
src/pyhf/constraints.py 97.08% <100.00%> (+0.06%) ⬆️
src/pyhf/interpolators/code0.py 100.00% <100.00%> (ø)
src/pyhf/interpolators/code1.py 100.00% <100.00%> (ø)
src/pyhf/interpolators/code2.py 93.54% <100.00%> (+0.21%) ⬆️
src/pyhf/interpolators/code4.py 95.34% <100.00%> (+0.11%) ⬆️
src/pyhf/interpolators/code4p.py 71.64% <100.00%> (+0.87%) ⬆️
src/pyhf/modifiers/histosys.py 100.00% <100.00%> (ø)
src/pyhf/modifiers/shapefactor.py 100.00% <100.00%> (ø)
... and 8 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update c910f86...972435c. Read the comment docs.

@phinate
Copy link
Contributor

phinate commented Oct 19, 2021

I think we need a couple of changes and tests: when playing around with this branch, I found that

pyhf.simplemodels.correlated_background([2],[3],[5],[5])

fails with error

~/new-neos/neos/venv/lib/python3.9/site-packages/pyhf/pdf.py in <listcomp>(.0)
     68         nominal_rates = pyhf.default_backend.astensor(
     69             [
---> 70                 pyhf.default_backend.concatenate(self.mega_samples[sample]['nom'])
     71                 for sample in self.config.samples
     72             ]

~/new-neos/neos/venv/lib/python3.9/site-packages/pyhf/tensor/jax_backend.py in concatenate(self, sequence, axis)
    296 
    297         """
--> 298         return jnp.concatenate(sequence, axis=axis)
    299 
    300     def simple_broadcast(self, *args):

~/new-neos/neos/venv/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py in concatenate(arrays, axis)
   2992   if isinstance(arrays, ndarray):
   2993     return _concatenate_array(arrays, axis)
-> 2994   _check_arraylike("concatenate", *arrays)
   2995   if not len(arrays):
   2996     raise ValueError("Need at least one array to concatenate.")

~/new-neos/neos/venv/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py in _check_arraylike(fun_name, *args)
    307                     if not _arraylike(arg))
    308     msg = "{} requires ndarray or scalar arguments, got {} at position {}."
--> 309     raise TypeError(msg.format(fun_name, type(arg), pos))
    310 
    311 def _check_no_float0s(fun_name, *args):

TypeError: concatenate requires ndarray or scalar arguments, got <class 'list'> at position 0.

This is because jax does not mimic numpy's behaviour when using a list as an arg -- you need to explicitly cast to a tensor when doing this. Probably solved with some default_backend.astensor calls.

We should also try to flag this with some tests, was there not already tests to check basic functionality across multiple default/tensorlib combinations?

Edit: of course, I should clarify that this is with

pyhf.set_backend("jax", default=True)

@phinate
Copy link
Contributor

phinate commented Oct 19, 2021

This is because jax does not mimic numpy's behaviour when using a list as an arg -- you need to explicitly cast to a tensor when doing this. Probably solved with some default_backend.astensor calls.

This is mainly an issue as of #1647, since that's where the concatenate calls were likely added, so i'm only seeing it from the rebase you did @kratsg

@phinate
Copy link
Contributor

phinate commented Oct 19, 2021

just for clarity, here's the relevant page in the jax docs that goes over this

@kratsg
Copy link
Contributor Author

kratsg commented Oct 19, 2021

Huh, we do have tests/test_simplemodels.py but those don't cover the other backends.

@kratsg
Copy link
Contributor Author

kratsg commented Oct 19, 2021

Actually, now I'm confused. We're loading in regular workspaces -- but those also won't work with non-numpy backends by default, right?

Edit, doesn't work. This is just revealing an underlying feature we never actually supported

>>> import json
>>> ws = pyhf.Workspace(json.load(open('mysigfit_brZ_100_brH_0_brW_0_bre_33_brm_33_brt_34_mass_100.json')))
>>> pdf = ws.model()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/kratsg/.pyenv/versions/pyhf-dev/lib/python3.8/site-packages/pyhf/workspace.py", line 425, in model
    return Model(modelspec, **config_kwargs)
  File "/Users/kratsg/.pyenv/versions/pyhf-dev/lib/python3.8/site-packages/pyhf/pdf.py", line 632, in __init__
    modifiers, _nominal_rates = _nominal_and_modifiers_from_spec(
  File "/Users/kratsg/.pyenv/versions/pyhf-dev/lib/python3.8/site-packages/pyhf/pdf.py", line 129, in _nominal_and_modifiers_from_spec
    nominal_rates = nominal.finalize()
  File "/Users/kratsg/.pyenv/versions/pyhf-dev/lib/python3.8/site-packages/pyhf/pdf.py", line 69, in finalize
    [
  File "/Users/kratsg/.pyenv/versions/pyhf-dev/lib/python3.8/site-packages/pyhf/pdf.py", line 70, in <listcomp>
    pyhf.default_backend.concatenate(self.mega_samples[sample]['nom'])
  File "/Users/kratsg/.pyenv/versions/pyhf-dev/lib/python3.8/site-packages/pyhf/tensor/jax_backend.py", line 298, in concatenate
    return jnp.concatenate(sequence, axis=axis)
  File "/Users/kratsg/.pyenv/versions/pyhf-dev/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 3382, in concatenate
    _check_arraylike("concatenate", *arrays)
  File "/Users/kratsg/.pyenv/versions/pyhf-dev/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 560, in _check_arraylike
    raise TypeError(msg.format(fun_name, type(arg), pos))
TypeError: concatenate requires ndarray or scalar arguments, got <class 'list'> at position 0.

Copy link
Member

@matthewfeickert matthewfeickert left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @kratsg for this PR. This looks good overall to me, and I just have a few small questions and suggested changes.

I also agree with @lukasheinrich that using

default_backend = pyhf.default_backend

is a nicer pattern (and certainly made this easier to review) so thanks for adopting that. 👍

src/pyhf/tensor/manager.py Outdated Show resolved Hide resolved
tests/test_public_api_repr.py Outdated Show resolved Hide resolved
tests/test_simplemodels.py Outdated Show resolved Hide resolved
tests/test_simplemodels.py Outdated Show resolved Hide resolved
src/pyhf/tensor/manager.py Show resolved Hide resolved
src/pyhf/__init__.py Outdated Show resolved Hide resolved
Comment on lines +1 to +2
import pyhf
from pyhf.tensor.manager import get_backend
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just so I understand, this is necessary to avoid

ImportError: cannot import name 'default_backend' from partially initialized module 'pyhf' (most likely due to a circular import)

because pyhf.defaul_backend needs to be evaluated as an attr and so doesn't exist as a static attribute? Or am I just missing the obvious here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Global vs non-global. You don't want to evaluate default_backend at import time.

Copy link
Member

@matthewfeickert matthewfeickert left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. 👍 Thanks!

@matthewfeickert matthewfeickert merged commit 5ea4e0a into master Oct 22, 2021
@matthewfeickert matthewfeickert deleted the feat/configurable-default-backend branch October 22, 2021 16:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
API Changes the public API feat/enhancement New feature or request refactor A code change that neither fixes a bug nor adds a feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Explicit setting of the default backend
4 participants