Skip to content

Commit

Permalink
Replace custom memoize module with cachetools
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Mar 7, 2021
1 parent 2901888 commit e6cf347
Show file tree
Hide file tree
Showing 11 changed files with 124 additions and 210 deletions.
6 changes: 3 additions & 3 deletions pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@
from aesara.graph.basic import Constant
from aesara.tensor.type import TensorType as AesaraTensorType
from aesara.tensor.var import TensorVariable
from cachetools import LRUCache, cached

from pymc3.distributions.shape_utils import (
broadcast_dist_samples_shape,
get_broadcastable_dist_samples,
to_tuple,
)
from pymc3.memoize import memoize
from pymc3.model import (
ContextMeta,
FreeRV,
Expand All @@ -52,7 +52,7 @@
ObservedRV,
build_named_node_tree,
)
from pymc3.util import get_repr_for_variable, get_var_name
from pymc3.util import get_repr_for_variable, get_var_name, hash_key
from pymc3.vartypes import string_types

__all__ = [
Expand Down Expand Up @@ -841,7 +841,7 @@ def draw_values(params, point=None, size=None):
return [evaluated[j] for j in params] # set the order back


@memoize
@cached(LRUCache(128), key=hash_key)
def _compile_aesara_function(param, vars, givens=None):
"""Compile aesara function for a given parameter and input variables.
Expand Down
113 changes: 0 additions & 113 deletions pymc3/memoize.py

This file was deleted.

8 changes: 5 additions & 3 deletions pymc3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from aesara.graph.basic import Apply, Variable
from aesara.tensor.type import TensorType as AesaraTensorType
from aesara.tensor.var import TensorVariable
from cachetools import LRUCache, cachedmethod
from pandas import Series

import pymc3 as pm
Expand All @@ -40,8 +41,7 @@
from pymc3.blocking import ArrayOrdering, DictToArrayBijection
from pymc3.exceptions import ImputationWarning
from pymc3.math import flatten_list
from pymc3.memoize import WithMemoization, memoize
from pymc3.util import get_transformed_name, get_var_name
from pymc3.util import WithMemoization, get_transformed_name, get_var_name, hash_key
from pymc3.vartypes import continuous_types, discrete_types, isgenerator, typefilter

__all__ = [
Expand Down Expand Up @@ -946,7 +946,9 @@ def isroot(self):
return self.parent is None

@property # type: ignore
@memoize(bound=True)
@cachedmethod(
lambda self: self.__dict__.setdefault("_bijection_cache", LRUCache(128)), key=hash_key
)
def bijection(self):
vars = inputvars(self.vars)

Expand Down
68 changes: 0 additions & 68 deletions pymc3/tests/test_memo.py

This file was deleted.

22 changes: 22 additions & 0 deletions pymc3/tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from pymc3.distributions.transforms import Transform
from pymc3.tests.helpers import SeededTest
from pymc3.util import hashable


class TestTransformName:
Expand Down Expand Up @@ -167,3 +168,24 @@ def test_dtype_error(self):
raise pm.exceptions.DtypeError("With types.", actual=int, expected=str)
assert "int" in exinfo.value.args[0] and "str" in exinfo.value.args[0]
pass


def test_hashing_of_rv_tuples():
obs = np.random.normal(-1, 0.1, size=10)
with pm.Model() as pmodel:
mu = pm.Normal("mu", 0, 1)
sd = pm.Gamma("sd", 1, 2)
dd = pm.DensityDist(
"dd",
pm.Normal.dist(mu, sd).logp,
random=pm.Normal.dist(mu, sd).random,
observed=obs,
)
for freerv in [mu, sd, dd] + pmodel.free_RVs:
for structure in [
freerv,
{"alpha": freerv, "omega": None},
[freerv, []],
(freerv, []),
]:
assert isinstance(hashable(structure), int)
7 changes: 2 additions & 5 deletions pymc3/tests/test_variational_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import pytest

import pymc3 as pm
import pymc3.memoize
import pymc3.util

from pymc3.aesaraf import intX
Expand Down Expand Up @@ -757,22 +756,20 @@ def test_remove_scan_op():
def test_clear_cache():
import pickle

pymc3.memoize.clear_cache()
assert all(len(c) == 0 for c in pymc3.memoize.CACHE_REGISTRY)
with pm.Model():
pm.Normal("n", 0, 1)
inference = ADVI()
inference.fit(n=10)
assert any(len(c) != 0 for c in inference.approx._cache.values())
pymc3.memoize.clear_cache(inference.approx)
inference.approx._cache.clear()
# should not be cleared at this call
assert all(len(c) == 0 for c in inference.approx._cache.values())
new_a = pickle.loads(pickle.dumps(inference.approx))
assert not hasattr(new_a, "_cache")
inference_new = pm.KLqp(new_a)
inference_new.fit(n=10)
assert any(len(c) != 0 for c in inference_new.approx._cache.values())
pymc3.memoize.clear_cache(inference_new.approx)
inference_new.approx._cache.clear()
assert all(len(c) == 0 for c in inference_new.approx._cache.values())


Expand Down
Loading

0 comments on commit e6cf347

Please sign in to comment.