Skip to content

Commit

Permalink
Bugfixes to increase robustness against unnamed dims (#6339)
Browse files Browse the repository at this point in the history
* Extract `ModelGraph._eval` to a function
* More robustness against unlabeled `dims` entries

Closes #6335
  • Loading branch information
michaelosthege authored Nov 27, 2022
1 parent e0d25c8 commit 3ff4e7a
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 11 deletions.
4 changes: 3 additions & 1 deletion pymc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1498,6 +1498,8 @@ def add_named_variable(self, var, dims: Optional[Tuple[Union[str, None], ...]] =
This can include several types of variables such basic_RVs, Data, Deterministics,
and Potentials.
"""
if var.name is None:
raise ValueError("Variable is unnamed.")
if self.named_vars.tree_contains(var.name):
raise ValueError(f"Variable name {var.name} already exists.")

Expand All @@ -1507,7 +1509,7 @@ def add_named_variable(self, var, dims: Optional[Tuple[Union[str, None], ...]] =
for dim in dims:
if dim not in self.coords and dim is not None:
raise ValueError(f"Dimension {dim} is not specified in `coords`.")
if any(var.name == dim for dim in dims):
if any(var.name == dim for dim in dims if dim is not None):
raise ValueError(f"Variable `{var.name}` has the same name as its dimension label.")
self.named_vars_to_dims[var.name] = dims

Expand Down
42 changes: 32 additions & 10 deletions pymc/model_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import warnings

from collections import defaultdict
from typing import Dict, Iterable, List, NewType, Optional, Set
from typing import Dict, Iterable, List, NewType, Optional, Sequence, Set

from aesara import function
from aesara.compile.sharedvalue import SharedVariable
Expand All @@ -32,6 +32,17 @@
VarName = NewType("VarName", str)


__all__ = (
"ModelGraph",
"model_to_graphviz",
"model_to_networkx",
)


def fast_eval(var):
return function([], var, mode="FAST_COMPILE")()


class ModelGraph:
def __init__(self, model):
self.model = model
Expand Down Expand Up @@ -183,9 +194,6 @@ def _make_node(self, var_name, graph, *, nx=False, cluster=False, formatting: st
else:
graph.node(var_name.replace(":", "&"), **kwargs)

def _eval(self, var):
return function([], var, mode="FAST_COMPILE")()

def get_plates(self, var_names: Optional[Iterable[VarName]] = None) -> Dict[str, Set[VarName]]:
"""Rough but surprisingly accurate plate detection.
Expand All @@ -198,18 +206,32 @@ def get_plates(self, var_names: Optional[Iterable[VarName]] = None) -> Dict[str,
"""
plates = defaultdict(set)

# TODO: Evaluate all RV shapes and dim_length at once.
# This should help to find discrepancies, and
# avoids unncessary function compiles for deetermining labels.

for var_name in self.vars_to_plot(var_names):
v = self.model[var_name]
shape: Sequence[int] = fast_eval(v.shape)
dim_labels = []
if var_name in self.model.named_vars_to_dims:
plate_label = " x ".join(
f"{d} ({self._eval(self.model.dim_lengths[d])})"
for d in self.model.named_vars_to_dims[var_name]
)
# The RV is associated with `dims` information.
for d, dname in enumerate(self.model.named_vars_to_dims[var_name]):
if dname is None:
# Unnamed dimension in a `dims` tuple!
dlen = shape[d]
dname = f"{var_name}_dim{d}"
else:
dlen = fast_eval(self.model.dim_lengths[dname])
dim_labels.append(f"{dname} ({dlen})")
plate_label = " x ".join(dim_labels)
else:
plate_label = " x ".join(map(str, self._eval(v.shape)))
# The RV has no `dims` information.
dim_labels = map(str, shape)
plate_label = " x ".join(map(str, shape))
plates[plate_label].add(var_name)

return plates
return dict(plates)

def make_graph(self, var_names: Optional[Iterable[VarName]] = None, formatting: str = "plain"):
"""Make graphviz Digraph of PyMC model
Expand Down
24 changes: 24 additions & 0 deletions pymc/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,6 +844,30 @@ def test_set_dim_with_coords():
assert pmodel.coords["mdim"] == ("A", "B", "C")


def test_add_named_variable_checks_dim_name():
with pm.Model() as pmodel:
rv = pm.Normal.dist(mu=[1, 2])

# Checks that vars are named
with pytest.raises(ValueError, match="is unnamed"):
pmodel.add_named_variable(rv)
rv.name = "nomnom"

# Coords must be available already
with pytest.raises(ValueError, match="not specified in `coords`"):
pmodel.add_named_variable(rv, dims="nomnom")
pmodel.add_coord("nomnom", [1, 2])

# No name collisions
with pytest.raises(ValueError, match="same name as"):
pmodel.add_named_variable(rv, dims="nomnom")

# This should work (regression test against #6335)
rv2 = rv[:, None]
rv2.name = "yumyum"
pmodel.add_named_variable(rv2, dims=("nomnom", None))


def test_set_data_indirect_resize():
with pm.Model() as pmodel:
pmodel.add_coord("mdim", mutable=True, length=2)
Expand Down
13 changes: 13 additions & 0 deletions pymc/tests/test_model_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import warnings

import aesara
import aesara.tensor as at
import numpy as np
import pytest

Expand Down Expand Up @@ -340,6 +341,18 @@ class TestImputationModel(BaseModelGraphTest):
class TestModelWithDims(BaseModelGraphTest):
model_func = model_with_dims

def test_issue_6335_dims_containing_none(self):
with pm.Model(coords=dict(time=np.arange(5))) as pmodel:
data = at.as_tensor(np.ones((3, 5)))
pm.Deterministic("n", data, dims=(None, "time"))

mg = ModelGraph(pmodel)
plates_actual = mg.get_plates()
plates_expected = {
"n_dim0 (3) x time (5)": {"n"},
}
assert plates_actual == plates_expected


class TestUnnamedObservedNodes(BaseModelGraphTest):
model_func = model_unnamed_observed_node
Expand Down

0 comments on commit 3ff4e7a

Please sign in to comment.