Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use dill to serialize logp functions in DensityDist #4053

Merged
merged 5 commits into from
Aug 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### Maintenance
- Mentioned the way to do any random walk with `theano.tensor.cumsum()` in `GaussianRandomWalk` docstrings (see [#4048](https://github.com/pymc-devs/pymc3/pull/4048)).
- Fixed numerical instability in ExGaussian's logp by preventing `logpow` from returning `-inf` (see [#4050](https://github.com/pymc-devs/pymc3/pull/4050)).
- Use dill to serialize user defined logp functions in `DensityDist`. The previous serialization code fails if it is used in notebooks on Windows and Mac. `dill` is now a required dependency. (see [#3844](https://github.com/pymc-devs/pymc3/issues/3844)).

### Documentation

Expand Down
2 changes: 1 addition & 1 deletion environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ dependencies:
- dataclasses # python_version < 3.7
- contextvars # python_version < 3.7
- mkl-service
- dill
- libblas=*=*mkl
- pip:
- black_nbconvert
- dill
14 changes: 14 additions & 0 deletions pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import numbers
import contextvars
import dill
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Optional, Callable
Expand Down Expand Up @@ -419,6 +420,19 @@ def __init__(
self.wrap_random_with_dist_shape = wrap_random_with_dist_shape
self.check_shape_in_random = check_shape_in_random

def __getstate__(self):
# We use dill to serialize the logp function, as this is almost
# always defined in the notebook and won't be pickled correctly.
# Fix https://github.com/pymc-devs/pymc3/issues/3844
logp = dill.dumps(self.logp)
vals = self.__dict__.copy()
vals['logp'] = logp
return vals

def __setstate__(self, vals):
vals['logp'] = dill.loads(vals['logp'])
self.__dict__ = vals

def random(self, point=None, size=None, **kwargs):
if self.rand is not None:
not_broadcast_kwargs = dict(point=point)
Expand Down
14 changes: 14 additions & 0 deletions pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@

from ..distributions import continuous
from pymc3.theanof import floatX
import pymc3 as pm
from numpy import array, inf, log, exp
from numpy.testing import assert_almost_equal, assert_allclose, assert_equal
import numpy.random as nr
Expand Down Expand Up @@ -1872,3 +1873,16 @@ def test_issue_3051(self, dims, dist_cls, kwargs):
assert isinstance(actual_a, np.ndarray)
assert actual_a.shape == (X.shape[0],)
pass


def test_serialize_density_dist():
def func(x):
return -2 * (x ** 2).sum()

with pm.Model():
pm.Normal('x')
y = pm.DensityDist('y', func)
pm.sample(draws=5, tune=1, mp_ctx="spawn")

import pickle
pickle.loads(pickle.dumps(y))
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ sphinx-autobuild==0.7.1
sphinx>=1.5.5
watermark
parameterized
dill
dill
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ h5py>=2.7.0
typing-extensions>=3.7.4
dataclasses; python_version < '3.7'
contextvars; python_version < '3.7'
dill
5 changes: 4 additions & 1 deletion scripts/create_testenv.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,18 @@ command -v conda >/dev/null 2>&1 || {
ENVNAME="${ENVNAME:-testenv}" # if no ENVNAME is specified, use testenv

if [ -z ${GLOBAL} ]; then
source $(dirname $(dirname $(which conda)))/etc/profile.d/conda.sh
if conda env list | grep -q ${ENVNAME}; then
echo "Environment ${ENVNAME} already exists, keeping up to date"
conda activate ${ENVNAME}
mamba env update -f environment-dev.yml
else
conda config --add channels conda-forge
conda config --set channel_priority strict
conda install -c conda-forge mamba --yes
mamba env create -f environment-dev.yml
conda activate ${ENVNAME}
fi
source activate ${ENVNAME}
fi

# Install editable using the setup.py
Expand Down