Skip to content

Commit

Permalink
New var_names kwarg for pm.model_to_graphviz (#5634)
Browse files Browse the repository at this point in the history
Enables positive selection of model variables to be included in the model graph.

Co-authored-by: Michael Osthege <michael.osthege@outlook.com>
  • Loading branch information
larryshamalama and michaelosthege authored May 22, 2022
1 parent 66fba38 commit d0af6b1
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 90 deletions.
168 changes: 79 additions & 89 deletions pymc/model_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
# limitations under the License.
import warnings

from collections import defaultdict, deque
from typing import Dict, Iterator, NewType, Optional, Set
from collections import defaultdict
from typing import Dict, Iterable, List, NewType, Optional, Set

from aesara import function
from aesara.compile.sharedvalue import SharedVariable
from aesara.graph.basic import walk
from aesara.graph import Apply
from aesara.graph.basic import ancestors, walk
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.var import TensorConstant, TensorVariable

Expand All @@ -32,85 +33,64 @@
class ModelGraph:
def __init__(self, model):
self.model = model
self.var_names = get_default_varnames(self.model.named_vars, include_transformed=False)
self._all_var_names = get_default_varnames(self.model.named_vars, include_transformed=False)
self.var_list = self.model.named_vars.values()
self.transform_map = {
v.transformed: v.name for v in self.var_list if hasattr(v, "transformed")
}
self._deterministics = None

def get_deterministics(self, var):
"""Compute the deterministic nodes of the graph, **not** including var itself."""
deterministics = []
attrs = ("transformed", "logpt")
for v in self.var_list:
if v != var and all(not hasattr(v, attr) for attr in attrs):
deterministics.append(v)
return deterministics

def _get_ancestors(self, var: TensorVariable, func) -> Set[TensorVariable]:
"""Get all ancestors of a function, doing some accounting for deterministics."""

# this contains all of the variables in the model EXCEPT var...
vars = set(self.var_list)
vars.remove(var)

blockers = set() # type: Set[TensorVariable]
retval = set() # type: Set[TensorVariable]

def _expand(node) -> Optional[Iterator[TensorVariable]]:
if node in blockers:
return None
elif node in vars:
blockers.add(node)
retval.add(node)
return None
elif node.owner:
blockers.add(node)
return reversed(node.owner.inputs)
else:
return None

list(walk(deque([func]), _expand, bfs=True))
return retval

def _filter_parents(self, var, parents) -> Set[VarName]:
"""Get direct parents of a var, as strings"""
keep = set() # type: Set[VarName]
for p in parents:
if p == var:
continue
elif p.name in self.var_names:
keep.add(p.name)
elif p in self.transform_map:
if self.transform_map[p] != var.name:
keep.add(self.transform_map[p])
else:
raise AssertionError(f"Do not know what to do with {get_var_name(p)}")
return keep

def get_parents(self, var: TensorVariable) -> Set[VarName]:
"""Get the named nodes that are direct inputs to the var"""
# TODO: Update these lines, variables no longer have a `logpt` attribute
if hasattr(var, "transformed"):
func = var.transformed.logpt
elif hasattr(var, "logpt"):
func = var.logpt
else:
func = var

parents = self._get_ancestors(var, func)
return self._filter_parents(var, parents)
def get_parent_names(self, var: TensorVariable) -> Set[VarName]:
if var.owner is None or var.owner.inputs is None:
return set()

def _expand(x):
if x.name:
return [x]
if isinstance(x.owner, Apply):
return reversed(x.owner.inputs)
return []

parents = {get_var_name(x) for x in walk(nodes=var.owner.inputs, expand=_expand) if x.name}

return parents

def vars_to_plot(self, var_names: Optional[Iterable[VarName]] = None) -> List[VarName]:
if var_names is None:
return self._all_var_names

selected_names = set(var_names)

# .copy() because sets cannot change in size during iteration
for var_name in selected_names.copy():
if var_name not in self._all_var_names:
raise ValueError(f"{var_name} is not in this model.")

for model_var in self.var_list:
if hasattr(model_var.tag, "observations"):
if model_var.tag.observations == self.model[var_name]:
selected_names.add(model_var.name)

def make_compute_graph(self) -> Dict[str, Set[VarName]]:
selected_ancestors = set(
filter(
lambda rv: rv.name in self._all_var_names,
list(ancestors([self.model[var_name] for var_name in selected_names])),
)
)

for var in selected_ancestors.copy():
if hasattr(var.tag, "observations"):
selected_ancestors.add(var.tag.observations)

# ordering of self._all_var_names is important
return [var.name for var in selected_ancestors]

def make_compute_graph(
self, var_names: Optional[Iterable[VarName]] = None
) -> Dict[VarName, Set[VarName]]:
"""Get map of var_name -> set(input var names) for the model"""
input_map = defaultdict(set) # type: Dict[str, Set[VarName]]
input_map: Dict[VarName, Set[VarName]] = defaultdict(set)

for var_name in self.var_names:
for var_name in self.vars_to_plot(var_names):
var = self.model[var_name]
key = var_name
val = self.get_parents(var)
input_map[key] = input_map[key].union(val)
parent_name = self.get_parent_names(var)
input_map[var_name] = input_map[var_name].union(parent_name)

if hasattr(var.tag, "observations"):
try:
Expand All @@ -120,6 +100,7 @@ def make_compute_graph(self) -> Dict[str, Set[VarName]]:
input_map[obs_name] = input_map[obs_name].union({var_name})
except AttributeError:
pass

return input_map

def _make_node(self, var_name, graph, *, formatting: str = "plain"):
Expand Down Expand Up @@ -168,18 +149,20 @@ def _make_node(self, var_name, graph, *, formatting: str = "plain"):
def _eval(self, var):
return function([], var, mode="FAST_COMPILE")()

def get_plates(self):
def get_plates(self, var_names: Optional[Iterable[VarName]] = None) -> Dict[str, Set[VarName]]:
"""Rough but surprisingly accurate plate detection.
Just groups by the shape of the underlying distribution. Will be wrong
if there are two plates with the same shape.
Returns
-------
dict: str -> set[str]
dict
Maps plate labels to the set of ``VarName``s inside the plate.
"""
plates = defaultdict(set)
for var_name in self.var_names:

for var_name in self.vars_to_plot(var_names):
v = self.model[var_name]
if var_name in self.model.RV_dims:
plate_label = " x ".join(
Expand All @@ -189,9 +172,10 @@ def get_plates(self):
else:
plate_label = " x ".join(map(str, self._eval(v.shape)))
plates[plate_label].add(var_name)

return plates

def make_graph(self, formatting: str = "plain"):
def make_graph(self, var_names: Optional[Iterable[VarName]] = None, formatting: str = "plain"):
"""Make graphviz Digraph of PyMC model
Returns
Expand All @@ -207,25 +191,29 @@ def make_graph(self, formatting: str = "plain"):
"\tconda install -c conda-forge python-graphviz"
)
graph = graphviz.Digraph(self.model.name)
for plate_label, var_names in self.get_plates().items():
for plate_label, all_var_names in self.get_plates(var_names).items():
if plate_label:
# must be preceded by 'cluster' to get a box around it
with graph.subgraph(name="cluster" + plate_label) as sub:
for var_name in var_names:
for var_name in all_var_names:
self._make_node(var_name, sub, formatting=formatting)
# plate label goes bottom right
sub.attr(label=plate_label, labeljust="r", labelloc="b", style="rounded")
else:
for var_name in var_names:
for var_name in all_var_names:
self._make_node(var_name, graph, formatting=formatting)

for key, values in self.make_compute_graph().items():
for value in values:
graph.edge(value.replace(":", "&"), key.replace(":", "&"))
for child, parents in self.make_compute_graph(var_names=var_names).items():
# parents is a set of rv names that preceed child rv nodes
for parent in parents:
graph.edge(parent.replace(":", "&"), child.replace(":", "&"))

return graph


def model_to_graphviz(model=None, *, formatting: str = "plain"):
def model_to_graphviz(
model=None, *, var_names: Optional[Iterable[VarName]] = None, formatting: str = "plain"
):
"""Produce a graphviz Digraph from a PyMC model.
Requires graphviz, which may be installed most easily with
Expand All @@ -240,7 +228,9 @@ def model_to_graphviz(model=None, *, formatting: str = "plain"):
----------
model : pm.Model
The model to plot. Not required when called from inside a modelcontext.
formatting : str
var_names : iterable of variable names, optional
Subset of variables to be plotted that identify a subgraph with respect to the entire model graph
formatting : str, optional
one of { "plain" }
Examples
Expand Down Expand Up @@ -275,4 +265,4 @@ def model_to_graphviz(model=None, *, formatting: str = "plain"):
"Formattings other than 'plain' are currently not supported.", UserWarning, stacklevel=2
)
model = pm.modelcontext(model)
return ModelGraph(model).make_graph(formatting=formatting)
return ModelGraph(model).make_graph(var_names=var_names, formatting=formatting)
81 changes: 80 additions & 1 deletion pymc/tests/test_model_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def setup_class(cls):
def test_inputs(self):
for child, parents_in_plot in self.compute_graph.items():
var = self.model[child]
parents_in_graph = self.model_graph.get_parents(var)
parents_in_graph = self.model_graph.get_parent_names(var)
if isinstance(var, SharedVariable):
# observed data also doesn't have parents in the compute graph!
# But for the visualization we like them to become decendants of the
Expand Down Expand Up @@ -183,6 +183,85 @@ def test_checks_formatting(self):
model_to_graphviz(self.model, formatting="plain_with_params")


def model_with_different_descendants():
"""
Model proposed by Michael to test variable selection functionality
From here: https://github.com/pymc-devs/pymc/pull/5634#pullrequestreview-916297509
"""
with pm.Model() as pmodel2:
a = pm.Normal("a")
b = pm.Normal("b")
pm.Normal("c", a * b)
intermediate = pm.Deterministic("intermediate", a + b)
pred = pm.Deterministic("pred", intermediate * 3)

obs = pm.ConstantData("obs", 1.75)

L = pm.Normal("L", mu=1 + 0.5 * pred, observed=obs)

return pmodel2


class TestParents:
@pytest.mark.parametrize(
"var_name, parent_names",
[
("L", {"pred"}),
("pred", {"intermediate"}),
("intermediate", {"a", "b"}),
("c", {"a", "b"}),
("a", set()),
("b", set()),
],
)
def test_get_parent_names(self, var_name, parent_names):
mg = ModelGraph(model_with_different_descendants())
mg.get_parent_names(mg.model[var_name]) == parent_names


class TestVariableSelection:
@pytest.mark.parametrize(
"var_names, vars_to_plot, compute_graph",
[
(["c"], ["a", "b", "c"], {"c": {"a", "b"}, "a": set(), "b": set()}),
(
["L"],
["pred", "obs", "L", "intermediate", "a", "b"],
{
"pred": {"intermediate"},
"obs": {"L"},
"L": {"pred"},
"intermediate": {"a", "b"},
"a": set(),
"b": set(),
},
),
(
["obs"],
["pred", "obs", "L", "intermediate", "a", "b"],
{
"pred": {"intermediate"},
"obs": {"L"},
"L": {"pred"},
"intermediate": {"a", "b"},
"a": set(),
"b": set(),
},
),
# selecting ["c", "L"] is akin to selecting the entire graph
(
["c", "L"],
ModelGraph(model_with_different_descendants()).vars_to_plot(),
ModelGraph(model_with_different_descendants()).make_compute_graph(),
),
],
)
def test_subgraph(self, var_names, vars_to_plot, compute_graph):
mg = ModelGraph(model_with_different_descendants())
assert set(mg.vars_to_plot(var_names=var_names)) == set(vars_to_plot)
assert mg.make_compute_graph(var_names=var_names) == compute_graph


class TestImputationModel(BaseModelGraphTest):
model_func = model_with_imputations

Expand Down

0 comments on commit d0af6b1

Please sign in to comment.