-
Notifications
You must be signed in to change notification settings - Fork 3
This issue was moved to a discussion.
You can continue the conversation there. Go to discussion →
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
pyhf features #1
Comments
Dear @lukasheinrich, thank you very much for your comment! Is there some sort of platform to discuss developments regarding binned likelihood fits? I would be very happy to join and contribute more in discussions/developments of these kind of tools. Thank you very much again for your kind words! Best, Peter |
Hi @pfackeldey, really cool to see more interest in this line of work :) To give a quick tl;dr of what's been going on:
Do you have a set of goals with this library/this work? I'm sure we could help with whatever direction you're planning to go in! RE: a common discussion place -- there's certainly some activity, but a lot of it happens in private slacks at the moment. I still think we lack a proper open discussion forum for this kind of thing, though at least in ATLAS there are some semi-active mailing lists. |
Hi @phinate, @lukasheinrich, thank you very much for the overview of what's going on! I've developed this library some days ago on a (long) train trip to a conference... It was primarily out of curiosity, my interest in these kind of fits, and to deepen my understanding in JAX a bit more. So far, I do not have concrete plans or a timeline. However, I have some ideas and a few philosophies/design-choices, which I want to follow up with this library. Some of them are listed here:
Most of these points seem rather straight forward to me (at least on first sight), but I would be very happy to learn from your experience here! There will be definitely pitfalls... Best, Peter |
You'd be in reasonable company with these ideas @pfackeldey -- me and @lukasheinrich have discussed making Looking through your code, I see you've also been on the same PyTree journey as me! @chex.dataclass was the second thing I tried, but things have gotten much easier now thanks to libraries like equinox and simple-pytree (the latter is my current favourite!) With three people, I think we have a good shot at giving this a go, but the immediate issues in my mind are:
If you're free @lukasheinrich, I think the best starting point for this could be to have an initial discussion meeting with us three, and then try and prototype a basic model that has Let me know what you both think! :) |
here's an example to get thinking based on something me and @lukasheinrich wrote for from __future__ import annotations
from typing import Any, Iterable
import jax
import jax.numpy as jnp
import pyhf
from simple_pytree import Pytree
pyhf.set_backend("jax")
class _Config:
def __init__(self) -> None:
self.poi_index = 0
self.npars = 2
def suggested_init(self) -> jax.Array:
return jnp.asarray([1.0, 1.0])
class Model(Pytree):
"""Dummy class to mimic the functionality of `pyhf.Model`."""
def __init__(self, spec: Iterable[Any]) -> None:
self.sig, self.nominal, self.uncert = spec
self.factor = (self.nominal / self.uncert) ** 2
self.aux = 1.0 * self.factor
self.config = _Config()
def expected_data(self, pars: jax.Array) -> jax.Array:
mu, gamma = pars
expected_main = jnp.asarray([gamma * self.nominal + mu * self.sig])
aux_data = jnp.asarray([self.aux])
return jnp.concatenate([expected_main, aux_data])
# logpdf as the call method
def __call__(self, pars: jax.Array, data: jax.Array) -> jax.Array:
maindata, auxdata = data
main, _ = self.expected_data(pars)
_, gamma = pars
main = pyhf.probability.Poisson(main).log_prob(maindata)
constraint = pyhf.probability.Poisson(gamma * self.factor).log_prob(auxdata)
# sum log probs over bins
return [jnp.sum(jnp.asarray([main + constraint]), axis=None)]
def uncorrelated_background(s: jax.Array, b: jax.Array, db: jax.Array) -> Model:
"""Dummy class to mimic the functionality of `pyhf.simplemodels.hepdata_like`."""
return Model([s, b, db])
model = uncorrelated_background(jnp.array([1,1]),jnp.array([2,2]),jnp.array([3,3]))
model = jax.jit(model) # the cool step! could also use vmap etc
model(pars=jnp.array([1,1]), data = [jnp.array([3,1]), jnp.array([0])])
# >> [Array(-4.2861992, dtype=float64)] |
That is very much in line with what I came up with! I did some minor differences in the API, but in principle this works very similar. I will pseudo-code the next snippet to highlight what I tried to add a bit to the model usage: from dilax.model import Model
from dilax.parameter import r, lnN
class MyModel(Model):
@jax.jit
def eval(self):
# function to calculate the expected bin yield(s)
# mu * sig + bkg (+ a bkg. norm. unc.)
sig = self.parameters["mu"].apply(self.processes["sig"])
bkg = self.parameters["norm"].apply(self.processes["bkg"])
return sig + bkg
model = MyModel(
processes={"sig": jnp.array([1.0]), "bkg": jnp.array([2.0])},
parameters={"mu": r(strength=jnp.array(1.0)), "norm": lnN(strength=jnp.array(0.0), width=jnp.array(0.1))},
)
# "functional" way of evaluating and updating the model
# default uses __init__ values of `processes` and `parameters`
print(model.eval())
>> jnp.array([3.0])
# change a parameter
print(model.apply(parameters={"mu": jnp.array([2.0])}).eval())
>> jnp.array([4.0])
# change a process expectation
print(model.apply(processes={"bkg": jnp.array([5.0])}).eval())
>> jnp.array([6.0]) The # ... assume we have performed a fit and extracted the fitted parameters in `best_fit_parameters` ...
from dilax.likelihood import nll
observation = jnp.array([4.0])
# now I want the gradient of the likelihood wrt to all parameters on the postfit model:
grad_nll = jax.grad(nll, argnums=0)
all_grad_postfitmodel = grad_nll(
best_fit_parameters,
model.apply(parameters=best_fit_parameters),
observation,
)
# or just for `mu`:
mu_grad_prefitmodel = grad_nll(
{"mu": best_fit_parameters["mu"]},
model,
observation,
)
# or for the fitted model wrt `mu`:
mu_grad_postfitmodel = grad_nll(
{"mu": best_fit_parameters["mu"]},
model.apply(parameters=best_fit_parameters),
observation,
) And in addition, this allows of course to do some nice (This comment was just to give you a bit of a feeling where I tried to go with the design-choices...) |
Yes, I would be happy to join a meeting and discuss these things more in details :) |
to come back to this, I think this approach will solve an issue I've been having with from functools import partial
import jax
import jax.numpy as jnp
import jax.scipy as jsp
import jaxopt
import optax
from typing import Any, Iterable
import pyhf
from simple_pytree import Pytree
pyhf.set_backend("jax")
class _Config(Pytree):
def __init__(self) -> None:
self.poi_index = 0
self.npars = 2
def suggested_init(self) -> jax.Array:
return jnp.asarray([1.0, 1.0])
class Model(Pytree):
"""Dummy class to mimic the functionality of `pyhf.Model`."""
def __init__(self, spec: Iterable[Any]) -> None:
self.sig, self.nominal, self.uncert = spec
self.factor = (self.nominal / self.uncert) ** 2
self.aux = 1.0 * self.factor
self.config = _Config()
def expected_data(self, pars: jax.Array) -> jax.Array:
mu, gamma = pars
expected_main = jnp.asarray([gamma * self.nominal + mu * self.sig])
aux_data = jnp.asarray([self.aux])
return jnp.concatenate([expected_main, aux_data])
# logpdf as the call method
def logpdf(self, pars: jax.Array, data: jax.Array) -> jax.Array:
maindata, auxdata = data
main, _ = self.expected_data(pars)
_, gamma = pars
main = pyhf.probability.Poisson(main).log_prob(maindata)
constraint = pyhf.probability.Poisson(gamma * self.factor).log_prob(auxdata)
# sum log probs over bins
return [jnp.sum(jnp.asarray([main + constraint]), axis=None)]
def uncorrelated_background(s: jax.Array, b: jax.Array, db: jax.Array) -> Model:
"""Dummy class to mimic the functionality of `pyhf.simplemodels.hepdata_like`."""
return Model([s, b, db])
@jax.jit
def pipeline(param_for_grad):
data=jnp.array([5.0, 5.0])
init_pars=jnp.array([1.0, 1.1])
lr=1e-3
model = Model(param_for_grad*jnp.array([1.0, 1.0, 1.0]))
def fit(pars, model, data):
def fit_objective(pars, model, data):
return -model.logpdf(pars, data)[0]
solver = jaxopt.OptaxSolver(
fun=fit_objective, opt=optax.adam(lr), implicit_diff=True, maxiter=5000
)
return solver.run(init_pars, model=model, data=data)[0]
return fit(init_pars, model, data)
jax.jacrev(pipeline)(jnp.asarray(0.5))
> works! If it's not clear what this is doing: it allows for differentiable optimization as targeted by |
Thanks for the example @phinate. As far as I can see this is exactly how |
nice, i actually got the above to work with your API too: import jax
import jax.numpy as jnp
import chex
from dilax.likelihood import nll
from dilax.parameter import r, lnN
from dilax.model import Model
from dilax.optimizer import JaxOptimizer
# Define model; i.e. how is the expectation calculated with (nuisance) parameters?
@chex.dataclass(frozen=True)
class SPlusBModel(Model):
@jax.jit
def eval(self) -> jax.Array:
expectation = jnp.array(0.0)
# modify affected processes
for process, sumw in self.processes.items():
# mu
if process == "signal":
sumw = self.parameters["mu"].apply(sumw)
# background norm
elif process == "bkg":
sumw = self.parameters["norm"].apply(sumw)
expectation += sumw
return expectation
@jax.jit
def pipeline(phi):
# Initialize S+B model
model = SPlusBModel(
processes={"signal": phi*jnp.array([3.0]), "bkg": phi*jnp.array([10.0])},
parameters={"mu": r(strength=jnp.array(1.0)), "norm": lnN(strength=jnp.array(0.0), width=jnp.array(0.1))},
)
# Define data
observation = jnp.array([15.0])
# Setup optimizer, see more at https://jaxopt.github.io/stable/
optimizer = JaxOptimizer.make(name="LBFGS", settings={"maxiter": 30, "tol": 1e-6, "implicit_diff": True})
# Run fit
params, _ = optimizer.fit(fun=nll, init_params=model.parameter_strengths, model=model, observation=observation)
return params
jax.jacrev(pipeline)(2.)
#> {'mu': Array(-1.2500001, dtype=float32, weak_type=True),
#'norm': Array(-5.662442e-07, dtype=float32, weak_type=True)} which is super nice to see! The real differences then are from an API standpoint. What's there in If you're able to meet in the next hour or so, send me an email via github -- I could have a quick chat about this if you want! Else we can properly schedule something later :) |
Randomly stumbled across this thread. @phinate, I'm curious about the preference for simple-pytree over Equinox. Is it simply that you don't want all the extra stuff provided by Equinox, when you're doing something minimal? |
Pretty much exactly this @patrick-kidger! I also like the syntax of inheriting something called I'm curious about these edge cases though, I think that's just something I've not needed as of yet. Have you found something like a bound method as a PyTree to be a useful concept in practice? (oh, and thanks also for your reassuring comment on the JAXopt issue! leaked tracers are something i've yet to find a good way to debug consistently...) |
Gotcha. I'm tempted to say we could split out (Honestly Equinox could be split into 4 separate libraries at this point: Naming: yeah, the Bound methods: I added this mostly just for the certainty of being safe from a possible footgun. It's certainly something other people seem to hit numerous times (e.g. jax-ml/jax#15338, jax-ml/jax#9672 are recent examples). (If it's interesting, various other edge cases Leaked tracers: FWIW this is an issue I hit approximately-never, just by working exclusively with pytrees. (And never raw Python classes.) |
I think it's an important feature deserving of a solution outside the neural network context; tempted to say it could even be worthy of being in jax core. But I don't think there would be a real impact from any of these decisions -- it could be nice to take the good bits of eqx and simple_pytree into one standard lib, which you could then just rename in eqx back to if you wanna chat more, let's use Twitter DMs or similar to keep this issue clean @patrick-kidger :) |
Hi @phinate, @patrick-kidger, I finally managed to come back to this project and updated its core to use equinox 🎉 and I really appreciate it!! There have been some additional HEP-related updates, which might be interesting for you @phinate to discuss at some point... Currently I'm using primarily Best, Peter |
Hi there! I get various versions of this request from time-to-time. The most common one is for
|
Thank you for your insights @patrick-kidger! Yes, that makes a lot of sense to me :) Indeed, I do not really mind that |
This issue was moved to a discussion.
You can continue the conversation there. Go to discussion →
Hi @pfackeldey -
this looks nice! I'm curious if you've checked out pyhf and or https://github.com/gradhep/relaxed where @phinate has been working towards something similar.
Cheers,
Lukas
The text was updated successfully, but these errors were encountered: