Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use fastprogress instead of tqdm progressbar #3693

Merged
merged 10 commits into from
Dec 9, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion RELEASE-NOTES.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
# Release Notes

## PyMC3 3.8 (on deck)
## PyMC3 3.9 (On deck)

### New features
- use [fastprogress](https://github.com/fastai/fastprogress) instead of tqdm [#3693](https://github.com/pymc-devs/pymc3/pull/3693)

## PyMC3 3.8 (November 29 2019)

### New features
- Implemented robust u turn check in NUTS (similar to stan-dev/stan#2800). See PR [#3605]
Expand Down
75 changes: 37 additions & 38 deletions pymc3/parallel_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import errno

import numpy as np
from fastprogress import progress_bar

from . import theanof

Expand All @@ -17,28 +18,31 @@

def _get_broken_pipe_exception():
import sys
if sys.platform == 'win32':
return RuntimeError("The communication pipe between the main process "
"and its spawned children is broken.\n"
"In Windows OS, this usually means that the child "
"process raised an exception while it was being "
"spawned, before it was setup to communicate to "
"the main process.\n"
"The exceptions raised by the child process while "
"spawning cannot be caught or handled from the "
"main process, and when running from an IPython or "
"jupyter notebook interactive kernel, the child's "
"exception and traceback appears to be lost.\n"
"A known way to see the child's error, and try to "
"fix or handle it, is to run the problematic code "
"as a batch script from a system's Command Prompt. "
"The child's exception will be printed to the "
"Command Promt's stderr, and it should be visible "
"above this error and traceback.\n"
"Note that if running a jupyter notebook that was "
"invoked from a Command Prompt, the child's "
"exception should have been printed to the Command "
"Prompt on which the notebook is running.")

if sys.platform == "win32":
return RuntimeError(
"The communication pipe between the main process "
"and its spawned children is broken.\n"
"In Windows OS, this usually means that the child "
"process raised an exception while it was being "
"spawned, before it was setup to communicate to "
"the main process.\n"
"The exceptions raised by the child process while "
"spawning cannot be caught or handled from the "
"main process, and when running from an IPython or "
"jupyter notebook interactive kernel, the child's "
"exception and traceback appears to be lost.\n"
"A known way to see the child's error, and try to "
"fix or handle it, is to run the problematic code "
"as a batch script from a system's Command Prompt. "
"The child's exception will be printed to the "
"Command Promt's stderr, and it should be visible "
"above this error and traceback.\n"
"Note that if running a jupyter notebook that was "
"invoked from a Command Prompt, the child's "
"exception should have been printed to the Command "
"Prompt on which the notebook is running."
)
else:
return None

Expand Down Expand Up @@ -237,7 +241,6 @@ def __init__(self, draws, tune, step_method, chain, seed, start):
tune,
seed,
)
# We fork right away, so that the main process can start tqdm threads
try:
self._process.start()
except IOError as e:
Expand Down Expand Up @@ -346,8 +349,6 @@ def __init__(
start_chain_num=0,
progressbar=True,
):
if progressbar:
from tqdm import tqdm

if any(len(arg) != chains for arg in [seeds, start_points]):
raise ValueError("Number of seeds and start_points must be %s." % chains)
Expand All @@ -369,14 +370,13 @@ def __init__(

self._progress = None
self._divergences = 0
self._total_draws = 0
self._desc = "Sampling {0._chains:d} chains, {0._divergences:,d} divergences"
self._chains = chains
if progressbar:
self._progress = tqdm(
total=chains * (draws + tune),
unit="draws",
desc=self._desc.format(self)
)
self._progress = progress_bar(
range(chains * (draws + tune)), display=progressbar, auto_update=False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oooh, this display=progressbar is fancy! I would move the import to the top for sure, then!

)
self._progress.comment = self._desc.format(self)

def _make_active(self):
while self._inactive and len(self._active) < self._max_active:
Expand All @@ -393,11 +393,11 @@ def __iter__(self):
while self._active:
draw = ProcessAdapter.recv_draw(self._active)
proc, is_last, draw, tuning, stats, warns = draw
if self._progress is not None:
if not tuning and stats and stats[0].get('diverging'):
self._divergences += 1
self._progress.set_description(self._desc.format(self))
self._progress.update()
self._total_draws += 1
if not tuning and stats and stats[0].get("diverging"):
self._divergences += 1
self._progress.comment = self._desc.format(self)
self._progress.update(self._total_draws)

if is_last:
proc.join()
Expand All @@ -423,8 +423,7 @@ def __enter__(self):

def __exit__(self, *args):
ProcessAdapter.terminate_all(self._samplers)
if self._progress is not None:
self._progress.close()


def _cpu_count():
"""Try to guess the number of CPUs in the system.
Expand Down
55 changes: 27 additions & 28 deletions pymc3/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from .parallel_sampling import _cpu_count
from pymc3.step_methods.hmc import quadpotential
import pymc3 as pm
from tqdm import tqdm
from fastprogress import progress_bar


import sys
Expand Down Expand Up @@ -568,11 +568,17 @@ def _sample_population(
# create the generator that iterates all chains in parallel
chains = [chain + c for c in range(chains)]
sampling = _prepare_iter_population(
draws, chains, step, start, parallelize, tune=tune, model=model, random_seed=random_seed
draws,
chains,
step,
start,
parallelize,
tune=tune,
model=model,
random_seed=random_seed,
)

if progressbar:
sampling = tqdm(sampling, total=draws)
sampling = progress_bar(sampling, total=draws, display=progressbar)

latest_traces = None
for it, traces in enumerate(sampling):
Expand All @@ -596,23 +602,20 @@ def _sample(

sampling = _iter_sample(draws, step, start, trace, chain, tune, model, random_seed)
_pbar_data = None
if progressbar:
_pbar_data = {"chain": chain, "divergences": 0}
_desc = "Sampling chain {chain:d}, {divergences:,d} divergences"
sampling = tqdm(sampling, total=draws, desc=_desc.format(**_pbar_data))
_pbar_data = {"chain": chain, "divergences": 0}
_desc = "Sampling chain {chain:d}, {divergences:,d} divergences"
sampling = progress_bar(sampling, total=draws, display=progressbar)
sampling.comment = _desc.format(**_pbar_data)
try:
strace = None
for it, (strace, diverging) in enumerate(sampling):
if it >= skip_first:
trace = MultiTrace([strace])
if diverging and _pbar_data is not None:
_pbar_data["divergences"] += 1
sampling.set_description(_desc.format(**_pbar_data))
sampling.comment = _desc.format(**_pbar_data)
except KeyboardInterrupt:
pass
finally:
if progressbar:
sampling.close()
return strace


Expand Down Expand Up @@ -753,7 +756,7 @@ def __init__(self, steppers, parallelize):
)
import multiprocessing

for c, stepper in enumerate(tqdm(steppers)):
for c, stepper in enumerate(progress_bar(steppers)):
slave_end, master_end = multiprocessing.Pipe()
stepper_dumps = pickle.dumps(stepper, protocol=4)
process = multiprocessing.Process(
Expand Down Expand Up @@ -1235,9 +1238,13 @@ def sample_posterior_predictive(
nchain = 1

if keep_size and samples is not None:
raise IncorrectArgumentsError("Should not specify both keep_size and samples argukments")
raise IncorrectArgumentsError(
"Should not specify both keep_size and samples argukments"
)
if keep_size and size is not None:
raise IncorrectArgumentsError("Should not specify both keep_size and size argukments")
raise IncorrectArgumentsError(
"Should not specify both keep_size and size argukments"
)

if samples is None:
samples = sum(len(v) for v in trace._straces.values())
Expand All @@ -1253,7 +1260,9 @@ def sample_posterior_predictive(

if var_names is not None:
if vars is not None:
raise IncorrectArgumentsError("Should not specify both vars and var_names arguments.")
raise IncorrectArgumentsError(
"Should not specify both vars and var_names arguments."
)
else:
vars = [model[x] for x in var_names]
elif vars is not None: # var_names is None, and vars is not.
Expand All @@ -1266,8 +1275,7 @@ def sample_posterior_predictive(

indices = np.arange(samples)

if progressbar:
indices = tqdm(indices, total=samples)
indices = progress_bar(indices, total=samples, display=progressbar)

ppc_trace_t = _DefaultTrace(samples)
try:
Expand All @@ -1285,10 +1293,6 @@ def sample_posterior_predictive(
except KeyboardInterrupt:
pass

finally:
if progressbar:
indices.close()

ppc_trace = ppc_trace_t.trace_dict
if keep_size:
for k, ary in ppc_trace.items():
Expand Down Expand Up @@ -1411,8 +1415,7 @@ def sample_posterior_predictive_w(

indices = np.random.randint(0, len_trace, samples)

if progressbar:
indices = tqdm(indices, total=samples)
indices = progress_bar(indices, total=samples, display=progressbar)

try:
ppc = defaultdict(list)
Expand All @@ -1426,10 +1429,6 @@ def sample_posterior_predictive_w(
except KeyboardInterrupt:
pass

finally:
if progressbar:
indices.close()

return {k: np.asarray(v) for k, v in ppc.items()}


Expand Down
Loading