diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index a4644f01d4c..9f6129d45e1 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -48,7 +48,6 @@ jobs: --ignore=pymc/tests/test_step.py --ignore=pymc/tests/test_tuning.py --ignore=pymc/tests/test_transforms.py - --ignore=pymc/tests/test_variational_inference.py --ignore=pymc/tests/test_sampling_jax.py --ignore=pymc/tests/test_dist_math.py --ignore=pymc/tests/test_minibatches.py @@ -169,6 +168,7 @@ jobs: pymc/tests/test_distributions_random.py pymc/tests/test_distributions_moments.py pymc/tests/test_distributions_timeseries.py + pymc/tests/test_variational_inference.py - | pymc/tests/test_parallel_sampling.py pymc/tests/test_sampling.py diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index 9d872e0724c..07406b09ea0 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -478,6 +478,7 @@ def is_data(name, var) -> bool: and var not in self.model.observed_RVs and var not in self.model.free_RVs and var not in self.model.potentials + and var not in self.model.value_vars and (self.observations is None or name not in self.observations) and isinstance(var, (Constant, SharedVariable)) ) diff --git a/pymc/distributions/logprob.py b/pymc/distributions/logprob.py index 9fe2b94b994..9c3ef883ae3 100644 --- a/pymc/distributions/logprob.py +++ b/pymc/distributions/logprob.py @@ -14,8 +14,9 @@ from collections.abc import Mapping from functools import singledispatch -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Sequence, Union +import aesara import aesara.tensor as at import numpy as np @@ -43,15 +44,17 @@ def logp_transform(op: Op): return None -def _get_scaling(total_size, shape, ndim): +def _get_scaling(total_size: Optional[Union[int, Sequence[int]]], shape, ndim: int): """ - Gets scaling constant for logp + Gets scaling constant for logp. Parameters ---------- - total_size: int or list[int] + total_size: Optional[int|List[int]] + size of a fully observed data without minibatching, + `None` means data is fully observed shape: shape - shape to scale + shape of an observed data ndim: int ndim hint @@ -60,7 +63,7 @@ def _get_scaling(total_size, shape, ndim): scalar """ if total_size is None: - coef = floatX(1) + coef = 1.0 elif isinstance(total_size, int): if ndim >= 1: denom = shape[0] @@ -90,21 +93,23 @@ def _get_scaling(total_size, shape, ndim): "number of scalings is bigger that ndim, got %r" % total_size ) elif (len(begin) + len(end)) == 0: - return floatX(1) + coef = 1.0 if len(end) > 0: shp_end = shape[-len(end) :] else: shp_end = np.asarray([]) shp_begin = shape[: len(begin)] - begin_coef = [floatX(t) / shp_begin[i] for i, t in enumerate(begin) if t is not None] - end_coef = [floatX(t) / shp_end[i] for i, t in enumerate(end) if t is not None] + begin_coef = [ + floatX(t) / floatX(shp_begin[i]) for i, t in enumerate(begin) if t is not None + ] + end_coef = [floatX(t) / floatX(shp_end[i]) for i, t in enumerate(end) if t is not None] coefs = begin_coef + end_coef coef = at.prod(coefs) else: raise TypeError( "Unrecognized `total_size` type, expected int or list of ints, got %r" % total_size ) - return at.as_tensor(floatX(coef)) + return at.as_tensor(coef, dtype=aesara.config.floatX) subtensor_types = ( diff --git a/pymc/model.py b/pymc/model.py index b97bf6e4b40..ec8f4d1c6ab 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -58,6 +58,7 @@ from pymc.blocking import DictToArrayBijection, RaveledVars from pymc.data import GenTensorVariable, Minibatch from pymc.distributions import joint_logpt, logp_transform +from pymc.distributions.logprob import _get_scaling from pymc.exceptions import ImputationWarning, SamplingError, ShapeError from pymc.initial_point import make_initial_point_fn from pymc.math import flatten_list @@ -1238,6 +1239,7 @@ def register_rv( name = self.name_for(name) rv_var.name = name rv_var.tag.total_size = total_size + rv_var.tag.scaling = _get_scaling(total_size, shape=rv_var.shape, ndim=rv_var.ndim) # Associate previously unknown dimension names with # the length of the corresponding RV dimension. @@ -1870,7 +1872,7 @@ def Potential(name, var, model=None): """ model = modelcontext(model) var.name = model.name_for(name) - var.tag.scaling = None + var.tag.scaling = 1.0 model.potentials.append(var) model.add_random_variable(var) diff --git a/pymc/sampling.py b/pymc/sampling.py index bd0e1fd2ec8..628d29bc05a 100644 --- a/pymc/sampling.py +++ b/pymc/sampling.py @@ -2385,7 +2385,7 @@ def init_nuts( progressbar=progressbar, obj_optimizer=pm.adagrad_window, ) - initial_points = list(approx.sample(draws=chains)) + initial_points = list(approx.sample(draws=chains, return_inferencedata=False)) std_apoint = approx.std.eval() cov = std_apoint**2 mean = approx.mean.get_value() @@ -2402,7 +2402,7 @@ def init_nuts( progressbar=progressbar, obj_optimizer=pm.adagrad_window, ) - initial_points = list(approx.sample(draws=chains)) + initial_points = list(approx.sample(draws=chains, return_inferencedata=False)) cov = approx.std.eval() ** 2 potential = quadpotential.QuadPotentialDiag(cov) elif init == "advi_map": @@ -2416,7 +2416,7 @@ def init_nuts( progressbar=progressbar, obj_optimizer=pm.adagrad_window, ) - initial_points = list(approx.sample(draws=chains)) + initial_points = list(approx.sample(draws=chains, return_inferencedata=False)) cov = approx.std.eval() ** 2 potential = quadpotential.QuadPotentialDiag(cov) elif init == "map": diff --git a/pymc/tests/test_variational_inference.py b/pymc/tests/test_variational_inference.py index 51d753afefa..54aef4fc06d 100644 --- a/pymc/tests/test_variational_inference.py +++ b/pymc/tests/test_variational_inference.py @@ -38,12 +38,20 @@ NormalizingFlowGroup, ) from pymc.variational.inference import ADVI, ASVGD, NFVI, SVGD, FullRankADVI, fit -from pymc.variational.opvi import Approximation, Group +from pymc.variational.opvi import Approximation, Group, NotImplementedInference -# pytestmark = pytest.mark.usefixtures("strict_float32", "seeded_test") -pytestmark = pytest.mark.xfail( - reason="These tests rely on Group, which hasn't been refactored for v4" -) +pytestmark = pytest.mark.usefixtures("strict_float32", "seeded_test") + + +def ignore_not_implemented_inference(func): + @functools.wraps(func) + def new_test(*args, **kwargs): + try: + return func(*args, **kwargs) + except NotImplementedInference: + pytest.xfail("NotImplementedInference") + + return new_test @pytest.mark.parametrize("diff", ["relative", "absolute"]) @@ -114,6 +122,7 @@ def three_var_model(): (not_raises(), {MeanFieldGroup: ["one"], FullRankGroup: ["two", "three"]}), ], ) +@ignore_not_implemented_inference def test_init_groups(three_var_model, raises, grouping): with raises, three_var_model: approxes, groups = zip(*grouping.items()) @@ -129,7 +138,9 @@ def test_init_groups(three_var_model, raises, grouping): else: assert {pm.util.get_transformed(z) for z in g} == set(ig.group) else: - assert approx.ndim == three_var_model.ndim + model_dim = sum(v.size for v in three_var_model.compute_initial_point(0).values()) + assert approx.ndim == model_dim + trace = approx.sample(100) @pytest.fixture( @@ -149,6 +160,7 @@ def test_init_groups(three_var_model, raises, grouping): ], ids=lambda t: ", ".join(f"{k.__name__}: {v[0]}" for k, v in t[1].items()), ) +@ignore_not_implemented_inference def three_var_groups(request, three_var_model): kw, grouping = request.param approxes, groups = zip(*grouping.items()) @@ -208,8 +220,9 @@ def parametric_grouped_approxes(request): @pytest.fixture +@ignore_not_implemented_inference def three_var_aevb_groups(parametric_grouped_approxes, three_var_model, aevb_initial): - one_initial_value = three_var_model.compute_initial_point()[ + one_initial_value = three_var_model.compute_initial_point(0)[ three_var_model.one.tag.value_var.name ] dsize = np.prod(one_initial_value.shape[1:]) @@ -234,9 +247,8 @@ def three_var_aevb_approx(three_var_model, three_var_aevb_groups): def test_sample_aevb(three_var_aevb_approx, aevb_initial): - pm.KLqp(three_var_aevb_approx).fit( - 1, more_replacements={aevb_initial: np.zeros_like(aevb_initial.get_value())[:1]} - ) + inf = pm.KLqp(three_var_aevb_approx) + inf.fit(1, more_replacements={aevb_initial: np.zeros_like(aevb_initial.get_value())[:1]}) aevb_initial.set_value(np.random.rand(7, 7).astype("float32")) trace = three_var_aevb_approx.sample(500, return_inferencedata=False) assert set(trace.varnames) == {"one", "one_log__", "two", "three"} @@ -265,6 +277,7 @@ def test_replacements_in_sample_node_aevb(three_var_aevb_approx, aevb_initial): ).eval({inp: np.random.rand(7, 7).astype("float32")}) +@ignore_not_implemented_inference def test_vae(): minibatch_size = 10 data = pm.floatX(np.random.rand(100)) @@ -296,6 +309,7 @@ def test_vae(): ) +@ignore_not_implemented_inference def test_logq_mini_1_sample_1_var(parametric_grouped_approxes, three_var_model): cls, kw = parametric_grouped_approxes approx = cls([three_var_model.one], model=three_var_model, **kw) @@ -304,6 +318,7 @@ def test_logq_mini_1_sample_1_var(parametric_grouped_approxes, three_var_model): logq.eval() +@ignore_not_implemented_inference def test_logq_mini_2_sample_2_var(parametric_grouped_approxes, three_var_model): cls, kw = parametric_grouped_approxes approx = cls([three_var_model.one, three_var_model.two], model=three_var_model, **kw) @@ -386,6 +401,7 @@ def test_logq_globals(three_var_approx): (not_raises(), "empirical", EmpiricalGroup, {"size": 100}), ], ) +@ignore_not_implemented_inference def test_group_api_vfam(three_var_model, raises, vfam, type_, kw): with three_var_model, raises: g = Group([three_var_model.one], vfam, **kw) @@ -461,6 +477,7 @@ def test_group_api_vfam(three_var_model, raises, vfam, type_, kw): (not_raises(), dict(histogram=np.ones((20, 10, 2), "float32")), EmpiricalGroup, {}, None), ], ) +@ignore_not_implemented_inference def test_group_api_params(three_var_model, raises, params, type_, kw, formula): with three_var_model, raises: g = Group([three_var_model.one], params=params, **kw) @@ -485,6 +502,7 @@ def test_group_api_params(three_var_model, raises, params, type_, kw, formula): (NormalizingFlowGroup, NormalizingFlow, {}), ], ) +@ignore_not_implemented_inference def test_single_group_shortcuts(three_var_model, approx, kw, gcls): with three_var_model: a = approx(**kw) @@ -694,6 +712,7 @@ def init_(**kw): @pytest.fixture(scope="function") +@ignore_not_implemented_inference def inference(inference_spec, simple_model): with simple_model: return inference_spec() @@ -713,7 +732,7 @@ def fit_kwargs(inference, use_minibatch): obj_optimizer=pm.adagrad_window(learning_rate=0.01, n_win=50), n=12000 ), (FullRankADVI, "full"): dict( - obj_optimizer=pm.adagrad_window(learning_rate=0.007, n_win=50), n=6000 + obj_optimizer=pm.adagrad_window(learning_rate=0.01, n_win=50), n=6000 ), (FullRankADVI, "mini"): dict( obj_optimizer=pm.adagrad_window(learning_rate=0.007, n_win=50), n=12000 @@ -740,8 +759,8 @@ def test_fit_oo(inference, fit_kwargs, simple_model_data): trace = inference.fit(**fit_kwargs).sample(10000) mu_post = simple_model_data["mu_post"] d = simple_model_data["d"] - np.testing.assert_allclose(np.mean(trace["mu"]), mu_post, rtol=0.05) - np.testing.assert_allclose(np.std(trace["mu"]), np.sqrt(1.0 / d), rtol=0.2) + np.testing.assert_allclose(np.mean(trace.posterior["mu"]), mu_post, rtol=0.05) + np.testing.assert_allclose(np.std(trace.posterior["mu"]), np.sqrt(1.0 / d), rtol=0.2) def test_profile(inference): @@ -817,6 +836,7 @@ def fit_method_with_object(request, another_simple_model): ("nfvi=bad-formula", dict(start={}), KeyError), ], ) +@ignore_not_implemented_inference def test_fit_fn_text(method, kwargs, error, another_simple_model): with another_simple_model: if error is not None: @@ -833,11 +853,13 @@ def aevb_model(): pm.Normal("y", size=(2,)) x = model.x y = model.y - mu = aesara.shared(x.init_value) - rho = aesara.shared(np.zeros_like(x.init_value)) + xr = model.compute_initial_point(0)[model.rvs_to_values[x].name] + mu = aesara.shared(xr) + rho = aesara.shared(np.zeros_like(xr)) return {"model": model, "y": y, "x": x, "replace": dict(mu=mu, rho=rho)} +@ignore_not_implemented_inference def test_aevb(inference_spec, aevb_model): # add to inference that supports aevb x = aevb_model["x"] @@ -856,6 +878,7 @@ def test_aevb(inference_spec, aevb_model): pytest.skip("Does not support AEVB") +@ignore_not_implemented_inference def test_rowwise_approx(three_var_model, parametric_grouped_approxes): # add to inference that supports aevb cls, kw = parametric_grouped_approxes @@ -907,11 +930,13 @@ def binomial_model(): @pytest.fixture(scope="module") +@ignore_not_implemented_inference def binomial_model_inference(binomial_model, inference_spec): with binomial_model: return inference_spec() +@pytest.mark.xfail("aesara.config.warn_float64 == 'raise'", reason="too strict float32") def test_replacements(binomial_model_inference): d = at.bscalar() d.tag.test_value = 1 @@ -919,14 +944,29 @@ def test_replacements(binomial_model_inference): p = approx.model.p p_t = p**3 p_s = approx.sample_node(p_t) + assert not any( + isinstance(n.owner.op, aesara.tensor.random.basic.BetaRV) + for n in aesara.graph.ancestors([p_s]) + if n.owner + ), "p should be replaced" if aesara.config.compute_test_value != "off": assert p_s.tag.test_value.shape == p_t.tag.test_value.shape sampled = [p_s.eval() for _ in range(100)] assert any(map(operator.ne, sampled[1:], sampled[:-1])) # stochastic - - p_d = approx.sample_node(p_t, deterministic=True) - sampled = [p_d.eval() for _ in range(100)] - assert all(map(operator.eq, sampled[1:], sampled[:-1])) # deterministic + p_z = approx.sample_node(p_t, deterministic=False, size=10) + assert p_z.shape.eval() == (10,) + try: + p_z = approx.sample_node(p_t, deterministic=True, size=10) + assert p_z.shape.eval() == (10,) + except NotImplementedInference: + pass + + try: + p_d = approx.sample_node(p_t, deterministic=True) + sampled = [p_d.eval() for _ in range(100)] + assert all(map(operator.eq, sampled[1:], sampled[:-1])) # deterministic + except NotImplementedInference: + pass p_r = approx.sample_node(p_t, deterministic=d) sampled = [p_r.eval({d: 1}) for _ in range(100)] @@ -1028,6 +1068,7 @@ def test_flow_det(flow_spec): np.testing.assert_allclose(logJdet.eval(), det.eval(), atol=0.0001) +@pytest.mark.skip("normalizing flows are not fully supported") def test_flow_det_local(flow_spec): z0 = at.arange(0, 12).astype("float32") spec = flow_spec.cls.get_param_spec_for(d=12) @@ -1070,3 +1111,16 @@ def test_flow_formula(formula, length, order): if order is not None: assert flows_list == order spec(dim=2, jitter=1)(at.ones((3, 2))).eval() # should work + + +@pytest.mark.parametrize("score", [True, False]) +def test_fit_with_nans(score): + X_mean = pm.floatX(np.linspace(0, 10, 10)) + y = pm.floatX(np.random.normal(X_mean * 4, 0.05)) + with pm.Model(): + inp = pm.Normal("X", X_mean, size=X_mean.shape) + coef = pm.Normal("b", 4.0) + mean = inp * coef + pm.Normal("y", mean, 0.1, observed=y) + with pytest.raises(FloatingPointError) as e: + advi = pm.fit(100, score=score, obj_optimizer=pm.adam(learning_rate=float("nan"))) diff --git a/pymc/variational/approximations.py b/pymc/variational/approximations.py index fbce2d345fe..f4e324160a1 100644 --- a/pymc/variational/approximations.py +++ b/pymc/variational/approximations.py @@ -23,10 +23,14 @@ from pymc.blocking import DictToArrayBijection from pymc.distributions.dist_math import rho2sigma -from pymc.initial_point import make_initial_point_fn from pymc.math import batched_diag from pymc.variational import flows, opvi -from pymc.variational.opvi import Approximation, Group, node_property +from pymc.variational.opvi import ( + Approximation, + Group, + NotImplementedInference, + node_property, +) __all__ = ["MeanField", "FullRank", "Empirical", "NormalizingFlow", "sample_approx"] @@ -70,17 +74,7 @@ def __init_group__(self, group): self._finalize_init() def create_shared_params(self, start=None): - ipfn = make_initial_point_fn( - model=self.model, - overrides=start, - jitter_rvs={}, - return_transformed=True, - ) - start = ipfn(self.model.rng_seeder.randint(2**30, dtype=np.int64)) - if self.batched: - start = start[self.group[0].name][0] - else: - start = DictToArrayBijection.map(start) + start = self._prepare_start(start) rho = np.zeros((self.ddim,)) if self.batched: start = np.tile(start, (self.bdim, 1)) @@ -102,7 +96,8 @@ def symbolic_logq_not_scaled(self): z0 = self.symbolic_initial std = rho2sigma(self.rho) logdet = at.log(std) - logq = pm.Normal.dist().logp(z0) - logdet + quaddist = -0.5 * z0**2 - at.log((2 * np.pi) ** 0.5) + logq = quaddist - logdet return logq.sum(range(1, logq.ndim)) @@ -126,17 +121,7 @@ def __init_group__(self, group): self._finalize_init() def create_shared_params(self, start=None): - ipfn = make_initial_point_fn( - model=self.model, - overrides=start, - jitter_rvs={}, - return_transformed=True, - ) - start = ipfn(self.model.rng_seeder.randint(2**30, dtype=np.int64)) - if self.batched: - start = start[self.group[0].name][0] - else: - start = DictToArrayBijection.map(start) + start = self._prepare_start(start) n = self.ddim L_tril = np.eye(n)[np.tril_indices(n)].astype(aesara.config.floatX) if self.batched: @@ -153,6 +138,8 @@ def L(self): else: L = at.zeros((self.ddim, self.ddim)) L = at.set_subtensor(L[self.tril_indices], self.params_dict["L_tril"]) + Ld = L[..., np.arange(self.ddim), np.arange(self.ddim)] + L = at.set_subtensor(Ld, rho2sigma(Ld)) return L @node_property @@ -185,18 +172,12 @@ def tril_indices(self): @node_property def symbolic_logq_not_scaled(self): - z = self.symbolic_random - if self.batched: - - def logq(z_b, mu_b, L_b): - return pm.MvNormal.dist(mu=mu_b, chol=L_b).logp(z_b) - - # it's gonna be so slow - # scan is computed over batch and then summed up - # output shape is (batch, samples) - return aesara.scan(logq, [z.swapaxes(0, 1), self.mean, self.L])[0].sum(0) - else: - return pm.MvNormal.dist(mu=self.mean, chol=self.L).logp(z) + z0 = self.symbolic_initial + diag = at.diagonal(self.L, 0, self.L.ndim - 2, self.L.ndim - 1) + logdet = at.log(diag) + quaddist = -0.5 * z0**2 - at.log((2 * np.pi) ** 0.5) + logq = quaddist - logdet + return logq.sum(range(1, logq.ndim)) @node_property def symbolic_random(self): @@ -241,14 +222,7 @@ def create_shared_params(self, trace=None, size=None, jitter=1, start=None): if size is None: raise opvi.ParametrizationError("Need `trace` or `size` to initialize") else: - ipfn = make_initial_point_fn( - model=self.model, - overrides=start, - jitter_rvs={}, - return_transformed=True, - ) - start = ipfn(self.model.rng_seeder.randint(2**30, dtype=np.int64)) - start = pm.floatX(DictToArrayBijection.map(start)) + start = self._prepare_start(start) # Initialize particles histogram = np.tile(start, (size, 1)) histogram += pm.floatX(np.random.normal(0, jitter, histogram.shape)) @@ -258,13 +232,15 @@ def create_shared_params(self, trace=None, size=None, jitter=1, start=None): i = 0 for t in trace.chains: for j in range(len(trace)): - histogram[i] = DictToArrayBijection.map(trace.point(j, t)) + histogram[i] = DictToArrayBijection.map(trace.point(j, t)).data i += 1 return dict(histogram=aesara.shared(pm.floatX(histogram), "histogram")) def _check_trace(self): trace = self._kwargs.get("trace", None) - if trace is not None and not all([var.name in trace.varnames for var in self.group]): + if trace is not None and not all( + [self.model.rvs_to_values[var].name in trace.varnames for var in self.group] + ): raise ValueError("trace has not all free RVs in the group") def randidx(self, size=None): @@ -285,15 +261,21 @@ def randidx(self, size=None): def _new_initial(self, size, deterministic, more_replacements=None): aesara_condition_is_here = isinstance(deterministic, Variable) + if size is None: + size = 1 + size = at.as_tensor(size) if aesara_condition_is_here: return at.switch( deterministic, - at.repeat(self.mean.dimshuffle("x", 0), size if size is not None else 1, -1), + at.repeat(self.mean.reshape((1, -1)), size, -1), self.histogram[self.randidx(size)], ) else: if deterministic: - return at.repeat(self.mean.dimshuffle("x", 0), size if size is not None else 1, -1) + raise NotImplementedInference( + "Deterministic sampling from a Histogram is broken in v4" + ) + return at.repeat(self.mean.reshape((1, -1)), size, -1) else: return self.histogram[self.randidx(size)] @@ -378,6 +360,7 @@ class NormalizingFlowGroup(Group): @aesara.config.change_flags(compute_test_value="off") def __init_group__(self, group): + raise NotImplementedInference("Normalizing flows are not yet ported to v4") super().__init_group__(group) # objects to be resolved # 1. string formula @@ -487,7 +470,7 @@ def params(self): @node_property def symbolic_logq_not_scaled(self): z0 = self.symbolic_initial - q0 = pm.Normal.dist().logp(z0).sum(range(1, z0.ndim)) + q0 = pm.Normal.logp(z0, 0, 1).sum(range(1, z0.ndim)) return q0 - self.flow.sum_logdets @property diff --git a/pymc/variational/flows.py b/pymc/variational/flows.py index 6b2ef8ca5bf..a509cce64f2 100644 --- a/pymc/variational/flows.py +++ b/pymc/variational/flows.py @@ -225,12 +225,6 @@ def logdet(self): @aesara.config.change_flags(compute_test_value="off") def forward_pass(self, z0): ret = aesara.clone_replace(self.forward, {self.root.z0: z0}) - try: - ret.tag.test_value = np.random.normal(size=z0.tag.test_value.shape).astype( - self.z0.dtype - ) - except AttributeError: - ret.tag.test_value = self.root.z0.tag.test_value return ret __call__ = forward_pass diff --git a/pymc/variational/inference.py b/pymc/variational/inference.py index 1ab75957fff..4bb29897c25 100644 --- a/pymc/variational/inference.py +++ b/pymc/variational/inference.py @@ -166,12 +166,10 @@ def _iterate_without_loss(self, s, _, step_func, progress, callbacks): if np.isnan(current_param).any(): name_slc = [] tmp_hold = list(range(current_param.size)) - # XXX: This needs to be refactored - vmap = None # self.approx.groups[0].bij.ordering.vmap - for vmap_ in vmap: - slclen = len(tmp_hold[vmap_.slc]) + for varname, slice_info in self.approx.groups[0].ordering.items(): + slclen = len(tmp_hold[slice_info[1]]) for j in range(slclen): - name_slc.append((vmap_.var, j)) + name_slc.append((varname, j)) index = np.where(np.isnan(current_param))[0] errmsg = ["NaN occurred in optimization. "] suggest_solution = ( @@ -210,18 +208,16 @@ def _infmean(input_array): try: for i in progress: e = step_func() - if np.isnan(e): # pragma: no cover + if np.isnan(e): scores = scores[:i] self.hist = np.concatenate([self.hist, scores]) current_param = self.approx.params[0].get_value() name_slc = [] tmp_hold = list(range(current_param.size)) - # XXX: This needs to be refactored - vmap = None # self.approx.groups[0].bij.ordering.vmap - for vmap_ in vmap: - slclen = len(tmp_hold[vmap_.slc]) + for varname, slice_info in self.approx.groups[0].ordering.items(): + slclen = len(tmp_hold[slice_info[1]]) for j in range(slclen): - name_slc.append((vmap_.var, j)) + name_slc.append((varname, j)) index = np.where(np.isnan(current_param))[0] errmsg = ["NaN occurred in optimization. "] suggest_solution = ( diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 0503e204a7b..582d4bae19b 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -57,15 +57,12 @@ import pymc as pm -from pymc.aesaraf import at_rng, identity +from pymc.aesaraf import at_rng, compile_pymc, identity, rvs_to_value_vars from pymc.backends import NDArray +from pymc.blocking import DictToArrayBijection +from pymc.initial_point import make_initial_point_fn from pymc.model import modelcontext -from pymc.util import ( - WithMemoization, - get_default_varnames, - get_transformed, - locally_cachedmethod, -) +from pymc.util import WithMemoization, locally_cachedmethod from pymc.variational.updates import adagrad_window from pymc.vartypes import discrete_types @@ -76,6 +73,10 @@ class VariationalInferenceError(Exception): """Exception for VI specific cases""" +class NotImplementedInference(VariationalInferenceError, NotImplementedError): + """Marking non functional parts of code""" + + class ExplicitInferenceError(VariationalInferenceError, TypeError): """Exception for bad explicit inference""" @@ -362,9 +363,9 @@ def step_function( total_grad_norm_constraint=total_grad_norm_constraint, ) if score: - step_fn = aesara.function([], updates.loss, updates=updates, **fn_kwargs) + step_fn = compile_pymc([], updates.loss, updates=updates, **fn_kwargs) else: - step_fn = aesara.function([], None, updates=updates, **fn_kwargs) + step_fn = compile_pymc([], [], updates=updates, **fn_kwargs) return step_fn @aesara.config.change_flags(compute_test_value="off") @@ -393,7 +394,7 @@ def score_function( if more_replacements is None: more_replacements = {} loss = self(sc_n_mc, more_replacements=more_replacements) - return aesara.function([], loss, **fn_kwargs) + return compile_pymc([], loss, **fn_kwargs) @aesara.config.change_flags(compute_test_value="off") def __call__(self, nmc, **kwargs): @@ -833,9 +834,6 @@ def __init__( options=None, **kwargs, ): - # XXX: Needs to be refactored for v4 - raise NotImplementedError("This class needs to be refactored for v4") - if local and not self.supports_batched: raise LocalGroupError("%s does not support local groups" % self.__class__) if local and rowwise: @@ -854,12 +852,30 @@ def __init__( self.group = group self.user_params = params self._user_params = None + self.replacements = collections.OrderedDict() + self.ordering = collections.OrderedDict() # save this stuff to use in __init_group__ later self._kwargs = kwargs if self.group is not None: # init can be delayed self.__init_group__(self.group) + def _prepare_start(self, start=None): + ipfn = make_initial_point_fn( + model=self.model, + overrides=start, + jitter_rvs={}, + return_transformed=True, + ) + start = ipfn(self.model.rng_seeder.randint(2**30, dtype=np.int64)) + group_vars = {self.model.rvs_to_values[v].name for v in self.group} + start = {k: v for k, v in start.items() if k in group_vars} + if self.batched: + start = start[self.group[0].name][0] + else: + start = DictToArrayBijection.map(start).data + return start + @classmethod def get_param_spec_for(cls, **kwargs): res = dict() @@ -940,6 +956,10 @@ def _input_type(self, name): def __init_group__(self, group): if not group: raise GroupError("Got empty group") + if self.local: + raise NotImplementedInference("Local inferene aka AEVB is not supported in v4") + if self.batched: + raise NotImplementedInference("Batched inferene is not supported in v4") if self.group is None: # delayed init self.group = group @@ -956,42 +976,44 @@ def __init_group__(self, group): self.input = self._input_type(self.__class__.__name__ + "_symbolic_input") # I do some staff that is not supported by standard __init__ # so I have to to it by myself - self.group = [get_transformed(var) for var in self.group] - # XXX: This needs to be refactored - # self.ordering = ArrayOrdering([]) - self.replacements = dict() + # 1) we need initial point (transformed space) + model_initial_point = self.model.compute_initial_point(0) + # 2) we'll work with a single group, a subset of the model + # here we need to create a mapping to replace value_vars with slices from the approximation + start_idx = 0 for var in self.group: if var.type.numpy_dtype.name in discrete_types: raise ParametrizationError(f"Discrete variables are not supported by VI: {var}") - begin = self.ddim + # 3) This is the way to infer shape and dtype of the variable + value_var = self.model.rvs_to_values[var] + test_var = model_initial_point[value_var.name] if self.batched: + # Leave a more complicated case for future work if var.ndim < 1: if self.local: raise LocalGroupError("Local variable should not be scalar") else: raise BatchedGroupError("Batched variable should not be scalar") - # XXX: This needs to be refactored - # self.ordering.size += None # (np.prod(var.dshape[1:])).astype(int) + size = test_var[0].size if self.local: - # XXX: This needs to be refactored - shape = None # (-1,) + var.dshape[1:] + shape = (-1,) + test_var.shape[1:] else: - # XXX: This needs to be refactored - shape = None # var.dshape + shape = test_var.shape else: - # XXX: This needs to be refactored - # self.ordering.size += None # var.dsize - # XXX: This needs to be refactored - shape = None # var.dshape - # end = self.ordering.size - # XXX: This needs to be refactored - vmap = None # VarMap(var.name, slice(begin, end), shape, var.dtype) - # self.ordering.vmap.append(vmap) - # self.ordering.by_name[vmap.var] = vmap - vr = self.input[..., vmap.slc].reshape(shape).astype(vmap.dtyp) - vr.name = vmap.var + "_vi_replacement" - self.replacements[var] = vr + shape = test_var.shape + size = test_var.size + dtype = test_var.dtype + vr = self.input[..., start_idx : start_idx + size].reshape(shape).astype(dtype) + vr.name = value_var.name + "_vi_replacement" + self.replacements[value_var] = vr + self.ordering[value_var.name] = ( + value_var.name, + slice(start_idx, start_idx + size), + shape, + dtype, + ) + start_idx += size def _finalize_init(self): """*Dev* - clean up after init""" @@ -1043,8 +1065,7 @@ def _new_initial_shape(self, size, dim, more_replacements=None): def bdim(self): if not self.local: if self.batched: - # XXX: This needs to be refactored - return None # self.ordering.vmap[0].shp[0] + return next(iter(self.ordering.values()))[2][0] else: return 1 else: @@ -1052,13 +1073,14 @@ def bdim(self): @node_property def ndim(self): - # XXX: This needs to be refactored - return None # self.ordering.size * self.bdim + if self.batched: + return self.ordering.size * self.bdim + else: + return self.ddim @property def ddim(self): - # XXX: This needs to be refactored - return None # self.ordering.size + return sum(s.stop - s.start for _, s, _, _ in self.ordering.values()) def _new_initial(self, size, deterministic, more_replacements=None): """*Dev* - allocates new initial random generator @@ -1177,7 +1199,6 @@ def symbolic_single_sample(self, node): """ node = self.to_flat_input(node) random = self.symbolic_random.astype(self.symbolic_initial.dtype) - random = at.patternbroadcast(random, self.symbolic_initial.broadcastable) return aesara.clone_replace(node, {self.input: random[0]}) def make_size_and_deterministic_replacements(self, s, d, more_replacements=None): @@ -1206,7 +1227,7 @@ def make_size_and_deterministic_replacements(self, s, d, more_replacements=None) @node_property def symbolic_normalizing_constant(self): """*Dev* - normalizing constant for `self.logq`, scales it to `minibatch_size` instead of `total_size`""" - t = self.to_flat_input(at.max([v.scaling for v in self.group])) + t = self.to_flat_input(at.max([v.tag.scaling for v in self.group])) t = self.symbolic_single_sample(t) return pm.floatX(t) @@ -1222,7 +1243,7 @@ def symbolic_logq_not_scaled(self): def symbolic_logq(self): """*Dev* - correctly scaled `self.symbolic_logq_not_scaled`""" if self.local: - s = self.group[0].scaling + s = self.group[0].tag.scaling s = self.to_flat_input(s) s = self.symbolic_single_sample(s) return self.symbolic_logq_not_scaled * s @@ -1359,7 +1380,7 @@ def symbolic_normalizing_constant(self): """ t = at.max( self.collect("symbolic_normalizing_constant") - + [var.scaling for var in self.model.observed_RVs] + + [var.tag.scaling for var in self.model.observed_RVs] ) t = at.switch(self._scale_cost_to_minibatch, t, at.constant(1, dtype=t.dtype)) return pm.floatX(t) @@ -1516,29 +1537,30 @@ def set_size_and_deterministic(self, node, s, d, more_replacements=None): try_to_set_test_value(_node, node, s) return node - def to_flat_input(self, node): + def to_flat_input(self, node, more_replacements=None): """*Dev* - replace vars with flattened view stored in `self.inputs`""" + more_replacements = more_replacements or {} + node = aesara.clone_replace(node, more_replacements) return aesara.clone_replace(node, self.replacements) - def symbolic_sample_over_posterior(self, node): + def symbolic_sample_over_posterior(self, node, more_replacements=None): """*Dev* - performs sampling of node applying independent samples from posterior each time. Note that it is done symbolically and this node needs :func:`set_size_and_deterministic` call """ - node = self.to_flat_input(node) + node = self.to_flat_input(node, more_replacements=more_replacements) - def sample(*post, node, inputs): - node, inputs = post[-2:] - return aesara.clone_replace(node, dict(zip(inputs, post))) + def sample(*post): + return aesara.clone_replace(node, dict(zip(self.inputs, post))) - nodes, _ = aesara.scan(sample, self.symbolic_randoms, non_sequences=[node, inputs]) + nodes, _ = aesara.scan(sample, self.symbolic_randoms) return nodes - def symbolic_single_sample(self, node): + def symbolic_single_sample(self, node, more_replacements=None): """*Dev* - performs sampling of node applying single sample from posterior. Note that it is done symbolically and this node needs :func:`set_size_and_deterministic` call with `size=1` """ - node = self.to_flat_input(node) + node = self.to_flat_input(node, more_replacements=more_replacements) post = [v[0] for v in self.symbolic_randoms] inp = self.inputs return aesara.clone_replace(node, dict(zip(inp, post))) @@ -1574,12 +1596,18 @@ def sample_node(self, node, size=None, deterministic=False, more_replacements=No sampled node(s) with replacements """ node_in = node - node = aesara.clone_replace(node, more_replacements) + if more_replacements: + node = aesara.clone_replace(node, more_replacements) + if not isinstance(node, (list, tuple)): + node = [node] + node, _ = rvs_to_value_vars(node, apply_transforms=True) + if not isinstance(node_in, (list, tuple)): + node = node[0] if size is None: node_out = self.symbolic_single_sample(node) else: node_out = self.symbolic_sample_over_posterior(node) - node_out = self.set_size_and_deterministic(node_out, size, deterministic, more_replacements) + node_out = self.set_size_and_deterministic(node_out, size, deterministic) try_to_set_test_value(node_in, node_out, size) return node_out @@ -1589,7 +1617,7 @@ def rslice(self, name): """ def vars_names(vs): - return {v.name for v in vs} + return {self.model.rvs_to_values[v].name for v in vs} for vars_, random, ordering in zip( self.collect("group"), self.symbolic_randoms, self.collect("ordering") @@ -1606,42 +1634,40 @@ def vars_names(vs): @node_property def sample_dict_fn(self): s = at.iscalar() - names = [v.name for v in self.model.free_RVs] + names = [self.model.rvs_to_values[v].name for v in self.model.free_RVs] sampled = [self.rslice(name) for name in names] sampled = self.set_size_and_deterministic(sampled, s, 0) - sample_fn = aesara.function([s], sampled) + sample_fn = compile_pymc([s], sampled) def inner(draws=100): _samples = sample_fn(draws) - return {v_.name: s_ for v_, s_ in zip(self.model.free_RVs, _samples)} + return {v_: s_ for v_, s_ in zip(names, _samples)} return inner - def sample(self, draws=500, include_transformed=True): + def sample(self, draws=500, return_inferencedata=True, **kwargs): """Draw samples from variational posterior. Parameters ---------- draws: `int` Number of random samples. - include_transformed: `bool` - If True, transformed variables are also sampled. Default is False. + return_inferencedata: `bool` + Return trace in Arviz format Returns ------- trace: :class:`pymc.backends.base.MultiTrace` Samples drawn from variational posterior. """ - vars_sampled = get_default_varnames( - [v.tag.value_var for v in self.model.unobserved_RVs], - include_transformed=include_transformed, - ) + # TODO: add tests for include_transformed case + kwargs["log_likelihood"] = False + samples = self.sample_dict_fn(draws) # type: dict points = ({name: records[i] for name, records in samples.items()} for i in range(draws)) trace = NDArray( model=self.model, - vars=vars_sampled, test_point={name: records[0] for name, records in samples.items()}, ) try: @@ -1650,7 +1676,12 @@ def sample(self, draws=500, include_transformed=True): trace.record(point) finally: trace.close() - return pm.sampling.MultiTrace([trace]) + + trace = pm.sampling.MultiTrace([trace]) + if not return_inferencedata: + return trace + else: + return pm.to_inference_data(trace, model=self.model, **kwargs) @property def ndim(self): diff --git a/pymc/variational/updates.py b/pymc/variational/updates.py old mode 100755 new mode 100644