Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite parallel sampling using multiprocessing #3011

Merged
merged 11 commits into from
Jun 14, 2018
5 changes: 3 additions & 2 deletions pymc3/backends/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,9 @@ def record(self, point):
self._fh.write(','.join(columns) + '\n')

def close(self):
self._fh.close()
self._fh = None # Avoid serialization issue.
if self._fh is not None:
self._fh.close()
self._fh = None # Avoid serialization issue.

# Selection methods

Expand Down
332 changes: 332 additions & 0 deletions pymc3/parallel_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,332 @@
import multiprocessing
import multiprocessing.sharedctypes
import ctypes
import time
import logging
from collections import namedtuple
import traceback

import six
import numpy as np

from . import theanof

logger = logging.getLogger('pymc3')


# Taken from https://hg.python.org/cpython/rev/c4f92b597074
class RemoteTraceback(Exception):
def __init__(self, tb):
self.tb = tb

def __str__(self):
return self.tb


class ExceptionWithTraceback:
def __init__(self, exc, tb):
tb = traceback.format_exception(type(exc), exc, tb)
tb = ''.join(tb)
self.exc = exc
self.tb = '\n"""\n%s"""' % tb

def __reduce__(self):
return rebuild_exc, (self.exc, self.tb)


def rebuild_exc(exc, tb):
exc.__cause__ = RemoteTraceback(tb)
return exc


# Messages
# ('writing_done', is_last, sample_idx, tuning, stats)
# ('error', *exception_info)

# ('abort', reason)
# ('write_next',)
# ('start',)


class _Process(multiprocessing.Process):
"""Seperate process for each chain.

We communicate with the main process using a pipe,
and send finished samples using shared memory.
"""
def __init__(self, name, msg_pipe, step_method, shared_point,
draws, tune, seed):
super(_Process, self).__init__(daemon=True, name=name)
self._msg_pipe = msg_pipe
self._step_method = step_method
self._shared_point = shared_point
self._seed = seed
self._tt_seed = seed + 1
self._draws = draws
self._tune = tune

def run(self):
try:
# We do not create this in __init__, as pickling this
# would destroy the shared memory.
self._point = self._make_numpy_refs()
self._start_loop()
except KeyboardInterrupt:
pass
except BaseException as e:
e = ExceptionWithTraceback(e, e.__traceback__)
self._msg_pipe.send(('error', e))
finally:
self._msg_pipe.close()

def _make_numpy_refs(self):
shape_dtypes = self._step_method.vars_shape_dtype
point = {}
for name, (shape, dtype) in shape_dtypes.items():
array = self._shared_point[name]
self._shared_point[name] = array
point[name] = np.frombuffer(array, dtype).reshape(shape)
return point

def _write_point(self, point):
for name, vals in point.items():
self._point[name][...] = vals

def _recv_msg(self):
return self._msg_pipe.recv()

def _start_loop(self):
np.random.seed(self._seed)
theanof.set_tt_rng(self._tt_seed)

draw = 0
tuning = True

msg = self._recv_msg()
if msg[0] == 'abort':
raise KeyboardInterrupt()
if msg[0] != 'start':
raise ValueError('Unexpected msg ' + msg[0])

while True:
if draw < self._draws + self._tune:
point, stats = self._compute_point()
else:
return

if draw == self._tune:
self._step_method.stop_tuning()
tuning = False

msg = self._recv_msg()
if msg[0] == 'abort':
raise KeyboardInterrupt()
elif msg[0] == 'write_next':
self._write_point(point)
is_last = draw + 1 == self._draws + self._tune
if is_last:
warns = self._collect_warnings()
else:
warns = None
self._msg_pipe.send(
('writing_done', is_last, draw, tuning, stats, warns))
draw += 1
else:
raise ValueError('Unknown message ' + msg[0])

def _compute_point(self):
if self._step_method.generates_stats:
point, stats = self._step_method.step(self._point)
else:
point = self._step_method.step(self._point)
stats = None
return point, stats

def _collect_warnings(self):
if hasattr(self._step_method, 'warnings'):
return self._step_method.warnings()
else:
return []


class ProcessAdapter(object):
"""Control a Chain process from the main thread."""
def __init__(self, draws, tune, step_method, chain, seed, start):
self.chain = chain
process_name = "worker_chain_%s" % chain
self._msg_pipe, remote_conn = multiprocessing.Pipe()

