|
| 1 | +# Copyright 2020- The Blackjax Authors. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +from typing import Callable, NamedTuple, Tuple |
| 15 | + |
| 16 | +import jax |
| 17 | +import jax.numpy as jnp |
| 18 | +import jax.scipy as jsp |
| 19 | +from optax import GradientTransformation, OptState |
| 20 | + |
| 21 | +from blackjax.types import PRNGKey, PyTree |
| 22 | + |
| 23 | +__all__ = ["MFVIState", "MFVIInfo", "sample", "generate_meanfield_logdensity", "step"] |
| 24 | + |
| 25 | + |
| 26 | +class MFVIState(NamedTuple): |
| 27 | + mu: PyTree |
| 28 | + rho: PyTree |
| 29 | + opt_state: OptState |
| 30 | + |
| 31 | + |
| 32 | +class MFVIInfo(NamedTuple): |
| 33 | + elbo: float |
| 34 | + |
| 35 | + |
| 36 | +def init( |
| 37 | + position: PyTree, |
| 38 | + optimizer: GradientTransformation, |
| 39 | + *optimizer_args, |
| 40 | + **optimizer_kwargs |
| 41 | +) -> MFVIState: |
| 42 | + """Initialize the mean-field VI state.""" |
| 43 | + mu = jax.tree_map(jnp.zeros_like, position) |
| 44 | + rho = jax.tree_map(lambda x: -2.0 * jnp.ones_like(x), position) |
| 45 | + opt_state = optimizer.init((mu, rho)) |
| 46 | + return MFVIState(mu, rho, opt_state) |
| 47 | + |
| 48 | + |
| 49 | +def step( |
| 50 | + rng_key: PRNGKey, |
| 51 | + state: MFVIState, |
| 52 | + logdensity_fn: Callable, |
| 53 | + optimizer: GradientTransformation, |
| 54 | + num_samples: int = 5, |
| 55 | + stl_estimator: bool = True, |
| 56 | +) -> Tuple[MFVIState, MFVIInfo]: |
| 57 | + """Approximate the target density using the mean-field approximation. |
| 58 | +
|
| 59 | + Parameters |
| 60 | + ---------- |
| 61 | + rng_key |
| 62 | + Key for JAX's pseudo-random number generator. |
| 63 | + init_state |
| 64 | + Initial state of the mean-field approximation. |
| 65 | + logdensity_fn |
| 66 | + Function that represents the target log-density to approximate. |
| 67 | + optimizer |
| 68 | + Optax `GradientTransformation` to be used for optimization. |
| 69 | + num_samples |
| 70 | + The number of samples that are taken from the approximation |
| 71 | + at each step to compute the Kullback-Leibler divergence between |
| 72 | + the approximation and the target log-density. |
| 73 | + stl_estimator |
| 74 | + Whether to use stick-the-landing (STL) gradient estimator [1] for gradient estimation. |
| 75 | + The STL estimator has lower gradient variance by removing the score function term |
| 76 | + from the gradient. It is suggested by [2] to always keep it in order for better results. |
| 77 | +
|
| 78 | + References |
| 79 | + ---------- |
| 80 | + .. [1]: Roeder, G., Wu, Y., & Duvenaud, D. K. (2017). |
| 81 | + Sticking the landing: Simple, lower-variance gradient estimators for variational inference. |
| 82 | + Advances in Neural Information Processing Systems, 30. |
| 83 | + .. [2]: Agrawal, A., Sheldon, D. R., & Domke, J. (2020). |
| 84 | + Advances in black-box VI: Normalizing flows, importance weighting, and optimization. |
| 85 | + Advances in Neural Information Processing Systems, 33. |
| 86 | + """ |
| 87 | + |
| 88 | + parameters = (state.mu, state.rho) |
| 89 | + |
| 90 | + def kl_divergence_fn(parameters): |
| 91 | + mu, rho = parameters |
| 92 | + z = _sample(rng_key, mu, rho, num_samples) |
| 93 | + if stl_estimator: |
| 94 | + mu = jax.lax.stop_gradient(mu) |
| 95 | + rho = jax.lax.stop_gradient(rho) |
| 96 | + logq = jax.vmap(generate_meanfield_logdensity(mu, rho))(z) |
| 97 | + logp = jax.vmap(logdensity_fn)(z) |
| 98 | + return (logq - logp).mean() |
| 99 | + |
| 100 | + elbo, elbo_grad = jax.value_and_grad(kl_divergence_fn)(parameters) |
| 101 | + updates, new_opt_state = optimizer.update(elbo_grad, state.opt_state, parameters) |
| 102 | + new_parameters = jax.tree_map(lambda p, u: p + u, parameters, updates) |
| 103 | + new_state = MFVIState(new_parameters[0], new_parameters[1], new_opt_state) |
| 104 | + return new_state, MFVIInfo(elbo) |
| 105 | + |
| 106 | + |
| 107 | +def sample(rng_key: PRNGKey, state: MFVIState, num_samples: int = 1): |
| 108 | + """Sample from the mean-field approximation.""" |
| 109 | + return _sample(rng_key, state.mu, state.rho, num_samples) |
| 110 | + |
| 111 | + |
| 112 | +def _sample(rng_key, mu, rho, num_samples): |
| 113 | + sigma = jax.tree_map(jnp.exp, rho) |
| 114 | + mu_flatten, unravel_fn = jax.flatten_util.ravel_pytree(mu) |
| 115 | + sigma_flat, _ = jax.flatten_util.ravel_pytree(sigma) |
| 116 | + flatten_sample = ( |
| 117 | + jax.random.normal(rng_key, (num_samples,) + mu_flatten.shape) * sigma_flat |
| 118 | + + mu_flatten |
| 119 | + ) |
| 120 | + return jax.vmap(unravel_fn)(flatten_sample) |
| 121 | + |
| 122 | + |
| 123 | +def generate_meanfield_logdensity(mu, rho): |
| 124 | + sigma_param = jax.tree_map(jnp.exp, rho) |
| 125 | + |
| 126 | + def meanfield_logdensity(position): |
| 127 | + logq_pytree = jax.tree_map(jsp.stats.norm.logpdf, position, mu, sigma_param) |
| 128 | + logq = jax.tree_map(jnp.sum, logq_pytree) |
| 129 | + return jax.tree_util.tree_reduce(jnp.add, logq) |
| 130 | + |
| 131 | + return meanfield_logdensity |
0 commit comments