Skip to content

Commit

Permalink
Show one progress bar for all chains
Browse files Browse the repository at this point in the history
  • Loading branch information
aseyboldt committed Jun 12, 2018
1 parent 9cc5bfb commit ec02da7
Showing 1 changed file with 8 additions and 25 deletions.
33 changes: 8 additions & 25 deletions pymc3/parallel_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,12 +234,8 @@ def terminate_all(processes, patience=2):

class ParallelSampler(object):
def __init__(self, draws, tune, chains, cores, seeds, start_points,
step_method, start_chain_num=0, progressbar=True,
notebook=True):
if progressbar and notebook:
import tqdm
tqdm_ = tqdm.tqdm_notebook
elif progressbar:
step_method, start_chain_num=0, progressbar=True):
if progressbar:
import tqdm
tqdm_ = tqdm.tqdm

Expand All @@ -257,18 +253,11 @@ def __init__(self, draws, tune, chains, cores, seeds, start_points,
self._in_context = False
self._start_chain_num = start_chain_num

self._global_progress = self._progress = None
self._progress = None
if progressbar:
self._global_progress = tqdm_(
total=chains, unit='chains', position=0)
self._progress = [
tqdm_(
desc=' Chain %i' % (chain + start_chain_num),
unit='draws',
position=chain + 1,
total=draws + tune)
for chain in range(chains)
]
self._progress = tqdm_(
total=chains * (draws + tune), unit='draws',
desc='Sampling %s chains' % chains)

def _make_active(self):
while self._inactive and len(self._active) < self._max_active:
Expand All @@ -286,17 +275,13 @@ def __iter__(self):
draw = ProcessAdapter.recv_draw(self._active)
proc, is_last, draw, tuning, stats, warns = draw
if self._progress is not None:
self._progress[proc.chain - self._start_chain_num].update()
self._progress.update()

if is_last:
proc.join()
self._active.remove(proc)
self._finished.append(proc)
self._make_active()
if self._global_progress is not None:
self._global_progress.update()
if self._progress is not None:
self._progress[proc.chain - self._start_chain_num].close()

# We could also yield proc.shared_point_view directly,
# and only call proc.write_next() after the yield returns.
Expand All @@ -318,6 +303,4 @@ def __enter__(self):
def __exit__(self, *args):
ProcessAdapter.terminate_all(self._samplers)
if self._progress is not None:
self._global_progress.close()
for progress in self._progress:
progress.close()
self._progress.close()

0 comments on commit ec02da7

Please sign in to comment.