Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Contour SGMCMC sampler. #396

Merged
merged 4 commits into from
Jan 5, 2023
Merged

Add Contour SGMCMC sampler. #396

merged 4 commits into from
Jan 5, 2023

Conversation

WayneDW
Copy link
Contributor

@WayneDW WayneDW commented Oct 31, 2022

Thank you for opening a PR!

A few important guidelines and requirements before we can merge your PR:

  • If I add a new sampler, there is an issue discussing it already;
  • We should be able to understand what the PR does from its title only;
  • There is a high-level description of the changes;
  • There are links to all the relevant issues, discussions and PRs;
  • The branch is rebased on the latest main commit;
  • Commit messages follow these guidelines;
  • The code respects the current naming conventions;
  • Docstrings follow the numpy style guide
  • pre-commit is installed and configured on your machine, and you ran it before opening the PR;
  • There are tests covering the changes;
  • The doc is up-to-date;
  • If I add a new sampler* I added/updated related examples

Consider opening a Draft PR if your work is still in progress but you would like some feedback from other contributors.

@rlouf rlouf changed the title Onging project: add Contour sampler to achieve free mode exploration on MNIST dataset. Add Contour SGMCMC sampler to achieve free mode exploration on MNIST dataset. Oct 31, 2022
@rlouf rlouf added sgmcmc Stochastic Gradient MCMC samplers sampler Issue related to samplers enhancement New feature or request labels Oct 31, 2022
@rlouf rlouf changed the title Add Contour SGMCMC sampler to achieve free mode exploration on MNIST dataset. Add Contour SGMCMC sampler. Oct 31, 2022
blackjax/sgmcmc/diffusion.py Outdated Show resolved Hide resolved
blackjax/sgmcmc/diffusion.py Outdated Show resolved Hide resolved
blackjax/sgmcmc/diffusion.py Outdated Show resolved Hide resolved
blackjax/sgmcmc/csgld.py Outdated Show resolved Hide resolved
blackjax/sgmcmc/csgld.py Outdated Show resolved Hide resolved
@rlouf
Copy link
Member

rlouf commented Nov 10, 2022

This looks great, the code is very readable! A few things are missing, but what's there is a very good start:

  1. You may want to install and run pre-commit in the repo. At the very least, black should reformat your files, and flake8 detect the unused argument;
  2. We want to export this at the higher level blackjax.csgld, to do so you will need to add an implementation in kernels.py;
  3. We need a test for this, along the lines of the ones that are already there for SGMCMC algorithms;
  4. Finally an example would be really nice, a comparison with the other SGMCMC algorithms even better :)

@WayneDW
Copy link
Contributor Author

WayneDW commented Nov 10, 2022

This looks great, the code is very readable! A few things are missing, but what's there is a very good start:

  1. You may want to install and run pre-commit in the repo. At the very least, black should reformat your files, and flake8 detect the unused argument;
  2. We want to export this at the higher level blackjax.csgld, to do so you will need to add an implementation in kernels.py;
  3. We need a test for this, along the lines of the ones that are already there for SGMCMC algorithms;
  4. Finally an example would be really nice, a comparison with the other SGMCMC algorithms even better :)

No problem. After I run pre-commit in the repo, I can import this sampler like the one below?

from blackjax.sgmcmc.gradients import grad_estimator

@rlouf
Copy link
Member

rlouf commented Nov 11, 2022

No problem. After I run pre-commit in the repo, I can import this sampler like the one below?

pre-commit is just a tool that checks that the style of your code follows Blackjax's conventions.

Do you mean being able to import it as?

from blackjax import csgld

@WayneDW
Copy link
Contributor Author

WayneDW commented Nov 11, 2022

No problem. After I run pre-commit in the repo, I can import this sampler like the one below?

pre-commit is just a tool that checks that the style of your code follows Blackjax's conventions.

Do you mean being able to import it as?

from blackjax import csgld

Yes, in this way, I may be able to mimic your MNIST/SGHMC example using csgld. If there are other ways to test this code, please also feel free to suggest them.

@rlouf
Copy link
Member

rlouf commented Nov 11, 2022

You will need to add CSGLD to this file and then import this in __init__.py.

