Skip to content

Commit

Permalink
Mark Data nodes in graph.
Browse files Browse the repository at this point in the history
  • Loading branch information
rpgoldman committed May 23, 2019
1 parent 29c72ec commit d03f5d1
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions pymc3/model_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@
from .util import get_default_varnames
import pymc3 as pm

# this is a placeholder for a better characterization of the type
# of variables in a model.
RV = Tensor


class ModelGraph:
def __init__(self, model):
Expand All @@ -30,7 +26,7 @@ def get_deterministics(self, var):
deterministics.append(v)
return deterministics

def _get_ancestors(self, var, func) -> MutableSet[RV]:
def _get_ancestors(self, var, func) -> MutableSet[Tensor]:
"""Get all ancestors of a function, doing some accounting for deterministics.
"""

Expand Down Expand Up @@ -108,7 +104,15 @@ def _make_node(self, var_name, graph):
if hasattr(v, 'distribution'):
distribution = v.distribution.__class__.__name__
else:
distribution = 'Deterministic'
is_data = False # type: bool
try:
is_data = v.is_data
except AttributeError:
pass
if is_data:
distribution = 'Data'
else:
distribution = 'Deterministic'
attrs['shape'] = 'box'

graph.node(var_name.replace(':', '&'),
Expand Down

0 comments on commit d03f5d1

Please sign in to comment.