Skip to content

Commit

Permalink
Refactor of Sequential Monte Carlo internals (#5281)
Browse files Browse the repository at this point in the history
Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
  • Loading branch information
ciguaran and ricardoV94 authored Dec 24, 2021
1 parent 29720d0 commit e6fc2ec
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 50 deletions.
123 changes: 80 additions & 43 deletions pymc/smc/sample_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,52 +222,54 @@ def sample_smc(
)

t1 = time.time()

if cores > 1:
pbar = progress_bar((), total=100, display=progressbar)
pbar.update(0)
pbars = [pbar] + [None] * (chains - 1)

pool = mp.Pool(cores)

# "manually" (de)serialize params before/after multiprocessing
params = tuple(cloudpickle.dumps(p) for p in params)
kernel_kwargs = {key: cloudpickle.dumps(value) for key, value in kernel_kwargs.items()}
results = _starmap_with_kwargs(
pool,
_sample_smc_int,
[(*params, random_seed[chain], chain, pbars[chain]) for chain in range(chains)],
repeat(kernel_kwargs),
results = run_chains_parallel(
chains, progressbar, _sample_smc_int, params, random_seed, kernel_kwargs, cores
)
results = tuple(cloudpickle.loads(r) for r in results)
pool.close()
pool.join()

else:
results = []
pbar = progress_bar((), total=100 * chains, display=progressbar)
pbar.update(0)
for chain in range(chains):
pbar.offset = 100 * chain
pbar.base_comment = f"Chain: {chain+1}/{chains}"
results.append(
_sample_smc_int(*params, random_seed[chain], chain, pbar, **kernel_kwargs)
)

results = run_chains_sequential(
chains, progressbar, _sample_smc_int, params, random_seed, kernel_kwargs
)
(
traces,
sample_stats,
sample_settings,
) = zip(*results)

trace = MultiTrace(traces)
idata = None

# Save sample_stats
_t_sampling = time.time() - t1
sample_stats, idata = _save_sample_stats(
sample_settings,
sample_stats,
chains,
trace,
return_inferencedata,
_t_sampling,
idata_kwargs,
model,
)

if compute_convergence_checks:
_compute_convergence_checks(idata, draws, model, trace)
return idata if return_inferencedata else trace


def _save_sample_stats(
sample_settings,
sample_stats,
chains,
trace,
return_inferencedata,
_t_sampling,
idata_kwargs,
model,
):
sample_settings_dict = sample_settings[0]
sample_settings_dict["_t_sampling"] = _t_sampling

sample_stats_dict = sample_stats[0]

if chains > 1:
# Collect the stat values from each chain in a single list
for stat in sample_stats[0].keys():
Expand All @@ -281,6 +283,7 @@ def sample_smc(
setattr(trace.report, stat, value)
for stat, value in sample_settings_dict.items():
setattr(trace.report, stat, value)
idata = None
else:
for stat, value in sample_stats_dict.items():
if chains > 1:
Expand All @@ -303,19 +306,20 @@ def sample_smc(
idata = to_inference_data(trace, **ikwargs)
idata = InferenceData(**idata, sample_stats=sample_stats)

if compute_convergence_checks:
if draws < 100:
warnings.warn(
"The number of samples is too small to check convergence reliably.",
stacklevel=2,
)
else:
if idata is None:
idata = to_inference_data(trace, log_likelihood=False)
trace.report._run_convergence_checks(idata, model)
trace.report._log_summary()
return sample_stats, idata

return idata if return_inferencedata else trace

def _compute_convergence_checks(idata, draws, model, trace):
if draws < 100:
warnings.warn(
"The number of samples is too small to check convergence reliably.",
stacklevel=2,
)
else:
if idata is None:
idata = to_inference_data(trace, log_likelihood=False)
trace.report._run_convergence_checks(idata, model)
trace.report._log_summary()


def _sample_smc_int(
Expand Down Expand Up @@ -391,6 +395,39 @@ def _sample_smc_int(
return results


def run_chains_parallel(chains, progressbar, to_run, params, random_seed, kernel_kwargs, cores):
pbar = progress_bar((), total=100, display=progressbar)
pbar.update(0)
pbars = [pbar] + [None] * (chains - 1)

pool = mp.Pool(cores)

# "manually" (de)serialize params before/after multiprocessing
params = tuple(cloudpickle.dumps(p) for p in params)
kernel_kwargs = {key: cloudpickle.dumps(value) for key, value in kernel_kwargs.items()}
results = _starmap_with_kwargs(
pool,
to_run,
[(*params, random_seed[chain], chain, pbars[chain]) for chain in range(chains)],
repeat(kernel_kwargs),
)
results = tuple(cloudpickle.loads(r) for r in results)
pool.close()
pool.join()
return results


def run_chains_sequential(chains, progressbar, to_run, params, random_seed, kernel_kwargs):
results = []
pbar = progress_bar((), total=100 * chains, display=progressbar)
pbar.update(0)
for chain in range(chains):
pbar.offset = 100 * chain
pbar.base_comment = f"Chain: {chain + 1}/{chains}"
results.append(to_run(*params, random_seed[chain], chain, pbar, **kernel_kwargs))
return results


def _starmap_with_kwargs(pool, fn, args_iter, kwargs_iter):
# Helper function to allow kwargs with Pool.starmap
# Copied from https://stackoverflow.com/a/53173433/13311693
Expand Down
13 changes: 9 additions & 4 deletions pymc/smc/smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ def __init__(

self.draws = draws
self.start = start
if threshold < 0 or threshold > 1:
raise ValueError(f"Threshold value {threshold} must be between 0 and 1")
self.threshold = threshold
self.model = model
self.rng = np.random.default_rng(seed=random_seed)
Expand Down Expand Up @@ -192,7 +194,6 @@ def _initialize_kernel(self):
initial_point = self.model.recompute_initial_point(seed=self.rng.integers(2 ** 30))
for v in self.variables:
self.var_info[v.name] = (initial_point[v.name].shape, initial_point[v.name].size)

# Create particles bijection map
if self.start:
init_rnd = self.start
Expand All @@ -203,6 +204,7 @@ def _initialize_kernel(self):
for i in range(self.draws):
point = Point({v.name: init_rnd[v.name][i] for v in self.variables}, model=self.model)
population.append(DictToArrayBijection.map(point).data)

self.tempered_posterior = np.array(floatX(population))

# Initialize prior and likelihood log probabilities
Expand All @@ -228,13 +230,16 @@ def setup_kernel(self):
def update_beta_and_weights(self):
"""Calculate the next inverse temperature (beta)
The importance weights based on two sucesive tempered likelihoods (i.e.
The importance weights based on two successive tempered likelihoods (i.e.
two successive values of beta) and updates the marginal likelihood estimate.
ESS is calculated for importance sampling. BDA 3rd ed. eq 10.4
"""
self.iteration += 1

low_beta = old_beta = self.beta
up_beta = 2.0

rN = int(len(self.likelihood_logp) * self.threshold)

while up_beta - low_beta > 1e-6:
Expand Down Expand Up @@ -268,6 +273,7 @@ def resample(self):
self.tempered_posterior = self.tempered_posterior[self.resampling_indexes]
self.prior_logp = self.prior_logp[self.resampling_indexes]
self.likelihood_logp = self.likelihood_logp[self.resampling_indexes]

self.tempered_posterior_logp = self.prior_logp + self.likelihood_logp * self.beta

def tune(self):
Expand Down Expand Up @@ -303,7 +309,7 @@ def sample_settings(self) -> Dict:
def _posterior_to_trace(self, chain=0) -> NDArray:
"""Save results into a PyMC trace
This method shoud not be overwritten.
This method should not be overwritten.
"""
lenght_pos = len(self.tempered_posterior)
varnames = [v.name for v in self.variables]
Expand Down Expand Up @@ -497,7 +503,6 @@ def tune(self):
def mutate(self):
"""Metropolis-Hastings perturbation."""
ac_ = np.empty((self.n_steps, self.draws))

log_R = np.log(self.rng.random((self.n_steps, self.draws)))
for n_step in range(self.n_steps):
proposal = floatX(
Expand Down
14 changes: 11 additions & 3 deletions pymc/tests/test_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def setup_class(self):
super().setup_class()
self.samples = 1000
n = 4
mu1 = np.ones(n) * (1.0 / 2)
mu1 = np.ones(n) * 0.5
mu2 = -mu1

stdev = 0.1
Expand All @@ -54,6 +54,9 @@ def setup_class(self):
w2 = 1 - stdev

def two_gaussians(x):
"""
Mixture of gaussians likelihood
"""
log_like1 = (
-0.5 * n * at.log(2 * np.pi)
- 0.5 * at.log(dsigma)
Expand All @@ -80,8 +83,9 @@ def test_sample(self):
initial_rng_state = np.random.get_state()
with self.SMC_test:
mtrace = pm.sample_smc(draws=self.samples, return_inferencedata=False)
assert_random_state_equal(initial_rng_state, np.random.get_state())

# Verify sampling was done with a non-global random generator
assert_random_state_equal(initial_rng_state, np.random.get_state())
x = mtrace["X"]
mu1d = np.abs(x).mean(axis=0)
np.testing.assert_allclose(self.muref, mu1d, rtol=0.0, atol=0.03)
Expand Down Expand Up @@ -109,7 +113,6 @@ def test_discrete_rounding_proposal(self):
def test_unobserved_discrete(self):
n = 10
rng = self.get_random_state()

z_true = np.zeros(n, dtype=int)
z_true[int(n / 2) :] = 1
y = st.norm(np.array([-1, 1])[z_true], 0.25).rvs(random_state=rng)
Expand All @@ -124,6 +127,10 @@ def test_unobserved_discrete(self):
assert np.all(np.median(trace["z"], axis=0) == z_true)

def test_marginal_likelihood(self):
"""
Verifies that the log marginal likelihood function
can be correctly computed for a Beta-Bernoulli model.
"""
data = np.repeat([1, 0], [50, 50])
marginals = []
a_prior_0, b_prior_0 = 1.0, 1.0
Expand All @@ -135,6 +142,7 @@ def test_marginal_likelihood(self):
y = pm.Bernoulli("y", a, observed=data)
trace = pm.sample_smc(2000, return_inferencedata=False)
marginals.append(trace.report.log_marginal_likelihood)

# compare to the analytical result
assert abs(np.exp(np.nanmean(marginals[1]) - np.nanmean(marginals[0])) - 4.0) <= 1

Expand Down

0 comments on commit e6fc2ec

Please sign in to comment.