diff --git a/blackjax/smc/adaptive_tempered.py b/blackjax/smc/adaptive_tempered.py index b8a611606..10fb194fa 100644 --- a/blackjax/smc/adaptive_tempered.py +++ b/blackjax/smc/adaptive_tempered.py @@ -130,7 +130,8 @@ def as_top_level_api( mcmc_init_fn The MCMC init function used to build a MCMC state from a particle position. mcmc_parameters - The parameters of the MCMC step function. + The parameters of the MCMC step function. Parameters with leading dimension + length of 1 are shared amongst the particles. resampling_fn The function used to resample the particles. target_ess diff --git a/blackjax/smc/base.py b/blackjax/smc/base.py index 21d8e12f4..5093cf06b 100644 --- a/blackjax/smc/base.py +++ b/blackjax/smc/base.py @@ -150,12 +150,9 @@ def step( ) -def extend_params(n_particles, params): +def extend_params(params): """Given a dictionary of params, repeats them for every single particle. The expected usage is in cases where the aim is to repeat the same parameters for all chains within SMC. """ - def extend(param): - return jnp.repeat(jnp.asarray(param)[None, ...], n_particles, axis=0) - - return jax.tree.map(extend, params) + return jax.tree.map(lambda x: jnp.asarray(x)[None, ...], params) diff --git a/blackjax/smc/tempered.py b/blackjax/smc/tempered.py index b373d062f..43b83d034 100644 --- a/blackjax/smc/tempered.py +++ b/blackjax/smc/tempered.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial from typing import Callable, NamedTuple import jax @@ -108,6 +109,9 @@ def kernel( Current state of the tempered SMC algorithm lmbda Current value of the tempering parameter + mcmc_parameters + The parameters of the MCMC step function. Parameters with leading dimension + length of 1 are shared amongst the particles. Returns ------- @@ -119,6 +123,14 @@ def kernel( """ delta = lmbda - state.lmbda + shared_mcmc_parameters = {} + unshared_mcmc_parameters = {} + for k, v in mcmc_parameters.items(): + if v.shape[0] == 1: + shared_mcmc_parameters[k] = v[0, ...] + else: + unshared_mcmc_parameters[k] = v + def log_weights_fn(position: ArrayLikeTree) -> float: return delta * loglikelihood_fn(position) @@ -127,11 +139,13 @@ def tempered_logposterior_fn(position: ArrayLikeTree) -> float: tempered_loglikelihood = state.lmbda * loglikelihood_fn(position) return logprior + tempered_loglikelihood + shared_mcmc_step_fn = partial(mcmc_step_fn, **shared_mcmc_parameters) + def mcmc_kernel(rng_key, position, step_parameters): state = mcmc_init_fn(position, tempered_logposterior_fn) def body_fn(state, rng_key): - new_state, info = mcmc_step_fn( + new_state, info = shared_mcmc_step_fn( rng_key, state, tempered_logposterior_fn, **step_parameters ) return new_state, info @@ -142,7 +156,7 @@ def body_fn(state, rng_key): smc_state, info = smc.base.step( rng_key, - SMCState(state.particles, state.weights, mcmc_parameters), + SMCState(state.particles, state.weights, unshared_mcmc_parameters), jax.vmap(mcmc_kernel), jax.vmap(log_weights_fn), resampling_fn, @@ -178,7 +192,8 @@ def as_top_level_api( mcmc_init_fn The MCMC init function used to build a MCMC state from a particle position. mcmc_parameters - The parameters of the MCMC step function. + The parameters of the MCMC step function. Parameters with leading dimension + length of 1 are shared amongst the particles. resampling_fn The function used to resample the particles. num_mcmc_steps diff --git a/tests/smc/test_inner_kernel_tuning.py b/tests/smc/test_inner_kernel_tuning.py index bf970ae47..7d6190af5 100644 --- a/tests/smc/test_inner_kernel_tuning.py +++ b/tests/smc/test_inner_kernel_tuning.py @@ -94,7 +94,7 @@ def smc_inner_kernel_tuning_test_case( proposal_factory.return_value = 100 def mcmc_parameter_update_fn(state, info): - return extend_params(1000, {"mean": 100}) + return extend_params({"mean": 100}) prior = lambda x: stats.norm.logpdf(x) @@ -114,7 +114,7 @@ def wrapped_kernel(rng_key, state, logdensity, mean): resampling_fn=resampling.systematic, smc_algorithm=smc_algorithm, mcmc_parameter_update_fn=mcmc_parameter_update_fn, - initial_parameter_value=extend_params(1000, {"mean": 1.0}), + initial_parameter_value=extend_params({"mean": 1.0}), **smc_parameters, ) @@ -281,7 +281,6 @@ def test_with_adaptive_tempered(self): def parameter_update(state, info): return extend_params( - 100, { "inverse_mass_matrix": mass_matrix_from_particles(state.particles), "step_size": 10e-2, @@ -298,7 +297,6 @@ def parameter_update(state, info): resampling.systematic, mcmc_parameter_update_fn=parameter_update, initial_parameter_value=extend_params( - 100, dict( inverse_mass_matrix=jnp.eye(2), step_size=10e-2, @@ -326,7 +324,7 @@ def body(carry): _, state = inference_loop(smc_kernel, self.key, init_state) - assert state.parameter_override["inverse_mass_matrix"].shape == (100, 2, 2) + assert state.parameter_override["inverse_mass_matrix"].shape == (1, 2, 2) self.assert_linear_regression_test_case(state.sampler_state) @chex.all_variants(with_pmap=False) @@ -340,7 +338,6 @@ def test_with_tempered_smc(self): def parameter_update(state, info): return extend_params( - 100, { "inverse_mass_matrix": mass_matrix_from_particles(state.particles), "step_size": 10e-2, @@ -357,7 +354,6 @@ def parameter_update(state, info): resampling.systematic, mcmc_parameter_update_fn=parameter_update, initial_parameter_value=extend_params( - 100, dict( inverse_mass_matrix=jnp.eye(2), step_size=10e-2, diff --git a/tests/smc/test_kernel_compatibility.py b/tests/smc/test_kernel_compatibility.py index 3e675c2cc..fdda30b3a 100644 --- a/tests/smc/test_kernel_compatibility.py +++ b/tests/smc/test_kernel_compatibility.py @@ -50,7 +50,7 @@ def kernel(rng_key, state, logdensity_fn, proposal_mean): self.check_compatible( kernel, blackjax.additive_step_random_walk.init, - extend_params(self.n_particles, {"proposal_mean": 1.0}), + extend_params({"proposal_mean": 1.0}), ) def test_compatible_with_rmh(self): @@ -70,7 +70,7 @@ def kernel( self.check_compatible( kernel, blackjax.rmh.init, - extend_params(self.n_particles, {"proposal_mean": 1.0}), + extend_params({"proposal_mean": 1.0}), ) def test_compatible_with_hmc(self): @@ -78,7 +78,6 @@ def test_compatible_with_hmc(self): blackjax.hmc.build_kernel(), blackjax.hmc.init, extend_params( - self.n_particles, { "step_size": 0.3, "inverse_mass_matrix": jnp.array([1.0]), @@ -100,7 +99,7 @@ def kernel(rng_key, state, logdensity_fn, mean, proposal_logdensity_fn=None): self.check_compatible( kernel, blackjax.irmh.init, - extend_params(self.n_particles, {"mean": jnp.array([1.0, 1.0])}), + extend_params({"mean": jnp.array([1.0, 1.0])}), ) def test_compatible_with_nuts(self): @@ -108,7 +107,6 @@ def test_compatible_with_nuts(self): blackjax.nuts.build_kernel(), blackjax.nuts.init, extend_params( - self.n_particles, {"step_size": 1e-10, "inverse_mass_matrix": jnp.eye(2)}, ), ) @@ -117,7 +115,7 @@ def test_compatible_with_mala(self): self.check_compatible( blackjax.mala.build_kernel(), blackjax.mala.init, - extend_params(self.n_particles, {"step_size": 1e-10}), + extend_params({"step_size": 1e-10}), ) @staticmethod diff --git a/tests/smc/test_smc.py b/tests/smc/test_smc.py index 2838e984f..6366182a8 100644 --- a/tests/smc/test_smc.py +++ b/tests/smc/test_smc.py @@ -50,16 +50,17 @@ def body_fn(state, rng_key): same_for_all_params = dict( step_size=1e-2, inverse_mass_matrix=jnp.eye(1), num_integration_steps=50 ) + state = init( init_particles, - extend_params(num_particles, same_for_all_params), + same_for_all_params, ) # Run the SMC sampler once new_state, info = self.variant(step, static_argnums=(2, 3, 4))( sample_key, state, - jax.vmap(update_fn), + jax.vmap(update_fn, in_axes=(0, 0, None)), jax.vmap(logdensity_fn), resampling.systematic, ) @@ -87,7 +88,9 @@ def body_fn(state, rng_key): _, (states, info) = jax.lax.scan(body_fn, state, keys) return states.position, info - particles, info = jax.vmap(one_particle_fn)(keys, particles, update_params) + particles, info = jax.vmap(one_particle_fn, in_axes=(0, 0, None))( + keys, particles, update_params + ) particles = particles.reshape((num_particles,)) return particles, info @@ -97,13 +100,10 @@ def body_fn(state, rng_key): init_particles = 0.25 + jax.random.normal(init_key, shape=(num_particles,)) state = init( init_particles, - extend_params( - num_resampled, - dict( - step_size=1e-2, - inverse_mass_matrix=jnp.eye(1), - num_integration_steps=100, - ), + dict( + step_size=1e-2, + inverse_mass_matrix=jnp.eye(1), + num_integration_steps=100, ), ) @@ -125,7 +125,6 @@ def body_fn(state, rng_key): class ExtendParamsTest(chex.TestCase): def test_extend_params(self): extended = extend_params( - 3, { "a": 50, "b": np.array([50]), @@ -133,14 +132,12 @@ def test_extend_params(self): "d": np.array([[1, 2], [3, 4]]), }, ) - np.testing.assert_allclose(extended["a"], np.ones((3,)) * 50) - np.testing.assert_allclose(extended["b"], np.array([[50], [50], [50]])) - np.testing.assert_allclose( - extended["c"], np.array([[50, 60], [50, 60], [50, 60]]) - ) + np.testing.assert_allclose(extended["a"], np.ones((1,)) * 50) + np.testing.assert_allclose(extended["b"], np.array([[50]])) + np.testing.assert_allclose(extended["c"], np.array([[50, 60]])) np.testing.assert_allclose( extended["d"], - np.array([[[1, 2], [3, 4]], [[1, 2], [3, 4]], [[1, 2], [3, 4]]]), + np.array([[[1, 2], [3, 4]]]), ) diff --git a/tests/smc/test_tempered_smc.py b/tests/smc/test_tempered_smc.py index a7d9acdd8..527457d62 100644 --- a/tests/smc/test_tempered_smc.py +++ b/tests/smc/test_tempered_smc.py @@ -65,16 +65,28 @@ def logprior_fn(x): hmc_kernel = blackjax.hmc.build_kernel() hmc_init = blackjax.hmc.init - hmc_parameters = extend_params( - num_particles, + + base_params = extend_params( { "step_size": 10e-2, "inverse_mass_matrix": jnp.eye(2), "num_integration_steps": 50, - }, + } ) - for target_ess in [0.5, 0.75]: + # verify results are equivalent with all shared, all unshared, and mixed params + hmc_parameters_list = [ + base_params, + jax.tree.map(lambda x: jnp.repeat(x, num_particles, axis=0), base_params), + jax.tree_util.tree_map_with_path( + lambda path, x: jnp.repeat(x, num_particles, axis=0) + if path[0].key == "step_size" + else x, + base_params, + ), + ] + + for target_ess, hmc_parameters in zip([0.5, 0.5, 0.75], hmc_parameters_list): tempering = adaptive_tempered_smc( logprior_fn, loglikelihood_fn, @@ -115,7 +127,6 @@ def test_fixed_schedule_tempered_smc(self): hmc_init = blackjax.hmc.init hmc_kernel = blackjax.hmc.build_kernel() hmc_parameters = extend_params( - 100, { "step_size": 10e-2, "inverse_mass_matrix": jnp.eye(2), @@ -182,7 +193,6 @@ def test_normalizing_constant(self): hmc_init = blackjax.hmc.init hmc_kernel = blackjax.hmc.build_kernel() hmc_parameters = extend_params( - num_particles, { "step_size": 10e-2, "inverse_mass_matrix": jnp.eye(num_dim),