Skip to content

Commit

Permalink
Merge branch 'main' into barker-inverse-mm2
Browse files Browse the repository at this point in the history
  • Loading branch information
junpenglao authored Oct 5, 2024
2 parents ef4b434 + 5a25352 commit 6d9c02d
Show file tree
Hide file tree
Showing 8 changed files with 350 additions and 59 deletions.
4 changes: 3 additions & 1 deletion blackjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from .sgmcmc import sgnht as _sgnht
from .smc import adaptive_tempered
from .smc import inner_kernel_tuning as _inner_kernel_tuning
from .smc import partial_posteriors_path as _partial_posteriors_smc
from .smc import tempered
from .vi import meanfield_vi as _meanfield_vi
from .vi import pathfinder as _pathfinder
Expand Down Expand Up @@ -119,8 +120,9 @@ def generate_top_level_api_from(module):
adaptive_tempered_smc = generate_top_level_api_from(adaptive_tempered)
tempered_smc = generate_top_level_api_from(tempered)
inner_kernel_tuning = generate_top_level_api_from(_inner_kernel_tuning)
partial_posteriors_smc = generate_top_level_api_from(_partial_posteriors_smc)

smc_family = [tempered_smc, adaptive_tempered_smc]
smc_family = [tempered_smc, adaptive_tempered_smc, partial_posteriors_smc]
"Step_fn returning state has a .particles attribute"

