From d005f54dfd826af19b6d851c95b002ab750276a7 Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Sun, 27 Nov 2022 12:57:10 +0100 Subject: [PATCH 1/2] Extract `ModelGraph._eval` to a function --- pymc/model_graph.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/pymc/model_graph.py b/pymc/model_graph.py index 3bc272a91fb..95abe65cc84 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -32,6 +32,17 @@ VarName = NewType("VarName", str) +__all__ = ( + "ModelGraph", + "model_to_graphviz", + "model_to_networkx", +) + + +def fast_eval(var): + return function([], var, mode="FAST_COMPILE")() + + class ModelGraph: def __init__(self, model): self.model = model @@ -183,9 +194,6 @@ def _make_node(self, var_name, graph, *, nx=False, cluster=False, formatting: st else: graph.node(var_name.replace(":", "&"), **kwargs) - def _eval(self, var): - return function([], var, mode="FAST_COMPILE")() - def get_plates(self, var_names: Optional[Iterable[VarName]] = None) -> Dict[str, Set[VarName]]: """Rough but surprisingly accurate plate detection. @@ -202,11 +210,11 @@ def get_plates(self, var_names: Optional[Iterable[VarName]] = None) -> Dict[str, v = self.model[var_name] if var_name in self.model.named_vars_to_dims: plate_label = " x ".join( - f"{d} ({self._eval(self.model.dim_lengths[d])})" + f"{d} ({fast_eval(self.model.dim_lengths[d])})" for d in self.model.named_vars_to_dims[var_name] ) else: - plate_label = " x ".join(map(str, self._eval(v.shape))) + plate_label = " x ".join(map(str, fast_eval(v.shape))) plates[plate_label].add(var_name) return plates From 01fb0a823fc31483d66da2e394ecdf12c626c1c8 Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Sun, 27 Nov 2022 13:50:17 +0100 Subject: [PATCH 2/2] More robustness against unlabeled `dims` entries Closes #6335 --- pymc/model.py | 4 +++- pymc/model_graph.py | 28 +++++++++++++++++++++------- pymc/tests/test_model.py | 24 ++++++++++++++++++++++++ pymc/tests/test_model_graph.py | 13 +++++++++++++ 4 files changed, 61 insertions(+), 8 deletions(-) diff --git a/pymc/model.py b/pymc/model.py index efb50b32bce..14fc75d7ead 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -1498,6 +1498,8 @@ def add_named_variable(self, var, dims: Optional[Tuple[Union[str, None], ...]] = This can include several types of variables such basic_RVs, Data, Deterministics, and Potentials. """ + if var.name is None: + raise ValueError("Variable is unnamed.") if self.named_vars.tree_contains(var.name): raise ValueError(f"Variable name {var.name} already exists.") @@ -1507,7 +1509,7 @@ def add_named_variable(self, var, dims: Optional[Tuple[Union[str, None], ...]] = for dim in dims: if dim not in self.coords and dim is not None: raise ValueError(f"Dimension {dim} is not specified in `coords`.") - if any(var.name == dim for dim in dims): + if any(var.name == dim for dim in dims if dim is not None): raise ValueError(f"Variable `{var.name}` has the same name as its dimension label.") self.named_vars_to_dims[var.name] = dims diff --git a/pymc/model_graph.py b/pymc/model_graph.py index 95abe65cc84..cb4f8038fe8 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -14,7 +14,7 @@ import warnings from collections import defaultdict -from typing import Dict, Iterable, List, NewType, Optional, Set +from typing import Dict, Iterable, List, NewType, Optional, Sequence, Set from aesara import function from aesara.compile.sharedvalue import SharedVariable @@ -206,18 +206,32 @@ def get_plates(self, var_names: Optional[Iterable[VarName]] = None) -> Dict[str, """ plates = defaultdict(set) + # TODO: Evaluate all RV shapes and dim_length at once. + # This should help to find discrepancies, and + # avoids unncessary function compiles for deetermining labels. + for var_name in self.vars_to_plot(var_names): v = self.model[var_name] + shape: Sequence[int] = fast_eval(v.shape) + dim_labels = [] if var_name in self.model.named_vars_to_dims: - plate_label = " x ".join( - f"{d} ({fast_eval(self.model.dim_lengths[d])})" - for d in self.model.named_vars_to_dims[var_name] - ) + # The RV is associated with `dims` information. + for d, dname in enumerate(self.model.named_vars_to_dims[var_name]): + if dname is None: + # Unnamed dimension in a `dims` tuple! + dlen = shape[d] + dname = f"{var_name}_dim{d}" + else: + dlen = fast_eval(self.model.dim_lengths[dname]) + dim_labels.append(f"{dname} ({dlen})") + plate_label = " x ".join(dim_labels) else: - plate_label = " x ".join(map(str, fast_eval(v.shape))) + # The RV has no `dims` information. + dim_labels = map(str, shape) + plate_label = " x ".join(map(str, shape)) plates[plate_label].add(var_name) - return plates + return dict(plates) def make_graph(self, var_names: Optional[Iterable[VarName]] = None, formatting: str = "plain"): """Make graphviz Digraph of PyMC model diff --git a/pymc/tests/test_model.py b/pymc/tests/test_model.py index c020bdd90b8..0c8049a6029 100644 --- a/pymc/tests/test_model.py +++ b/pymc/tests/test_model.py @@ -844,6 +844,30 @@ def test_set_dim_with_coords(): assert pmodel.coords["mdim"] == ("A", "B", "C") +def test_add_named_variable_checks_dim_name(): + with pm.Model() as pmodel: + rv = pm.Normal.dist(mu=[1, 2]) + + # Checks that vars are named + with pytest.raises(ValueError, match="is unnamed"): + pmodel.add_named_variable(rv) + rv.name = "nomnom" + + # Coords must be available already + with pytest.raises(ValueError, match="not specified in `coords`"): + pmodel.add_named_variable(rv, dims="nomnom") + pmodel.add_coord("nomnom", [1, 2]) + + # No name collisions + with pytest.raises(ValueError, match="same name as"): + pmodel.add_named_variable(rv, dims="nomnom") + + # This should work (regression test against #6335) + rv2 = rv[:, None] + rv2.name = "yumyum" + pmodel.add_named_variable(rv2, dims=("nomnom", None)) + + def test_set_data_indirect_resize(): with pm.Model() as pmodel: pmodel.add_coord("mdim", mutable=True, length=2) diff --git a/pymc/tests/test_model_graph.py b/pymc/tests/test_model_graph.py index b19c3b7f305..59d77db329a 100644 --- a/pymc/tests/test_model_graph.py +++ b/pymc/tests/test_model_graph.py @@ -14,6 +14,7 @@ import warnings import aesara +import aesara.tensor as at import numpy as np import pytest @@ -340,6 +341,18 @@ class TestImputationModel(BaseModelGraphTest): class TestModelWithDims(BaseModelGraphTest): model_func = model_with_dims + def test_issue_6335_dims_containing_none(self): + with pm.Model(coords=dict(time=np.arange(5))) as pmodel: + data = at.as_tensor(np.ones((3, 5))) + pm.Deterministic("n", data, dims=(None, "time")) + + mg = ModelGraph(pmodel) + plates_actual = mg.get_plates() + plates_expected = { + "n_dim0 (3) x time (5)": {"n"}, + } + assert plates_actual == plates_expected + class TestUnnamedObservedNodes(BaseModelGraphTest): model_func = model_unnamed_observed_node