Skip to content

Commit

Permalink
Fix tests for extra divergence info
Browse files Browse the repository at this point in the history
  • Loading branch information
aseyboldt committed Jul 2, 2020
1 parent 9a3ba4a commit ea71951
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 66 deletions.
3 changes: 3 additions & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Release Notes

## PyMC3 3.9.x (on deck)
- Fix an error on Windows an Mac where error message from unpickling models did not show up in the notebook (see [#3991](https://github.com/pymc-devs/pymc3/pull/3991)).
- Introduce an option argument `mp_ctx` to control how the processes for parallel sampling are started (see [#3991](https://github.com/pymc-devs/pymc3/pull/3991))

### Documentation
- Notebook on [multilevel modeling](https://docs.pymc.io/notebooks/multilevel_modeling.html) has been rewritten to showcase ArviZ and xarray usage for inference result analysis (see [#3963](https://github.com/pymc-devs/pymc3/pull/3963))

Expand Down
95 changes: 38 additions & 57 deletions pymc3/parallel_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from collections import namedtuple
import traceback
from pymc3.exceptions import SamplingError
import errno

import numpy as np
from fastprogress.fastprogress import progress_bar
Expand All @@ -31,37 +30,6 @@
logger = logging.getLogger("pymc3")


def _get_broken_pipe_exception():
import sys

if sys.platform == "win32":
return RuntimeError(
"The communication pipe between the main process "
"and its spawned children is broken.\n"
"In Windows OS, this usually means that the child "
"process raised an exception while it was being "
"spawned, before it was setup to communicate to "
"the main process.\n"
"The exceptions raised by the child process while "
"spawning cannot be caught or handled from the "
"main process, and when running from an IPython or "
"jupyter notebook interactive kernel, the child's "
"exception and traceback appears to be lost.\n"
"A known way to see the child's error, and try to "
"fix or handle it, is to run the problematic code "
"as a batch script from a system's Command Prompt. "
"The child's exception will be printed to the "
"Command Promt's stderr, and it should be visible "
"above this error and traceback.\n"
"Note that if running a jupyter notebook that was "
"invoked from a Command Prompt, the child's "
"exception should have been printed to the Command "
"Prompt on which the notebook is running."
)
else:
return None


class ParallelSamplingError(Exception):
def __init__(self, message, chain, warnings=None):
super().__init__(message)
Expand Down Expand Up @@ -133,18 +101,37 @@ def __init__(
self._tune = tune
self._pickle_backend = pickle_backend

def _unpickle_step_method(self):
unpickle_error = (
"The model could not be unpickled. This is required for sampling "
"with more than one core and multiprocessing context spawn "
"or forkserver."
)
if self._step_method_is_pickled:
if self._pickle_backend == 'pickle':
try:
self._step_method = pickle.loads(self._step_method)
except Exception:
raise ValueError(unpickle_error)
elif self._pickle_backend == 'dill':
try:
import dill
except ImportError:
raise ValueError(
"dill must be installed for pickle_backend='dill'."
)
try:
self._step_method = dill.loads(self._step_method)
except Exception:
raise ValueError(unpickle_error)
else:
raise ValueError("Unknown pickle backend")

def run(self):
try:
# We do not create this in __init__, as pickling this
# would destroy the shared memory.
if self._step_method_is_pickled:
if self._pickle_backend == 'pickle':
self._step_method = pickle.loads(self._step_method)
elif self._pickle_backend == 'dill':
import dill
self._step_method = dill.loads(self._step_method)
else:
raise ValueError("Unknown pickle backend")
self._unpickle_step_method()
self._point = self._make_numpy_refs()
self._start_loop()
except KeyboardInterrupt:
Expand Down Expand Up @@ -289,7 +276,7 @@ def __init__(

self._process = mp_ctx.Process(
daemon=True,
name=name,
name=process_name,
target=_run_process,
args=(
process_name,
Expand All @@ -303,21 +290,10 @@ def __init__(
pickle_backend,
)
)
try:
self._process.start()
# Close the remote pipe, so that we get notified if the other
# end is closed.
remote_conn.close()
except IOError as e:
# Something may have gone wrong during the fork / spawn
if e.errno == errno.EPIPE:
exc = _get_broken_pipe_exception()
if exc is not None:
# Sleep a little to give the child process time to flush
# all its error message
time.sleep(0.2)
raise exc
raise
self._process.start()
# Close the remote pipe, so that we get notified if the other
# end is closed.
remote_conn.close()

@property
def shared_point_view(self):
Expand Down Expand Up @@ -451,7 +427,12 @@ def __init__(
if pickle_backend == 'pickle':
step_method_pickled = pickle.dumps(step_method, protocol=-1)
elif pickle_backend == 'dill':
import dill
try:
import dill
except ImportError:
raise ValueError(
"dill must be installed for pickle_backend='dill'."
)
step_method_pickled = dill.dumps(step_method, protocol=-1)

self._samplers = [
Expand Down
16 changes: 11 additions & 5 deletions pymc3/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,9 +338,13 @@ def sample(
Defaults to `False`, but we'll switch to `True` in an upcoming release.
idata_kwargs : dict, optional
Keyword arguments for `arviz.from_pymc3`
mp_ctx : str
The name of a multiprocessing context. One of `fork`, `spawn` or `forkserver`.
See multiprocessing documentation for details.
mp_ctx : multiprocessing.context.BaseContent
A multiprocessing context for parallel sampling. See multiprocessing
documentation for details.
pickle_backend : str
One of `'pickle'` or `'dill'`. The library used to pickle models
in parallel sampling if the multiprocessing context is not of type
`fork`.
Returns
-------
Expand Down Expand Up @@ -508,8 +512,10 @@ def sample(
"cores": cores,
"callback": callback,
"discard_tuned_samples": discard_tuned_samples,
"mp_ctx": mp_ctx,
}
parallel_args = {
"pickle_backend": pickle_backend,
"mp_ctx": mp_ctx,
}

sample_args.update(kwargs)
Expand All @@ -527,7 +533,7 @@ def sample(
_log.info("Multiprocess sampling ({} chains in {} jobs)".format(chains, cores))
_print_step_hierarchy(step)
try:
trace = _mp_sample(**sample_args)
trace = _mp_sample(**sample_args, **parallel_args)
except pickle.PickleError:
_log.warning("Could not pickle model, sampling singlethreaded.")
_log.debug("Pickling error:", exec_info=True)
Expand Down
42 changes: 38 additions & 4 deletions pymc3/tests/test_parallel_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,41 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import multiprocessing

import pytest
import pymc3.parallel_sampling as ps
import pymc3 as pm


def test_context():
with pm.Model():
pm.Normal('x')
ctx = multiprocessing.get_context('spawn')
pm.sample(tune=2, draws=2, chains=2, cores=2, mp_ctx=ctx)


class NoUnpickle:
def __getstate__(self):
return self.__dict__.copy()

def __setstate__(self, state):
raise AttributeError("This fails")


def test_bad_unpickle():
with pm.Model() as model:
pm.Normal('x')

with model:
step = pm.NUTS()
step.no_unpickle = NoUnpickle()
with pytest.raises(Exception) as exc_info:
pm.sample(tune=2, draws=2, mp_ctx='spawn', step=step,
cores=2, chains=2, compute_convergence_checks=False)
assert 'could not be unpickled' in str(exc_info.getrepr(style='short'))


def test_abort():
with pm.Model() as model:
a = pm.Normal('a', shape=1)
Expand All @@ -25,8 +55,10 @@ def test_abort():

step = pm.CompoundStep([step1, step2])

proc = ps.ProcessAdapter(10, 10, step, chain=3, seed=1,
start={'a': 1., 'b_log__': 2.})
ctx = multiprocessing.get_context()
proc = ps.ProcessAdapter(10, 10, step, chain=3, seed=1, mp_ctx=ctx,
start={'a': 1., 'b_log__': 2.},
step_method_pickled=None, pickle_backend='pickle')
proc.start()
proc.write_next()
proc.abort()
Expand All @@ -42,8 +74,10 @@ def test_explicit_sample():

step = pm.CompoundStep([step1, step2])

proc = ps.ProcessAdapter(10, 10, step, chain=3, seed=1,
start={'a': 1., 'b_log__': 2.})
ctx = multiprocessing.get_context()
proc = ps.ProcessAdapter(10, 10, step, chain=3, seed=1, mp_ctx=ctx,
start={'a': 1., 'b_log__': 2.},
step_method_pickled=None, pickle_backend='pickle')
proc.start()
while True:
proc.write_next()
Expand Down

0 comments on commit ea71951

Please sign in to comment.