Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfixes to increase robustness against unnamed dims #6339

Merged
merged 2 commits into from
Nov 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pymc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1498,6 +1498,8 @@ def add_named_variable(self, var, dims: Optional[Tuple[Union[str, None], ...]] =
This can include several types of variables such basic_RVs, Data, Deterministics,
and Potentials.
"""
if var.name is None:
raise ValueError("Variable is unnamed.")
if self.named_vars.tree_contains(var.name):
raise ValueError(f"Variable name {var.name} already exists.")

Expand All @@ -1507,7 +1509,7 @@ def add_named_variable(self, var, dims: Optional[Tuple[Union[str, None], ...]] =
for dim in dims:
if dim not in self.coords and dim is not None:
raise ValueError(f"Dimension {dim} is not specified in `coords`.")
if any(var.name == dim for dim in dims):
if any(var.name == dim for dim in dims if dim is not None):
raise ValueError(f"Variable `{var.name}` has the same name as its dimension label.")
self.named_vars_to_dims[var.name] = dims

Expand Down
42 changes: 32 additions & 10 deletions pymc/model_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import warnings

from collections import defaultdict
from typing import Dict, Iterable, List, NewType, Optional, Set
from typing import Dict, Iterable, List, NewType, Optional, Sequence, Set

from aesara import function
from aesara.compile.sharedvalue import SharedVariable
Expand All @@ -32,6 +32,17 @@
VarName = NewType("VarName", str)


__all__ = (
"ModelGraph",
"model_to_graphviz",
"model_to_networkx",
)


def fast_eval(var):
return function([], var, mode="FAST_COMPILE")()


class ModelGraph:
def __init__(self, model):
self.model = model
Expand Down Expand Up @@ -183,9 +194,6 @@ def _make_node(self, var_name, graph, *, nx=False, cluster=False, formatting: st
else:
graph.node(var_name.replace(":", "&"), **kwargs)

def _eval(self, var):
return function([], var, mode="FAST_COMPILE")()

def get_plates(self, var_names: Optional[Iterable[VarName]] = None) -> Dict[str, Set[VarName]]:
"""Rough but surprisingly accurate plate detection.

Expand All @@ -198,18 +206,32 @@ def get_plates(self, var_names: Optional[Iterable[VarName]] = None) -> Dict[str,
"""
plates = defaultdict(set)

# TODO: Evaluate all RV shapes and dim_length at once.
# This should help to find discrepancies, and
# avoids unncessary function compiles for deetermining labels.

Comment on lines +209 to +212
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not needed for the bugfix, but I commented it here for the next person touching this code.

for var_name in self.vars_to_plot(var_names):
v = self.model[var_name]
shape: Sequence[int] = fast_eval(v.shape)
dim_labels = []
if var_name in self.model.named_vars_to_dims:
plate_label = " x ".join(
f"{d} ({self._eval(self.model.dim_lengths[d])})"
for d in self.model.named_vars_to_dims[var_name]
)
# The RV is associated with `dims` information.
for d, dname in enumerate(self.model.named_vars_to_dims[var_name]):
if dname is None:
# Unnamed dimension in a `dims` tuple!
dlen = shape[d]
dname = f"{var_name}_dim{d}"
else:
dlen = fast_eval(self.model.dim_lengths[dname])
dim_labels.append(f"{dname} ({dlen})")
plate_label = " x ".join(dim_labels)
else:
plate_label = " x ".join(map(str, self._eval(v.shape)))
# The RV has no `dims` information.
dim_labels = map(str, shape)
plate_label = " x ".join(map(str, shape))
plates[plate_label].add(var_name)

return plates
return dict(plates)

def make_graph(self, var_names: Optional[Iterable[VarName]] = None, formatting: str = "plain"):
"""Make graphviz Digraph of PyMC model
Expand Down
24 changes: 24 additions & 0 deletions pymc/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,6 +844,30 @@ def test_set_dim_with_coords():
assert pmodel.coords["mdim"] == ("A", "B", "C")


def test_add_named_variable_checks_dim_name():
with pm.Model() as pmodel:
rv = pm.Normal.dist(mu=[1, 2])

# Checks that vars are named
with pytest.raises(ValueError, match="is unnamed"):
pmodel.add_named_variable(rv)
rv.name = "nomnom"

# Coords must be available already
with pytest.raises(ValueError, match="not specified in `coords`"):
pmodel.add_named_variable(rv, dims="nomnom")
pmodel.add_coord("nomnom", [1, 2])

# No name collisions
with pytest.raises(ValueError, match="same name as"):
pmodel.add_named_variable(rv, dims="nomnom")

# This should work (regression test against #6335)
rv2 = rv[:, None]
rv2.name = "yumyum"
pmodel.add_named_variable(rv2, dims=("nomnom", None))


def test_set_data_indirect_resize():
with pm.Model() as pmodel:
pmodel.add_coord("mdim", mutable=True, length=2)
Expand Down
13 changes: 13 additions & 0 deletions pymc/tests/test_model_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import warnings

import aesara
import aesara.tensor as at
import numpy as np
import pytest

Expand Down Expand Up @@ -340,6 +341,18 @@ class TestImputationModel(BaseModelGraphTest):
class TestModelWithDims(BaseModelGraphTest):
model_func = model_with_dims

def test_issue_6335_dims_containing_none(self):
with pm.Model(coords=dict(time=np.arange(5))) as pmodel:
data = at.as_tensor(np.ones((3, 5)))
pm.Deterministic("n", data, dims=(None, "time"))

mg = ModelGraph(pmodel)
plates_actual = mg.get_plates()
plates_expected = {
"n_dim0 (3) x time (5)": {"n"},
}
assert plates_actual == plates_expected


class TestUnnamedObservedNodes(BaseModelGraphTest):
model_func = model_unnamed_observed_node
Expand Down