-
Notifications
You must be signed in to change notification settings - Fork 110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement SVGD #512
Implement SVGD #512
Conversation
I've made a quick collab notebook to play around with the algorithm. https://colab.research.google.com/drive/115rjkthFjz10d_l_M2BqAQdn-kaDdulC?usp=sharing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this PR!
Let me know if there is anything that needs clarification on my review (or if I've missed anything).
To be able to do pmap and to be able to use minbatches for the gradient computations, you can consider taking as input an optional grad_estimator
or grad_fn
that allows the user to specify its own gradient computations, and have it default to the usual jax.grad
of the log density. In the case of minibatches, the script sgmcmc/gradients.py
has all the basic components you'll need already implemented!
An example notebook meant to run on GPU can be left as a separate PR for blackjax/sampling-book
, it would make a very nice addition to this implementation.
I might be the best use of your time to leave the tests for when you are on the final draft of the implementation, i.e. when you have minibatch gradients, pmapping, etc.
Thanks for the detailed comments! I'll try and get to them next week 😃 |
Codecov Report
@@ Coverage Diff @@
## main #512 +/- ##
==========================================
- Coverage 99.30% 99.29% -0.02%
==========================================
Files 48 48
Lines 2019 1989 -30
==========================================
- Hits 2005 1975 -30
Misses 14 14
|
I had to implement SVGD for a project I'm working on. The latest commit is a working version with the improvements from my review. I've tested it on medium-sized models and the vmapping hasn't broken yet, so we can leave it like this for now. It might still need some work on the documentation, other than that it should be ready to merge. |
dacf636
to
749243d
Compare
SVGD done pass gradient function instead of logprob function, allowing gradient estimators, on SVGD pass positive gradients instead of negative gradients to SVGD move kernels to individual folders (restructuring) rewording and documentation, adding median heuristic for RBF kernel Co-authored-by: Antoine Carnec <antoinecarnec@gmail.com> Co-authored-by: Alberto Cabezas <a.cabezasgonzalez@lancaster.ac.uk>
Thank you both! |
SVGD done pass gradient function instead of logprob function, allowing gradient estimators, on SVGD pass positive gradients instead of negative gradients to SVGD move kernels to individual folders (restructuring) rewording and documentation, adding median heuristic for RBF kernel Co-authored-by: Alberto Cabezas <a.cabezasgonzalez@lancaster.ac.uk>
As discussed in #385.
Done:
TODO:
Appreciate any comments & feedback!
A few important guidelines and requirements before we can merge your PR:
main
commit;pre-commit
is installed and configured on your machine, and you ran it before opening the PR;