Skip to content

Refactor ModelGraph for v4 #4818

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ jobs:
pymc3/tests/test_gp.py
pymc3/tests/test_model.py
pymc3/tests/test_model_func.py
pymc3/tests/test_model_graph.py
pymc3/tests/test_ode.py
pymc3/tests/test_posdef_sym.py
pymc3/tests/test_quadpotential.py
Expand Down Expand Up @@ -149,6 +150,7 @@ jobs:
pymc3/tests/test_model.py
pymc3/tests/test_model_func.py
pymc3/tests/test_modelcontext.py
pymc3/tests/test_model_graph.py
pymc3/tests/test_pickling.py

fail-fast: false
Expand Down
104 changes: 52 additions & 52 deletions pymc3/model_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings

from collections import deque
from collections import defaultdict, deque
from typing import Dict, Iterator, NewType, Optional, Set

from aesara.compile.sharedvalue import SharedVariable
Expand Down Expand Up @@ -101,23 +102,20 @@ def get_parents(self, var: TensorVariable) -> Set[VarName]:

def make_compute_graph(self) -> Dict[str, Set[VarName]]:
"""Get map of var_name -> set(input var names) for the model"""
input_map = {} # type: Dict[str, Set[VarName]]

def update_input_map(key: str, val: Set[VarName]):
if key in input_map:
input_map[key] = input_map[key].union(val)
else:
input_map[key] = val
input_map = defaultdict(set) # type: Dict[str, Set[VarName]]

for var_name in self.var_names:
var = self.model[var_name]
update_input_map(var_name, self.get_parents(var))
key = var_name
val = self.get_parents(var)
input_map[key] = input_map[key].union(val)

if hasattr(var.tag, "observations"):
try:
obs_name = var.tag.observations.name
if obs_name:
input_map[var_name] = input_map[var_name].difference({obs_name})
update_input_map(obs_name, {var_name})
input_map[obs_name] = input_map[obs_name].union({var_name})
except AttributeError:
pass
return input_map
Expand All @@ -126,30 +124,40 @@ def _make_node(self, var_name, graph, *, formatting: str = "plain"):
"""Attaches the given variable to a graphviz Digraph"""
v = self.model[var_name]

# styling for node
attrs = {}
if v.owner and isinstance(v.owner.op, RandomVariable) and hasattr(v.tag, "observations"):
attrs["style"] = "filled"

# make Data be roundtangle, instead of rectangle
if isinstance(v, SharedVariable):
attrs["style"] = "rounded, filled"

# determine the shape for this node (default (Distribution) is ellipse)
if v in self.model.potentials:
attrs["shape"] = "octagon"
elif isinstance(v, SharedVariable) or not hasattr(v, "distribution"):
# shared variables and Deterministic represented by a box
attrs["shape"] = "box"
shape = None
style = None
label = str(v)

if v in self.model.potentials:
shape = "octagon"
style = "filled"
label = f"{var_name}\n~\nPotential"
elif isinstance(v, SharedVariable):
shape = "box"
style = "rounded, filled"
label = f"{var_name}\n~\nData"
elif v.owner and isinstance(v.owner.op, RandomVariable):
shape = "ellipse"
if hasattr(v.tag, "observations"):
# observed RV
style = "filled"
else:
shape = "ellipse"
syle = None
symbol = v.owner.op.__class__.__name__.strip("RV")
label = f"{var_name}\n~\n{symbol}"
else:
label = v._str_repr(formatting=formatting).replace(" ~ ", "\n~\n")
shape = "box"
style = None
label = f"{var_name}\n~\nDeterministic"

kwargs = {
"shape": shape,
"style": style,
"label": label,
}

graph.node(var_name.replace(":", "&"), label, **attrs)
graph.node(var_name.replace(":", "&"), **kwargs)