If this doesn't make sense I can do it whenever I have time.

@WayneDW
Copy link
Contributor Author

WayneDW commented Nov 11, 2022

You will need to add CSGLD to this file and then import this in __init__.py.

If this doesn't make sense I can do it whenever I have time.

Let me try. Have a great weekend.

@WayneDW
Copy link
Contributor Author

WayneDW commented Nov 14, 2022

No problem. After I run pre-commit in the repo, I can import this sampler like the one below?

pre-commit is just a tool that checks that the style of your code follows Blackjax's conventions.

Do you mean being able to import it as?

from blackjax import csgld

pre-commit run --all-files works but make test fails
"
JAX_PLATFORM_NAME=cpu pytest -n 4 --cov=blackjax --cov-report term --cov-report html:coverage tests
ERROR: usage: pytest [options] [file_or_dir] [file_or_dir] [...]
pytest: error: unrecognized arguments: -n --cov=blackjax --cov-report term --cov-report html:coverage tests
"

Do you have similar issues before?

@rlouf
Copy link
Member

rlouf commented Nov 24, 2022

Hey- I'm going to fix the merge conflicts (due to the changes in#299) in your branch and fix any other error; you will have to make sure to pull it locally before making any other modification.

@WayneDW
Copy link
Contributor Author

WayneDW commented Nov 24, 2022

Hey- I'm going to fix the merge conflicts (due to the changes in#299) in your branch and fix any other error; you will have to make sure to pull it locally before making any other modification.

Got it, thanks a lot. Happy thanksgiving day!

@rlouf
Copy link
Member

rlouf commented Nov 24, 2022

I rebased your branch on the current main branch, and made your code match the new API for stochastic gradient MCMC algorithms. I have a few comments on the next steps and a question; please excuse my current ignorance regarding your paper (I will read it before my next review!):

  1. Are any of zeta, energy_gap, temperature susceptible to be changed during inference? Is there a world in which it would make sense? In this case we would need to pass them to one_step directly in diffusions.py
  2. energy_estimator_fnrequires a small refactor of the library (which I will do), I suggest that we have a blackjax.value_and_grad_estimator function to mimick jax.value_and_grad that would build the energy estimator and the estimator of its gradient.
  3. We'll need to add a test
  4. And finally add an example.

I'll work on (2) and (3) tomorrow, happy Thanksgiving!

@WayneDW
Copy link
Contributor Author

WayneDW commented Nov 24, 2022

  1. energy_estimator_fnrequires a small refactor of the library (which I will do), I suggest that we have a blackjax.value_and_grad_estimator function to mimick jax.value_and_grad that would build the energy estimator and the estimator of its gradient.

Reply to 2: Yeah, we probably need an independent function for energy_estimator_fn, which is required by many interesting advanced samplers.

@WayneDW
Copy link
Contributor Author

WayneDW commented Nov 24, 2022

  1. Are any of zeta, energy_gap, temperature susceptible to be changed during inference? Is there a world in which it would make sense? In this case we would need to pass them to one_step directly in diffusions.py

These three can be just treated as hyperparameters.

  1. zeta should be fine-tuned;

  2. energy_gap should be specified before training (no fine-tune), e.g. if we use the vanilla training and find the range of loss goes from 1e4 to 0, then we may set energy_gap = 150 (some buffer is helpful) with around 100 partitions.

  3. temperature: due to the cold posterior phenomenon in DNN, the temperature is often set to 1e-2 or similar values. If we fix the temperature, it is more appropriate for sampling; if we anneal it, it is better for global optimization.

@rlouf
Copy link
Member

rlouf commented Nov 25, 2022

Thanks! Just wanted to say while I'm learning about the algorithm: your blog post is very well-written!

One thing I am noticing, both in the code and in the paper, is that Contour SgLD is a "meta-algorithm", in the sense that it uses the Langevin dynamics as a base transition kernel, but this transition kernel is called by another process that changes the base measure of the original density (and carries a latent vector) to make the transition kernel's job easier. I am wondering if we could make this more explicit in the code, which would be very in line with the rest of the Blackjax codebase where we try to implement basic elements and combine them to get more complex algorithms.

Which leads me to a naive question: could this be generalized to be used with other Langevin dynamics?

