Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement utility to change value variable transforms #216

Merged
merged 2 commits into from
Aug 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/api_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ Model Transformations

conditioning.do
conditioning.observe
conditioning.change_value_transforms
conditioning.remove_value_transforms


Utils
Expand Down
11 changes: 11 additions & 0 deletions pymc_experimental/model_transform/basic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import List, Sequence, Union

from pymc import Model
from pytensor import Variable
from pytensor.graph import ancestors

from pymc_experimental.utils.model_fgraph import (
Expand All @@ -8,6 +11,8 @@
model_from_fgraph,
)

ModelVariable = Union[Variable, str]


def prune_vars_detached_from_observed(model: Model) -> Model:
"""Prune model variables that are not related to any observed variable in the Model."""
Expand All @@ -33,3 +38,9 @@ def prune_vars_detached_from_observed(model: Model) -> Model:
for node_to_remove in nodes_to_remove:
fgraph.remove_node(node_to_remove)
return model_from_fgraph(fgraph)


def parse_vars(model: Model, vars: Union[ModelVariable, Sequence[ModelVariable]]) -> List[Variable]:
if not isinstance(vars, (list, tuple)):
vars = (vars,)
return [model[var] if isinstance(var, str) else var for var in vars]
148 changes: 144 additions & 4 deletions pymc_experimental/model_transform/conditioning.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
from typing import Any, Dict, List, Sequence, Union
from typing import Any, List, Mapping, Optional, Sequence, Union

from pymc import Model
from pymc.logprob.transforms import RVTransform
from pymc.pytensorf import _replace_vars_in_graphs
from pymc.util import get_transformed_name, get_untransformed_name
from pytensor.tensor import TensorVariable