def get_plates(self):
"""Rough but surprisingly accurate plate detection.
Expand All @@ -161,26 +169,17 @@ def get_plates(self):
-------
dict: str -> set[str]
"""
plates = {}
plates = defaultdict(set)
for var_name in self.var_names:
v = self.model[var_name]
if hasattr(v, "observations"):
try:
# To get shape of _observed_ data container `pm.Data`
# (wrapper for aesara.SharedVariable) we evaluate it.
shape = tuple(v.observations.shape.eval())
except AttributeError:
shape = v.observations.shape
# XXX: This needs to be refactored
# elif hasattr(v, "dshape"):
# shape = v.dshape
if var_name in self.model.RV_dims:
plate_label = " x ".join(
f"{d} ({self.model.dim_lengths[d].eval()})"
for d in self.model.RV_dims[var_name]
)
else:
shape = v.tag.test_value.shape
if shape == (1,):
shape = tuple()
if shape not in plates:
plates[shape] = set()
plates[shape].add(var_name)
plate_label = " x ".join(map(str, v.shape.eval()))
plates[plate_label].add(var_name)
return plates

def make_graph(self, formatting: str = "plain"):
Expand All @@ -199,17 +198,14 @@ def make_graph(self, formatting: str = "plain"):
"\tconda install -c conda-forge python-graphviz"
)
graph = graphviz.Digraph(self.model.name)
for shape, var_names in self.get_plates().items():
if isinstance(shape, SharedVariable):
shape = shape.eval()
label = " x ".join(map("{:,d}".format, shape))
if label:
for plate_label, var_names in self.get_plates().items():
if plate_label:
# must be preceded by 'cluster' to get a box around it
with graph.subgraph(name="cluster" + label) as sub:
with graph.subgraph(name="cluster" + plate_label) as sub:
for var_name in var_names:
self._make_node(var_name, sub, formatting=formatting)
# plate label goes bottom right
sub.attr(label=label, labeljust="r", labelloc="b", style="rounded")
sub.attr(label=plate_label, labeljust="r", labelloc="b", style="rounded")
else:
for var_name in var_names:
self._make_node(var_name, graph, formatting=formatting)
Expand All @@ -236,9 +232,13 @@ def model_to_graphviz(model=None, *, formatting: str = "plain"):
model : pm.Model
The model to plot. Not required when called from inside a modelcontext.
formatting : str
one of { "plain", "plain_with_params" }
one of { "plain" }
"""
if not "plain" in formatting:
raise ValueError(f"Unsupported formatting for graph nodes: '{formatting}'. See docstring.")
if formatting != "plain":
warnings.warn(
"Formattings other than 'plain' are currently not supported.", UserWarning, stacklevel=2
)
model = pm.modelcontext(model)
return ModelGraph(model).make_graph(formatting=formatting)
132 changes: 114 additions & 18 deletions pymc3/tests/test_model_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import aesara as th
import aesara
import numpy as np
import pytest

from aesara.compile.sharedvalue import SharedVariable

import pymc3 as pm

from pymc3.model_graph import ModelGraph, model_to_graphviz
from pymc3.tests.helpers import SeededTest

pytestmark = pytest.mark.xfail(reason="ModelGraph not refactored yet")


def radon_model():
"""Similar in shape to the Radon model"""
Expand All @@ -30,9 +30,6 @@ def radon_model():
uranium = np.random.normal(-0.1, 0.4, size=n_homes)
xbar = np.random.normal(1, 0.1, size=n_homes)
floor_measure = np.random.randint(0, 2, size=n_homes)
log_radon = np.random.normal(1, 1, size=n_homes)

floor_measure = th.shared(floor_measure)

d, r = divmod(919, 85)
county = np.hstack((np.tile(np.arange(counties, dtype=int), d), np.arange(r)))
Expand All @@ -44,42 +41,121 @@ def radon_model():
a = pm.Deterministic("a", mu_a + eps_a[county])
b = pm.Normal("b", mu=0.0, sigma=1e15)
sigma_y = pm.Uniform("sigma_y", lower=0, upper=100)
y_hat = a + b * floor_measure

# Anonymous SharedVariables don't show up
floor_measure = aesara.shared(floor_measure)
floor_measure_offset = pm.Data("floor_measure_offset", 1)
y_hat = a + b * floor_measure + floor_measure_offset
log_radon = pm.Data("log_radon", np.random.normal(1, 1, size=n_homes))
y_like = pm.Normal("y_like", mu=y_hat, sigma=sigma_y, observed=log_radon)

compute_graph = {
# variable_name : set of named parents in the graph
"sigma_a": set(),
"gamma": set(),
"mu_a": {"gamma"},
"eps_a": {"sigma_a"},
"a": {"mu_a", "eps_a"},
"b": set(),
"sigma_y": set(),
"y_like": {"a", "b", "sigma_y"},
"y_like": {"a", "b", "sigma_y", "floor_measure_offset"},
"floor_measure_offset": set(),
# observed data don't have parents in the model graph, but are shown as decendants
# of the model variables that the observations belong to:
"log_radon": {"y_like"},
}
plates = {
(): {"b", "sigma_a", "sigma_y"},
(3,): {"gamma"},
(85,): {"eps_a"},
(919,): {"a", "mu_a", "y_like"},
"": {"b", "sigma_a", "sigma_y", "floor_measure_offset"},
"3": {"gamma"},
"85": {"eps_a"},
"919": {"a", "mu_a", "y_like", "log_radon"},
}
return model, compute_graph, plates


class TestSimpleModel(SeededTest):
def model_with_imputations():
"""The example from https://github.com/pymc-devs/pymc3/issues/4043"""
x = np.random.randn(10) + 10.0
x = np.concatenate([x, [np.nan], [np.nan]])
x = np.ma.masked_array(x, np.isnan(x))

