From 85658d71286098d0da3cda065264903551fc76cc Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Thu, 7 Jun 2018 15:44:28 +0200 Subject: [PATCH 01/11] Rewrite of multiprocessing code --- pymc3/parallel_sampling.py | 297 ++++++++++++++++++++++++++ pymc3/sampling.py | 59 +++-- pymc3/step_methods/arraystep.py | 15 +- pymc3/step_methods/compound.py | 17 +- pymc3/tests/test_parallel_sampling.py | 64 ++++++ 5 files changed, 434 insertions(+), 18 deletions(-) create mode 100644 pymc3/parallel_sampling.py create mode 100644 pymc3/tests/test_parallel_sampling.py diff --git a/pymc3/parallel_sampling.py b/pymc3/parallel_sampling.py new file mode 100644 index 00000000000..48aae1bc94d --- /dev/null +++ b/pymc3/parallel_sampling.py @@ -0,0 +1,297 @@ +import multiprocessing +import multiprocessing.sharedctypes +import sys +import ctypes +import time +import logging +from collections import namedtuple + +import six +import tqdm +import numpy as np + +from . import theanof + +logger = logging.getLogger('pymc3') + +# Messages +# ('writing_done', is_last, sample_idx, tuning, stats) +# ('error', *exception_info) + +# ('abort', reason) +# ('write_next',) +# ('start',) + + +class _Process(multiprocessing.Process): + """Seperate process for each chain. + + We communicate with the main process using a pipe, + and send finished samples using shared memory. + """ + def __init__(self, msg_pipe, step_method, shared_point, draws, tune, seed): + super(_Process, self).__init__(daemon=True) + self._msg_pipe = msg_pipe + self._step_method = step_method + self._shared_point = shared_point + self._seed = seed + self._tt_seed = seed + 1 + self._draws = draws + self._tune = tune + + def run(self): + try: + # We do not create this in __init__, as pickling this + # would destroy the shared memory. + self._point = self._make_numpy_refs() + self._start_loop() + except KeyboardInterrupt: + pass + except BaseException: + exc_info = sys.exc_info() + self._msg_pipe.send(('error', exc_info[:2])) + finally: + self._msg_pipe.close() + + def _make_numpy_refs(self): + shape_dtypes = self._step_method.vars_shape_dtype + point = {} + for name, (shape, dtype) in shape_dtypes.items(): + array = self._shared_point[name] + self._shared_point[name] = array + point[name] = np.frombuffer(array, dtype).reshape(shape) + return point + + def _write_point(self, point): + for name, vals in point.items(): + self._point[name][...] = vals + + def _recv_msg(self): + return self._msg_pipe.recv() + + def _start_loop(self): + np.random.seed(self._seed) + theanof.set_tt_rng(self._tt_seed) + + draw = 0 + tuning = True + + msg = self._recv_msg() + if msg[0] == 'abort': + raise KeyboardInterrupt() + if msg[0] != 'start': + raise ValueError('Unexpected msg ' + msg[0]) + + while True: + if draw < self._draws + self._tune: + point, stats = self._compute_point() + else: + return + + if draw == self._tune: + self._step_method.stop_tuning() + tuning = False + + msg = self._recv_msg() + if msg[0] == 'abort': + raise KeyboardInterrupt() + elif msg[0] == 'write_next': + self._write_point(point) + is_last = draw + 1 == self._draws + self._tune + self._msg_pipe.send( + ('writing_done', is_last, draw, tuning, stats)) + draw += 1 + else: + raise ValueError('Unknown message ' + msg[0]) + + def _compute_point(self): + if self._step_method.generates_stats: + point, stats = self._step_method.step(self._point) + else: + point = self._step_method.step(self._point) + stats = None + return point, stats + + +class ProcessAdapter(object): + """Control a Chain process from the main thread.""" + def __init__(self, draws, tune, step_method, chain, seed, start): + self.chain = chain + self._msg_pipe, remote_conn = multiprocessing.Pipe() + + self._shared_point = {} + self._point = {} + for name, (shape, dtype) in step_method.vars_shape_dtype.items(): + size = 1 + for dim in shape: + size *= int(dim) + size *= dtype.itemsize + if size != ctypes.c_size_t(size).value: + raise ValueError('Variable %s is too large' % name) + + array = multiprocessing.sharedctypes.RawArray('c', size) + self._shared_point[name] = array + array_np = np.frombuffer(array, dtype).reshape(shape) + array_np[...] = start[name] + self._point[name] = array_np + + self._readable = True + self._num_samples = 0 + + self._process = _Process( + remote_conn, step_method, self._shared_point, draws, tune, seed) + # We fork right away, so that the main process can start tqdm threads + self._process.start() + + @property + def shared_point_view(self): + """May only be written to or read between a `recv_draw` + call from the process and a `write_next` or `abort` call. + """ + if not self._readable: + raise RuntimeError() + return self._point + + def start(self): + self._msg_pipe.send(('start',)) + + def write_next(self): + self._readable = False + self._msg_pipe.send(('write_next',)) + + def abort(self): + self._msg_pipe.send(('abort',)) + + def join(self, timeout=None): + self._process.join(timeout) + + def terminate(self): + self._process.terminate() + + @staticmethod + def recv_draw(processes, timeout=3600): + if not processes: + raise ValueError('No processes.') + pipes = [proc._msg_pipe for proc in processes] + ready = multiprocessing.connection.wait(pipes) + if not ready: + raise multiprocessing.TimeoutError('No message from samplers.') + idxs = {id(proc._msg_pipe): proc for proc in processes} + proc = idxs[id(ready[0])] + msg = ready[0].recv() + + if msg[0] == 'error': + old = msg[1][1]#.with_traceback(msg[1][2]) + six.raise_from(RuntimeError('Chain %s failed.' % proc.chain), old) + elif msg[0] == 'writing_done': + proc._readable = True + proc._num_samples += 1 + return (proc, *msg[1:]) + else: + raise ValueError('Sampler sent bad message.') + + @staticmethod + def terminate_all(processes, patience=2): + for process in processes: + try: + process.abort() + except EOFError: + pass + + start_time = time.time() + try: + for process in processes: + timeout = start_time + patience - time.time() + if timeout < 0: + raise multiprocessing.TimeoutError() + process.join(timeout) + except multiprocessing.TimeoutError: + logger.warn('Chain processes did not terminate as expected. ' + 'Terminating forcefully...') + for process in processes: + process.terminate() + for process in processes: + process.join() + + +Draw = namedtuple( + 'Draw', + ['chain', 'is_last', 'draw_idx', 'tuning', 'stats', 'point'] +) + + +class ParallelSampler(object): + def __init__(self, draws, tune, chains, cores, seeds, start_points, + step_method, start_chain_num=0, progressbar=True): + self._samplers = [ + ProcessAdapter(draws, tune, step_method, + chain + start_chain_num, seed, start) + for chain, seed, start in zip(range(chains), seeds, start_points) + ] + + self._inactive = self._samplers.copy() + self._finished = [] + self._active = [] + self._max_active = cores + + self._in_context = False + self._start_chain_num = start_chain_num + + self._global_progress = self._progress = None + if progressbar: + self._global_progress = tqdm.tqdm( + total=chains, unit='chains', position=1) + self._progress = [ + tqdm.tqdm( + desc=' Chain %i' % (chain + start_chain_num), + unit='draws', + position=chain + 2, + total=draws + tune) + for chain in range(chains) + ] + + def _make_active(self): + while self._inactive and len(self._active) < self._max_active: + proc = self._inactive.pop() + proc.start() + proc.write_next() + self._active.append(proc) + + def __iter__(self): + if not self._in_context: + raise ValueError('Use ParallelSampler as context manager.') + self._make_active() + + while self._active: + draw = ProcessAdapter.recv_draw(self._active) + proc, is_last, draw, tuning, stats = draw + if self._progress is not None: + self._progress[proc.chain - self._start_chain_num].update() + + if is_last: + proc.join() + self._active.remove(proc) + self._finished.append(proc) + self._make_active() + if self._global_progress is not None: + self._global_progress.update() + + yield Draw( + proc.chain, is_last, draw, tuning, + stats, proc.shared_point_view + ) + + # Already called for new proc in _make_active + if not is_last: + proc.write_next() + + def __enter__(self): + self._in_context = True + return self + + def __exit__(self, *args): + ProcessAdapter.terminate_all(self._samplers) + if self._progress is not None: + self._global_progress.close() + for progress in self._progress: + progress.close() diff --git a/pymc3/sampling.py b/pymc3/sampling.py index cfbfa52f016..ea2cae3d232 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -966,35 +966,64 @@ def _choose_backend(trace, chain, shortcuts=None, **kwds): def _mp_sample(**kwargs): + import sys + cores = kwargs.pop('cores') chain = kwargs.pop('chain') rseed = kwargs.pop('random_seed') start = kwargs.pop('start') chains = kwargs.pop('chains') + draws = kwargs.pop('draws') + tune = kwargs.pop('tune') + step = kwargs.pop('step') + progressbar = kwargs.pop('progressbar') use_mmap = kwargs.pop('use_mmap') - chain_nums = list(range(chain, chain + chains)) - pbars = [kwargs.pop('progressbar')] + [False] * (chains - 1) - jobs = (delayed(_sample)(*args, **kwargs) - for args in zip(chain_nums, pbars, rseed, start)) + if sys.version_info.major >= 3: + import pymc3.parallel_sampling as ps + + model = modelcontext(kwargs.pop('model', None)) + trace = kwargs.pop('trace', None) + traces = [] + for idx in range(chain, chain + chains): + strace = _choose_backend(trace, idx, model=model) + # TODO what is this for? + update_start_vals(start[idx - chain], model.test_point, model) + if step.generates_stats and strace.supports_sampler_stats: + strace.setup(draws + tune, idx + chain, step.stats_dtypes) + else: + strace.setup(draws + tune, idx + chain) + traces.append(strace) + + sampler = ps.ParallelSampler( + draws, tune, chains, cores, rseed, start, step, chain, progressbar) + with sampler: + for draw in sampler: + trace = traces[draw.chain - chain] + if trace.supports_sampler_stats and draw.stats is not None: + trace.record(draw.point, draw.stats) + else: + trace.record(draw.point) + if draw.is_last: + trace.close() + return MultiTrace(traces) - if use_mmap: - traces = Parallel(n_jobs=cores)(jobs) else: - traces = Parallel(n_jobs=cores, mmap_mode=None)(jobs) - - return MultiTrace(traces) + chain_nums = list(range(chain, chain + chains)) + pbars = [kwargs.pop('progressbar')] + [False] * (chains - 1) + jobs = (delayed(_sample)(*args, **kwargs) + for args in zip(chain_nums, pbars, rseed, start)) + if use_mmap: + traces = Parallel(n_jobs=cores)(jobs) + else: + traces = Parallel(n_jobs=cores, mmap_mode=None)(jobs) + return MultiTrace(traces) def stop_tuning(step): """ stop tuning the current step method """ - if hasattr(step, 'tune'): - step.tune = False - - if hasattr(step, 'methods'): - step.methods = [stop_tuning(s) for s in step.methods] - + step.stop_tuning() return step diff --git a/pymc3/step_methods/arraystep.py b/pymc3/step_methods/arraystep.py index 4cf9585e41e..8366bbb5c36 100644 --- a/pymc3/step_methods/arraystep.py +++ b/pymc3/step_methods/arraystep.py @@ -87,13 +87,26 @@ def _competence(cls, vars, have_grad): vars = np.atleast_1d(vars) have_grad = np.atleast_1d(have_grad) competences = [] - for var,has_grad in zip(vars, have_grad): + for var, has_grad in zip(vars, have_grad): try: competences.append(cls.competence(var, has_grad)) except TypeError: competences.append(cls.competence(var)) return competences + @property + def vars_shape_dtype(self): + shape_dtypes = {} + for var in self.vars: + dtype = np.dtype(var.dtype) + shape = var.dshape + shape_dtypes[var.name] = (shape, dtype) + return shape_dtypes + + def stop_tuning(self): + if hasattr(self, 'tune'): + self.tune = False + class ArrayStep(BlockedStep): """ diff --git a/pymc3/step_methods/compound.py b/pymc3/step_methods/compound.py index d8f66171cb9..7b089a1104a 100644 --- a/pymc3/step_methods/compound.py +++ b/pymc3/step_methods/compound.py @@ -6,11 +6,13 @@ class CompoundStep(object): - """Step method composed of a list of several other step methods applied in sequence.""" + """Step method composed of a list of several other step + methods applied in sequence.""" def __init__(self, methods): self.methods = list(methods) - self.generates_stats = any(method.generates_stats for method in self.methods) + self.generates_stats = any( + method.generates_stats for method in self.methods) self.stats_dtypes = [] for method in self.methods: if method.generates_stats: @@ -37,3 +39,14 @@ def warnings(self, strace): if hasattr(method, 'warnings'): warns.extend(method.warnings(strace)) return warns + + def stop_tuning(self): + for method in self.methods: + method.stop_tuning() + + @property + def vars_shape_dtype(self): + dtype_shapes = {} + for method in self.methods: + dtype_shapes.update(method.vars_shape_dtype) + return dtype_shapes diff --git a/pymc3/tests/test_parallel_sampling.py b/pymc3/tests/test_parallel_sampling.py new file mode 100644 index 00000000000..942d0fdbd22 --- /dev/null +++ b/pymc3/tests/test_parallel_sampling.py @@ -0,0 +1,64 @@ +import time + +import pymc3.parallel_sampling as ps +import pymc3 as pm + + +def test_abort(): + with pm.Model() as model: + a = pm.Normal('a', shape=1) + pm.HalfNormal('b') + step1 = pm.NUTS([a]) + step2 = pm.Metropolis([model.b_log__]) + + step = pm.CompoundStep([step1, step2]) + + proc = ps.ProcessAdapter(10, 10, step, chain=3, seed=1, + start={'a': 1., 'b_log__': 2.}) + proc.start() + proc.write_next() + proc.abort() + proc.join() + + +def test_explicit_sample(): + with pm.Model() as model: + a = pm.Normal('a', shape=1) + pm.HalfNormal('b') + step1 = pm.NUTS([a]) + step2 = pm.Metropolis([model.b_log__]) + + step = pm.CompoundStep([step1, step2]) + + start = time.time() + proc = ps.ProcessAdapter(10, 10, step, chain=3, seed=1, + start={'a': 1., 'b_log__': 2.}) + proc.start() + while True: + proc.write_next() + out = ps.ProcessAdapter.recv_draw([proc]) + view = proc.shared_point_view + for name in view: + view[name].copy() + if out[1]: + break + proc.join() + print(time.time() - start) + + +def test_iterator(): + with pm.Model() as model: + a = pm.Normal('a', shape=1) + pm.HalfNormal('b') + step1 = pm.NUTS([a]) + step2 = pm.Metropolis([model.b_log__]) + + step = pm.CompoundStep([step1, step2]) + + start = time.time() + start = {'a': 1., 'b_log__': 2.} + sampler = ps.ParallelSampler(10, 10, 3, 2, [2, 3, 4], [start] * 3, + step, 0, False) + with sampler: + for draw in sampler: + pass From 08603a4998a71cf3a51c366667be9721e4933536 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Thu, 7 Jun 2018 15:55:28 +0200 Subject: [PATCH 02/11] Use tqdm_notebook --- pymc3/parallel_sampling.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/pymc3/parallel_sampling.py b/pymc3/parallel_sampling.py index 48aae1bc94d..d1b75a6ef9d 100644 --- a/pymc3/parallel_sampling.py +++ b/pymc3/parallel_sampling.py @@ -7,7 +7,6 @@ from collections import namedtuple import six -import tqdm import numpy as np from . import theanof @@ -222,7 +221,15 @@ def terminate_all(processes, patience=2): class ParallelSampler(object): def __init__(self, draws, tune, chains, cores, seeds, start_points, - step_method, start_chain_num=0, progressbar=True): + step_method, start_chain_num=0, progressbar=True, + notebook=True): + if progressbar and notebook: + import tqdm + tqdm_ = tqdm.tqdm_notebook + elif progressbar: + import tqdm + tqdm_ = tqdm.tqdm + self._samplers = [ ProcessAdapter(draws, tune, step_method, chain + start_chain_num, seed, start) @@ -239,10 +246,10 @@ def __init__(self, draws, tune, chains, cores, seeds, start_points, self._global_progress = self._progress = None if progressbar: - self._global_progress = tqdm.tqdm( + self._global_progress = tqdm_( total=chains, unit='chains', position=1) self._progress = [ - tqdm.tqdm( + tqdm_( desc=' Chain %i' % (chain + start_chain_num), unit='draws', position=chain + 2, From 51f1c8bc9f0d858e649b761347bc6263d0455712 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Thu, 7 Jun 2018 16:53:58 +0200 Subject: [PATCH 03/11] Sample lower chains first --- pymc3/parallel_sampling.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pymc3/parallel_sampling.py b/pymc3/parallel_sampling.py index d1b75a6ef9d..1531b015cfe 100644 --- a/pymc3/parallel_sampling.py +++ b/pymc3/parallel_sampling.py @@ -247,19 +247,19 @@ def __init__(self, draws, tune, chains, cores, seeds, start_points, self._global_progress = self._progress = None if progressbar: self._global_progress = tqdm_( - total=chains, unit='chains', position=1) + total=chains, unit='chains', position=0) self._progress = [ tqdm_( desc=' Chain %i' % (chain + start_chain_num), unit='draws', - position=chain + 2, + position=chain + 1, total=draws + tune) for chain in range(chains) ] def _make_active(self): while self._inactive and len(self._active) < self._max_active: - proc = self._inactive.pop() + proc = self._inactive.pop(0) proc.start() proc.write_next() self._active.append(proc) @@ -282,6 +282,8 @@ def __iter__(self): self._make_active() if self._global_progress is not None: self._global_progress.update() + if self._progress is not None: + self._progress[proc.chain - self._start_chain_num].close() yield Draw( proc.chain, is_last, draw, tuning, From f1228721a1e138c72218e06a6870911ffcde8b16 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Thu, 7 Jun 2018 17:54:36 +0200 Subject: [PATCH 04/11] Copy shared memory before yielding sampling --- pymc3/parallel_sampling.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pymc3/parallel_sampling.py b/pymc3/parallel_sampling.py index 1531b015cfe..a90c36686cf 100644 --- a/pymc3/parallel_sampling.py +++ b/pymc3/parallel_sampling.py @@ -285,15 +285,15 @@ def __iter__(self): if self._progress is not None: self._progress[proc.chain - self._start_chain_num].close() - yield Draw( - proc.chain, is_last, draw, tuning, - stats, proc.shared_point_view - ) + point = {name: val.copy() + for name, val in proc.shared_point_view.items()} # Already called for new proc in _make_active if not is_last: proc.write_next() + yield Draw(proc.chain, is_last, draw, tuning, stats, point) + def __enter__(self): self._in_context = True return self From dd21cc4e78bc81000d97f44abc0210878bc84f2e Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Fri, 8 Jun 2018 14:11:05 +0200 Subject: [PATCH 05/11] Return partial traces if sampling is interrupted --- pymc3/backends/text.py | 5 +-- pymc3/parallel_sampling.py | 15 +++++--- pymc3/sampling.py | 72 +++++++++++++++++++++++++++++--------- 3 files changed, 70 insertions(+), 22 deletions(-) diff --git a/pymc3/backends/text.py b/pymc3/backends/text.py index d5bec95639d..027a748e318 100644 --- a/pymc3/backends/text.py +++ b/pymc3/backends/text.py @@ -99,8 +99,9 @@ def record(self, point): self._fh.write(','.join(columns) + '\n') def close(self): - self._fh.close() - self._fh = None # Avoid serialization issue. + if self._fh is not None: + self._fh.close() + self._fh = None # Avoid serialization issue. # Selection methods diff --git a/pymc3/parallel_sampling.py b/pymc3/parallel_sampling.py index a90c36686cf..396b9f963fc 100644 --- a/pymc3/parallel_sampling.py +++ b/pymc3/parallel_sampling.py @@ -28,7 +28,8 @@ class _Process(multiprocessing.Process): We communicate with the main process using a pipe, and send finished samples using shared memory. """ - def __init__(self, msg_pipe, step_method, shared_point, draws, tune, seed): + def __init__(self, name, msg_pipe, step_method, shared_point, + draws, tune, seed): super(_Process, self).__init__(daemon=True) self._msg_pipe = msg_pipe self._step_method = step_method @@ -116,6 +117,7 @@ class ProcessAdapter(object): """Control a Chain process from the main thread.""" def __init__(self, draws, tune, step_method, chain, seed, start): self.chain = chain + process_name = "worker_chain_%s" % chain self._msg_pipe, remote_conn = multiprocessing.Pipe() self._shared_point = {} @@ -138,7 +140,8 @@ def __init__(self, draws, tune, step_method, chain, seed, start): self._num_samples = 0 self._process = _Process( - remote_conn, step_method, self._shared_point, draws, tune, seed) + process_name, remote_conn, step_method, self._shared_point, + draws, tune, seed) # We fork right away, so that the main process can start tqdm threads self._process.start() @@ -185,7 +188,7 @@ def recv_draw(processes, timeout=3600): elif msg[0] == 'writing_done': proc._readable = True proc._num_samples += 1 - return (proc, *msg[1:]) + return (proc,) + msg[1:] else: raise ValueError('Sampler sent bad message.') @@ -200,7 +203,7 @@ def terminate_all(processes, patience=2): start_time = time.time() try: for process in processes: - timeout = start_time + patience - time.time() + timeout = time.time() + patience - start_time if timeout < 0: raise multiprocessing.TimeoutError() process.join(timeout) @@ -285,6 +288,10 @@ def __iter__(self): if self._progress is not None: self._progress[proc.chain - self._start_chain_num].close() + # We could also yield proc.shared_point_view directly, + # and only call proc.write_next() after the yield returns. + # This seems to be faster overally though, as the worker + # loses less time waiting. point = {name: val.copy() for name, val in proc.shared_point_view.items()} diff --git a/pymc3/sampling.py b/pymc3/sampling.py index ea2cae3d232..245ceed823e 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -966,8 +966,6 @@ def _choose_backend(trace, chain, shortcuts=None, **kwds): def _mp_sample(**kwargs): - import sys - cores = kwargs.pop('cores') chain = kwargs.pop('chain') rseed = kwargs.pop('random_seed') @@ -978,15 +976,21 @@ def _mp_sample(**kwargs): step = kwargs.pop('step') progressbar = kwargs.pop('progressbar') use_mmap = kwargs.pop('use_mmap') + model = kwargs.pop('model', None) + trace = kwargs.pop('trace', None) if sys.version_info.major >= 3: import pymc3.parallel_sampling as ps - model = modelcontext(kwargs.pop('model', None)) - trace = kwargs.pop('trace', None) + # We did draws += tune in pm.sample + draws -= tune + traces = [] for idx in range(chain, chain + chains): - strace = _choose_backend(trace, idx, model=model) + if trace is not None: + strace = _choose_backend(copy(trace), idx, model=model) + else: + strace = _choose_backend(None, idx, model=model) # TODO what is this for? update_start_vals(start[idx - chain], model.test_point, model) if step.generates_stats and strace.supports_sampler_stats: @@ -997,20 +1001,27 @@ def _mp_sample(**kwargs): sampler = ps.ParallelSampler( draws, tune, chains, cores, rseed, start, step, chain, progressbar) - with sampler: - for draw in sampler: - trace = traces[draw.chain - chain] - if trace.supports_sampler_stats and draw.stats is not None: - trace.record(draw.point, draw.stats) - else: - trace.record(draw.point) - if draw.is_last: - trace.close() - return MultiTrace(traces) + try: + with sampler: + for draw in sampler: + trace = traces[draw.chain - chain] + if trace.supports_sampler_stats and draw.stats is not None: + trace.record(draw.point, draw.stats) + else: + trace.record(draw.point) + if draw.is_last: + trace.close() + return MultiTrace(traces) + except KeyboardInterrupt: + traces, length = _choose_chains(traces, tune) + return MultiTrace(traces)[:length] + finally: + for trace in traces: + trace.close() else: chain_nums = list(range(chain, chain + chains)) - pbars = [kwargs.pop('progressbar')] + [False] * (chains - 1) + pbars = [progressbar] + [False] * (chains - 1) jobs = (delayed(_sample)(*args, **kwargs) for args in zip(chain_nums, pbars, rseed, start)) if use_mmap: @@ -1020,6 +1031,35 @@ def _mp_sample(**kwargs): return MultiTrace(traces) +def _choose_chains(traces, tune): + if tune is None: + tune = 0 + + if not traces: + return [] + + lengths = [max(0, len(trace) - tune) for trace in traces] + if not sum(lengths): + raise ValueError('Not enough samples to build a trace.') + + idxs = np.argsort(lengths)[::-1] + l_sort = np.array(lengths)[idxs] + + final_length = l_sort[0] + last_total = 0 + for i, length in enumerate(l_sort): + total = (i + 1) * length + if total < last_total: + use_until = i + break + last_total = total + final_length = length + else: + use_until = len(lengths) + + return [traces[idx] for idx in idxs[:use_until]], final_length + tune + + def stop_tuning(step): """ stop tuning the current step method """ From 03b10dec480792e91b82b11383d71a6d7f741adf Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Fri, 8 Jun 2018 15:16:46 +0200 Subject: [PATCH 06/11] Fix tests --- pymc3/parallel_sampling.py | 2 +- pymc3/sampling.py | 27 +++++++++++---------------- pymc3/tests/test_parallel_sampling.py | 8 ++++++++ 3 files changed, 20 insertions(+), 17 deletions(-) diff --git a/pymc3/parallel_sampling.py b/pymc3/parallel_sampling.py index 396b9f963fc..5338251af31 100644 --- a/pymc3/parallel_sampling.py +++ b/pymc3/parallel_sampling.py @@ -30,7 +30,7 @@ class _Process(multiprocessing.Process): """ def __init__(self, name, msg_pipe, step_method, shared_point, draws, tune, seed): - super(_Process, self).__init__(daemon=True) + super(_Process, self).__init__(daemon=True, name=name) self._msg_pipe = msg_pipe self._step_method = step_method self._shared_point = shared_point diff --git a/pymc3/sampling.py b/pymc3/sampling.py index 245ceed823e..55a884dcbf4 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -965,19 +965,9 @@ def _choose_backend(trace, chain, shortcuts=None, **kwds): raise ValueError('Argument `trace` is invalid.') -def _mp_sample(**kwargs): - cores = kwargs.pop('cores') - chain = kwargs.pop('chain') - rseed = kwargs.pop('random_seed') - start = kwargs.pop('start') - chains = kwargs.pop('chains') - draws = kwargs.pop('draws') - tune = kwargs.pop('tune') - step = kwargs.pop('step') - progressbar = kwargs.pop('progressbar') - use_mmap = kwargs.pop('use_mmap') - model = kwargs.pop('model', None) - trace = kwargs.pop('trace', None) +def _mp_sample(draws, tune, step, chains, cores, chain, random_seed, + start, progressbar, trace=None, model=None, use_mmap=False, + **kwargs): if sys.version_info.major >= 3: import pymc3.parallel_sampling as ps @@ -1000,7 +990,8 @@ def _mp_sample(**kwargs): traces.append(strace) sampler = ps.ParallelSampler( - draws, tune, chains, cores, rseed, start, step, chain, progressbar) + draws, tune, chains, cores, random_seed, start, step, + chain, progressbar) try: with sampler: for draw in sampler: @@ -1022,8 +1013,12 @@ def _mp_sample(**kwargs): else: chain_nums = list(range(chain, chain + chains)) pbars = [progressbar] + [False] * (chains - 1) - jobs = (delayed(_sample)(*args, **kwargs) - for args in zip(chain_nums, pbars, rseed, start)) + jobs = (delayed(_sample)( + chain=args[0], progressbar=args[1], random_seed=args[2], + start=args[3], draws=draws, step=step, trace=trace, + tune=tune, model=model, **kwargs + ) + for args in zip(chain_nums, pbars, random_seed, start)) if use_mmap: traces = Parallel(n_jobs=cores)(jobs) else: diff --git a/pymc3/tests/test_parallel_sampling.py b/pymc3/tests/test_parallel_sampling.py index 942d0fdbd22..515c130a90c 100644 --- a/pymc3/tests/test_parallel_sampling.py +++ b/pymc3/tests/test_parallel_sampling.py @@ -1,9 +1,13 @@ import time +import sys +import pytest import pymc3.parallel_sampling as ps import pymc3 as pm +@pytest.mark.skipif(sys.version_info < (3,3), + reason="requires python3.3") def test_abort(): with pm.Model() as model: a = pm.Normal('a', shape=1) @@ -21,6 +25,8 @@ def test_abort(): proc.join() +@pytest.mark.skipif(sys.version_info < (3,3), + reason="requires python3.3") def test_explicit_sample(): with pm.Model() as model: a = pm.Normal('a', shape=1) @@ -46,6 +52,8 @@ def test_explicit_sample(): print(time.time() - start) +@pytest.mark.skipif(sys.version_info < (3,3), + reason="requires python3.3") def test_iterator(): with pm.Model() as model: a = pm.Normal('a', shape=1) From 212ff07f8740e2d82cae0cb778883c3090913135 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Tue, 12 Jun 2018 18:19:35 +0200 Subject: [PATCH 07/11] Add warnings in new multiprocessing sampling --- pymc3/parallel_sampling.py | 18 ++++++++++++++---- pymc3/sampling.py | 20 ++++++++++++-------- pymc3/step_methods/compound.py | 4 ++-- pymc3/step_methods/hmc/base_hmc.py | 2 +- pymc3/step_methods/hmc/nuts.py | 4 ++-- 5 files changed, 31 insertions(+), 17 deletions(-) diff --git a/pymc3/parallel_sampling.py b/pymc3/parallel_sampling.py index 5338251af31..c1381156dd5 100644 --- a/pymc3/parallel_sampling.py +++ b/pymc3/parallel_sampling.py @@ -98,8 +98,12 @@ def _start_loop(self): elif msg[0] == 'write_next': self._write_point(point) is_last = draw + 1 == self._draws + self._tune + if is_last: + warns = self._collect_warnings() + else: + warns = None self._msg_pipe.send( - ('writing_done', is_last, draw, tuning, stats)) + ('writing_done', is_last, draw, tuning, stats, warns)) draw += 1 else: raise ValueError('Unknown message ' + msg[0]) @@ -112,6 +116,12 @@ def _compute_point(self): stats = None return point, stats + def _collect_warnings(self): + if hasattr(self._step_method, 'warnings'): + return self._step_method.warnings() + else: + return [] + class ProcessAdapter(object): """Control a Chain process from the main thread.""" @@ -218,7 +228,7 @@ def terminate_all(processes, patience=2): Draw = namedtuple( 'Draw', - ['chain', 'is_last', 'draw_idx', 'tuning', 'stats', 'point'] + ['chain', 'is_last', 'draw_idx', 'tuning', 'stats', 'point', 'warnings'] ) @@ -274,7 +284,7 @@ def __iter__(self): while self._active: draw = ProcessAdapter.recv_draw(self._active) - proc, is_last, draw, tuning, stats = draw + proc, is_last, draw, tuning, stats, warns = draw if self._progress is not None: self._progress[proc.chain - self._start_chain_num].update() @@ -299,7 +309,7 @@ def __iter__(self): if not is_last: proc.write_next() - yield Draw(proc.chain, is_last, draw, tuning, stats, point) + yield Draw(proc.chain, is_last, draw, tuning, stats, point, warns) def __enter__(self): self._in_context = True diff --git a/pymc3/sampling.py b/pymc3/sampling.py index 55a884dcbf4..133abd1549b 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -663,7 +663,7 @@ def _iter_sample(draws, step, start=None, trace=None, chain=0, tune=None, except KeyboardInterrupt: strace.close() if hasattr(step, 'warnings'): - warns = step.warnings(strace) + warns = step.warnings() strace._add_warnings(warns) raise except BaseException: @@ -672,7 +672,7 @@ def _iter_sample(draws, step, start=None, trace=None, chain=0, tune=None, else: strace.close() if hasattr(step, 'warnings'): - warns = step.warnings(strace) + warns = step.warnings() strace._add_warnings(warns) @@ -1002,6 +1002,8 @@ def _mp_sample(draws, tune, step, chains, cores, chain, random_seed, trace.record(draw.point) if draw.is_last: trace.close() + if draw.warnings is not None: + trace._add_warnings(draw.warnings) return MultiTrace(traces) except KeyboardInterrupt: traces, length = _choose_chains(traces, tune) @@ -1013,12 +1015,14 @@ def _mp_sample(draws, tune, step, chains, cores, chain, random_seed, else: chain_nums = list(range(chain, chain + chains)) pbars = [progressbar] + [False] * (chains - 1) - jobs = (delayed(_sample)( - chain=args[0], progressbar=args[1], random_seed=args[2], - start=args[3], draws=draws, step=step, trace=trace, - tune=tune, model=model, **kwargs - ) - for args in zip(chain_nums, pbars, random_seed, start)) + jobs = ( + delayed(_sample)( + chain=args[0], progressbar=args[1], random_seed=args[2], + start=args[3], draws=draws, step=step, trace=trace, + tune=tune, model=model, **kwargs + ) + for args in zip(chain_nums, pbars, random_seed, start) + ) if use_mmap: traces = Parallel(n_jobs=cores)(jobs) else: diff --git a/pymc3/step_methods/compound.py b/pymc3/step_methods/compound.py index 7b089a1104a..56b98ac4575 100644 --- a/pymc3/step_methods/compound.py +++ b/pymc3/step_methods/compound.py @@ -33,11 +33,11 @@ def step(self, point): point = method.step(point) return point - def warnings(self, strace): + def warnings(self): warns = [] for method in self.methods: if hasattr(method, 'warnings'): - warns.extend(method.warnings(strace)) + warns.extend(method.warnings()) return warns def stop_tuning(self): diff --git a/pymc3/step_methods/hmc/base_hmc.py b/pymc3/step_methods/hmc/base_hmc.py index 68d0627e8c1..6fd1d81885d 100644 --- a/pymc3/step_methods/hmc/base_hmc.py +++ b/pymc3/step_methods/hmc/base_hmc.py @@ -164,7 +164,7 @@ def reset(self, start=None): self.tune = True self.potential.reset() - def warnings(self, strace): + def warnings(self): # list.copy() is not available in python2 warnings = self._warnings[:] diff --git a/pymc3/step_methods/hmc/nuts.py b/pymc3/step_methods/hmc/nuts.py index f4925def3c8..37f59c45d94 100644 --- a/pymc3/step_methods/hmc/nuts.py +++ b/pymc3/step_methods/hmc/nuts.py @@ -184,8 +184,8 @@ def competence(var, has_grad): return Competence.IDEAL return Competence.INCOMPATIBLE - def warnings(self, strace): - warnings = super(NUTS, self).warnings(strace) + def warnings(self): + warnings = super(NUTS, self).warnings() n_samples = self._samples_after_tune n_treedepth = self._reached_max_treedepth From ebb3b3eb77b2afa49ca559a23c4299baa02417f7 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Tue, 12 Jun 2018 18:20:14 +0200 Subject: [PATCH 08/11] Show one progress bar for all chains --- pymc3/parallel_sampling.py | 33 ++++++++------------------------- 1 file changed, 8 insertions(+), 25 deletions(-) diff --git a/pymc3/parallel_sampling.py b/pymc3/parallel_sampling.py index c1381156dd5..e62eeedc303 100644 --- a/pymc3/parallel_sampling.py +++ b/pymc3/parallel_sampling.py @@ -234,12 +234,8 @@ def terminate_all(processes, patience=2): class ParallelSampler(object): def __init__(self, draws, tune, chains, cores, seeds, start_points, - step_method, start_chain_num=0, progressbar=True, - notebook=True): - if progressbar and notebook: - import tqdm - tqdm_ = tqdm.tqdm_notebook - elif progressbar: + step_method, start_chain_num=0, progressbar=True): + if progressbar: import tqdm tqdm_ = tqdm.tqdm @@ -257,18 +253,11 @@ def __init__(self, draws, tune, chains, cores, seeds, start_points, self._in_context = False self._start_chain_num = start_chain_num - self._global_progress = self._progress = None + self._progress = None if progressbar: - self._global_progress = tqdm_( - total=chains, unit='chains', position=0) - self._progress = [ - tqdm_( - desc=' Chain %i' % (chain + start_chain_num), - unit='draws', - position=chain + 1, - total=draws + tune) - for chain in range(chains) - ] + self._progress = tqdm_( + total=chains * (draws + tune), unit='draws', + desc='Sampling %s chains' % chains) def _make_active(self): while self._inactive and len(self._active) < self._max_active: @@ -286,17 +275,13 @@ def __iter__(self): draw = ProcessAdapter.recv_draw(self._active) proc, is_last, draw, tuning, stats, warns = draw if self._progress is not None: - self._progress[proc.chain - self._start_chain_num].update() + self._progress.update() if is_last: proc.join() self._active.remove(proc) self._finished.append(proc) self._make_active() - if self._global_progress is not None: - self._global_progress.update() - if self._progress is not None: - self._progress[proc.chain - self._start_chain_num].close() # We could also yield proc.shared_point_view directly, # and only call proc.write_next() after the yield returns. @@ -318,6 +303,4 @@ def __enter__(self): def __exit__(self, *args): ProcessAdapter.terminate_all(self._samplers) if self._progress is not None: - self._global_progress.close() - for progress in self._progress: - progress.close() + self._progress.close() From 145856deb04cbab5d12a45f498a371ee43904e1d Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Tue, 12 Jun 2018 19:49:52 +0200 Subject: [PATCH 09/11] Better remote exception formatting --- pymc3/parallel_sampling.py | 36 +++++++++++++++++++++++++++++++----- 1 file changed, 31 insertions(+), 5 deletions(-) diff --git a/pymc3/parallel_sampling.py b/pymc3/parallel_sampling.py index e62eeedc303..193a5d62e7a 100644 --- a/pymc3/parallel_sampling.py +++ b/pymc3/parallel_sampling.py @@ -1,10 +1,10 @@ import multiprocessing import multiprocessing.sharedctypes -import sys import ctypes import time import logging from collections import namedtuple +import traceback import six import numpy as np @@ -13,6 +13,32 @@ logger = logging.getLogger('pymc3') + +# Taken from https://hg.python.org/cpython/rev/c4f92b597074 +class RemoteTraceback(Exception): + def __init__(self, tb): + self.tb = tb + + def __str__(self): + return self.tb + + +class ExceptionWithTraceback: + def __init__(self, exc, tb): + tb = traceback.format_exception(type(exc), exc, tb) + tb = ''.join(tb) + self.exc = exc + self.tb = '\n"""\n%s"""' % tb + + def __reduce__(self): + return rebuild_exc, (self.exc, self.tb) + + +def rebuild_exc(exc, tb): + exc.__cause__ = RemoteTraceback(tb) + return exc + + # Messages # ('writing_done', is_last, sample_idx, tuning, stats) # ('error', *exception_info) @@ -47,9 +73,9 @@ def run(self): self._start_loop() except KeyboardInterrupt: pass - except BaseException: - exc_info = sys.exc_info() - self._msg_pipe.send(('error', exc_info[:2])) + except BaseException as e: + e = ExceptionWithTraceback(e, e.__traceback__) + self._msg_pipe.send(('error', e)) finally: self._msg_pipe.close() @@ -193,7 +219,7 @@ def recv_draw(processes, timeout=3600): msg = ready[0].recv() if msg[0] == 'error': - old = msg[1][1]#.with_traceback(msg[1][2]) + old = msg[1] six.raise_from(RuntimeError('Chain %s failed.' % proc.chain), old) elif msg[0] == 'writing_done': proc._readable = True From c7b43b4c39181232d05e665da92589895cdc83e7 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Wed, 13 Jun 2018 18:37:54 +0200 Subject: [PATCH 10/11] Make posterior tests more stable --- pymc3/tests/sampler_fixtures.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc3/tests/sampler_fixtures.py b/pymc3/tests/sampler_fixtures.py index 0a5e9e48c25..78f7ee8c52b 100644 --- a/pymc3/tests/sampler_fixtures.py +++ b/pymc3/tests/sampler_fixtures.py @@ -82,13 +82,13 @@ def make_model(cls): class StudentTFixture(KnownMean, KnownCDF): means = {'a': 0} - cdfs = {'a': stats.t(df=3).cdf} + cdfs = {'a': stats.t(df=4).cdf} ks_thin = 10 @classmethod def make_model(cls): with pm.Model() as model: - a = pm.StudentT("a", nu=3, mu=0, sd=1) + a = pm.StudentT("a", nu=4, mu=0, sd=1) return model From ae1025bd4c865e515762bf30fce99ced005ef7b6 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Thu, 14 Jun 2018 12:05:36 +0200 Subject: [PATCH 11/11] Add release notes for multiproc rewrite --- RELEASE-NOTES.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 5b59cf88cae..bf814c9ff5c 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -13,6 +13,11 @@ - Improve error message `NaN occurred in optimization.` during ADVI - Save and load traces without `pickle` using `pm.save_trace` and `pm.load_trace` - Add `Kumaraswamy` distribution +- Rewrite parallel sampling of multiple chains on py3. This resolves + long standing issues when tranferring large traces to the main process, + avoids pickleing issues on UNIX, and allows us to show a progress bar + for all chains. If parallel sampling is interrupted, we now return + partial results. ### Fixes