diff --git a/pymc3/parallel_sampling.py b/pymc3/parallel_sampling.py index c1381156dd5..e62eeedc303 100644 --- a/pymc3/parallel_sampling.py +++ b/pymc3/parallel_sampling.py @@ -234,12 +234,8 @@ def terminate_all(processes, patience=2): class ParallelSampler(object): def __init__(self, draws, tune, chains, cores, seeds, start_points, - step_method, start_chain_num=0, progressbar=True, - notebook=True): - if progressbar and notebook: - import tqdm - tqdm_ = tqdm.tqdm_notebook - elif progressbar: + step_method, start_chain_num=0, progressbar=True): + if progressbar: import tqdm tqdm_ = tqdm.tqdm @@ -257,18 +253,11 @@ def __init__(self, draws, tune, chains, cores, seeds, start_points, self._in_context = False self._start_chain_num = start_chain_num - self._global_progress = self._progress = None + self._progress = None if progressbar: - self._global_progress = tqdm_( - total=chains, unit='chains', position=0) - self._progress = [ - tqdm_( - desc=' Chain %i' % (chain + start_chain_num), - unit='draws', - position=chain + 1, - total=draws + tune) - for chain in range(chains) - ] + self._progress = tqdm_( + total=chains * (draws + tune), unit='draws', + desc='Sampling %s chains' % chains) def _make_active(self): while self._inactive and len(self._active) < self._max_active: @@ -286,17 +275,13 @@ def __iter__(self): draw = ProcessAdapter.recv_draw(self._active) proc, is_last, draw, tuning, stats, warns = draw if self._progress is not None: - self._progress[proc.chain - self._start_chain_num].update() + self._progress.update() if is_last: proc.join() self._active.remove(proc) self._finished.append(proc) self._make_active() - if self._global_progress is not None: - self._global_progress.update() - if self._progress is not None: - self._progress[proc.chain - self._start_chain_num].close() # We could also yield proc.shared_point_view directly, # and only call proc.write_next() after the yield returns. @@ -318,6 +303,4 @@ def __enter__(self): def __exit__(self, *args): ProcessAdapter.terminate_all(self._samplers) if self._progress is not None: - self._global_progress.close() - for progress in self._progress: - progress.close() + self._progress.close()