Skip to content

Commit b6807b2

Browse files
xidulurlouf
authored andcommitted
Add Mean-Field Variational Inference
add some doc remove extra tuple in base.py change log prob to logdensity remove fn from log density
1 parent 27d981e commit b6807b2

9 files changed

+234
-9
lines changed

blackjax/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
irmh,
1212
mala,
1313
meads,
14+
meanfield_vi,
1415
mgrad_gaussian,
1516
nuts,
1617
orbital_hmc,
@@ -45,7 +46,8 @@
4546
"pathfinder_adaptation",
4647
"adaptive_tempered_smc", # smc
4748
"tempered_smc",
48-
"pathfinder", # variational inference
49+
"meanfield_vi", # variational inference
50+
"pathfinder",
4951
"ess", # diagnostics
5052
"rhat",
5153
]

blackjax/base.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# Copyright 2020- The Blackjax Authors.
2-
#
32
# Licensed under the Apache License, Version 2.0 (the "License");
43
# you may not use this file except in compliance with the License.
54
# You may obtain a copy of the License at
@@ -136,7 +135,8 @@ class VIAlgorithm(NamedTuple):
136135
137136
"""
138137

139-
approximate: Callable
138+
init: Callable
139+
step: Callable
140140
sample: Callable
141141

142142

blackjax/kernels.py

+36-3
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""Blackjax high-level interface with sampling algorithms."""
15-
from typing import Callable, Dict, NamedTuple, Optional, Union
15+
from typing import Callable, Dict, NamedTuple, Optional, Tuple, Union
1616

1717
import jax
1818
import jax.numpy as jnp
19+
from optax import GradientTransformation
1920

2021
import blackjax.adaptation as adaptation
2122
import blackjax.mcmc as mcmc
@@ -1251,6 +1252,11 @@ def step_fn(rng_key: PRNGKey, state):
12511252
# -----------------------------------------------------------------------------
12521253

12531254

1255+
class PathFinderAlgorithm(NamedTuple):
1256+
approximate: Callable
1257+
sample: Callable
1258+
1259+
12541260
class pathfinder:
12551261
"""Implements the (basic) user interface for the pathfinder kernel.
12561262
@@ -1273,7 +1279,7 @@ class pathfinder:
12731279
approximate = staticmethod(vi.pathfinder.approximate)
12741280
sample = staticmethod(vi.pathfinder.sample)
12751281

1276-
def __new__(cls, logdensity_fn: Callable) -> VIAlgorithm: # type: ignore[misc]
1282+
def __new__(cls, logdensity_fn: Callable) -> PathFinderAlgorithm: # type: ignore[misc]
12771283
def approximate_fn(
12781284
rng_key: PRNGKey,
12791285
position: PyTree,
@@ -1289,7 +1295,7 @@ def sample_fn(
12891295
):
12901296
return cls.sample(rng_key, state, num_samples)
12911297

1292-
return VIAlgorithm(approximate_fn, sample_fn)
1298+
return PathFinderAlgorithm(approximate_fn, sample_fn)
12931299

12941300

12951301
def pathfinder_adaptation(
@@ -1385,3 +1391,30 @@ def kernel(rng_key, state):
13851391
return AdaptationResults(last_chain_state, kernel, parameters)
13861392

13871393
return AdaptationAlgorithm(run)
1394+
1395+
1396+
class meanfield_vi:
1397+
init = staticmethod(vi.meanfield_vi.init)
1398+
step = staticmethod(vi.meanfield_vi.step)
1399+
sample = staticmethod(vi.meanfield_vi.sample)
1400+
1401+
def __new__(
1402+
cls,
1403+
logdensity_fn: Callable,
1404+
optimizer: GradientTransformation,
1405+
num_samples: int = 100,
1406+
): # type: ignore[misc]
1407+
def init_fn(position: PyTree):
1408+
return cls.init(position, optimizer)
1409+
1410+
def step_fn(
1411+
rng_key: PRNGKey, state: vi.meanfield_vi.MFVIState
1412+
) -> Tuple[vi.meanfield_vi.MFVIState, vi.meanfield_vi.MFVIInfo]:
1413+
return cls.step(rng_key, state, logdensity_fn, optimizer, num_samples)
1414+
1415+
def sample_fn(
1416+
rng_key: PRNGKey, state: vi.meanfield_vi.MFVIState, num_samples: int
1417+
):
1418+
return cls.sample(rng_key, state, num_samples)
1419+
1420+
return VIAlgorithm(init_fn, step_fn, sample_fn)

blackjax/vi/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from . import pathfinder
1+
from . import meanfield_vi, pathfinder
22

3-
__all__ = ["pathfinder"]
3+
__all__ = ["pathfinder", "meanfield_vi"]

blackjax/vi/meanfield_vi.py

+131
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# Copyright 2020- The Blackjax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import Callable, NamedTuple, Tuple
15+
16+
import jax
17+
import jax.numpy as jnp
18+
import jax.scipy as jsp
19+
from optax import GradientTransformation, OptState
20+
21+
from blackjax.types import PRNGKey, PyTree
22+
23+
__all__ = ["MFVIState", "MFVIInfo", "sample", "generate_meanfield_logdensity", "step"]
24+
25+
26+
class MFVIState(NamedTuple):
27+
mu: PyTree
28+
rho: PyTree
29+
opt_state: OptState
30+
31+
32+
class MFVIInfo(NamedTuple):
33+
elbo: float
34+
35+
36+
def init(
37+
position: PyTree,
38+
optimizer: GradientTransformation,
39+
*optimizer_args,
40+
**optimizer_kwargs
41+
) -> MFVIState:
42+
"""Initialize the mean-field VI state."""
43+
mu = jax.tree_map(jnp.zeros_like, position)
44+
rho = jax.tree_map(lambda x: -2.0 * jnp.ones_like(x), position)
45+
opt_state = optimizer.init((mu, rho))
46+
return MFVIState(mu, rho, opt_state)
47+
48+
49+
def step(
50+
rng_key: PRNGKey,
51+
state: MFVIState,
52+
logdensity_fn: Callable,
53+
optimizer: GradientTransformation,
54+
num_samples: int = 5,
55+
stl_estimator: bool = True,
56+
) -> Tuple[MFVIState, MFVIInfo]:
57+
"""Approximate the target density using the mean-field approximation.
58+
59+
Parameters
60+
----------
61+
rng_key
62+
Key for JAX's pseudo-random number generator.
63+
init_state
64+
Initial state of the mean-field approximation.
65+
logdensity_fn
66+
Function that represents the target log-density to approximate.
67+
optimizer
68+
Optax `GradientTransformation` to be used for optimization.
69+
num_samples
70+
The number of samples that are taken from the approximation
71+
at each step to compute the Kullback-Leibler divergence between
72+
the approximation and the target log-density.
73+
stl_estimator
74+
Whether to use stick-the-landing (STL) gradient estimator [1] for gradient estimation.
75+
The STL estimator has lower gradient variance by removing the score function term
76+
from the gradient. It is suggested by [2] to always keep it in order for better results.
77+
78+
References
79+
----------
80+
.. [1]: Roeder, G., Wu, Y., & Duvenaud, D. K. (2017).
81+
Sticking the landing: Simple, lower-variance gradient estimators for variational inference.
82+
Advances in Neural Information Processing Systems, 30.
83+
.. [2]: Agrawal, A., Sheldon, D. R., & Domke, J. (2020).
84+
Advances in black-box VI: Normalizing flows, importance weighting, and optimization.
85+
Advances in Neural Information Processing Systems, 33.
86+
"""
87+
88+
parameters = (state.mu, state.rho)
89+
90+
def kl_divergence_fn(parameters):
91+
mu, rho = parameters
92+
z = _sample(rng_key, mu, rho, num_samples)
93+
if stl_estimator:
94+
mu = jax.lax.stop_gradient(mu)
95+
rho = jax.lax.stop_gradient(rho)
96+
logq = jax.vmap(generate_meanfield_logdensity(mu, rho))(z)
97+
logp = jax.vmap(logdensity_fn)(z)
98+
return (logq - logp).mean()
99+
100+
elbo, elbo_grad = jax.value_and_grad(kl_divergence_fn)(parameters)
101+
updates, new_opt_state = optimizer.update(elbo_grad, state.opt_state, parameters)
102+
new_parameters = jax.tree_map(lambda p, u: p + u, parameters, updates)
103+
new_state = MFVIState(new_parameters[0], new_parameters[1], new_opt_state)
104+
return new_state, MFVIInfo(elbo)
105+
106+
107+
def sample(rng_key: PRNGKey, state: MFVIState, num_samples: int = 1):
108+
"""Sample from the mean-field approximation."""
109+
return _sample(rng_key, state.mu, state.rho, num_samples)
110+
111+
112+
def _sample(rng_key, mu, rho, num_samples):
113+
sigma = jax.tree_map(jnp.exp, rho)
114+
mu_flatten, unravel_fn = jax.flatten_util.ravel_pytree(mu)
115+
sigma_flat, _ = jax.flatten_util.ravel_pytree(sigma)
116+
flatten_sample = (
117+
jax.random.normal(rng_key, (num_samples,) + mu_flatten.shape) * sigma_flat
118+
+ mu_flatten
119+
)
120+
return jax.vmap(unravel_fn)(flatten_sample)
121+
122+
123+
def generate_meanfield_logdensity(mu, rho):
124+
sigma_param = jax.tree_map(jnp.exp, rho)
125+
126+
def meanfield_logdensity(position):
127+
logq_pytree = jax.tree_map(jsp.stats.norm.logpdf, position, mu, sigma_param)
128+
logq = jax.tree_map(jnp.sum, logq_pytree)
129+
return jax.tree_util.tree_reduce(jnp.add, logq)
130+
131+
return meanfield_logdensity

docs/vi.rst

+6
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,15 @@ Variational Inference
77
:nosignatures:
88

99
pathfinder
10+
meanfield_vi
1011

1112

1213
Pathfinder
1314
~~~~~~~~~~
1415

1516
.. autoclass:: blackjax.pathfinder
17+
18+
Mean-field VI
19+
~~~~~~~~~~~~~
20+
21+
.. autoclass:: blackjax.meanfield_vi

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ dependencies = [
3636
"jax>=0.3.13",
3737
"jaxlib>=0.3.10",
3838
"jaxopt>=0.5.5",
39+
"optax",
3940
"typing-extensions>=4.4.0",
4041
]
4142
dynamic = ["version"]

pytest.ini

-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
[pytest]
2-
addopts = -n auto
32
testpaths= "tests"
43
filterwarnings =
54
error

tests/test_meanfield_vi.py

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import chex
2+
import jax
3+
import jax.numpy as jnp
4+
import jax.scipy.stats as stats
5+
import optax
6+
from absl.testing import absltest
7+
8+
import blackjax
9+
10+
11+
class MFVITest(chex.TestCase):
12+
def setUp(self):
13+
super().setUp()
14+
self.key = jax.random.PRNGKey(42)
15+
16+
def test_recover_posterior(self):
17+
ground_truth = [
18+
# loc, scale
19+
(2, 4),
20+
(3, 5),
21+
]
22+
23+
def logdensity_fn(x):
24+
logpdf = stats.norm.logpdf(x["x_1"], *ground_truth[0]) + stats.norm.logpdf(
25+
x["x_2"], *ground_truth[1]
26+
)
27+
return jnp.sum(logpdf)
28+
29+
initial_position = {"x_1": 0.0, "x_2": 0.0}
30+
31+
num_steps = 50_000
32+
num_samples = 500
33+
34+
optimizer = optax.sgd(1e-2)
35+
mfvi = blackjax.meanfield_vi(logdensity_fn, optimizer, num_samples)
36+
state = mfvi.init(initial_position)
37+
38+
rng_key = self.key
39+
for _ in range(num_steps):
40+
rng_key, _ = jax.random.split(rng_key)
41+
state, _ = jax.jit(mfvi.step)(self.key, state)
42+
43+
loc_1, loc_2 = state.mu["x_1"], state.mu["x_2"]
44+
scale = jax.tree_map(jnp.exp, state.rho)
45+
scale_1, scale_2 = scale["x_1"], scale["x_2"]
46+
self.assertAlmostEqual(loc_1, ground_truth[0][0], delta=0.01)
47+
self.assertAlmostEqual(scale_1, ground_truth[0][1], delta=0.01)
48+
self.assertAlmostEqual(loc_2, ground_truth[1][0], delta=0.01)
49+
self.assertAlmostEqual(scale_2, ground_truth[1][1], delta=0.01)
50+
51+
52+
if __name__ == "__main__":
53+
absltest.main()

0 commit comments

Comments
 (0)