From e59ebba399be4a874d4e1b3bbb9834a816b90623 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Wed, 31 Mar 2021 08:43:21 +0000 Subject: [PATCH 01/82] resolve merge conflicts --- pymc/tests/test_variational_inference.py | 2 +- pymc/variational/opvi.py | 13 +++++-------- pymc/variational/updates.py | 0 3 files changed, 6 insertions(+), 9 deletions(-) mode change 100755 => 100644 pymc/variational/updates.py diff --git a/pymc/tests/test_variational_inference.py b/pymc/tests/test_variational_inference.py index 49d3df979ca..50b0c7958ad 100644 --- a/pymc/tests/test_variational_inference.py +++ b/pymc/tests/test_variational_inference.py @@ -961,7 +961,7 @@ def test_discrete_not_allowed(): with pm.Model(): mu = pm.Normal("mu", mu=0, sigma=10, size=3) - z = pm.Categorical("z", p=at.ones(3) / 3, size=len(y)) + z = pm.Categorical("z", p=aet.ones(3) / 3, size=len(y)) pm.Normal("y_obs", mu=mu[z], sigma=1.0, observed=y) with pytest.raises(opvi.ParametrizationError): pm.fit(n=1) # fails diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 0503e204a7b..d6abfda5e13 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -833,9 +833,6 @@ def __init__( options=None, **kwargs, ): - # XXX: Needs to be refactored for v4 - raise NotImplementedError("This class needs to be refactored for v4") - if local and not self.supports_batched: raise LocalGroupError("%s does not support local groups" % self.__class__) if local and rowwise: @@ -959,7 +956,7 @@ def __init_group__(self, group): self.group = [get_transformed(var) for var in self.group] # XXX: This needs to be refactored - # self.ordering = ArrayOrdering([]) + self.point_map_info = [] self.replacements = dict() for var in self.group: if var.type.numpy_dtype.name in discrete_types: @@ -975,18 +972,18 @@ def __init_group__(self, group): # self.ordering.size += None # (np.prod(var.dshape[1:])).astype(int) if self.local: # XXX: This needs to be refactored - shape = None # (-1,) + var.dshape[1:] + shape = (-1,) + var.dshape[1:] else: # XXX: This needs to be refactored - shape = None # var.dshape + shape = var.dshape else: # XXX: This needs to be refactored # self.ordering.size += None # var.dsize # XXX: This needs to be refactored - shape = None # var.dshape + shape = var.dshape # end = self.ordering.size # XXX: This needs to be refactored - vmap = None # VarMap(var.name, slice(begin, end), shape, var.dtype) + vmap = (var.name, shape, var.dtype) # self.ordering.vmap.append(vmap) # self.ordering.by_name[vmap.var] = vmap vr = self.input[..., vmap.slc].reshape(shape).astype(vmap.dtyp) diff --git a/pymc/variational/updates.py b/pymc/variational/updates.py old mode 100755 new mode 100644 From 8aa290f6fd3b5a9f7810bfa3d26ff64d03afe049 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sun, 6 Jun 2021 10:53:20 +0000 Subject: [PATCH 02/82] start fixing things --- pymc/tests/test_variational_inference.py | 8 ++--- pymc/variational/approximations.py | 2 +- pymc/variational/opvi.py | 43 +++++++++++++----------- 3 files changed, 29 insertions(+), 24 deletions(-) diff --git a/pymc/tests/test_variational_inference.py b/pymc/tests/test_variational_inference.py index 50b0c7958ad..971bbcd729f 100644 --- a/pymc/tests/test_variational_inference.py +++ b/pymc/tests/test_variational_inference.py @@ -40,10 +40,10 @@ from pymc.variational.inference import ADVI, ASVGD, NFVI, SVGD, FullRankADVI, fit from pymc.variational.opvi import Approximation, Group -# pytestmark = pytest.mark.usefixtures("strict_float32", "seeded_test") -pytestmark = pytest.mark.xfail( - reason="These tests rely on Group, which hasn't been refactored for v4" -) +pytestmark = pytest.mark.usefixtures("strict_float32", "seeded_test") +# pytestmark = pytest.mark.xfail( +# reason="These tests rely on Group, which hasn't been refactored for v4" +# ) @pytest.mark.parametrize("diff", ["relative", "absolute"]) diff --git a/pymc/variational/approximations.py b/pymc/variational/approximations.py index 5392e631332..5d1563610bc 100644 --- a/pymc/variational/approximations.py +++ b/pymc/variational/approximations.py @@ -84,7 +84,7 @@ def create_shared_params(self, start=None): start = np.tile(start, (self.bdim, 1)) rho = np.tile(rho, (self.bdim, 1)) return { - "mu": aesara.shared(pm.floatX(start), "mu"), + "mu": aesara.shared(pm.floatX(start.data), "mu"), "rho": aesara.shared(pm.floatX(rho), "rho"), } diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index d6abfda5e13..277e10b5b5b 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -851,6 +851,7 @@ def __init__( self.group = group self.user_params = params self._user_params = None + self.replacements = dict() # save this stuff to use in __init_group__ later self._kwargs = kwargs if self.group is not None: @@ -953,16 +954,21 @@ def __init_group__(self, group): self.input = self._input_type(self.__class__.__name__ + "_symbolic_input") # I do some staff that is not supported by standard __init__ # so I have to to it by myself - self.group = [get_transformed(var) for var in self.group] - # XXX: This needs to be refactored - self.point_map_info = [] - self.replacements = dict() + # 1) we need initial point (transformed space) + model_initial_point = self.model.initial_point + + # 2) we'll work with a single group, a subset of the model + # here we need to create a mapping to replace value_vars with slices from the approximation + start_idx = 0 for var in self.group: if var.type.numpy_dtype.name in discrete_types: raise ParametrizationError(f"Discrete variables are not supported by VI: {var}") - begin = self.ddim + # 3) This is the way to infer shape and dtype of the variable + test_var = model_initial_point[var.tag.value_var.name] if self.batched: + # Leave a more complicated case for future work + raise NotImplementedError("not yet ready") if var.ndim < 1: if self.local: raise LocalGroupError("Local variable should not be scalar") @@ -977,18 +983,16 @@ def __init_group__(self, group): # XXX: This needs to be refactored shape = var.dshape else: - # XXX: This needs to be refactored - # self.ordering.size += None # var.dsize - # XXX: This needs to be refactored - shape = var.dshape - # end = self.ordering.size - # XXX: This needs to be refactored - vmap = (var.name, shape, var.dtype) - # self.ordering.vmap.append(vmap) - # self.ordering.by_name[vmap.var] = vmap - vr = self.input[..., vmap.slc].reshape(shape).astype(vmap.dtyp) - vr.name = vmap.var + "_vi_replacement" - self.replacements[var] = vr + shape = test_var.shape + dtype = test_var.dtype + size = test_var.size + # TODO: There was self.ordering used in other util funcitons + vr = self.input[..., start_idx:start_idx+size].reshape(shape).astype(dtype) + vr.name = var.tag.value_var.name + "_vi_replacement" + self.replacements[var.tag.value_var] = vr + + start_idx += size + self._ddim = start_idx def _finalize_init(self): """*Dev* - clean up after init""" @@ -1054,8 +1058,8 @@ def ndim(self): @property def ddim(self): - # XXX: This needs to be refactored - return None # self.ordering.size + # TODO: This needs to be refactored + return self._ddim # self.ordering.size def _new_initial(self, size, deterministic, more_replacements=None): """*Dev* - allocates new initial random generator @@ -1602,6 +1606,7 @@ def vars_names(vs): @node_property def sample_dict_fn(self): + # TODO: this breaks s = at.iscalar() names = [v.name for v in self.model.free_RVs] sampled = [self.rslice(name) for name in names] From f96c626384eaf462bda3a3e9de8a292db863fa52 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Mon, 7 Jun 2021 08:55:02 +0000 Subject: [PATCH 03/82] make a simple test pass --- pymc/variational/opvi.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 277e10b5b5b..6b47d80635e 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -852,6 +852,7 @@ def __init__( self.user_params = params self._user_params = None self.replacements = dict() + self.ordering = dict() # save this stuff to use in __init_group__ later self._kwargs = kwargs if self.group is not None: @@ -990,7 +991,12 @@ def __init_group__(self, group): vr = self.input[..., start_idx:start_idx+size].reshape(shape).astype(dtype) vr.name = var.tag.value_var.name + "_vi_replacement" self.replacements[var.tag.value_var] = vr - + self.ordering[var.tag.value_var.name] = ( + var.tag.value_var.name, + slice(start_idx, start_idx+size), + shape, + dtype + ) start_idx += size self._ddim = start_idx @@ -1590,7 +1596,7 @@ def rslice(self, name): """ def vars_names(vs): - return {v.name for v in vs} + return {v.tag.value_var.name for v in vs} for vars_, random, ordering in zip( self.collect("group"), self.symbolic_randoms, self.collect("ordering") @@ -1608,43 +1614,37 @@ def vars_names(vs): def sample_dict_fn(self): # TODO: this breaks s = at.iscalar() - names = [v.name for v in self.model.free_RVs] + names = [v.tag.value_var.name for v in self.model.free_RVs] sampled = [self.rslice(name) for name in names] sampled = self.set_size_and_deterministic(sampled, s, 0) sample_fn = aesara.function([s], sampled) def inner(draws=100): _samples = sample_fn(draws) - return {v_.name: s_ for v_, s_ in zip(self.model.free_RVs, _samples)} + return {v_: s_ for v_, s_ in zip(names, _samples)} return inner - def sample(self, draws=500, include_transformed=True): + def sample(self, draws=500): """Draw samples from variational posterior. Parameters ---------- draws: `int` Number of random samples. - include_transformed: `bool` - If True, transformed variables are also sampled. Default is False. Returns ------- trace: :class:`pymc.backends.base.MultiTrace` Samples drawn from variational posterior. """ - vars_sampled = get_default_varnames( - [v.tag.value_var for v in self.model.unobserved_RVs], - include_transformed=include_transformed, - ) + # TODO: check for include_transformed case + samples = self.sample_dict_fn(draws) # type: dict points = ({name: records[i] for name, records in samples.items()} for i in range(draws)) trace = NDArray( model=self.model, - vars=vars_sampled, - test_point={name: records[0] for name, records in samples.items()}, ) try: trace.setup(draws=draws, chain=0) From 0af6dac97a5a2b40f48b2d2ba7f44152a4acd906 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Fri, 11 Jun 2021 14:41:15 +0000 Subject: [PATCH 04/82] fix some more tests --- pymc/tests/test_variational_inference.py | 3 ++- pymc/variational/approximations.py | 18 ++++++++++++------ pymc/variational/opvi.py | 10 +++++++--- 3 files changed, 21 insertions(+), 10 deletions(-) diff --git a/pymc/tests/test_variational_inference.py b/pymc/tests/test_variational_inference.py index 971bbcd729f..b2857ab4524 100644 --- a/pymc/tests/test_variational_inference.py +++ b/pymc/tests/test_variational_inference.py @@ -129,7 +129,8 @@ def test_init_groups(three_var_model, raises, grouping): else: assert {pm.util.get_transformed(z) for z in g} == set(ig.group) else: - assert approx.ndim == three_var_model.ndim + model_dim = sum(v.size for v in three_var_model.initial_point.values()) + assert approx.ndim == model_dim @pytest.fixture( diff --git a/pymc/variational/approximations.py b/pymc/variational/approximations.py index 5d1563610bc..0734a15d3d3 100644 --- a/pymc/variational/approximations.py +++ b/pymc/variational/approximations.py @@ -100,7 +100,7 @@ def symbolic_logq_not_scaled(self): z0 = self.symbolic_initial std = rho2sigma(self.rho) logdet = at.log(std) - logq = pm.Normal.dist().logp(z0) - logdet + logq = pm.Normal.logp(z0, 0, 1) - logdet return logq.sum(range(1, logq.ndim)) @@ -150,6 +150,8 @@ def L(self): else: L = at.zeros((self.ddim, self.ddim)) L = at.set_subtensor(L[self.tril_indices], self.params_dict["L_tril"]) + Ld = L[..., np.arange(self.ddim), np.arange(self.ddim)] + L = at.set_subtensor(Ld, rho2sigma(Ld)) return L @node_property @@ -159,7 +161,7 @@ def mean(self): @node_property def cov(self): L = self.L - if self.batched: + if self.batched: return at.batched_dot(L, L.swapaxes(-1, -2)) else: return L.dot(L.T) @@ -182,9 +184,9 @@ def tril_indices(self): @node_property def symbolic_logq_not_scaled(self): - z = self.symbolic_random + z0 = self.symbolic_initial if self.batched: - + raise NotImplementedError def logq(z_b, mu_b, L_b): return pm.MvNormal.dist(mu=mu_b, chol=L_b).logp(z_b) @@ -193,7 +195,11 @@ def logq(z_b, mu_b, L_b): # output shape is (batch, samples) return aesara.scan(logq, [z.swapaxes(0, 1), self.mean, self.L])[0].sum(0) else: - return pm.MvNormal.dist(mu=self.mean, chol=self.L).logp(z) + # return pm.MvNormal.dist(mu=self.mean, chol=self.L).logp(z) + logdet = at.sum(at.diagonal(self.L, 0, self.L.ndim - 2, self.L.ndim - 1), axis=-1) + logq = pm.Normal.logp(z0, 0, 1) - logdet + return logq.sum(range(1, logq.ndim)) + @node_property def symbolic_random(self): @@ -483,7 +489,7 @@ def params(self): @node_property def symbolic_logq_not_scaled(self): z0 = self.symbolic_initial - q0 = pm.Normal.dist().logp(z0).sum(range(1, z0.ndim)) + q0 = pm.Normal.logp(z0, 0, 1).sum(range(1, z0.ndim)) return q0 - self.flow.sum_logdets @property diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 6b47d80635e..7b624806835 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -998,7 +998,6 @@ def __init_group__(self, group): dtype ) start_idx += size - self._ddim = start_idx def _finalize_init(self): """*Dev* - clean up after init""" @@ -1060,12 +1059,17 @@ def bdim(self): @node_property def ndim(self): # XXX: This needs to be refactored - return None # self.ordering.size * self.bdim + if self.batched: + # return None # self.ordering.size * self.bdim + # TODO: add support for batching + raise NotImplementedError("not implemented for batching") + else: + return self.ddim @property def ddim(self): # TODO: This needs to be refactored - return self._ddim # self.ordering.size + return sum(s.stop - s.start for _, s, _, _ in self.ordering.values()) def _new_initial(self, size, deterministic, more_replacements=None): """*Dev* - allocates new initial random generator From 7e60bcc8ee230cf425cab896a74b772d8348f4c1 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Fri, 11 Jun 2021 15:18:37 +0000 Subject: [PATCH 05/82] fix some more tests --- pymc/variational/approximations.py | 6 +++--- pymc/variational/opvi.py | 7 +++---- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/pymc/variational/approximations.py b/pymc/variational/approximations.py index 0734a15d3d3..e6cabb057f2 100644 --- a/pymc/variational/approximations.py +++ b/pymc/variational/approximations.py @@ -133,7 +133,7 @@ def create_shared_params(self, start=None): if self.batched: start = start[self.group[0].name][0] else: - start = DictToArrayBijection.map(start) + start = DictToArrayBijection.map(start).data n = self.ddim L_tril = np.eye(n)[np.tril_indices(n)].astype(aesara.config.floatX) if self.batched: @@ -250,7 +250,7 @@ def create_shared_params(self, trace=None, size=None, jitter=1, start=None): start_ = self.model.initial_point.copy() self.model.update_start_vals(start_, start) start = start_ - start = pm.floatX(DictToArrayBijection.map(start)) + start = pm.floatX(DictToArrayBijection.map(start).data) # Initialize particles histogram = np.tile(start, (size, 1)) histogram += pm.floatX(np.random.normal(0, jitter, histogram.shape)) @@ -260,7 +260,7 @@ def create_shared_params(self, trace=None, size=None, jitter=1, start=None): i = 0 for t in trace.chains: for j in range(len(trace)): - histogram[i] = DictToArrayBijection.map(trace.point(j, t)) + histogram[i] = DictToArrayBijection.map(trace.point(j, t)).data i += 1 return dict(histogram=aesara.shared(pm.floatX(histogram), "histogram")) diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 7b624806835..22b30c900a0 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -1537,11 +1537,10 @@ def symbolic_sample_over_posterior(self, node): """ node = self.to_flat_input(node) - def sample(*post, node, inputs): - node, inputs = post[-2:] - return aesara.clone_replace(node, dict(zip(inputs, post))) + def sample(*post): + return aesara.clone_replace(node, dict(zip(self.inputs, post))) - nodes, _ = aesara.scan(sample, self.symbolic_randoms, non_sequences=[node, inputs]) + nodes, _ = aesara.scan(sample, self.symbolic_randoms) return nodes def symbolic_single_sample(self, node): From e0fbb98350b4feac982b25d817834ef6c999244f Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Fri, 18 Jun 2021 14:34:02 +0000 Subject: [PATCH 06/82] add scaling for VI --- pymc/model.py | 65 +++++++++++++++++++++++++++++++++++++++- pymc/variational/opvi.py | 4 +-- 2 files changed, 66 insertions(+), 3 deletions(-) diff --git a/pymc/model.py b/pymc/model.py index e578c6feffb..eb1c005cfd4 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -1223,6 +1223,7 @@ def register_rv( name = self.name_for(name) rv_var.name = name rv_var.tag.total_size = total_size + rv_var.tag.scaling = _get_scaling(total_size, shape=rv_var.shape, ndim=rv_var.ndim) # Associate previously unknown dimension names with # the length of the corresponding RV dimension. @@ -1860,6 +1861,68 @@ def Deterministic(name, var, model=None, dims=None, auto=False): return var +def _get_scaling(total_size, shape, ndim): + """ + Gets scaling constant for logp + Parameters + ---------- + total_size: int or list[int] + shape: shape + shape to scale + ndim: int + ndim hint + Returns + ------- + scalar + """ + if total_size is None: + coef = 1. + elif isinstance(total_size, int): + if ndim >= 1: + denom = shape[0] + else: + denom = 1 + coef = total_size / denom + elif isinstance(total_size, (list, tuple)): + if not all(isinstance(i, int) for i in total_size if (i is not Ellipsis and i is not None)): + raise TypeError( + "Unrecognized `total_size` type, expected " + "int or list of ints, got %r" % total_size + ) + if Ellipsis in total_size: + sep = total_size.index(Ellipsis) + begin = total_size[:sep] + end = total_size[sep + 1 :] + if Ellipsis in end: + raise ValueError( + "Double Ellipsis in `total_size` is restricted, got %r" % total_size + ) + else: + begin = total_size + end = [] + if (len(begin) + len(end)) > ndim: + raise ValueError( + "Length of `total_size` is too big, " + "number of scalings is bigger that ndim, got %r" % total_size + ) + elif (len(begin) + len(end)) == 0: + coef = 1. + if len(end) > 0: + shp_end = shape[-len(end) :] + else: + shp_end = np.asarray([]) + shp_begin = shape[: len(begin)] + begin_coef = [t / shp_begin[i] for i, t in enumerate(begin) if t is not None] + end_coef = [t / shp_end[i] for i, t in enumerate(end) if t is not None] + coefs = begin_coef + end_coef + coef = at.prod(coefs) + else: + raise TypeError( + "Unrecognized `total_size` type, expected int or list of ints, got %r" % total_size + ) + return at.as_tensor(coef, dtype=aesara.config.floatX) + + def Potential(name, var, model=None): """Add an arbitrary factor potential to the model likelihood @@ -1874,7 +1937,7 @@ def Potential(name, var, model=None): """ model = modelcontext(model) var.name = model.name_for(name) - var.tag.scaling = None + var.tag.scaling = 1. model.potentials.append(var) model.add_random_variable(var) diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 22b30c900a0..ca3574c4165 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -1217,7 +1217,7 @@ def make_size_and_deterministic_replacements(self, s, d, more_replacements=None) @node_property def symbolic_normalizing_constant(self): """*Dev* - normalizing constant for `self.logq`, scales it to `minibatch_size` instead of `total_size`""" - t = self.to_flat_input(at.max([v.scaling for v in self.group])) + t = self.to_flat_input(at.max([v.tag.scaling for v in self.group])) t = self.symbolic_single_sample(t) return pm.floatX(t) @@ -1370,7 +1370,7 @@ def symbolic_normalizing_constant(self): """ t = at.max( self.collect("symbolic_normalizing_constant") - + [var.scaling for var in self.model.observed_RVs] + + [var.tag.scaling for var in self.model.observed_RVs] ) t = at.switch(self._scale_cost_to_minibatch, t, at.constant(1, dtype=t.dtype)) return pm.floatX(t) From e5152171bc00e47fabd3333096703f827dfeb981 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Fri, 18 Jun 2021 15:32:29 +0000 Subject: [PATCH 07/82] add shape check --- pymc/tests/test_variational_inference.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pymc/tests/test_variational_inference.py b/pymc/tests/test_variational_inference.py index b2857ab4524..e55c804825c 100644 --- a/pymc/tests/test_variational_inference.py +++ b/pymc/tests/test_variational_inference.py @@ -922,6 +922,8 @@ def test_replacements(binomial_model_inference): assert p_s.tag.test_value.shape == p_t.tag.test_value.shape sampled = [p_s.eval() for _ in range(100)] assert any(map(operator.ne, sampled[1:], sampled[:-1])) # stochastic + p_z = approx.sample_node(p_t, deterministic=True, size=10) + assert p_z.shape.eval() == (10, ) p_d = approx.sample_node(p_t, deterministic=True) sampled = [p_d.eval() for _ in range(100)] From 6dfc18c8db3f8abb0a43e521119d72a49fe790ce Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Fri, 18 Jun 2021 15:59:14 +0000 Subject: [PATCH 08/82] aet -> at --- pymc/tests/test_variational_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/tests/test_variational_inference.py b/pymc/tests/test_variational_inference.py index e55c804825c..eed31bfc1e6 100644 --- a/pymc/tests/test_variational_inference.py +++ b/pymc/tests/test_variational_inference.py @@ -964,7 +964,7 @@ def test_discrete_not_allowed(): with pm.Model(): mu = pm.Normal("mu", mu=0, sigma=10, size=3) - z = pm.Categorical("z", p=aet.ones(3) / 3, size=len(y)) + z = pm.Categorical("z", p=at.ones(3) / 3, size=len(y)) pm.Normal("y_obs", mu=mu[z], sigma=1.0, observed=y) with pytest.raises(opvi.ParametrizationError): pm.fit(n=1) # fails From 39e635b38c45a2fdbfb3a7bef442263a29c774cc Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Mon, 21 Jun 2021 15:20:57 +0000 Subject: [PATCH 09/82] use rvs_to_values from the model in opi.py --- pymc/variational/opvi.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index ca3574c4165..3c7539ffa50 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -966,7 +966,8 @@ def __init_group__(self, group): if var.type.numpy_dtype.name in discrete_types: raise ParametrizationError(f"Discrete variables are not supported by VI: {var}") # 3) This is the way to infer shape and dtype of the variable - test_var = model_initial_point[var.tag.value_var.name] + value_var = self.model.rvs_to_values[var] + test_var = model_initial_point[value_var.name] if self.batched: # Leave a more complicated case for future work raise NotImplementedError("not yet ready") @@ -989,10 +990,10 @@ def __init_group__(self, group): size = test_var.size # TODO: There was self.ordering used in other util funcitons vr = self.input[..., start_idx:start_idx+size].reshape(shape).astype(dtype) - vr.name = var.tag.value_var.name + "_vi_replacement" - self.replacements[var.tag.value_var] = vr - self.ordering[var.tag.value_var.name] = ( - var.tag.value_var.name, + vr.name = value_var.name + "_vi_replacement" + self.replacements[value_var] = vr + self.ordering[value_var.name] = ( + value_var.name, slice(start_idx, start_idx+size), shape, dtype @@ -1599,7 +1600,7 @@ def rslice(self, name): """ def vars_names(vs): - return {v.tag.value_var.name for v in vs} + return {self.model.rvs_to_values[v].name for v in vs} for vars_, random, ordering in zip( self.collect("group"), self.symbolic_randoms, self.collect("ordering") @@ -1617,7 +1618,7 @@ def vars_names(vs): def sample_dict_fn(self): # TODO: this breaks s = at.iscalar() - names = [v.tag.value_var.name for v in self.model.free_RVs] + names = [self.model.rvs_to_values[v].name for v in self.model.free_RVs] sampled = [self.rslice(name) for name in names] sampled = self.set_size_and_deterministic(sampled, s, 0) sample_fn = aesara.function([s], sampled) From 9f61021e228975a42779a1b914f8a321e2326ed6 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Mon, 21 Jun 2021 16:39:59 +0000 Subject: [PATCH 10/82] refactor cloning routines (fix pymc references) --- pymc/variational/approximations.py | 4 +-- pymc/variational/opvi.py | 43 +++++++++++++++++++----------- 2 files changed, 30 insertions(+), 17 deletions(-) diff --git a/pymc/variational/approximations.py b/pymc/variational/approximations.py index e6cabb057f2..5750a825aa1 100644 --- a/pymc/variational/approximations.py +++ b/pymc/variational/approximations.py @@ -161,7 +161,7 @@ def mean(self): @node_property def cov(self): L = self.L - if self.batched: + if self.batched: return at.batched_dot(L, L.swapaxes(-1, -2)) else: return L.dot(L.T) @@ -187,6 +187,7 @@ def symbolic_logq_not_scaled(self): z0 = self.symbolic_initial if self.batched: raise NotImplementedError + def logq(z_b, mu_b, L_b): return pm.MvNormal.dist(mu=mu_b, chol=L_b).logp(z_b) @@ -199,7 +200,6 @@ def logq(z_b, mu_b, L_b): logdet = at.sum(at.diagonal(self.L, 0, self.L.ndim - 2, self.L.ndim - 1), axis=-1) logq = pm.Normal.logp(z0, 0, 1) - logdet return logq.sum(range(1, logq.ndim)) - @node_property def symbolic_random(self): diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 3c7539ffa50..a8991f8bfd7 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -851,8 +851,9 @@ def __init__( self.group = group self.user_params = params self._user_params = None - self.replacements = dict() - self.ordering = dict() + self.replacements = collections.OrderedDict() + self.value_replacements = collections.OrderedDict() + self.ordering = collections.OrderedDict() # save this stuff to use in __init_group__ later self._kwargs = kwargs if self.group is not None: @@ -958,7 +959,8 @@ def __init_group__(self, group): # 1) we need initial point (transformed space) model_initial_point = self.model.initial_point - + _, replace_to_value_vars = rvs_to_value_vars(self.group, apply_transforms=True) + self.value_replacements.update(replace_to_value_vars) # 2) we'll work with a single group, a subset of the model # here we need to create a mapping to replace value_vars with slices from the approximation start_idx = 0 @@ -989,14 +991,14 @@ def __init_group__(self, group): dtype = test_var.dtype size = test_var.size # TODO: There was self.ordering used in other util funcitons - vr = self.input[..., start_idx:start_idx+size].reshape(shape).astype(dtype) + vr = self.input[..., start_idx : start_idx + size].reshape(shape).astype(dtype) vr.name = value_var.name + "_vi_replacement" self.replacements[value_var] = vr self.ordering[value_var.name] = ( value_var.name, - slice(start_idx, start_idx+size), + slice(start_idx, start_idx + size), shape, - dtype + dtype, ) start_idx += size @@ -1166,6 +1168,7 @@ def set_size_and_deterministic(self, node, s, d, more_replacements=None): def to_flat_input(self, node): """*Dev* - replace vars with flattened view stored in `self.inputs`""" + node = aesara.clone_replace(node, self.value_replacements) return aesara.clone_replace(node, self.replacements) def symbolic_sample_over_posterior(self, node): @@ -1468,6 +1471,13 @@ def datalogp_norm(self): """*Dev* - normalized :math:`E_{q}(data term)`""" return self.datalogp / self.symbolic_normalizing_constant + @property + def value_replacements(self): + """*Dev* - all replacements from groups to replace PyMC random variables with approximation""" + return collections.OrderedDict( + itertools.chain.from_iterable(g.value_replacements.items() for g in self.groups) + ) + @property def replacements(self): """*Dev* - all replacements from groups to replace PyMC random variables with approximation""" @@ -1528,15 +1538,17 @@ def set_size_and_deterministic(self, node, s, d, more_replacements=None): try_to_set_test_value(_node, node, s) return node - def to_flat_input(self, node): + def to_flat_input(self, node, more_replacements=None): """*Dev* - replace vars with flattened view stored in `self.inputs`""" + more_replacements = more_replacements or {} + node = aesara.clone_replace(node, {**self.value_replacements, **more_replacements}) return aesara.clone_replace(node, self.replacements) - def symbolic_sample_over_posterior(self, node): + def symbolic_sample_over_posterior(self, node, more_replacements=None): """*Dev* - performs sampling of node applying independent samples from posterior each time. Note that it is done symbolically and this node needs :func:`set_size_and_deterministic` call """ - node = self.to_flat_input(node) + node = self.to_flat_input(node, more_replacements=more_replacements) def sample(*post): return aesara.clone_replace(node, dict(zip(self.inputs, post))) @@ -1544,12 +1556,12 @@ def sample(*post): nodes, _ = aesara.scan(sample, self.symbolic_randoms) return nodes - def symbolic_single_sample(self, node): + def symbolic_single_sample(self, node, more_replacements=None): """*Dev* - performs sampling of node applying single sample from posterior. Note that it is done symbolically and this node needs :func:`set_size_and_deterministic` call with `size=1` """ - node = self.to_flat_input(node) + node = self.to_flat_input(node, more_replacements=more_replacements) post = [v[0] for v in self.symbolic_randoms] inp = self.inputs return aesara.clone_replace(node, dict(zip(inp, post))) @@ -1585,12 +1597,13 @@ def sample_node(self, node, size=None, deterministic=False, more_replacements=No sampled node(s) with replacements """ node_in = node - node = aesara.clone_replace(node, more_replacements) if size is None: - node_out = self.symbolic_single_sample(node) + node_out = self.symbolic_single_sample(node, more_replacements=more_replacements) else: - node_out = self.symbolic_sample_over_posterior(node) - node_out = self.set_size_and_deterministic(node_out, size, deterministic, more_replacements) + node_out = self.symbolic_sample_over_posterior( + node, more_replacements=more_replacements + ) + node_out = self.set_size_and_deterministic(node_out, size, deterministic) try_to_set_test_value(node_in, node_out, size) return node_out From 8909ac7b10e3a31d490bf3b7d43a4f9e1b126f4c Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Fri, 2 Jul 2021 22:03:19 +0200 Subject: [PATCH 11/82] Run pre-commit and include VI tests in pytest workflow (rebase) --- .github/workflows/pytest.yml | 6 ++++-- pymc/model.py | 6 +++--- pymc/tests/test_variational_inference.py | 2 +- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 95ae3918727..5288e652029 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -61,8 +61,9 @@ jobs: --ignore=pymc/tests/test_idata_conversion.py - | - pymc/tests/test_initvals.py - pymc/tests/test_distributions.py + pymc3/tests/test_initvals.py + pymc3/tests/test_distributions.py + - pymc3/tests/test_variational_inference.py - | pymc/tests/test_modelcontext.py @@ -153,6 +154,7 @@ jobs: os: [windows-latest] floatx: [float32, float64] test-subset: + - pymc3/tests/test_variational_inference.py - | pymc/tests/test_initvals.py pymc/tests/test_distributions_random.py diff --git a/pymc/model.py b/pymc/model.py index eb1c005cfd4..cfe154a6666 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -1876,7 +1876,7 @@ def _get_scaling(total_size, shape, ndim): scalar """ if total_size is None: - coef = 1. + coef = 1.0 elif isinstance(total_size, int): if ndim >= 1: denom = shape[0] @@ -1906,7 +1906,7 @@ def _get_scaling(total_size, shape, ndim): "number of scalings is bigger that ndim, got %r" % total_size ) elif (len(begin) + len(end)) == 0: - coef = 1. + coef = 1.0 if len(end) > 0: shp_end = shape[-len(end) :] else: @@ -1937,7 +1937,7 @@ def Potential(name, var, model=None): """ model = modelcontext(model) var.name = model.name_for(name) - var.tag.scaling = 1. + var.tag.scaling = 1.0 model.potentials.append(var) model.add_random_variable(var) diff --git a/pymc/tests/test_variational_inference.py b/pymc/tests/test_variational_inference.py index eed31bfc1e6..bbe183dbd20 100644 --- a/pymc/tests/test_variational_inference.py +++ b/pymc/tests/test_variational_inference.py @@ -923,7 +923,7 @@ def test_replacements(binomial_model_inference): sampled = [p_s.eval() for _ in range(100)] assert any(map(operator.ne, sampled[1:], sampled[:-1])) # stochastic p_z = approx.sample_node(p_t, deterministic=True, size=10) - assert p_z.shape.eval() == (10, ) + assert p_z.shape.eval() == (10,) p_d = approx.sample_node(p_t, deterministic=True) sampled = [p_d.eval() for _ in range(100)] From 1076fa10c254c10053b983570218880c82eed682 Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Fri, 2 Jul 2021 22:03:19 +0200 Subject: [PATCH 12/82] Run pre-commit and include VI tests in pytest workflow --- pymc/variational/opvi.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index a8991f8bfd7..80b955ba710 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -57,13 +57,11 @@ import pymc as pm -from pymc.aesaraf import at_rng, identity +from pymc.aesaraf import at_rng, identity, rvs_to_value_vars from pymc.backends import NDArray from pymc.model import modelcontext from pymc.util import ( WithMemoization, - get_default_varnames, - get_transformed, locally_cachedmethod, ) from pymc.variational.updates import adagrad_window From 7e73cd74619b08efd905fbe1cd34575de48db3c3 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Wed, 28 Jul 2021 15:37:15 +0000 Subject: [PATCH 13/82] seems like Grouped inference not working --- pymc/tests/test_variational_inference.py | 3 ++- pymc/variational/flows.py | 6 ------ pymc/variational/opvi.py | 25 +++++++++++------------- 3 files changed, 13 insertions(+), 21 deletions(-) diff --git a/pymc/tests/test_variational_inference.py b/pymc/tests/test_variational_inference.py index bbe183dbd20..89a4d246853 100644 --- a/pymc/tests/test_variational_inference.py +++ b/pymc/tests/test_variational_inference.py @@ -233,7 +233,8 @@ def three_var_aevb_approx(three_var_model, three_var_aevb_groups): def test_sample_aevb(three_var_aevb_approx, aevb_initial): - pm.KLqp(three_var_aevb_approx).fit( + inf = pm.KLqp(three_var_aevb_approx) + inf.fit( 1, more_replacements={aevb_initial: np.zeros_like(aevb_initial.get_value())[:1]} ) aevb_initial.set_value(np.random.rand(7, 7).astype("float32")) diff --git a/pymc/variational/flows.py b/pymc/variational/flows.py index cd13069f484..956b1c2ff26 100644 --- a/pymc/variational/flows.py +++ b/pymc/variational/flows.py @@ -225,12 +225,6 @@ def logdet(self): @aesara.config.change_flags(compute_test_value="off") def forward_pass(self, z0): ret = aesara.clone_replace(self.forward, {self.root.z0: z0}) - try: - ret.tag.test_value = np.random.normal(size=z0.tag.test_value.shape).astype( - self.z0.dtype - ) - except AttributeError: - ret.tag.test_value = self.root.z0.tag.test_value return ret __call__ = forward_pass diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 80b955ba710..8c095f3190e 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -73,6 +73,8 @@ class VariationalInferenceError(Exception): """Exception for VI specific cases""" +class NotImplementedInference(VariationalInferenceError, NotImplementedError): + """Marking non functional parts of code""" class ExplicitInferenceError(VariationalInferenceError, TypeError): """Exception for bad explicit inference""" @@ -98,6 +100,7 @@ class LocalGroupError(BatchedGroupError, AEVBInferenceError): """Error raised in case of bad local_rv usage""" + def append_name(name): def wrap(f): if name is None: @@ -855,6 +858,7 @@ def __init__( # save this stuff to use in __init_group__ later self._kwargs = kwargs if self.group is not None: + raise NotImplementedInference("Grouped Inference is not yet supported, open an issue once you need it https://github.com/pymc-devs/pymc3/issues") # init can be delayed self.__init_group__(self.group) @@ -970,25 +974,21 @@ def __init_group__(self, group): test_var = model_initial_point[value_var.name] if self.batched: # Leave a more complicated case for future work - raise NotImplementedError("not yet ready") if var.ndim < 1: if self.local: raise LocalGroupError("Local variable should not be scalar") else: raise BatchedGroupError("Batched variable should not be scalar") - # XXX: This needs to be refactored - # self.ordering.size += None # (np.prod(var.dshape[1:])).astype(int) + size = test_var[0].size if self.local: - # XXX: This needs to be refactored - shape = (-1,) + var.dshape[1:] + shape = (-1,) + test_var.shape[1:] else: - # XXX: This needs to be refactored - shape = var.dshape + shape = test_var.shape else: shape = test_var.shape - dtype = test_var.dtype size = test_var.size # TODO: There was self.ordering used in other util funcitons + dtype = test_var.dtype vr = self.input[..., start_idx : start_idx + size].reshape(shape).astype(dtype) vr.name = value_var.name + "_vi_replacement" self.replacements[value_var] = vr @@ -1050,8 +1050,7 @@ def _new_initial_shape(self, size, dim, more_replacements=None): def bdim(self): if not self.local: if self.batched: - # XXX: This needs to be refactored - return None # self.ordering.vmap[0].shp[0] + return next(iter(self.ordering.values()))[1][0] else: return 1 else: @@ -1061,9 +1060,7 @@ def bdim(self): def ndim(self): # XXX: This needs to be refactored if self.batched: - # return None # self.ordering.size * self.bdim - # TODO: add support for batching - raise NotImplementedError("not implemented for batching") + return self.ordering.size * self.bdim else: return self.ddim @@ -1235,7 +1232,7 @@ def symbolic_logq_not_scaled(self): def symbolic_logq(self): """*Dev* - correctly scaled `self.symbolic_logq_not_scaled`""" if self.local: - s = self.group[0].scaling + s = self.group[0].tag.scaling s = self.to_flat_input(s) s = self.symbolic_single_sample(s) return self.symbolic_logq_not_scaled * s From 64ba83728f7650d36edb6077421cf788d9be43a3 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Tue, 3 Aug 2021 08:04:17 +0000 Subject: [PATCH 14/82] spot an error in a simple test case --- pymc/tests/test_variational_inference.py | 1 + pymc/variational/opvi.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/pymc/tests/test_variational_inference.py b/pymc/tests/test_variational_inference.py index 89a4d246853..9b9a3e5221d 100644 --- a/pymc/tests/test_variational_inference.py +++ b/pymc/tests/test_variational_inference.py @@ -131,6 +131,7 @@ def test_init_groups(three_var_model, raises, grouping): else: model_dim = sum(v.size for v in three_var_model.initial_point.values()) assert approx.ndim == model_dim + trace = approx.sample(100) @pytest.fixture( diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 8c095f3190e..107c2acc16e 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -858,7 +858,6 @@ def __init__( # save this stuff to use in __init_group__ later self._kwargs = kwargs if self.group is not None: - raise NotImplementedInference("Grouped Inference is not yet supported, open an issue once you need it https://github.com/pymc-devs/pymc3/issues") # init can be delayed self.__init_group__(self.group) @@ -947,6 +946,7 @@ def __init_group__(self, group): self.group = group if self.batched and len(group) > 1: if self.local: # better error message + raise NotImplementedInference("Grouped Inference is not yet supported, open an issue once you need it https://github.com/pymc-devs/pymc3/issues") raise LocalGroupError("Local groups with more than 1 variable are not supported") else: raise BatchedGroupError( From 4b91bce647525b3b3a57a53ce58d47e19b82ac5e Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Tue, 3 Aug 2021 08:36:21 +0000 Subject: [PATCH 15/82] fix the test case with grouping --- pymc/tests/test_variational_inference.py | 3 +- pymc/variational/approximations.py | 36 +++++------------------- pymc/variational/opvi.py | 28 ++++++++++++++---- 3 files changed, 31 insertions(+), 36 deletions(-) diff --git a/pymc/tests/test_variational_inference.py b/pymc/tests/test_variational_inference.py index 9b9a3e5221d..2dbd0c575ba 100644 --- a/pymc/tests/test_variational_inference.py +++ b/pymc/tests/test_variational_inference.py @@ -131,7 +131,8 @@ def test_init_groups(three_var_model, raises, grouping): else: model_dim = sum(v.size for v in three_var_model.initial_point.values()) assert approx.ndim == model_dim - trace = approx.sample(100) + trace = approx.sample(100) + @pytest.fixture( diff --git a/pymc/variational/approximations.py b/pymc/variational/approximations.py index 5750a825aa1..fe6e4112bf4 100644 --- a/pymc/variational/approximations.py +++ b/pymc/variational/approximations.py @@ -69,22 +69,13 @@ def __init_group__(self, group): self._finalize_init() def create_shared_params(self, start=None): - if start is None: - start = self.model.initial_point - else: - start_ = start.copy() - self.model.update_start_vals(start_, self.model.initial_point) - start = start_ - if self.batched: - start = start[self.group[0].name][0] - else: - start = DictToArrayBijection.map(start) + start = self._prepare_start(start) rho = np.zeros((self.ddim,)) if self.batched: start = np.tile(start, (self.bdim, 1)) rho = np.tile(rho, (self.bdim, 1)) return { - "mu": aesara.shared(pm.floatX(start.data), "mu"), + "mu": aesara.shared(pm.floatX(start), "mu"), "rho": aesara.shared(pm.floatX(rho), "rho"), } @@ -124,16 +115,7 @@ def __init_group__(self, group): self._finalize_init() def create_shared_params(self, start=None): - if start is None: - start = self.model.initial_point - else: - start_ = start.copy() - self.model.update_start_vals(start_, self.model.initial_point) - start = start_ - if self.batched: - start = start[self.group[0].name][0] - else: - start = DictToArrayBijection.map(start).data + start = self._prepare_start(start) n = self.ddim L_tril = np.eye(n)[np.tril_indices(n)].astype(aesara.config.floatX) if self.batched: @@ -244,13 +226,7 @@ def create_shared_params(self, trace=None, size=None, jitter=1, start=None): if size is None: raise opvi.ParametrizationError("Need `trace` or `size` to initialize") else: - if start is None: - start = self.model.initial_point - else: - start_ = self.model.initial_point.copy() - self.model.update_start_vals(start_, start) - start = start_ - start = pm.floatX(DictToArrayBijection.map(start).data) + start = self._prepare_start(start) # Initialize particles histogram = np.tile(start, (size, 1)) histogram += pm.floatX(np.random.normal(0, jitter, histogram.shape)) @@ -266,7 +242,9 @@ def create_shared_params(self, trace=None, size=None, jitter=1, start=None): def _check_trace(self): trace = self._kwargs.get("trace", None) - if trace is not None and not all([var.name in trace.varnames for var in self.group]): + if trace is not None and not all( + [self.model.rvs_to_values[var].name in trace.varnames for var in self.group] + ): raise ValueError("trace has not all free RVs in the group") def randidx(self, size=None): diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 107c2acc16e..0b403f322df 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -60,12 +60,10 @@ from pymc.aesaraf import at_rng, identity, rvs_to_value_vars from pymc.backends import NDArray from pymc.model import modelcontext -from pymc.util import ( - WithMemoization, - locally_cachedmethod, -) +from pymc.util import WithMemoization, locally_cachedmethod from pymc.variational.updates import adagrad_window from pymc.vartypes import discrete_types +from pymc.blocking import DictToArrayBijection __all__ = ["ObjectiveFunction", "Operator", "TestFunction", "Group", "Approximation"] @@ -73,9 +71,11 @@ class VariationalInferenceError(Exception): """Exception for VI specific cases""" + class NotImplementedInference(VariationalInferenceError, NotImplementedError): """Marking non functional parts of code""" + class ExplicitInferenceError(VariationalInferenceError, TypeError): """Exception for bad explicit inference""" @@ -100,7 +100,6 @@ class LocalGroupError(BatchedGroupError, AEVBInferenceError): """Error raised in case of bad local_rv usage""" - def append_name(name): def wrap(f): if name is None: @@ -861,6 +860,21 @@ def __init__( # init can be delayed self.__init_group__(self.group) + def _prepare_start(self, start=None): + if start is None: + start = self.model.initial_point + else: + start_ = start.copy() + self.model.update_start_vals(start_, self.model.initial_point) + start = start_ + group_vars = {self.model.rvs_to_values[v].name for v in self.group} + start = {k: v for k, v in start.items() if k in group_vars} + if self.batched: + start = start[self.model.rvs_to_values[self.group[0]].name][0] + else: + start = DictToArrayBijection.map(start).data + return start + @classmethod def get_param_spec_for(cls, **kwargs): res = dict() @@ -946,7 +960,9 @@ def __init_group__(self, group): self.group = group if self.batched and len(group) > 1: if self.local: # better error message - raise NotImplementedInference("Grouped Inference is not yet supported, open an issue once you need it https://github.com/pymc-devs/pymc3/issues") + raise NotImplementedInference( + "Grouped Inference is not yet supported, open an issue once you need it https://github.com/pymc-devs/pymc3/issues" + ) raise LocalGroupError("Local groups with more than 1 variable are not supported") else: raise BatchedGroupError( From c81458a2e2842a630e9ac3e56df5db20bacade45 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Tue, 3 Aug 2021 09:02:19 +0000 Subject: [PATCH 16/82] fix sampling with changed shape --- pymc/variational/opvi.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 0b403f322df..f791a90184e 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -1673,6 +1673,7 @@ def sample(self, draws=500): trace = NDArray( model=self.model, + test_point={name: records[1] for name, records in samples.items()}, ) try: trace.setup(draws=draws, chain=0) From 11ef0b6561a78106f25677544c70924cb5c23b6e Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Tue, 3 Aug 2021 09:03:37 +0000 Subject: [PATCH 17/82] remove not implemented error for local inference --- pymc/variational/opvi.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index f791a90184e..293adac2efa 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -960,9 +960,6 @@ def __init_group__(self, group): self.group = group if self.batched and len(group) > 1: if self.local: # better error message - raise NotImplementedInference( - "Grouped Inference is not yet supported, open an issue once you need it https://github.com/pymc-devs/pymc3/issues" - ) raise LocalGroupError("Local groups with more than 1 variable are not supported") else: raise BatchedGroupError( From 98dd81dae838993e3c52ea772d729b5408c21305 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sun, 8 Aug 2021 13:18:26 +0000 Subject: [PATCH 18/82] support inferencedata --- pymc/variational/opvi.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 293adac2efa..a9baaff8f31 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -1650,13 +1650,15 @@ def inner(draws=100): return inner - def sample(self, draws=500): + def sample(self, draws=500, return_inferencedata=True): """Draw samples from variational posterior. Parameters ---------- draws: `int` Number of random samples. + return_inferencedata: `bool` + Return trace in Arviz format Returns ------- @@ -1678,7 +1680,13 @@ def sample(self, draws=500): trace.record(point) finally: trace.close() - return pm.sampling.MultiTrace([trace]) + + trace = pm.sampling.MultiTrace([trace]) + if not return_inferencedata: + return trace + else: + import arviz as az + return az.from_pymc3(trace) @property def ndim(self): From c08eea3fd35c05fa0c0faf917ddf39d04f97cab8 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sun, 8 Aug 2021 14:02:23 +0000 Subject: [PATCH 19/82] get rid of shape error for batched mvnormal --- pymc/variational/approximations.py | 18 +++--------------- pymc/variational/opvi.py | 2 +- 2 files changed, 4 insertions(+), 16 deletions(-) diff --git a/pymc/variational/approximations.py b/pymc/variational/approximations.py index fe6e4112bf4..62abf306815 100644 --- a/pymc/variational/approximations.py +++ b/pymc/variational/approximations.py @@ -167,21 +167,9 @@ def tril_indices(self): @node_property def symbolic_logq_not_scaled(self): z0 = self.symbolic_initial - if self.batched: - raise NotImplementedError - - def logq(z_b, mu_b, L_b): - return pm.MvNormal.dist(mu=mu_b, chol=L_b).logp(z_b) - - # it's gonna be so slow - # scan is computed over batch and then summed up - # output shape is (batch, samples) - return aesara.scan(logq, [z.swapaxes(0, 1), self.mean, self.L])[0].sum(0) - else: - # return pm.MvNormal.dist(mu=self.mean, chol=self.L).logp(z) - logdet = at.sum(at.diagonal(self.L, 0, self.L.ndim - 2, self.L.ndim - 1), axis=-1) - logq = pm.Normal.logp(z0, 0, 1) - logdet - return logq.sum(range(1, logq.ndim)) + logdet = at.sum(at.diagonal(self.L, 0, self.L.ndim - 2, self.L.ndim - 1), axis=-1, keepdims=True) + logq = pm.Normal.logp(z0, 0, 1) - logdet + return logq.sum(range(1, logq.ndim)) @node_property def symbolic_random(self): diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index a9baaff8f31..327d48debad 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -1672,7 +1672,7 @@ def sample(self, draws=500, return_inferencedata=True): trace = NDArray( model=self.model, - test_point={name: records[1] for name, records in samples.items()}, + test_point={name: records[0] for name, records in samples.items()}, ) try: trace.setup(draws=draws, chain=0) From 77443f5a228500d235378fd46a29889fd49d01ba Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sun, 8 Aug 2021 14:26:09 +0000 Subject: [PATCH 20/82] do not support AEVB with an error message --- pymc/tests/test_variational_inference.py | 5 +++-- pymc/variational/approximations.py | 1 + pymc/variational/opvi.py | 2 ++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/pymc/tests/test_variational_inference.py b/pymc/tests/test_variational_inference.py index 2dbd0c575ba..83f24971d08 100644 --- a/pymc/tests/test_variational_inference.py +++ b/pymc/tests/test_variational_inference.py @@ -835,8 +835,9 @@ def aevb_model(): pm.Normal("y", size=(2,)) x = model.x y = model.y - mu = aesara.shared(x.init_value) - rho = aesara.shared(np.zeros_like(x.init_value)) + xr = model.initial_values[model.rvs_to_values[x]] + mu = aesara.shared(xr) + rho = aesara.shared(np.zeros_like(xr)) return {"model": model, "y": y, "x": x, "replace": dict(mu=mu, rho=rho)} diff --git a/pymc/variational/approximations.py b/pymc/variational/approximations.py index 62abf306815..a3943ebeda5 100644 --- a/pymc/variational/approximations.py +++ b/pymc/variational/approximations.py @@ -346,6 +346,7 @@ class NormalizingFlowGroup(Group): @aesara.config.change_flags(compute_test_value="off") def __init_group__(self, group): + raise NotImplementedInference("Normalizing flows are not yet ported to v4") super().__init_group__(group) # objects to be resolved # 1. string formula diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 327d48debad..625c9058ff8 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -955,6 +955,8 @@ def _input_type(self, name): def __init_group__(self, group): if not group: raise GroupError("Got empty group") + if self.local: + raise NotImplementedInference("Local inferene aka AEVB is not fully supported in v4") if self.group is None: # delayed init self.group = group From 215f92b8fb470c850b1d5f11af090908de931854 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Wed, 8 Sep 2021 09:11:30 +0000 Subject: [PATCH 21/82] fix some meore tests --- pymc/backends/arviz.py | 1 + pymc/tests/test_variational_inference.py | 22 +++++++++++++++++++--- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index 1aed0b6c4e1..847bbb03c88 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -527,6 +527,7 @@ def is_data(name, var) -> bool: and var not in self.model.observed_RVs and var not in self.model.free_RVs and var not in self.model.potentials + and var not in self.model.value_vars and (self.observations is None or name not in self.observations) and isinstance(var, (Constant, SharedVariable)) ) diff --git a/pymc/tests/test_variational_inference.py b/pymc/tests/test_variational_inference.py index 83f24971d08..20471b1e818 100644 --- a/pymc/tests/test_variational_inference.py +++ b/pymc/tests/test_variational_inference.py @@ -41,9 +41,15 @@ from pymc.variational.opvi import Approximation, Group pytestmark = pytest.mark.usefixtures("strict_float32", "seeded_test") -# pytestmark = pytest.mark.xfail( -# reason="These tests rely on Group, which hasn't been refactored for v4" -# ) + +def ignore_not_implemented_inference(func): + @functools.wraps(func) + def new_test(*args, **kwargs): + try: + return func(*args, **kwargs) + except NotImplementedInference: + pytest.xfail("NotImplementedInference") + return new_test @pytest.mark.parametrize("diff", ["relative", "absolute"]) @@ -114,6 +120,7 @@ def three_var_model(): (not_raises(), {MeanFieldGroup: ["one"], FullRankGroup: ["two", "three"]}), ], ) +@ignore_not_implemented_inference def test_init_groups(three_var_model, raises, grouping): with raises, three_var_model: approxes, groups = zip(*grouping.items()) @@ -152,6 +159,7 @@ def test_init_groups(three_var_model, raises, grouping): ], ids=lambda t: ", ".join(f"{k.__name__}: {v[0]}" for k, v in t[1].items()), ) +@ignore_not_implemented_inference def three_var_groups(request, three_var_model): kw, grouping = request.param approxes, groups = zip(*grouping.items()) @@ -211,6 +219,7 @@ def parametric_grouped_approxes(request): @pytest.fixture +@ignore_not_implemented_inference def three_var_aevb_groups(parametric_grouped_approxes, three_var_model, aevb_initial): one_initial_value = three_var_model.initial_point[three_var_model.one.tag.value_var.name] dsize = np.prod(one_initial_value.shape[1:]) @@ -267,6 +276,7 @@ def test_replacements_in_sample_node_aevb(three_var_aevb_approx, aevb_initial): ).eval({inp: np.random.rand(7, 7).astype("float32")}) +@ignore_not_implemented_inference def test_vae(): minibatch_size = 10 data = pm.floatX(np.random.rand(100)) @@ -298,6 +308,7 @@ def test_vae(): ) +@ignore_not_implemented_inference def test_logq_mini_1_sample_1_var(parametric_grouped_approxes, three_var_model): cls, kw = parametric_grouped_approxes approx = cls([three_var_model.one], model=three_var_model, **kw) @@ -306,6 +317,7 @@ def test_logq_mini_1_sample_1_var(parametric_grouped_approxes, three_var_model): logq.eval() +@ignore_not_implemented_inference def test_logq_mini_2_sample_2_var(parametric_grouped_approxes, three_var_model): cls, kw = parametric_grouped_approxes approx = cls([three_var_model.one, three_var_model.two], model=three_var_model, **kw) @@ -388,6 +400,7 @@ def test_logq_globals(three_var_approx): (not_raises(), "empirical", EmpiricalGroup, {"size": 100}), ], ) +@ignore_not_implemented_inference def test_group_api_vfam(three_var_model, raises, vfam, type_, kw): with three_var_model, raises: g = Group([three_var_model.one], vfam, **kw) @@ -463,6 +476,7 @@ def test_group_api_vfam(three_var_model, raises, vfam, type_, kw): (not_raises(), dict(histogram=np.ones((20, 10, 2), "float32")), EmpiricalGroup, {}, None), ], ) +@ignore_not_implemented_inference def test_group_api_params(three_var_model, raises, params, type_, kw, formula): with three_var_model, raises: g = Group([three_var_model.one], params=params, **kw) @@ -487,6 +501,7 @@ def test_group_api_params(three_var_model, raises, params, type_, kw, formula): (NormalizingFlowGroup, NormalizingFlow, {}), ], ) +@ignore_not_implemented_inference def test_single_group_shortcuts(three_var_model, approx, kw, gcls): with three_var_model: a = approx(**kw) @@ -696,6 +711,7 @@ def init_(**kw): @pytest.fixture(scope="function") +@ignore_not_implemented_inference def inference(inference_spec, simple_model): with simple_model: return inference_spec() From 94a28e5d1139691d47aeabf812402a6d68844fcc Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Thu, 16 Sep 2021 11:28:03 +0000 Subject: [PATCH 22/82] fix some more tests --- pymc/tests/test_variational_inference.py | 5 +++-- pymc/variational/approximations.py | 6 ++++-- pymc/variational/opvi.py | 8 +++++--- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/pymc/tests/test_variational_inference.py b/pymc/tests/test_variational_inference.py index 20471b1e818..fc80c5e607c 100644 --- a/pymc/tests/test_variational_inference.py +++ b/pymc/tests/test_variational_inference.py @@ -758,8 +758,8 @@ def test_fit_oo(inference, fit_kwargs, simple_model_data): trace = inference.fit(**fit_kwargs).sample(10000) mu_post = simple_model_data["mu_post"] d = simple_model_data["d"] - np.testing.assert_allclose(np.mean(trace["mu"]), mu_post, rtol=0.05) - np.testing.assert_allclose(np.std(trace["mu"]), np.sqrt(1.0 / d), rtol=0.2) + np.testing.assert_allclose(np.mean(trace.posterior["mu"]), mu_post, rtol=0.05) + np.testing.assert_allclose(np.std(trace.posterior["mu"]), np.sqrt(1.0 / d), rtol=0.2) def test_profile(inference): @@ -857,6 +857,7 @@ def aevb_model(): return {"model": model, "y": y, "x": x, "replace": dict(mu=mu, rho=rho)} +@ignore_not_implemented_inference def test_aevb(inference_spec, aevb_model): # add to inference that supports aevb x = aevb_model["x"] diff --git a/pymc/variational/approximations.py b/pymc/variational/approximations.py index a3943ebeda5..8d9826f3557 100644 --- a/pymc/variational/approximations.py +++ b/pymc/variational/approximations.py @@ -167,8 +167,10 @@ def tril_indices(self): @node_property def symbolic_logq_not_scaled(self): z0 = self.symbolic_initial - logdet = at.sum(at.diagonal(self.L, 0, self.L.ndim - 2, self.L.ndim - 1), axis=-1, keepdims=True) - logq = pm.Normal.logp(z0, 0, 1) - logdet + logdet = at.sum( + at.log(at.diagonal(self.L, 0, self.L.ndim - 2, self.L.ndim - 1)), axis=-1, keepdims=True + ) + logq = pm.Normal.logp(z0, 0, 1).sum(-1, keepdims=True) + logdet return logq.sum(range(1, logq.ndim)) @node_property diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 625c9058ff8..7eea855918d 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -1652,7 +1652,7 @@ def inner(draws=100): return inner - def sample(self, draws=500, return_inferencedata=True): + def sample(self, draws=500, return_inferencedata=True, **kwargs): """Draw samples from variational posterior. Parameters @@ -1668,6 +1668,7 @@ def sample(self, draws=500, return_inferencedata=True): Samples drawn from variational posterior. """ # TODO: check for include_transformed case + kwargs["log_likelihood"] = False samples = self.sample_dict_fn(draws) # type: dict points = ({name: records[i] for name, records in samples.items()} for i in range(draws)) @@ -1682,13 +1683,14 @@ def sample(self, draws=500, return_inferencedata=True): trace.record(point) finally: trace.close() - + trace = pm.sampling.MultiTrace([trace]) if not return_inferencedata: return trace else: import arviz as az - return az.from_pymc3(trace) + + return az.from_pymc3(trace, **kwargs) @property def ndim(self): From 509f7baaaecfa0ac3005113f05b41ae57ab6e851 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Thu, 16 Sep 2021 11:48:50 +0000 Subject: [PATCH 23/82] fix full rank test --- pymc/tests/test_variational_inference.py | 2 +- pymc/variational/approximations.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pymc/tests/test_variational_inference.py b/pymc/tests/test_variational_inference.py index fc80c5e607c..7a058aefeac 100644 --- a/pymc/tests/test_variational_inference.py +++ b/pymc/tests/test_variational_inference.py @@ -731,7 +731,7 @@ def fit_kwargs(inference, use_minibatch): obj_optimizer=pm.adagrad_window(learning_rate=0.01, n_win=50), n=12000 ), (FullRankADVI, "full"): dict( - obj_optimizer=pm.adagrad_window(learning_rate=0.007, n_win=50), n=6000 + obj_optimizer=pm.adagrad_window(learning_rate=0.01, n_win=50), n=6000 ), (FullRankADVI, "mini"): dict( obj_optimizer=pm.adagrad_window(learning_rate=0.007, n_win=50), n=12000 diff --git a/pymc/variational/approximations.py b/pymc/variational/approximations.py index 8d9826f3557..95f72be0a42 100644 --- a/pymc/variational/approximations.py +++ b/pymc/variational/approximations.py @@ -167,10 +167,10 @@ def tril_indices(self): @node_property def symbolic_logq_not_scaled(self): z0 = self.symbolic_initial - logdet = at.sum( - at.log(at.diagonal(self.L, 0, self.L.ndim - 2, self.L.ndim - 1)), axis=-1, keepdims=True - ) - logq = pm.Normal.logp(z0, 0, 1).sum(-1, keepdims=True) + logdet + diag = at.diagonal(self.L, 0, self.L.ndim - 2, self.L.ndim - 1) + logdet = at.log(diag) + quaddist = ((z0) ** 2 + at.log(np.pi / 2.0)) / 2.0 + logq = quaddist - logdet return logq.sum(range(1, logq.ndim)) @node_property From c0c8fb97cdc06269d06283b3ca8d7bb6cc8e89f3 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Thu, 16 Sep 2021 14:26:38 +0000 Subject: [PATCH 24/82] fix tests --- pymc/tests/test_variational_inference.py | 21 ++++++++++++++++----- pymc/variational/approximations.py | 10 ++++++++-- pymc/variational/opvi.py | 4 +++- 3 files changed, 27 insertions(+), 8 deletions(-) diff --git a/pymc/tests/test_variational_inference.py b/pymc/tests/test_variational_inference.py index 7a058aefeac..135b0ed166f 100644 --- a/pymc/tests/test_variational_inference.py +++ b/pymc/tests/test_variational_inference.py @@ -835,6 +835,7 @@ def fit_method_with_object(request, another_simple_model): ("nfvi=bad-formula", dict(start={}), KeyError), ], ) +@ignore_not_implemented_inference def test_fit_fn_text(method, kwargs, error, another_simple_model): with another_simple_model: if error is not None: @@ -876,6 +877,7 @@ def test_aevb(inference_spec, aevb_model): pytest.skip("Does not support AEVB") +@ignore_not_implemented_inference def test_rowwise_approx(three_var_model, parametric_grouped_approxes): # add to inference that supports aevb cls, kw = parametric_grouped_approxes @@ -927,6 +929,7 @@ def binomial_model(): @pytest.fixture(scope="module") +@ignore_not_implemented_inference def binomial_model_inference(binomial_model, inference_spec): with binomial_model: return inference_spec() @@ -943,12 +946,20 @@ def test_replacements(binomial_model_inference): assert p_s.tag.test_value.shape == p_t.tag.test_value.shape sampled = [p_s.eval() for _ in range(100)] assert any(map(operator.ne, sampled[1:], sampled[:-1])) # stochastic - p_z = approx.sample_node(p_t, deterministic=True, size=10) + p_z = approx.sample_node(p_t, deterministic=False, size=10) assert p_z.shape.eval() == (10,) - - p_d = approx.sample_node(p_t, deterministic=True) - sampled = [p_d.eval() for _ in range(100)] - assert all(map(operator.eq, sampled[1:], sampled[:-1])) # deterministic + try: + p_z = approx.sample_node(p_t, deterministic=True, size=10) + assert p_z.shape.eval() == (10,) + except NotImplementedInference: + pass + + try: + p_d = approx.sample_node(p_t, deterministic=True) + sampled = [p_d.eval() for _ in range(100)] + assert all(map(operator.eq, sampled[1:], sampled[:-1])) # deterministic + except NotImplementedInference: + pass p_r = approx.sample_node(p_t, deterministic=d) sampled = [p_r.eval({d: 1}) for _ in range(100)] diff --git a/pymc/variational/approximations.py b/pymc/variational/approximations.py index 95f72be0a42..8e51541ef05 100644 --- a/pymc/variational/approximations.py +++ b/pymc/variational/approximations.py @@ -255,15 +255,21 @@ def randidx(self, size=None): def _new_initial(self, size, deterministic, more_replacements=None): aesara_condition_is_here = isinstance(deterministic, Variable) + if size is None: + size = 1 + size = at.as_tensor(size) if aesara_condition_is_here: return at.switch( deterministic, - at.repeat(self.mean.dimshuffle("x", 0), size if size is not None else 1, -1), + at.repeat(self.mean.reshape((1, -1)), size, -1), self.histogram[self.randidx(size)], ) else: if deterministic: - return at.repeat(self.mean.dimshuffle("x", 0), size if size is not None else 1, -1) + raise NotImplementedInference( + "Deterministic sampling from a Histogram is broken in v4" + ) + return at.repeat(self.mean.reshape((1, -1)), size, -1) else: return self.histogram[self.randidx(size)] diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 7eea855918d..e5e13cc7ac7 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -957,6 +957,8 @@ def __init_group__(self, group): raise GroupError("Got empty group") if self.local: raise NotImplementedInference("Local inferene aka AEVB is not fully supported in v4") + if self.batched: + raise NotImplementedInference("Batched inferene is not fully supported in v4") if self.group is None: # delayed init self.group = group @@ -1065,7 +1067,7 @@ def _new_initial_shape(self, size, dim, more_replacements=None): def bdim(self): if not self.local: if self.batched: - return next(iter(self.ordering.values()))[1][0] + return next(iter(self.ordering.values()))[2][0] else: return 1 else: From 7745ac6ac1d2b3f051a39e85c5737a5e848dc405 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Thu, 16 Sep 2021 14:50:21 +0000 Subject: [PATCH 25/82] test vi --- pymc/tests/test_variational_inference.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pymc/tests/test_variational_inference.py b/pymc/tests/test_variational_inference.py index 135b0ed166f..9aeaa4ff414 100644 --- a/pymc/tests/test_variational_inference.py +++ b/pymc/tests/test_variational_inference.py @@ -42,6 +42,7 @@ pytestmark = pytest.mark.usefixtures("strict_float32", "seeded_test") + def ignore_not_implemented_inference(func): @functools.wraps(func) def new_test(*args, **kwargs): @@ -49,6 +50,7 @@ def new_test(*args, **kwargs): return func(*args, **kwargs) except NotImplementedInference: pytest.xfail("NotImplementedInference") + return new_test @@ -141,7 +143,6 @@ def test_init_groups(three_var_model, raises, grouping): trace = approx.sample(100) - @pytest.fixture( params=[ ({}, {MeanFieldGroup: (None, {})}), @@ -245,9 +246,7 @@ def three_var_aevb_approx(three_var_model, three_var_aevb_groups): def test_sample_aevb(three_var_aevb_approx, aevb_initial): inf = pm.KLqp(three_var_aevb_approx) - inf.fit( - 1, more_replacements={aevb_initial: np.zeros_like(aevb_initial.get_value())[:1]} - ) + inf.fit(1, more_replacements={aevb_initial: np.zeros_like(aevb_initial.get_value())[:1]}) aevb_initial.set_value(np.random.rand(7, 7).astype("float32")) trace = three_var_aevb_approx.sample(500, return_inferencedata=False) assert set(trace.varnames) == {"one", "one_log__", "two", "three"} From 3dafc10d2eadca27f3e81ecef4970c7d906279cb Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Thu, 16 Sep 2021 16:09:37 +0000 Subject: [PATCH 26/82] fix conversion function --- pymc/variational/opvi.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index e5e13cc7ac7..fdf18369878 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -1690,9 +1690,7 @@ def sample(self, draws=500, return_inferencedata=True, **kwargs): if not return_inferencedata: return trace else: - import arviz as az - - return az.from_pymc3(trace, **kwargs) + return pm.to_inference_data(trace, **kwargs) @property def ndim(self): From 2752ebdf4350889936b4f5604f12f0967a064d01 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Thu, 16 Sep 2021 17:51:58 +0000 Subject: [PATCH 27/82] propagate model --- pymc/variational/opvi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index fdf18369878..e9477e04305 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -1690,7 +1690,7 @@ def sample(self, draws=500, return_inferencedata=True, **kwargs): if not return_inferencedata: return trace else: - return pm.to_inference_data(trace, **kwargs) + return pm.to_inference_data(trace, model=self.model, **kwargs) @property def ndim(self): From ff5f8c852c32ee4dffd4435b98f5fdc77edb0de2 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Thu, 16 Sep 2021 19:15:09 +0000 Subject: [PATCH 28/82] fix --- pymc/variational/approximations.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pymc/variational/approximations.py b/pymc/variational/approximations.py index 8e51541ef05..a888f7b6933 100644 --- a/pymc/variational/approximations.py +++ b/pymc/variational/approximations.py @@ -91,7 +91,8 @@ def symbolic_logq_not_scaled(self): z0 = self.symbolic_initial std = rho2sigma(self.rho) logdet = at.log(std) - logq = pm.Normal.logp(z0, 0, 1) - logdet + quaddist = ((z0) ** 2 + at.log(np.pi / 2.0)) / 2.0 + logq = quaddist - logdet return logq.sum(range(1, logq.ndim)) From c154063986829fba82755e9adfe0dd428bc9603c Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sun, 19 Sep 2021 11:11:40 +0000 Subject: [PATCH 29/82] fix elbo --- pymc/variational/approximations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/variational/approximations.py b/pymc/variational/approximations.py index a888f7b6933..1e86310e9c5 100644 --- a/pymc/variational/approximations.py +++ b/pymc/variational/approximations.py @@ -91,7 +91,7 @@ def symbolic_logq_not_scaled(self): z0 = self.symbolic_initial std = rho2sigma(self.rho) logdet = at.log(std) - quaddist = ((z0) ** 2 + at.log(np.pi / 2.0)) / 2.0 + quaddist = -0.5 * z0 ** 2 - at.log((2 * np.pi) ** 0.5) logq = quaddist - logdet return logq.sum(range(1, logq.ndim)) From af9c24d8e569c5093ec94a71d5fc999745b0a74e Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sun, 19 Sep 2021 11:13:43 +0000 Subject: [PATCH 30/82] fix elbo full rank --- pymc/variational/approximations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/variational/approximations.py b/pymc/variational/approximations.py index 1e86310e9c5..af0c342a83e 100644 --- a/pymc/variational/approximations.py +++ b/pymc/variational/approximations.py @@ -170,7 +170,7 @@ def symbolic_logq_not_scaled(self): z0 = self.symbolic_initial diag = at.diagonal(self.L, 0, self.L.ndim - 2, self.L.ndim - 1) logdet = at.log(diag) - quaddist = ((z0) ** 2 + at.log(np.pi / 2.0)) / 2.0 + quaddist = -0.5 * z0 ** 2 - at.log((2 * np.pi) ** 0.5) logq = quaddist - logdet return logq.sum(range(1, logq.ndim)) From a9d40efe46d55bb35015910343b5739d2656980e Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Thu, 23 Sep 2021 21:56:49 +0000 Subject: [PATCH 31/82] Fixing broken scaling with float32 --- pymc/model.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/pymc/model.py b/pymc/model.py index cfe154a6666..fd0d7634552 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -50,6 +50,7 @@ from pymc.aesaraf import ( compile_rv_inplace, + floatX, gradient, hessian, inputvars, @@ -1882,7 +1883,7 @@ def _get_scaling(total_size, shape, ndim): denom = shape[0] else: denom = 1 - coef = total_size / denom + coef = floatX(total_size) / floatX(denom) elif isinstance(total_size, (list, tuple)): if not all(isinstance(i, int) for i in total_size if (i is not Ellipsis and i is not None)): raise TypeError( @@ -1912,8 +1913,10 @@ def _get_scaling(total_size, shape, ndim): else: shp_end = np.asarray([]) shp_begin = shape[: len(begin)] - begin_coef = [t / shp_begin[i] for i, t in enumerate(begin) if t is not None] - end_coef = [t / shp_end[i] for i, t in enumerate(end) if t is not None] + begin_coef = [ + floatX(t) / floatX(shp_begin[i]) for i, t in enumerate(begin) if t is not None + ] + end_coef = [floatX(t) / floatX(shp_end[i]) for i, t in enumerate(end) if t is not None] coefs = begin_coef + end_coef coef = at.prod(coefs) else: From 54d2a437b804f91a2cc9d33e816a3401f97b127e Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Thu, 23 Sep 2021 22:56:15 +0000 Subject: [PATCH 32/82] ignore a nasty test --- pymc/tests/test_variational_inference.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pymc/tests/test_variational_inference.py b/pymc/tests/test_variational_inference.py index 9aeaa4ff414..f8a2bf991ee 100644 --- a/pymc/tests/test_variational_inference.py +++ b/pymc/tests/test_variational_inference.py @@ -935,6 +935,8 @@ def binomial_model_inference(binomial_model, inference_spec): def test_replacements(binomial_model_inference): + if aesara.config.warn_float64 == "raise": + pytest.skip("float32 is unreasonably strict here") d = at.bscalar() d.tag.test_value = 1 approx = binomial_model_inference.approx From 6d46a2f8f5d7972b802269709b801420b3efbf29 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sun, 26 Sep 2021 08:08:34 +0000 Subject: [PATCH 33/82] xfail one test with float 32 --- pymc/tests/test_variational_inference.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pymc/tests/test_variational_inference.py b/pymc/tests/test_variational_inference.py index f8a2bf991ee..b5f48285548 100644 --- a/pymc/tests/test_variational_inference.py +++ b/pymc/tests/test_variational_inference.py @@ -934,9 +934,8 @@ def binomial_model_inference(binomial_model, inference_spec): return inference_spec() +@pytest.mark.xfail("aesara.config.warn_float64 == 'raise'", reason="too strict float32") def test_replacements(binomial_model_inference): - if aesara.config.warn_float64 == "raise": - pytest.skip("float32 is unreasonably strict here") d = at.bscalar() d.tag.test_value = 1 approx = binomial_model_inference.approx From 2ce5a7d5dd0241602d4b97d762217b41bb26b3d8 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sun, 26 Sep 2021 09:41:03 +0000 Subject: [PATCH 34/82] fix pre commit --- .github/workflows/pytest.yml | 8 ++++---- .pre-commit-config.yaml | 6 +++--- pymc/variational/opvi.py | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 5288e652029..288a973caae 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -61,9 +61,9 @@ jobs: --ignore=pymc/tests/test_idata_conversion.py - | - pymc3/tests/test_initvals.py - pymc3/tests/test_distributions.py - - pymc3/tests/test_variational_inference.py + pymc/tests/test_initvals.py + pymc/tests/test_distributions.py + - pymc/tests/test_variational_inference.py - | pymc/tests/test_modelcontext.py @@ -154,7 +154,7 @@ jobs: os: [windows-latest] floatx: [float32, float64] test-subset: - - pymc3/tests/test_variational_inference.py + - pymc/tests/test_variational_inference.py - | pymc/tests/test_initvals.py pymc/tests/test_distributions_random.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bbe7f227a9a..7e70355acdf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,4 +1,4 @@ -exclude: ^(docs/logos|pymc3/tests/data)/ +exclude: ^(docs/logos|pymc/tests/data)/ repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.0.1 @@ -32,12 +32,12 @@ repos: hooks: - id: pylint args: [--rcfile=.pylintrc] - files: ^pymc3/ + files: ^pymc/ - repo: https://github.com/MarcoGorelli/madforhooks rev: 0.2.1 hooks: - id: no-print-statements - files: ^pymc3/ + files: ^pymc/ - repo: local hooks: - id: check-no-tests-are-ignored diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index e9477e04305..728d69cdfd6 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -59,11 +59,11 @@ from pymc.aesaraf import at_rng, identity, rvs_to_value_vars from pymc.backends import NDArray +from pymc.blocking import DictToArrayBijection from pymc.model import modelcontext from pymc.util import WithMemoization, locally_cachedmethod from pymc.variational.updates import adagrad_window from pymc.vartypes import discrete_types -from pymc.blocking import DictToArrayBijection __all__ = ["ObjectiveFunction", "Operator", "TestFunction", "Group", "Approximation"] From 69b9486ba321ba451eb288c69378e31fd62c4400 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sun, 26 Sep 2021 11:28:14 +0000 Subject: [PATCH 35/82] fix import --- pymc/variational/approximations.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pymc/variational/approximations.py b/pymc/variational/approximations.py index af0c342a83e..dc60a9ac3c3 100644 --- a/pymc/variational/approximations.py +++ b/pymc/variational/approximations.py @@ -25,7 +25,12 @@ from pymc.distributions.dist_math import rho2sigma from pymc.math import batched_diag from pymc.variational import flows, opvi -from pymc.variational.opvi import Approximation, Group, node_property +from pymc.variational.opvi import ( + Approximation, + Group, + NotImplementedInference, + node_property, +) __all__ = ["MeanField", "FullRank", "Empirical", "NormalizingFlow", "sample_approx"] From 1beec122dbd3ac73b30f82b11e9cf955b4d91450 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sun, 26 Sep 2021 17:26:24 +0000 Subject: [PATCH 36/82] fix import.1 --- pymc/tests/test_variational_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/tests/test_variational_inference.py b/pymc/tests/test_variational_inference.py index b5f48285548..5834f558f28 100644 --- a/pymc/tests/test_variational_inference.py +++ b/pymc/tests/test_variational_inference.py @@ -38,7 +38,7 @@ NormalizingFlowGroup, ) from pymc.variational.inference import ADVI, ASVGD, NFVI, SVGD, FullRankADVI, fit -from pymc.variational.opvi import Approximation, Group +from pymc.variational.opvi import Approximation, Group, NotImplementedInference pytestmark = pytest.mark.usefixtures("strict_float32", "seeded_test") From 894d5ce9e09c45da65035c4e942b33a4407cb803 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Mon, 27 Sep 2021 11:10:31 +0300 Subject: [PATCH 37/82] Update pymc/variational/opvi.py Co-authored-by: Thomas Wiecki --- pymc/variational/opvi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 728d69cdfd6..418555926e4 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -958,7 +958,7 @@ def __init_group__(self, group): if self.local: raise NotImplementedInference("Local inferene aka AEVB is not fully supported in v4") if self.batched: - raise NotImplementedInference("Batched inferene is not fully supported in v4") + raise NotImplementedInference("Batched inferene is not supported in v4") if self.group is None: # delayed init self.group = group From 8d2ec8ba7577a1e602882d6d1bf4997e61b2018d Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Mon, 27 Sep 2021 08:24:09 +0000 Subject: [PATCH 38/82] fix docstrings --- pymc/model.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/pymc/model.py b/pymc/model.py index fd0d7634552..421f0f6cb32 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -1864,14 +1864,18 @@ def Deterministic(name, var, model=None, dims=None, auto=False): def _get_scaling(total_size, shape, ndim): """ - Gets scaling constant for logp + Gets scaling constant for logp. + Parameters ---------- - total_size: int or list[int] + total_size: Optional[int|List[int]] + size of a fully observed data without minibatching, + `None` means data is fully observed shape: shape - shape to scale + shape of an observed data ndim: int ndim hint + Returns ------- scalar From c03352e65fb25bfb4256176158427c36a82c1ec4 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Thu, 14 Oct 2021 09:04:20 +0000 Subject: [PATCH 39/82] fix error with nans --- pymc/tests/test_variational_inference.py | 13 +++++++++++++ pymc/variational/inference.py | 18 +++++++----------- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/pymc/tests/test_variational_inference.py b/pymc/tests/test_variational_inference.py index 5834f558f28..fc091bc90be 100644 --- a/pymc/tests/test_variational_inference.py +++ b/pymc/tests/test_variational_inference.py @@ -1103,3 +1103,16 @@ def test_flow_formula(formula, length, order): if order is not None: assert flows_list == order spec(dim=2, jitter=1)(at.ones((3, 2))).eval() # should work + + +@pytest.mark.parametrize("score", [True, False]) +def test_fit_with_nans(score): + X_mean = pm.floatX(np.linspace(0, 10, 10)) + y = pm.floatX(np.random.normal(X_mean * 4, 0.05)) + with pm.Model(): + inp = pm.Normal("X", X_mean, size=X_mean.shape) + coef = pm.Normal("b", 4.0) + mean = inp * coef + pm.Normal("y", mean, 0.1, observed=y) + with pytest.raises(FloatingPointError) as e: + advi = pm.fit(100, score=score, obj_optimizer=pm.adam(learning_rate=float("nan"))) diff --git a/pymc/variational/inference.py b/pymc/variational/inference.py index 1ab75957fff..4bb29897c25 100644 --- a/pymc/variational/inference.py +++ b/pymc/variational/inference.py @@ -166,12 +166,10 @@ def _iterate_without_loss(self, s, _, step_func, progress, callbacks): if np.isnan(current_param).any(): name_slc = [] tmp_hold = list(range(current_param.size)) - # XXX: This needs to be refactored - vmap = None # self.approx.groups[0].bij.ordering.vmap - for vmap_ in vmap: - slclen = len(tmp_hold[vmap_.slc]) + for varname, slice_info in self.approx.groups[0].ordering.items(): + slclen = len(tmp_hold[slice_info[1]]) for j in range(slclen): - name_slc.append((vmap_.var, j)) + name_slc.append((varname, j)) index = np.where(np.isnan(current_param))[0] errmsg = ["NaN occurred in optimization. "] suggest_solution = ( @@ -210,18 +208,16 @@ def _infmean(input_array): try: for i in progress: e = step_func() - if np.isnan(e): # pragma: no cover + if np.isnan(e): scores = scores[:i] self.hist = np.concatenate([self.hist, scores]) current_param = self.approx.params[0].get_value() name_slc = [] tmp_hold = list(range(current_param.size)) - # XXX: This needs to be refactored - vmap = None # self.approx.groups[0].bij.ordering.vmap - for vmap_ in vmap: - slclen = len(tmp_hold[vmap_.slc]) + for varname, slice_info in self.approx.groups[0].ordering.items(): + slclen = len(tmp_hold[slice_info[1]]) for j in range(slclen): - name_slc.append((vmap_.var, j)) + name_slc.append((varname, j)) index = np.where(np.isnan(current_param))[0] errmsg = ["NaN occurred in optimization. "] suggest_solution = ( From 00c1d14e7185df873e5845a4bddf41ff55b3205b Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Thu, 14 Oct 2021 09:14:03 +0000 Subject: [PATCH 40/82] remove TODO comments --- pymc/variational/opvi.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 418555926e4..b369c22d6ad 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -1075,7 +1075,6 @@ def bdim(self): @node_property def ndim(self): - # XXX: This needs to be refactored if self.batched: return self.ordering.size * self.bdim else: @@ -1083,7 +1082,6 @@ def ndim(self): @property def ddim(self): - # TODO: This needs to be refactored return sum(s.stop - s.start for _, s, _, _ in self.ordering.values()) def _new_initial(self, size, deterministic, more_replacements=None): From 694286a96f331fde1331516a575d66369b2e76a8 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Thu, 14 Oct 2021 10:27:27 +0000 Subject: [PATCH 41/82] print statements to logging --- pymc/__init__.py | 4 ++-- pymc/tests/test_bart.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pymc/__init__.py b/pymc/__init__.py index cb30dd2d487..10c5a2651d3 100644 --- a/pymc/__init__.py +++ b/pymc/__init__.py @@ -32,7 +32,7 @@ def _check_install_compatibilitites(): try: import theano - print( + _log.info( "!" * 60 + f"\nYour Python environment has Theano(-PyMC) {theano.__version__} installed, " + f"but you are importing PyMC {__version__} which uses Aesara as its backend." @@ -46,7 +46,7 @@ def _check_install_compatibilitites(): try: import pymc3 - print( + _log.info( "!" * 60 + f"\nYou are importing PyMC {__version__}, but your environment also has" + f" the legacy version PyMC3 {pymc3.__version__} installed." diff --git a/pymc/tests/test_bart.py b/pymc/tests/test_bart.py index 20b14c6966d..04e6cef8d33 100644 --- a/pymc/tests/test_bart.py +++ b/pymc/tests/test_bart.py @@ -1,7 +1,6 @@ import numpy as np from numpy.random import RandomState -from numpy.testing import assert_almost_equal import pymc as pm @@ -62,6 +61,7 @@ def test_bart_random(): rng = RandomState(12345) pred_first = mu.owner.op.rng_fn(rng, X_new=X[:10]) + np.testing.assert_almost_equal(pred_first, pred_all[0, :10], decimal=4) assert pred_all.shape == (2, 50) assert pred_first.shape == (10,) From 8dba7d58538d5468bc8f6b6bd3ea91e49b77b200 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Thu, 14 Oct 2021 11:19:10 +0000 Subject: [PATCH 42/82] revert bart test --- pymc/tests/test_bart.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pymc/tests/test_bart.py b/pymc/tests/test_bart.py index 04e6cef8d33..57e031b0be4 100644 --- a/pymc/tests/test_bart.py +++ b/pymc/tests/test_bart.py @@ -61,7 +61,8 @@ def test_bart_random(): rng = RandomState(12345) pred_first = mu.owner.op.rng_fn(rng, X_new=X[:10]) - np.testing.assert_almost_equal(pred_first, pred_all[0, :10], decimal=4) + # XXX: is this supposed to work? + # np.testing.assert_almost_equal(pred_first, pred_all[0, :10], decimal=4) assert pred_all.shape == (2, 50) assert pred_first.shape == (10,) From 3a5915aee8eea47c92303a5c29b84f05b5823a69 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Fri, 15 Oct 2021 18:25:01 +0000 Subject: [PATCH 43/82] fix pylint issues --- pymc/distributions/distribution.py | 2 -- pymc/sampling.py | 4 ++-- pymc/tests/test_model.py | 1 - 3 files changed, 2 insertions(+), 5 deletions(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index efb648c0870..444d800591d 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -22,7 +22,6 @@ from typing import Callable, Optional, Sequence import aesara -import numpy as np from aesara.tensor.basic import as_tensor_variable from aesara.tensor.random.op import RandomVariable @@ -42,7 +41,6 @@ maybe_resize, resize_from_dims, resize_from_observed, - to_tuple, ) from pymc.printing import str_for_dist from pymc.util import UNSET diff --git a/pymc/sampling.py b/pymc/sampling.py index 692a28bcc41..ba44ac9bf1c 100644 --- a/pymc/sampling.py +++ b/pymc/sampling.py @@ -22,7 +22,7 @@ import warnings from collections import defaultdict -from copy import copy, deepcopy +from copy import copy from typing import Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union, cast import aesara.gradient as tg @@ -51,7 +51,7 @@ filter_rvs_to_jitter, make_initial_point_fns_per_chain, ) -from pymc.model import Model, Point, modelcontext +from pymc.model import Model, modelcontext from pymc.parallel_sampling import Draw, _cpu_count from pymc.step_methods import ( NUTS, diff --git a/pymc/tests/test_model.py b/pymc/tests/test_model.py index a5b1bf14872..83f1fe26056 100644 --- a/pymc/tests/test_model.py +++ b/pymc/tests/test_model.py @@ -28,7 +28,6 @@ from aesara.tensor.random.op import RandomVariable from aesara.tensor.var import TensorConstant -from numpy.testing import assert_almost_equal import pymc as pm From f6d9b9850f5709f8c157565a543d798ac24dbdd1 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Wed, 20 Oct 2021 10:36:03 +0000 Subject: [PATCH 44/82] fix test bart --- pymc/tests/test_bart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/tests/test_bart.py b/pymc/tests/test_bart.py index b6e8d0eb40b..04e6cef8d33 100644 --- a/pymc/tests/test_bart.py +++ b/pymc/tests/test_bart.py @@ -43,7 +43,7 @@ def test_bart_vi(): ) var_imp /= var_imp.sum() assert var_imp[0] > var_imp[1:].sum() - assert_almost_equal(var_imp.sum(), 1) + np.testing.assert_almost_equal(var_imp.sum(), 1) def test_bart_random(): From 9a79e275c04934dd5f97b1931ff3a46b98bcd333 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Wed, 20 Oct 2021 14:07:23 +0000 Subject: [PATCH 45/82] fix interence_data in init --- pymc/sampling.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pymc/sampling.py b/pymc/sampling.py index ba44ac9bf1c..c74158e8b17 100644 --- a/pymc/sampling.py +++ b/pymc/sampling.py @@ -2229,7 +2229,7 @@ def init_nuts( progressbar=progressbar, obj_optimizer=pm.adagrad_window, ) - initial_points = list(approx.sample(draws=chains)) + initial_points = list(approx.sample(draws=chains, return_inferencedata=False)) std_apoint = approx.std.eval() cov = std_apoint ** 2 mean = approx.mean.get_value() @@ -2246,7 +2246,7 @@ def init_nuts( progressbar=progressbar, obj_optimizer=pm.adagrad_window, ) - initial_points = list(approx.sample(draws=chains)) + initial_points = list(approx.sample(draws=chains, return_inferencedata=False)) cov = approx.std.eval() ** 2 potential = quadpotential.QuadPotentialDiag(cov) elif init == "advi_map": @@ -2260,7 +2260,7 @@ def init_nuts( progressbar=progressbar, obj_optimizer=pm.adagrad_window, ) - initial_points = list(approx.sample(draws=chains)) + initial_points = list(approx.sample(draws=chains, return_inferencedata=False)) cov = approx.std.eval() ** 2 potential = quadpotential.QuadPotentialDiag(cov) elif init == "map": From deafa96a59252731ad055abf7298b61a12536c1d Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Tue, 26 Oct 2021 07:56:24 +0000 Subject: [PATCH 46/82] ignore pickling problems --- pymc/tests/test_variational_inference.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pymc/tests/test_variational_inference.py b/pymc/tests/test_variational_inference.py index fc091bc90be..2cb50125c99 100644 --- a/pymc/tests/test_variational_inference.py +++ b/pymc/tests/test_variational_inference.py @@ -893,6 +893,7 @@ def test_rowwise_approx(three_var_model, parametric_grouped_approxes): pytest.skip("Does not support rowwise grouping") +@pytest.mark.xfail("https://github.com/pymc-devs/pymc/issues/5090") def test_pickle_approx(three_var_approx): import cloudpickle @@ -901,6 +902,7 @@ def test_pickle_approx(three_var_approx): assert new.sample(1) +@pytest.mark.xfail("https://github.com/pymc-devs/pymc/issues/5090") def test_pickle_single_group(three_var_approx_single_group_mf): import cloudpickle @@ -909,6 +911,7 @@ def test_pickle_single_group(three_var_approx_single_group_mf): assert new.sample(1) +@pytest.mark.xfail("https://github.com/pymc-devs/pymc/issues/5090") def test_pickle_approx_aevb(three_var_aevb_approx): import cloudpickle From 0f45e7323581bdb501f5e3f06f7fba237cdd78d7 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Tue, 26 Oct 2021 08:01:50 +0000 Subject: [PATCH 47/82] fix aevb test --- pymc/tests/test_variational_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/tests/test_variational_inference.py b/pymc/tests/test_variational_inference.py index 2cb50125c99..40bf103bf3d 100644 --- a/pymc/tests/test_variational_inference.py +++ b/pymc/tests/test_variational_inference.py @@ -851,7 +851,7 @@ def aevb_model(): pm.Normal("y", size=(2,)) x = model.x y = model.y - xr = model.initial_values[model.rvs_to_values[x]] + xr = model.recompute_initial_point()[model.rvs_to_values[x].name] mu = aesara.shared(xr) rho = aesara.shared(np.zeros_like(xr)) return {"model": model, "y": y, "x": x, "replace": dict(mu=mu, rho=rho)} From 8d48870091155b36778c6ede4b828e2cd3d0bb3c Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sun, 7 Nov 2021 09:04:42 +0000 Subject: [PATCH 48/82] fix name error --- pymc/tests/test_bart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/tests/test_bart.py b/pymc/tests/test_bart.py index 6b0ebb8bec9..b65ebfa7119 100644 --- a/pymc/tests/test_bart.py +++ b/pymc/tests/test_bart.py @@ -84,7 +84,7 @@ def test_predict(self): rng = RandomState(12345) pred_first = pm.bart.utils.predict(self.idata, rng, X_new=self.X[:10]) - assert_almost_equal(pred_first, pred_all[0, :10], decimal=4) + np.testing.assert_almost_equal(pred_first, pred_all[0, :10], decimal=4) assert pred_all.shape == (2, 50) assert pred_first.shape == (10,) From 6efd630ebdfea7fdb32083db618c5429f0add714 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sun, 7 Nov 2021 09:19:58 +0000 Subject: [PATCH 49/82] xfail test ramdom fn --- pymc/tests/test_bart.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/pymc/tests/test_bart.py b/pymc/tests/test_bart.py index b65ebfa7119..30c5f712df3 100644 --- a/pymc/tests/test_bart.py +++ b/pymc/tests/test_bart.py @@ -58,6 +58,14 @@ def test_missing_data(): y = pm.Normal("y", mu, sigma, observed=Y) idata = pm.sample(random_seed=3415, chains=1) + +@pytest.xfail("random fn is not yet implemented") +def test_random_fn(): + X = np.random.normal(0, 1, size=(2, 50)).T + Y = np.random.normal(0, 1, size=50) + + with pm.Model() as model: + mu = pm.BART("mu", X, Y, m=10) rng = RandomState(12345) pred_all = mu.owner.op.rng_fn(rng, size=2) rng = RandomState(12345) From b2e9c0f3c2e3db06685ecd99437ed988cde61654 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sun, 7 Nov 2021 10:01:34 +0000 Subject: [PATCH 50/82] mark xfail --- pymc/tests/test_bart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/tests/test_bart.py b/pymc/tests/test_bart.py index 30c5f712df3..a586ae58412 100644 --- a/pymc/tests/test_bart.py +++ b/pymc/tests/test_bart.py @@ -59,7 +59,7 @@ def test_missing_data(): idata = pm.sample(random_seed=3415, chains=1) -@pytest.xfail("random fn is not yet implemented") +@pytest.mark.xfail("random fn is not yet implemented") def test_random_fn(): X = np.random.normal(0, 1, size=(2, 50)).T Y = np.random.normal(0, 1, size=50) From a92aad86fa0499131a27dd464a47316ed8fb9c99 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sun, 7 Nov 2021 11:21:18 +0000 Subject: [PATCH 51/82] refactor test --- pymc/tests/test_bart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/tests/test_bart.py b/pymc/tests/test_bart.py index a586ae58412..97368b263e9 100644 --- a/pymc/tests/test_bart.py +++ b/pymc/tests/test_bart.py @@ -71,9 +71,9 @@ def test_random_fn(): rng = RandomState(12345) pred_first = mu.owner.op.rng_fn(rng, X_new=X[:10]) - np.testing.assert_almost_equal(pred_first, pred_all[0, :10], decimal=4) assert pred_all.shape == (2, 50) assert pred_first.shape == (10,) + np.testing.assert_almost_equal(pred_first, pred_all[0, :10], decimal=4) class TestUtils: From f253417134b713e7efe11bfb303756ff3ca2d6bb Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sun, 7 Nov 2021 11:43:53 +0000 Subject: [PATCH 52/82] xfail fix --- pymc/tests/test_bart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/tests/test_bart.py b/pymc/tests/test_bart.py index 97368b263e9..4958b89af2f 100644 --- a/pymc/tests/test_bart.py +++ b/pymc/tests/test_bart.py @@ -59,7 +59,7 @@ def test_missing_data(): idata = pm.sample(random_seed=3415, chains=1) -@pytest.mark.xfail("random fn is not yet implemented") +@pytest.mark.xfail(reason="random fn is not yet implemented") def test_random_fn(): X = np.random.normal(0, 1, size=(2, 50)).T Y = np.random.normal(0, 1, size=50) From f09d33adef9c8be5d6eaa9606d83afabda954081 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Mon, 8 Nov 2021 11:32:22 +0000 Subject: [PATCH 53/82] fix xfail syntax --- pymc/tests/test_variational_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/tests/test_variational_inference.py b/pymc/tests/test_variational_inference.py index 40bf103bf3d..305c0ef2943 100644 --- a/pymc/tests/test_variational_inference.py +++ b/pymc/tests/test_variational_inference.py @@ -893,7 +893,7 @@ def test_rowwise_approx(three_var_model, parametric_grouped_approxes): pytest.skip("Does not support rowwise grouping") -@pytest.mark.xfail("https://github.com/pymc-devs/pymc/issues/5090") +@pytest.mark.xfail(reason="https://github.com/pymc-devs/pymc/issues/5090") def test_pickle_approx(three_var_approx): import cloudpickle From 19ea8c97b281cbdc3fd92381a353300d14de35a5 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Mon, 8 Nov 2021 12:06:49 +0000 Subject: [PATCH 54/82] pytest --- pymc/tests/test_variational_inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc/tests/test_variational_inference.py b/pymc/tests/test_variational_inference.py index 305c0ef2943..416ab0f6ab7 100644 --- a/pymc/tests/test_variational_inference.py +++ b/pymc/tests/test_variational_inference.py @@ -902,7 +902,7 @@ def test_pickle_approx(three_var_approx): assert new.sample(1) -@pytest.mark.xfail("https://github.com/pymc-devs/pymc/issues/5090") +@pytest.mark.xfail(reason="https://github.com/pymc-devs/pymc/issues/5090") def test_pickle_single_group(three_var_approx_single_group_mf): import cloudpickle @@ -911,7 +911,7 @@ def test_pickle_single_group(three_var_approx_single_group_mf): assert new.sample(1) -@pytest.mark.xfail("https://github.com/pymc-devs/pymc/issues/5090") +@pytest.mark.xfail(reason="https://github.com/pymc-devs/pymc/issues/5090") def test_pickle_approx_aevb(three_var_aevb_approx): import cloudpickle From f14cbc1d05ef5bb619fd8ddb3bcf624fac9670c4 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Mon, 8 Nov 2021 14:16:11 +0000 Subject: [PATCH 55/82] test fixed --- pymc/tests/test_variational_inference.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pymc/tests/test_variational_inference.py b/pymc/tests/test_variational_inference.py index 416ab0f6ab7..5526754aa1b 100644 --- a/pymc/tests/test_variational_inference.py +++ b/pymc/tests/test_variational_inference.py @@ -893,7 +893,6 @@ def test_rowwise_approx(three_var_model, parametric_grouped_approxes): pytest.skip("Does not support rowwise grouping") -@pytest.mark.xfail(reason="https://github.com/pymc-devs/pymc/issues/5090") def test_pickle_approx(three_var_approx): import cloudpickle From 02fc30fedbc492afc180151ac6adebdca9b35edb Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Mon, 8 Nov 2021 14:44:54 +0000 Subject: [PATCH 56/82] 5090 fixed --- pymc/tests/test_variational_inference.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pymc/tests/test_variational_inference.py b/pymc/tests/test_variational_inference.py index 5526754aa1b..92303ab953d 100644 --- a/pymc/tests/test_variational_inference.py +++ b/pymc/tests/test_variational_inference.py @@ -901,7 +901,6 @@ def test_pickle_approx(three_var_approx): assert new.sample(1) -@pytest.mark.xfail(reason="https://github.com/pymc-devs/pymc/issues/5090") def test_pickle_single_group(three_var_approx_single_group_mf): import cloudpickle @@ -910,7 +909,6 @@ def test_pickle_single_group(three_var_approx_single_group_mf): assert new.sample(1) -@pytest.mark.xfail(reason="https://github.com/pymc-devs/pymc/issues/5090") def test_pickle_approx_aevb(three_var_aevb_approx): import cloudpickle From baefac657cc0d7eebdf8315135bebe0ff496ca62 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Mon, 15 Nov 2021 11:43:25 +0000 Subject: [PATCH 57/82] do not test local flows --- pymc/tests/test_variational_inference.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pymc/tests/test_variational_inference.py b/pymc/tests/test_variational_inference.py index 92303ab953d..225e6c1012e 100644 --- a/pymc/tests/test_variational_inference.py +++ b/pymc/tests/test_variational_inference.py @@ -1061,6 +1061,7 @@ def test_flow_det(flow_spec): np.testing.assert_allclose(logJdet.eval(), det.eval(), atol=0.0001) +@pytest.mark.skip("normalizing flows are not fully supported") def test_flow_det_local(flow_spec): z0 = at.arange(0, 12).astype("float32") spec = flow_spec.cls.get_param_spec_for(d=12) From beb75bae4a8a6a366b252e90660ae644edca1e45 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Tue, 16 Nov 2021 13:28:35 +0000 Subject: [PATCH 58/82] change model.logpt not to return float --- pymc/model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pymc/model.py b/pymc/model.py index e4d59f8bfd9..dcd6b41a249 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -772,7 +772,7 @@ def varlogpt(self): if rv_values: return logpt(self.free_RVs, rv_values) else: - return 0 + return at.as_tensor(0.0) @property def varlogp_nojact(self): @@ -784,7 +784,7 @@ def varlogp_nojact(self): if rv_values: return logpt(self.free_RVs, rv_values, jacobian=False) else: - return 0 + return at.as_tensor(0.0) @property def observedlogpt(self): @@ -795,7 +795,7 @@ def observedlogpt(self): if obs_values: return logpt(self.observed_RVs, obs_values) else: - return 0 + return at.as_tensor(0.0) @property def potentiallogpt(self): @@ -806,7 +806,7 @@ def potentiallogpt(self): if potentials: return at.sum([at.sum(factor) for factor in potentials]) else: - return 0 + return at.as_tensor(0.0) @property def vars(self): From c2d24de2a2baf5f57c769cf1dfb258447859ac7c Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sat, 27 Nov 2021 13:06:18 +0000 Subject: [PATCH 59/82] add a test for the replacenent in the graph --- pymc/tests/test_variational_inference.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pymc/tests/test_variational_inference.py b/pymc/tests/test_variational_inference.py index 225e6c1012e..9a4d6897c49 100644 --- a/pymc/tests/test_variational_inference.py +++ b/pymc/tests/test_variational_inference.py @@ -942,6 +942,10 @@ def test_replacements(binomial_model_inference): p = approx.model.p p_t = p ** 3 p_s = approx.sample_node(p_t) + assert not any( + isinstance(n.owner.op, aesara.tensor.random.basic.BetaRV) + for n in aesara.graph.ancestors([p_s]) + ), "p should be replaced" if aesara.config.compute_test_value != "off": assert p_s.tag.test_value.shape == p_t.tag.test_value.shape sampled = [p_s.eval() for _ in range(100)] From 13a970e47dc0d8c9e6c95bc4b5627aa051abcf12 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sun, 16 Jan 2022 09:59:04 +0000 Subject: [PATCH 60/82] fix sample node functionality --- pymc/variational/opvi.py | 28 +++++++++++----------------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 48be4f3d64f..8e4ea05b397 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -853,7 +853,6 @@ def __init__( self.user_params = params self._user_params = None self.replacements = collections.OrderedDict() - self.value_replacements = collections.OrderedDict() self.ordering = collections.OrderedDict() # save this stuff to use in __init_group__ later self._kwargs = kwargs @@ -979,9 +978,7 @@ def __init_group__(self, group): # so I have to to it by myself # 1) we need initial point (transformed space) - model_initial_point = self.model.initial_point - _, replace_to_value_vars = rvs_to_value_vars(self.group, apply_transforms=True) - self.value_replacements.update(replace_to_value_vars) + model_initial_point = self.model.recompute_initial_point(0) # 2) we'll work with a single group, a subset of the model # here we need to create a mapping to replace value_vars with slices from the approximation start_idx = 0 @@ -1180,7 +1177,6 @@ def set_size_and_deterministic(self, node, s, d, more_replacements=None): def to_flat_input(self, node): """*Dev* - replace vars with flattened view stored in `self.inputs`""" - node = aesara.clone_replace(node, self.value_replacements) return aesara.clone_replace(node, self.replacements) def symbolic_sample_over_posterior(self, node): @@ -1483,13 +1479,6 @@ def datalogp_norm(self): """*Dev* - normalized :math:`E_{q}(data term)`""" return self.datalogp / self.symbolic_normalizing_constant - @property - def value_replacements(self): - """*Dev* - all replacements from groups to replace PyMC random variables with approximation""" - return collections.OrderedDict( - itertools.chain.from_iterable(g.value_replacements.items() for g in self.groups) - ) - @property def replacements(self): """*Dev* - all replacements from groups to replace PyMC random variables with approximation""" @@ -1553,7 +1542,7 @@ def set_size_and_deterministic(self, node, s, d, more_replacements=None): def to_flat_input(self, node, more_replacements=None): """*Dev* - replace vars with flattened view stored in `self.inputs`""" more_replacements = more_replacements or {} - node = aesara.clone_replace(node, {**self.value_replacements, **more_replacements}) + node = aesara.clone_replace(node, more_replacements) return aesara.clone_replace(node, self.replacements) def symbolic_sample_over_posterior(self, node, more_replacements=None): @@ -1609,12 +1598,17 @@ def sample_node(self, node, size=None, deterministic=False, more_replacements=No sampled node(s) with replacements """ node_in = node + if not isinstance(node, (list, tuple)): + node = [node] + node, _ = rvs_to_value_vars( + node, apply_transforms=True, initial_replacements=more_replacements + ) + if not isinstance(node_in, (list, tuple)): + node = node[0] if size is None: - node_out = self.symbolic_single_sample(node, more_replacements=more_replacements) + node_out = self.symbolic_single_sample(node) else: - node_out = self.symbolic_sample_over_posterior( - node, more_replacements=more_replacements - ) + node_out = self.symbolic_sample_over_posterior(node) node_out = self.set_size_and_deterministic(node_out, size, deterministic) try_to_set_test_value(node_in, node_out, size) return node_out From 994fba52b85a2d1ba228abd2de4830cd553ec1ba Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sun, 16 Jan 2022 10:30:43 +0000 Subject: [PATCH 61/82] Fix test with var replacement --- pymc/variational/opvi.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 8e4ea05b397..0f1515cca97 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -1598,11 +1598,11 @@ def sample_node(self, node, size=None, deterministic=False, more_replacements=No sampled node(s) with replacements """ node_in = node + if more_replacements: + node = aesara.clone_replace(node, more_replacements) if not isinstance(node, (list, tuple)): node = [node] - node, _ = rvs_to_value_vars( - node, apply_transforms=True, initial_replacements=more_replacements - ) + node, _ = rvs_to_value_vars(node, apply_transforms=True) if not isinstance(node_in, (list, tuple)): node = node[0] if size is None: From 60900291786555e7268e2ecd955c8b4d9dc8be9f Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sun, 16 Jan 2022 11:20:29 +0000 Subject: [PATCH 62/82] add uncommited changes --- pymc/tests/test_variational_inference.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pymc/tests/test_variational_inference.py b/pymc/tests/test_variational_inference.py index ef9437f3b00..38b7fc0b306 100644 --- a/pymc/tests/test_variational_inference.py +++ b/pymc/tests/test_variational_inference.py @@ -947,6 +947,7 @@ def test_replacements(binomial_model_inference): assert not any( isinstance(n.owner.op, aesara.tensor.random.basic.BetaRV) for n in aesara.graph.ancestors([p_s]) + if n.owner ), "p should be replaced" if aesara.config.compute_test_value != "off": assert p_s.tag.test_value.shape == p_t.tag.test_value.shape From 48041f5ccadbdaf56333d47ca0c1cdeebd8f65ad Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sun, 23 Jan 2022 10:09:33 +0000 Subject: [PATCH 63/82] resolve @ricardoV94's comment about initial point --- pymc/tests/test_variational_inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc/tests/test_variational_inference.py b/pymc/tests/test_variational_inference.py index 38b7fc0b306..929b20b5190 100644 --- a/pymc/tests/test_variational_inference.py +++ b/pymc/tests/test_variational_inference.py @@ -222,7 +222,7 @@ def parametric_grouped_approxes(request): @pytest.fixture @ignore_not_implemented_inference def three_var_aevb_groups(parametric_grouped_approxes, three_var_model, aevb_initial): - one_initial_value = three_var_model.recompute_initial_point()[ + one_initial_value = three_var_model.recompute_initial_point(0)[ three_var_model.one.tag.value_var.name ] dsize = np.prod(one_initial_value.shape[1:]) @@ -853,7 +853,7 @@ def aevb_model(): pm.Normal("y", size=(2,)) x = model.x y = model.y - xr = model.recompute_initial_point()[model.rvs_to_values[x].name] + xr = model.recompute_initial_point(0)[model.rvs_to_values[x].name] mu = aesara.shared(xr) rho = aesara.shared(np.zeros_like(xr)) return {"model": model, "y": y, "x": x, "replace": dict(mu=mu, rho=rho)} From cb0fee992ef70ccbc659de461269d0a1b2de9d48 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sun, 23 Jan 2022 10:25:07 +0000 Subject: [PATCH 64/82] restore test_bart.py as in main branch --- pymc/tests/test_bart.py | 24 ++++-------------------- 1 file changed, 4 insertions(+), 20 deletions(-) diff --git a/pymc/tests/test_bart.py b/pymc/tests/test_bart.py index 0114128afe2..3b5df81bf8d 100644 --- a/pymc/tests/test_bart.py +++ b/pymc/tests/test_bart.py @@ -2,6 +2,7 @@ import pytest from numpy.random import RandomState +from numpy.testing import assert_almost_equal import pymc as pm @@ -46,7 +47,7 @@ def test_bart_vi(): ) var_imp /= var_imp.sum() assert var_imp[0] > var_imp[1:].sum() - np.testing.assert_almost_equal(var_imp.sum(), 1) + assert_almost_equal(var_imp.sum(), 1) def test_missing_data(): @@ -58,24 +59,7 @@ def test_missing_data(): mu = pm.BART("mu", X, Y, m=10) sigma = pm.HalfNormal("sigma", 1) y = pm.Normal("y", mu, sigma, observed=Y) - idata = pm.sample(random_seed=3415, chains=1) - - -@pytest.mark.xfail(reason="random fn is not yet implemented") -def test_random_fn(): - X = np.random.normal(0, 1, size=(2, 50)).T - Y = np.random.normal(0, 1, size=50) - - with pm.Model() as model: - mu = pm.BART("mu", X, Y, m=10) - rng = RandomState(12345) - pred_all = mu.owner.op.rng_fn(rng, size=2) - rng = RandomState(12345) - pred_first = mu.owner.op.rng_fn(rng, X_new=X[:10]) - - assert pred_all.shape == (2, 50) - assert pred_first.shape == (10,) - np.testing.assert_almost_equal(pred_first, pred_all[0, :10], decimal=4) + idata = pm.sample(random_seed=3415) class TestUtils: @@ -96,7 +80,7 @@ def test_predict(self): rng = RandomState(12345) pred_first = pm.bart.utils.predict(self.idata, rng, X_new=self.X[:10]) - np.testing.assert_almost_equal(pred_first, pred_all[0, :10], decimal=4) + assert_almost_equal(pred_first, pred_all[0, :10], decimal=4) assert pred_all.shape == (2, 50) assert pred_first.shape == (10,) From c5911ac68cfb0a8edf6fdac8c4ad5e7847e4e5ea Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sun, 23 Jan 2022 10:29:45 +0000 Subject: [PATCH 65/82] resolve duplicated _get_scaling function --- pymc/distributions/logprob.py | 21 +++++++---- pymc/model.py | 70 +---------------------------------- 2 files changed, 14 insertions(+), 77 deletions(-) diff --git a/pymc/distributions/logprob.py b/pymc/distributions/logprob.py index 9fe2b94b994..e59384c55db 100644 --- a/pymc/distributions/logprob.py +++ b/pymc/distributions/logprob.py @@ -16,6 +16,7 @@ from functools import singledispatch from typing import Dict, List, Optional, Union +import aesara import aesara.tensor as at import numpy as np @@ -45,13 +46,15 @@ def logp_transform(op: Op): def _get_scaling(total_size, shape, ndim): """ - Gets scaling constant for logp + Gets scaling constant for logp. Parameters ---------- - total_size: int or list[int] + total_size: Optional[int|List[int]] + size of a fully observed data without minibatching, + `None` means data is fully observed shape: shape - shape to scale + shape of an observed data ndim: int ndim hint @@ -60,7 +63,7 @@ def _get_scaling(total_size, shape, ndim): scalar """ if total_size is None: - coef = floatX(1) + coef = 1.0 elif isinstance(total_size, int): if ndim >= 1: denom = shape[0] @@ -90,21 +93,23 @@ def _get_scaling(total_size, shape, ndim): "number of scalings is bigger that ndim, got %r" % total_size ) elif (len(begin) + len(end)) == 0: - return floatX(1) + coef = 1.0 if len(end) > 0: shp_end = shape[-len(end) :] else: shp_end = np.asarray([]) shp_begin = shape[: len(begin)] - begin_coef = [floatX(t) / shp_begin[i] for i, t in enumerate(begin) if t is not None] - end_coef = [floatX(t) / shp_end[i] for i, t in enumerate(end) if t is not None] + begin_coef = [ + floatX(t) / floatX(shp_begin[i]) for i, t in enumerate(begin) if t is not None + ] + end_coef = [floatX(t) / floatX(shp_end[i]) for i, t in enumerate(end) if t is not None] coefs = begin_coef + end_coef coef = at.prod(coefs) else: raise TypeError( "Unrecognized `total_size` type, expected int or list of ints, got %r" % total_size ) - return at.as_tensor(floatX(coef)) + return at.as_tensor(coef, dtype=aesara.config.floatX) subtensor_types = ( diff --git a/pymc/model.py b/pymc/model.py index beb76e68809..9c59287ba7a 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -50,7 +50,6 @@ from pymc.aesaraf import ( compile_pymc, - floatX, gradient, hessian, inputvars, @@ -60,6 +59,7 @@ from pymc.blocking import DictToArrayBijection, RaveledVars from pymc.data import GenTensorVariable, Minibatch from pymc.distributions import joint_logpt, logp_transform +from pymc.distributions.logprob import _get_scaling from pymc.exceptions import ImputationWarning, SamplingError, ShapeError from pymc.initial_point import make_initial_point_fn from pymc.math import flatten_list @@ -1858,74 +1858,6 @@ def Deterministic(name, var, model=None, dims=None, auto=False): return var -def _get_scaling(total_size, shape, ndim): - """ - Gets scaling constant for logp. - - Parameters - ---------- - total_size: Optional[int|List[int]] - size of a fully observed data without minibatching, - `None` means data is fully observed - shape: shape - shape of an observed data - ndim: int - ndim hint - - Returns - ------- - scalar - """ - if total_size is None: - coef = 1.0 - elif isinstance(total_size, int): - if ndim >= 1: - denom = shape[0] - else: - denom = 1 - coef = floatX(total_size) / floatX(denom) - elif isinstance(total_size, (list, tuple)): - if not all(isinstance(i, int) for i in total_size if (i is not Ellipsis and i is not None)): - raise TypeError( - "Unrecognized `total_size` type, expected " - "int or list of ints, got %r" % total_size - ) - if Ellipsis in total_size: - sep = total_size.index(Ellipsis) - begin = total_size[:sep] - end = total_size[sep + 1 :] - if Ellipsis in end: - raise ValueError( - "Double Ellipsis in `total_size` is restricted, got %r" % total_size - ) - else: - begin = total_size - end = [] - if (len(begin) + len(end)) > ndim: - raise ValueError( - "Length of `total_size` is too big, " - "number of scalings is bigger that ndim, got %r" % total_size - ) - elif (len(begin) + len(end)) == 0: - coef = 1.0 - if len(end) > 0: - shp_end = shape[-len(end) :] - else: - shp_end = np.asarray([]) - shp_begin = shape[: len(begin)] - begin_coef = [ - floatX(t) / floatX(shp_begin[i]) for i, t in enumerate(begin) if t is not None - ] - end_coef = [floatX(t) / floatX(shp_end[i]) for i, t in enumerate(end) if t is not None] - coefs = begin_coef + end_coef - coef = at.prod(coefs) - else: - raise TypeError( - "Unrecognized `total_size` type, expected int or list of ints, got %r" % total_size - ) - return at.as_tensor(coef, dtype=aesara.config.floatX) - - def Potential(name, var, model=None): """Add an arbitrary factor potential to the model likelihood From 78ca582c95019eaf91d265105d649408d7ac7368 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sun, 23 Jan 2022 11:00:20 +0000 Subject: [PATCH 66/82] change job order --- .github/workflows/pytest.yml | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 94a816a9c88..af0923a1153 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -48,7 +48,6 @@ jobs: --ignore=pymc/tests/test_step.py --ignore=pymc/tests/test_tuning.py --ignore=pymc/tests/test_transforms.py - --ignore=pymc/tests/test_variational_inference.py --ignore=pymc/tests/test_sampling_jax.py --ignore=pymc/tests/test_dist_math.py --ignore=pymc/tests/test_minibatches.py @@ -67,10 +66,7 @@ jobs: --ignore=pymc/tests/test_bart.py --ignore=pymc/tests/test_missing.py - - | - pymc/tests/test_distributions.py - - pymc/tests/test_variational_inference.py - + - pymc/tests/test_distributions.py - | pymc/tests/test_modelcontext.py pymc/tests/test_dist_math.py @@ -161,12 +157,12 @@ jobs: os: [windows-latest] floatx: [float32, float64] test-subset: - - pymc/tests/test_variational_inference.py - | pymc/tests/test_initial_point.py pymc/tests/test_distributions_random.py pymc/tests/test_distributions_moments.py pymc/tests/test_distributions_timeseries.py + pymc/tests/test_variational_inference.py - | pymc/tests/test_parallel_sampling.py pymc/tests/test_sampling.py From e4cbb33a84ad9a395a9013010ddefb0a433c8f5a Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sun, 23 Jan 2022 11:56:35 +0000 Subject: [PATCH 67/82] use commit initial point in the test file --- pymc/tests/test_variational_inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc/tests/test_variational_inference.py b/pymc/tests/test_variational_inference.py index 929b20b5190..2ace9ae3264 100644 --- a/pymc/tests/test_variational_inference.py +++ b/pymc/tests/test_variational_inference.py @@ -222,7 +222,7 @@ def parametric_grouped_approxes(request): @pytest.fixture @ignore_not_implemented_inference def three_var_aevb_groups(parametric_grouped_approxes, three_var_model, aevb_initial): - one_initial_value = three_var_model.recompute_initial_point(0)[ + one_initial_value = three_var_model.compute_initial_point(0)[ three_var_model.one.tag.value_var.name ] dsize = np.prod(one_initial_value.shape[1:]) @@ -853,7 +853,7 @@ def aevb_model(): pm.Normal("y", size=(2,)) x = model.x y = model.y - xr = model.recompute_initial_point(0)[model.rvs_to_values[x].name] + xr = model.compute_initial_point(0)[model.rvs_to_values[x].name] mu = aesara.shared(xr) rho = aesara.shared(np.zeros_like(xr)) return {"model": model, "y": y, "x": x, "replace": dict(mu=mu, rho=rho)} From 8fad15714b3a4f184af31ffee829a6ddae795ff3 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sun, 23 Jan 2022 11:57:35 +0000 Subject: [PATCH 68/82] use compute initial point in the opvi.py --- pymc/variational/opvi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 0f1515cca97..c2a4b7e5b0d 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -978,7 +978,7 @@ def __init_group__(self, group): # so I have to to it by myself # 1) we need initial point (transformed space) - model_initial_point = self.model.recompute_initial_point(0) + model_initial_point = self.model.compute_initial_point(0) # 2) we'll work with a single group, a subset of the model # here we need to create a mapping to replace value_vars with slices from the approximation start_idx = 0 From 7f281bd8f4d919c7b0d691c9bb700263674cf3d0 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Mon, 24 Jan 2022 09:55:00 +0000 Subject: [PATCH 69/82] remove unnessesary pattern broadcast --- pymc/variational/opvi.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index c2a4b7e5b0d..39b018e8318 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -1200,7 +1200,6 @@ def symbolic_single_sample(self, node): """ node = self.to_flat_input(node) random = self.symbolic_random.astype(self.symbolic_initial.dtype) - random = at.patternbroadcast(random, self.symbolic_initial.broadcastable) return aesara.clone_replace(node, {self.input: random[0]}) def make_size_and_deterministic_replacements(self, s, d, more_replacements=None): From 8e8f63ecb8d9a06c515c55fcac9099992c1df06f Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Mon, 24 Jan 2022 16:24:44 +0000 Subject: [PATCH 70/82] mark test as xfail before aesara release --- pymc/tests/test_variational_inference.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pymc/tests/test_variational_inference.py b/pymc/tests/test_variational_inference.py index 2ace9ae3264..326580756cc 100644 --- a/pymc/tests/test_variational_inference.py +++ b/pymc/tests/test_variational_inference.py @@ -936,7 +936,8 @@ def binomial_model_inference(binomial_model, inference_spec): return inference_spec() -@pytest.mark.xfail("aesara.config.warn_float64 == 'raise'", reason="too strict float32") +# @pytest.mark.xfail("aesara.config.warn_float64 == 'raise'", reason="too strict float32") +@pytest.mark.xfail(reason="waits for aesara release") def test_replacements(binomial_model_inference): d = at.bscalar() d.tag.test_value = 1 From 72a75569c26702080138983642bc8be7df2587bd Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Mon, 24 Jan 2022 16:27:38 +0000 Subject: [PATCH 71/82] Do not mark anything but just wait for the new release --- pymc/tests/test_variational_inference.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pymc/tests/test_variational_inference.py b/pymc/tests/test_variational_inference.py index 326580756cc..2ace9ae3264 100644 --- a/pymc/tests/test_variational_inference.py +++ b/pymc/tests/test_variational_inference.py @@ -936,8 +936,7 @@ def binomial_model_inference(binomial_model, inference_spec): return inference_spec() -# @pytest.mark.xfail("aesara.config.warn_float64 == 'raise'", reason="too strict float32") -@pytest.mark.xfail(reason="waits for aesara release") +@pytest.mark.xfail("aesara.config.warn_float64 == 'raise'", reason="too strict float32") def test_replacements(binomial_model_inference): d = at.bscalar() d.tag.test_value = 1 From a6f54ac3d51b0abb6b598cc0353e6b3c48e85100 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sun, 13 Feb 2022 09:28:35 +0000 Subject: [PATCH 72/82] use compute_initial_point --- pymc/tests/test_variational_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/tests/test_variational_inference.py b/pymc/tests/test_variational_inference.py index 2ace9ae3264..2cb6e3bd7e5 100644 --- a/pymc/tests/test_variational_inference.py +++ b/pymc/tests/test_variational_inference.py @@ -138,7 +138,7 @@ def test_init_groups(three_var_model, raises, grouping): else: assert {pm.util.get_transformed(z) for z in g} == set(ig.group) else: - model_dim = sum(v.size for v in three_var_model.initial_point.values()) + model_dim = sum(v.size for v in three_var_model.compute_initial_point(0).values()) assert approx.ndim == model_dim trace = approx.sample(100) From b4a2f62aea54233c8f572ce411a8356cf3245657 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sun, 20 Feb 2022 12:10:40 +0300 Subject: [PATCH 73/82] Update pymc/variational/opvi.py Co-authored-by: Thomas Wiecki --- pymc/variational/opvi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 39b018e8318..f1cb1077108 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -957,7 +957,7 @@ def __init_group__(self, group): if not group: raise GroupError("Got empty group") if self.local: - raise NotImplementedInference("Local inferene aka AEVB is not fully supported in v4") + raise NotImplementedInference("Local inferene aka AEVB is not supported in v4") if self.batched: raise NotImplementedInference("Batched inferene is not supported in v4") if self.group is None: From f9d16a7d2e60febe68418d3aa3cd4561ef453756 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sun, 20 Feb 2022 09:18:00 +0000 Subject: [PATCH 74/82] run upgraded pre-commit --- pymc/variational/opvi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 39b018e8318..65031bd0db2 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -867,7 +867,7 @@ def _prepare_start(self, start=None): jitter_rvs={}, return_transformed=True, ) - start = ipfn(self.model.rng_seeder.randint(2 ** 30, dtype=np.int64)) + start = ipfn(self.model.rng_seeder.randint(2**30, dtype=np.int64)) group_vars = {self.model.rvs_to_values[v].name for v in self.group} start = {k: v for k, v in start.items() if k in group_vars} if self.batched: From 6a3ee610e506a617af649333947dd8b45a4b1b67 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sun, 20 Feb 2022 13:00:57 +0000 Subject: [PATCH 75/82] move pipe back --- .github/workflows/pytest.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index ebf015cf00a..ccf1b64d8bf 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -66,7 +66,8 @@ jobs: --ignore=pymc/tests/test_bart.py --ignore=pymc/tests/test_missing.py - - pymc/tests/test_distributions.py + - | + pymc/tests/test_distributions.py - | pymc/tests/test_modelcontext.py pymc/tests/test_dist_math.py From cd2cda917072ff92f7a4dcb35de723083952479e Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Wed, 23 Feb 2022 10:09:37 +0300 Subject: [PATCH 76/82] Update pymc/variational/opvi.py Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- pymc/variational/opvi.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 23b22e89977..cca96d02c6c 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -1003,7 +1003,6 @@ def __init_group__(self, group): else: shape = test_var.shape size = test_var.size - # TODO: There was self.ordering used in other util funcitons dtype = test_var.dtype vr = self.input[..., start_idx : start_idx + size].reshape(shape).astype(dtype) vr.name = value_var.name + "_vi_replacement" From 670edb977fb9ffa888289172316f72a88cc969c0 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Wed, 23 Feb 2022 10:10:34 +0300 Subject: [PATCH 77/82] Update pymc/variational/opvi.py Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- pymc/variational/opvi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index cca96d02c6c..41aed75cf35 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -1661,7 +1661,7 @@ def sample(self, draws=500, return_inferencedata=True, **kwargs): trace: :class:`pymc.backends.base.MultiTrace` Samples drawn from variational posterior. """ - # TODO: check for include_transformed case + # TODO: add tests for include_transformed case kwargs["log_likelihood"] = False samples = self.sample_dict_fn(draws) # type: dict From 01fb2232d5cebc360bc2da629658974334c6444f Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Wed, 23 Feb 2022 10:11:08 +0300 Subject: [PATCH 78/82] Update pymc/variational/opvi.py Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- pymc/variational/opvi.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 41aed75cf35..6a55b83ff62 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -1633,7 +1633,6 @@ def vars_names(vs): @node_property def sample_dict_fn(self): - # TODO: this breaks s = at.iscalar() names = [self.model.rvs_to_values[v].name for v in self.model.free_RVs] sampled = [self.rslice(name) for name in names] From 32006cdb722a6da046ec7d7f7431e6dd659b47e9 Mon Sep 17 00:00:00 2001 From: Ricardo Date: Wed, 23 Feb 2022 10:08:37 +0100 Subject: [PATCH 79/82] Add removed newline --- .github/workflows/pytest.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index ccf1b64d8bf..9f6129d45e1 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -68,6 +68,7 @@ jobs: - | pymc/tests/test_distributions.py + - | pymc/tests/test_modelcontext.py pymc/tests/test_dist_math.py From 1cb1418b5d06d4ea7a1bcf4ff5411372c8dd50bb Mon Sep 17 00:00:00 2001 From: Ricardo Date: Wed, 23 Feb 2022 10:25:47 +0100 Subject: [PATCH 80/82] Use compile_pymc instead of aesara.function --- pymc/variational/opvi.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 6a55b83ff62..948805297a9 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -57,7 +57,7 @@ import pymc as pm -from pymc.aesaraf import at_rng, identity, rvs_to_value_vars +from pymc.aesaraf import at_rng, compile_pymc, identity, rvs_to_value_vars from pymc.backends import NDArray from pymc.blocking import DictToArrayBijection from pymc.initial_point import make_initial_point_fn @@ -363,9 +363,9 @@ def step_function( total_grad_norm_constraint=total_grad_norm_constraint, ) if score: - step_fn = aesara.function([], updates.loss, updates=updates, **fn_kwargs) + step_fn = compile_pymc([], updates.loss, updates=updates, **fn_kwargs) else: - step_fn = aesara.function([], None, updates=updates, **fn_kwargs) + step_fn = compile_pymc([], None, updates=updates, **fn_kwargs) return step_fn @aesara.config.change_flags(compute_test_value="off") @@ -394,7 +394,7 @@ def score_function( if more_replacements is None: more_replacements = {} loss = self(sc_n_mc, more_replacements=more_replacements) - return aesara.function([], loss, **fn_kwargs) + return compile_pymc([], loss, **fn_kwargs) @aesara.config.change_flags(compute_test_value="off") def __call__(self, nmc, **kwargs): @@ -1637,7 +1637,7 @@ def sample_dict_fn(self): names = [self.model.rvs_to_values[v].name for v in self.model.free_RVs] sampled = [self.rslice(name) for name in names] sampled = self.set_size_and_deterministic(sampled, s, 0) - sample_fn = aesara.function([s], sampled) + sample_fn = compile_pymc([s], sampled) def inner(draws=100): _samples = sample_fn(draws) From ceddb5c0fb913ef5bcd36a31f3f31efa24ca3065 Mon Sep 17 00:00:00 2001 From: Ricardo Date: Wed, 23 Feb 2022 12:05:30 +0100 Subject: [PATCH 81/82] Replace None by empty list in output --- pymc/variational/opvi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 948805297a9..582d4bae19b 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -365,7 +365,7 @@ def step_function( if score: step_fn = compile_pymc([], updates.loss, updates=updates, **fn_kwargs) else: - step_fn = compile_pymc([], None, updates=updates, **fn_kwargs) + step_fn = compile_pymc([], [], updates=updates, **fn_kwargs) return step_fn @aesara.config.change_flags(compute_test_value="off") From ef5f91b41d733c64cf3ea785ec2760caae5d1bd4 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Thu, 24 Feb 2022 16:55:14 +0300 Subject: [PATCH 82/82] Apply suggestions from code review Co-authored-by: Michael Osthege --- pymc/distributions/logprob.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc/distributions/logprob.py b/pymc/distributions/logprob.py index e59384c55db..9c3ef883ae3 100644 --- a/pymc/distributions/logprob.py +++ b/pymc/distributions/logprob.py @@ -14,7 +14,7 @@ from collections.abc import Mapping from functools import singledispatch -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Sequence, Union import aesara import aesara.tensor as at @@ -44,7 +44,7 @@ def logp_transform(op: Op): return None -def _get_scaling(total_size, shape, ndim): +def _get_scaling(total_size: Optional[Union[int, Sequence[int]]], shape, ndim: int): """ Gets scaling constant for logp.