Skip to content

Commit

Permalink
adding option to display parameters of distributions to graphviz gene…
Browse files Browse the repository at this point in the history
…ration
  • Loading branch information
Spaak committed Aug 10, 2020
1 parent 6504275 commit c6d5cdf
Showing 1 changed file with 26 additions and 9 deletions.
35 changes: 26 additions & 9 deletions pymc3/model_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .util import get_default_varnames
from .model import ObservedRV
import pymc3 as pm
from pymc3.util import get_variable_name


class ModelGraph:
Expand All @@ -33,6 +34,10 @@ def __init__(self, model):
self.var_list = self.model.named_vars.values()
self.transform_map = {v.transformed: v.name for v in self.var_list if hasattr(v, 'transformed')}
self._deterministics = None
self._distr_params = {
'Normal': ['mu', 'sigma'],
'Uniform': ['lower', 'upper'],
}

def get_deterministics(self, var):
"""Compute the deterministic nodes of the graph, **not** including var itself."""
Expand Down Expand Up @@ -120,7 +125,7 @@ def update_input_map(key: str, val: Set[VarName]):
pass
return input_map

def _make_node(self, var_name, graph):
def _make_node(self, var_name, graph, include_prior_params):
"""Attaches the given variable to a graphviz Digraph"""
v = self.model[var_name]

Expand All @@ -146,9 +151,21 @@ def _make_node(self, var_name, graph):
distribution = 'Deterministic'
attrs['shape'] = 'box'

graph.node(var_name.replace(':', '&'),
'{var_name}\n~\n{distribution}'.format(var_name=var_name, distribution=distribution),
**attrs)
node_text = '{var_name}\n~\n{distribution}'.format(var_name=var_name, distribution=distribution)
if include_prior_params and distribution in self._distr_params:
param_strings = []
for param in self._distr_params[distribution]:
val = get_variable_name(getattr(v.distribution, param))
if type(val) is str and len(val) > 100:
val = '<long expression>'
try:
val = '{val:.3g}'.format(val=float(val))
except ValueError:
pass
param_strings.append('{param}={val}'.format(param=param,
val=val))
node_text += '(' + ', '.join(param_strings) + ')'
graph.node(var_name.replace(':', '&'), node_text, **attrs)

def get_plates(self):
""" Rough but surprisingly accurate plate detection.
Expand Down Expand Up @@ -181,7 +198,7 @@ def get_plates(self):
plates[shape].add(var_name)
return plates

def make_graph(self):
def make_graph(self, include_prior_params=False):
"""Make graphviz Digraph of PyMC3 model
Returns
Expand All @@ -203,20 +220,20 @@ def make_graph(self):
# must be preceded by 'cluster' to get a box around it
with graph.subgraph(name='cluster' + label) as sub:
for var_name in var_names:
self._make_node(var_name, sub)
self._make_node(var_name, sub, include_prior_params)
# plate label goes bottom right
sub.attr(label=label, labeljust='r', labelloc='b', style='rounded')
else:
for var_name in var_names:
self._make_node(var_name, graph)
self._make_node(var_name, graph, include_prior_params)

for key, values in self.make_compute_graph().items():
for value in values:
graph.edge(value.replace(':', '&'), key.replace(':', '&'))
return graph


def model_to_graphviz(model=None):
def model_to_graphviz(model=None, **kwargs):
"""Produce a graphviz Digraph from a PyMC3 model.
Requires graphviz, which may be installed most easily with
Expand All @@ -228,4 +245,4 @@ def model_to_graphviz(model=None):
for more information.
"""
model = pm.modelcontext(model)
return ModelGraph(model).make_graph()
return ModelGraph(model).make_graph(**kwargs)

0 comments on commit c6d5cdf

Please sign in to comment.