From e6cf347ce769fa5c5e2245ef2e40ce1e0a3f618e Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Sun, 7 Mar 2021 14:18:17 -0600 Subject: [PATCH] Replace custom memoize module with cachetools --- pymc3/distributions/distribution.py | 6 +- pymc3/memoize.py | 113 ---------------------- pymc3/model.py | 8 +- pymc3/tests/test_memo.py | 68 ------------- pymc3/tests/test_util.py | 22 +++++ pymc3/tests/test_variational_inference.py | 7 +- pymc3/util.py | 75 ++++++++++++++ pymc3/variational/flows.py | 2 +- pymc3/variational/opvi.py | 27 +++--- pymc3/variational/stein.py | 5 +- requirements.txt | 1 + 11 files changed, 124 insertions(+), 210 deletions(-) delete mode 100644 pymc3/memoize.py delete mode 100644 pymc3/tests/test_memo.py diff --git a/pymc3/distributions/distribution.py b/pymc3/distributions/distribution.py index d0ef10b236c..cd79cf3ad62 100644 --- a/pymc3/distributions/distribution.py +++ b/pymc3/distributions/distribution.py @@ -37,13 +37,13 @@ from aesara.graph.basic import Constant from aesara.tensor.type import TensorType as AesaraTensorType from aesara.tensor.var import TensorVariable +from cachetools import LRUCache, cached from pymc3.distributions.shape_utils import ( broadcast_dist_samples_shape, get_broadcastable_dist_samples, to_tuple, ) -from pymc3.memoize import memoize from pymc3.model import ( ContextMeta, FreeRV, @@ -52,7 +52,7 @@ ObservedRV, build_named_node_tree, ) -from pymc3.util import get_repr_for_variable, get_var_name +from pymc3.util import get_repr_for_variable, get_var_name, hash_key from pymc3.vartypes import string_types __all__ = [ @@ -841,7 +841,7 @@ def draw_values(params, point=None, size=None): return [evaluated[j] for j in params] # set the order back -@memoize +@cached(LRUCache(128), key=hash_key) def _compile_aesara_function(param, vars, givens=None): """Compile aesara function for a given parameter and input variables. diff --git a/pymc3/memoize.py b/pymc3/memoize.py deleted file mode 100644 index cbe791f10ce..00000000000 --- a/pymc3/memoize.py +++ /dev/null @@ -1,113 +0,0 @@ -# Copyright 2020 The PyMC Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import collections -import functools - -import dill - -from pymc3.util import biwrap - -CACHE_REGISTRY = [] - - -@biwrap -def memoize(obj, bound=False): - """ - Decorator to apply memoization to expensive functions. - It uses a custom `hashable` helper function to hash typically unhashable Python objects. - - Parameters - ---------- - obj : callable - the function to apply the caching to - bound : bool - indicates if the [obj] is a bound method (self as first argument) - For bound methods, the cache is kept in a `_cache` attribute on [self]. - """ - # this is declared not to be a bound method, so just attach new attr to obj - if not bound: - obj.cache = {} - CACHE_REGISTRY.append(obj.cache) - - @functools.wraps(obj) - def memoizer(*args, **kwargs): - if not bound: - key = (hashable(args), hashable(kwargs)) - cache = obj.cache - else: - # bound methods have self as first argument, remove it to compute key - key = (hashable(args[1:]), hashable(kwargs)) - if not hasattr(args[0], "_cache"): - setattr(args[0], "_cache", collections.defaultdict(dict)) - # do not add to cache registry - cache = getattr(args[0], "_cache")[obj.__name__] - if key not in cache: - cache[key] = obj(*args, **kwargs) - - return cache[key] - - return memoizer - - -def clear_cache(obj=None): - if obj is None: - for c in CACHE_REGISTRY: - c.clear() - else: - if isinstance(obj, WithMemoization): - for v in getattr(obj, "_cache", {}).values(): - v.clear() - else: - obj.cache.clear() - - -class WithMemoization: - def __hash__(self): - return hash(id(self)) - - def __getstate__(self): - state = self.__dict__.copy() - state.pop("_cache", None) - return state - - def __setstate__(self, state): - self.__dict__.update(state) - - -def hashable(a) -> int: - """ - Hashes many kinds of objects, including some that are unhashable through the builtin `hash` function. - Lists and tuples are hashed based on their elements. - """ - if isinstance(a, dict): - # first hash the keys and values with hashable - # then hash the tuple of int-tuples with the builtin - return hash(tuple((hashable(k), hashable(v)) for k, v in a.items())) - if isinstance(a, (tuple, list)): - # lists are mutable and not hashable by default - # for memoization, we need the hash to depend on the items - return hash(tuple(hashable(i) for i in a)) - try: - return hash(a) - except TypeError: - pass - # Not hashable >>> - try: - return hash(dill.dumps(a)) - except Exception: - if hasattr(a, "__dict__"): - return hashable(a.__dict__) - else: - return id(a) diff --git a/pymc3/model.py b/pymc3/model.py index a5a0a635c8b..76800e79600 100644 --- a/pymc3/model.py +++ b/pymc3/model.py @@ -32,6 +32,7 @@ from aesara.graph.basic import Apply, Variable from aesara.tensor.type import TensorType as AesaraTensorType from aesara.tensor.var import TensorVariable +from cachetools import LRUCache, cachedmethod from pandas import Series import pymc3 as pm @@ -40,8 +41,7 @@ from pymc3.blocking import ArrayOrdering, DictToArrayBijection from pymc3.exceptions import ImputationWarning from pymc3.math import flatten_list -from pymc3.memoize import WithMemoization, memoize -from pymc3.util import get_transformed_name, get_var_name +from pymc3.util import WithMemoization, get_transformed_name, get_var_name, hash_key from pymc3.vartypes import continuous_types, discrete_types, isgenerator, typefilter __all__ = [ @@ -946,7 +946,9 @@ def isroot(self): return self.parent is None @property # type: ignore - @memoize(bound=True) + @cachedmethod( + lambda self: self.__dict__.setdefault("_bijection_cache", LRUCache(128)), key=hash_key + ) def bijection(self): vars = inputvars(self.vars) diff --git a/pymc3/tests/test_memo.py b/pymc3/tests/test_memo.py deleted file mode 100644 index 6653662e32e..00000000000 --- a/pymc3/tests/test_memo.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright 2020 The PyMC Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import numpy as np - -import pymc3 as pm - -from pymc3 import memoize - - -def test_memo(): - def fun(inputs, suffix="_a"): - return str(inputs) + str(suffix) - - inputs = ["i1", "i2"] - assert fun(inputs) == "['i1', 'i2']_a" - assert fun(inputs, "_b") == "['i1', 'i2']_b" - - funmem = memoize.memoize(fun) - assert hasattr(fun, "cache") - assert isinstance(fun.cache, dict) - assert len(fun.cache) == 0 - - # call the memoized function with a list input - # and check the size of the cache! - assert funmem(inputs) == "['i1', 'i2']_a" - assert funmem(inputs) == "['i1', 'i2']_a" - assert len(fun.cache) == 1 - assert funmem(inputs, "_b") == "['i1', 'i2']_b" - assert funmem(inputs, "_b") == "['i1', 'i2']_b" - assert len(fun.cache) == 2 - - # add items to the inputs list (the list instance remains identical !!) - inputs.append("i3") - assert funmem(inputs) == "['i1', 'i2', 'i3']_a" - assert funmem(inputs) == "['i1', 'i2', 'i3']_a" - assert len(fun.cache) == 3 - - -def test_hashing_of_rv_tuples(): - obs = np.random.normal(-1, 0.1, size=10) - with pm.Model() as pmodel: - mu = pm.Normal("mu", 0, 1) - sd = pm.Gamma("sd", 1, 2) - dd = pm.DensityDist( - "dd", - pm.Normal.dist(mu, sd).logp, - random=pm.Normal.dist(mu, sd).random, - observed=obs, - ) - for freerv in [mu, sd, dd] + pmodel.free_RVs: - for structure in [ - freerv, - {"alpha": freerv, "omega": None}, - [freerv, []], - (freerv, []), - ]: - assert isinstance(memoize.hashable(structure), int) diff --git a/pymc3/tests/test_util.py b/pymc3/tests/test_util.py index adb334fb8af..ab049643bb7 100644 --- a/pymc3/tests/test_util.py +++ b/pymc3/tests/test_util.py @@ -21,6 +21,7 @@ from pymc3.distributions.transforms import Transform from pymc3.tests.helpers import SeededTest +from pymc3.util import hashable class TestTransformName: @@ -167,3 +168,24 @@ def test_dtype_error(self): raise pm.exceptions.DtypeError("With types.", actual=int, expected=str) assert "int" in exinfo.value.args[0] and "str" in exinfo.value.args[0] pass + + +def test_hashing_of_rv_tuples(): + obs = np.random.normal(-1, 0.1, size=10) + with pm.Model() as pmodel: + mu = pm.Normal("mu", 0, 1) + sd = pm.Gamma("sd", 1, 2) + dd = pm.DensityDist( + "dd", + pm.Normal.dist(mu, sd).logp, + random=pm.Normal.dist(mu, sd).random, + observed=obs, + ) + for freerv in [mu, sd, dd] + pmodel.free_RVs: + for structure in [ + freerv, + {"alpha": freerv, "omega": None}, + [freerv, []], + (freerv, []), + ]: + assert isinstance(hashable(structure), int) diff --git a/pymc3/tests/test_variational_inference.py b/pymc3/tests/test_variational_inference.py index 8e115350b49..820c5b4dc8b 100644 --- a/pymc3/tests/test_variational_inference.py +++ b/pymc3/tests/test_variational_inference.py @@ -22,7 +22,6 @@ import pytest import pymc3 as pm -import pymc3.memoize import pymc3.util from pymc3.aesaraf import intX @@ -757,14 +756,12 @@ def test_remove_scan_op(): def test_clear_cache(): import pickle - pymc3.memoize.clear_cache() - assert all(len(c) == 0 for c in pymc3.memoize.CACHE_REGISTRY) with pm.Model(): pm.Normal("n", 0, 1) inference = ADVI() inference.fit(n=10) assert any(len(c) != 0 for c in inference.approx._cache.values()) - pymc3.memoize.clear_cache(inference.approx) + inference.approx._cache.clear() # should not be cleared at this call assert all(len(c) == 0 for c in inference.approx._cache.values()) new_a = pickle.loads(pickle.dumps(inference.approx)) @@ -772,7 +769,7 @@ def test_clear_cache(): inference_new = pm.KLqp(new_a) inference_new.fit(n=10) assert any(len(c) != 0 for c in inference_new.approx._cache.values()) - pymc3.memoize.clear_cache(inference_new.approx) + inference_new.approx._cache.clear() assert all(len(c) == 0 for c in inference_new.approx._cache.values()) diff --git a/pymc3/util.py b/pymc3/util.py index f0429901f8e..b2aa1e64d73 100644 --- a/pymc3/util.py +++ b/pymc3/util.py @@ -19,10 +19,12 @@ from typing import Dict, List, Tuple, Union import arviz +import dill import numpy as np import xarray from aesara.tensor.var import TensorVariable +from cachetools import cachedmethod from pymc3.exceptions import SamplingError @@ -304,3 +306,76 @@ def chains_and_samples(data: Union[xarray.Dataset, arviz.InferenceData]) -> Tupl nchains = coords["chain"].sizes["chain"] nsamples = coords["draw"].sizes["draw"] return nchains, nsamples + + +def hashable(a=None) -> int: + """ + Hashes many kinds of objects, including some that are unhashable through the builtin `hash` function. + Lists and tuples are hashed based on their elements. + """ + if isinstance(a, dict): + # first hash the keys and values with hashable + # then hash the tuple of int-tuples with the builtin + return hash(tuple((hashable(k), hashable(v)) for k, v in a.items())) + if isinstance(a, (tuple, list)): + # lists are mutable and not hashable by default + # for memoization, we need the hash to depend on the items + return hash(tuple(hashable(i) for i in a)) + try: + return hash(a) + except TypeError: + pass + # Not hashable >>> + try: + return hash(dill.dumps(a)) + except Exception: + if hasattr(a, "__dict__"): + return hashable(a.__dict__) + else: + return id(a) + + +def hash_key(*args, **kwargs): + return tuple(HashableWrapper(a) for a in args + tuple(kwargs.items())) + + +class HashableWrapper: + __slots__ = ("obj",) + + def __init__(self, obj): + self.obj = obj + + def __hash__(self): + return hashable(self.obj) + + def __eq__(self, other): + return self.obj == other + + def __repr__(self): + return f"{type(self).__name__}({self.obj})" + + +class WithMemoization: + def __hash__(self): + return hash(id(self)) + + def __getstate__(self): + state = self.__dict__.copy() + state.pop("_cache", None) + return state + + def __setstate__(self, state): + self.__dict__.update(state) + + +def locally_cachedmethod(f): + + from collections import defaultdict + + def self_cache_fn(f_name): + def cf(self): + return self.__dict__.setdefault("_cache", defaultdict(dict))[f_name] + + return cf + + return cachedmethod(self_cache_fn(f.__name__), key=hashable)(f) diff --git a/pymc3/variational/flows.py b/pymc3/variational/flows.py index f78c32e69bb..68f6095afb5 100644 --- a/pymc3/variational/flows.py +++ b/pymc3/variational/flows.py @@ -18,7 +18,7 @@ from aesara import tensor as aet from pymc3.distributions.dist_math import rho2sigma -from pymc3.memoize import WithMemoization +from pymc3.util import WithMemoization from pymc3.variational import opvi from pymc3.variational.opvi import collect_shared_to_list, node_property diff --git a/pymc3/variational/opvi.py b/pymc3/variational/opvi.py index 115c0abcaef..cf749b24567 100644 --- a/pymc3/variational/opvi.py +++ b/pymc3/variational/opvi.py @@ -60,9 +60,13 @@ from pymc3.aesaraf import aet_rng, identity from pymc3.backends import NDArray from pymc3.blocking import ArrayOrdering, DictToArrayBijection, VarMap -from pymc3.memoize import WithMemoization, memoize from pymc3.model import modelcontext -from pymc3.util import get_default_varnames, get_transformed +from pymc3.util import ( + WithMemoization, + get_default_varnames, + get_transformed, + locally_cachedmethod, +) from pymc3.variational.updates import adagrad_window __all__ = ["ObjectiveFunction", "Operator", "TestFunction", "Group", "Approximation"] @@ -113,21 +117,18 @@ def inner(*args, **kwargs): def node_property(f): """A shortcut for wrapping method to accessible tensor""" + if isinstance(f, str): def wrapper(fn): - return property( - memoize( - aesara.config.change_flags(compute_test_value="off")(append_name(f)(fn)), - bound=True, - ) - ) + ff = append_name(f)(fn) + f_ = aesara.config.change_flags(compute_test_value="off")(ff) + return property(locally_cachedmethod(f_)) return wrapper else: - return property( - memoize(aesara.config.change_flags(compute_test_value="off")(f), bound=True) - ) + f_ = aesara.config.change_flags(compute_test_value="off")(f) + return property(locally_cachedmethod(f_)) @aesara.config.change_flags(compute_test_value="ignore") @@ -1588,9 +1589,7 @@ def vars_names(vs): raise KeyError("%r not found" % name) return found - @property - @memoize(bound=True) - @aesara.config.change_flags(compute_test_value="off") + @node_property def sample_dict_fn(self): s = aet.iscalar() names = [v.name for v in self.model.free_RVs] diff --git a/pymc3/variational/stein.py b/pymc3/variational/stein.py index 79a7d78183c..b5701e5240a 100644 --- a/pymc3/variational/stein.py +++ b/pymc3/variational/stein.py @@ -16,7 +16,7 @@ import aesara.tensor as aet from pymc3.aesaraf import floatX -from pymc3.memoize import WithMemoization, memoize +from pymc3.util import WithMemoization, locally_cachedmethod from pymc3.variational.opvi import node_property from pymc3.variational.test_functions import rbf @@ -90,7 +90,6 @@ def logp_norm(self): ) return sized_symbolic_logp / self.approx.symbolic_normalizing_constant - @memoize - @aesara.config.change_flags(compute_test_value="off") + @locally_cachedmethod def _kernel(self): return self._kernel_f(self.input_joint_matrix) diff --git a/requirements.txt b/requirements.txt index 9ec84e75387..66b4f0e2dee 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ aesara>=2.0.1 arviz>=0.11.1 +cachetools>=4.2.1 dill fastprogress>=0.2.0 numpy>=1.15.0