from pymc_experimental.model_transform.basic import prune_vars_detached_from_observed
from pymc_experimental.model_transform.basic import (
ModelVariable,
parse_vars,
prune_vars_detached_from_observed,
)
from pymc_experimental.utils.model_fgraph import (
ModelDeterministic,
ModelFreeRV,
extract_dims,
fgraph_from_model,
model_deterministic,
model_free_rv,
model_from_fgraph,
model_named,
model_observed_rv,
Expand All @@ -19,7 +26,9 @@
from pymc_experimental.utils.pytensorf import rvs_in_graph


def observe(model: Model, vars_to_observations: Dict[Union["str", TensorVariable], Any]) -> Model:
def observe(
model: Model, vars_to_observations: Mapping[Union["str", TensorVariable], Any]
) -> Model:
"""Convert free RVs or Deterministics to observed RVs.

Parameters
Expand Down Expand Up @@ -115,7 +124,9 @@ def replacement_fn(var, inner_replacements):


def do(
model: Model, vars_to_interventions: Dict[Union["str", TensorVariable], Any], prune_vars=False
model: Model,
vars_to_interventions: Mapping[Union["str", TensorVariable], Any],
prune_vars=False,
) -> Model:
"""Replace model variables by intervention variables.

Expand Down Expand Up @@ -206,3 +217,132 @@ def do(
if prune_vars:
return prune_vars_detached_from_observed(model)
return model


def change_value_transforms(
model: Model,
vars_to_transforms: Mapping[ModelVariable, Union[RVTransform, None]],
) -> Model:
"""Change the value variables transforms in the model

Parameters
----------
model : Model
vars_to_transforms : Dict
Dictionary that maps RVs to new transforms to be applied to the respective value variables

Returns
-------
new_model : Model
Model with the updated transformed value variables

Examples
--------
Extract untransformed space Hessian after finding transformed space MAP

.. code-block:: python

import pymc as pm
from pymc.distributions.transforms import logodds
from pymc_experimental.model_transform.conditioning import change_value_transforms

with pm.Model() as base_m:
p = pm.Uniform("p", 0, 1, transform=None)
w = pm.Binomial("w", n=9, p=p, observed=6)

with change_value_transforms(base_m, {"p": logodds}) as transformed_p:
mean_q = pm.find_MAP()

with change_value_transforms(transformed_p, {"p": None}) as untransformed_p:
new_p = untransformed_p['p']
std_q = ((1 / pm.find_hessian(mean_q, vars=[new_p])) ** 0.5)[0]

print(f" Mean, Standard deviation\np {mean_q['p']:.2}, {std_q[0]:.2}")
# Mean, Standard deviation
# p 0.67, 0.16

"""
vars_to_transforms = {
parse_vars(model, var)[0]: transform for var, transform in vars_to_transforms.items()
}

if set(vars_to_transforms.keys()) - set(model.free_RVs):
raise ValueError(f"All keys must be free variables in the model: {model.free_RVs}")

fgraph, memo = fgraph_from_model(model)

vars_to_transforms = {memo[var]: transform for var, transform in vars_to_transforms.items()}
replacements = {}
for node in fgraph.apply_nodes:
if not isinstance(node.op, ModelFreeRV):
continue

[dummy_rv] = node.outputs
if dummy_rv not in vars_to_transforms:
continue

transform = vars_to_transforms[dummy_rv]

rv, value, *dims = node.inputs

new_value = rv.type()
try:
untransformed_name = get_untransformed_name(value.name)
except ValueError:
untransformed_name = value.name
if transform:
new_name = get_transformed_name(untransformed_name, transform)
else:
new_name = untransformed_name
new_value.name = new_name

new_dummy_rv = model_free_rv(rv, new_value, transform, *dims)
replacements[dummy_rv] = new_dummy_rv

toposort_replace(fgraph, tuple(replacements.items()))
return model_from_fgraph(fgraph)


def remove_value_transforms(
model: Model,
vars: Optional[Sequence[ModelVariable]] = None,
) -> Model:
"""Remove the value variables transforms in the model

Parameters
----------
model : Model
vars : Model variables, optional
Model variables for which to remove transforms. Defaults to all transformed variables

Returns
-------
new_model : Model
Model with the removed transformed value variables

Examples
--------
Extract untransformed space Hessian after finding transformed space MAP

.. code-block:: python

import pymc as pm
from pymc_experimental.model_transform.conditioning import remove_value_transforms

with pm.Model() as transformed_m:
p = pm.Uniform("p", 0, 1)
w = pm.Binomial("w", n=9, p=p, observed=6)
mean_q = pm.find_MAP()

with remove_value_transforms(transformed_m) as untransformed_m:
new_p = untransformed_m["p"]
std_q = ((1 / pm.find_hessian(mean_q, vars=[new_p])) ** 0.5)[0]
print(f" Mean, Standard deviation\np {mean_q['p']:.2}, {std_q[0]:.2}")

# Mean, Standard deviation
# p 0.67, 0.16

"""
if vars is None:
vars = model.free_RVs
return change_value_transforms(model, {var: None for var in vars})
65 changes: 64 additions & 1 deletion pymc_experimental/tests/model_transform/test_conditioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,16 @@
import numpy as np
import pymc as pm
import pytest
from pymc.distributions.transforms import logodds
from pymc.variational.minibatch_rv import create_minibatch_rv
from pytensor import config

from pymc_experimental.model_transform.conditioning import do, observe
from pymc_experimental.model_transform.conditioning import (
change_value_transforms,
do,
observe,
remove_value_transforms,
)


def test_observe():
Expand Down Expand Up @@ -214,3 +220,60 @@ def test_do_prune(prune):
assert set(do_m.named_vars) == {"x1", "z", "llike"}
else:
assert set(do_m.named_vars) == orig_named_vars


def test_change_value_transforms():
with pm.Model() as base_m:
p = pm.Uniform("p", 0, 1, transform=None)
w = pm.Binomial("w", n=9, p=p, observed=6)
assert base_m.rvs_to_transforms[p] is None
assert base_m.rvs_to_values[p].name == "p"

with change_value_transforms(base_m, {"p": logodds}) as transformed_p:
new_p = transformed_p["p"]
assert transformed_p.rvs_to_transforms[new_p] == logodds
assert transformed_p.rvs_to_values[new_p].name == "p_logodds__"
mean_q = pm.find_MAP(progressbar=False)

with change_value_transforms(transformed_p, {"p": None}) as untransformed_p:
new_p = untransformed_p["p"]
assert untransformed_p.rvs_to_transforms[new_p] is None
assert untransformed_p.rvs_to_values[new_p].name == "p"
std_q = ((1 / pm.find_hessian(mean_q, vars=[new_p])) ** 0.5)[0]

np.testing.assert_allclose(np.round(mean_q["p"], 2), 0.67)
np.testing.assert_allclose(np.round(std_q[0], 2), 0.16)


def test_change_value_transforms_error():
with pm.Model() as m:
x = pm.Uniform("x", observed=5.0)

with pytest.raises(ValueError, match="All keys must be free variables in the model"):
change_value_transforms(m, {x: logodds})


def test_remove_value_transforms():
with pm.Model() as base_m:
p = pm.Uniform("p", transform=logodds)
q = pm.Uniform("q", transform=logodds)

new_m = remove_value_transforms(base_m)
new_p = new_m["p"]
new_q = new_m["q"]
assert new_m.rvs_to_transforms == {new_p: None, new_q: None}

new_m = remove_value_transforms(base_m, [p, q])
new_p = new_m["p"]
new_q = new_m["q"]
assert new_m.rvs_to_transforms == {new_p: None, new_q: None}

new_m = remove_value_transforms(base_m, [p])
new_p = new_m["p"]
new_q = new_m["q"]
assert new_m.rvs_to_transforms == {new_p: None, new_q: logodds}

new_m = remove_value_transforms(base_m, ["q"])
new_p = new_m["p"]
new_q = new_m["q"]
assert new_m.rvs_to_transforms == {new_p: logodds, new_q: None}