-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding log_likelihood, observed_data, and sample_stats to numpyro sampler #5189
Adding log_likelihood, observed_data, and sample_stats to numpyro sampler #5189
Conversation
Codecov Report
@@ Coverage Diff @@
## main #5189 +/- ##
==========================================
- Coverage 78.11% 77.97% -0.15%
==========================================
Files 88 88
Lines 14159 14210 +51
==========================================
+ Hits 11061 11080 +19
- Misses 3098 3130 +32
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Just left a comment below to avoid recreating the same code
3ba2643
to
8fa77c3
Compare
7d742f0
to
661ca8c
Compare
79dd918
to
f5aeaf6
Compare
Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
Thanks @zaxtax. Awesome work! |
logp_v = replace_shared_variables([logpt(v)]) | ||
fgraph = FunctionGraph(model.value_vars, logp_v, clone=False) | ||
jax_fn = jax_funcify(fgraph) | ||
result = jax.vmap(jax.vmap(jax_fn))(*samples)[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out of curiosity, would we expect any benefits to jit_compiling this outer vmap
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be possible to use a similar approach with Aesara directly?
Here we only loop over observed variables in order to get the pointwise log likelihood. We had some discussion about this in #4489 but ended up keeping the 3 nested loops over variables, chains and draws.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be possible, but requires a Aesara Scan, and at least for small models this was not faster than python looping when I checked it. Here is a Notebook that documents some things I tried: https://gist.github.com/ricardoV94/6089a8c46a0e19665f01c79ea04e1cb2
It might be faster if using shared variables...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No idea. I think the easiest thing to do is just benchmark it. I don't even call optimize_graph
on either the graph in this function or the main sample routine.
When I run the model in the unit test with the change
result = jax.vmap(jax.vmap(jax_fn))(*samples)[0]
to
result = jax.jit(jax.vmap(jax.vmap(jax_fn)))(*samples)[0]
I don't really get a speed-up until there are millions of samples.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't even call
optimize_graph
on either the graph in this function or the main sample routine
We should definitely call optimize_graph
, otherwise the computed logps may not correspond to the ones used during sampling. For instance we have many optimizations that improve numerically stability, so you might get underflows to -inf
for some of the posterior samples (which would never have been accepted by NUTS) which could screw up things downstream.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be possible, but requires a Aesara Scan, and at least for small models this was not faster than python looping when I checked it.
Then it's probably not worth it. I was under the impression it would be possible to vectorize/broadcast the operation from the conversations in #4489 and in slack.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It must be possible, since the vmap above works just fine. I just have no idea how they do it xD, or how/if you could do it in Aesara. I also wonder whether the vmap works for more complicated models with multivariate distributions and the like
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright. I'm going to make a separate PR for some of this other stuff.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool, feel free to tag me if you want me to review, I am not watching PRs. I can already say I won't be able to help with the vectorized log_likelihood thing, I tried and I lost much more time with that than what would have been healthy. I should be able to help with coords and dims though
Thanks! We should document that while posterior, log_likelihood, sample_stats and observed_data groups will be created, all coords and dims are ignored unlike with the "regular" backend. Is sample numpyro in the docs already? Should it be? |
I guess we could also retrieve these, no?
If it's not yet, it should! |
@OriolAbril I think it would take me about the same effort to document the discrepancy as just do the correct thing with coords and dims. Out of curiosity, there used to be a jax sampler based on TFP. Did that just silently get dropped? |
Thanks! I assume the code is already written actually, either here in the |
Yes |
This adds more fields to the trace object returned from
sample_numpyro_nuts
addressing some of the concerns in #5121