From 839407b29b19edce5a0ea9e48ad70241368d8efd Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Sat, 15 Aug 2020 15:35:37 +0200 Subject: [PATCH 1/5] Use dill to serialize logp functions in DensityDist --- RELEASE-NOTES.md | 1 + environment-dev.yml | 2 +- pymc3/distributions/distribution.py | 14 ++++++++++++++ pymc3/tests/test_distributions.py | 11 +++++++++++ requirements-dev.txt | 2 +- requirements.txt | 1 + 6 files changed, 29 insertions(+), 2 deletions(-) diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 172313e36f3..57b944c21a6 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -5,6 +5,7 @@ ### Maintenance - Mentioned the way to do any random walk with `theano.tensor.cumsum()` in `GaussianRandomWalk` docstrings (see [#4048](https://github.com/pymc-devs/pymc3/pull/4048)). - Fixed numerical instability in ExGaussian's logp by preventing `logpow` from returning `-inf` (see [#4050](https://github.com/pymc-devs/pymc3/pull/4050)). +- Use dill to serialize user defined logp functions in `DensityDist`. `dill` is now a required dependency. (see [#3844](https://github.com/pymc-devs/pymc3/issues/3844)). ### Documentation diff --git a/environment-dev.yml b/environment-dev.yml index 5cfa3ec6025..23982a66352 100644 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -37,7 +37,7 @@ dependencies: - dataclasses # python_version < 3.7 - contextvars # python_version < 3.7 - mkl-service + - dill - libblas=*=*mkl - pip: - black_nbconvert - - dill diff --git a/pymc3/distributions/distribution.py b/pymc3/distributions/distribution.py index 7176fb47add..2bdf9d88c09 100644 --- a/pymc3/distributions/distribution.py +++ b/pymc3/distributions/distribution.py @@ -14,6 +14,7 @@ import numbers import contextvars +import dill from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Optional, Callable @@ -419,6 +420,19 @@ def __init__( self.wrap_random_with_dist_shape = wrap_random_with_dist_shape self.check_shape_in_random = check_shape_in_random + def __getstate__(self): + # We use dill to serialize the logp function, as this is almost + # always defined in the notebook and won't be pickled correctly. + # Fix https://github.com/pymc-devs/pymc3/issues/3844 + logp = dill.dumps(self.logp) + vals = self.__dict__.copy() + vals['logp'] = logp + return vals + + def __setstate__(self, vals): + vals['logp'] = dill.loads(vals['logp']) + self.__dict__ = vals + def random(self, point=None, size=None, **kwargs): if self.rand is not None: not_broadcast_kwargs = dict(point=point) diff --git a/pymc3/tests/test_distributions.py b/pymc3/tests/test_distributions.py index 2897211e9bc..dade81a02b8 100644 --- a/pymc3/tests/test_distributions.py +++ b/pymc3/tests/test_distributions.py @@ -79,6 +79,7 @@ from ..distributions import continuous from pymc3.theanof import floatX +import pymc3 as pm from numpy import array, inf, log, exp from numpy.testing import assert_almost_equal, assert_allclose, assert_equal import numpy.random as nr @@ -1872,3 +1873,13 @@ def test_issue_3051(self, dims, dist_cls, kwargs): assert isinstance(actual_a, np.ndarray) assert actual_a.shape == (X.shape[0],) pass + + +def test_serialize_density_dist(): + def func(x): + return -2 * (x ** 2).sum() + + with pm.Model(): + pm.Normal('x') + pm.DensityDist('y', func) + pm.sample(draws=1, tune=1, mp_ctx="spawn") diff --git a/requirements-dev.txt b/requirements-dev.txt index f11a4da9b73..13caa0d577d 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -17,4 +17,4 @@ sphinx-autobuild==0.7.1 sphinx>=1.5.5 watermark parameterized -dill \ No newline at end of file +dill diff --git a/requirements.txt b/requirements.txt index c65a007ef10..75aaba286da 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,3 +9,4 @@ h5py>=2.7.0 typing-extensions>=3.7.4 dataclasses; python_version < '3.7' contextvars; python_version < '3.7' +dill From a92ce548a96b65572de426a3182fc41419d40694 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Sat, 15 Aug 2020 16:32:17 +0200 Subject: [PATCH 2/5] Update testenv based on yml file on travis --- scripts/create_testenv.sh | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scripts/create_testenv.sh b/scripts/create_testenv.sh index 6a1e1e516ad..e7b04b60c69 100755 --- a/scripts/create_testenv.sh +++ b/scripts/create_testenv.sh @@ -24,13 +24,15 @@ ENVNAME="${ENVNAME:-testenv}" # if no ENVNAME is specified, use testenv if [ -z ${GLOBAL} ]; then if conda env list | grep -q ${ENVNAME}; then echo "Environment ${ENVNAME} already exists, keeping up to date" + source activate ${ENVNAME} + mamba env update -f environment-dev.yml else conda config --add channels conda-forge conda config --set channel_priority strict conda install -c conda-forge mamba --yes mamba env create -f environment-dev.yml + source activate ${ENVNAME} fi - source activate ${ENVNAME} fi # Install editable using the setup.py From a216190e04c6f289c73722f41dc87e155ecbd9bd Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Sat, 15 Aug 2020 20:51:56 +0200 Subject: [PATCH 3/5] Explicitly test pickling and unpickling of DensityDist --- pymc3/tests/test_distributions.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pymc3/tests/test_distributions.py b/pymc3/tests/test_distributions.py index dade81a02b8..a52ff63a4a3 100644 --- a/pymc3/tests/test_distributions.py +++ b/pymc3/tests/test_distributions.py @@ -1881,5 +1881,8 @@ def func(x): with pm.Model(): pm.Normal('x') - pm.DensityDist('y', func) - pm.sample(draws=1, tune=1, mp_ctx="spawn") + y = pm.DensityDist('y', func) + pm.sample(draws=5, tune=1, mp_ctx="spawn") + + import pickle + pickle.loads(pickle.dumps(y)) From f3bbf0758b595d85229d5f7506dfa922d50a0785 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Sun, 16 Aug 2020 14:14:25 +0200 Subject: [PATCH 4/5] Improve release notes --- RELEASE-NOTES.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 57b944c21a6..ac313f2877b 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -5,7 +5,7 @@ ### Maintenance - Mentioned the way to do any random walk with `theano.tensor.cumsum()` in `GaussianRandomWalk` docstrings (see [#4048](https://github.com/pymc-devs/pymc3/pull/4048)). - Fixed numerical instability in ExGaussian's logp by preventing `logpow` from returning `-inf` (see [#4050](https://github.com/pymc-devs/pymc3/pull/4050)). -- Use dill to serialize user defined logp functions in `DensityDist`. `dill` is now a required dependency. (see [#3844](https://github.com/pymc-devs/pymc3/issues/3844)). +- Use dill to serialize user defined logp functions in `DensityDist`. The previous serialization code fails if it is used in notebooks on Windows and Mac. `dill` is now a required dependency. (see [#3844](https://github.com/pymc-devs/pymc3/issues/3844)). ### Documentation From 712e3fbfc879aa7baaf2eab0f7bfc5afb46b41c7 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Sun, 16 Aug 2020 14:14:52 +0200 Subject: [PATCH 5/5] Use conda activate in create testenv --- scripts/create_testenv.sh | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/scripts/create_testenv.sh b/scripts/create_testenv.sh index e7b04b60c69..418d3dacc69 100755 --- a/scripts/create_testenv.sh +++ b/scripts/create_testenv.sh @@ -22,16 +22,17 @@ command -v conda >/dev/null 2>&1 || { ENVNAME="${ENVNAME:-testenv}" # if no ENVNAME is specified, use testenv if [ -z ${GLOBAL} ]; then + source $(dirname $(dirname $(which conda)))/etc/profile.d/conda.sh if conda env list | grep -q ${ENVNAME}; then echo "Environment ${ENVNAME} already exists, keeping up to date" - source activate ${ENVNAME} + conda activate ${ENVNAME} mamba env update -f environment-dev.yml else conda config --add channels conda-forge conda config --set channel_priority strict conda install -c conda-forge mamba --yes mamba env create -f environment-dev.yml - source activate ${ENVNAME} + conda activate ${ENVNAME} fi fi