with pm.Model() as model:
a = pm.Normal("a")
pm.Normal("L", a, 1.0, observed=x)

compute_graph = {
"a": set(),
"L_missing": {"a"},
"L_observed": {"a"},
"L": {"L_missing", "L_observed"},
}
plates = {
"": {"a"},
"2": {"L_missing"},
"10": {"L_observed"},
"12": {"L"},
}
return model, compute_graph, plates


def model_with_dims():
with pm.Model(coords={"city": ["Aachen", "Maastricht", "London", "Bergheim"]}) as pmodel:
economics = pm.Uniform("economics", lower=-1, upper=1, shape=(1,))

population = pm.HalfNormal("population", sd=5, dims=("city"))

time = pm.Data("year", [2014, 2015, 2016], dims="year")

n = pm.Deterministic(
"tax revenue", economics * population[None, :] * time[:, None], dims=("year", "city")
)

yobs = pm.Data("observed", np.ones((3, 4)))
L = pm.Normal("L", n, observed=yobs)

compute_graph = {
"economics": set(),
"population": set(),
"year": set(),
"tax revenue": {"economics", "population", "year"},
"L": {"tax revenue"},
"observed": {"L"},
}
plates = {
"1": {"economics"},
"city (4)": {"population"},
"year (3)": {"year"},
"year (3) x city (4)": {"tax revenue"},
"3 x 4": {"L", "observed"},
}

return pmodel, compute_graph, plates


class BaseModelGraphTest(SeededTest):
model_func = None

@classmethod
def setup_class(cls):
cls.model, cls.compute_graph, cls.plates = radon_model()
cls.model, cls.compute_graph, cls.plates = cls.model_func()
cls.model_graph = ModelGraph(cls.model)

def test_inputs(self):
for child, parents in self.compute_graph.items():
for child, parents_in_plot in self.compute_graph.items():
var = self.model[child]
found_parents = self.model_graph.get_parents(var)
assert found_parents == parents
parents_in_graph = self.model_graph.get_parents(var)
if isinstance(var, SharedVariable):
# observed data also doesn't have parents in the compute graph!
# But for the visualization we like them to become decendants of the
# RVs that these observations belong to.
assert not parents_in_graph
else:
assert parents_in_plot == parents_in_graph

def test_compute_graph(self):
assert self.compute_graph == self.model_graph.make_compute_graph()
expected = self.compute_graph
actual = self.model_graph.make_compute_graph()
assert actual == expected

def test_plates(self):
assert self.plates == self.model_graph.get_plates()
Expand All @@ -93,3 +169,23 @@ def test_graphviz(self):
g = model_to_graphviz(self.model)
for key in self.compute_graph:
assert key in g.source


class TestRadonModel(BaseModelGraphTest):
model_func = radon_model

def test_checks_formatting(self):
with pytest.warns(None):
model_to_graphviz(self.model, formatting="plain")
with pytest.raises(ValueError, match="Unsupported formatting"):
model_to_graphviz(self.model, formatting="latex")
with pytest.warns(UserWarning, match="currently not supported"):
model_to_graphviz(self.model, formatting="plain_with_params")


class TestImputationModel(BaseModelGraphTest):
model_func = model_with_imputations


class TestModelWithDims(BaseModelGraphTest):
model_func = model_with_dims