Skip to content

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pyhf features #1

Closed
lukasheinrich opened this issue Apr 1, 2023 · 17 comments
Closed

pyhf features #1

lukasheinrich opened this issue Apr 1, 2023 · 17 comments

Comments

@lukasheinrich
Copy link

Hi @pfackeldey -

this looks nice! I'm curious if you've checked out pyhf and or https://github.com/gradhep/relaxed where @phinate has been working towards something similar.

Cheers,
Lukas

@pfackeldey
Copy link
Owner

pfackeldey commented Apr 2, 2023

Dear @lukasheinrich,

thank you very much for your comment!
Yes, I know pyhf and really appreciate it! Although my experience with pyhf is a bit limited (it is not so widely used in the CMS Collaboration yet).
Thank you for the link to gradhep/relaxed. Indeed, this looks very familiar... I'm very happy to see active developments in this direction! :)

Is there some sort of platform to discuss developments regarding binned likelihood fits? I would be very happy to join and contribute more in discussions/developments of these kind of tools.

Thank you very much again for your kind words!

Best, Peter

@phinate
Copy link
Contributor

phinate commented Apr 3, 2023

Hi @pfackeldey, really cool to see more interest in this line of work :)

To give a quick tl;dr of what's been going on:

  • pyhf is a pretty robust implementation of the histfactory model spec that's also used in combine (there's ongoing work to make them match exactly that's going well afiak!)
  • it's backed by jax already, meaning that jax is being used as the engine to do fits, set limits etc
  • we've more or less made this fully differentiable too -- this lets you use pyhf models in optimization loops (see e.g. http://github.com/gradhep/neos, which uses relaxed under the hood to make histograms/fits differentiable).

Do you have a set of goals with this library/this work? I'm sure we could help with whatever direction you're planning to go in!

RE: a common discussion place -- there's certainly some activity, but a lot of it happens in private slacks at the moment. I still think we lack a proper open discussion forum for this kind of thing, though at least in ATLAS there are some semi-active mailing lists.

@pfackeldey
Copy link
Owner

Hi @phinate, @lukasheinrich,

thank you very much for the overview of what's going on!

I've developed this library some days ago on a (long) train trip to a conference... It was primarily out of curiosity, my interest in these kind of fits, and to deepen my understanding in JAX a bit more.

So far, I do not have concrete plans or a timeline. However, I have some ideas and a few philosophies/design-choices, which I want to follow up with this library. Some of them are listed here:

  1. Be as 'friendly' as possible to the JAX environment to easily integrate libraries from the JAX ecosystem
  2. Exploit vectorisation as much as possible, i.e. to be more efficient on (multiple) GPUs
  3. Seperate the model from its parameters (similar to flax) in jittable and easy-to-introspect objects (PyTrees)
  4. Minor additions: checkpointing, (TensorBoard) logging, serialisation of PyTrees (e.g. parameters), ...

Most of these points seem rather straight forward to me (at least on first sight), but I would be very happy to learn from your experience here! There will be definitely pitfalls...

Best, Peter

@phinate
Copy link
Contributor

phinate commented Apr 5, 2023

You'd be in reasonable company with these ideas @pfackeldey -- me and @lukasheinrich have discussed making pyhf models PyTrees a few times. He has a much better idea of where the gains are in terms of inference, but I believe things like vmapping limits across a number of points/models could be possible there, as well as adding a lot more @jit. I also think this type of rewrite could help solve the outstanding issues with making pyhf differentiable.

Looking through your code, I see you've also been on the same PyTree journey as me! @chex.dataclass was the second thing I tried, but things have gotten much easier now thanks to libraries like equinox and simple-pytree (the latter is my current favourite!)

With three people, I think we have a good shot at giving this a go, but the immediate issues in my mind are:

  • Merging something into pyhf is better for the community, but may be more work technically compared to making a new start [but there, we risk making something that is never used]
  • Highlighting the ways in which this type of model could outperform the existing implementation (to help guide development)

If you're free @lukasheinrich, I think the best starting point for this could be to have an initial discussion meeting with us three, and then try and prototype a basic model that has pyhf-compatible syntax. We could then work on a gradhep
fork of pyhf if needed.

