-
Notifications
You must be signed in to change notification settings - Fork 108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
numpyro and blackjax samplers producing different results #144
Comments
Hello! |
@theorashid id be interested to know if this works with the latest version ! |
Hey both, so I left this stale because of the But I rewrote my model for the new blackjax API with numpyro earlier today and tested it out. There's good news and bad news: Good news: the model is converging nicely This is a weird one. Any ideas? Code is here. |
Yay, great news! The values are negative because numpyro transforms the variables so that their values are between -infty and +infty before sampling (NUTS works much better with this kind of values). The logpdf that numpyro returns works with these "unconstrained" variables, so this is what blackjax returns. To get values of the untransformed variablles you thus need to apply the inverse of the transformation that numpyro used, which I guess here is the absolute value. Does that make sense ? |
Also did you run the code with the latest version on master or the last release on Pypi ? |
Okay great, I was wondering how numpyro was sampling a normal with negative sigma. Glad to hear it all makes sense. I ran this using the latest pypi release – 0.7.0 blackjax 0.7.0 pypi_0 pypi |
Great, thank you for rerunning the example! |
Is there a general method to know which inverse transformation to apply to return the constrained parameter in the case of Numpyro while using Blackjax? In the case of sigma, it can be both absolute and exponential. How to decide which one it is @rlouf? |
I am also seeing a poor performance when using numpyro.initialize_model and looking at the samples I think it might also be related to some transfromation? Was there a solution to this found in the end? |
No, I wasn't able to find a solution it :( |
For numpyro, default transformation are associated with each bounded distribution at init time. e.g., for LogNormal the default transformation is Exp: |
There should be a way to get all the transformation being applied using numpyro's effect handler, maybe @fehiepsi know? |
Hi all, managed to get what I think is a solution! In the end I made use of the util function init_params_unconstrained, potential_fn_gen, post_proc_fun, model_trace = numpyro.infer.util.initialize_model(
rng_key=...,
model=...,
model_args=...
dynamic_args=..,
) Then you can return a function that undoes the transformations as: unmapper = post_proc_fun(None)
samples_constrained = unmapper(samples_unconstrained) Hope this helps! |
Bug Description
I followed the use-with-numpyro notebook to get a model that works with numpyro's sampler running on blackjax. The model runs (quickly) but the values it produces are well off. It is suspected this is due to a poor choice of step size and mass matrix.
Steps/Code to Reproduce
The code requires external data, so it is best to clone the repo if the problem isn't immediately solvable. The working numpyro code is here and the attempt using blackjax is here. The numpyro model and blackjax code is also below.
Expected Results
As an example, for the "home" parameter, should be around 0.2-0.3. The "sd_att" and "sd_def" parameters should be constrained by the model to be positive (using
FoldedDistribution()
).Actual Results
The parameter values are well off:
Versions
BlackJAX 0.2.1
numpyro 0.7.2
Python 3.8.0 | packaged by conda-forge | (default, Nov 22 2019, 19:11:19)
[Clang 9.0.0 (tags/RELEASE_900/final)]
Jax 0.2.17
Jaxlib 0.1.67
The text was updated successfully, but these errors were encountered: