-
Notifications
You must be signed in to change notification settings - Fork 107
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Contour SGMCMC sampler. #396
Conversation
This looks great, the code is very readable! A few things are missing, but what's there is a very good start:
|
No problem. After I run from blackjax.sgmcmc.gradients import grad_estimator |
Do you mean being able to import it as? from blackjax import csgld |
Yes, in this way, I may be able to mimic your MNIST/SGHMC example using csgld. If there are other ways to test this code, please also feel free to suggest them. |
You will need to add CSGLD to this file and then import this in If this doesn't make sense I can do it whenever I have time. |
Let me try. Have a great weekend. |
pre-commit run --all-files works but make test fails Do you have similar issues before? |
Hey- I'm going to fix the merge conflicts (due to the changes in#299) in your branch and fix any other error; you will have to make sure to pull it locally before making any other modification. |
Got it, thanks a lot. Happy thanksgiving day! |
I rebased your branch on the current
I'll work on (2) and (3) tomorrow, happy Thanksgiving! |
Reply to 2: Yeah, we probably need an independent function for energy_estimator_fn, which is required by many interesting advanced samplers. |
These three can be just treated as hyperparameters.
|
Thanks! Just wanted to say while I'm learning about the algorithm: your blog post is very well-written! One thing I am noticing, both in the code and in the paper, is that Contour SgLD is a "meta-algorithm", in the sense that it uses the Langevin dynamics as a base transition kernel, but this transition kernel is called by another process that changes the base measure of the original density (and carries a latent vector) to make the transition kernel's job easier. I am wondering if we could make this more explicit in the code, which would be very in line with the rest of the Blackjax codebase where we try to implement basic elements and combine them to get more complex algorithms. Which leads me to a naive question: could this be generalized to be used with other Langevin dynamics? |
Codecov Report
@@ Coverage Diff @@
## main #396 +/- ##
==========================================
+ Coverage 99.23% 99.25% +0.01%
==========================================
Files 46 47 +1
Lines 1834 1872 +38
==========================================
+ Hits 1820 1858 +38
Misses 14 14
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. |
Is that really necessary in practice? I've recently read this paper that claimed the opposite: https://arxiv.org/abs/2104.14421 |
Yes, I am going to move the stochastic update directly to
I am currently re-using the Langevin dynamics integration in |
With the current code on the branch (don't forget to import jax
import jax.numpy as jnp
import blackjax
import blackjax.sgmcmc.gradients as gradients
rng_key = jax.random.PRNGKey(0)
rng_key, data_key = jax.random.split(rng_key, 2)
data_size = 1000
X_data = jax.random.normal(data_key, shape=(data_size, 5))
data_batch = X_data[:100, :]
init_position = 1.0
def logprior_fn(position):
return -0.5 * jnp.dot(position, position) * 0.01
def loglikelihood_fn(position, x):
w = x - position
return -0.5 * jnp.dot(w, w)
# Build the CSGDL sampling step and init functions
logprob_fn, grad_fn = gradients.logprob_and_grad_estimator(logprior_fn, loglikelihood_fn, data_size)
csgld = blackjax.csgld(
logprob_fn,
grad_fn,
zeta=1., # can be specified at each step in lower-level interface
temperature=1e-2, # can be specified at each step
num_partitions=12, # cannot be specified at each step
energy_gap=1000, # cannot be specified at each step
)
# Initialize and take one step using the CSGLD algorithm
init_state = csgld.init(init_position)
new_state = csgld.step(rng_key, init_state, data_batch, 1e-3, 1e-3)
print(new_state)
# CSGLDState(position=DeviceArray(-4.366815, dtype=float32), energy_pdf=DeviceArray([0.15382446, 0.14114678, 0.12818706, 0.11536834, 0.10254964,
# 0.08973093, 0.07691223, 0.06409353, 0.05127482, 0.03845612,
# 0.02563741, 0.01281871], dtype=float32), energy_idx=1) ... but it doesn't mean it yields correct samples! This is where I'll hand the PR over to you again. We need the following:
I also have a question regarding |
You can check this paper: https://arxiv.org/pdf/2002.02405.pdf. Data augmentation is a key factor that leads to the cold posterior effect. e.g. cifar has 50K images, but after data augmentation, we have many more images (>>50K). For the sum of maximum likelihood, it might be inappropriate to sum the likelihood from 1 to 50K. More importantly, we don't know what is an exact appropriate number. |
This is a great question, a short answer is yes, we can use SGLD or SGHMC as the base sampler. Maybe we can first implement the SGLD version of contour and then generalize it in the next version. |
That is great, I am working on it now and will try to add a demo. Thanks a lot for the updates. Great catch on the choice of step_size_stoch. Yes, annealing it is more theoretically grounded to make it converge to a fixed point, especially when we can obtain the exact energy estimator; however, in practical approximation tasks, we may set it to a constant to facilitate fine-tuning. |
OK, make sure the temperatures used for SGLD and CSGLD are the same. |
That's fixed, your turn! As you can see I decreased the number of iterations with little to no effect on the quality of the results. |
What do you suggest to do next? remove the code on the top 5% of partitions? |
Yes and add some explanations. We should also take a step back now that we have an example to see if there is something that needs to be changed in the API. After I'll add a few tests and we should be good to merge! |
No problem. |
One tricky thing is that the function doesn't allow us to explore much in tails
Since CSGLD is so good at exploring tails (which leads to high energy), we have to use more partitions; as such maybe the top X% partition is inevitable for efficiency purposes. I will use a cleaner way to code it. |
Wouldn't using |
Wow, that seems like a perfect function. Let me try! |
You can take a look at my implementation of Cyclical SGLD, where I used it to define the log-density. |
Gotcha, this is really a wonderful function, I like it. This 25-mode mixture is more challenging and interesting. Let me see if CSGLD can be used at this moment. |
Maybe this should be the final version: Updates:
|
Modified the notebook to take your changes into account, this works perfectly. We just need to add some tests and we'll be good to merge. Unrelated question, could we apply the flattening/resampling idea to algorithms like HMC (full gradient)? |
I believe so, but I haven't rigorously tried yet. It should work. |
Actually, this is my first time merging an algorithm to a public repo. Very excited. |
One thing to comment: for the contour_sgld.md, I noticed that you included a module to re-start a particle when it goes beyond the domain. If you intentionally did it, that is OK; otherwise, you may check the code below: https://github.com/WayneDW/Jax_implementation_of_CSGLD/blob/main/howto_use_csgld_final_final.py. In this code, we don't need to be concerned about going beyond the domain because a robust version of the mixture is implemented. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this looks pretty good. Some improvements can be made, but I think the underlying algorithm implementation, tests, and example notebooks are top-notch
|
||
Contour SGLD takes inspiration from the Wang-Landau algorithm to learn the density of states of the model at each energy level, and uses this information to flatten the target density to be able to explore it more easily. | ||
|
||
As a result, the samples returned by contour SGLD are not from the target density directly, and we need to resample them using the density of state as importance weights to get samples from the target distribution. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we expect to do importance sampling like this more than once? Is this worth factoring into a library function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently just once. Maybe we can think about this in the future. My current implementation is only acceptable and may not be fast enough, in the future version, we need to include some additional variables showing the important weight for each simulated sample.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Periodic Orbital HMC also uses importance sampling and it would be good to generalise this at some point, but maybe not in this PR.
Thanks for the nice words. To my best knowledge, this is the most promising sampler for multi-modal distributions (One of my academic family's (including COPSS Award winners: Wing-Hung Wong and Jun S. Liu) proudest works). It has been proven successful on MNIST level dataset. Although MNIST looks like a baby CV dataset, none of the existing samplers can achieve fluctuation losses like ours. |
I didn’t see the change in the example, I will correct this before merging (it’s admittedly much cleaner this way!). I will take @zaxtax’s suggestions into account, maybe tweak a few things here and there in the API and we should be able to merge in the coming week. |
Fantastic! |
Yes! I don't want my remarks to be a blocker for this fantastic
contribution!
…On Sat, 31 Dec 2022, 15:32 Wei Deng, ***@***.***> wrote:
I didn’t see the change in the example, I will correct this before merging
(it’s admittedly much cleaner this way!). I will take @zaxtax
<https://github.com/zaxtax>’s suggestions into account, maybe tweak a few
things here and there in the API and we should be able to merge in the
coming week.
Fantastic!
—
Reply to this email directly, view it on GitHub
<#396 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAACCUK76ZU325QONESQFKTWQA7ZBANCNFSM6AAAAAARSTFFEE>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
Add Gaussian mixture example for the Contour SGLD sampler.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is good to merge. Thank you so much for your implementation and being patient through the whole process. Time to celebrate 🥳
energy_pdf_update = -energy_pdf.copy() | ||
energy_pdf_update = energy_pdf_update.at[idx].set(energy_pdf_update[idx] + 1) | ||
energy_pdf = jax.tree_util.tree_map( | ||
lambda e: e + step_size_stoch * energy_pdf[idx] * energy_pdf_update, | ||
energy_pdf, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since energy_pdf_update
is not used after that we can probably make this more efficient.
blackjax/kernels.py
Outdated
energy_gap: float = 100, | ||
) -> MCMCSamplingAlgorithm: | ||
|
||
step = cls.kernel(num_partitions, energy_gap) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Of course!
blackjax/kernels.py
Outdated
energy_gap: float = 100, | ||
) -> MCMCSamplingAlgorithm: | ||
|
||
step = cls.kernel(num_partitions, energy_gap) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you! I'll convert the example and add a test ASAP.
blackjax/kernels.py
Outdated
energy_gap: float = 100, | ||
) -> MCMCSamplingAlgorithm: | ||
|
||
step = cls.kernel(num_partitions, energy_gap) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious: what happens if T=1?
|
||
Contour SGLD takes inspiration from the Wang-Landau algorithm to learn the density of states of the model at each energy level, and uses this information to flatten the target density to be able to explore it more easily. | ||
|
||
As a result, the samples returned by contour SGLD are not from the target density directly, and we need to resample them using the density of state as importance weights to get samples from the target distribution. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Periodic Orbital HMC also uses importance sampling and it would be good to generalise this at some point, but maybe not in this PR.
T=1 recovers the exact posterior? |
BTW, for some minor things:
I will present more interesting results in the future. Especially to demonstrate it on MNIST dataset. This is the real one. |
Thank you for opening a PR!
A few important guidelines and requirements before we can merge your PR:
main
commit;pre-commit
is installed and configured on your machine, and you ran it before opening the PR;Consider opening a Draft PR if your work is still in progress but you would like some feedback from other contributors.