Skip to content

Commit

Permalink
Do not infer graph_model node types based on variable Op class
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Nov 3, 2022
1 parent fd5a6cc commit 48664c0
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 4 deletions.
6 changes: 2 additions & 4 deletions pymc/model_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,11 @@ def _make_node(self, var_name, graph, *, nx=False, cluster=False, formatting: st
shape = "box"
style = "rounded, filled"
label = f"{var_name}\n~\nMutableData"
elif v.owner and isinstance(v.owner.op, RandomVariable):
elif v in self.model.basic_RVs:
shape = "ellipse"
if hasattr(v.tag, "observations"):
# observed RV
if v in self.model.observed_RVs:
style = "filled"
else:
shape = "ellipse"
style = None
symbol = v.owner.op.__class__.__name__
if symbol.endswith("RV"):
Expand Down
34 changes: 34 additions & 0 deletions pymc/tests/test_model_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,36 @@ def model_observation_dtype_casting():
return model, compute_graph, plates


def model_non_random_variable_rvs():
"""Test that node types are not inferred based on the variable Op type, but
model properties
See https://github.com/pymc-devs/pymc/issues/5766
"""
with pm.Model() as model:
mu = pm.Normal(name="mu", mu=0.0, sigma=5.0)

y_raw = pm.Normal.dist(mu)
y = pm.math.clip(y_raw, -3, 3)
model.register_rv(y, name="y")

z_raw = pm.Normal.dist(y, shape=(5,))
z = pm.math.clip(z_raw, -1, 1)
model.register_rv(z, name="z", data=[0] * 5)

compute_graph = {
"mu": set(),
"y": {"mu"},
"z": {"y"},
}
plates = {
"": {"mu", "y"},
"5": {"z"},
}

return model, compute_graph, plates


class BaseModelGraphTest(SeededTest):
model_func = None

Expand Down Expand Up @@ -360,3 +390,7 @@ def test_subgraph(self, var_names, vars_to_plot, compute_graph):
mg = ModelGraph(model_with_different_descendants())
assert set(mg.vars_to_plot(var_names=var_names)) == set(vars_to_plot)
assert mg.make_compute_graph(var_names=var_names) == compute_graph


class TestModelNonRandomVariableRVs(BaseModelGraphTest):
model_func = model_non_random_variable_rvs

0 comments on commit 48664c0

Please sign in to comment.