diff --git a/pymc_experimental/model_transform/basic.py b/pymc_experimental/model_transform/basic.py index 23e10cebe..8146f8fea 100644 --- a/pymc_experimental/model_transform/basic.py +++ b/pymc_experimental/model_transform/basic.py @@ -1,4 +1,7 @@ +from typing import List, Sequence, Union + from pymc import Model +from pytensor import Variable from pytensor.graph import ancestors from pymc_experimental.utils.model_fgraph import ( @@ -8,6 +11,8 @@ model_from_fgraph, ) +ModelVariable = Union[Variable, str] + def prune_vars_detached_from_observed(model: Model) -> Model: """Prune model variables that are not related to any observed variable in the Model.""" @@ -33,3 +38,9 @@ def prune_vars_detached_from_observed(model: Model) -> Model: for node_to_remove in nodes_to_remove: fgraph.remove_node(node_to_remove) return model_from_fgraph(fgraph) + + +def parse_vars(model: Model, vars: Union[ModelVariable, Sequence[ModelVariable]]) -> List[Variable]: + if not isinstance(vars, (list, tuple)): + vars = (vars,) + return [model[var] if isinstance(var, str) else var for var in vars] diff --git a/pymc_experimental/model_transform/conditioning.py b/pymc_experimental/model_transform/conditioning.py index fb4468c87..8919adbf0 100644 --- a/pymc_experimental/model_transform/conditioning.py +++ b/pymc_experimental/model_transform/conditioning.py @@ -1,16 +1,23 @@ -from typing import Any, Dict, List, Sequence, Union +from typing import Any, Dict, List, Optional, Sequence, Union from pymc import Model +from pymc.logprob.transforms import RVTransform from pymc.pytensorf import _replace_vars_in_graphs +from pymc.util import get_transformed_name, get_untransformed_name from pytensor.tensor import TensorVariable -from pymc_experimental.model_transform.basic import prune_vars_detached_from_observed +from pymc_experimental.model_transform.basic import ( + ModelVariable, + parse_vars, + prune_vars_detached_from_observed, +) from pymc_experimental.utils.model_fgraph import ( ModelDeterministic, ModelFreeRV, extract_dims, fgraph_from_model, model_deterministic, + model_free_rv, model_from_fgraph, model_named, model_observed_rv, @@ -206,3 +213,132 @@ def do( if prune_vars: return prune_vars_detached_from_observed(model) return model + + +def change_value_transforms( + model: Model, + vars_to_transforms: Dict[ModelVariable, Union[RVTransform, None]], +) -> Model: + """Change the value variables transforms in the model + + Parameters + ---------- + model: Model + vars_to_transforms: Dict + Mapping between RVs and new transforms to be applied to the respective value variables + + Returns + ------- + new_model: Model + Model with the updated transformed value variables + + Examples + -------- + Extract untransformed space Hessian after finding transformed space MAP + + .. code-block:: python + + import pymc as pm + from pymc.distributions.transforms import logodds + from pymc_experimental.model_transform.conditioning import change_value_transforms + + with pm.Model() as base_m: + p = pm.Uniform("p", 0, 1, transform=None) + w = pm.Binomial("w", n=9, p=p, observed=6) + + with change_value_transforms(base_m, {"p": logodds}) as transformed_p: + mean_q = pm.find_MAP() + + with change_value_transforms(transformed_p, {"p": None}) as untransformed_p: + new_p = untransformed_p['p'] + std_q = ((1 / pm.find_hessian(mean_q, vars=[new_p])) ** 0.5)[0] + + print(f" Mean, Standard deviation\np {mean_q['p']:.2}, {std_q[0]:.2}") + # Mean, Standard deviation + # p 0.67, 0.16 + + """ + vars_to_transforms = { + parse_vars(model, var)[0]: transform for var, transform in vars_to_transforms.items() + } + + if set(vars_to_transforms.keys()) - set(model.free_RVs): + raise ValueError(f"All keys must be free variables in the model: {vars_to_transforms}") + + fgraph, memo = fgraph_from_model(model) + + vars_to_transforms = {memo[var]: transform for var, transform in vars_to_transforms.items()} + replacements = {} + for node in fgraph.apply_nodes: + if not isinstance(node.op, ModelFreeRV): + continue + + [dummy_rv] = node.outputs + if dummy_rv not in vars_to_transforms: + continue + + transform = vars_to_transforms[dummy_rv] + + rv, value, *dims = node.inputs + + new_value = rv.type() + try: + untransformed_name = get_untransformed_name(value.name) + except ValueError: + untransformed_name = value.name + if transform: + new_name = get_transformed_name(untransformed_name, transform) + else: + new_name = untransformed_name + new_value.name = new_name + + new_dummy_rv = model_free_rv(rv, new_value, transform, *dims) + replacements[dummy_rv] = new_dummy_rv + + toposort_replace(fgraph, tuple(replacements.items())) + return model_from_fgraph(fgraph) + + +def remove_value_transforms( + model: Model, + vars: Optional[Sequence[ModelVariable]] = None, +) -> Model: + """Remove the value variables transforms in the model + + Parameters + ---------- + model: Model + vars: Model variables, optional + Model variables for which to remove transforms. Defaults to all transformed variables + + Returns + ------- + new_model: Model + Model with the removed transformed value variables + + Examples + -------- + Extract untransformed space Hessian after finding transformed space MAP + + .. code-block:: python + + import pymc as pm + from pymc_experimental.model_transform.conditioning import remove_value_transforms + + with pm.Model() as transformed_m: + p = pm.Uniform("p", 0, 1) + w = pm.Binomial("w", n=9, p=p, observed=6) + mean_q = pm.find_MAP() + + with remove_value_transforms(transformed_m) as untransformed_m: + new_p = untransformed_m["p"] + std_q = ((1 / pm.find_hessian(mean_q, vars=[new_p])) ** 0.5)[0] + print(f" Mean, Standard deviation\np {mean_q['p']:.2}, {std_q[0]:.2}") + + # Mean, Standard deviation + # p 0.67, 0.16 + + """ + if vars is None: + vars = model.free_RVs + return change_value_transforms(model, {var: None for var in vars}) diff --git a/pymc_experimental/tests/model_transform/test_conditioning.py b/pymc_experimental/tests/model_transform/test_conditioning.py index 5d455e698..fd16162e7 100644 --- a/pymc_experimental/tests/model_transform/test_conditioning.py +++ b/pymc_experimental/tests/model_transform/test_conditioning.py @@ -2,10 +2,16 @@ import numpy as np import pymc as pm import pytest +from pymc.distributions.transforms import logodds from pymc.variational.minibatch_rv import create_minibatch_rv from pytensor import config -from pymc_experimental.model_transform.conditioning import do, observe +from pymc_experimental.model_transform.conditioning import ( + change_value_transforms, + do, + observe, + remove_value_transforms, +) def test_observe(): @@ -214,3 +220,59 @@ def test_do_prune(prune): assert set(do_m.named_vars) == {"x1", "z", "llike"} else: assert set(do_m.named_vars) == orig_named_vars + + +def test_change_value_transforms(): + with pm.Model() as base_m: + p = pm.Uniform("p", 0, 1, transform=None) + w = pm.Binomial("w", n=9, p=p, observed=6) + assert base_m.rvs_to_transforms == {p: None, w: None} + + with change_value_transforms(base_m, {"p": logodds}) as transformed_p: + new_p = transformed_p["p"] + new_w = transformed_p["w"] + assert transformed_p.rvs_to_transforms == {new_p: logodds, new_w: None} + mean_q = pm.find_MAP(progressbar=False) + + with change_value_transforms(transformed_p, {"p": None}) as untransformed_p: + new_p = untransformed_p["p"] + new_w = untransformed_p["w"] + assert untransformed_p.rvs_to_transforms == {new_p: None, new_w: None} + std_q = ((1 / pm.find_hessian(mean_q, vars=[new_p])) ** 0.5)[0] + + assert np.round(mean_q["p"], 2) == 0.67 + assert np.round(std_q[0], 2) == 0.16 + + +def test_change_value_transforms_error(): + with pm.Model() as m: + x = pm.Uniform("x", observed=5.0) + + with pytest.raises(ValueError, match="All keys must be free variables in the model"): + change_value_transforms(m, {x: logodds}) + + +def test_remove_value_transforms(): + with pm.Model() as base_m: + p = pm.Uniform("p", transform=logodds) + q = pm.Uniform("q", transform=logodds) + + new_m = remove_value_transforms(base_m) + new_p = new_m["p"] + new_q = new_m["q"] + assert new_m.rvs_to_transforms == {new_p: None, new_q: None} + + new_m = remove_value_transforms(base_m, [p, q]) + new_p = new_m["p"] + new_q = new_m["q"] + assert new_m.rvs_to_transforms == {new_p: None, new_q: None} + + new_m = remove_value_transforms(base_m, [p]) + new_p = new_m["p"] + new_q = new_m["q"] + assert new_m.rvs_to_transforms == {new_p: None, new_q: logodds} + + new_m = remove_value_transforms(base_m, ["q"]) + new_p = new_m["p"] + new_q = new_m["q"] + assert new_m.rvs_to_transforms == {new_p: logodds, new_q: None}