Let me know what you both think! :)

@phinate
Copy link
Contributor

phinate commented Apr 5, 2023

here's an example to get thinking based on something me and @lukasheinrich wrote for neos:

from __future__ import annotations

from typing import Any, Iterable

import jax
import jax.numpy as jnp
import pyhf
from simple_pytree import Pytree
        
pyhf.set_backend("jax")


class _Config:
    def __init__(self) -> None:
        self.poi_index = 0
        self.npars = 2

    def suggested_init(self) -> jax.Array:
        return jnp.asarray([1.0, 1.0])


class Model(Pytree):
    """Dummy class to mimic the functionality of `pyhf.Model`."""

    def __init__(self, spec: Iterable[Any]) -> None:
        self.sig, self.nominal, self.uncert = spec
        self.factor = (self.nominal / self.uncert) ** 2
        self.aux = 1.0 * self.factor
        self.config = _Config()

    def expected_data(self, pars: jax.Array) -> jax.Array:
        mu, gamma = pars
        expected_main = jnp.asarray([gamma * self.nominal + mu * self.sig])
        aux_data = jnp.asarray([self.aux])
        return jnp.concatenate([expected_main, aux_data])

    # logpdf as the call method
    def __call__(self, pars: jax.Array, data: jax.Array) -> jax.Array:
        maindata, auxdata = data
        main, _ = self.expected_data(pars)
        _, gamma = pars
        main = pyhf.probability.Poisson(main).log_prob(maindata)
        constraint = pyhf.probability.Poisson(gamma * self.factor).log_prob(auxdata)
        # sum log probs over bins
        return [jnp.sum(jnp.asarray([main + constraint]), axis=None)]


def uncorrelated_background(s: jax.Array, b: jax.Array, db: jax.Array) -> Model:
    """Dummy class to mimic the functionality of `pyhf.simplemodels.hepdata_like`."""
    return Model([s, b, db])


model = uncorrelated_background(jnp.array([1,1]),jnp.array([2,2]),jnp.array([3,3]))

model = jax.jit(model)  # the cool step! could also use vmap etc

model(pars=jnp.array([1,1]), data = [jnp.array([3,1]), jnp.array([0])])
# >> [Array(-4.2861992, dtype=float64)]

@pfackeldey
Copy link
Owner

That is very much in line with what I came up with! I did some minor differences in the API, but in principle this works very similar. I will pseudo-code the next snippet to highlight what I tried to add a bit to the model usage:

from dilax.model import Model
from dilax.parameter import r, lnN

class MyModel(Model):
    @jax.jit
    def eval(self):
        # function to calculate the expected bin yield(s)
        # mu * sig + bkg (+ a bkg. norm. unc.)
        sig = self.parameters["mu"].apply(self.processes["sig"])
        bkg = self.parameters["norm"].apply(self.processes["bkg"])
        return sig + bkg
        
model = MyModel(
    processes={"sig": jnp.array([1.0]), "bkg": jnp.array([2.0])},
    parameters={"mu": r(strength=jnp.array(1.0)), "norm": lnN(strength=jnp.array(0.0), width=jnp.array(0.1))},
)

# "functional" way of evaluating and updating the model
# default uses __init__ values of `processes` and `parameters`
print(model.eval())
>> jnp.array([3.0])

# change a parameter
print(model.apply(parameters={"mu": jnp.array([2.0])}).eval())
>> jnp.array([4.0])

# change a process expectation
print(model.apply(processes={"bkg": jnp.array([5.0])}).eval())
>> jnp.array([6.0])

The model only holds priors and expectations and knows how to calculate the histogram expectation. The nll is decoupled and just takes the model + observation (data) + a certain set of parameters as arguments. The set of parameters might be all or just a subset of what the model knows of. This way I can differentiate wrt any combination of parameters. Another very pseudo-code snippet highlights this a bit:

# ... assume we have performed a fit and extracted the fitted parameters in `best_fit_parameters` ...
from dilax.likelihood import nll

observation = jnp.array([4.0])

# now I want the gradient of the likelihood wrt to all parameters on the postfit model:
grad_nll = jax.grad(nll, argnums=0)
all_grad_postfitmodel = grad_nll(
    best_fit_parameters,
    model.apply(parameters=best_fit_parameters),
    observation,
)