@codecov
Copy link

codecov bot commented Nov 25, 2022

Codecov Report

Merging #396 (0871831) into main (25a1e28) will increase coverage by 0.01%.
The diff coverage is 100.00%.

❗ Current head 0871831 differs from pull request most recent head 4057243. Consider uploading reports for the commit 4057243 to get more accurate results

@@            Coverage Diff             @@
##             main     #396      +/-   ##
==========================================
+ Coverage   99.23%   99.25%   +0.01%     
==========================================
  Files          46       47       +1     
  Lines        1834     1872      +38     
==========================================
+ Hits         1820     1858      +38     
  Misses         14       14              
Impacted Files Coverage Δ
blackjax/__init__.py 100.00% <ø> (ø)
blackjax/sgmcmc/sghmc.py 100.00% <ø> (ø)
blackjax/kernels.py 99.57% <100.00%> (+0.01%) ⬆️
blackjax/sgmcmc/__init__.py 100.00% <100.00%> (ø)
blackjax/sgmcmc/csgld.py 100.00% <100.00%> (ø)
blackjax/sgmcmc/diffusions.py 100.00% <100.00%> (ø)
blackjax/sgmcmc/gradients.py 100.00% <100.00%> (ø)
blackjax/sgmcmc/sgld.py 100.00% <100.00%> (ø)

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

@rlouf
Copy link
Member

rlouf commented Nov 25, 2022

  1. temperature: due to the cold posterior phenomenon in DNN, the temperature is often set to 1e-2 or similar values. If we fix the temperature, it is more appropriate for sampling; if we anneal it, it is better for global optimization.

Is that really necessary in practice? I've recently read this paper that claimed the opposite: https://arxiv.org/abs/2104.14421

@rlouf
Copy link
Member

rlouf commented Nov 25, 2022

One thing I am noticing, both in the code and in the paper, is that Contour SgLD is a "meta-algorithm", in the sense that it uses the Langevin dynamics as a base transition kernel, but this transition kernel is called by another process that changes the base measure of the original density (and carries a latent vector) to make the transition kernel's job easier. I am wondering if we could make this more explicit in the code, which would be very in line with the rest of the Blackjax codebase where we try to implement basic elements and combine them to get more complex algorithms.

Yes, I am going to move the stochastic update directly to csgld.py since it is not an integrator for a diffusion equation per se. We could always abstract out this logic into another file, but I think we should delay until the need arises (other algorithm that exploits the same idea in a different way).

Which leads me to a naive question: could this be generalized to be used with other Langevin dynamics?

I am currently re-using the Langevin dynamics integration in diffusions.py by re-scaling the gradients before calling one_step. We can probably do this with other diffusions.

@rlouf
Copy link
Member

rlouf commented Nov 25, 2022

