From f296a5cd824f30354d04421a05566c0c3f84a806 Mon Sep 17 00:00:00 2001 From: Eelke Spaak Date: Tue, 13 Jul 2021 11:39:21 +0200 Subject: [PATCH] switch from pickle/dill to cloudpickle (#4858) * use cloudpickle for serialization * add cloudpickle to requirements * update tests for cloudpickle * update release notes with cloudpickle * update conda envs with cloudpickle * remove special case serialization for DensityDist.logp * add pickle import back in for pickle.PickleError * remove strict error message check in test --- RELEASE-NOTES.md | 1 + conda-envs/environment-dev-py37.yml | 2 +- conda-envs/environment-dev-py38.yml | 2 +- conda-envs/environment-dev-py39.yml | 2 +- conda-envs/windows-environment-dev-py38.yml | 2 +- pymc3/distributions/distribution.py | 22 ------------ pymc3/parallel_sampling.py | 37 ++++----------------- pymc3/sampling.py | 13 ++------ pymc3/tests/test_distributions.py | 4 +-- pymc3/tests/test_minibatches.py | 8 ++--- pymc3/tests/test_model.py | 10 ++---- pymc3/tests/test_parallel_sampling.py | 8 ----- pymc3/tests/test_pickling.py | 6 ++-- pymc3/tests/test_variational_inference.py | 22 ++++++------ pymc3/util.py | 4 +-- requirements-dev.txt | 2 +- requirements.txt | 2 +- 17 files changed, 42 insertions(+), 105 deletions(-) diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index cd37f6e834f..8e3b51b736b 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -26,6 +26,7 @@ - Logp method of `Uniform` and `DiscreteUniform` no longer depends on `pymc3.distributions.dist_math.bound` for proper evaluation (see [#4541](https://github.com/pymc-devs/pymc3/pull/4541)). - `Model.RV_dims` and `Model.coords` are now read-only properties. To modify the `coords` dictionary use `Model.add_coord`. Also `dims` or coordinate values that are `None` will be auto-completed (see [#4625](https://github.com/pymc-devs/pymc3/pull/4625)). - The length of `dims` in the model is now tracked symbolically through `Model.dim_lengths` (see [#4625](https://github.com/pymc-devs/pymc3/pull/4625)). +- We now include `cloudpickle` as a required dependency, and no longer depend on `dill` (see [#4858](https://github.com/pymc-devs/pymc3/pull/4858)). - ... ## PyMC3 3.11.2 (14 March 2021) diff --git a/conda-envs/environment-dev-py37.yml b/conda-envs/environment-dev-py37.yml index 7e1ac0ae795..621330fbbb5 100644 --- a/conda-envs/environment-dev-py37.yml +++ b/conda-envs/environment-dev-py37.yml @@ -6,7 +6,7 @@ dependencies: - aesara>=2.0.9 - arviz>=0.11.2 - cachetools>=4.2.1 -- dill +- cloudpickle - fastprogress>=0.2.0 - h5py>=2.7 - ipython>=7.16 diff --git a/conda-envs/environment-dev-py38.yml b/conda-envs/environment-dev-py38.yml index c267d8b8541..a3175e9727c 100644 --- a/conda-envs/environment-dev-py38.yml +++ b/conda-envs/environment-dev-py38.yml @@ -6,7 +6,7 @@ dependencies: - aesara>=2.0.9 - arviz>=0.11.2 - cachetools>=4.2.1 -- dill +- cloudpickle - fastprogress>=0.2.0 - h5py>=2.7 - ipython>=7.16 diff --git a/conda-envs/environment-dev-py39.yml b/conda-envs/environment-dev-py39.yml index 98b7788edc1..c655f8d0a62 100644 --- a/conda-envs/environment-dev-py39.yml +++ b/conda-envs/environment-dev-py39.yml @@ -6,7 +6,7 @@ dependencies: - aesara>=2.0.9 - arviz>=0.11.2 - cachetools>=4.2.1 -- dill +- cloudpickle - fastprogress>=0.2.0 - h5py>=2.7 - ipython>=7.16 diff --git a/conda-envs/windows-environment-dev-py38.yml b/conda-envs/windows-environment-dev-py38.yml index 9f085f7d475..987397f34fa 100644 --- a/conda-envs/windows-environment-dev-py38.yml +++ b/conda-envs/windows-environment-dev-py38.yml @@ -7,7 +7,7 @@ dependencies: - aesara>=2.0.9 - arviz>=0.11.2 - cachetools>=4.2.1 -- dill +- cloudpickle - fastprogress>=0.2.0 - h5py>=2.7 - libpython diff --git a/pymc3/distributions/distribution.py b/pymc3/distributions/distribution.py index 3f3f3d1e2d2..2dfa0d01c00 100644 --- a/pymc3/distributions/distribution.py +++ b/pymc3/distributions/distribution.py @@ -23,7 +23,6 @@ import aesara import aesara.tensor as at -import dill from aesara.tensor.random.op import RandomVariable from aesara.tensor.random.var import RandomStateSharedVariable @@ -533,26 +532,5 @@ def __init__( self.wrap_random_with_dist_shape = wrap_random_with_dist_shape self.check_shape_in_random = check_shape_in_random - def __getstate__(self): - # We use dill to serialize the logp function, as this is almost - # always defined in the notebook and won't be pickled correctly. - # Fix https://github.com/pymc-devs/pymc3/issues/3844 - try: - logp = dill.dumps(self.logp) - except RecursionError as err: - if type(self.logp) == types.MethodType: - raise ValueError( - "logp for DensityDist is a bound method, leading to RecursionError while serializing" - ) from err - else: - raise err - vals = self.__dict__.copy() - vals["logp"] = logp - return vals - - def __setstate__(self, vals): - vals["logp"] = dill.loads(vals["logp"]) - self.__dict__ = vals - def _distr_parameters_for_repr(self): return [] diff --git a/pymc3/parallel_sampling.py b/pymc3/parallel_sampling.py index 9d8cb4d7ff8..5aa49071e33 100644 --- a/pymc3/parallel_sampling.py +++ b/pymc3/parallel_sampling.py @@ -16,13 +16,13 @@ import logging import multiprocessing import multiprocessing.sharedctypes -import pickle import platform import time import traceback from collections import namedtuple +import cloudpickle import numpy as np from fastprogress.fastprogress import progress_bar @@ -93,7 +93,6 @@ def __init__( draws: int, tune: int, seed, - pickle_backend, ): self._msg_pipe = msg_pipe self._step_method = step_method @@ -103,7 +102,6 @@ def __init__( self._at_seed = seed + 1 self._draws = draws self._tune = tune - self._pickle_backend = pickle_backend def _unpickle_step_method(self): unpickle_error = ( @@ -112,22 +110,10 @@ def _unpickle_step_method(self): "or forkserver." ) if self._step_method_is_pickled: - if self._pickle_backend == "pickle": - try: - self._step_method = pickle.loads(self._step_method) - except Exception: - raise ValueError(unpickle_error) - elif self._pickle_backend == "dill": - try: - import dill - except ImportError: - raise ValueError("dill must be installed for pickle_backend='dill'.") - try: - self._step_method = dill.loads(self._step_method) - except Exception: - raise ValueError(unpickle_error) - else: - raise ValueError("Unknown pickle backend") + try: + self._step_method = cloudpickle.loads(self._step_method) + except Exception: + raise ValueError(unpickle_error) def run(self): try: @@ -243,7 +229,6 @@ def __init__( seed, start, mp_ctx, - pickle_backend, ): self.chain = chain process_name = "worker_chain_%s" % chain @@ -287,7 +272,6 @@ def __init__( draws, tune, seed, - pickle_backend, ), ) self._process.start() @@ -406,7 +390,6 @@ def __init__( start_chain_num: int = 0, progressbar: bool = True, mp_ctx=None, - pickle_backend: str = "pickle", ): if any(len(arg) != chains for arg in [seeds, start_points]): @@ -420,14 +403,7 @@ def __init__( step_method_pickled = None if mp_ctx.get_start_method() != "fork": - if pickle_backend == "pickle": - step_method_pickled = pickle.dumps(step_method, protocol=-1) - elif pickle_backend == "dill": - try: - import dill - except ImportError: - raise ValueError("dill must be installed for pickle_backend='dill'.") - step_method_pickled = dill.dumps(step_method, protocol=-1) + step_method_pickled = cloudpickle.dumps(step_method, protocol=-1) self._samplers = [ ProcessAdapter( @@ -439,7 +415,6 @@ def __init__( seed, start, mp_ctx, - pickle_backend, ) for chain, seed, start in zip(range(chains), seeds, start_points) ] diff --git a/pymc3/sampling.py b/pymc3/sampling.py index 607599f4817..922c28c7635 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -26,6 +26,7 @@ from typing import Any, Dict, Iterable, List, Optional, Set, Union, cast import aesara.gradient as tg +import cloudpickle import numpy as np import xarray @@ -268,7 +269,6 @@ def sample( return_inferencedata=None, idata_kwargs: dict = None, mp_ctx=None, - pickle_backend: str = "pickle", **kwargs, ): r"""Draw samples from the posterior using the given step methods. @@ -362,10 +362,6 @@ def sample( mp_ctx : multiprocessing.context.BaseContent A multiprocessing context for parallel sampling. See multiprocessing documentation for details. - pickle_backend : str - One of `'pickle'` or `'dill'`. The library used to pickle models - in parallel sampling if the multiprocessing context is not of type - `fork`. Returns ------- @@ -548,7 +544,6 @@ def sample( "discard_tuned_samples": discard_tuned_samples, } parallel_args = { - "pickle_backend": pickle_backend, "mp_ctx": mp_ctx, } @@ -1100,7 +1095,7 @@ def __init__(self, steppers, parallelize, progressbar=True): enumerate(progress_bar(steppers)) if progressbar else enumerate(steppers) ): secondary_end, primary_end = multiprocessing.Pipe() - stepper_dumps = pickle.dumps(stepper, protocol=4) + stepper_dumps = cloudpickle.dumps(stepper, protocol=4) process = multiprocessing.Process( target=self.__class__._run_secondary, args=(c, stepper_dumps, secondary_end), @@ -1159,7 +1154,7 @@ def _run_secondary(c, stepper_dumps, secondary_end): # re-seed each child process to make them unique np.random.seed(None) try: - stepper = pickle.loads(stepper_dumps) + stepper = cloudpickle.loads(stepper_dumps) # the stepper is not necessarily a PopulationArraySharedStep itself, # but rather a CompoundStep. PopulationArrayStepShared.population # has to be updated, therefore we identify the substeppers first. @@ -1418,7 +1413,6 @@ def _mp_sample( callback=None, discard_tuned_samples=True, mp_ctx=None, - pickle_backend="pickle", **kwargs, ): """Main iteration for multiprocess sampling. @@ -1491,7 +1485,6 @@ def _mp_sample( chain, progressbar, mp_ctx=mp_ctx, - pickle_backend=pickle_backend, ) try: try: diff --git a/pymc3/tests/test_distributions.py b/pymc3/tests/test_distributions.py index 5648ae48720..9bcec0a94a6 100644 --- a/pymc3/tests/test_distributions.py +++ b/pymc3/tests/test_distributions.py @@ -3109,9 +3109,9 @@ def func(x): y = pm.DensityDist("y", func) pm.sample(draws=5, tune=1, mp_ctx="spawn") - import pickle + import cloudpickle - pickle.loads(pickle.dumps(y)) + cloudpickle.loads(cloudpickle.dumps(y)) def test_distinct_rvs(): diff --git a/pymc3/tests/test_minibatches.py b/pymc3/tests/test_minibatches.py index 64a8cbc42df..cc42bcd92b4 100644 --- a/pymc3/tests/test_minibatches.py +++ b/pymc3/tests/test_minibatches.py @@ -13,9 +13,9 @@ # limitations under the License. import itertools -import pickle import aesara +import cloudpickle import numpy as np import pytest @@ -132,10 +132,10 @@ def gen(): def test_pickling(self, datagen): gen = generator(datagen) - pickle.loads(pickle.dumps(gen)) + cloudpickle.loads(cloudpickle.dumps(gen)) bad_gen = generator(integers()) - with pytest.raises(Exception): - pickle.dumps(bad_gen) + with pytest.raises(TypeError): + cloudpickle.dumps(bad_gen) def test_gen_cloning_with_shape_change(self, datagen): gen = generator(datagen) diff --git a/pymc3/tests/test_model.py b/pymc3/tests/test_model.py index 7968ff80fac..b94864ee4cb 100644 --- a/pymc3/tests/test_model.py +++ b/pymc3/tests/test_model.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import pickle import unittest from functools import reduce @@ -19,6 +18,7 @@ import aesara import aesara.sparse as sparse import aesara.tensor as at +import cloudpickle import numpy as np import numpy.ma as ma import numpy.testing as npt @@ -407,9 +407,7 @@ def test_model_pickle(tmpdir): x = pm.Normal("x") pm.Normal("y", observed=1) - file_path = tmpdir.join("model.p") - with open(file_path, "wb") as buff: - pickle.dump(model, buff) + cloudpickle.loads(cloudpickle.dumps(model)) def test_model_pickle_deterministic(tmpdir): @@ -420,9 +418,7 @@ def test_model_pickle_deterministic(tmpdir): pm.Deterministic("w", x / z) pm.Normal("y", observed=1) - file_path = tmpdir.join("model.p") - with open(file_path, "wb") as buff: - pickle.dump(model, buff) + cloudpickle.loads(cloudpickle.dumps(model)) def test_model_vars(): diff --git a/pymc3/tests/test_parallel_sampling.py b/pymc3/tests/test_parallel_sampling.py index 5348cd72ada..d58604b93e3 100644 --- a/pymc3/tests/test_parallel_sampling.py +++ b/pymc3/tests/test_parallel_sampling.py @@ -71,12 +71,6 @@ def _crash_remote_process(a, master_pid): return 2 * np.array(a) -def test_dill(): - with pm.Model(): - pm.Normal("x") - pm.sample(tune=1, draws=1, chains=2, cores=2, pickle_backend="dill", mp_ctx="spawn") - - def test_remote_pipe_closed(): master_pid = os.getpid() with pm.Model(): @@ -112,7 +106,6 @@ def test_abort(): mp_ctx=ctx, start={"a": np.array([1.0]), "b_log__": np.array(2.0)}, step_method_pickled=None, - pickle_backend="pickle", ) proc.start() while True: @@ -147,7 +140,6 @@ def test_explicit_sample(): mp_ctx=ctx, start={"a": np.array([1.0]), "b_log__": np.array(2.0)}, step_method_pickled=None, - pickle_backend="pickle", ) proc.start() while True: diff --git a/pymc3/tests/test_pickling.py b/pymc3/tests/test_pickling.py index edb7c0b07eb..06284587a8a 100644 --- a/pymc3/tests/test_pickling.py +++ b/pymc3/tests/test_pickling.py @@ -15,6 +15,8 @@ import pickle import traceback +import cloudpickle + from pymc3.tests.models import simple_model @@ -26,8 +28,8 @@ def test_model_roundtrip(self): m = self.model for proto in range(pickle.HIGHEST_PROTOCOL + 1): try: - s = pickle.dumps(m, proto) - pickle.loads(s) + s = cloudpickle.dumps(m, proto) + cloudpickle.loads(s) except Exception: raise AssertionError( "Exception while trying roundtrip with pickle protocol %d:\n" % proto diff --git a/pymc3/tests/test_variational_inference.py b/pymc3/tests/test_variational_inference.py index 46215242589..c5d193ba2cd 100644 --- a/pymc3/tests/test_variational_inference.py +++ b/pymc3/tests/test_variational_inference.py @@ -757,7 +757,7 @@ def test_remove_scan_op(): def test_clear_cache(): - import pickle + import cloudpickle with pm.Model(): pm.Normal("n", 0, 1) @@ -767,7 +767,7 @@ def test_clear_cache(): inference.approx._cache.clear() # should not be cleared at this call assert all(len(c) == 0 for c in inference.approx._cache.values()) - new_a = pickle.loads(pickle.dumps(inference.approx)) + new_a = cloudpickle.loads(cloudpickle.dumps(inference.approx)) assert not hasattr(new_a, "_cache") inference_new = pm.KLqp(new_a) inference_new.fit(n=10) @@ -871,26 +871,26 @@ def test_rowwise_approx(three_var_model, parametric_grouped_approxes): def test_pickle_approx(three_var_approx): - import pickle + import cloudpickle - dump = pickle.dumps(three_var_approx) - new = pickle.loads(dump) + dump = cloudpickle.dumps(three_var_approx) + new = cloudpickle.loads(dump) assert new.sample(1) def test_pickle_single_group(three_var_approx_single_group_mf): - import pickle + import cloudpickle - dump = pickle.dumps(three_var_approx_single_group_mf) - new = pickle.loads(dump) + dump = cloudpickle.dumps(three_var_approx_single_group_mf) + new = cloudpickle.loads(dump) assert new.sample(1) def test_pickle_approx_aevb(three_var_aevb_approx): - import pickle + import cloudpickle - dump = pickle.dumps(three_var_aevb_approx) - new = pickle.loads(dump) + dump = cloudpickle.dumps(three_var_aevb_approx) + new = cloudpickle.loads(dump) assert new.sample(1000) diff --git a/pymc3/util.py b/pymc3/util.py index d60f83caffa..13ca7882868 100644 --- a/pymc3/util.py +++ b/pymc3/util.py @@ -19,7 +19,7 @@ from typing import Dict, List, Tuple, Union import arviz -import dill +import cloudpickle import numpy as np import xarray @@ -347,7 +347,7 @@ def hashable(a=None) -> int: pass # Not hashable >>> try: - return hash(dill.dumps(a)) + return hash(cloudpickle.dumps(a)) except Exception: if hasattr(a, "__dict__"): return hashable(a.__dict__) diff --git a/requirements-dev.txt b/requirements-dev.txt index 22fd046502c..d9bd59b577f 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -4,7 +4,7 @@ aesara>=2.0.9 arviz>=0.11.2 cachetools>=4.2.1 -dill +cloudpickle fastprogress>=0.2.0 h5py>=2.7 ipython>=7.16 diff --git a/requirements.txt b/requirements.txt index 70db77f661b..ec41fbaf99a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ aesara>=2.0.9 arviz>=0.11.2 cachetools>=4.2.1 -dill +cloudpickle fastprogress>=0.2.0 numpy>=1.15.0 pandas>=0.24.0