Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding log_likelihood, observed_data, and sample_stats to numpyro sampler #5189
Adding log_likelihood, observed_data, and sample_stats to numpyro sampler #5189
Changes from 9 commits
549f714
ff3a0cd
6987bac
af108d4
f295288
8a28fad
59ebfdb
f5aeaf6
bf2ad0d
a3c6cc2
0098876
5f5cd87
6e4fcab
File filter
Filter by extension
Conversations
Jump to
There are no files selected for viewing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out of curiosity, would we expect any benefits to jit_compiling this outer
vmap
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be possible to use a similar approach with Aesara directly?
Here we only loop over observed variables in order to get the pointwise log likelihood. We had some discussion about this in #4489 but ended up keeping the 3 nested loops over variables, chains and draws.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be possible, but requires a Aesara Scan, and at least for small models this was not faster than python looping when I checked it. Here is a Notebook that documents some things I tried: https://gist.github.com/ricardoV94/6089a8c46a0e19665f01c79ea04e1cb2
It might be faster if using shared variables...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No idea. I think the easiest thing to do is just benchmark it. I don't even call
optimize_graph
on either the graph in this function or the main sample routine.When I run the model in the unit test with the change
result = jax.vmap(jax.vmap(jax_fn))(*samples)[0]
toresult = jax.jit(jax.vmap(jax.vmap(jax_fn)))(*samples)[0]
I don't really get a speed-up until there are millions of samples.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should definitely call
optimize_graph
, otherwise the computed logps may not correspond to the ones used during sampling. For instance we have many optimizations that improve numerically stability, so you might get underflows to-inf
for some of the posterior samples (which would never have been accepted by NUTS) which could screw up things downstream.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then it's probably not worth it. I was under the impression it would be possible to vectorize/broadcast the operation from the conversations in #4489 and in slack.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It must be possible, since the vmap above works just fine. I just have no idea how they do it xD, or how/if you could do it in Aesara. I also wonder whether the vmap works for more complicated models with multivariate distributions and the like
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright. I'm going to make a separate PR for some of this other stuff.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool, feel free to tag me if you want me to review, I am not watching PRs. I can already say I won't be able to help with the vectorized log_likelihood thing, I tried and I lost much more time with that than what would have been healthy. I should be able to help with coords and dims though