diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 172313e36f3..ac313f2877b 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -5,6 +5,7 @@ ### Maintenance - Mentioned the way to do any random walk with `theano.tensor.cumsum()` in `GaussianRandomWalk` docstrings (see [#4048](https://github.com/pymc-devs/pymc3/pull/4048)). - Fixed numerical instability in ExGaussian's logp by preventing `logpow` from returning `-inf` (see [#4050](https://github.com/pymc-devs/pymc3/pull/4050)). +- Use dill to serialize user defined logp functions in `DensityDist`. The previous serialization code fails if it is used in notebooks on Windows and Mac. `dill` is now a required dependency. (see [#3844](https://github.com/pymc-devs/pymc3/issues/3844)). ### Documentation diff --git a/environment-dev.yml b/environment-dev.yml index 5cfa3ec6025..23982a66352 100644 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -37,7 +37,7 @@ dependencies: - dataclasses # python_version < 3.7 - contextvars # python_version < 3.7 - mkl-service + - dill - libblas=*=*mkl - pip: - black_nbconvert - - dill diff --git a/pymc3/distributions/distribution.py b/pymc3/distributions/distribution.py index 7176fb47add..2bdf9d88c09 100644 --- a/pymc3/distributions/distribution.py +++ b/pymc3/distributions/distribution.py @@ -14,6 +14,7 @@ import numbers import contextvars +import dill from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Optional, Callable @@ -419,6 +420,19 @@ def __init__( self.wrap_random_with_dist_shape = wrap_random_with_dist_shape self.check_shape_in_random = check_shape_in_random + def __getstate__(self): + # We use dill to serialize the logp function, as this is almost + # always defined in the notebook and won't be pickled correctly. + # Fix https://github.com/pymc-devs/pymc3/issues/3844 + logp = dill.dumps(self.logp) + vals = self.__dict__.copy() + vals['logp'] = logp + return vals + + def __setstate__(self, vals): + vals['logp'] = dill.loads(vals['logp']) + self.__dict__ = vals + def random(self, point=None, size=None, **kwargs): if self.rand is not None: not_broadcast_kwargs = dict(point=point) diff --git a/pymc3/tests/test_distributions.py b/pymc3/tests/test_distributions.py index 2897211e9bc..a52ff63a4a3 100644 --- a/pymc3/tests/test_distributions.py +++ b/pymc3/tests/test_distributions.py @@ -79,6 +79,7 @@ from ..distributions import continuous from pymc3.theanof import floatX +import pymc3 as pm from numpy import array, inf, log, exp from numpy.testing import assert_almost_equal, assert_allclose, assert_equal import numpy.random as nr @@ -1872,3 +1873,16 @@ def test_issue_3051(self, dims, dist_cls, kwargs): assert isinstance(actual_a, np.ndarray) assert actual_a.shape == (X.shape[0],) pass + + +def test_serialize_density_dist(): + def func(x): + return -2 * (x ** 2).sum() + + with pm.Model(): + pm.Normal('x') + y = pm.DensityDist('y', func) + pm.sample(draws=5, tune=1, mp_ctx="spawn") + + import pickle + pickle.loads(pickle.dumps(y)) diff --git a/requirements-dev.txt b/requirements-dev.txt index f11a4da9b73..13caa0d577d 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -17,4 +17,4 @@ sphinx-autobuild==0.7.1 sphinx>=1.5.5 watermark parameterized -dill \ No newline at end of file +dill diff --git a/requirements.txt b/requirements.txt index c65a007ef10..75aaba286da 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,3 +9,4 @@ h5py>=2.7.0 typing-extensions>=3.7.4 dataclasses; python_version < '3.7' contextvars; python_version < '3.7' +dill diff --git a/scripts/create_testenv.sh b/scripts/create_testenv.sh index 6a1e1e516ad..418d3dacc69 100755 --- a/scripts/create_testenv.sh +++ b/scripts/create_testenv.sh @@ -22,15 +22,18 @@ command -v conda >/dev/null 2>&1 || { ENVNAME="${ENVNAME:-testenv}" # if no ENVNAME is specified, use testenv if [ -z ${GLOBAL} ]; then + source $(dirname $(dirname $(which conda)))/etc/profile.d/conda.sh if conda env list | grep -q ${ENVNAME}; then echo "Environment ${ENVNAME} already exists, keeping up to date" + conda activate ${ENVNAME} + mamba env update -f environment-dev.yml else conda config --add channels conda-forge conda config --set channel_priority strict conda install -c conda-forge mamba --yes mamba env create -f environment-dev.yml + conda activate ${ENVNAME} fi - source activate ${ENVNAME} fi # Install editable using the setup.py