# or just for `mu`:
mu_grad_prefitmodel = grad_nll(
    {"mu": best_fit_parameters["mu"]},
    model,
    observation,
)

# or for the fitted model wrt `mu`:
mu_grad_postfitmodel = grad_nll(
    {"mu": best_fit_parameters["mu"]},
    model.apply(parameters=best_fit_parameters),
    observation,
)

And in addition, this allows of course to do some nice vmaps for e.g. likelihood profiling, where I can vectorise along multiple likelihood minimisations + evaluations for different fixed parameter values.

(This comment was just to give you a bit of a feeling where I tried to go with the design-choices...)

@pfackeldey
Copy link
Owner

You'd be in reasonable company with these ideas @pfackeldey -- me and @lukasheinrich have discussed making pyhf models PyTrees a few times. He has a much better idea of where the gains are in terms of inference, but I believe things like vmapping limits across a number of points/models could be possible there, as well as adding a lot more @jit. I also think this type of rewrite could help solve the outstanding issues with making pyhf differentiable.

Looking through your code, I see you've also been on the same PyTree journey as me! @chex.dataclass was the second thing I tried, but things have gotten much easier now thanks to libraries like equinox and simple-pytree (the latter is my current favourite!)

With three people, I think we have a good shot at giving this a go, but the immediate issues in my mind are:

  • Merging something into pyhf is better for the community, but may be more work technically compared to making a new start [but there, we risk making something that is never used]
  • Highlighting the ways in which this type of model could outperform the existing implementation (to help guide development)

If you're free @lukasheinrich, I think the best starting point for this could be to have an initial discussion meeting with us three, and then try and prototype a basic model that has pyhf-compatible syntax. We could then work on a gradhep fork of pyhf if needed.

Let me know what you both think! :)

Yes, I would be happy to join a meeting and discuss these things more in details :)

@phinate
Copy link
Contributor

phinate commented Apr 14, 2023

to come back to this, I think this approach will solve an issue I've been having with jaxopt and closure_convert (see this jaxopt issue)

from functools import partial

import jax
import jax.numpy as jnp
import jax.scipy as jsp
import jaxopt
import optax
from typing import Any, Iterable

import pyhf
from simple_pytree import Pytree
        
pyhf.set_backend("jax")


class _Config(Pytree):
    def __init__(self) -> None:
        self.poi_index = 0
        self.npars = 2

    def suggested_init(self) -> jax.Array:
        return jnp.asarray([1.0, 1.0])


class Model(Pytree):
    """Dummy class to mimic the functionality of `pyhf.Model`."""

    def __init__(self, spec: Iterable[Any]) -> None:
        self.sig, self.nominal, self.uncert = spec
        self.factor = (self.nominal / self.uncert) ** 2
        self.aux = 1.0 * self.factor
        self.config = _Config()

    def expected_data(self, pars: jax.Array) -> jax.Array:
        mu, gamma = pars
        expected_main = jnp.asarray([gamma * self.nominal + mu * self.sig])
        aux_data = jnp.asarray([self.aux])
        return jnp.concatenate([expected_main, aux_data])

    # logpdf as the call method
    def logpdf(self, pars: jax.Array, data: jax.Array) -> jax.Array:
        maindata, auxdata = data
        main, _ = self.expected_data(pars)
        _, gamma = pars
        main = pyhf.probability.Poisson(main).log_prob(maindata)
        constraint = pyhf.probability.Poisson(gamma * self.factor).log_prob(auxdata)
        # sum log probs over bins
        return [jnp.sum(jnp.asarray([main + constraint]), axis=None)]


def uncorrelated_background(s: jax.Array, b: jax.Array, db: jax.Array) -> Model:
    """Dummy class to mimic the functionality of `pyhf.simplemodels.hepdata_like`."""
    return Model([s, b, db])


