Skip to content

Commit

Permalink
fix regression #4273 (#4297)
Browse files Browse the repository at this point in the history
* informative warnings on bound method logp in DensityDist

* black

* run test on single core only to avoid dill error on windows/macos

* adding tests for DensityDist serialize recursion handling

* forgot a test
  • Loading branch information
Spaak authored Dec 5, 2020
1 parent df3ae60 commit 3fa3d1f
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 7 deletions.
27 changes: 26 additions & 1 deletion pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import multiprocessing
import numbers
import contextvars
import dill
import inspect
import sys
import types
from typing import TYPE_CHECKING
import warnings

if TYPE_CHECKING:
from typing import Optional, Callable
Expand Down Expand Up @@ -505,6 +509,19 @@ def __init__(
dtype = theano.config.floatX
super().__init__(shape, dtype, testval, *args, **kwargs)
self.logp = logp
if type(self.logp) == types.MethodType:
if sys.platform != "linux":
warnings.warn(
"You are passing a bound method as logp for DensityDist, this can lead to "
+ "errors when sampling on platforms other than Linux. Consider using a "
+ "plain function instead, or subclass Distribution."
)
elif type(multiprocessing.get_context()) != multiprocessing.context.ForkContext:
warnings.warn(
"You are passing a bound method as logp for DensityDist, this can lead to "
+ "errors when sampling when multiprocessing cannot rely on forking. Consider using a "
+ "plain function instead, or subclass Distribution."
)
self.rand = random
self.wrap_random_with_dist_shape = wrap_random_with_dist_shape
self.check_shape_in_random = check_shape_in_random
Expand All @@ -513,7 +530,15 @@ def __getstate__(self):
# We use dill to serialize the logp function, as this is almost
# always defined in the notebook and won't be pickled correctly.
# Fix https://github.com/pymc-devs/pymc3/issues/3844
logp = dill.dumps(self.logp)
try:
logp = dill.dumps(self.logp)
except RecursionError as err:
if type(self.logp) == types.MethodType:
raise ValueError(
"logp for DensityDist is a bound method, leading to RecursionError while serializing"
) from err
else:
raise err
vals = self.__dict__.copy()
vals["logp"] = logp
return vals
Expand Down
12 changes: 6 additions & 6 deletions pymc3/tests/test_distributions_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,7 +1171,7 @@ def test_density_dist_with_random_sampleable(self, shape):
shape=shape,
random=normal_dist.random,
)
trace = pm.sample(100)
trace = pm.sample(100, cores=1)

samples = 500
size = 100
Expand All @@ -1194,7 +1194,7 @@ def test_density_dist_with_random_sampleable_failure(self, shape):
random=normal_dist.random,
wrap_random_with_dist_shape=False,
)
trace = pm.sample(100)
trace = pm.sample(100, cores=1)

samples = 500
with pytest.raises(RuntimeError):
Expand All @@ -1217,7 +1217,7 @@ def test_density_dist_with_random_sampleable_hidden_error(self, shape):
wrap_random_with_dist_shape=False,
check_shape_in_random=False,
)
trace = pm.sample(100)
trace = pm.sample(100, cores=1)

samples = 500
ppc = pm.sample_posterior_predictive(trace, samples=samples, model=model)
Expand All @@ -1240,7 +1240,7 @@ def test_density_dist_with_random_sampleable_handcrafted_success(self):
random=rvs,
wrap_random_with_dist_shape=False,
)
trace = pm.sample(100)
trace = pm.sample(100, cores=1)

samples = 500
size = 100
Expand All @@ -1260,7 +1260,7 @@ def test_density_dist_with_random_sampleable_handcrafted_success_fast(self):
random=rvs,
wrap_random_with_dist_shape=False,
)
trace = pm.sample(100)
trace = pm.sample(100, cores=1)

samples = 500
size = 100
Expand All @@ -1273,7 +1273,7 @@ def test_density_dist_without_random_not_sampleable(self):
mu = pm.Normal("mu", 0, 1)
normal_dist = pm.Normal.dist(mu, 1)
pm.DensityDist("density_dist", normal_dist.logp, observed=np.random.randn(100))
trace = pm.sample(100)
trace = pm.sample(100, cores=1)

samples = 500
with pytest.raises(ValueError):
Expand Down
42 changes: 42 additions & 0 deletions pymc3/tests/test_parallel_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,45 @@ def test_iterator():
with sampler:
for draw in sampler:
pass


def test_spawn_densitydist_function():
with pm.Model() as model:
mu = pm.Normal("mu", 0, 1)

def func(x):
return -2 * (x ** 2).sum()

obs = pm.DensityDist("density_dist", func, observed=np.random.randn(100))
trace = pm.sample(draws=10, tune=10, step=pm.Metropolis(), cores=2, mp_ctx="spawn")


@pytest.mark.xfail(raises=ValueError)
def test_spawn_densitydist_bound_method():
with pm.Model() as model:
mu = pm.Normal("mu", 0, 1)
normal_dist = pm.Normal.dist(mu, 1)
obs = pm.DensityDist("density_dist", normal_dist.logp, observed=np.random.randn(100))
trace = pm.sample(draws=10, tune=10, step=pm.Metropolis(), cores=2, mp_ctx="spawn")


# cannot test this properly: monkeypatching sys.platform messes up Theano
# def test_spawn_densitydist_syswarning(monkeypatch):
# monkeypatch.setattr(sys, "platform", "win32")
# with pm.Model() as model:
# mu = pm.Normal('mu', 0, 1)
# normal_dist = pm.Normal.dist(mu, 1)
# with pytest.warns(UserWarning) as w:
# obs = pm.DensityDist('density_dist', normal_dist.logp, observed=np.random.randn(100))
# assert len(w) == 1 and "errors when sampling on platforms" in w[0].message.args[0]


def test_spawn_densitydist_mpctxwarning(monkeypatch):
ctx = multiprocessing.get_context("spawn")
monkeypatch.setattr(multiprocessing, "get_context", lambda: ctx)
with pm.Model() as model:
mu = pm.Normal("mu", 0, 1)
normal_dist = pm.Normal.dist(mu, 1)
with pytest.warns(UserWarning) as w:
obs = pm.DensityDist("density_dist", normal_dist.logp, observed=np.random.randn(100))
assert len(w) == 1 and "errors when sampling when multiprocessing" in w[0].message.args[0]

0 comments on commit 3fa3d1f

Please sign in to comment.