Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement SVGD #512

Merged
merged 1 commit into from
Jun 1, 2023
Merged

Implement SVGD #512

merged 1 commit into from
Jun 1, 2023

Conversation

antotocar34
Copy link
Contributor

As discussed in #385.

Done:

  • Implementation of the algorithm
  • Added very simple (bad) test.

TODO:

  • Add an example with GPU where gradients is calculated with pmap.
  • How to take mini batch gradients?
  • Come up with a better test.

Appreciate any comments & feedback!

A few important guidelines and requirements before we can merge your PR:

  • [ X ] If I add a new sampler, there is an issue discussing it already;
  • [ X ] We should be able to understand what the PR does from its title only;
  • [ X ] There is a high-level description of the changes;
  • [ X ] There are links to all the relevant issues, discussions and PRs;
  • [ X ] The branch is rebased on the latest main commit;
  • Commit messages follow these guidelines;
  • The code respects the current naming conventions;
  • Docstrings follow the numpy style guide
  • [ X ] pre-commit is installed and configured on your machine, and you ran it before opening the PR;
  • [ ~ ] There are tests covering the changes;
  • The doc is up-to-date;
  • If I add a new sampler* I added/updated related examples

@antotocar34 antotocar34 marked this pull request as draft March 24, 2023 16:59
@antotocar34
Copy link
Contributor Author

I've made a quick collab notebook to play around with the algorithm.

https://colab.research.google.com/drive/115rjkthFjz10d_l_M2BqAQdn-kaDdulC?usp=sharing

@albcab albcab self-requested a review March 27, 2023 08:33
Copy link
Member

@albcab albcab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this PR!

Let me know if there is anything that needs clarification on my review (or if I've missed anything).

To be able to do pmap and to be able to use minbatches for the gradient computations, you can consider taking as input an optional grad_estimator or grad_fn that allows the user to specify its own gradient computations, and have it default to the usual jax.grad of the log density. In the case of minibatches, the script sgmcmc/gradients.py has all the basic components you'll need already implemented!

An example notebook meant to run on GPU can be left as a separate PR for blackjax/sampling-book, it would make a very nice addition to this implementation.

I might be the best use of your time to leave the tests for when you are on the final draft of the implementation, i.e. when you have minibatch gradients, pmapping, etc.

@antotocar34
Copy link
Contributor Author

Thanks for the detailed comments! I'll try and get to them next week 😃

@codecov
Copy link

codecov bot commented May 4, 2023

Codecov Report

Merging #512 (8f29ad3) into main (87d94fe) will decrease coverage by 0.02%.
The diff coverage is 97.72%.

❗ Current head 8f29ad3 differs from pull request most recent head b9e8347. Consider uploading reports for the commit b9e8347 to get more accurate results

@@            Coverage Diff             @@
##             main     #512      +/-   ##
==========================================
- Coverage   99.30%   99.29%   -0.02%     
==========================================
  Files          48       48              
  Lines        2019     1989      -30     
==========================================
- Hits         2005     1975      -30     
  Misses         14       14              
Impacted Files Coverage Δ
blackjax/__init__.py 100.00% <ø> (ø)
blackjax/kernels.py 99.25% <91.66%> (ø)
blackjax/vi/__init__.py 100.00% <100.00%> (ø)
blackjax/vi/svgd.py 100.00% <100.00%> (ø)

... and 23 files with indirect coverage changes

@albcab
Copy link
Member

albcab commented May 4, 2023

I had to implement SVGD for a project I'm working on. The latest commit is a working version with the improvements from my review. I've tested it on medium-sized models and the vmapping hasn't broken yet, so we can leave it like this for now. It might still need some work on the documentation, other than that it should be ready to merge.

@SamDuffield SamDuffield mentioned this pull request May 31, 2023
10 tasks
@albcab albcab force-pushed the svgd branch 3 times, most recently from dacf636 to 749243d Compare June 1, 2023 14:34
@albcab albcab marked this pull request as ready for review June 1, 2023 14:38
albcab
albcab previously approved these changes Jun 1, 2023
SVGD done

pass gradient function instead of logprob function, allowing gradient estimators, on SVGD

pass positive gradients instead of negative gradients to SVGD

move kernels to individual folders (restructuring)

rewording and documentation, adding median heuristic for RBF kernel

Co-authored-by: Antoine Carnec <antoinecarnec@gmail.com>
Co-authored-by: Alberto Cabezas <a.cabezasgonzalez@lancaster.ac.uk>
@junpenglao
Copy link
Member

Thank you both!

@junpenglao junpenglao merged commit 5004f9f into blackjax-devs:main Jun 1, 2023
@rlouf rlouf changed the title SVGD implementation Implement SVGD Jun 1, 2023
junpenglao pushed a commit that referenced this pull request Mar 12, 2024
SVGD done

pass gradient function instead of logprob function, allowing gradient estimators, on SVGD

pass positive gradients instead of negative gradients to SVGD

move kernels to individual folders (restructuring)

rewording and documentation, adding median heuristic for RBF kernel

Co-authored-by: Alberto Cabezas <a.cabezasgonzalez@lancaster.ac.uk>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants