Skip to content

Commit

Permalink
Add progressbar to sample_smc and deprecate parallel (#4826)
Browse files Browse the repository at this point in the history
* Add progressbar to `sample_smc` and deprecate `parallel`

Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
Co-authored-by: Michael Osthege <michael.osthege@outlook.com>
  • Loading branch information
ricardoV94 and michaelosthege authored Jul 2, 2021
1 parent 13487d0 commit de83381
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 25 deletions.
53 changes: 40 additions & 13 deletions pymc3/smc/sample_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import numpy as np

from arviz import InferenceData
from fastprogress.fastprogress import progress_bar

import pymc3

Expand All @@ -45,12 +46,13 @@ def sample_smc(
save_log_pseudolikelihood=True,
model=None,
random_seed=-1,
parallel=False,
parallel=None,
chains=None,
cores=None,
compute_convergence_checks=True,
return_inferencedata=True,
idata_kwargs=None,
progressbar=True,
):
r"""
Sequential Monte Carlo based sampling.
Expand Down Expand Up @@ -90,12 +92,9 @@ def sample_smc(
model: Model (optional if in ``with`` context)).
random_seed: int
random seed
parallel: bool
Distribute computations across cores if the number of cores is larger than 1.
Defaults to False.
cores : int
The number of chains to run in parallel. If ``None``, set to the number of CPUs in the
system, but at most 4.
system.
chains : int
The number of chains to sample. Running independent chains is important for some
convergence statistics. If ``None`` (default), then set to either ``cores`` or 2, whichever
Expand All @@ -108,6 +107,9 @@ def sample_smc(
Defaults to ``True``.
idata_kwargs : dict, optional
Keyword arguments for :func:`pymc3.to_inference_data`
progressbar : bool, optional default=True
Whether or not to display a progress bar in the command line.
Notes
-----
SMC works by moving through successive stages. At each stage the inverse temperature
Expand Down Expand Up @@ -153,6 +155,16 @@ def sample_smc(
816-832. `link <http://ascelibrary.org/doi/abs/10.1061/%28ASCE%290733-9399
%282007%29133:7%28816%29>`__
"""

if parallel is not None:
warnings.warn(
"The argument parallel is deprecated, use the argument cores instead.",
DeprecationWarning,
stacklevel=2,
)
if parallel is False:
cores = 1

_log = logging.getLogger("pymc3")
_log.info("Initializing SMC sampler...")

Expand Down Expand Up @@ -206,19 +218,26 @@ def sample_smc(
)

t1 = time.time()
if parallel and chains > 1:
loggers = [_log] + [None] * (chains - 1)
if cores > 1:
pbar = progress_bar((), total=100, display=progressbar)
pbar.update(0)
pbars = [pbar] + [None] * (chains - 1)

pool = mp.Pool(cores)
results = pool.starmap(
sample_smc_int, [(*params, random_seed[i], i, loggers[i]) for i in range(chains)]
sample_smc_int, [(*params, random_seed[i], i, pbars[i]) for i in range(chains)]
)

pool.close()
pool.join()

else:
results = []
pbar = progress_bar((), total=100 * chains, display=progressbar)
pbar.update(0)
for i in range(chains):
results.append(sample_smc_int(*params, random_seed[i], i, _log))
pbar.offset = 100 * i
pbar.base_comment = f"Chain: {i+1}/{chains}"
results.append(sample_smc_int(*params, random_seed[i], i, pbar))

(
traces,
Expand Down Expand Up @@ -310,7 +329,7 @@ def sample_smc_int(
model,
random_seed,
chain,
_log,
progressbar=None,
):
"""Run one SMC instance."""
smc = SMC(
Expand All @@ -331,14 +350,22 @@ def sample_smc_int(
betas = []
accept_ratios = []
nsteps = []

if progressbar:
progressbar.comment = f"{getattr(progressbar, 'base_comment', '')} Stage: 0 Beta: 0"
progressbar.update_bar(getattr(progressbar, "offset", 0) + 0)

smc.initialize_population()
smc.setup_kernel()
smc.initialize_logp()

while smc.beta < 1:
smc.update_weights_beta()
if _log is not None:
_log.info(f"Stage: {stage:3d} Beta: {smc.beta:.3f}")
if progressbar:
progressbar.comment = (
f"{getattr(progressbar, 'base_comment', '')} Stage: {stage} Beta: {smc.beta:.3f}"
)
progressbar.update_bar(getattr(progressbar, "offset", 0) + int(smc.beta * 100))
smc.update_proposal()
smc.resample()
smc.mutate()
Expand Down
65 changes: 53 additions & 12 deletions pymc3/tests/test_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import time

import aesara
import aesara.tensor as at
import numpy as np
Expand Down Expand Up @@ -60,9 +62,22 @@ def two_gaussians(x):

self.muref = mu1

with pm.Model() as self.fast_model:
x = pm.Normal("x", 0, 1)
y = pm.Normal("y", x, 1, observed=0)

with pm.Model() as self.slow_model:
x = pm.Normal("x", 0, 1)
y = pm.Normal("y", x, 1, observed=100)

def test_sample(self):
with self.SMC_test:
mtrace = pm.sample_smc(draws=self.samples, return_inferencedata=False)

mtrace = pm.sample_smc(
draws=self.samples,
cores=1, # Fails in parallel due to #4799
return_inferencedata=False,
)

x = mtrace["X"]
mu1d = np.abs(x).mean(axis=0)
Expand Down Expand Up @@ -107,39 +122,65 @@ def test_slowdown_warning(self):
with pm.Model() as model:
a = pm.Poisson("a", 5)
y = pm.Normal("y", a, 5, observed=[1, 2, 3, 4])
trace = pm.sample_smc(draws=100, chains=2)
trace = pm.sample_smc(draws=100, chains=2, cores=1)

@pytest.mark.parametrize("chains", (1, 2))
def test_return_datatype(self, chains):
draws = 10

with pm.Model() as m:
x = pm.Normal("x", 0, 1)
y = pm.Normal("y", x, 1, observed=5)

with self.fast_model:
idata = pm.sample_smc(chains=chains, draws=draws)
mt = pm.sample_smc(chains=chains, draws=draws, return_inferencedata=False)

assert isinstance(idata, InferenceData)
assert "sample_stats" in idata
assert len(idata.posterior.chain) == chains
assert len(idata.posterior.draw) == draws
assert idata.posterior.dims["chain"] == chains
assert idata.posterior.dims["draw"] == draws

assert isinstance(mt, MultiTrace)
assert mt.nchains == chains
assert mt["x"].size == chains * draws

def test_convergence_checks(self):
with pm.Model() as m:
x = pm.Normal("x", 0, 1)
y = pm.Normal("y", x, 1, observed=5)

with self.fast_model:
with pytest.warns(
UserWarning,
match="The number of samples is too small",
):
pm.sample_smc(draws=99)

def test_parallel_sampling(self):
# Cache graph
with self.slow_model:
_ = pm.sample_smc(draws=10, chains=1, cores=1, return_inferencedata=False)

chains = 4
draws = 100

t0 = time.time()
with self.slow_model:
idata = pm.sample_smc(draws=draws, chains=chains, cores=4)
t_mp = time.time() - t0
assert idata.posterior.dims["chain"] == chains
assert idata.posterior.dims["draw"] == draws

t0 = time.time()
with self.slow_model:
idata = pm.sample_smc(draws=draws, chains=chains, cores=1)
t_seq = time.time() - t0
assert idata.posterior.dims["chain"] == chains
assert idata.posterior.dims["draw"] == draws

assert t_mp < t_seq

def test_depracated_parallel_arg(self):
with self.fast_model:
with pytest.warns(
DeprecationWarning,
match="The argument parallel is deprecated",
):
pm.sample_smc(draws=10, chains=1, parallel=False)


@pytest.mark.xfail(reason="SMC-ABC not refactored yet")
class TestSMCABC(SeededTest):
Expand Down

0 comments on commit de83381

Please sign in to comment.