diff --git a/docs/api_reference.rst b/docs/api_reference.rst index 18ddc24b..f6a43a42 100644 --- a/docs/api_reference.rst +++ b/docs/api_reference.rst @@ -12,8 +12,8 @@ methods in the current release of PyMC experimental. :toctree: generated/ as_model - MarginalModel marginalize + recover_marginals model_builder.ModelBuilder Inference @@ -53,6 +53,7 @@ Utils spline.bspline_interpolation prior.prior_from_idata + model_equivalence.equivalent_models Statespace Models diff --git a/pymc_experimental/__init__.py b/pymc_experimental/__init__.py index c19097e6..30bd3c56 100644 --- a/pymc_experimental/__init__.py +++ b/pymc_experimental/__init__.py @@ -16,7 +16,11 @@ from pymc_experimental import gp, statespace, utils from pymc_experimental.distributions import * from pymc_experimental.inference.fit import fit -from pymc_experimental.model.marginal.marginal_model import MarginalModel, marginalize +from pymc_experimental.model.marginal.marginal_model import ( + MarginalModel, + marginalize, + recover_marginals, +) from pymc_experimental.model.model_api import as_model from pymc_experimental.version import __version__ diff --git a/pymc_experimental/distributions/timeseries.py b/pymc_experimental/distributions/timeseries.py index d4cd9435..62da8229 100644 --- a/pymc_experimental/distributions/timeseries.py +++ b/pymc_experimental/distributions/timeseries.py @@ -214,8 +214,8 @@ def transition(*args): discrete_mc_op = DiscreteMarkovChainRV( inputs=[P_, steps_, init_dist_, state_rng], outputs=[state_next_rng, discrete_mc_], - ndim_supp=1, n_lags=n_lags, + extended_signature="(p,p),(),(p),[rng]->[rng],(t)", ) discrete_mc = discrete_mc_op(P, steps, init_dist, state_rng) diff --git a/pymc_experimental/model/marginal/distributions.py b/pymc_experimental/model/marginal/distributions.py index 661665e9..cb211949 100644 --- a/pymc_experimental/model/marginal/distributions.py +++ b/pymc_experimental/model/marginal/distributions.py @@ -1,20 +1,25 @@ +import warnings + from collections.abc import Sequence import numpy as np import pytensor.tensor as pt from pymc.distributions import Bernoulli, Categorical, DiscreteUniform +from pymc.distributions.distribution import _support_point, support_point from pymc.logprob.abstract import MeasurableOp, _logprob from pymc.logprob.basic import conditional_logp, logp from pymc.pytensorf import constant_fold from pytensor import Variable from pytensor.compile.builders import OpFromGraph from pytensor.compile.mode import Mode -from pytensor.graph import Op, vectorize_graph +from pytensor.graph import FunctionGraph, Op, vectorize_graph +from pytensor.graph.basic import equal_computations from pytensor.graph.replace import clone_replace, graph_replace from pytensor.scan import map as scan_map from pytensor.scan import scan from pytensor.tensor import TensorVariable +from pytensor.tensor.random.type import RandomType from pymc_experimental.distributions import DiscreteMarkovChain @@ -43,6 +48,74 @@ def support_axes(self) -> tuple[tuple[int]]: ) return tuple(support_axes_vars) + def __eq__(self, other): + # Just to allow easy testing of equivalent models, + # This can be removed once https://github.com/pymc-devs/pytensor/issues/1114 is fixed + if type(self) is not type(other): + return False + + return equal_computations( + self.inner_outputs, + other.inner_outputs, + self.inner_inputs, + other.inner_inputs, + ) + + def __hash__(self): + # Just to allow easy testing of equivalent models, + # This can be removed once https://github.com/pymc-devs/pytensor/issues/1114 is fixed + return hash((type(self), len(self.inner_inputs), len(self.inner_outputs))) + + +@_support_point.register +def support_point_marginal_rv(op: MarginalRV, rv, *inputs): + """Support point for a marginalized RV. + + The support point of a marginalized RV is the support point of the inner RV, + conditioned on the marginalized RV taking its support point. + """ + outputs = rv.owner.outputs + + inner_rv = op.inner_outputs[outputs.index(rv)] + marginalized_inner_rv, *other_dependent_inner_rvs = ( + out + for out in op.inner_outputs + if out is not inner_rv and not isinstance(out.type, RandomType) + ) + + # Replace references to inner rvs by the dummy variables (including the marginalized RV) + # This is necessary because the inner RVs may depend on each other + marginalized_inner_rv_dummy = marginalized_inner_rv.clone() + other_dependent_inner_rv_to_dummies = { + inner_rv: inner_rv.clone() for inner_rv in other_dependent_inner_rvs + } + inner_rv = clone_replace( + inner_rv, + replace={marginalized_inner_rv: marginalized_inner_rv_dummy} + | other_dependent_inner_rv_to_dummies, + ) + + # Get support point of inner RV and marginalized RV + inner_rv_support_point = support_point(inner_rv) + marginalized_inner_rv_support_point = support_point(marginalized_inner_rv) + + replacements = [ + # Replace the marginalized RV dummy by its support point + (marginalized_inner_rv_dummy, marginalized_inner_rv_support_point), + # Replace other dependent RVs dummies by the respective outer outputs. + # PyMC will replace them by their support points later + *( + (v, outputs[op.inner_outputs.index(k)]) + for k, v in other_dependent_inner_rv_to_dummies.items() + ), + # Replace outer input RVs + *zip(op.inner_inputs, inputs), + ] + fgraph = FunctionGraph(outputs=[inner_rv_support_point], clone=False) + fgraph.replace_all(replacements, import_missing=True) + [rv_support_point] = fgraph.outputs + return rv_support_point + class MarginalFiniteDiscreteRV(MarginalRV): """Base class for Marginalized Finite Discrete RVs""" @@ -132,12 +205,22 @@ def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> tuple[Var Whereas `OpFromGraph` "wraps" a graph inside a single Op, this function "unwraps" the inner graph. """ - return clone_replace( + return graph_replace( op.inner_outputs, replace=tuple(zip(op.inner_inputs, inputs)), ) +def warn_logp_non_separable(values): + if len(values) > 1: + warnings.warn( + "There are multiple dependent variables in a FiniteDiscreteMarginalRV. " + f"Their joint logp terms will be assigned to the first value: {values[0]}.", + UserWarning, + stacklevel=2, + ) + + DUMMY_ZERO = pt.constant(0, name="dummy_zero") @@ -200,6 +283,7 @@ def logp_fn(marginalized_rv_const, *non_sequences): joint_logp = align_logp_dims(dims=op.dims_connections[0], logp=joint_logp) # We have to add dummy logps for the remaining value variables, otherwise PyMC will raise + warn_logp_non_separable(values) dummy_logps = (DUMMY_ZERO,) * (len(values) - 1) return joint_logp, *dummy_logps @@ -272,5 +356,6 @@ def step_alpha(logp_emission, log_alpha, log_P): # If there are multiple emission streams, we have to add dummy logps for the remaining value variables. The first # return is the joint probability of everything together, but PyMC still expects one logp for each emission stream. + warn_logp_non_separable(values) dummy_logps = (DUMMY_ZERO,) * (len(values) - 1) return joint_logp, *dummy_logps diff --git a/pymc_experimental/model/marginal/graph_analysis.py b/pymc_experimental/model/marginal/graph_analysis.py index 62ac2abb..87b242e4 100644 --- a/pymc_experimental/model/marginal/graph_analysis.py +++ b/pymc_experimental/model/marginal/graph_analysis.py @@ -4,6 +4,7 @@ from itertools import zip_longest from pymc import SymbolicRandomVariable +from pymc.model.fgraph import ModelVar from pytensor.compile import SharedVariable from pytensor.graph import Constant, Variable, ancestors from pytensor.graph.basic import io_toposort @@ -35,12 +36,12 @@ def static_shape_ancestors(vars): def find_conditional_input_rvs(output_rvs, all_rvs): """Find conditionally indepedent input RVs.""" - blockers = [other_rv for other_rv in all_rvs if other_rv not in output_rvs] - blockers += static_shape_ancestors(tuple(all_rvs) + tuple(output_rvs)) + other_rvs = [other_rv for other_rv in all_rvs if other_rv not in output_rvs] + blockers = other_rvs + static_shape_ancestors(tuple(all_rvs) + tuple(output_rvs)) return [ var for var in ancestors(output_rvs, blockers=blockers) - if var in blockers or (var.owner is None and not isinstance(var, Constant | SharedVariable)) + if var in other_rvs ] @@ -141,6 +142,9 @@ def _subgraph_batch_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars) # None of the inputs are related to the batch_axes of the input_vars continue + elif isinstance(node.op, ModelVar): + var_dims[node.outputs[0]] = inputs_dims[0] + elif isinstance(node.op, DimShuffle): [input_dims] = inputs_dims output_dims = tuple(None if i == "x" else input_dims[i] for i in node.op.new_order) diff --git a/pymc_experimental/model/marginal/marginal_model.py b/pymc_experimental/model/marginal/marginal_model.py index b4700c3d..b0e44b2d 100644 --- a/pymc_experimental/model/marginal/marginal_model.py +++ b/pymc_experimental/model/marginal/marginal_model.py @@ -1,7 +1,6 @@ import warnings from collections.abc import Sequence -from typing import Union import numpy as np import pymc @@ -13,21 +12,41 @@ from pymc.distributions.transforms import Chain from pymc.logprob.transforms import IntervalTransform from pymc.model import Model -from pymc.pytensorf import compile_pymc, constant_fold -from pymc.util import RandomState, _get_seeds_per_chain, treedict +from pymc.model.fgraph import ( + ModelFreeRV, + ModelValuedVar, + fgraph_from_model, + model_free_rv, + model_from_fgraph, +) +from pymc.pytensorf import collect_default_updates, compile_pymc, constant_fold, toposort_replace +from pymc.util import RandomState, _get_seeds_per_chain +from pytensor import In, Out from pytensor.compile import SharedVariable -from pytensor.graph import FunctionGraph, clone_replace, graph_inputs -from pytensor.graph.replace import vectorize_graph +from pytensor.graph import ( + FunctionGraph, + Variable, + clone_replace, + graph_inputs, + graph_replace, + node_rewriter, + vectorize_graph, +) +from pytensor.graph.rewriting.basic import in2out from pytensor.tensor import TensorVariable -from pytensor.tensor.special import log_softmax __all__ = ["MarginalModel", "marginalize"] +from pytensor.tensor.random.type import RandomType +from pytensor.tensor.special import log_softmax + from pymc_experimental.distributions import DiscreteMarkovChain from pymc_experimental.model.marginal.distributions import ( MarginalDiscreteMarkovChainRV, MarginalFiniteDiscreteRV, + MarginalRV, get_domain_of_finite_discrete_rv, + inline_ofg_outputs, reduce_batch_dependent_logps, ) from pymc_experimental.model.marginal.graph_analysis import ( @@ -87,479 +106,442 @@ class MarginalModel(Model): """ def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.marginalized_rvs = [] - self._marginalized_named_vars_to_dims = {} + raise TypeError( + "MarginalModel was deprecated in favor of `marginalize` which now returns a PyMC model" + ) - def _delete_rv_mappings(self, rv: TensorVariable) -> None: - """Remove all model mappings referring to rv - This can be used to "delete" an RV from a model - """ - assert rv in self.basic_RVs, "rv is not part of the Model" +def _warn_interval_transform(rv_to_marginalize, replaced_vars: Sequence[ModelValuedVar]) -> None: + for replaced_var in replaced_vars: + if not isinstance(replaced_var.owner.op, ModelValuedVar): + raise TypeError(f"{replaced_var} is not a ModelValuedVar") - name = rv.name - self.named_vars.pop(name) - if name in self.named_vars_to_dims: - self.named_vars_to_dims.pop(name) + if not isinstance(replaced_var.owner.op, ModelFreeRV): + continue - value = self.rvs_to_values.pop(rv) - self.values_to_rvs.pop(value) + if replaced_var is rv_to_marginalize: + continue - self.rvs_to_transforms.pop(rv) - if rv in self.free_RVs: - self.free_RVs.remove(rv) - self.rvs_to_initial_values.pop(rv) - else: - self.observed_RVs.remove(rv) - - def _transfer_rv_mappings(self, old_rv: TensorVariable, new_rv: TensorVariable) -> None: - """Transfer model mappings from old_rv to new_rv""" - - assert old_rv in self.basic_RVs, "old_rv is not part of the Model" - assert new_rv not in self.basic_RVs, "new_rv is already part of the Model" - - self.named_vars.pop(old_rv.name) - new_rv.name = old_rv.name - self.named_vars[new_rv.name] = new_rv - if old_rv in self.named_vars_to_dims: - self._RV_dims[new_rv] = self._RV_dims.pop(old_rv) - - value = self.rvs_to_values.pop(old_rv) - self.rvs_to_values[new_rv] = value - self.values_to_rvs[value] = new_rv - - self.rvs_to_transforms[new_rv] = self.rvs_to_transforms.pop(old_rv) - if old_rv in self.free_RVs: - index = self.free_RVs.index(old_rv) - self.free_RVs.pop(index) - self.free_RVs.insert(index, new_rv) - self.rvs_to_initial_values[new_rv] = self.rvs_to_initial_values.pop(old_rv) - elif old_rv in self.observed_RVs: - index = self.observed_RVs.index(old_rv) - self.observed_RVs.pop(index) - self.observed_RVs.insert(index, new_rv) - - def _marginalize(self, user_warnings=False): - fg = FunctionGraph(outputs=self.basic_RVs + self.marginalized_rvs, clone=False) - - toposort = fg.toposort() - rvs_left_to_marginalize = self.marginalized_rvs - for rv_to_marginalize in sorted( - self.marginalized_rvs, - key=lambda rv: toposort.index(rv.owner), - reverse=True, + transform = replaced_var.owner.op.transform + + if isinstance(transform, IntervalTransform) or ( + isinstance(transform, Chain) + and any(isinstance(tr, IntervalTransform) for tr in transform.transform_list) ): - # Check that no deterministics or potentials dependend on the rv to marginalize - for det in self.deterministics: - if is_conditional_dependent( - det, rv_to_marginalize, self.basic_RVs + rvs_left_to_marginalize - ): - raise NotImplementedError( - f"Cannot marginalize {rv_to_marginalize} due to dependent Deterministic {det}" - ) - for pot in self.potentials: - if is_conditional_dependent( - pot, rv_to_marginalize, self.basic_RVs + rvs_left_to_marginalize - ): - raise NotImplementedError( - f"Cannot marginalize {rv_to_marginalize} due to dependent Potential {pot}" - ) - - old_rvs, new_rvs = replace_finite_discrete_marginal_subgraph( - fg, rv_to_marginalize, self.basic_RVs + rvs_left_to_marginalize + warnings.warn( + f"The transform {transform} for the variable {replaced_var}, which depends on the " + f"marginalized {rv_to_marginalize} may no longer work if bounds depended on other variables.", + UserWarning, ) - if user_warnings and len(new_rvs) > 2: - warnings.warn( - "There are multiple dependent variables in a FiniteDiscreteMarginalRV. " - f"Their joint logp terms will be assigned to the first RV: {old_rvs[1]}", - UserWarning, - ) - rvs_left_to_marginalize.remove(rv_to_marginalize) - - for old_rv, new_rv in zip(old_rvs, new_rvs): - new_rv.name = old_rv.name - if old_rv in self.marginalized_rvs: - idx = self.marginalized_rvs.index(old_rv) - self.marginalized_rvs.pop(idx) - self.marginalized_rvs.insert(idx, new_rv) - if old_rv in self.basic_RVs: - self._transfer_rv_mappings(old_rv, new_rv) - if user_warnings: - # Interval transforms for dependent variable won't work for non-constant bounds because - # the RV inputs are now different and may depend on another RV that also depends on the - # same marginalized RV - transform = self.rvs_to_transforms[new_rv] - if isinstance(transform, IntervalTransform) or ( - isinstance(transform, Chain) - and any( - isinstance(tr, IntervalTransform) for tr in transform.transform_list - ) - ): - warnings.warn( - f"The transform {transform} for the variable {old_rv}, which depends on the " - f"marginalized {rv_to_marginalize} may no longer work if bounds depended on other variables.", - UserWarning, - ) - return self - - def _logp(self, *args, **kwargs): - return super().logp(*args, **kwargs) - - def logp(self, vars=None, **kwargs): - m = self.clone()._marginalize() - if vars is not None: - if not isinstance(vars, Sequence): - vars = (vars,) - vars = [m[var.name] for var in vars] - return m._logp(vars=vars, **kwargs) - - @staticmethod - def from_model(model: Union[Model, "MarginalModel"]) -> "MarginalModel": - new_model = MarginalModel(coords=model.coords) - if isinstance(model, MarginalModel): - marginalized_rvs = model.marginalized_rvs - marginalized_named_vars_to_dims = model._marginalized_named_vars_to_dims - else: - marginalized_rvs = [] - marginalized_named_vars_to_dims = {} - - model_vars = model.basic_RVs + model.potentials + model.deterministics + marginalized_rvs - data_vars = [var for name, var in model.named_vars.items() if var not in model_vars] - vars = model_vars + data_vars - cloned_vars = clone_replace(vars) - vars_to_clone = {var: cloned_var for var, cloned_var in zip(vars, cloned_vars)} - new_model.vars_to_clone = vars_to_clone - - new_model.named_vars = treedict( - {name: vars_to_clone[var] for name, var in model.named_vars.items()} - ) - new_model.named_vars_to_dims = model.named_vars_to_dims - new_model.values_to_rvs = {vv: vars_to_clone[rv] for vv, rv in model.values_to_rvs.items()} - new_model.rvs_to_values = {vars_to_clone[rv]: vv for rv, vv in model.rvs_to_values.items()} - new_model.rvs_to_transforms = { - vars_to_clone[rv]: tr for rv, tr in model.rvs_to_transforms.items() - } - new_model.rvs_to_initial_values = { - vars_to_clone[rv]: iv for rv, iv in model.rvs_to_initial_values.items() - } - new_model.free_RVs = [vars_to_clone[rv] for rv in model.free_RVs] - new_model.observed_RVs = [vars_to_clone[rv] for rv in model.observed_RVs] - new_model.potentials = [vars_to_clone[pot] for pot in model.potentials] - new_model.deterministics = [vars_to_clone[det] for det in model.deterministics] - - new_model.marginalized_rvs = [vars_to_clone[rv] for rv in marginalized_rvs] - new_model._marginalized_named_vars_to_dims = marginalized_named_vars_to_dims - return new_model - - def clone(self): - return self.from_model(self) - - def marginalize( - self, - rvs_to_marginalize: ModelRVs, - ): - if not isinstance(rvs_to_marginalize, Sequence): - rvs_to_marginalize = (rvs_to_marginalize,) - rvs_to_marginalize = [ - self[var] if isinstance(var, str) else var for var in rvs_to_marginalize - ] +def _unique(seq: Sequence) -> list: + """Copied from https://stackoverflow.com/a/480227""" + seen = set() + seen_add = seen.add + return [x for x in seq if not (x in seen or seen_add(x))] - for rv_to_marginalize in rvs_to_marginalize: - if rv_to_marginalize not in self.free_RVs: - raise ValueError( - f"Marginalized RV {rv_to_marginalize} is not a free RV in the model" - ) - rv_op = rv_to_marginalize.owner.op - if isinstance(rv_op, DiscreteMarkovChain): - if rv_op.n_lags > 1: - raise NotImplementedError( - "Marginalization for DiscreteMarkovChain with n_lags > 1 is not supported" - ) - if rv_to_marginalize.owner.inputs[0].type.ndim > 2: - raise NotImplementedError( - "Marginalization for DiscreteMarkovChain with non-matrix transition probability is not supported" - ) - elif not isinstance(rv_op, Bernoulli | Categorical | DiscreteUniform): - raise NotImplementedError( - f"Marginalization of RV with distribution {rv_to_marginalize.owner.op} is not supported" - ) - - if rv_to_marginalize.name in self.named_vars_to_dims: - dims = self.named_vars_to_dims[rv_to_marginalize.name] - self._marginalized_named_vars_to_dims[rv_to_marginalize.name] = dims - - self._delete_rv_mappings(rv_to_marginalize) - self.marginalized_rvs.append(rv_to_marginalize) - - # Raise errors and warnings immediately - self.clone()._marginalize(user_warnings=True) - - def _to_transformed(self): - """Create a function from the untransformed space to the transformed space""" - transformed_rvs = [] - transformed_names = [] - - for rv in self.free_RVs: - transform = self.rvs_to_transforms.get(rv) - if transform is None: - transformed_rvs.append(rv) - transformed_names.append(rv.name) - else: - transformed_rv = transform.forward(rv, *rv.owner.inputs) - transformed_rvs.append(transformed_rv) - transformed_names.append(self.rvs_to_values[rv].name) - - fn = self.compile_fn(inputs=self.free_RVs, outs=transformed_rvs) - return fn, transformed_names - - def unmarginalize(self, rvs_to_unmarginalize: Sequence[TensorVariable | str]): - for rv in rvs_to_unmarginalize: - if isinstance(rv, str): - rv = self[rv] - self.marginalized_rvs.remove(rv) - if rv.name in self._marginalized_named_vars_to_dims: - dims = self._marginalized_named_vars_to_dims.pop(rv.name) - else: - dims = None - self.register_rv(rv, name=rv.name, dims=dims) - - def recover_marginals( - self, - idata: InferenceData, - var_names: Sequence[str] | None = None, - return_samples: bool = True, - extend_inferencedata: bool = True, - random_seed: RandomState = None, - ): - """Computes posterior log-probabilities and samples of marginalized variables - conditioned on parameters of the model given InferenceData with posterior group +def marginalize(model: Model, rvs_to_marginalize: ModelRVs) -> MarginalModel: + """Marginalize a subset of variables in a PyMC model. - When there are multiple marginalized variables, each marginalized variable is - conditioned on both the parameters and the other variables still marginalized + This creates a class of `MarginalModel` from an existing `Model`, with the specified + variables marginalized. - All log-probabilities are within the transformed space + See documentation for `MarginalModel` for more information. - Parameters - ---------- - idata : InferenceData - InferenceData with posterior group - var_names : sequence of str, optional - List of variable names for which to compute posterior log-probabilities and samples. Defaults to all marginalized variables - return_samples : bool, default True - If True, also return samples of the marginalized variables - extend_inferencedata : bool, default True - Whether to extend the original InferenceData or return a new one - random_seed: int, array-like of int or SeedSequence, optional - Seed used to generating samples + Parameters + ---------- + model : Model + PyMC model to marginalize. Original variables well be cloned. + rvs_to_marginalize : Sequence[TensorVariable] + Variables to marginalize in the returned model. - Returns - ------- - idata : InferenceData - InferenceData with where a lp_{varname} and {varname} for each marginalized variable in var_names added to the posterior group + Returns + ------- + marginal_model: MarginalModel + Marginal model with the specified variables marginalized. + """ + if isinstance(rvs_to_marginalize, str | Variable): + rvs_to_marginalize = (rvs_to_marginalize,) - .. code-block:: python + rvs_to_marginalize = [model[rv] if isinstance(rv, str) else rv for rv in rvs_to_marginalize] - import pymc as pm - from pymc_experimental import MarginalModel + if not rvs_to_marginalize: + return model - with MarginalModel() as m: - p = pm.Beta("p", 1, 1) - x = pm.Bernoulli("x", p=p, shape=(3,)) - y = pm.Normal("y", pm.math.switch(x, -10, 10), observed=[10, 10, -10]) + for rv_to_marginalize in rvs_to_marginalize: + if rv_to_marginalize not in model.free_RVs: + raise ValueError(f"Marginalized RV {rv_to_marginalize} is not a free RV in the model") - m.marginalize([x]) + rv_op = rv_to_marginalize.owner.op + if isinstance(rv_op, DiscreteMarkovChain): + if rv_op.n_lags > 1: + raise NotImplementedError( + "Marginalization for DiscreteMarkovChain with n_lags > 1 is not supported" + ) + if rv_to_marginalize.owner.inputs[0].type.ndim > 2: + raise NotImplementedError( + "Marginalization for DiscreteMarkovChain with non-matrix transition probability is not supported" + ) + elif not isinstance(rv_op, Bernoulli | Categorical | DiscreteUniform): + raise NotImplementedError( + f"Marginalization of RV with distribution {rv_to_marginalize.owner.op} is not supported" + ) - idata = pm.sample() - m.recover_marginals(idata, var_names=["x"]) + fg, memo = fgraph_from_model(model) + rvs_to_marginalize = [memo[rv] for rv in rvs_to_marginalize] + toposort = fg.toposort() + for rv_to_marginalize in sorted( + rvs_to_marginalize, + key=lambda rv: toposort.index(rv.owner), + reverse=True, + ): + all_rvs = [node.out for node in fg.toposort() if isinstance(node.op, ModelValuedVar)] - """ - if var_names is None: - var_names = [var.name for var in self.marginalized_rvs] + dependent_rvs = find_conditional_dependent_rvs(rv_to_marginalize, all_rvs) + if not dependent_rvs: + # TODO: This should at most be a warning, not an error + raise ValueError(f"No RVs depend on marginalized RV {rv_to_marginalize}") - var_names = [var if isinstance(var, str) else var.name for var in var_names] - vars_to_recover = [v for v in self.marginalized_rvs if v.name in var_names] - missing_names = [v.name for v in vars_to_recover if v not in self.marginalized_rvs] - if missing_names: - raise ValueError(f"Unrecognized var_names: {missing_names}") + # Issue warning for IntervalTransform on dependent RVs + for dependent_rv in dependent_rvs: + transform = dependent_rv.owner.op.transform - if return_samples and random_seed is not None: - seeds = _get_seeds_per_chain(random_seed, len(vars_to_recover)) - else: - seeds = [None] * len(vars_to_recover) + if isinstance(transform, IntervalTransform) or ( + isinstance(transform, Chain) + and any(isinstance(tr, IntervalTransform) for tr in transform.transform_list) + ): + warnings.warn( + f"The transform {transform} for the variable {dependent_rv}, which depends on the " + f"marginalized {rv_to_marginalize} may no longer work if bounds depended on other variables.", + UserWarning, + ) - posterior = idata.posterior + # Check that no deterministics or potentials depend on the rv to marginalize + for det in model.deterministics: + if is_conditional_dependent(memo[det], rv_to_marginalize, all_rvs): + raise NotImplementedError( + f"Cannot marginalize {rv_to_marginalize} due to dependent Deterministic {det}" + ) + for pot in model.potentials: + if is_conditional_dependent(memo[pot], rv_to_marginalize, all_rvs): + raise NotImplementedError( + f"Cannot marginalize {rv_to_marginalize} due to dependent Potential {pot}" + ) - # Remove Deterministics - posterior_values = posterior[ - [rv.name for rv in self.free_RVs if rv not in self.marginalized_rvs] + marginalized_rv_input_rvs = find_conditional_input_rvs([rv_to_marginalize], all_rvs) + other_direct_rv_ancestors = [ + rv + for rv in find_conditional_input_rvs(dependent_rvs, all_rvs) + if rv is not rv_to_marginalize ] + input_rvs = _unique((*marginalized_rv_input_rvs, *other_direct_rv_ancestors)) - sample_dims = ("chain", "draw") - posterior_pts, stacked_dims = dataset_to_point_list(posterior_values, sample_dims) + replace_finite_discrete_marginal_subgraph(fg, rv_to_marginalize, dependent_rvs, input_rvs) - # Handle Transforms - transform_fn, transform_names = self._to_transformed() + return model_from_fgraph(fg, mutate_fgraph=True) - def transform_input(inputs): - return dict(zip(transform_names, transform_fn(inputs))) - posterior_pts = [transform_input(vs) for vs in posterior_pts] +@node_rewriter(tracks=[MarginalRV]) +def local_unmarginalize(fgraph, node): + unmarginalized_rv, *dependent_rvs_and_rngs = inline_ofg_outputs(node.op, node.inputs) + rngs = [rng for rng in dependent_rvs_and_rngs if isinstance(rng.type, RandomType)] + dependent_rvs = [rv for rv in dependent_rvs_and_rngs if rv not in rngs] - rv_dict = {} - rv_dims = {} - for seed, marginalized_rv in zip(seeds, vars_to_recover): - supported_dists = (Bernoulli, Categorical, DiscreteUniform) - if not isinstance(marginalized_rv.owner.op, supported_dists): - raise NotImplementedError( - f"RV with distribution {marginalized_rv.owner.op} cannot be recovered. " - f"Supported distribution include {supported_dists}" - ) + # Wrap the marginalized RV in a FreeRV + # TODO: Preserve dims and transform in MarginalRV + value = unmarginalized_rv.clone() + fgraph.add_input(value) + unmarginalized_free_rv = model_free_rv(unmarginalized_rv, value, transform=None) - m = self.clone() - marginalized_rv = m.vars_to_clone[marginalized_rv] - m.unmarginalize([marginalized_rv]) - dependent_rvs = find_conditional_dependent_rvs(marginalized_rv, m.basic_RVs) - logps = m.logp(vars=[marginalized_rv, *dependent_rvs], sum=False) + # Replace references to the marginalized RV with the FreeRV in the dependent RVs + dependent_rvs = graph_replace(dependent_rvs, {unmarginalized_rv: unmarginalized_free_rv}) - # Handle batch dims for marginalized value and its dependent RVs - dependent_rvs_dim_connections = subgraph_batch_dim_connection( - marginalized_rv, dependent_rvs - ) - marginalized_logp, *dependent_logps = logps - joint_logp = marginalized_logp + reduce_batch_dependent_logps( - dependent_rvs_dim_connections, - [dependent_var.owner.op for dependent_var in dependent_rvs], - dependent_logps, - ) + return [unmarginalized_free_rv, *dependent_rvs, *rngs] - marginalized_value = m.rvs_to_values[marginalized_rv] - other_values = [v for v in m.value_vars if v is not marginalized_value] - - rv_shape = constant_fold(tuple(marginalized_rv.shape), raise_not_constant=False) - rv_domain = get_domain_of_finite_discrete_rv(marginalized_rv) - rv_domain_tensor = pt.moveaxis( - pt.full( - (*rv_shape, len(rv_domain)), - rv_domain, - dtype=marginalized_rv.dtype, - ), - -1, - 0, - ) - batched_joint_logp = vectorize_graph( - joint_logp, - replace={marginalized_value: rv_domain_tensor}, - ) - batched_joint_logp = pt.moveaxis(batched_joint_logp, 0, -1) - - joint_logp_norm = log_softmax(batched_joint_logp, axis=-1) - if return_samples: - rv_draws = pymc.Categorical.dist(logit_p=batched_joint_logp) - if isinstance(marginalized_rv.owner.op, DiscreteUniform): - rv_draws += rv_domain[0] - outputs = [joint_logp_norm, rv_draws] - else: - outputs = joint_logp_norm - - rv_loglike_fn = compile_pymc( - inputs=other_values, - outputs=outputs, - on_unused_input="ignore", - random_seed=seed, - ) +unmarginalize_rewrite = in2out(local_unmarginalize, ignore_newtrees=False) - logvs = [rv_loglike_fn(**vs) for vs in posterior_pts] - if return_samples: - logps, samples = zip(*logvs) - logps = np.array(logps) - samples = np.array(samples) - rv_dict[marginalized_rv.name] = samples.reshape( - tuple(len(coord) for coord in stacked_dims.values()) + samples.shape[1:], - ) - else: - logps = np.array(logvs) +def unmarginalize(model: Model, rvs_to_unmarginalize: str | Sequence[str] | None = None) -> Model: + """Unmarginalize a subset of variables in a PyMC model. - rv_dict["lp_" + marginalized_rv.name] = logps.reshape( - tuple(len(coord) for coord in stacked_dims.values()) + logps.shape[1:], - ) - if marginalized_rv.name in m.named_vars_to_dims: - rv_dims[marginalized_rv.name] = list(m.named_vars_to_dims[marginalized_rv.name]) - rv_dims["lp_" + marginalized_rv.name] = rv_dims[marginalized_rv.name] + [ - "lp_" + marginalized_rv.name + "_dim" - ] - - coords, dims = coords_and_dims_for_inferencedata(self) - dims.update(rv_dims) - rv_dataset = dict_to_dataset( - rv_dict, - library=pymc, - dims=dims, - coords=coords, - default_dims=list(sample_dims), - skip_event_dims=True, + + Parameters + ---------- + model : Model + PyMC model to unmarginalize. Original variables well be cloned. + rvs_to_unmarginalize : str or sequence of str, optional + Variables to unmarginalize in the returned model. If None, all variables are + unmarginalized. + + Returns + ------- + unmarginal_model: Model + Model with the specified variables unmarginalized. + """ + + # Unmarginalize all the MarginalRVs + fg, memo = fgraph_from_model(model) + unmarginalize_rewrite(fg) + unmarginalized_model = model_from_fgraph(fg, mutate_fgraph=True) + if rvs_to_unmarginalize is None: + return unmarginalized_model + + # Re-marginalize the variables we want to keep marginalized + if not isinstance(rvs_to_unmarginalize, list | tuple): + rvs_to_unmarginalize = (rvs_to_unmarginalize,) + rvs_to_unmarginalize = set(rvs_to_unmarginalize) + + old_free_rv_names = set(rv.name for rv in model.free_RVs) + new_free_rv_names = set( + rv.name for rv in unmarginalized_model.free_RVs if rv.name not in old_free_rv_names + ) + if rvs_to_unmarginalize - new_free_rv_names: + raise ValueError( + f"Unrecognized rvs_to_unmarginalize: {rvs_to_unmarginalize - new_free_rv_names}" ) + rvs_to_keep_marginalized = tuple(new_free_rv_names - rvs_to_unmarginalize) + return marginalize(unmarginalized_model, rvs_to_keep_marginalized) + + +def transform_posterior_pts(model, posterior_pts): + """Create a function from the untransformed space to the transformed space""" + # TODO: This should be a utility in PyMC + transformed_rvs = [] + transformed_names = [] - if extend_inferencedata: - idata.posterior = idata.posterior.assign(rv_dataset) - return idata + for rv in model.free_RVs: + transform = model.rvs_to_transforms.get(rv) + if transform is None: + transformed_rvs.append(rv) + transformed_names.append(rv.name) else: - return rv_dataset + transformed_rv = transform.forward(rv, *rv.owner.inputs) + transformed_rvs.append(transformed_rv) + transformed_names.append(model.rvs_to_values[rv].name) + fn = compile_pymc( + inputs=[In(inp, borrow=True) for inp in model.free_RVs], + outputs=[Out(out, borrow=True) for out in transformed_rvs], + ) + fn.trust_input = True -def marginalize(model: Model, rvs_to_marginalize: ModelRVs) -> MarginalModel: - """Marginalize a subset of variables in a PyMC model. + # TODO: This should work with vectorized inputs + return [dict(zip(transformed_names, fn(**point))) for point in posterior_pts] - This creates a class of `MarginalModel` from an existing `Model`, with the specified - variables marginalized. - See documentation for `MarginalModel` for more information. +def recover_marginals( + model: Model, + idata: InferenceData, + var_names: Sequence[str] | None = None, + return_samples: bool = True, + extend_inferencedata: bool = True, + random_seed: RandomState = None, +): + """Computes posterior log-probabilities and samples of marginalized variables + conditioned on parameters of the model given InferenceData with posterior group + + When there are multiple marginalized variables, each marginalized variable is + conditioned on both the parameters and the other variables still marginalized + + All log-probabilities are within the transformed space Parameters ---------- - model : Model - PyMC model to marginalize. Original variables well be cloned. - rvs_to_marginalize : Sequence[TensorVariable] - Variables to marginalize in the returned model. + model: Model + PyMC model with marginalized variables to recover + idata : InferenceData + InferenceData with posterior group + var_names : sequence of str, optional + List of variable names for which to compute posterior log-probabilities and samples. Defaults to all marginalized variables + return_samples : bool, default True + If True, also return samples of the marginalized variables + extend_inferencedata : bool, default True + Whether to extend the original InferenceData or return a new one + random_seed: int, array-like of int or SeedSequence, optional + Seed used to generating samples Returns ------- - marginal_model: MarginalModel - Marginal model with the specified variables marginalized. + idata : InferenceData + InferenceData with where a lp_{varname} and {varname} for each marginalized variable in var_names added to the posterior group + + .. code-block:: python + + import pymc as pm + from pymc_experimental import MarginalModel + + with MarginalModel() as m: + p = pm.Beta("p", 1, 1) + x = pm.Bernoulli("x", p=p, shape=(3,)) + y = pm.Normal("y", pm.math.switch(x, -10, 10), observed=[10, 10, -10]) + + m.marginalize([x]) + + idata = pm.sample() + m.recover_marginals(idata, var_names=["x"]) + + """ - if not isinstance(rvs_to_marginalize, tuple | list): - rvs_to_marginalize = (rvs_to_marginalize,) - rvs_to_marginalize = [rv if isinstance(rv, str) else rv.name for rv in rvs_to_marginalize] + unmarginal_model = unmarginalize(model) + + # Find the names of the marginalized variables + model_var_names = set(rv.name for rv in model.free_RVs) + marginalized_rv_names = [ + rv.name for rv in unmarginal_model.free_RVs if rv.name not in model_var_names + ] + + if var_names is None: + var_names = marginalized_rv_names + + var_names = [var if isinstance(var, str) else var.name for var in var_names] + var_names_to_recover = [name for name in marginalized_rv_names if name in var_names] + missing_names = [name for name in var_names_to_recover if name not in marginalized_rv_names] + if missing_names: + raise ValueError(f"Unrecognized var_names: {missing_names}") + + if return_samples and random_seed is not None: + seeds = _get_seeds_per_chain(random_seed, len(var_names_to_recover)) + else: + seeds = [None] * len(var_names_to_recover) + + posterior_pts, stacked_dims = dataset_to_point_list( + # Remove Deterministics + idata.posterior[[rv.name for rv in model.free_RVs]], + sample_dims=("chain", "draw"), + ) + transformed_posterior_pts = transform_posterior_pts(model, posterior_pts) + + rv_dict = {} + rv_dims = {} + for seed, var_name_to_recover in zip(seeds, var_names_to_recover): + var_to_recover = unmarginal_model[var_name_to_recover] + supported_dists = (Bernoulli, Categorical, DiscreteUniform) + if not isinstance(var_to_recover.owner.op, supported_dists): + raise NotImplementedError( + f"RV with distribution {var_to_recover.owner.op} cannot be recovered. " + f"Supported distribution include {supported_dists}" + ) - marginal_model = MarginalModel.from_model(model) - marginal_model.marginalize(rvs_to_marginalize) - return marginal_model + other_marginalized_rvs_names = marginalized_rv_names.copy() + other_marginalized_rvs_names.remove(var_name_to_recover) + dependent_rvs = find_conditional_dependent_rvs(var_to_recover, unmarginal_model.basic_RVs) + # Handle batch dims for marginalized value and its dependent RVs + dependent_rvs_dim_connections = subgraph_batch_dim_connection(var_to_recover, dependent_rvs) + + marginalized_model = marginalize(unmarginal_model, other_marginalized_rvs_names) + + var_to_recover = marginalized_model[var_name_to_recover] + dependent_rvs = [marginalized_model[rv.name] for rv in dependent_rvs] + logps = marginalized_model.logp(vars=[var_to_recover, *dependent_rvs], sum=False) + + marginalized_logp, *dependent_logps = logps + joint_logp = marginalized_logp + reduce_batch_dependent_logps( + dependent_rvs_dim_connections, + [dependent_var.owner.op for dependent_var in dependent_rvs], + dependent_logps, + ) + + marginalized_value = marginalized_model.rvs_to_values[var_to_recover] + other_values = [v for v in marginalized_model.value_vars if v is not marginalized_value] + + rv_shape = constant_fold(tuple(var_to_recover.shape), raise_not_constant=False) + rv_domain = get_domain_of_finite_discrete_rv(var_to_recover) + rv_domain_tensor = pt.moveaxis( + pt.full( + (*rv_shape, len(rv_domain)), + rv_domain, + dtype=var_to_recover.dtype, + ), + -1, + 0, + ) + + batched_joint_logp = vectorize_graph( + joint_logp, + replace={marginalized_value: rv_domain_tensor}, + ) + batched_joint_logp = pt.moveaxis(batched_joint_logp, 0, -1) + + joint_logp_norm = log_softmax(batched_joint_logp, axis=-1) + if return_samples: + rv_draws = Categorical.dist(logit_p=batched_joint_logp) + if isinstance(var_to_recover.owner.op, DiscreteUniform): + rv_draws += rv_domain[0] + outputs = [joint_logp_norm, rv_draws] + else: + outputs = joint_logp_norm + + rv_loglike_fn = compile_pymc( + inputs=other_values, + outputs=outputs, + on_unused_input="ignore", + random_seed=seed, + ) + + logvs = [rv_loglike_fn(**vs) for vs in transformed_posterior_pts] + + if return_samples: + logps, samples = zip(*logvs) + logps = np.asarray(logps) + samples = np.asarray(samples) + rv_dict[var_name_to_recover] = samples.reshape( + tuple(len(coord) for coord in stacked_dims.values()) + samples.shape[1:], + ) + else: + logps = np.asarray(logvs) + + rv_dict["lp_" + var_name_to_recover] = logps.reshape( + tuple(len(coord) for coord in stacked_dims.values()) + logps.shape[1:], + ) + if var_name_to_recover in unmarginal_model.named_vars_to_dims: + rv_dims[var_name_to_recover] = list( + unmarginal_model.named_vars_to_dims[var_name_to_recover] + ) + rv_dims["lp_" + var_name_to_recover] = rv_dims[var_name_to_recover] + [ + "lp_" + var_name_to_recover + "_dim" + ] + + coords, dims = coords_and_dims_for_inferencedata(unmarginal_model) + dims.update(rv_dims) + rv_dataset = dict_to_dataset( + rv_dict, + library=pymc, + dims=dims, + coords=coords, + skip_event_dims=True, + ) + + if extend_inferencedata: + idata.posterior = idata.posterior.assign(rv_dataset) + return idata + else: + return rv_dataset def collect_shared_vars(outputs, blockers): return [ - inp for inp in graph_inputs(outputs, blockers=blockers) if isinstance(inp, SharedVariable) + inp + for inp in graph_inputs(outputs, blockers=blockers) + if (isinstance(inp, SharedVariable) and inp not in blockers) ] -def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs): - dependent_rvs = find_conditional_dependent_rvs(rv_to_marginalize, all_rvs) - if not dependent_rvs: - raise ValueError(f"No RVs depend on marginalized RV {rv_to_marginalize}") +def remove_model_vars(vars): + """Remove ModelVars from the graph of vars.""" + model_vars = [var for var in vars if isinstance(var.owner.op, ModelValuedVar)] + replacements = [(model_var, model_var.owner.inputs[0]) for model_var in model_vars] + fgraph = FunctionGraph(outputs=vars, clone=False) + toposort_replace(fgraph, replacements) + return fgraph.outputs - marginalized_rv_input_rvs = find_conditional_input_rvs([rv_to_marginalize], all_rvs) - other_direct_rv_ancestors = [ - rv - for rv in find_conditional_input_rvs(dependent_rvs, all_rvs) - if rv is not rv_to_marginalize - ] +def replace_finite_discrete_marginal_subgraph( + fgraph, rv_to_marginalize, dependent_rvs, input_rvs +) -> None: # If the marginalized RV has multiple dimensions, check that graph between # marginalized RV and dependent RVs does not mix information from batch dimensions # (otherwise logp would require enumerating over all combinations of batch dimension values) @@ -574,22 +556,45 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs "You can try splitting the marginalized RV into separate components and marginalizing them separately." ) from e - input_rvs = list(set((*marginalized_rv_input_rvs, *other_direct_rv_ancestors))) output_rvs = [rv_to_marginalize, *dependent_rvs] + rng_updates = collect_default_updates(output_rvs, inputs=input_rvs, must_be_shared=False) + outputs = output_rvs + list(rng_updates.values()) + inputs = input_rvs + list(rng_updates.keys()) + # Add any other shared variable inputs + inputs += collect_shared_vars(output_rvs, blockers=inputs) - # We are strict about shared variables in SymbolicRandomVariables - inputs = input_rvs + collect_shared_vars(output_rvs, blockers=input_rvs) + inner_inputs = [inp.clone() for inp in inputs] + inner_outputs = clone_replace(outputs, replace=dict(zip(inputs, inner_inputs))) + inner_outputs = remove_model_vars(inner_outputs) - if isinstance(rv_to_marginalize.owner.op, DiscreteMarkovChain): + if isinstance(inner_outputs[0].owner.op, DiscreteMarkovChain): marginalize_constructor = MarginalDiscreteMarkovChainRV else: marginalize_constructor = MarginalFiniteDiscreteRV marginalization_op = marginalize_constructor( - inputs=inputs, - outputs=output_rvs, # TODO: Add RNG updates to outputs so this can be used in the generative graph + inputs=inner_inputs, + outputs=inner_outputs, dims_connections=dependent_rvs_dim_connections, ) - new_output_rvs = marginalization_op(*inputs) - fgraph.replace_all(tuple(zip(output_rvs, new_output_rvs))) - return output_rvs, new_output_rvs + + new_outputs = marginalization_op(*inputs) + for old_output, new_output in zip(outputs, new_outputs): + new_output.name = old_output.name + + outer_replacements = [ + ( + # Remove the marginalized FreeRV, but keep the dependent ones as Free/ObservedRVs + ( + old_output + if ( + old_output is rv_to_marginalize + or not isinstance(old_output.owner.op, ModelValuedVar) + ) + else old_output.owner.inputs[0] + ), + new_output, + ) + for old_output, new_output in zip(outputs, new_outputs) + ] + fgraph.replace_all(outer_replacements) diff --git a/pymc_experimental/utils/model_equivalence.py b/pymc_experimental/utils/model_equivalence.py new file mode 100644 index 00000000..cdbb8c90 --- /dev/null +++ b/pymc_experimental/utils/model_equivalence.py @@ -0,0 +1,66 @@ +from collections.abc import Sequence + +from pymc.model.core import Model +from pymc.model.fgraph import fgraph_from_model +from pytensor import Variable +from pytensor.compile import SharedVariable +from pytensor.graph import Constant, graph_inputs +from pytensor.graph.basic import equal_computations +from pytensor.tensor.random.type import RandomType + + +def equal_computations_up_to_root( + xs: Sequence[Variable], ys: Sequence[Variable], ignore_rng_values=True +) -> bool: + # Check if graphs are equivalent even if root variables have distinct identities + + x_graph_inputs = [var for var in graph_inputs(xs) if not isinstance(var, Constant)] + y_graph_inputs = [var for var in graph_inputs(ys) if not isinstance(var, Constant)] + if len(x_graph_inputs) != len(y_graph_inputs): + return False + for x, y in zip(x_graph_inputs, y_graph_inputs): + if x.type != y.type: + return False + if x.name != y.name: + return False + if isinstance(x, SharedVariable): + # if not isinstance(y, SharedVariable): + # return False + if isinstance(x.type, RandomType) and ignore_rng_values: + continue + if not x.type.values_eq(x.get_value(), y.get_value()): + return False + + return equal_computations(xs, ys, in_xs=x_graph_inputs, in_ys=y_graph_inputs) + + +def equivalent_models(model1: Model, model2: Model) -> bool: + """Check whether two PyMC models are equivalent. + + Examples + -------- + + .. code-block:: python + + import pymc as pm + from pymc_experimental.utils.model_equivalence import equivalent_models + + with pm.Model() as m1: + x = pm.Normal("x") + y = pm.Normal("y", x) + + with pm.Model() as m2: + x = pm.Normal("x") + y = pm.Normal("y", x + 1) + + with pm.Model() as m3: + x = pm.Normal("x") + y = pm.Normal("y", x) + + assert not equivalent_models(m1, m2) + assert equivalent_models(m1, m3) + + """ + fgraph1, _ = fgraph_from_model(model1) + fgraph2, _ = fgraph_from_model(model2) + return equal_computations_up_to_root(fgraph1.outputs, fgraph2.outputs) diff --git a/requirements.txt b/requirements.txt index b992ad37..571fa1b9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ -pymc>=5.17.0 +pymc>=5.18.1 scikit-learn diff --git a/tests/model/marginal/test_distributions.py b/tests/model/marginal/test_distributions.py index ecbc8817..45a543aa 100644 --- a/tests/model/marginal/test_distributions.py +++ b/tests/model/marginal/test_distributions.py @@ -6,7 +6,7 @@ from pytensor import tensor as pt from scipy.stats import norm -from pymc_experimental import MarginalModel +from pymc_experimental import marginalize from pymc_experimental.distributions import DiscreteMarkovChain from pymc_experimental.model.marginal.distributions import MarginalFiniteDiscreteRV @@ -43,7 +43,7 @@ def test_marginalized_hmm_normal_emission(batch_chain, batch_emission): if batch_chain and not batch_emission: pytest.skip("Redundant implicit combination") - with MarginalModel() as m: + with pm.Model() as m: P = [[0, 1], [1, 0]] init_dist = pm.Categorical.dist(p=[1, 0]) chain = DiscreteMarkovChain( @@ -53,8 +53,8 @@ def test_marginalized_hmm_normal_emission(batch_chain, batch_emission): "emission", mu=chain * 2 - 1, sigma=1e-1, shape=(3, 4) if batch_emission else None ) - m.marginalize([chain]) - logp_fn = m.compile_logp() + marginal_m = marginalize(m, [chain]) + logp_fn = marginal_m.compile_logp() test_value = np.array([-1, 1, -1, 1]) expected_logp = pm.logp(pm.Normal.dist(0, 1e-1), np.zeros_like(test_value)).sum().eval() @@ -70,7 +70,7 @@ def test_marginalized_hmm_normal_emission(batch_chain, batch_emission): ) def test_marginalized_hmm_categorical_emission(categorical_emission): """Example adapted from https://www.youtube.com/watch?v=9-sPm4CfcD0""" - with MarginalModel() as m: + with pm.Model() as m: P = np.array([[0.5, 0.5], [0.3, 0.7]]) init_dist = pm.Categorical.dist(p=[0.375, 0.625]) chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, steps=2) @@ -78,11 +78,11 @@ def test_marginalized_hmm_categorical_emission(categorical_emission): emission = pm.Categorical("emission", p=pt.constant([[0.8, 0.2], [0.4, 0.6]])[chain]) else: emission = pm.Bernoulli("emission", p=pt.where(pt.eq(chain, 0), 0.2, 0.6)) - m.marginalize([chain]) + marginal_m = marginalize(m, [chain]) test_value = np.array([0, 0, 1]) expected_logp = np.log(0.1344) # Shown at the 10m22s mark in the video - logp_fn = m.compile_logp() + logp_fn = marginal_m.compile_logp() np.testing.assert_allclose(logp_fn({"emission": test_value}), expected_logp) @@ -95,7 +95,7 @@ def test_marginalized_hmm_multiple_emissions(batch_chain, batch_emission1, batch (2, *reversed(chain_shape)) if batch_emission1 else tuple(reversed(chain_shape)) ) emission2_shape = (*chain_shape, 2) if batch_emission2 else chain_shape - with MarginalModel() as m: + with pm.Model() as m: P = [[0, 1], [1, 0]] init_dist = pm.Categorical.dist(p=[1, 0]) chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, shape=chain_shape) @@ -109,9 +109,9 @@ def test_marginalized_hmm_multiple_emissions(batch_chain, batch_emission1, batch emission_2 = pm.Normal("emission_2", mu=emission2_mu, sigma=1e-1, shape=emission2_shape) with pytest.warns(UserWarning, match="multiple dependent variables"): - m.marginalize([chain]) + marginal_m = marginalize(m, [chain]) - logp_fn = m.compile_logp(sum=False) + logp_fn = marginal_m.compile_logp(sum=False) test_value = np.array([-1, 1, -1, 1]) multiplier = 2 + batch_emission1 + batch_emission2 diff --git a/tests/model/marginal/test_marginal_model.py b/tests/model/marginal/test_marginal_model.py index 88625e24..22b1f1bc 100644 --- a/tests/model/marginal/test_marginal_model.py +++ b/tests/model/marginal/test_marginal_model.py @@ -9,25 +9,28 @@ import pytest from arviz import InferenceData, dict_to_dataset +from pymc import Model, draw from pymc.distributions import transforms from pymc.distributions.transforms import ordered -from pymc.model.fgraph import fgraph_from_model -from pymc.pytensorf import inputvars +from pymc.initial_point import make_initial_point_expression +from pymc.pytensorf import constant_fold, inputvars from pymc.util import UNSET from scipy.special import log_softmax, logsumexp from scipy.stats import halfnorm, norm +from pymc_experimental.model.marginal.distributions import MarginalRV from pymc_experimental.model.marginal.marginal_model import ( - MarginalModel, marginalize, + recover_marginals, + unmarginalize, ) -from tests.utils import equal_computations_up_to_root +from pymc_experimental.utils.model_equivalence import equivalent_models def test_basic_marginalized_rv(): data = [2] * 5 - with MarginalModel() as m: + with Model() as m: sigma = pm.HalfNormal("sigma") idx = pm.Categorical("idx", p=[0.1, 0.3, 0.6]) mu = pt.switch( @@ -42,45 +45,68 @@ def test_basic_marginalized_rv(): y = pm.Normal("y", mu=mu, sigma=sigma) z = pm.Normal("z", y, observed=data) - m.marginalize([idx]) - assert idx not in m.free_RVs - assert [rv.name for rv in m.marginalized_rvs] == ["idx"] + marginal_m = marginalize(m, [idx]) + assert isinstance(marginal_m["y"].owner.op, MarginalRV) + assert ["idx"] not in [rv.name for rv in marginal_m.free_RVs] + + # Test forward draws + y_draws, z_draws = draw( + [marginal_m["y"], marginal_m["z"]], + # Make sigma very small to make draws deterministic + givens={marginal_m["sigma"]: 0.001}, + draws=1000, + random_seed=54, + ) + assert sorted(np.unique(y_draws.round()) == [-1.0, 0.0, 1.0]) + assert z_draws[y_draws < 0].mean() < z_draws[y_draws > 0].mean() + + # Test initial_point + ips = make_initial_point_expression( + # Use basic_RVs to include the observed RV + free_rvs=marginal_m.basic_RVs, + rvs_to_transforms=marginal_m.rvs_to_transforms, + initval_strategies={}, + ) + # After simplification, we should have only constants in the graph (expect alloc which isn't constant folded): + ip_sigma, ip_y, ip_z = constant_fold(ips) + np.testing.assert_allclose(ip_sigma, 1.0) + np.testing.assert_allclose(ip_y, 1.0) + np.testing.assert_allclose(ip_z, np.full((5,), 1.0)) + + marginal_ip = marginal_m.initial_point() + expected_ip = m.initial_point() + expected_ip.pop("idx") + assert marginal_ip == expected_ip # Test logp - with pm.Model() as m_ref: + with pm.Model() as ref_m: sigma = pm.HalfNormal("sigma") y = pm.NormalMixture("y", w=[0.1, 0.3, 0.6], mu=[-1, 0, 1], sigma=sigma) z = pm.Normal("z", y, observed=data) - test_point = m_ref.initial_point() - ref_logp = m_ref.compile_logp()(test_point) - ref_dlogp = m_ref.compile_dlogp([m_ref["y"]])(test_point) - - # Assert we can marginalize and unmarginalize internally non-destructively - for i in range(3): - np.testing.assert_almost_equal( - m.compile_logp()(test_point), - ref_logp, - ) - np.testing.assert_almost_equal( - m.compile_dlogp([m["y"]])(test_point), - ref_dlogp, - ) + np.testing.assert_almost_equal( + marginal_m.compile_logp()(marginal_ip), + ref_m.compile_logp()(marginal_ip), + ) + np.testing.assert_almost_equal( + marginal_m.compile_dlogp([marginal_m["y"]])(marginal_ip), + ref_m.compile_dlogp([ref_m["y"]])(marginal_ip), + ) def test_one_to_one_marginalized_rvs(): """Test case with multiple, independent marginalized RVs.""" - with MarginalModel() as m: + with Model() as m: sigma = pm.HalfNormal("sigma") idx1 = pm.Bernoulli("idx1", p=0.75) x = pm.Normal("x", mu=idx1, sigma=sigma) idx2 = pm.Bernoulli("idx2", p=0.75, shape=(5,)) y = pm.Normal("y", mu=(idx2 * 2 - 1), sigma=sigma, shape=(5,)) - m.marginalize([idx1, idx2]) - m["x"].owner is not m["y"].owner - _m = m.clone()._marginalize() - _m["x"].owner is not _m["y"].owner + m = marginalize(m, [idx1, idx2]) + assert isinstance(m["x"].owner.op, MarginalRV) + assert isinstance(m["y"].owner.op, MarginalRV) + assert m["x"].owner is not m["y"].owner with pm.Model() as m_ref: sigma = pm.HalfNormal("sigma") @@ -97,24 +123,25 @@ def test_one_to_one_marginalized_rvs(): def test_one_to_many_marginalized_rvs(): """Test that marginalization works when there is more than one dependent RV""" - with MarginalModel() as m: + with Model() as m: sigma = pm.HalfNormal("sigma") idx = pm.Bernoulli("idx", p=0.75) x = pm.Normal("x", mu=idx, sigma=sigma) y = pm.Normal("y", mu=(idx * 2 - 1), sigma=sigma, shape=(5,)) - ref_logp_x_y_fn = m.compile_logp([idx, x, y]) - with pytest.warns(UserWarning, match="There are multiple dependent variables"): - m.marginalize([idx]) + marginal_m = marginalize(m, [idx]) - m["x"].owner is not m["y"].owner - _m = m.clone()._marginalize() - _m["x"].owner is _m["y"].owner + marginal_x = marginal_m["x"] + marginal_y = marginal_m["y"] + assert isinstance(marginal_x.owner.op, MarginalRV) + assert isinstance(marginal_y.owner.op, MarginalRV) + assert marginal_x.owner is marginal_y.owner - tp = m.initial_point() + ref_logp_x_y_fn = m.compile_logp([idx, x, y]) + tp = marginal_m.initial_point() ref_logp_x_y = logsumexp([ref_logp_x_y_fn({**tp, **{"idx": idx}}) for idx in (0, 1)]) - logp_x_y = m.compile_logp([x, y])(tp) + logp_x_y = marginal_m.compile_logp([marginal_x, marginal_y])(tp) np.testing.assert_array_almost_equal(logp_x_y, ref_logp_x_y) @@ -122,7 +149,7 @@ def test_one_to_many_unaligned_marginalized_rvs(): """Test that marginalization works when there is more than one dependent RV with batch dimensions that are not aligned""" def build_model(build_batched: bool): - with MarginalModel() as m: + with Model() as m: if build_batched: idx = pm.Bernoulli("idx", p=[0.75, 0.4], shape=(3, 2)) else: @@ -134,12 +161,9 @@ def build_model(build_batched: bool): return m - m = build_model(build_batched=True) - ref_m = build_model(build_batched=False) - with pytest.warns(UserWarning, match="There are multiple dependent variables"): - m.marginalize(["idx"]) - ref_m.marginalize([f"idx_{i}" for i in range(6)]) + m = marginalize(build_model(build_batched=True), ["idx"]) + ref_m = marginalize(build_model(build_batched=False), [f"idx_{i}" for i in range(6)]) test_point = m.initial_point() np.testing.assert_allclose( @@ -150,28 +174,27 @@ def build_model(build_batched: bool): def test_many_to_one_marginalized_rvs(): """Test when random variables depend on multiple marginalized variables""" - with MarginalModel() as m: + with Model() as m: x = pm.Bernoulli("x", 0.1) y = pm.Bernoulli("y", 0.3) z = pm.DiracDelta("z", c=x + y) - m.marginalize([x, y]) - logp = m.compile_logp() + logp_fn = marginalize(m, [x, y]).compile_logp() - np.testing.assert_allclose(np.exp(logp({"z": 0})), 0.9 * 0.7) - np.testing.assert_allclose(np.exp(logp({"z": 1})), 0.9 * 0.3 + 0.1 * 0.7) - np.testing.assert_allclose(np.exp(logp({"z": 2})), 0.1 * 0.3) + np.testing.assert_allclose(np.exp(logp_fn({"z": 0})), 0.9 * 0.7) + np.testing.assert_allclose(np.exp(logp_fn({"z": 1})), 0.9 * 0.3 + 0.1 * 0.7) + np.testing.assert_allclose(np.exp(logp_fn({"z": 2})), 0.1 * 0.3) @pytest.mark.parametrize("batched", (False, "left", "right")) def test_nested_marginalized_rvs(batched): """Test that marginalization works when there are nested marginalized RVs""" - def build_model(build_batched: bool) -> MarginalModel: + def build_model(build_batched: bool) -> Model: idx_shape = (3,) if build_batched else () sub_idx_shape = (5,) if not build_batched else (5, 3) if batched == "left" else (3, 5) - with MarginalModel() as m: + with Model() as m: sigma = pm.HalfNormal("sigma") idx = pm.Bernoulli("idx", p=0.75, shape=idx_shape) @@ -186,10 +209,34 @@ def build_model(build_batched: bool) -> MarginalModel: return m - m = build_model(build_batched=batched) with pytest.warns(UserWarning, match="There are multiple dependent variables"): - m.marginalize(["idx", "sub_idx"]) - assert sorted(m.name for m in m.marginalized_rvs) == ["idx", "sub_idx"] + marginal_m = marginalize(build_model(build_batched=batched), ["idx", "sub_idx"]) + assert all(rv.name not in ("idx", "sub_idx") for rv in marginal_m.free_RVs) + + # Test forward draws and initial_point, shouldn't depend on batching, so we only test one case + if not batched: + # Test forward draws + dep_draws, sub_dep_draws = draw( + [marginal_m["dep"], marginal_m["sub_dep"]], + # Make sigma very small to make draws deterministic + givens={marginal_m["sigma"]: 0.001}, + draws=1000, + random_seed=214, + ) + assert sorted(np.unique(dep_draws.round()) == [-1000.0, 1000.0]) + assert sorted(np.unique(sub_dep_draws.round()) == [-1000.0, -900.0, 1000.0, 1100.0]) + + # Test initial_point + ips = make_initial_point_expression( + free_rvs=marginal_m.free_RVs, + rvs_to_transforms=marginal_m.rvs_to_transforms, + initval_strategies={}, + ) + # After simplification, we should have only constants in the graph + ip_sigma, ip_dep, ip_sub_dep = constant_fold(ips) + np.testing.assert_allclose(ip_sigma, 1.0) + np.testing.assert_allclose(ip_dep, 1000.0) + np.testing.assert_allclose(ip_sub_dep, np.full((5,), 1100.0)) # Test logp ref_m = build_model(build_batched=False) @@ -210,14 +257,68 @@ def build_model(build_batched: bool) -> MarginalModel: if batched: ref_logp *= 3 - test_point = m.initial_point() + test_point = marginal_m.initial_point() test_point["dep"] = np.full_like(test_point["dep"], 1000) test_point["sub_dep"] = np.full_like(test_point["sub_dep"], 1000 + 100) - logp = m.compile_logp(vars=[m["dep"], m["sub_dep"]])(test_point) + logp = marginal_m.compile_logp(vars=[marginal_m["dep"], marginal_m["sub_dep"]])(test_point) np.testing.assert_almost_equal(logp, ref_logp) +def test_interdependent_rvs(): + """Test Marginalization when dependent RVs are interdependent.""" + with Model() as m: + idx = pm.Bernoulli("idx", p=0.75) + x = pm.Normal("x", mu=idx * 2, sigma=1e-3) + # Y depends on both x and idx + y = pm.Normal("y", mu=x * idx * 2, sigma=1e-3) + + with pytest.warns(UserWarning, match="There are multiple dependent variables"): + marginal_m = marginalize(m, "idx") + + marginal_x = marginal_m["x"] + marginal_y = marginal_m["y"] + assert isinstance(marginal_x.owner.op, MarginalRV) + assert isinstance(marginal_y.owner.op, MarginalRV) + assert marginal_x.owner is marginal_y.owner + + # Test forward draws + x_draws, y_draws = draw([marginal_x, marginal_y], draws=1000, random_seed=54) + assert sorted(np.unique(x_draws.round())) == [0, 2] + assert sorted(np.unique(y_draws.round())) == [0, 4] + assert np.unique(y_draws[x_draws < 1].round()) == [0] + assert np.unique(y_draws[x_draws > 1].round()) == [4] + + # Test initial_point + ips = make_initial_point_expression( + free_rvs=marginal_m.free_RVs, + rvs_to_transforms={}, + initval_strategies={}, + ) + # After simplification, we should have only constants in the graph + ip_x, ip_y = constant_fold(ips) + np.testing.assert_allclose(ip_x, 2.0) + np.testing.assert_allclose(ip_y, 4.0) + + # Test custom initval strategy + ips = make_initial_point_expression( + # Test that order does not matter + free_rvs=marginal_m.free_RVs[::-1], + rvs_to_transforms={}, + initval_strategies={marginal_x: pt.constant(5.0)}, + ) + ip_y, ip_x = constant_fold(ips) + np.testing.assert_allclose(ip_x, 5.0) + np.testing.assert_allclose(ip_y, 10.0) + + # Test logp + test_point = marginal_m.initial_point() + ref_logp_fn = m.compile_logp([m["idx"], m["x"], m["y"]]) + ref_logp = logsumexp([ref_logp_fn({**test_point, **{"idx": idx}}) for idx in (0, 1)]) + logp = marginal_m.compile_logp([marginal_m["x"], marginal_m["y"]])(test_point) + np.testing.assert_almost_equal(logp, ref_logp) + + @pytest.mark.parametrize("advanced_indexing", (False, True)) def test_marginalized_index_as_key(advanced_indexing): """Test we can marginalize graphs where indexing is used as a mapping.""" @@ -232,13 +333,13 @@ def test_marginalized_index_as_key(advanced_indexing): y_val = -1 shape = () - with MarginalModel() as m: + with Model() as m: x = pm.Categorical("x", p=w, shape=shape) y = pm.Normal("y", mu[x].T, sigma=1, observed=y_val) - m.marginalize(x) + marginal_m = marginalize(m, x) - marginal_logp = m.compile_logp(sum=False)({})[0] + marginal_logp = marginal_m.compile_logp(sum=False)({})[0] ref_logp = pm.logp(pm.NormalMixture.dist(w=w, mu=mu.T, sigma=1, shape=shape), y_val).eval() np.testing.assert_allclose(marginal_logp, ref_logp) @@ -247,8 +348,8 @@ def test_marginalized_index_as_key(advanced_indexing): def test_marginalized_index_as_value_and_key(): """Test we can marginalize graphs were marginalized_rv is indexed.""" - def build_model(build_batched: bool) -> MarginalModel: - with MarginalModel() as m: + def build_model(build_batched: bool) -> Model: + with Model() as m: if build_batched: latent_state = pm.Bernoulli("latent_state", p=0.3, size=(4,)) else: @@ -270,16 +371,16 @@ def build_model(build_batched: bool) -> MarginalModel: m = build_model(build_batched=True) ref_m = build_model(build_batched=False) - m.marginalize(["latent_state"]) - ref_m.marginalize([f"latent_state_{i}" for i in range(4)]) + m = marginalize(m, ["latent_state"]) + ref_m = marginalize(ref_m, [f"latent_state_{i}" for i in range(4)]) test_point = {"picked_intensity": 1} np.testing.assert_allclose( m.compile_logp()(test_point), ref_m.compile_logp()(test_point), ) - m.marginalize(["picked_intensity"]) - ref_m.marginalize(["picked_intensity"]) + m = marginalize(m, ["picked_intensity"]) + ref_m = marginalize(ref_m, ["picked_intensity"]) test_point = {} np.testing.assert_allclose( m.compile_logp()(test_point), @@ -291,76 +392,77 @@ class TestNotSupportedMixedDims: """Test lack of support for models where batch dims of marginalized variables are mixed.""" def test_mixed_dims_via_transposed_dot(self): - with MarginalModel() as m: + with Model() as m: idx = pm.Bernoulli("idx", p=0.7, shape=2) y = pm.Normal("y", mu=idx @ idx.T) - with pytest.raises(NotImplementedError): - m.marginalize(idx) + + with pytest.raises(NotImplementedError): + marginalize(m, idx) def test_mixed_dims_via_indexing(self): mean = pt.as_tensor([[0.1, 0.9], [0.6, 0.4]]) - with MarginalModel() as m: + with Model() as m: idx = pm.Bernoulli("idx", p=0.7, shape=2) y = pm.Normal("y", mu=mean[idx, :] + mean[:, idx]) - with pytest.raises(NotImplementedError): - m.marginalize(idx) + with pytest.raises(NotImplementedError): + marginalize(m, idx) - with MarginalModel() as m: + with Model() as m: idx = pm.Bernoulli("idx", p=0.7, shape=2) y = pm.Normal("y", mu=mean[idx, None] + mean[None, idx]) - with pytest.raises(NotImplementedError): - m.marginalize(idx) + with pytest.raises(NotImplementedError): + marginalize(m, idx) - with MarginalModel() as m: + with Model() as m: idx = pm.Bernoulli("idx", p=0.7, shape=2) mu = pt.specify_broadcastable(mean[:, None][idx], 1) + pt.specify_broadcastable( mean[None, :][:, idx], 0 ) y = pm.Normal("y", mu=mu) - with pytest.raises(NotImplementedError): - m.marginalize(idx) + with pytest.raises(NotImplementedError): + marginalize(m, idx) - with MarginalModel() as m: + with Model() as m: idx = pm.Bernoulli("idx", p=0.7, shape=2) y = pm.Normal("y", mu=idx[0] + idx[1]) - with pytest.raises(NotImplementedError): - m.marginalize(idx) + with pytest.raises(NotImplementedError): + marginalize(m, idx) def test_mixed_dims_via_vector_indexing(self): - with MarginalModel() as m: + with Model() as m: idx = pm.Bernoulli("idx", p=0.7, shape=2) y = pm.Normal("y", mu=idx[[0, 1, 0, 0]]) - with pytest.raises(NotImplementedError): - m.marginalize(idx) + with pytest.raises(NotImplementedError): + marginalize(m, idx) - with MarginalModel() as m: + with Model() as m: idx = pm.Categorical("key", p=[0.1, 0.3, 0.6], shape=(2, 2)) y = pm.Normal("y", pt.as_tensor([[0, 1], [2, 3]])[idx.astype(bool)]) - with pytest.raises(NotImplementedError): - m.marginalize(idx) + with pytest.raises(NotImplementedError): + marginalize(m, idx) def test_mixed_dims_via_support_dimension(self): - with MarginalModel() as m: + with Model() as m: x = pm.Bernoulli("x", p=0.7, shape=3) y = pm.Dirichlet("y", a=x * 10 + 1) - with pytest.raises(NotImplementedError): - m.marginalize(x) + with pytest.raises(NotImplementedError): + marginalize(m, x) def test_mixed_dims_via_nested_marginalization(self): - with MarginalModel() as m: + with Model() as m: x = pm.Bernoulli("x", p=0.7, shape=(3,)) y = pm.Bernoulli("y", p=0.7, shape=(2,)) z = pm.Normal("z", mu=pt.add.outer(x, y), shape=(3, 2)) - with pytest.raises(NotImplementedError): - m.marginalize([x, y]) + with pytest.raises(NotImplementedError): + marginalize(m, [x, y]) def test_marginalized_deterministic_and_potential(): rng = np.random.default_rng(299) - with MarginalModel() as m: + with Model() as m: x = pm.Bernoulli("x", p=0.7) y = pm.Normal("y", x) z = pm.Normal("z", x) @@ -368,22 +470,22 @@ def test_marginalized_deterministic_and_potential(): pot = pm.Potential("pot", y + z + 1) with pytest.warns(UserWarning, match="There are multiple dependent variables"): - m.marginalize([x]) + marginal_m = marginalize(m, [x]) y_draw, z_draw, det_draw, pot_draw = pm.draw([y, z, det, pot], draws=5, random_seed=rng) np.testing.assert_almost_equal(y_draw + z_draw, det_draw) np.testing.assert_almost_equal(det_draw, pot_draw - 1) - y_value = m.rvs_to_values[y] - z_value = m.rvs_to_values[z] - det_value, pot_value = m.replace_rvs_by_values([det, pot]) + y_value = marginal_m.rvs_to_values[marginal_m["y"]] + z_value = marginal_m.rvs_to_values[marginal_m["z"]] + det_value, pot_value = marginal_m.replace_rvs_by_values([marginal_m["det"], marginal_m["pot"]]) assert set(inputvars([det_value, pot_value])) == {y_value, z_value} assert det_value.eval({y_value: 2, z_value: 5}) == 7 assert pot_value.eval({y_value: 2, z_value: 5}) == 8 def test_not_supported_marginalized_deterministic_and_potential(): - with MarginalModel() as m: + with Model() as m: x = pm.Bernoulli("x", p=0.7) y = pm.Normal("y", x) det = pm.Deterministic("det", x + y) @@ -391,9 +493,9 @@ def test_not_supported_marginalized_deterministic_and_potential(): with pytest.raises( NotImplementedError, match="Cannot marginalize x due to dependent Deterministic det" ): - m.marginalize([x]) + marginalize(m, [x]) - with MarginalModel() as m: + with Model() as m: x = pm.Bernoulli("x", p=0.7) y = pm.Normal("y", x) pot = pm.Potential("pot", x + y) @@ -401,7 +503,7 @@ def test_not_supported_marginalized_deterministic_and_potential(): with pytest.raises( NotImplementedError, match="Cannot marginalize x due to dependent Potential pot" ): - m.marginalize([x]) + marginalize(m, [x]) @pytest.mark.parametrize( @@ -440,7 +542,7 @@ def test_marginalized_transforms(transform, expected_warning): ) y = pm.Normal("y", 0, sigma, observed=data) - with MarginalModel() as m: + with Model() as m: idx = pm.Categorical("idx", p=w) sigma = pm.HalfNormal( "sigma", @@ -453,32 +555,32 @@ def test_marginalized_transforms(transform, expected_warning): 3, ), ), - initval=initval, default_transform=transform, ) y = pm.Normal("y", 0, sigma, observed=data) with expected_warning: - m.marginalize([idx]) + marginal_m = marginalize(m, [idx]) - ip = m.initial_point() + marginal_m.set_initval(marginal_m["sigma"], initval) + ip = marginal_m.initial_point() if transform is not None: if transform is UNSET: transform_name = "log" else: transform_name = transform.name assert -np.inf < ip[f"sigma_{transform_name}__"] < 0.0 - np.testing.assert_allclose(m.compile_logp()(ip), m_ref.compile_logp()(ip)) + np.testing.assert_allclose(marginal_m.compile_logp()(ip), m_ref.compile_logp()(ip)) def test_data_container(): """Test that MarginalModel can handle Data containers.""" - with MarginalModel(coords={"obs": [0]}) as marginal_m: + with Model(coords={"obs": [0]}) as m: x = pm.Data("x", 2.5) idx = pm.Bernoulli("idx", p=0.7, dims="obs") y = pm.Normal("y", idx * x, dims="obs") - marginal_m.marginalize([idx]) + marginal_m = marginalize(m, [idx]) logp_fn = marginal_m.compile_logp() @@ -501,7 +603,7 @@ def test_mutable_indexing_jax_backend(): pytest.importorskip("jax") from pymc.sampling.jax import get_jaxified_logp - with MarginalModel() as model: + with Model() as model: data = pm.Data("data", np.zeros(10)) cat_effect = pm.Normal("cat_effect", sigma=1, shape=5) @@ -509,38 +611,8 @@ def test_mutable_indexing_jax_backend(): is_outlier = pm.Bernoulli("is_outlier", 0.4, shape=10) pm.LogNormal("y", mu=cat_effect[cat_effect_idx], sigma=1 + is_outlier, observed=data) - model.marginalize(["is_outlier"]) - get_jaxified_logp(model) - - -def test_marginal_model_func(): - def create_model(model_class): - with model_class(coords={"trial": range(10)}) as m: - idx = pm.Bernoulli("idx", p=0.5, dims="trial") - mu = pt.where(idx, 1, -1) - sigma = pm.HalfNormal("sigma") - y = pm.Normal("y", mu=mu, sigma=sigma, dims="trial", observed=[1] * 10) - return m - - marginal_m = marginalize(create_model(pm.Model), ["idx"]) - assert isinstance(marginal_m, MarginalModel) - - reference_m = create_model(MarginalModel) - reference_m.marginalize(["idx"]) - - # Check forward graph representation is the same - marginal_fgraph, _ = fgraph_from_model(marginal_m) - reference_fgraph, _ = fgraph_from_model(reference_m) - assert equal_computations_up_to_root(marginal_fgraph.outputs, reference_fgraph.outputs) - - # Check logp graph is the same - # This fails because OpFromGraphs comparison is broken - # assert equal_computations_up_to_root([marginal_m.logp()], [reference_m.logp()]) - ip = marginal_m.initial_point() - np.testing.assert_allclose( - marginal_m.compile_logp()(ip), - reference_m.compile_logp()(ip), - ) + marginal_model = marginalize(model, ["is_outlier"]) + get_jaxified_logp(marginal_model) class TestFullModels: @@ -559,10 +631,10 @@ def disaster_model(self): # fmt: on years = np.arange(1851, 1962) - with MarginalModel() as disaster_model: + with Model() as disaster_model: switchpoint = pm.DiscreteUniform("switchpoint", lower=years.min(), upper=years.max()) - early_rate = pm.Exponential("early_rate", 1.0, initval=3) - late_rate = pm.Exponential("late_rate", 1.0, initval=1) + early_rate = pm.Exponential("early_rate", 1.0) + late_rate = pm.Exponential("late_rate", 1.0) rate = pm.math.switch(switchpoint >= years, early_rate, late_rate) with pytest.warns(Warning): disasters = pm.Poisson("disasters", rate, observed=disaster_data) @@ -573,6 +645,8 @@ def test_change_point_model(self, disaster_model): m, years = disaster_model ip = m.initial_point() + ip["late_rate_log__"] += 1.0 # Make early and endpoint ip different + ip.pop("switchpoint") ref_logp_fn = m.compile_logp( [m["switchpoint"], m["disasters_observed"], m["disasters_unobserved"]] @@ -580,10 +654,12 @@ def test_change_point_model(self, disaster_model): ref_logp = logsumexp([ref_logp_fn({**ip, **{"switchpoint": year}}) for year in years]) with pytest.warns(UserWarning, match="There are multiple dependent variables"): - m.marginalize(m["switchpoint"]) + marginal_m = marginalize(m, m["switchpoint"]) - logp = m.compile_logp([m["disasters_observed"], m["disasters_unobserved"]])(ip) - np.testing.assert_almost_equal(logp, ref_logp) + marginal_m_logp = marginal_m.compile_logp( + [marginal_m["disasters_observed"], marginal_m["disasters_unobserved"]] + )(ip) + np.testing.assert_almost_equal(marginal_m_logp, ref_logp) @pytest.mark.slow def test_change_point_model_sampling(self, disaster_model): @@ -597,9 +673,9 @@ def test_change_point_model_sampling(self, disaster_model): ) with pytest.warns(UserWarning, match="There are multiple dependent variables"): - m.marginalize([m["switchpoint"]]) + marginal_m = marginalize(m, "switchpoint") - with m: + with marginal_m: after_marg = pm.sample(chains=2, random_seed=rng).posterior.stack( sample=("draw", "chain") ) @@ -618,7 +694,7 @@ def test_change_point_model_sampling(self, disaster_model): @pytest.mark.parametrize("univariate", (True, False)) def test_vector_univariate_mixture(self, univariate): - with MarginalModel() as m: + with Model() as m: idx = pm.Bernoulli("idx", p=0.5, shape=(2,) if univariate else ()) def dist(idx, size): @@ -630,8 +706,8 @@ def dist(idx, size): pm.CustomDist("norm", idx, dist=dist) - m.marginalize(idx) - logp_fn = m.compile_logp() + marginal_m = marginalize(m, idx) + logp_fn = marginal_m.compile_logp() if univariate: with pm.Model() as ref_m: @@ -659,16 +735,17 @@ def dist(idx, size): np.testing.assert_allclose(logp_fn(pt), ref_logp_fn(pt)) def test_k_censored_clusters_model(self): - def build_model(build_batched: bool) -> MarginalModel: - data = np.array([[-1.0, -1.0], [0.0, 0.0], [1.0, 1.0]]) - nobs = data.shape[0] - n_clusters = 5 + data = np.array([[-1.0, -1.0], [0.0, 0.0], [1.0, 1.0]]) + nobs = data.shape[0] + n_clusters = 5 + + def build_model(build_batched: bool) -> Model: coords = { "cluster": range(n_clusters), "ndim": ("x", "y"), "obs": range(nobs), } - with MarginalModel(coords=coords) as m: + with Model(coords=coords) as m: if build_batched: idx = pm.Categorical("idx", p=np.ones(n_clusters) / n_clusters, dims=["obs"]) else: @@ -683,7 +760,6 @@ def build_model(build_batched: bool) -> MarginalModel: "mu_x", dims=["cluster"], transform=ordered, - initval=np.linspace(-1, 1, n_clusters), ) mu_y = pm.Normal("mu_y", dims=["cluster"]) mu = pm.math.stack([mu_x, mu_y], axis=-1) # (cluster, ndim) @@ -702,12 +778,10 @@ def build_model(build_batched: bool) -> MarginalModel: return m - m = build_model(build_batched=True) - ref_m = build_model(build_batched=False) - - m.marginalize([m["idx"]]) - ref_m.marginalize([n for n in ref_m.named_vars if n.startswith("idx_")]) + m = marginalize(build_model(build_batched=True), "idx") + m.set_initval(m["mu_x"], np.linspace(-1, 1, n_clusters)) + ref_m = marginalize(build_model(build_batched=False), [f"idx_{i}" for i in range(nobs)]) test_point = m.initial_point() np.testing.assert_almost_equal( m.compile_logp()(test_point), @@ -715,9 +789,32 @@ def build_model(build_batched: bool) -> MarginalModel: ) +def test_unmarginalize(): + with pm.Model() as m: + idx = pm.Bernoulli("idx", p=0.5) + sub_idx = pm.Bernoulli("sub_idx", p=pt.as_tensor([0.3, 0.7])[idx]) + x = pm.Normal("x", mu=(idx + sub_idx) - 1) + + marginal_m = marginalize(m, [idx, sub_idx]) + assert not equivalent_models(marginal_m, m) + + unmarginal_m = unmarginalize(marginal_m) + assert equivalent_models(unmarginal_m, m) + + unmarginal_idx_explicit = unmarginalize(marginal_m, ("idx", "sub_idx")) + assert equivalent_models(unmarginal_idx_explicit, m) + + # Test partial unmarginalize + unmarginal_idx = unmarginalize(marginal_m, "idx") + assert equivalent_models(unmarginal_idx, marginalize(m, "sub_idx")) + + unmarginal_sub_idx = unmarginalize(marginal_m, "sub_idx") + assert equivalent_models(unmarginal_sub_idx, marginalize(m, "idx")) + + class TestRecoverMarginals: def test_basic(self): - with MarginalModel() as m: + with Model() as m: sigma = pm.HalfNormal("sigma") p = np.array([0.5, 0.2, 0.3]) k = pm.Categorical("k", p=p) @@ -725,7 +822,7 @@ def test_basic(self): mu_ = pt.as_tensor_variable(mu) y = pm.Normal("y", mu=mu_[k], sigma=sigma) - m.marginalize([k]) + m = marginalize(m, [k]) rng = np.random.default_rng(211) @@ -737,7 +834,7 @@ def test_basic(self): ) idata = InferenceData(posterior=dict_to_dataset(prior)) - idata = m.recover_marginals(idata, return_samples=True) + idata = recover_marginals(m, idata, return_samples=True) post = idata.posterior assert "k" in post assert "lp_k" in post @@ -763,12 +860,12 @@ def true_logp(y, sigma): def test_coords(self): """Test if coords can be recovered with marginalized value had it originally""" - with MarginalModel(coords={"year": [1990, 1991, 1992]}) as m: + with Model(coords={"year": [1990, 1991, 1992]}) as m: sigma = pm.HalfNormal("sigma") idx = pm.Bernoulli("idx", p=0.75, dims="year") x = pm.Normal("x", mu=idx, sigma=sigma, dims="year") - m.marginalize([idx]) + m = marginalize(m, [idx]) rng = np.random.default_rng(211) with m: @@ -781,19 +878,19 @@ def test_coords(self): posterior=dict_to_dataset({k: np.expand_dims(prior[k], axis=0) for k in prior}) ) - idata = m.recover_marginals(idata, return_samples=True) + idata = recover_marginals(m, idata, return_samples=True) post = idata.posterior assert post.idx.dims == ("chain", "draw", "year") assert post.lp_idx.dims == ("chain", "draw", "year", "lp_idx_dim") def test_batched(self): """Test that marginalization works for batched random variables""" - with MarginalModel() as m: + with Model() as m: sigma = pm.HalfNormal("sigma") idx = pm.Bernoulli("idx", p=0.7, shape=(3, 2)) y = pm.Normal("y", mu=idx.T, sigma=sigma, shape=(2, 3)) - m.marginalize([idx]) + m = marginalize(m, [idx]) rng = np.random.default_rng(211) @@ -807,7 +904,7 @@ def test_batched(self): posterior=dict_to_dataset({k: np.expand_dims(prior[k], axis=0) for k in prior}) ) - idata = m.recover_marginals(idata, return_samples=True) + idata = recover_marginals(m, idata, return_samples=True) post = idata.posterior assert post["y"].shape == (1, 20, 2, 3) assert post["idx"].shape == (1, 20, 3, 2) @@ -816,12 +913,12 @@ def test_batched(self): def test_nested(self): """Test that marginalization works when there are nested marginalized RVs""" - with MarginalModel() as m: + with Model() as m: idx = pm.Bernoulli("idx", p=0.75) sub_idx = pm.Bernoulli("sub_idx", p=pt.switch(pt.eq(idx, 0), 0.15, 0.95)) sub_dep = pm.Normal("y", mu=idx + sub_idx, sigma=1.0) - m.marginalize([idx, sub_idx]) + m = marginalize(m, [idx, sub_idx]) rng = np.random.default_rng(211) @@ -833,7 +930,7 @@ def test_nested(self): ) idata = InferenceData(posterior=dict_to_dataset(prior)) - idata = m.recover_marginals(idata, return_samples=True) + idata = recover_marginals(m, idata, return_samples=True) post = idata.posterior assert "idx" in post assert "lp_idx" in post diff --git a/tests/utils.py b/tests/utils.py index 9576b4be..e69de29b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,31 +0,0 @@ -from collections.abc import Sequence - -from pytensor.compile import SharedVariable -from pytensor.graph import Constant, graph_inputs -from pytensor.graph.basic import Variable, equal_computations -from pytensor.tensor.random.type import RandomType - - -def equal_computations_up_to_root( - xs: Sequence[Variable], ys: Sequence[Variable], ignore_rng_values=True -) -> bool: - # Check if graphs are equivalent even if root variables have distinct identities - - x_graph_inputs = [var for var in graph_inputs(xs) if not isinstance(var, Constant)] - y_graph_inputs = [var for var in graph_inputs(ys) if not isinstance(var, Constant)] - if len(x_graph_inputs) != len(y_graph_inputs): - return False - for x, y in zip(x_graph_inputs, y_graph_inputs): - if x.type != y.type: - return False - if x.name != y.name: - return False - if isinstance(x, SharedVariable): - if not isinstance(y, SharedVariable): - return False - if isinstance(x.type, RandomType) and ignore_rng_values: - continue - if not x.type.values_eq(x.get_value(), y.get_value()): - return False - - return equal_computations(xs, ys, in_xs=x_graph_inputs, in_ys=y_graph_inputs)