diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 5b59cf88cae..bf814c9ff5c 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -13,6 +13,11 @@ - Improve error message `NaN occurred in optimization.` during ADVI - Save and load traces without `pickle` using `pm.save_trace` and `pm.load_trace` - Add `Kumaraswamy` distribution +- Rewrite parallel sampling of multiple chains on py3. This resolves + long standing issues when tranferring large traces to the main process, + avoids pickleing issues on UNIX, and allows us to show a progress bar + for all chains. If parallel sampling is interrupted, we now return + partial results. ### Fixes diff --git a/pymc3/backends/text.py b/pymc3/backends/text.py index d5bec95639d..027a748e318 100644 --- a/pymc3/backends/text.py +++ b/pymc3/backends/text.py @@ -99,8 +99,9 @@ def record(self, point): self._fh.write(','.join(columns) + '\n') def close(self): - self._fh.close() - self._fh = None # Avoid serialization issue. + if self._fh is not None: + self._fh.close() + self._fh = None # Avoid serialization issue. # Selection methods diff --git a/pymc3/parallel_sampling.py b/pymc3/parallel_sampling.py new file mode 100644 index 00000000000..193a5d62e7a --- /dev/null +++ b/pymc3/parallel_sampling.py @@ -0,0 +1,332 @@ +import multiprocessing +import multiprocessing.sharedctypes +import ctypes +import time +import logging +from collections import namedtuple +import traceback + +import six +import numpy as np + +from . import theanof + +logger = logging.getLogger('pymc3') + + +# Taken from https://hg.python.org/cpython/rev/c4f92b597074 +class RemoteTraceback(Exception): + def __init__(self, tb): + self.tb = tb + + def __str__(self): + return self.tb + + +class ExceptionWithTraceback: + def __init__(self, exc, tb): + tb = traceback.format_exception(type(exc), exc, tb) + tb = ''.join(tb) + self.exc = exc + self.tb = '\n"""\n%s"""' % tb + + def __reduce__(self): + return rebuild_exc, (self.exc, self.tb) + + +def rebuild_exc(exc, tb): + exc.__cause__ = RemoteTraceback(tb) + return exc + + +# Messages +# ('writing_done', is_last, sample_idx, tuning, stats) +# ('error', *exception_info) + +# ('abort', reason) +# ('write_next',) +# ('start',) + + +class _Process(multiprocessing.Process): + """Seperate process for each chain. + + We communicate with the main process using a pipe, + and send finished samples using shared memory. + """ + def __init__(self, name, msg_pipe, step_method, shared_point, + draws, tune, seed): + super(_Process, self).__init__(daemon=True, name=name) + self._msg_pipe = msg_pipe + self._step_method = step_method + self._shared_point = shared_point + self._seed = seed + self._tt_seed = seed + 1 + self._draws = draws + self._tune = tune + + def run(self): + try: + # We do not create this in __init__, as pickling this + # would destroy the shared memory. + self._point = self._make_numpy_refs() + self._start_loop() + except KeyboardInterrupt: + pass + except BaseException as e: + e = ExceptionWithTraceback(e, e.__traceback__) + self._msg_pipe.send(('error', e)) + finally: + self._msg_pipe.close() + + def _make_numpy_refs(self): + shape_dtypes = self._step_method.vars_shape_dtype + point = {} + for name, (shape, dtype) in shape_dtypes.items(): + array = self._shared_point[name] + self._shared_point[name] = array + point[name] = np.frombuffer(array, dtype).reshape(shape) + return point + + def _write_point(self, point): + for name, vals in point.items(): + self._point[name][...] = vals + + def _recv_msg(self): + return self._msg_pipe.recv() + + def _start_loop(self): + np.random.seed(self._seed) + theanof.set_tt_rng(self._tt_seed) + + draw = 0 + tuning = True + + msg = self._recv_msg() + if msg[0] == 'abort': + raise KeyboardInterrupt() + if msg[0] != 'start': + raise ValueError('Unexpected msg ' + msg[0]) + + while True: + if draw < self._draws + self._tune: + point, stats = self._compute_point() + else: + return + + if draw == self._tune: + self._step_method.stop_tuning() + tuning = False + + msg = self._recv_msg() + if msg[0] == 'abort': + raise KeyboardInterrupt() + elif msg[0] == 'write_next': + self._write_point(point) + is_last = draw + 1 == self._draws + self._tune + if is_last: + warns = self._collect_warnings() + else: + warns = None + self._msg_pipe.send( + ('writing_done', is_last, draw, tuning, stats, warns)) + draw += 1 + else: + raise ValueError('Unknown message ' + msg[0]) + + def _compute_point(self): + if self._step_method.generates_stats: + point, stats = self._step_method.step(self._point) + else: + point = self._step_method.step(self._point) + stats = None + return point, stats + + def _collect_warnings(self): + if hasattr(self._step_method, 'warnings'): + return self._step_method.warnings() + else: + return [] + + +class ProcessAdapter(object): + """Control a Chain process from the main thread.""" + def __init__(self, draws, tune, step_method, chain, seed, start): + self.chain = chain + process_name = "worker_chain_%s" % chain + self._msg_pipe, remote_conn = multiprocessing.Pipe() + + self._shared_point = {} + self._point = {} + for name, (shape, dtype) in step_method.vars_shape_dtype.items(): + size = 1 + for dim in shape: + size *= int(dim) + size *= dtype.itemsize + if size != ctypes.c_size_t(size).value: + raise ValueError('Variable %s is too large' % name) + + array = multiprocessing.sharedctypes.RawArray('c', size) + self._shared_point[name] = array + array_np = np.frombuffer(array, dtype).reshape(shape) + array_np[...] = start[name] + self._point[name] = array_np + + self._readable = True + self._num_samples = 0 + + self._process = _Process( + process_name, remote_conn, step_method, self._shared_point, + draws, tune, seed) + # We fork right away, so that the main process can start tqdm threads + self._process.start() + + @property + def shared_point_view(self): + """May only be written to or read between a `recv_draw` + call from the process and a `write_next` or `abort` call. + """ + if not self._readable: + raise RuntimeError() + return self._point + + def start(self): + self._msg_pipe.send(('start',)) + + def write_next(self): + self._readable = False + self._msg_pipe.send(('write_next',)) + + def abort(self): + self._msg_pipe.send(('abort',)) + + def join(self, timeout=None): + self._process.join(timeout) + + def terminate(self): + self._process.terminate() + + @staticmethod + def recv_draw(processes, timeout=3600): + if not processes: + raise ValueError('No processes.') + pipes = [proc._msg_pipe for proc in processes] + ready = multiprocessing.connection.wait(pipes) + if not ready: + raise multiprocessing.TimeoutError('No message from samplers.') + idxs = {id(proc._msg_pipe): proc for proc in processes} + proc = idxs[id(ready[0])] + msg = ready[0].recv() + + if msg[0] == 'error': + old = msg[1] + six.raise_from(RuntimeError('Chain %s failed.' % proc.chain), old) + elif msg[0] == 'writing_done': + proc._readable = True + proc._num_samples += 1 + return (proc,) + msg[1:] + else: + raise ValueError('Sampler sent bad message.') + + @staticmethod + def terminate_all(processes, patience=2): + for process in processes: + try: + process.abort() + except EOFError: + pass + + start_time = time.time() + try: + for process in processes: + timeout = time.time() + patience - start_time + if timeout < 0: + raise multiprocessing.TimeoutError() + process.join(timeout) + except multiprocessing.TimeoutError: + logger.warn('Chain processes did not terminate as expected. ' + 'Terminating forcefully...') + for process in processes: + process.terminate() + for process in processes: + process.join() + + +Draw = namedtuple( + 'Draw', + ['chain', 'is_last', 'draw_idx', 'tuning', 'stats', 'point', 'warnings'] +) + + +class ParallelSampler(object): + def __init__(self, draws, tune, chains, cores, seeds, start_points, + step_method, start_chain_num=0, progressbar=True): + if progressbar: + import tqdm + tqdm_ = tqdm.tqdm + + self._samplers = [ + ProcessAdapter(draws, tune, step_method, + chain + start_chain_num, seed, start) + for chain, seed, start in zip(range(chains), seeds, start_points) + ] + + self._inactive = self._samplers.copy() + self._finished = [] + self._active = [] + self._max_active = cores + + self._in_context = False + self._start_chain_num = start_chain_num + + self._progress = None + if progressbar: + self._progress = tqdm_( + total=chains * (draws + tune), unit='draws', + desc='Sampling %s chains' % chains) + + def _make_active(self): + while self._inactive and len(self._active) < self._max_active: + proc = self._inactive.pop(0) + proc.start() + proc.write_next() + self._active.append(proc) + + def __iter__(self): + if not self._in_context: + raise ValueError('Use ParallelSampler as context manager.') + self._make_active() + + while self._active: + draw = ProcessAdapter.recv_draw(self._active) + proc, is_last, draw, tuning, stats, warns = draw + if self._progress is not None: + self._progress.update() + + if is_last: + proc.join() + self._active.remove(proc) + self._finished.append(proc) + self._make_active() + + # We could also yield proc.shared_point_view directly, + # and only call proc.write_next() after the yield returns. + # This seems to be faster overally though, as the worker + # loses less time waiting. + point = {name: val.copy() + for name, val in proc.shared_point_view.items()} + + # Already called for new proc in _make_active + if not is_last: + proc.write_next() + + yield Draw(proc.chain, is_last, draw, tuning, stats, point, warns) + + def __enter__(self): + self._in_context = True + return self + + def __exit__(self, *args): + ProcessAdapter.terminate_all(self._samplers) + if self._progress is not None: + self._progress.close() diff --git a/pymc3/sampling.py b/pymc3/sampling.py index cfbfa52f016..133abd1549b 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -663,7 +663,7 @@ def _iter_sample(draws, step, start=None, trace=None, chain=0, tune=None, except KeyboardInterrupt: strace.close() if hasattr(step, 'warnings'): - warns = step.warnings(strace) + warns = step.warnings() strace._add_warnings(warns) raise except BaseException: @@ -672,7 +672,7 @@ def _iter_sample(draws, step, start=None, trace=None, chain=0, tune=None, else: strace.close() if hasattr(step, 'warnings'): - warns = step.warnings(strace) + warns = step.warnings() strace._add_warnings(warns) @@ -965,36 +965,104 @@ def _choose_backend(trace, chain, shortcuts=None, **kwds): raise ValueError('Argument `trace` is invalid.') -def _mp_sample(**kwargs): - cores = kwargs.pop('cores') - chain = kwargs.pop('chain') - rseed = kwargs.pop('random_seed') - start = kwargs.pop('start') - chains = kwargs.pop('chains') - use_mmap = kwargs.pop('use_mmap') +def _mp_sample(draws, tune, step, chains, cores, chain, random_seed, + start, progressbar, trace=None, model=None, use_mmap=False, + **kwargs): - chain_nums = list(range(chain, chain + chains)) - pbars = [kwargs.pop('progressbar')] + [False] * (chains - 1) - jobs = (delayed(_sample)(*args, **kwargs) - for args in zip(chain_nums, pbars, rseed, start)) + if sys.version_info.major >= 3: + import pymc3.parallel_sampling as ps + + # We did draws += tune in pm.sample + draws -= tune + + traces = [] + for idx in range(chain, chain + chains): + if trace is not None: + strace = _choose_backend(copy(trace), idx, model=model) + else: + strace = _choose_backend(None, idx, model=model) + # TODO what is this for? + update_start_vals(start[idx - chain], model.test_point, model) + if step.generates_stats and strace.supports_sampler_stats: + strace.setup(draws + tune, idx + chain, step.stats_dtypes) + else: + strace.setup(draws + tune, idx + chain) + traces.append(strace) + + sampler = ps.ParallelSampler( + draws, tune, chains, cores, random_seed, start, step, + chain, progressbar) + try: + with sampler: + for draw in sampler: + trace = traces[draw.chain - chain] + if trace.supports_sampler_stats and draw.stats is not None: + trace.record(draw.point, draw.stats) + else: + trace.record(draw.point) + if draw.is_last: + trace.close() + if draw.warnings is not None: + trace._add_warnings(draw.warnings) + return MultiTrace(traces) + except KeyboardInterrupt: + traces, length = _choose_chains(traces, tune) + return MultiTrace(traces)[:length] + finally: + for trace in traces: + trace.close() - if use_mmap: - traces = Parallel(n_jobs=cores)(jobs) else: - traces = Parallel(n_jobs=cores, mmap_mode=None)(jobs) + chain_nums = list(range(chain, chain + chains)) + pbars = [progressbar] + [False] * (chains - 1) + jobs = ( + delayed(_sample)( + chain=args[0], progressbar=args[1], random_seed=args[2], + start=args[3], draws=draws, step=step, trace=trace, + tune=tune, model=model, **kwargs + ) + for args in zip(chain_nums, pbars, random_seed, start) + ) + if use_mmap: + traces = Parallel(n_jobs=cores)(jobs) + else: + traces = Parallel(n_jobs=cores, mmap_mode=None)(jobs) + return MultiTrace(traces) - return MultiTrace(traces) +def _choose_chains(traces, tune): + if tune is None: + tune = 0 -def stop_tuning(step): - """ stop tuning the current step method """ + if not traces: + return [] + + lengths = [max(0, len(trace) - tune) for trace in traces] + if not sum(lengths): + raise ValueError('Not enough samples to build a trace.') + + idxs = np.argsort(lengths)[::-1] + l_sort = np.array(lengths)[idxs] + + final_length = l_sort[0] + last_total = 0 + for i, length in enumerate(l_sort): + total = (i + 1) * length + if total < last_total: + use_until = i + break + last_total = total + final_length = length + else: + use_until = len(lengths) + + return [traces[idx] for idx in idxs[:use_until]], final_length + tune - if hasattr(step, 'tune'): - step.tune = False - if hasattr(step, 'methods'): - step.methods = [stop_tuning(s) for s in step.methods] +def stop_tuning(step): + """ stop tuning the current step method """ + step.stop_tuning() return step diff --git a/pymc3/step_methods/arraystep.py b/pymc3/step_methods/arraystep.py index 4cf9585e41e..8366bbb5c36 100644 --- a/pymc3/step_methods/arraystep.py +++ b/pymc3/step_methods/arraystep.py @@ -87,13 +87,26 @@ def _competence(cls, vars, have_grad): vars = np.atleast_1d(vars) have_grad = np.atleast_1d(have_grad) competences = [] - for var,has_grad in zip(vars, have_grad): + for var, has_grad in zip(vars, have_grad): try: competences.append(cls.competence(var, has_grad)) except TypeError: competences.append(cls.competence(var)) return competences + @property + def vars_shape_dtype(self): + shape_dtypes = {} + for var in self.vars: + dtype = np.dtype(var.dtype) + shape = var.dshape + shape_dtypes[var.name] = (shape, dtype) + return shape_dtypes + + def stop_tuning(self): + if hasattr(self, 'tune'): + self.tune = False + class ArrayStep(BlockedStep): """ diff --git a/pymc3/step_methods/compound.py b/pymc3/step_methods/compound.py index d8f66171cb9..56b98ac4575 100644 --- a/pymc3/step_methods/compound.py +++ b/pymc3/step_methods/compound.py @@ -6,11 +6,13 @@ class CompoundStep(object): - """Step method composed of a list of several other step methods applied in sequence.""" + """Step method composed of a list of several other step + methods applied in sequence.""" def __init__(self, methods): self.methods = list(methods) - self.generates_stats = any(method.generates_stats for method in self.methods) + self.generates_stats = any( + method.generates_stats for method in self.methods) self.stats_dtypes = [] for method in self.methods: if method.generates_stats: @@ -31,9 +33,20 @@ def step(self, point): point = method.step(point) return point - def warnings(self, strace): + def warnings(self): warns = [] for method in self.methods: if hasattr(method, 'warnings'): - warns.extend(method.warnings(strace)) + warns.extend(method.warnings()) return warns + + def stop_tuning(self): + for method in self.methods: + method.stop_tuning() + + @property + def vars_shape_dtype(self): + dtype_shapes = {} + for method in self.methods: + dtype_shapes.update(method.vars_shape_dtype) + return dtype_shapes diff --git a/pymc3/step_methods/hmc/base_hmc.py b/pymc3/step_methods/hmc/base_hmc.py index 68d0627e8c1..6fd1d81885d 100644 --- a/pymc3/step_methods/hmc/base_hmc.py +++ b/pymc3/step_methods/hmc/base_hmc.py @@ -164,7 +164,7 @@ def reset(self, start=None): self.tune = True self.potential.reset() - def warnings(self, strace): + def warnings(self): # list.copy() is not available in python2 warnings = self._warnings[:] diff --git a/pymc3/step_methods/hmc/nuts.py b/pymc3/step_methods/hmc/nuts.py index f4925def3c8..37f59c45d94 100644 --- a/pymc3/step_methods/hmc/nuts.py +++ b/pymc3/step_methods/hmc/nuts.py @@ -184,8 +184,8 @@ def competence(var, has_grad): return Competence.IDEAL return Competence.INCOMPATIBLE - def warnings(self, strace): - warnings = super(NUTS, self).warnings(strace) + def warnings(self): + warnings = super(NUTS, self).warnings() n_samples = self._samples_after_tune n_treedepth = self._reached_max_treedepth diff --git a/pymc3/tests/sampler_fixtures.py b/pymc3/tests/sampler_fixtures.py index 0a5e9e48c25..78f7ee8c52b 100644 --- a/pymc3/tests/sampler_fixtures.py +++ b/pymc3/tests/sampler_fixtures.py @@ -82,13 +82,13 @@ def make_model(cls): class StudentTFixture(KnownMean, KnownCDF): means = {'a': 0} - cdfs = {'a': stats.t(df=3).cdf} + cdfs = {'a': stats.t(df=4).cdf} ks_thin = 10 @classmethod def make_model(cls): with pm.Model() as model: - a = pm.StudentT("a", nu=3, mu=0, sd=1) + a = pm.StudentT("a", nu=4, mu=0, sd=1) return model diff --git a/pymc3/tests/test_parallel_sampling.py b/pymc3/tests/test_parallel_sampling.py new file mode 100644 index 00000000000..515c130a90c --- /dev/null +++ b/pymc3/tests/test_parallel_sampling.py @@ -0,0 +1,72 @@ +import time +import sys +import pytest + +import pymc3.parallel_sampling as ps +import pymc3 as pm + + +@pytest.mark.skipif(sys.version_info < (3,3), + reason="requires python3.3") +def test_abort(): + with pm.Model() as model: + a = pm.Normal('a', shape=1) + pm.HalfNormal('b') + step1 = pm.NUTS([a]) + step2 = pm.Metropolis([model.b_log__]) + + step = pm.CompoundStep([step1, step2]) + + proc = ps.ProcessAdapter(10, 10, step, chain=3, seed=1, + start={'a': 1., 'b_log__': 2.}) + proc.start() + proc.write_next() + proc.abort() + proc.join() + + +@pytest.mark.skipif(sys.version_info < (3,3), + reason="requires python3.3") +def test_explicit_sample(): + with pm.Model() as model: + a = pm.Normal('a', shape=1) + pm.HalfNormal('b') + step1 = pm.NUTS([a]) + step2 = pm.Metropolis([model.b_log__]) + + step = pm.CompoundStep([step1, step2]) + + start = time.time() + proc = ps.ProcessAdapter(10, 10, step, chain=3, seed=1, + start={'a': 1., 'b_log__': 2.}) + proc.start() + while True: + proc.write_next() + out = ps.ProcessAdapter.recv_draw([proc]) + view = proc.shared_point_view + for name in view: + view[name].copy() + if out[1]: + break + proc.join() + print(time.time() - start) + + +@pytest.mark.skipif(sys.version_info < (3,3), + reason="requires python3.3") +def test_iterator(): + with pm.Model() as model: + a = pm.Normal('a', shape=1) + pm.HalfNormal('b') + step1 = pm.NUTS([a]) + step2 = pm.Metropolis([model.b_log__]) + + step = pm.CompoundStep([step1, step2]) + + start = time.time() + start = {'a': 1., 'b_log__': 2.} + sampler = ps.ParallelSampler(10, 10, 3, 2, [2, 3, 4], [start] * 3, + step, 0, False) + with sampler: + for draw in sampler: + pass