From 7c7f6e600640e0ebd6d3b9cb9fb6cbeebfb54c79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Mon, 12 Dec 2022 09:49:23 +0100 Subject: [PATCH 1/4] Allow temperatures != 1 in SGLD --- blackjax/kernels.py | 12 ++++++++++-- blackjax/sgmcmc/diffusions.py | 5 ++++- blackjax/sgmcmc/sgld.py | 5 ++++- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/blackjax/kernels.py b/blackjax/kernels.py index 323f31c84..b0cf1459f 100644 --- a/blackjax/kernels.py +++ b/blackjax/kernels.py @@ -538,8 +538,16 @@ def __new__( # type: ignore[misc] step = cls.kernel() - def step_fn(rng_key: PRNGKey, state, minibatch: PyTree, step_size: float): - return step(rng_key, state, grad_estimator, minibatch, step_size) + def step_fn( + rng_key: PRNGKey, + state, + minibatch: PyTree, + step_size: float, + temperature: float = 1, + ): + return step( + rng_key, state, grad_estimator, minibatch, step_size, temperature + ) return step_fn diff --git a/blackjax/sgmcmc/diffusions.py b/blackjax/sgmcmc/diffusions.py index 44dbc07d4..1de3b4985 100644 --- a/blackjax/sgmcmc/diffusions.py +++ b/blackjax/sgmcmc/diffusions.py @@ -39,11 +39,14 @@ def one_step( position: PyTree, logdensity_grad: PyTree, step_size: float, + temperature: float = 1.0, ) -> PyTree: noise = generate_gaussian_noise(rng_key, position) position = jax.tree_util.tree_map( - lambda p, g, n: p + step_size * g + jnp.sqrt(2 * step_size) * n, + lambda p, g, n: p + + step_size * g + + jnp.sqrt(2 * temperature * step_size) * n, position, logdensity_grad, noise, diff --git a/blackjax/sgmcmc/sgld.py b/blackjax/sgmcmc/sgld.py index 5a21212cb..c41a39ff1 100644 --- a/blackjax/sgmcmc/sgld.py +++ b/blackjax/sgmcmc/sgld.py @@ -30,10 +30,13 @@ def one_step( grad_estimator: Callable, minibatch: PyTree, step_size: float, + temperature: float = 1.0, ): logdensity_grad = grad_estimator(position, minibatch) - new_position = integrator(rng_key, position, logdensity_grad, step_size) + new_position = integrator( + rng_key, position, logdensity_grad, step_size, temperature + ) return new_position From 856098152a27e9d0d70b387f19af82e92dca62f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Sun, 18 Dec 2022 14:32:02 +0100 Subject: [PATCH 2/4] Allow temperatures != 1 in SgHMC --- blackjax/sgmcmc/diffusions.py | 3 ++- blackjax/sgmcmc/sghmc.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/blackjax/sgmcmc/diffusions.py b/blackjax/sgmcmc/diffusions.py index 1de3b4985..7a75760f2 100644 --- a/blackjax/sgmcmc/diffusions.py +++ b/blackjax/sgmcmc/diffusions.py @@ -79,13 +79,14 @@ def one_step( momentum: PyTree, logdensity_grad: PyTree, step_size: float, + temperature: float = 1.0, ): noise = generate_gaussian_noise(rng_key, position) position = jax.tree_util.tree_map(lambda x, p: x + p, position, momentum) momentum = jax.tree_util.tree_map( lambda p, g, n: (1.0 - alpha) * p + step_size * g - + jnp.sqrt(2 * step_size * (alpha - beta)) * n, + + jnp.sqrt(2 * step_size * (alpha - beta) * temperature) * n, momentum, logdensity_grad, noise, diff --git a/blackjax/sgmcmc/sghmc.py b/blackjax/sgmcmc/sghmc.py index 177b958b1..141f8504e 100644 --- a/blackjax/sgmcmc/sghmc.py +++ b/blackjax/sgmcmc/sghmc.py @@ -35,12 +35,13 @@ def one_step( minibatch: PyTree, step_size: float, num_integration_steps: int, + temperature: float = 1.0, ) -> PyTree: def body_fn(state, rng_key): position, momentum = state logdensity_grad = grad_estimator(position, minibatch) position, momentum = integrator( - rng_key, position, momentum, logdensity_grad, step_size + rng_key, position, momentum, logdensity_grad, step_size, temperature ) return ((position, momentum), position) From 974b173bcf33540d1c96114311edd044a5d02d9a Mon Sep 17 00:00:00 2001 From: Wei Deng Date: Sun, 30 Oct 2022 20:30:58 -0400 Subject: [PATCH 3/4] Add the Contour SgLD sampler --- blackjax/__init__.py | 2 + blackjax/kernels.py | 44 ++++++++- blackjax/sgmcmc/__init__.py | 5 +- blackjax/sgmcmc/csgld.py | 176 ++++++++++++++++++++++++++++++++++ blackjax/sgmcmc/diffusions.py | 2 +- blackjax/sgmcmc/gradients.py | 40 ++++---- docs/examples/SGMCMC.md | 6 +- tests/test_sampling.py | 31 ++++-- 8 files changed, 272 insertions(+), 34 deletions(-) create mode 100644 blackjax/sgmcmc/csgld.py diff --git a/blackjax/__init__.py b/blackjax/__init__.py index e2f60851a..74d7bd38e 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -4,6 +4,7 @@ from .diagnostics import potential_scale_reduction as rhat from .kernels import ( adaptive_tempered_smc, + csgld, elliptical_slice, ghmc, hmc, @@ -39,6 +40,7 @@ "meads", "sgld", # stochastic gradient mcmc "sghmc", + "csgld", "window_adaptation", # mcmc adaptation "pathfinder_adaptation", "adaptive_tempered_smc", # smc diff --git a/blackjax/kernels.py b/blackjax/kernels.py index b0cf1459f..9b13ae0d3 100644 --- a/blackjax/kernels.py +++ b/blackjax/kernels.py @@ -35,6 +35,7 @@ "rmh", "sgld", "sghmc", + "csgld", "tempered_smc", "window_adaptation", "irmh", @@ -533,7 +534,7 @@ class sgld: def __new__( # type: ignore[misc] cls, - grad_estimator: sgmcmc.gradients.GradientEstimator, + grad_estimator: Callable, ) -> Callable: step = cls.kernel() @@ -631,6 +632,47 @@ def step_fn(rng_key: PRNGKey, state, minibatch: PyTree, step_size: float): return step_fn +class csgld: + + init = staticmethod(sgmcmc.csgld.init) + kernel = staticmethod(sgmcmc.csgld.kernel) + + def __new__( # type: ignore[misc] + cls, + logdensity_estimator_fn: Callable, + zeta: float = 1, + temperature: float = 0.01, + num_partitions: int = 512, + energy_gap: float = 100, + min_energy: float = 0, + ) -> MCMCSamplingAlgorithm: + + step = cls.kernel(num_partitions, energy_gap, min_energy) + + def init_fn(position: PyTree): + return cls.init(position, num_partitions) + + def step_fn( + rng_key: PRNGKey, + state, + minibatch: PyTree, + step_size_diff: float, + step_size_stoch: float, + ): + return step( + rng_key, + state, + logdensity_estimator_fn, + minibatch, + step_size_diff, + step_size_stoch, + zeta, + temperature, + ) + + return MCMCSamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] + + # ----------------------------------------------------------------------------- # ADAPTATION # ----------------------------------------------------------------------------- diff --git a/blackjax/sgmcmc/__init__.py b/blackjax/sgmcmc/__init__.py index 40a748c13..5ff3f170c 100644 --- a/blackjax/sgmcmc/__init__.py +++ b/blackjax/sgmcmc/__init__.py @@ -1,3 +1,4 @@ -from . import gradients, sghmc, sgld +from . import csgld, sghmc, sgld +from .gradients import grad_estimator, logdensity_estimator -__all__ = ["gradients", "sgld", "sghmc"] +__all__ = ["grad_estimator", "logdensity_estimator", "csgld", "sgld", "sghmc"] diff --git a/blackjax/sgmcmc/csgld.py b/blackjax/sgmcmc/csgld.py new file mode 100644 index 000000000..22bf36ae8 --- /dev/null +++ b/blackjax/sgmcmc/csgld.py @@ -0,0 +1,176 @@ +"""Public API for the Contour Stochastic gradient Langevin Dynamics kernel. + +References +---------- +.. [0]: Deng, W., Lin, G., Liang, F. (2020). + A Contour Stochastic Gradient Langevin Dynamics Algorithm + for Simulations of Multi-modal Distributions. + In Neural Information Processing Systems (NeurIPS 2020). + +.. [1]: Deng, W., Liang, S., Hao, B., Lin, G., Liang, F. (2022) + Interacting Contour Stochastic Gradient Langevin Dynamics + In International Conference on Learning Representations (ICLR) +""" +from typing import Callable, NamedTuple + +import jax +import jax.numpy as jnp + +from blackjax.sgmcmc.diffusions import overdamped_langevin +from blackjax.types import Array, PRNGKey, PyTree + +__all__ = ["ContourSGLDState", "init", "kernel"] + + +class ContourSGLDState(NamedTuple): + r"""State of the Contour SgLD algorithm. + + Parameters + ---------- + position + Current position in the sample space. + energy_pdf + Vector with `m` non-negative values that sum to 1. The `i`-th value + of the vector is equal to :math:`\int_{S_1} \pi(\mathrm{d}x)` where + :math:`S_i` is the `i`-th energy partition. + energy_idx + Index `i` such that the current position belongs to :math:`S_i`. + + """ + position: PyTree + energy_pdf: Array + energy_idx: int + + +def init(position: PyTree, num_partitions=512): + energy_pdf = ( + jnp.arange(num_partitions, 0, -1) / jnp.arange(num_partitions, 0, -1).sum() + ) + return ContourSGLDState(position, energy_pdf, num_partitions - 1) + + +def kernel(num_partitions=512, energy_gap=10, min_energy=0) -> Callable: + r""" + + Parameters + ---------- + num_partitions + The number of partitions we divide the energy landscape into. + energy_gap + The difference in energy :math:`\Delta u` between the successive + partitions. Can be determined by running e.g. an optimizer to determine + the range of energies. `num_partition` * `energy_gap` should match this + range. + min_energy + A rough estimate of the minimum energy in a dataset, which should be + strictly smaller than the exact minimum energy! e.g. if the minimum + energy of a dataset is 3456, we can set min_energy to be any value + smaller than 3456. Set it to 0 is acceptable, but not efficient enough. + the closer the gap between min_energy and 3456 is, the better. + """ + + integrator = overdamped_langevin() + + def one_step( + rng_key: PRNGKey, + state: ContourSGLDState, + logdensity_estimator_fn: Callable, + minibatch: PyTree, + step_size_diff: float, # step size for Langevin diffusion + step_size_stoch: float = 1e-3, # step size for stochastic approximation + zeta: float = 1, + temperature: float = 1.0, + ) -> ContourSGLDState: + r"""Multil-modal sampling via Contour SGLD. + + We are interested in the simulations of :math:`\exp(-U(x) / T)`, + where :math:`U` is an energy function and :math:`T` is the temperature. + + To do so we partition the energy space into :math:`m`: + + .. math:: + S_0 = {x: U(x) <= u_1} + S_1 = {x: u_1 < U(x) <= u_2} + S_2 = {x: u_2 < U(x) <= u_3} + ... + S_{m-2} = {x: u_{m-2} < U(x) <= u_{m-1}} + S_{m-1} = {x: U(x) > u_{m-1}} + + where :math:`-\inf < u_1 < u_2 < · · · < u_{m−1} < \inf`. We assume + :math:`u_{i+1} − u_i = \Delta u` for :math:`i = 1, \dots , m−2`. + + Parameters + ---------- + rng_key + State of the pseudo-random number generator. + state + Current state of the CSGLD sampler + logdensity_estimator_fn + Function that returns an estimation of the value of the density + function at the current position. + minibatch + Minibatch of data. + step_size_diff + Step size for the dynamics integration. Also called learning rate. + step_size_stoch + Step size for the update of the energy estimation. + zeta + Hyperparameter that controls the geometric property of the flattened + density. If `zeta=0` the function reduces to the SGLD step function. + temperature + Temperature parameter :math:`T`. + + References + ---------- + .. [0]: Deng, W., Lin, G., Liang, F. (2020). + A Contour Stochastic Gradient Langevin Dynamics Algorithm + for Simulations of Multi-modal Distributions. + In Neural Information Processing Systems (NeurIPS 2020). + + [1]: Deng, W., Liang, S., Hao, B., Lin, G., Liang, F. (2022) + Interacting Contour Stochastic Gradient Langevin Dynamics + In International Conference on Learning Representations (ICLR) + """ + + position, energy_pdf, idx = state + + # Update the position using the overdamped Langevin diffusion + gradient_multiplier = ( + 1.0 + + zeta + * temperature + * (jnp.log(energy_pdf[idx]) - jnp.log(energy_pdf[idx - 1])) + / energy_gap + ) + + logprob_grad = jax.grad(logdensity_estimator_fn)(position, minibatch) + position = integrator( + rng_key, + position, + jax.tree_util.tree_map(lambda g: gradient_multiplier * g, logprob_grad), + step_size_diff, + temperature, + ) + + # Update the stochastic approximation to the energy histogram + neg_logprob = -logdensity_estimator_fn(position, minibatch) + idx = jax.lax.min( + jax.lax.max( + jax.lax.floor((neg_logprob - min_energy) / energy_gap + 1).astype( + "int32" + ), + 1, + ), + num_partitions - 1, + ) + + energy_pdf_update = -energy_pdf.copy() + energy_pdf_update = energy_pdf_update.at[idx].set(energy_pdf_update[idx] + 1) + energy_pdf = jax.tree_util.tree_map( + lambda e: e + step_size_stoch * energy_pdf[idx] * energy_pdf_update, + energy_pdf, + ) + + return ContourSGLDState(position, energy_pdf, idx) + + return one_step diff --git a/blackjax/sgmcmc/diffusions.py b/blackjax/sgmcmc/diffusions.py index 7a75760f2..75c326fea 100644 --- a/blackjax/sgmcmc/diffusions.py +++ b/blackjax/sgmcmc/diffusions.py @@ -18,7 +18,7 @@ from blackjax.types import PRNGKey, PyTree from blackjax.util import generate_gaussian_noise -__all__ = ["overdamped_langevin"] +__all__ = ["overdamped_langevin", "sghmc"] def overdamped_langevin(): diff --git a/blackjax/sgmcmc/gradients.py b/blackjax/sgmcmc/gradients.py index e29ab096d..a4b64ef4e 100644 --- a/blackjax/sgmcmc/gradients.py +++ b/blackjax/sgmcmc/gradients.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, NamedTuple +from typing import Callable import jax import jax.numpy as jnp @@ -19,15 +19,10 @@ from blackjax.types import PyTree -class GradientEstimator(NamedTuple): - init: Callable - estimate: Callable - - -def estimator( +def logdensity_estimator( logprior_fn: Callable, loglikelihood_fn: Callable, data_size: int -) -> GradientEstimator: - """Builds a simple gradient estimator. +) -> Callable: + """Builds a simple estimator for the log-density. This estimator first appeared in [1]_. The `logprior_fn` function has a single argument: the current position (value of parameters). The @@ -57,8 +52,8 @@ def estimator( """ - def logposterior_estimator_fn(position: PyTree, minibatch: PyTree) -> PyTree: - """Returns an approximation of the log-posterior density. + def logdensity_estimator_fn(position: PyTree, minibatch: PyTree) -> PyTree: + """Return an approximation of the log-posterior density. Parameters ---------- @@ -79,11 +74,22 @@ def logposterior_estimator_fn(position: PyTree, minibatch: PyTree) -> PyTree: batch_loglikelihood(position, minibatch), axis=0 ) - return jax.grad(logposterior_estimator_fn) + return logdensity_estimator_fn + + +def grad_estimator( + logprior_fn: Callable, loglikelihood_fn: Callable, data_size: int +) -> Callable: + """Build a simple estimator for the gradient of the log-density.""" + + logdensity_estimator_fn = logdensity_estimator( + logprior_fn, loglikelihood_fn, data_size + ) + return jax.grad(logdensity_estimator_fn) def control_variates( - grad_estimator: Callable, + logdensity_grad_estimator: Callable, centering_position: PyTree, data: PyTree, ) -> Callable: @@ -93,7 +99,7 @@ def control_variates( Parameters ---------- - grad_estimator + logdensity_grad_estimator A function that approximates the target's gradient function. data The full dataset. @@ -110,7 +116,7 @@ def control_variates( Journal of Open Source Software, 7(72), 4113. """ - cv_grad_value = grad_estimator(centering_position, data) + cv_grad_value = logdensity_grad_estimator(centering_position, data) def cv_grad_estimator_fn(position: PyTree, minibatch: PyTree) -> PyTree: """Return an approximation of the log-posterior density. @@ -129,8 +135,8 @@ def cv_grad_estimator_fn(position: PyTree, minibatch: PyTree) -> PyTree: the current value of the random variables. """ - grad_estimate = grad_estimator(position, minibatch) - center_grad_estimate = grad_estimator(centering_position, minibatch) + grad_estimate = logdensity_grad_estimator(position, minibatch) + center_grad_estimate = logdensity_grad_estimator(centering_position, minibatch) return jax.tree_map( lambda grad_est, cv_grad_est, cv_grad: cv_grad + grad_est - cv_grad_est, diff --git a/docs/examples/SGMCMC.md b/docs/examples/SGMCMC.md index 996fbc266..f9cbc51bc 100644 --- a/docs/examples/SGMCMC.md +++ b/docs/examples/SGMCMC.md @@ -144,7 +144,7 @@ We now sample from the model's posteriors using SGLD. We discard the first 1000 from fastprogress.fastprogress import progress_bar import blackjax -import blackjax.sgmcmc.gradients as gradients +from blackjax.sgmcmc import grad_estimator data_size = len(y_train) @@ -162,7 +162,7 @@ batches = batch_data(rng_key, (X_train, y_train), batch_size, data_size) state = jax.jit(model.init)(rng_key, jnp.ones(X_train.shape[-1])) # Build the SGLD kernel with a constant learning rate -grad_fn = gradients.estimator(logprior_fn, loglikelihood_fn, data_size) +grad_fn = grad_estimator(logprior_fn, loglikelihood_fn, data_size) sgld = blackjax.sgld(grad_fn) # Sample from the posterior @@ -207,7 +207,7 @@ We can also use SGHMC to samples from this model ```{code-cell} python # Build the SGHMC kernel with a constant learning rate step_size = 9e-6 -grad_fn = gradients.estimator(logprior_fn, loglikelihood_fn, data_size) +grad_fn = grad_estimator(logprior_fn, loglikelihood_fn, data_size) sghmc = blackjax.sghmc(grad_fn) # Batch the data diff --git a/tests/test_sampling.py b/tests/test_sampling.py index 227da253c..17b203588 100644 --- a/tests/test_sampling.py +++ b/tests/test_sampling.py @@ -232,18 +232,32 @@ def loglikelihood_fn(self, position, x): w = x - position return -0.5 * jnp.dot(w, w) - def constant_step_size(_): - return 1e-3 + def test_linear_regression_contour_sgld(self): + + rng_key, data_key = jax.random.split(self.key, 2) + + data_size = 1000 + X_data = jax.random.normal(data_key, shape=(data_size, 5)) + + logdensity_fn = blackjax.sgmcmc.logdensity_estimator( + self.logprior_fn, self.loglikelihood_fn, data_size + ) + csgld = blackjax.csgld(logdensity_fn) + + _, rng_key = jax.random.split(rng_key) + data_batch = X_data[:100, :] + init_position = 1.0 + init_state = csgld.init(init_position) + _ = csgld.step(rng_key, init_state, data_batch, 1e-3, 1e-2) def test_linear_regression_sgld(self): - import blackjax.sgmcmc.gradients rng_key, data_key = jax.random.split(self.key, 2) data_size = 1000 X_data = jax.random.normal(data_key, shape=(data_size, 5)) - grad_fn = blackjax.sgmcmc.gradients.estimator( + grad_fn = blackjax.sgmcmc.grad_estimator( self.logprior_fn, self.loglikelihood_fn, data_size ) sgld = blackjax.sgld(grad_fn) @@ -254,7 +268,6 @@ def test_linear_regression_sgld(self): _ = sgld(rng_key, init_position, data_batch, 1e-3) def test_linear_regression_sgld_cv(self): - import blackjax.sgmcmc.gradients rng_key, data_key = jax.random.split(self.key, 2) @@ -263,7 +276,7 @@ def test_linear_regression_sgld_cv(self): centering_position = 1.0 - grad_fn = blackjax.sgmcmc.gradients.estimator( + grad_fn = blackjax.sgmcmc.grad_estimator( self.logprior_fn, self.loglikelihood_fn, data_size ) cv_grad_fn = blackjax.sgmcmc.gradients.control_variates( @@ -278,14 +291,13 @@ def test_linear_regression_sgld_cv(self): _ = sgld(rng_key, init_position, data_batch, 1e-3) def test_linear_regression_sghmc(self): - import blackjax.sgmcmc.gradients rng_key, data_key = jax.random.split(self.key, 2) data_size = 1000 X_data = jax.random.normal(data_key, shape=(data_size, 5)) - grad_fn = blackjax.sgmcmc.gradients.estimator( + grad_fn = blackjax.sgmcmc.grad_estimator( self.logprior_fn, self.loglikelihood_fn, data_size ) sghmc = blackjax.sghmc(grad_fn, 10) @@ -297,7 +309,6 @@ def test_linear_regression_sghmc(self): _ = sghmc(rng_key, init_position, data_batch, 1e-3) def test_linear_regression_sghmc_cv(self): - import blackjax.sgmcmc.gradients rng_key, data_key = jax.random.split(self.key, 2) @@ -305,7 +316,7 @@ def test_linear_regression_sghmc_cv(self): X_data = jax.random.normal(data_key, shape=(data_size, 5)) centering_position = 1.0 - grad_fn = blackjax.sgmcmc.gradients.estimator( + grad_fn = blackjax.sgmcmc.grad_estimator( self.logprior_fn, self.loglikelihood_fn, data_size ) cv_grad_fn = blackjax.sgmcmc.gradients.control_variates( From 4057243d2cf3e5e7e159af92ecd225808462f577 Mon Sep 17 00:00:00 2001 From: Wei Deng Date: Mon, 12 Dec 2022 18:42:20 -0500 Subject: [PATCH 4/4] Add Contour SgLD example Add Gaussian mixture example for the Contour SGLD sampler. --- docs/examples.rst | 7 +- docs/examples/contour_sgld.md | 290 ++++++++++++++++++++++++++++++++++ 2 files changed, 294 insertions(+), 3 deletions(-) create mode 100644 docs/examples/contour_sgld.md diff --git a/docs/examples.rst b/docs/examples.rst index 2feef1910..8590fbd49 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -7,13 +7,14 @@ Examples examples/Introduction.md examples/LogisticRegression.md examples/LogisticRegressionWithLatentGaussianSampler.md + examples/SparseLogisticRegression.md examples/TemperedSMC.md - examples/HierarchicalBNN.md examples/PeriodicOrbitalMCMC.md examples/GP_EllipticalSliceSampler.md examples/GP_Marginal.md - examples/SGMCMC.md examples/change_of_variable_hmc.md examples/Pathfinder.md examples/RegimeSwitchingModel.md - examples/SparseLogisticRegression.md + examples/HierarchicalBNN.md + examples/SGMCMC.md + examples/contour_sgld.md diff --git a/docs/examples/contour_sgld.md b/docs/examples/contour_sgld.md new file mode 100644 index 000000000..adb904e07 --- /dev/null +++ b/docs/examples/contour_sgld.md @@ -0,0 +1,290 @@ +--- +jupytext: + formats: ipynb,md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.14.0 +kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +--- + +# Contour stochastic gradient Langevin dynamics + ++++ + +Sampling in big data problems is fundamentally limited by the multi-modality of the target distributions, with extremely high energy barriers. Multi-modality is often empirically solved via cyclical learning rates or different initializations (parallel chains). + +Contour SgLD takes a different approach altogether: the algorithms learns the energy landscape with sampling, and uses this approximation to effectively integrate the diffusion on a flat landscape, before using the importance weight to reweigh the obtained samples. + +In this notebook we will compare the performance of SGLD and Contour SGLD on a simple bimodal gaussian target. This example looks simple, but is rather challenging to sample with most methods. + ++++ + +## Gaussian Mixture + +Let us first generate data points that follow a gaussian mixture distributions. The example appears simple, and yet it is hard enough for most algorithms to fail to recover the two modes. + +```{code-cell} ipython3 +import jax +import jax.numpy as jnp +import jax.scipy as jsp + + +def gaussian_mixture_model(mu=-5.0, sigma=5.0, gamma=20.0): + def sample_fn(rng_key, num_samples): + key1, key2, key3 = jax.random.split(rng_key, 3) + prob_mixture = jax.random.bernoulli(key1, p=0.5, shape=(num_samples, 1)) + mixture_1 = jax.random.normal(key2, shape=(num_samples, 1)) * sigma + mu + mixture_2 = jax.random.normal(key3, shape=(num_samples, 1)) * sigma + gamma - mu + return prob_mixture * mixture_1 + (1 - prob_mixture) * mixture_2 + + def logprior_fn(position): + return 0 + + def loglikelihood_fn(position, x): + mixture_1 = jax.scipy.stats.norm.logpdf(x, loc=position, scale=sigma) + mixture_2 = jax.scipy.stats.norm.logpdf(x, loc=-position + gamma, scale=sigma) + return jsp.special.logsumexp(jnp.array([mixture_1, mixture_2])) + jnp.log(0.5) + + return sample_fn, logprior_fn, loglikelihood_fn + + +sample_fn, logprior_fn, loglikelihood_fn = gaussian_mixture_model() +``` + +```{code-cell} ipython3 +data_size = 1000 + +rng_key = jax.random.PRNGKey(888) +rng_key, sample_key = jax.random.split(rng_key) +X_data = sample_fn(sample_key, data_size) +``` + +```{code-cell} ipython3 +import matplotlib.pylab as plt + +ax = plt.subplot(111) +ax.hist(X_data.squeeze(), 100) +ax.set_xlabel("X") +ax.set_xlim(left=-15, right=35) + +ax.set_yticks([]) +ax.spines["top"].set_visible(False) +ax.spines["right"].set_visible(False) +ax.spines["left"].set_visible(False) + +plt.title("Data") +``` + +# Use the stochastic gradient Langevin dynamics algorithm + +```{code-cell} ipython3 +from fastprogress import progress_bar + +import blackjax +import blackjax.sgmcmc.gradients as gradients + +# Specify hyperparameters for SGLD +total_iter = 50_000 +thinning_factor = 10 + +batch_size = 100 +lr = 1e-3 +temperature = 50.0 + +init_position = 10.0 + + +# Build the SGDL sampler +grad_fn = gradients.grad_estimator(logprior_fn, loglikelihood_fn, data_size) +sgld = blackjax.sgld(grad_fn) + + +# Initialize and take one step using the vanilla SGLD algorithm +position = init_position +sgld_sample_list = jnp.array([]) + +pb = progress_bar(range(total_iter)) +for iter_ in pb: + rng_key, batch_key, sample_key = jax.random.split(rng_key, 3) + data_batch = jax.random.shuffle(batch_key, X_data)[:batch_size, :] + position = jax.jit(sgld)(sample_key, position, data_batch, lr, temperature) + if iter_ % thinning_factor == 0: + sgld_sample_list = jnp.append(sgld_sample_list, position) + pb.comment = f"| position: {position: .2f}" +``` + +```{code-cell} ipython3 +import matplotlib.gridspec as gridspec +import matplotlib.pylab as plt + +fig = plt.figure(figsize=(12, 6)) + +G = gridspec.GridSpec(1, 3) + +# Trajectory +ax = plt.subplot(G[0, :2]) +ax.plot(sgld_sample_list, label="SGLD") +ax.set_xlabel(f"Iterations (x{thinning_factor})") +ax.set_ylabel("X") + +ax.spines["top"].set_visible(False) +ax.spines["right"].set_visible(False) + + +# Histogram +ax = plt.subplot(G[0, 2]) +ax.hist(sgld_sample_list, 100) +ax.set_xlabel("X") +ax.set_xlim(left=-15, right=35) + +ax.set_yticks([]) +ax.spines["top"].set_visible(False) +ax.spines["right"].set_visible(False) +ax.spines["left"].set_visible(False) + + +plt.suptitle("Stochastic gradient Langevin dynamics (SGLD)") +``` + +```{code-cell} ipython3 +# specify hyperparameters (zeta and sz are the only two hyperparameters to tune) +zeta = 2 +sz = 10 +temperature = 50 + +lr = 1e-3 +init_position = 10.0 + + +# The following parameters partition the energy space and no tuning is needed. +num_partitions = 100000 +energy_gap = 0.25 +domain_radius = 50 # restart sampling when the particle explores too deep over the tails and leads to nan. + + +logdensity_fn = gradients.logdensity_estimator(logprior_fn, loglikelihood_fn, data_size) +csgld = blackjax.csgld( + logdensity_fn, + zeta=zeta, # can be specified at each step in lower-level interface + temperature=temperature, # can be specified at each step + num_partitions=num_partitions, # cannot be specified at each step + energy_gap=energy_gap, # cannot be specified at each step + min_energy=0, +) + +# 3.1 Simulate via the CSGLD algorithm +state = csgld.init(init_position) + +csgld_sample_list, csgld_energy_idx_list = jnp.array([]), jnp.array([]) + +pb = progress_bar(range(total_iter)) +for iter_ in pb: + rng_key, subkey = jax.random.split(rng_key) + stepsize_SA = min(1e-2, (iter_ + 100) ** (-0.8)) * sz + + data_batch = jax.random.shuffle(rng_key, X_data)[:batch_size, :] + state = jax.jit(csgld.step)(subkey, state, data_batch, lr, stepsize_SA) + + if iter_ % thinning_factor == 0: + csgld_sample_list = jnp.append(csgld_sample_list, state.position) + csgld_energy_idx_list = jnp.append(csgld_energy_idx_list, state.energy_idx) + pb.comment = ( + f"| position {state.position: .2f}" + ) +``` + +Contour SGLD takes inspiration from the Wang-Landau algorithm to learn the density of states of the model at each energy level, and uses this information to flatten the target density to be able to explore it more easily. + +As a result, the samples returned by contour SGLD are not from the target density directly, and we need to resample them using the density of state as importance weights to get samples from the target distribution. + +```{code-cell} ipython3 +important_idx = jnp.where(state.energy_pdf > jnp.quantile(state.energy_pdf, 0.95))[0] +scaled_energy_pdf = ( + state.energy_pdf[important_idx] ** zeta + / (state.energy_pdf[important_idx] ** zeta).max() +) + +csgld_re_sample_list = jnp.array([]) +for _ in range(5): + rng_key, subkey = jax.random.split(rng_key) + for my_idx in important_idx: + if jax.random.bernoulli(rng_key, p=scaled_energy_pdf[my_idx], shape=None) == 1: + samples_in_my_idx = csgld_sample_list[csgld_energy_idx_list == my_idx] + csgld_re_sample_list = jnp.concatenate( + (csgld_re_sample_list, samples_in_my_idx) + ) +``` + +```{code-cell} ipython3 +import matplotlib.gridspec as gridspec +import matplotlib.pylab as plt + +fig = plt.figure(figsize=(12, 6)) + +G = gridspec.GridSpec(1, 3) + +# Trajectory +ax = plt.subplot(G[0, :2]) +ax.plot(csgld_sample_list, label="SGLD") +ax.set_xlabel(f"Iterations (x{thinning_factor})") +ax.set_ylabel("X") + +ax.spines["top"].set_visible(False) +ax.spines["right"].set_visible(False) + + +# Histogram before resampling +ax = plt.subplot(G[0, 2]) +ax.hist(csgld_sample_list, 100, label="before resampling") +ax.hist(csgld_re_sample_list, 100, label="after resampling") + +ax.set_xlabel("X") +ax.set_xlim(left=-15, right=35) + +ax.set_yticks([]) +ax.spines["top"].set_visible(False) +ax.spines["right"].set_visible(False) +ax.spines["left"].set_visible(False) + +plt.legend() +plt.suptitle("Contour SGLD") +``` + +### Why does Contour SGLD work? + +The energy density is crucial for us to build a flat density, so let's take a look at the estimation returned by the algorithm. For illustration purposes, we smooth out fluctations and focus on the energy range from 3700 to 100000, which covers the major part of sample space. + +```{code-cell} ipython3 +smooth_energy_pdf = jnp.convolve( + state.energy_pdf, jsp.stats.norm.pdf(jnp.arange(-100, 101), scale=10), mode="same" +) +interested_idx = jax.lax.floor((jnp.arange(3700, 10000)) / energy_gap).astype( + "int32" +) # min 3681 + +fig = plt.figure() +ax = fig.add_subplot(111) +ax.plot( + jnp.arange(num_partitions)[interested_idx] * energy_gap, + smooth_energy_pdf[interested_idx], +) + +ax.set_xlabel("Energy") +ax.set_ylabel("Energy Density") + +ax.spines["top"].set_visible(False) +ax.spines["right"].set_visible(False) + +plt.show() +``` + +From the figure above, we see that low-energy regions usually lead to much higher probability mass. Moreover, the slope is negative with a higher scale in low energy regions. In view of Eq.(8) in [the paper]( https://proceedings.neurips.cc/paper/2020/file/b5b8c484824d8a06f4f3d570bc420313-Paper.pdf), we can expect a **negative learning rate** to help the particle escape the local trap. Eventually, a particle is able to bounce out of the deep local traps freely instead of being absorbed into it. + ++++ + +Admittedly, this algorithm is a little sophisticated due to the need to partition the energy space; Learning energy pdf also makes this algorithm delicate and leads to a large variance. However, this allows to escape deep local traps in a principled sampling framework without using any tricks (cyclical learning rates or different initializations). The variance-reduced version is studied in [this work](https://arxiv.org/pdf/2202.09867.pdf).