self._shared_point = {}
self._point = {}
for name, (shape, dtype) in step_method.vars_shape_dtype.items():
size = 1
for dim in shape:
size *= int(dim)
size *= dtype.itemsize
if size != ctypes.c_size_t(size).value:
raise ValueError('Variable %s is too large' % name)

array = multiprocessing.sharedctypes.RawArray('c', size)
self._shared_point[name] = array
array_np = np.frombuffer(array, dtype).reshape(shape)
array_np[...] = start[name]
self._point[name] = array_np

self._readable = True
self._num_samples = 0

self._process = _Process(
process_name, remote_conn, step_method, self._shared_point,
draws, tune, seed)
# We fork right away, so that the main process can start tqdm threads
self._process.start()

@property
def shared_point_view(self):
"""May only be written to or read between a `recv_draw`
call from the process and a `write_next` or `abort` call.
"""
if not self._readable:
raise RuntimeError()
return self._point

def start(self):
self._msg_pipe.send(('start',))

def write_next(self):
self._readable = False
self._msg_pipe.send(('write_next',))

def abort(self):
self._msg_pipe.send(('abort',))

def join(self, timeout=None):
self._process.join(timeout)

def terminate(self):
self._process.terminate()

@staticmethod
def recv_draw(processes, timeout=3600):
if not processes:
raise ValueError('No processes.')
pipes = [proc._msg_pipe for proc in processes]
ready = multiprocessing.connection.wait(pipes)
if not ready:
raise multiprocessing.TimeoutError('No message from samplers.')
idxs = {id(proc._msg_pipe): proc for proc in processes}
proc = idxs[id(ready[0])]
msg = ready[0].recv()

if msg[0] == 'error':
old = msg[1]
six.raise_from(RuntimeError('Chain %s failed.' % proc.chain), old)
elif msg[0] == 'writing_done':
proc._readable = True
proc._num_samples += 1
return (proc,) + msg[1:]
else:
raise ValueError('Sampler sent bad message.')

@staticmethod
def terminate_all(processes, patience=2):
for process in processes:
try:
process.abort()
except EOFError:
pass

start_time = time.time()
try:
for process in processes:
timeout = time.time() + patience - start_time
if timeout < 0:
raise multiprocessing.TimeoutError()
process.join(timeout)
except multiprocessing.TimeoutError:
logger.warn('Chain processes did not terminate as expected. '
'Terminating forcefully...')
for process in processes:
process.terminate()
for process in processes:
process.join()


Draw = namedtuple(
'Draw',
['chain', 'is_last', 'draw_idx', 'tuning', 'stats', 'point', 'warnings']
)


class ParallelSampler(object):
def __init__(self, draws, tune, chains, cores, seeds, start_points,
step_method, start_chain_num=0, progressbar=True):
if progressbar:
import tqdm
tqdm_ = tqdm.tqdm

self._samplers = [
ProcessAdapter(draws, tune, step_method,
chain + start_chain_num, seed, start)
for chain, seed, start in zip(range(chains), seeds, start_points)
]

self._inactive = self._samplers.copy()
self._finished = []
self._active = []
self._max_active = cores

self._in_context = False
self._start_chain_num = start_chain_num

self._progress = None
if progressbar:
self._progress = tqdm_(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you have to add the position argument here in order to not have tqdms interfering with each other

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is only one progress bar now, but that progress bar counts samples from all chains.

total=chains * (draws + tune), unit='draws',
desc='Sampling %s chains' % chains)

def _make_active(self):
while self._inactive and len(self._active) < self._max_active:
proc = self._inactive.pop(0)
proc.start()
proc.write_next()
self._active.append(proc)

def __iter__(self):
if not self._in_context:
raise ValueError('Use ParallelSampler as context manager.')
self._make_active()

while self._active:
draw = ProcessAdapter.recv_draw(self._active)
proc, is_last, draw, tuning, stats, warns = draw
if self._progress is not None:
self._progress.update()

if is_last:
proc.join()
self._active.remove(proc)
self._finished.append(proc)
self._make_active()

# We could also yield proc.shared_point_view directly,
# and only call proc.write_next() after the yield returns.
# This seems to be faster overally though, as the worker
# loses less time waiting.
point = {name: val.copy()
for name, val in proc.shared_point_view.items()}

# Already called for new proc in _make_active
if not is_last:
proc.write_next()

yield Draw(proc.chain, is_last, draw, tuning, stats, point, warns)

def __enter__(self):
self._in_context = True
return self

def __exit__(self, *args):
ProcessAdapter.terminate_all(self._samplers)
if self._progress is not None:
self._progress.close()
Loading