@jax.jit
def pipeline(param_for_grad):
    data=jnp.array([5.0, 5.0])
    init_pars=jnp.array([1.0, 1.1])
    lr=1e-3

    model = Model(param_for_grad*jnp.array([1.0, 1.0, 1.0]))

    def fit(pars, model, data):
        def fit_objective(pars, model, data):
            return -model.logpdf(pars, data)[0]

        solver = jaxopt.OptaxSolver(
            fun=fit_objective, opt=optax.adam(lr), implicit_diff=True, maxiter=5000
        )
        return solver.run(init_pars, model=model, data=data)[0]

    return fit(init_pars, model, data)

jax.jacrev(pipeline)(jnp.asarray(0.5))
> works!

If it's not clear what this is doing: it allows for differentiable optimization as targeted by relaxed, and allows one to jit the whole pipeline. I think this is probably the key ingredient to making this kind of differentiable workflow viable -- the current solution is broken, and other solutions are potentially more effort (model updates in-place). I'm optimistic for the time being!

@pfackeldey
Copy link
Owner

Thanks for the example @phinate. As far as I can see this is exactly how dilax currently implements profiling of parameters. Basically it's your pipeline with an additional vmap across multiple fixed parameters (and optimize with respect to the rest), to profile the likelihood as a function of a parameter (vectorized).

@phinate
Copy link
Contributor

phinate commented Apr 14, 2023

Thanks for the example @phinate. As far as I can see this is exactly how dilax currently implements profiling of parameters.

nice, i actually got the above to work with your API too:

import jax
import jax.numpy as jnp

import chex

from dilax.likelihood import nll
from dilax.parameter import r, lnN
from dilax.model import Model
from dilax.optimizer import JaxOptimizer


# Define model; i.e. how is the expectation calculated with (nuisance) parameters?
@chex.dataclass(frozen=True)
class SPlusBModel(Model):
    @jax.jit
    def eval(self) -> jax.Array:
        expectation = jnp.array(0.0)

        # modify affected processes
        for process, sumw in self.processes.items():
            # mu
            if process == "signal":
                sumw = self.parameters["mu"].apply(sumw)

            # background norm
            elif process == "bkg":
                sumw = self.parameters["norm"].apply(sumw)

            expectation += sumw
        return expectation

@jax.jit
def pipeline(phi):

    # Initialize S+B model
    model = SPlusBModel(
        processes={"signal": phi*jnp.array([3.0]), "bkg": phi*jnp.array([10.0])},
        parameters={"mu": r(strength=jnp.array(1.0)), "norm": lnN(strength=jnp.array(0.0), width=jnp.array(0.1))},
    )

    # Define data
    observation = jnp.array([15.0])

    # Setup optimizer, see more at https://jaxopt.github.io/stable/ 
    optimizer = JaxOptimizer.make(name="LBFGS", settings={"maxiter": 30, "tol": 1e-6, "implicit_diff": True})

    # Run fit
    params, _ = optimizer.fit(fun=nll, init_params=model.parameter_strengths, model=model, observation=observation)

    return params

jax.jacrev(pipeline)(2.)
#> {'mu': Array(-1.2500001, dtype=float32, weak_type=True),
#'norm': Array(-5.662442e-07, dtype=float32, weak_type=True)}

which is super nice to see!

The real differences then are from an API standpoint. What's there in relaxed now can mimic the pyhf API, and is tested against that in CI to give the same numerical results for fitting, hypothesis testing etc. Given that the above differentiable example works with dilax, I believe that we could perhaps use dilax and relaxed as a starting point to design something that can mimic the API and numerics of pyhf, more as a technical prototype to explore what's possible without re-inventing the wheel too much.

If you're able to meet in the next hour or so, send me an email via github -- I could have a quick chat about this if you want! Else we can properly schedule something later :)

@patrick-kidger
Copy link

Randomly stumbled across this thread. @phinate, I'm curious about the preference for simple-pytree over Equinox.
FWIW I believe eqx.Module tackles several edge cases that simple-pytree leaves out (probably deliberately), e.g. bound methods being pytrees too.

Is it simply that you don't want all the extra stuff provided by Equinox, when you're doing something minimal?

@phinate
Copy link
Contributor

phinate commented Apr 14, 2023

Is it simply that you don't want all the extra stuff provided by Equinox, when you're doing something minimal?