With the current code on the branch (don't forget to pull --rebase locally before working on this code again!), the following code outputs a result:

import jax
import jax.numpy as jnp

import blackjax
import blackjax.sgmcmc.gradients as gradients

rng_key = jax.random.PRNGKey(0)
rng_key, data_key = jax.random.split(rng_key, 2)

data_size = 1000
X_data = jax.random.normal(data_key, shape=(data_size, 5))
data_batch = X_data[:100, :]

init_position = 1.0

def logprior_fn(position):
    return -0.5 * jnp.dot(position, position) * 0.01

def loglikelihood_fn(position, x):
    w = x - position
    return -0.5 * jnp.dot(w, w)

# Build the CSGDL sampling step and init functions
logprob_fn, grad_fn = gradients.logprob_and_grad_estimator(logprior_fn, loglikelihood_fn, data_size)
csgld = blackjax.csgld(
    logprob_fn,
    grad_fn,
    zeta=1.,  # can be specified at each step in lower-level interface
    temperature=1e-2,  # can be specified at each step
    num_partitions=12, # cannot be specified at each step
    energy_gap=1000, # cannot be specified at each step
)

# Initialize and take one step using the CSGLD algorithm
init_state = csgld.init(init_position)
new_state = csgld.step(rng_key, init_state, data_batch, 1e-3, 1e-3)
print(new_state)
# CSGLDState(position=DeviceArray(-4.366815, dtype=float32), energy_pdf=DeviceArray([0.15382446, 0.14114678, 0.12818706, 0.11536834, 0.10254964,
#              0.08973093, 0.07691223, 0.06409353, 0.05127482, 0.03845612,
#              0.02563741, 0.01281871], dtype=float32), energy_idx=1)

... but it doesn't mean it yields correct samples! This is where I'll hand the PR over to you again. We need the following:

  1. A test. If there's a target energy function for which you have theoretical convergence results then we should turn this result into a test for csgld. Otherwise we'll figure something out.
  2. A working example.

I also have a question regarding step_size_stoch: is there a situation in which we'd want this step size to vary over time or does it always remain fixed?

@WayneDW
Copy link
Contributor Author

WayneDW commented Nov 25, 2022

  1. temperature: due to the cold posterior phenomenon in DNN, the temperature is often set to 1e-2 or similar values. If we fix the temperature, it is more appropriate for sampling; if we anneal it, it is better for global optimization.

Is that really necessary in practice? I've recently read this paper that claimed the opposite: https://arxiv.org/abs/2104.14421

You can check this paper: https://arxiv.org/pdf/2002.02405.pdf. Data augmentation is a key factor that leads to the cold posterior effect. e.g. cifar has 50K images, but after data augmentation, we have many more images (>>50K). For the sum of maximum likelihood, it might be inappropriate to sum the likelihood from 1 to 50K. More importantly, we don't know what is an exact appropriate number.

@WayneDW
Copy link
Contributor Author

WayneDW commented Nov 26, 2022

Thanks! Just wanted to say while I'm learning about the algorithm: your blog post is very well-written!

One thing I am noticing, both in the code and in the paper, is that Contour SgLD is a "meta-algorithm", in the sense that it uses the Langevin dynamics as a base transition kernel, but this transition kernel is called by another process that changes the base measure of the original density (and carries a latent vector) to make the transition kernel's job easier. I am wondering if we could make this more explicit in the code, which would be very in line with the rest of the Blackjax codebase where we try to implement basic elements and combine them to get more complex algorithms.

Which leads me to a naive question: could this be generalized to be used with other Langevin dynamics?

This is a great question, a short answer is yes, we can use SGLD or SGHMC as the base sampler. Maybe we can first implement the SGLD version of contour and then generalize it in the next version.

@WayneDW
Copy link
Contributor Author

WayneDW commented Nov 26, 2022

new_state = csgld.step(rng_key, init_state, data_batch, 1e-3, 1e-3)

That is great, I am working on it now and will try to add a demo. Thanks a lot for the updates.

Great catch on the choice of step_size_stoch. Yes, annealing it is more theoretically grounded to make it converge to a fixed point, especially when we can obtain the exact energy estimator; however, in practical approximation tasks, we may set it to a constant to facilitate fine-tuning.

@WayneDW
Copy link
Contributor Author

WayneDW commented Dec 19, 2022

Ok, just let me push the fix for my stupid mistake first!

OK, make sure the temperatures used for SGLD and CSGLD are the same.

@rlouf
Copy link
Member

rlouf commented Dec 19, 2022

That's fixed, your turn! As you can see I decreased the number of iterations with little to no effect on the quality of the results.

@WayneDW
Copy link
Contributor Author

WayneDW commented Dec 19, 2022

your turn

What do you suggest to do next?

remove the code on the top 5% of partitions?

@rlouf
Copy link
Member

rlouf commented Dec 19, 2022

Yes and add some explanations. We should also take a step back now that we have an example to see if there is something that needs to be changed in the API. After I'll add a few tests and we should be good to merge!

@WayneDW
Copy link
Contributor Author

WayneDW commented Dec 19, 2022

Yes and add some explanations. We should also take a step back now that we have an example to see if there is something that needs to be changed in the API. After I'll add a few tests and we should be good to merge!

No problem.

@WayneDW
Copy link
Contributor Author

WayneDW commented Dec 20, 2022

One tricky thing is that the function doesn't allow us to explore much in tails

def loglikelihood_fn(position, x):
... mixture_1 = jax.scipy.stats.norm.pdf(x, loc=position, scale=sigma)
... mixture_2 = jax.scipy.stats.norm.pdf(x, loc=-position + gamma, scale=sigma)
... return jnp.log(0.5 * mixture_1 + 0.5 * mixture_2).sum()
...

loglikelihood_fn(80, 5)
DeviceArray(-inf, dtype=float32)

Since CSGLD is so good at exploring tails (which leads to high energy), we have to use more partitions; as such maybe the top X% partition is inevitable for efficiency purposes. I will use a cleaner way to code it.

@rlouf
Copy link
Member

rlouf commented Dec 20, 2022

Since CSGLD is so good at exploring tails (which leads to high energy), we have to use more partitions; as such maybe the top X% partition is inevitable for efficiency purposes. I will use a cleaner way to code it.

Wouldn't using jax.scipy.special.logsumexp in the logpdf be more numerically stable?

@WayneDW
Copy link
Contributor Author

WayneDW commented Dec 20, 2022

Since CSGLD is so good at exploring tails (which leads to high energy), we have to use more partitions; as such maybe the top X% partition is inevitable for efficiency purposes. I will use a cleaner way to code it.

Wouldn't using jax.scipy.special.logsumexp in the logpdf be more numerically stable?

Wow, that seems like a perfect function. Let me try!

@rlouf
Copy link
Member

rlouf commented Dec 20, 2022

You can take a look at my implementation of Cyclical SGLD, where I used it to define the log-density.

@WayneDW
Copy link
Contributor Author

WayneDW commented Dec 20, 2022

Gotcha, this is really a wonderful function, I like it. This 25-mode mixture is more challenging and interesting. Let me see if CSGLD can be used at this moment.

@rlouf rlouf mentioned this pull request Dec 20, 2022
12 tasks
@WayneDW
Copy link
Contributor Author

WayneDW commented Dec 21, 2022

You can take a look at my implementation of Cyclical SGLD, where I used it to define the log-density.

Maybe this should be the final version:

https://github.com/WayneDW/Jax_implementation_of_CSGLD/blob/main/howto_use_csgld_final_final.py

Updates:

  1. adopted a stable version of the mixture using log sum exp;
  2. much simplified the re-sampling step (still used the 5% selection);
  3. added some results to explain why the algorithm works;
  4. include more comments on the pros and cons of the algorithm.

@rlouf
Copy link
Member

rlouf commented Dec 22, 2022

Modified the notebook to take your changes into account, this works perfectly. We just need to add some tests and we'll be good to merge.

Unrelated question, could we apply the flattening/resampling idea to algorithms like HMC (full gradient)?

@WayneDW
Copy link
Contributor Author

WayneDW commented Dec 22, 2022

Modified the notebook to take your changes into account, this works perfectly. We just need to add some tests and we'll be good to merge.

Unrelated question, could we apply the flattening/resampling idea to algorithms like HMC (full gradient)?

I believe so, but I haven't rigorously tried yet. It should work.

@WayneDW
Copy link
Contributor Author

WayneDW commented Dec 22, 2022

Modified the notebook to take your changes into account, this works perfectly. We just need to add some tests and we'll be good to merge.

Unrelated question, could we apply the flattening/resampling idea to algorithms like HMC (full gradient)?

Actually, this is my first time merging an algorithm to a public repo. Very excited.

@WayneDW
Copy link
Contributor Author

WayneDW commented Dec 30, 2022

One thing to comment: for the contour_sgld.md, I noticed that you included a module to re-start a particle when it goes beyond the domain. If you intentionally did it, that is OK; otherwise, you may check the code below:

https://github.com/WayneDW/Jax_implementation_of_CSGLD/blob/main/howto_use_csgld_final_final.py. In this code, we don't need to be concerned about going beyond the domain because a robust version of the mixture is implemented.

image

Copy link
Contributor

@zaxtax zaxtax left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this looks pretty good. Some improvements can be made, but I think the underlying algorithm implementation, tests, and example notebooks are top-notch

blackjax/sgmcmc/csgld.py Outdated Show resolved Hide resolved
blackjax/sgmcmc/gradients.py Outdated Show resolved Hide resolved

Contour SGLD takes inspiration from the Wang-Landau algorithm to learn the density of states of the model at each energy level, and uses this information to flatten the target density to be able to explore it more easily.

As a result, the samples returned by contour SGLD are not from the target density directly, and we need to resample them using the density of state as importance weights to get samples from the target distribution.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we expect to do importance sampling like this more than once? Is this worth factoring into a library function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently just once. Maybe we can think about this in the future. My current implementation is only acceptable and may not be fast enough, in the future version, we need to include some additional variables showing the important weight for each simulated sample.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Periodic Orbital HMC also uses importance sampling and it would be good to generalise this at some point, but maybe not in this PR.

@WayneDW
Copy link
Contributor Author

WayneDW commented Dec 31, 2022

I think this looks pretty good. Some improvements can be made, but I think the underlying algorithm implementation, tests, and example notebooks are top-notch

Thanks for the nice words. To my best knowledge, this is the most promising sampler for multi-modal distributions (One of my academic family's (including COPSS Award winners: Wing-Hung Wong and Jun S. Liu) proudest works). It has been proven successful on MNIST level dataset. Although MNIST looks like a baby CV dataset, none of the existing samplers can achieve fluctuation losses like ours.

@rlouf
Copy link
Member

rlouf commented Dec 31, 2022

I didn’t see the change in the example, I will correct this before merging (it’s admittedly much cleaner this way!). I will take @zaxtax’s suggestions into account, maybe tweak a few things here and there in the API and we should be able to merge in the coming week.

@WayneDW
Copy link
Contributor Author

WayneDW commented Dec 31, 2022

I didn’t see the change in the example, I will correct this before merging (it’s admittedly much cleaner this way!). I will take @zaxtax’s suggestions into account, maybe tweak a few things here and there in the API and we should be able to merge in the coming week.

Fantastic!

@zaxtax
Copy link
Contributor

zaxtax commented Dec 31, 2022 via email

Copy link
Member

@rlouf rlouf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is good to merge. Thank you so much for your implementation and being patient through the whole process. Time to celebrate 🥳

Comment on lines +171 to +172
energy_pdf_update = -energy_pdf.copy()
energy_pdf_update = energy_pdf_update.at[idx].set(energy_pdf_update[idx] + 1)
energy_pdf = jax.tree_util.tree_map(
lambda e: e + step_size_stoch * energy_pdf[idx] * energy_pdf_update,
energy_pdf,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since energy_pdf_update is not used after that we can probably make this more efficient.

energy_gap: float = 100,
) -> MCMCSamplingAlgorithm:

step = cls.kernel(num_partitions, energy_gap)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course!

energy_gap: float = 100,
) -> MCMCSamplingAlgorithm:

step = cls.kernel(num_partitions, energy_gap)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! I'll convert the example and add a test ASAP.

energy_gap: float = 100,
) -> MCMCSamplingAlgorithm:

step = cls.kernel(num_partitions, energy_gap)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious: what happens if T=1?

blackjax/sgmcmc/gradients.py Outdated Show resolved Hide resolved

Contour SGLD takes inspiration from the Wang-Landau algorithm to learn the density of states of the model at each energy level, and uses this information to flatten the target density to be able to explore it more easily.

As a result, the samples returned by contour SGLD are not from the target density directly, and we need to resample them using the density of state as importance weights to get samples from the target distribution.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Periodic Orbital HMC also uses importance sampling and it would be good to generalise this at some point, but maybe not in this PR.

@rlouf rlouf merged commit 2ddf548 into blackjax-devs:main Jan 5, 2023
@WayneDW
Copy link
Contributor Author

WayneDW commented Jan 6, 2023

T=1 recovers the exact posterior?

@WayneDW
Copy link
Contributor Author

WayneDW commented Jan 6, 2023

BTW, for some minor things:

  1. the line of "domain_radius = 50" can be safely deleted;
  2. Build the SGDL sampler -> # Build the SGLD sampler

  3. 100K iterations may probably make the result more appealing, but it is fine for now.

I will present more interesting results in the future. Especially to demonstrate it on MNIST dataset. This is the real one.

@albcab albcab mentioned this pull request Jun 2, 2023
10 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request sampler Issue related to samplers sgmcmc Stochastic Gradient MCMC samplers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement Contour SGLD
3 participants