Skip to content

Commit

Permalink
SMC: allow each mutation kernel to have different parameters. (#649)
Browse files Browse the repository at this point in the history
* vmaping over parameters in base

* switch from mcmc_factory to just passing in parameters

* pre-commit and typing

* CRU and docs improvement

* pre-commit

* code review updates

* pre-commit

* rename test
  • Loading branch information
ciguaran authored Mar 25, 2024
1 parent 2ccdfb0 commit 3dc3809
Show file tree
Hide file tree
Showing 9 changed files with 229 additions and 121 deletions.
8 changes: 7 additions & 1 deletion blackjax/smc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
from . import adaptive_tempered, inner_kernel_tuning, tempered
from .base import extend_params

__all__ = ["adaptive_tempered", "tempered", "inner_kernel_tuning"]
__all__ = [
"adaptive_tempered",
"tempered",
"inner_kernel_tuning",
"extend_params",
]
20 changes: 16 additions & 4 deletions blackjax/smc/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class SMCState(NamedTuple):

particles: ArrayTree
weights: Array
update_parameters: ArrayTree


class SMCInfo(NamedTuple):
Expand All @@ -59,12 +60,12 @@ class SMCInfo(NamedTuple):
update_info: NamedTuple


def init(particles: ArrayLikeTree):
def init(particles: ArrayLikeTree, init_update_params):
# Infer the number of particles from the size of the leading dimension of
# the first leaf of the inputted PyTree.
num_particles = jax.tree_util.tree_flatten(particles)[0][0].shape[0]
weights = jnp.ones(num_particles) / num_particles
return SMCState(particles, weights)
return SMCState(particles, weights, init_update_params)


def step(
Expand Down Expand Up @@ -137,13 +138,24 @@ def step(
particles = jax.tree_map(lambda x: x[resampling_idx], state.particles)

keys = jax.random.split(updating_key, num_resampled)
particles, update_info = update_fn(keys, particles)
particles, update_info = update_fn(keys, particles, state.update_parameters)

log_weights = weight_fn(particles)
logsum_weights = jax.scipy.special.logsumexp(log_weights)
normalizing_constant = logsum_weights - jnp.log(num_particles)
weights = jnp.exp(log_weights - logsum_weights)

return SMCState(particles, weights), SMCInfo(
return SMCState(particles, weights, state.update_parameters), SMCInfo(
resampling_idx, normalizing_constant, update_info
)


def extend_params(n_particles, params):
"""Given a dictionary of params, repeats them for every single particle. The expected
usage is in cases where the aim is to repeat the same parameters for all chains within SMC.
"""

def extend(param):
return jnp.repeat(jnp.asarray(param)[None, ...], n_particles, axis=0)

return jax.tree_map(extend, params)
39 changes: 20 additions & 19 deletions blackjax/smc/inner_kernel_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,15 @@


class StateWithParameterOverride(NamedTuple):
"""
Stores both the sampling status and also a dictionary
that contains an dictionary with parameter names as key
and (n_particles, *) arrays as meanings. The latter
represent a parameter per chain for the next mutation step.
"""

sampler_state: ArrayTree
parameter_override: ArrayTree
parameter_override: Dict[str, ArrayTree]


def init(alg_init_fn, position, initial_parameter_value):
Expand All @@ -20,11 +27,10 @@ def build_kernel(
smc_algorithm,
logprior_fn: Callable,
loglikelihood_fn: Callable,
mcmc_factory: Callable,
mcmc_step_fn: Callable,
mcmc_init_fn: Callable,
mcmc_parameters: Dict,
resampling_fn: Callable,
mcmc_parameter_update_fn: Callable[[SMCState, SMCInfo], ArrayTree],
mcmc_parameter_update_fn: Callable[[SMCState, SMCInfo], Dict[str, ArrayTree]],
num_mcmc_steps: int = 10,
**extra_parameters,
) -> Callable:
Expand All @@ -41,12 +47,11 @@ def build_kernel(
A function that computes the log density of the prior distribution
loglikelihood_fn
A function that returns the probability at a given position.
mcmc_factory
A callable that can construct an inner kernel out of the newly-computed parameter
mcmc_step_fn:
The transition kernel, should take as parameters the dictionary output of mcmc_parameter_update_fn.
mcmc_step_fn(rng_key, state, tempered_logposterior_fn, **mcmc_parameter_update_fn())
mcmc_init_fn
A callable that initializes the inner kernel
mcmc_parameters
Other (fixed across SMC iterations) parameters for the inner kernel
mcmc_parameter_update_fn
A callable that takes the SMCState and SMCInfo at step i and constructs a parameter to be used by the inner kernel in i+1 iteration.
extra_parameters:
Expand All @@ -59,9 +64,9 @@ def kernel(
step_fn = smc_algorithm(
logprior_fn=logprior_fn,
loglikelihood_fn=loglikelihood_fn,
mcmc_step_fn=mcmc_factory(state.parameter_override),
mcmc_step_fn=mcmc_step_fn,
mcmc_init_fn=mcmc_init_fn,
mcmc_parameters=mcmc_parameters,
mcmc_parameters=state.parameter_override,
resampling_fn=resampling_fn,
num_mcmc_steps=num_mcmc_steps,
**extra_parameters,
Expand Down Expand Up @@ -89,17 +94,15 @@ class inner_kernel_tuning:
A function that computes the log density of the prior distribution
loglikelihood_fn
A function that returns the probability at a given position.
mcmc_factory
A callable that can construct an inner kernel out of the newly-computed parameter
mcmc_step_fn
The transition kernel, should take as parameters the dictionary output of mcmc_parameter_update_fn.
mcmc_init_fn
A callable that initializes the inner kernel
mcmc_parameters
Other (fixed across SMC iterations) parameters for the inner kernel step
mcmc_parameter_update_fn
A callable that takes the SMCState and SMCInfo at step i and constructs a parameter to be used by the
inner kernel in i+1 iteration.
initial_parameter_value
Paramter to be used by the mcmc_factory before the first iteration.
Parameter to be used by the mcmc_factory before the first iteration.
extra_parameters:
parameters to be used for the creation of the smc_algorithm.
Expand All @@ -117,9 +120,8 @@ def __new__( # type: ignore[misc]
smc_algorithm: Union[adaptive_tempered_smc, tempered_smc],
logprior_fn: Callable,
loglikelihood_fn: Callable,
mcmc_factory: Callable,
mcmc_step_fn: Callable,
mcmc_init_fn: Callable,
mcmc_parameters: Dict,
resampling_fn: Callable,
mcmc_parameter_update_fn: Callable[[SMCState, SMCInfo], ArrayTree],
initial_parameter_value,
Expand All @@ -130,9 +132,8 @@ def __new__( # type: ignore[misc]
smc_algorithm,
logprior_fn,
loglikelihood_fn,
mcmc_factory,
mcmc_step_fn,
mcmc_init_fn,
mcmc_parameters,
resampling_fn,
mcmc_parameter_update_fn,
num_mcmc_steps,
Expand Down
6 changes: 3 additions & 3 deletions blackjax/smc/tempered.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,12 @@ def tempered_logposterior_fn(position: ArrayLikeTree) -> float:
tempered_loglikelihood = state.lmbda * loglikelihood_fn(position)
return logprior + tempered_loglikelihood

def mcmc_kernel(rng_key, position):
def mcmc_kernel(rng_key, position, step_parameters):
state = mcmc_init_fn(position, tempered_logposterior_fn)

def body_fn(state, rng_key):
new_state, info = mcmc_step_fn(
rng_key, state, tempered_logposterior_fn, **mcmc_parameters
rng_key, state, tempered_logposterior_fn, **step_parameters
)
return new_state, info

Expand All @@ -142,7 +142,7 @@ def body_fn(state, rng_key):

smc_state, info = smc.base.step(
rng_key,
SMCState(state.particles, state.weights),
SMCState(state.particles, state.weights, mcmc_parameters),
jax.vmap(mcmc_kernel),
jax.vmap(log_weights_fn),
resampling_fn,
Expand Down
8 changes: 5 additions & 3 deletions tests/mcmc/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ def sample_orbit(orbit, weights, rng_key):
return samples


def irmh_proposal_distribution(rng_key):
def irmh_proposal_distribution(rng_key, mean):
"""
The proposal distribution is chosen to be wider than the target, so that the RMH rejection
doesn't make the sample overemphasize the center of the target distribution.
"""
return 1.0 + jax.random.normal(rng_key) * 25.0
return mean + jax.random.normal(rng_key) * 25.0


def rmh_proposal_distribution(rng_key, position):
Expand Down Expand Up @@ -657,7 +657,9 @@ def test_univariate_normal(
self, algorithm, initial_position, parameters, num_sampling_steps, burnin
):
if algorithm == blackjax.irmh:
parameters["proposal_distribution"] = irmh_proposal_distribution
parameters["proposal_distribution"] = functools.partial(
irmh_proposal_distribution, mean=1.0
)

if algorithm == blackjax.rmh:
parameters["proposal_generator"] = rmh_proposal_distribution
Expand Down
89 changes: 55 additions & 34 deletions tests/smc/test_inner_kernel_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import blackjax
import blackjax.smc.resampling as resampling
from blackjax import adaptive_tempered_smc, tempered_smc
from blackjax.mcmc.random_walk import build_irmh
from blackjax.smc import extend_params
from blackjax.smc.inner_kernel_tuning import inner_kernel_tuning
from blackjax.smc.tuning.from_kernel_info import update_scale_from_acceptance_rate
from blackjax.smc.tuning.from_particles import (
Expand Down Expand Up @@ -92,38 +94,37 @@ def smc_inner_kernel_tuning_test_case(
proposal_factory.return_value = 100

def mcmc_parameter_update_fn(state, info):
return 100
return extend_params(1000, {"mean": 100})

mcmc_factory = MagicMock()
sampling_algorithm = MagicMock()
mcmc_factory.return_value = sampling_algorithm
prior = lambda x: stats.norm.logpdf(x)

def kernel_factory(proposal_distribution):
kernel = blackjax.irmh.build_kernel()

def wrapped_kernel(rng_key, state, logdensity):
return kernel(rng_key, state, logdensity, proposal_distribution)

return wrapped_kernel
def wrapped_kernel(rng_key, state, logdensity, mean):
return build_irmh()(
rng_key,
state,
logdensity,
functools.partial(irmh_proposal_distribution, mean=mean),
)

kernel = inner_kernel_tuning(
logprior_fn=prior,
loglikelihood_fn=specialized_log_weights_fn,
mcmc_factory=kernel_factory,
mcmc_step_fn=wrapped_kernel,
mcmc_init_fn=blackjax.irmh.init,
resampling_fn=resampling.systematic,
smc_algorithm=smc_algorithm,
mcmc_parameters={},
mcmc_parameter_update_fn=mcmc_parameter_update_fn,
initial_parameter_value=irmh_proposal_distribution,
initial_parameter_value=extend_params(1000, {"mean": 1.0}),
**smc_parameters,
)

new_state, new_info = kernel.step(
self.key, state=kernel.init(init_particles), **step_parameters
)
assert new_state.parameter_override == 100
assert set(new_state.parameter_override.keys()) == {
"mean",
}
np.testing.assert_allclose(new_state.parameter_override["mean"], 100)


class MeanAndStdFromParticlesTest(chex.TestCase):
Expand Down Expand Up @@ -270,14 +271,6 @@ def setUp(self):
super().setUp()
self.key = jax.random.key(42)

def mcmc_factory(self, mass_matrix):
return functools.partial(
blackjax.hmc.build_kernel(),
inverse_mass_matrix=mass_matrix,
step_size=10e-2,
num_integration_steps=50,
)

@chex.all_variants(with_pmap=False)
def test_with_adaptive_tempered(self):
(
Expand All @@ -286,18 +279,32 @@ def test_with_adaptive_tempered(self):
loglikelihood_fn,
) = self.particles_prior_loglikelihood()

def parameter_update(state, info):
return extend_params(
100,
{
"inverse_mass_matrix": mass_matrix_from_particles(state.particles),
"step_size": 10e-2,
"num_integration_steps": 50,
},
)

init, step = blackjax.inner_kernel_tuning(
adaptive_tempered_smc,
logprior_fn,
loglikelihood_fn,
self.mcmc_factory,
blackjax.hmc.build_kernel(),
blackjax.hmc.init,
{},
resampling.systematic,
mcmc_parameter_update_fn=lambda state, info: mass_matrix_from_particles(
state.particles
mcmc_parameter_update_fn=parameter_update,
initial_parameter_value=extend_params(
100,
dict(
inverse_mass_matrix=jnp.eye(2),
step_size=10e-2,
num_integration_steps=50,
),
),
initial_parameter_value=jnp.eye(2),
num_mcmc_steps=10,
target_ess=0.5,
)
Expand All @@ -319,7 +326,7 @@ def body(carry):

state, _ = inference_loop(smc_kernel, self.key, init_state)

assert state.parameter_override.shape == (2, 2)
assert state.parameter_override["inverse_mass_matrix"].shape == (100, 2, 2)
self.assert_linear_regression_test_case(state.sampler_state)

@chex.all_variants(with_pmap=False)
Expand All @@ -331,18 +338,32 @@ def test_with_tempered_smc(self):
loglikelihood_fn,
) = self.particles_prior_loglikelihood()

def parameter_update(state, info):
return extend_params(
100,
{
"inverse_mass_matrix": mass_matrix_from_particles(state.particles),
"step_size": 10e-2,
"num_integration_steps": 50,
},
)

init, step = blackjax.inner_kernel_tuning(
tempered_smc,
logprior_fn,
loglikelihood_fn,
self.mcmc_factory,
blackjax.hmc.build_kernel(),
blackjax.hmc.init,
{},
resampling.systematic,
mcmc_parameter_update_fn=lambda state, info: mass_matrix_from_particles(
state.particles
mcmc_parameter_update_fn=parameter_update,
initial_parameter_value=extend_params(
100,
dict(
inverse_mass_matrix=jnp.eye(2),
step_size=10e-2,
num_integration_steps=50,
),
),
initial_parameter_value=jnp.eye(2),
num_mcmc_steps=10,
)

Expand Down
Loading

0 comments on commit 3dc3809

Please sign in to comment.