Skip to content

Commit

Permalink
OPVI speedup (#2759)
Browse files Browse the repository at this point in the history
* fix scan op redundancy

* fix clear cache

* Better solution for caching

* Redundant usages of memoize

* Clear cache function

* clear cache

* fix testing

* fix unused import

* fix imports

* fix imports
  • Loading branch information
ferrine authored and Junpeng Lao committed Dec 19, 2017
1 parent de1b8c8 commit 1cdd163
Show file tree
Hide file tree
Showing 7 changed files with 98 additions and 36 deletions.
27 changes: 26 additions & 1 deletion pymc3/memoize.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
import functools
import pickle

CACHE_REGISTRY = []


def memoize(obj):
"""
An expensive memoizer that works with unhashables
"""
cache = obj.cache = {}
CACHE_REGISTRY.append(cache)

@functools.wraps(obj)
def memoizer(*args, **kwargs):
key = (hashable(args), hashable(kwargs))
# remember first argument as well, used to clear cache for particular instance
key = (hashable(args[:1]), hashable(args), hashable(kwargs))

if key not in cache:
cache[key] = obj(*args, **kwargs)
Expand All @@ -19,6 +23,27 @@ def memoizer(*args, **kwargs):
return memoizer


def clear_cache():
for c in CACHE_REGISTRY:
c.clear()


class WithMemoization(object):
def __hash__(self):
return hash(id(self))

def __del__(self):
# regular property call with args (self, )
key = hash((self, ))
to_del = []
for c in CACHE_REGISTRY:
for k in c.keys():
if k[0] == key:
to_del.append((c, k))
for (c, k) in to_del:
del c[k]


def hashable(a):
"""
Turn some unhashable objects into hashable ones.
Expand Down
4 changes: 2 additions & 2 deletions pymc3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from pymc3.theanof import set_theano_conf
import pymc3 as pm
from pymc3.math import flatten_list
from .memoize import memoize
from .memoize import memoize, WithMemoization
from .theanof import gradient, hessian, inputvars, generator
from .vartypes import typefilter, discrete_types, continuous_types, isgenerator
from .blocking import DictToArrayBijection, ArrayOrdering
Expand Down Expand Up @@ -487,7 +487,7 @@ def _build_joined(self, cost, args, vmap):
return args_joined, theano.clone(cost, replace=replace)


class Model(six.with_metaclass(InitContextMeta, Context, Factor)):
class Model(six.with_metaclass(InitContextMeta, Context, Factor, WithMemoization)):
"""Encapsulates the variables and likelihood factors of a model.
Model class can be used for creating class based models. To create
Expand Down
31 changes: 26 additions & 5 deletions pymc3/tests/test_variational_inference.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import pytest
import six
import functools
import operator
import numpy as np
from theano import theano, tensor as tt


import pymc3 as pm
import pymc3.memoize
import pymc3.util
from pymc3.theanof import change_flags
from pymc3.variational.approximations import (
Expand Down Expand Up @@ -596,11 +598,30 @@ def test_fit_oo(inference,


def test_profile(inference):
try:
inference.run_profiling(n=100).summary()
except ZeroDivisionError:
# weird error in SVGD, ASVGD
pass
inference.run_profiling(n=100).summary()


def test_remove_scan_op():
with pm.Model():
pm.Normal('n', 0, 1)
inference = ADVI()
buff = six.StringIO()
inference.run_profiling(n=10).summary(buff)
assert 'theano.scan_module.scan_op.Scan' not in buff.getvalue()
buff.close()


def test_clear_cache():
pymc3.memoize.clear_cache()
with pm.Model():
pm.Normal('n', 0, 1)
inference = ADVI()
inference.fit(n=10)
assert len(pm.variational.opvi.Approximation.logp.fget.cache) == 1
del inference
assert len(pm.variational.opvi.Approximation.logp.fget.cache) == 0
for c in pymc3.memoize.CACHE_REGISTRY:
assert len(c) == 0


@pytest.fixture('module')
Expand Down
9 changes: 2 additions & 7 deletions pymc3/theanof.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from .blocking import ArrayOrdering
from .data import GeneratorAdapter
from .memoize import memoize
from .vartypes import typefilter, continuous_types

__all__ = ['gradient',
Expand Down Expand Up @@ -85,10 +84,10 @@ def gradient1(f, v):
"""flat gradient of f wrt v"""
return tt.flatten(tt.grad(f, v, disconnected_inputs='warn'))


empty_gradient = tt.zeros(0, dtype='float32')


@memoize
def gradient(f, vars=None):
if vars is None:
vars = cont_inputs(f)
Expand All @@ -110,7 +109,6 @@ def grad_i(i):
return theano.map(grad_i, idx)[0]


@memoize
def jacobian(f, vars=None):
if vars is None:
vars = cont_inputs(f)
Expand All @@ -132,7 +130,6 @@ def grad_ii(i):
name='jacobian_diag')[0]


@memoize
@change_flags(compute_test_value='ignore')
def hessian(f, vars=None):
return -jacobian(gradient(f, vars), vars)
Expand All @@ -149,7 +146,6 @@ def hess_ii(i):
return theano.map(hess_ii, idx)[0]


@memoize
@change_flags(compute_test_value='ignore')
def hessian_diag(f, vars=None):
if vars is None:
Expand Down Expand Up @@ -276,6 +272,7 @@ def __call__(self, input):
oldinput, = inputvars(self.tensor)
return theano.clone(self.tensor, {oldinput: input}, strict=False)


scalar_identity = IdentityOp(scalar.upgrade_to_float, name='scalar_identity')
identity = tt.Elemwise(scalar_identity, name='identity')

Expand Down Expand Up @@ -463,5 +460,3 @@ def largest_common_dtype(tensors):
else smartfloatX(np.asarray(t)).dtype
for t in tensors)
return np.stack([np.ones((), dtype=dtype) for dtype in dtypes]).dtype


10 changes: 7 additions & 3 deletions pymc3/variational/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@
import theano
from theano import tensor as tt

from pymc3.distributions.dist_math import rho2sd
from pymc3.theanof import change_flags
from ..distributions.dist_math import rho2sd
from ..theanof import change_flags
from ..memoize import WithMemoization
from .opvi import node_property, collect_shared_to_list
from . import opvi

__all__ = [
'Formula',
'PlanarFlow',
'HouseholderFlow',
'RadialFlow',
'LocFlow',
'ScaleFlow'
]
Expand Down Expand Up @@ -97,7 +100,7 @@ def seems_like_flow_params(params):
return False


class AbstractFlow(object):
class AbstractFlow(WithMemoization):
shared_params = None
__param_spec__ = dict()
short_name = ''
Expand Down Expand Up @@ -255,6 +258,7 @@ def __repr__(self):
def __str__(self):
return self.short_name


flow_for_params = AbstractFlow.flow_for_params
flow_for_short_name = AbstractFlow.flow_for_short_name

Expand Down
48 changes: 32 additions & 16 deletions pymc3/variational/opvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@
ArrayOrdering, DictToArrayBijection, VarMap
)
from ..model import modelcontext
from ..theanof import tt_rng, memoize, change_flags, identity
from ..theanof import tt_rng, change_flags, identity
from ..util import get_default_varnames
from ..memoize import WithMemoization, memoize

__all__ = [
'ObjectiveFunction',
Expand Down Expand Up @@ -86,10 +87,29 @@ class LocalGroupError(BatchedGroupError, AEVBInferenceError):
"""Error raised in case of bad local_rv usage"""


def append_name(name):
def wrap(f):
if name is None:
return f

def inner(*args, **kwargs):
res = f(*args, **kwargs)
res.name = name
return res
return inner
return wrap


def node_property(f):
"""A shortcut for wrapping method to accessible tensor
"""
return property(memoize(change_flags(compute_test_value='off')(f)))
if isinstance(f, str):

def wrapper(fn):
return property(memoize(change_flags(compute_test_value='off')(append_name(f)(fn))))
return wrapper
else:
return property(memoize(change_flags(compute_test_value='off')(f)))


@change_flags(compute_test_value='ignore')
Expand Down Expand Up @@ -134,7 +154,6 @@ class ObjectiveFunction(object):
tf : :class:`TestFunction`
OPVI TestFunction
"""
__hash__ = id

def __init__(self, op, tf):
self.op = op
Expand Down Expand Up @@ -351,7 +370,6 @@ class Operator(object):
-----
For implementing custom operator it is needed to define :func:`Operator.apply` method
"""
__hash__ = id

has_test_function = False
returns_loss = True
Expand Down Expand Up @@ -444,8 +462,6 @@ def collect_shared_to_list(params):


class TestFunction(object):
__hash__ = id

def __init__(self):
self._inited = False
self.shared_params = None
Expand All @@ -469,7 +485,7 @@ def from_function(cls, f):
return obj


class Group(object):
class Group(WithMemoization):
R"""**Base class for grouping variables in VI**
Grouped Approximation is used for modelling mutual dependencies
Expand Down Expand Up @@ -682,8 +698,7 @@ class Group(object):
- Kingma, D. P., & Welling, M. (2014).
`Auto-Encoding Variational Bayes. stat, 1050, 1. <https://arxiv.org/abs/1312.6114>`_
"""
__hash__ = id
# need to be defined in init
# needs to be defined in init
shared_params = None
symbolic_initial = None
replacements = None
Expand Down Expand Up @@ -1064,14 +1079,14 @@ def set_size_and_deterministic(self, node, s, d, more_replacements=None):
:class:`Variable` with applied replacements, ready to use
"""
flat2rand = self.make_size_and_deterministic_replacements(s, d, more_replacements)
node_out = theano.clone(node, flat2rand, strict=False)
node_out = theano.clone(node, flat2rand)
try_to_set_test_value(node, node_out, s)
return node_out

def to_flat_input(self, node):
"""*Dev* - replace vars with flattened view stored in `self.inputs`
"""
return theano.clone(node, self.replacements, strict=False)
return theano.clone(node, self.replacements)

def symbolic_sample_over_posterior(self, node):
"""*Dev* - performs sampling of node applying independent samples from posterior each time.
Expand Down Expand Up @@ -1184,11 +1199,12 @@ def cov(self):
def mean(self):
raise NotImplementedError


group_for_params = Group.group_for_params
group_for_short_name = Group.group_for_short_name


class Approximation(object):
class Approximation(WithMemoization):
"""**Wrapper for grouped approximations**
Wraps list of groups, creates an Approximation instance that collects
Expand Down Expand Up @@ -1217,7 +1233,6 @@ class Approximation(object):
--------
:class:`Group`
"""
__hash__ = id

def __init__(self, groups, model=None):
self._scale_cost_to_minibatch = theano.shared(np.int8(1))
Expand Down Expand Up @@ -1374,12 +1389,13 @@ def set_size_and_deterministic(self, node, s, d, more_replacements=None):
-------
:class:`Variable` with applied replacements, ready to use
"""
_node = node
optimizations = self.get_optimization_replacements(s, d)
flat2rand = self.make_size_and_deterministic_replacements(s, d, more_replacements)
node = theano.clone(node, optimizations)
node_out = theano.clone(node, flat2rand, strict=False)
try_to_set_test_value(node, node_out, s)
return node_out
node = theano.clone(node, flat2rand)
try_to_set_test_value(_node, node, s)
return node

def to_flat_input(self, node):
"""*Dev* - replace vars with flattened view stored in `self.inputs`
Expand Down
5 changes: 3 additions & 2 deletions pymc3/variational/stein.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from theano import theano, tensor as tt
from pymc3.variational.opvi import node_property
from pymc3.variational.test_functions import rbf
from pymc3.theanof import memoize, floatX, change_flags
from pymc3.theanof import floatX, change_flags
from pymc3.memoize import WithMemoization, memoize

__all__ = [
'Stein'
]


class Stein(object):
class Stein(WithMemoization):
def __init__(self, approx, kernel=rbf, use_histogram=True, temperature=1):
self.approx = approx
self.temperature = floatX(temperature)
Expand Down

0 comments on commit 1cdd163

Please sign in to comment.