# stochastic gradient mcmc
Expand Down
1 change: 1 addition & 0 deletions blackjax/smc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@
"tempered",
"inner_kernel_tuning",
"extend_params",
"partial_posteriors_path",
]
28 changes: 28 additions & 0 deletions blackjax/smc/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,31 @@ def extend_params(params):
"""

return jax.tree.map(lambda x: jnp.asarray(x)[None, ...], params)


def update_and_take_last(
mcmc_init_fn,
tempered_logposterior_fn,
shared_mcmc_step_fn,
num_mcmc_steps,
n_particles,
):
"""Given N particles, runs num_mcmc_steps of a kernel starting at each particle, and
returns the last values, waisting the previous num_mcmc_steps-1
samples per chain.
"""

def mcmc_kernel(rng_key, position, step_parameters):
state = mcmc_init_fn(position, tempered_logposterior_fn)

def body_fn(state, rng_key):
new_state, info = shared_mcmc_step_fn(
rng_key, state, tempered_logposterior_fn, **step_parameters
)
return new_state, info

keys = jax.random.split(rng_key, num_mcmc_steps)
last_state, info = jax.lax.scan(body_fn, state, keys)
return last_state.position, info

return jax.vmap(mcmc_kernel), n_particles
64 changes: 64 additions & 0 deletions blackjax/smc/from_mcmc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from functools import partial
from typing import Callable

import jax

from blackjax import smc
from blackjax.smc.base import SMCState, update_and_take_last
from blackjax.types import PRNGKey


def build_kernel(
mcmc_step_fn: Callable,
mcmc_init_fn: Callable,
resampling_fn: Callable,
update_strategy: Callable = update_and_take_last,
):
"""SMC step from MCMC kernels.
Builds MCMC kernels from the input parameters, which may change across iterations.
Moreover, it defines the way such kernels are used to update the particles. This layer
adapts an API defined in terms of kernels (mcmc_step_fn and mcmc_init_fn) into an API
that depends on an update function over the set of particles.
Returns
-------
A callable that takes a rng_key and a state with .particles and .weights and returns a base.SMCState
and base.SMCInfo pair.
"""

def step(
rng_key: PRNGKey,
state,
num_mcmc_steps: int,
mcmc_parameters: dict,
logposterior_fn: Callable,
log_weights_fn: Callable,
) -> tuple[smc.base.SMCState, smc.base.SMCInfo]:
shared_mcmc_parameters = {}
unshared_mcmc_parameters = {}
for k, v in mcmc_parameters.items():
if v.shape[0] == 1:
shared_mcmc_parameters[k] = v[0, ...]
else:
unshared_mcmc_parameters[k] = v

shared_mcmc_step_fn = partial(mcmc_step_fn, **shared_mcmc_parameters)

update_fn, num_resampled = update_strategy(
mcmc_init_fn,
logposterior_fn,
shared_mcmc_step_fn,
n_particles=state.weights.shape[0],
num_mcmc_steps=num_mcmc_steps,
)

return smc.base.step(
rng_key,
SMCState(state.particles, state.weights, unshared_mcmc_parameters),
update_fn,
jax.vmap(log_weights_fn),
resampling_fn,
num_resampled,
)

return step
127 changes: 127 additions & 0 deletions blackjax/smc/partial_posteriors_path.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from typing import Callable, NamedTuple, Optional, Tuple

import jax
import jax.numpy as jnp

from blackjax import SamplingAlgorithm, smc
from blackjax.smc.base import update_and_take_last
from blackjax.smc.from_mcmc import build_kernel as smc_from_mcmc
from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey


class PartialPosteriorsSMCState(NamedTuple):
"""Current state for the tempered SMC algorithm.
particles: PyTree
The particles' positions.
weights:
Weights of the particles, so that they represent a probability distribution
data_mask:
A 1D boolean array to indicate which datapoints to include
in the computation of the observed likelihood.
"""

particles: ArrayTree
weights: Array
data_mask: Array


def init(particles: ArrayLikeTree, num_datapoints: int) -> PartialPosteriorsSMCState:
"""num_datapoints are the number of observations that could potentially be
used in a partial posterior. Since the initial data_mask is all 0s, it
means that no likelihood term will be added (only prior).
"""
num_particles = jax.tree_util.tree_flatten(particles)[0][0].shape[0]
weights = jnp.ones(num_particles) / num_particles
return PartialPosteriorsSMCState(particles, weights, jnp.zeros(num_datapoints))


def build_kernel(
mcmc_step_fn: Callable,
mcmc_init_fn: Callable,
resampling_fn: Callable,
num_mcmc_steps: Optional[int],
mcmc_parameters: ArrayTree,
partial_logposterior_factory: Callable[[Array], Callable],
update_strategy=update_and_take_last,
) -> Callable:
"""Build the Partial Posteriors (data tempering) SMC kernel.
The distribution's trajectory includes increasingly adding more
datapoints to the likelihood. See Section 2.2 of https://arxiv.org/pdf/2007.11936
Parameters
----------
mcmc_step_fn
A function that computes the log density of the prior distribution
mcmc_init_fn
A function that returns the probability at a given position.
resampling_fn
A random function that resamples generated particles based of weights
num_mcmc_steps
Number of iterations in the MCMC chain.
mcmc_parameters
A dictionary of parameters to be used by the inner MCMC kernels
partial_logposterior_factory:
A callable that given an array of 0 and 1, returns a function logposterior(x).
The array represents which values to include in the logposterior calculation. The logposterior
must be jax compilable.
Returns
-------
A callable that takes a rng_key and PartialPosteriorsSMCState and selectors for
the current and previous posteriors, and takes a data-tempered SMC state.
"""
delegate = smc_from_mcmc(mcmc_step_fn, mcmc_init_fn, resampling_fn, update_strategy)

def step(
key, state: PartialPosteriorsSMCState, data_mask: Array
) -> Tuple[PartialPosteriorsSMCState, smc.base.SMCInfo]:
logposterior_fn = partial_logposterior_factory(data_mask)

previous_logposterior_fn = partial_logposterior_factory(state.data_mask)

def log_weights_fn(x):
return logposterior_fn(x) - previous_logposterior_fn(x)

state, info = delegate(
key, state, num_mcmc_steps, mcmc_parameters, logposterior_fn, log_weights_fn
)

return (
PartialPosteriorsSMCState(state.particles, state.weights, data_mask),
info,
)

return step


def as_top_level_api(
mcmc_step_fn: Callable,
mcmc_init_fn: Callable,
mcmc_parameters: dict,
resampling_fn: Callable,
num_mcmc_steps,
partial_logposterior_factory: Callable,
update_strategy=update_and_take_last,
) -> SamplingAlgorithm:
"""A factory that wraps the kernel into a SamplingAlgorithm object.
See build_kernel for full documentation on the parameters.
"""

kernel = build_kernel(
mcmc_step_fn,
mcmc_init_fn,
resampling_fn,
num_mcmc_steps,
mcmc_parameters,
partial_logposterior_factory,
update_strategy,
)

def init_fn(position: ArrayLikeTree, num_observations, rng_key=None):
del rng_key
return init(position, num_observations)

def step(key: PRNGKey, state: PartialPosteriorsSMCState, data_mask: Array):
return kernel(key, state, data_mask)

return SamplingAlgorithm(init_fn, step) # type: ignore[arg-type]
66 changes: 11 additions & 55 deletions blackjax/smc/tempered.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Callable, NamedTuple, Optional

import jax
import jax.numpy as jnp

import blackjax.smc as smc
import blackjax.smc.from_mcmc as smc_from_mcmc
from blackjax.base import SamplingAlgorithm
from blackjax.smc.base import SMCState
from blackjax.smc.base import update_and_take_last
from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey

__all__ = ["TemperedSMCState", "init", "build_kernel", "as_top_level_api"]
Expand Down Expand Up @@ -48,35 +48,6 @@ def init(particles: ArrayLikeTree):
return TemperedSMCState(particles, weights, 0.0)


def update_and_take_last(
mcmc_init_fn,
tempered_logposterior_fn,
shared_mcmc_step_fn,
num_mcmc_steps,
n_particles,
):
"""
Given N particles, runs num_mcmc_steps of a kernel starting at each particle, and
returns the last values, waisting the previous num_mcmc_steps-1
samples per chain.
"""

def mcmc_kernel(rng_key, position, step_parameters):
state = mcmc_init_fn(position, tempered_logposterior_fn)

def body_fn(state, rng_key):
new_state, info = shared_mcmc_step_fn(
rng_key, state, tempered_logposterior_fn, **step_parameters
)
return new_state, info

keys = jax.random.split(rng_key, num_mcmc_steps)
last_state, info = jax.lax.scan(body_fn, state, keys)
return last_state.position, info

return jax.vmap(mcmc_kernel), n_particles


def build_kernel(
logprior_fn: Callable,
loglikelihood_fn: Callable,
Expand Down Expand Up @@ -121,6 +92,9 @@ def build_kernel(
information about the transition.
"""
delegate = smc_from_mcmc.build_kernel(
mcmc_step_fn, mcmc_init_fn, resampling_fn, update_strategy
)