Pretty much exactly this @patrick-kidger! I also like the syntax of inheriting something called Pytree -- it's semantically clearer when explaining to other people what that's doing as opposed to thinking in terms of neural network modules etc (purely a naming accident of course, as those more familiar with JAX/eqx understand what it does!)

I'm curious about these edge cases though, I think that's just something I've not needed as of yet. Have you found something like a bound method as a PyTree to be a useful concept in practice?

(oh, and thanks also for your reassuring comment on the JAXopt issue! leaked tracers are something i've yet to find a good way to debug consistently...)

@patrick-kidger
Copy link

Gotcha. I'm tempted to say we could split out eqx.Module into a separate mini-library (that Eqx would depend upon) if there's a use-case for that. But I suspect that may just complicate matters in the long run. Any opinion?

(Honestly Equinox could be split into 4 separate libraries at this point: eqx.Module, most-of-Equinox, eqx.nn, and eqx.internal. But again, so far I've prefered to keep things together for simplicity.)

Naming: yeah, the PyTree name is quite nice! However this does conflict with jaxtyping.PyTree, and that's actually encapsulating a slightly broader concept. (As it also refers to lists/tuples/dictionaries/custom nodes, and supports subscripting to indicate leaf type.)

Bound methods: I added this mostly just for the certainty of being safe from a possible footgun. It's certainly something other people seem to hit numerous times (e.g. jax-ml/jax#15338, jax-ml/jax#9672 are recent examples).

(If it's interesting, various other edge cases eqx.Module handles include inheritance, hashing, pretty printing, and perpetuating annotations+docstrings etc. across flattening+unflattening.)

Leaked tracers: FWIW this is an issue I hit approximately-never, just by working exclusively with pytrees. (And never raw Python classes.)

@phinate
Copy link
Contributor

phinate commented Apr 17, 2023

Gotcha. I'm tempted to say we could split out eqx.Module into a separate mini-library (that Eqx would depend upon) if there's a use-case for that. But I suspect that may just complicate matters in the long run. Any opinion?

I think it's an important feature deserving of a solution outside the neural network context; tempted to say it could even be worthy of being in jax core. But I don't think there would be a real impact from any of these decisions -- it could be nice to take the good bits of eqx and simple_pytree into one standard lib, which you could then just rename in eqx back to Module through a wrapper class. In any case, I think people that search for a solution to this issue will probably be googling things like 'pytree jax class' etc, which is why it might be confusing to end up at a neural network library as the solution.

if you wanna chat more, let's use Twitter DMs or similar to keep this issue clean @patrick-kidger :)

@pfackeldey
Copy link
Owner

Hi @phinate, @patrick-kidger,

I finally managed to come back to this project and updated its core to use equinox 🎉 and I really appreciate it!! There have been some additional HEP-related updates, which might be interesting for you @phinate to discuss at some point...

Currently I'm using primarily eqx.Module and some transformations (eqx.filter_jit and eqx.filter_vmap). In the future I expect to additionally make use of the following equinox features: runtime errors, pytree manipulation, and equinox' serialisation of pytree leaves; users might want to use equinox' debugging tools. However, there is (and will be) no use-case here for anything from eqx.nn. Are there any plans in the future to seperate neural network related features from equinox into a separate library @patrick-kidger?

Best, Peter

@patrick-kidger
Copy link

Hi there!

I get various versions of this request from time-to-time. The most common one is for eqx.Module to be in its own library, for example. (See the discussion above.) But right now I'm not inclined to do this:

  • eqx.nn is very small in terms of lines of code.
  • It would get a lot harder to maintain if I did so.
  • The JAX ecosystem already has a bit of a problem on this front. You need JAX+Equinox+Optax just to match what PyTorch ships with by default.
  • This makes the new-user onboarding story much more complicated.
  • No-one really seems to mind e.g. scipy shipping with loads of features that they often don't use!

@pfackeldey
Copy link
Owner

Thank you for your insights @patrick-kidger!

Yes, that makes a lot of sense to me :) Indeed, I do not really mind that equinox ships with more features than I need in this project. I was just curious if there was a such a plan/roadmap for the equinox project.

Repository owner locked and limited conversation to collaborators Apr 9, 2024
@pfackeldey pfackeldey converted this issue into discussion #14 Apr 9, 2024

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants