diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index a540918ff7..60a2fbc210 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -73,6 +73,7 @@ jobs: pymc3/tests/test_gp.py pymc3/tests/test_model.py pymc3/tests/test_model_func.py + pymc3/tests/test_model_graph.py pymc3/tests/test_ode.py pymc3/tests/test_posdef_sym.py pymc3/tests/test_quadpotential.py @@ -149,6 +150,7 @@ jobs: pymc3/tests/test_model.py pymc3/tests/test_model_func.py pymc3/tests/test_modelcontext.py + pymc3/tests/test_model_graph.py pymc3/tests/test_pickling.py fail-fast: false diff --git a/pymc3/model_graph.py b/pymc3/model_graph.py index 1602296066..1d45b341f6 100644 --- a/pymc3/model_graph.py +++ b/pymc3/model_graph.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import warnings -from collections import deque +from collections import defaultdict, deque from typing import Dict, Iterator, NewType, Optional, Set from aesara.compile.sharedvalue import SharedVariable @@ -101,23 +102,20 @@ def get_parents(self, var: TensorVariable) -> Set[VarName]: def make_compute_graph(self) -> Dict[str, Set[VarName]]: """Get map of var_name -> set(input var names) for the model""" - input_map = {} # type: Dict[str, Set[VarName]] - - def update_input_map(key: str, val: Set[VarName]): - if key in input_map: - input_map[key] = input_map[key].union(val) - else: - input_map[key] = val + input_map = defaultdict(set) # type: Dict[str, Set[VarName]] for var_name in self.var_names: var = self.model[var_name] - update_input_map(var_name, self.get_parents(var)) + key = var_name + val = self.get_parents(var) + input_map[key] = input_map[key].union(val) + if hasattr(var.tag, "observations"): try: obs_name = var.tag.observations.name if obs_name: input_map[var_name] = input_map[var_name].difference({obs_name}) - update_input_map(obs_name, {var_name}) + input_map[obs_name] = input_map[obs_name].union({var_name}) except AttributeError: pass return input_map @@ -126,30 +124,40 @@ def _make_node(self, var_name, graph, *, formatting: str = "plain"): """Attaches the given variable to a graphviz Digraph""" v = self.model[var_name] - # styling for node - attrs = {} - if v.owner and isinstance(v.owner.op, RandomVariable) and hasattr(v.tag, "observations"): - attrs["style"] = "filled" - - # make Data be roundtangle, instead of rectangle - if isinstance(v, SharedVariable): - attrs["style"] = "rounded, filled" - - # determine the shape for this node (default (Distribution) is ellipse) - if v in self.model.potentials: - attrs["shape"] = "octagon" - elif isinstance(v, SharedVariable) or not hasattr(v, "distribution"): - # shared variables and Deterministic represented by a box - attrs["shape"] = "box" + shape = None + style = None + label = str(v) if v in self.model.potentials: + shape = "octagon" + style = "filled" label = f"{var_name}\n~\nPotential" elif isinstance(v, SharedVariable): + shape = "box" + style = "rounded, filled" label = f"{var_name}\n~\nData" + elif v.owner and isinstance(v.owner.op, RandomVariable): + shape = "ellipse" + if hasattr(v.tag, "observations"): + # observed RV + style = "filled" + else: + shape = "ellipse" + syle = None + symbol = v.owner.op.__class__.__name__.strip("RV") + label = f"{var_name}\n~\n{symbol}" else: - label = v._str_repr(formatting=formatting).replace(" ~ ", "\n~\n") + shape = "box" + style = None + label = f"{var_name}\n~\nDeterministic" + + kwargs = { + "shape": shape, + "style": style, + "label": label, + } - graph.node(var_name.replace(":", "&"), label, **attrs) + graph.node(var_name.replace(":", "&"), **kwargs) def get_plates(self): """Rough but surprisingly accurate plate detection. @@ -161,26 +169,17 @@ def get_plates(self): ------- dict: str -> set[str] """ - plates = {} + plates = defaultdict(set) for var_name in self.var_names: v = self.model[var_name] - if hasattr(v, "observations"): - try: - # To get shape of _observed_ data container `pm.Data` - # (wrapper for aesara.SharedVariable) we evaluate it. - shape = tuple(v.observations.shape.eval()) - except AttributeError: - shape = v.observations.shape - # XXX: This needs to be refactored - # elif hasattr(v, "dshape"): - # shape = v.dshape + if var_name in self.model.RV_dims: + plate_label = " x ".join( + f"{d} ({self.model.dim_lengths[d].eval()})" + for d in self.model.RV_dims[var_name] + ) else: - shape = v.tag.test_value.shape - if shape == (1,): - shape = tuple() - if shape not in plates: - plates[shape] = set() - plates[shape].add(var_name) + plate_label = " x ".join(map(str, v.shape.eval())) + plates[plate_label].add(var_name) return plates def make_graph(self, formatting: str = "plain"): @@ -199,17 +198,14 @@ def make_graph(self, formatting: str = "plain"): "\tconda install -c conda-forge python-graphviz" ) graph = graphviz.Digraph(self.model.name) - for shape, var_names in self.get_plates().items(): - if isinstance(shape, SharedVariable): - shape = shape.eval() - label = " x ".join(map("{:,d}".format, shape)) - if label: + for plate_label, var_names in self.get_plates().items(): + if plate_label: # must be preceded by 'cluster' to get a box around it - with graph.subgraph(name="cluster" + label) as sub: + with graph.subgraph(name="cluster" + plate_label) as sub: for var_name in var_names: self._make_node(var_name, sub, formatting=formatting) # plate label goes bottom right - sub.attr(label=label, labeljust="r", labelloc="b", style="rounded") + sub.attr(label=plate_label, labeljust="r", labelloc="b", style="rounded") else: for var_name in var_names: self._make_node(var_name, graph, formatting=formatting) @@ -236,9 +232,13 @@ def model_to_graphviz(model=None, *, formatting: str = "plain"): model : pm.Model The model to plot. Not required when called from inside a modelcontext. formatting : str - one of { "plain", "plain_with_params" } + one of { "plain" } """ if not "plain" in formatting: raise ValueError(f"Unsupported formatting for graph nodes: '{formatting}'. See docstring.") + if formatting != "plain": + warnings.warn( + "Formattings other than 'plain' are currently not supported.", UserWarning, stacklevel=2 + ) model = pm.modelcontext(model) return ModelGraph(model).make_graph(formatting=formatting) diff --git a/pymc3/tests/test_model_graph.py b/pymc3/tests/test_model_graph.py index b221f2fb2a..6dccdb1cd4 100644 --- a/pymc3/tests/test_model_graph.py +++ b/pymc3/tests/test_model_graph.py @@ -11,17 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import aesara as th +import aesara import numpy as np import pytest +from aesara.compile.sharedvalue import SharedVariable + import pymc3 as pm from pymc3.model_graph import ModelGraph, model_to_graphviz from pymc3.tests.helpers import SeededTest -pytestmark = pytest.mark.xfail(reason="ModelGraph not refactored yet") - def radon_model(): """Similar in shape to the Radon model""" @@ -30,9 +30,6 @@ def radon_model(): uranium = np.random.normal(-0.1, 0.4, size=n_homes) xbar = np.random.normal(1, 0.1, size=n_homes) floor_measure = np.random.randint(0, 2, size=n_homes) - log_radon = np.random.normal(1, 1, size=n_homes) - - floor_measure = th.shared(floor_measure) d, r = divmod(919, 85) county = np.hstack((np.tile(np.arange(counties, dtype=int), d), np.arange(r))) @@ -44,10 +41,16 @@ def radon_model(): a = pm.Deterministic("a", mu_a + eps_a[county]) b = pm.Normal("b", mu=0.0, sigma=1e15) sigma_y = pm.Uniform("sigma_y", lower=0, upper=100) - y_hat = a + b * floor_measure + + # Anonymous SharedVariables don't show up + floor_measure = aesara.shared(floor_measure) + floor_measure_offset = pm.Data("floor_measure_offset", 1) + y_hat = a + b * floor_measure + floor_measure_offset + log_radon = pm.Data("log_radon", np.random.normal(1, 1, size=n_homes)) y_like = pm.Normal("y_like", mu=y_hat, sigma=sigma_y, observed=log_radon) compute_graph = { + # variable_name : set of named parents in the graph "sigma_a": set(), "gamma": set(), "mu_a": {"gamma"}, @@ -55,31 +58,104 @@ def radon_model(): "a": {"mu_a", "eps_a"}, "b": set(), "sigma_y": set(), - "y_like": {"a", "b", "sigma_y"}, + "y_like": {"a", "b", "sigma_y", "floor_measure_offset"}, + "floor_measure_offset": set(), + # observed data don't have parents in the model graph, but are shown as decendants + # of the model variables that the observations belong to: + "log_radon": {"y_like"}, } plates = { - (): {"b", "sigma_a", "sigma_y"}, - (3,): {"gamma"}, - (85,): {"eps_a"}, - (919,): {"a", "mu_a", "y_like"}, + "": {"b", "sigma_a", "sigma_y", "floor_measure_offset"}, + "3": {"gamma"}, + "85": {"eps_a"}, + "919": {"a", "mu_a", "y_like", "log_radon"}, } return model, compute_graph, plates -class TestSimpleModel(SeededTest): +def model_with_imputations(): + """The example from https://github.com/pymc-devs/pymc3/issues/4043""" + x = np.random.randn(10) + 10.0 + x = np.concatenate([x, [np.nan], [np.nan]]) + x = np.ma.masked_array(x, np.isnan(x)) + + with pm.Model() as model: + a = pm.Normal("a") + pm.Normal("L", a, 1.0, observed=x) + + compute_graph = { + "a": set(), + "L_missing": {"a"}, + "L_observed": {"a"}, + "L": {"L_missing", "L_observed"}, + } + plates = { + "": {"a"}, + "2": {"L_missing"}, + "10": {"L_observed"}, + "12": {"L"}, + } + return model, compute_graph, plates + + +def model_with_dims(): + with pm.Model(coords={"city": ["Aachen", "Maastricht", "London", "Bergheim"]}) as pmodel: + economics = pm.Uniform("economics", lower=-1, upper=1, shape=(1,)) + + population = pm.HalfNormal("population", sd=5, dims=("city")) + + time = pm.Data("year", [2014, 2015, 2016], dims="year") + + n = pm.Deterministic( + "tax revenue", economics * population[None, :] * time[:, None], dims=("year", "city") + ) + + yobs = pm.Data("observed", np.ones((3, 4))) + L = pm.Normal("L", n, observed=yobs) + + compute_graph = { + "economics": set(), + "population": set(), + "year": set(), + "tax revenue": {"economics", "population", "year"}, + "L": {"tax revenue"}, + "observed": {"L"}, + } + plates = { + "1": {"economics"}, + "city (4)": {"population"}, + "year (3)": {"year"}, + "year (3) x city (4)": {"tax revenue"}, + "3 x 4": {"L", "observed"}, + } + + return pmodel, compute_graph, plates + + +class BaseModelGraphTest(SeededTest): + model_func = None + @classmethod def setup_class(cls): - cls.model, cls.compute_graph, cls.plates = radon_model() + cls.model, cls.compute_graph, cls.plates = cls.model_func() cls.model_graph = ModelGraph(cls.model) def test_inputs(self): - for child, parents in self.compute_graph.items(): + for child, parents_in_plot in self.compute_graph.items(): var = self.model[child] - found_parents = self.model_graph.get_parents(var) - assert found_parents == parents + parents_in_graph = self.model_graph.get_parents(var) + if isinstance(var, SharedVariable): + # observed data also doesn't have parents in the compute graph! + # But for the visualization we like them to become decendants of the + # RVs that these observations belong to. + assert not parents_in_graph + else: + assert parents_in_plot == parents_in_graph def test_compute_graph(self): - assert self.compute_graph == self.model_graph.make_compute_graph() + expected = self.compute_graph + actual = self.model_graph.make_compute_graph() + assert actual == expected def test_plates(self): assert self.plates == self.model_graph.get_plates() @@ -93,3 +169,23 @@ def test_graphviz(self): g = model_to_graphviz(self.model) for key in self.compute_graph: assert key in g.source + + +class TestRadonModel(BaseModelGraphTest): + model_func = radon_model + + def test_checks_formatting(self): + with pytest.warns(None): + model_to_graphviz(self.model, formatting="plain") + with pytest.raises(ValueError, match="Unsupported formatting"): + model_to_graphviz(self.model, formatting="latex") + with pytest.warns(UserWarning, match="currently not supported"): + model_to_graphviz(self.model, formatting="plain_with_params") + + +class TestImputationModel(BaseModelGraphTest): + model_func = model_with_imputations + + +class TestModelWithDims(BaseModelGraphTest): + model_func = model_with_dims