def kernel(
rng_key: PRNGKey,
Expand Down Expand Up @@ -153,14 +127,6 @@ def kernel(
"""
delta = lmbda - state.lmbda

shared_mcmc_parameters = {}
unshared_mcmc_parameters = {}
for k, v in mcmc_parameters.items():
if v.shape[0] == 1:
shared_mcmc_parameters[k] = v[0, ...]
else:
unshared_mcmc_parameters[k] = v

def log_weights_fn(position: ArrayLikeTree) -> float:
return delta * loglikelihood_fn(position)

Expand All @@ -169,23 +135,13 @@ def tempered_logposterior_fn(position: ArrayLikeTree) -> float:
tempered_loglikelihood = state.lmbda * loglikelihood_fn(position)
return logprior + tempered_loglikelihood

shared_mcmc_step_fn = partial(mcmc_step_fn, **shared_mcmc_parameters)

update_fn, num_resampled = update_strategy(
mcmc_init_fn,
tempered_logposterior_fn,
shared_mcmc_step_fn,
n_particles=state.weights.shape[0],
num_mcmc_steps=num_mcmc_steps,
)

smc_state, info = smc.base.step(
smc_state, info = delegate(
rng_key,
SMCState(state.particles, state.weights, unshared_mcmc_parameters),
update_fn,
jax.vmap(log_weights_fn),
resampling_fn,
num_resampled,
state,
num_mcmc_steps,
mcmc_parameters,
tempered_logposterior_fn,
log_weights_fn,
)

tempered_state = TemperedSMCState(
Expand Down
Loading

0 comments on commit 6d9c02d

Please sign in to comment.