Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mark pm.Data nodes in graphviz graphs. #3491

Merged
merged 1 commit into from
May 31, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Release Notes

## PyMC3 3.8 (on deck)

### New features
- Distinguish between `Data` and `Deterministic` variables when graphing models with graphviz. PR [#3491](https://github.com/pymc-defs/pymc3/pulls/3491).

## PyMC3 3.7 (May 29 2019)

### New features
Expand Down
2 changes: 1 addition & 1 deletion pymc3/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def align_minibatches(batches=None):

class Data:
"""Data container class that wraps the theano SharedVariable class
and let the model be aware of its inputs and outputs.
and lets the model be aware of its inputs and outputs.
Parameters
----------
Expand Down
53 changes: 37 additions & 16 deletions pymc3/model_graph.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
from collections import deque
from typing import Iterator, Optional, MutableSet
from typing import Dict, Iterator, Set, Optional

VarName = str
rpgoldman marked this conversation as resolved.
Show resolved Hide resolved

from theano.gof.graph import stack_search
from theano.compile import SharedVariable
from theano.tensor import Tensor

from .util import get_default_varnames
from .model import ObservedRV
import pymc3 as pm

# this is a placeholder for a better characterization of the type
# of variables in a model.
RV = Tensor


class ModelGraph:
def __init__(self, model):
Expand All @@ -30,16 +29,16 @@ def get_deterministics(self, var):
deterministics.append(v)
return deterministics

def _get_ancestors(self, var, func) -> MutableSet[RV]:
def _get_ancestors(self, var: Tensor, func) -> Set[Tensor]:
"""Get all ancestors of a function, doing some accounting for deterministics.
"""

# this contains all of the variables in the model EXCEPT var...
vars = set(self.var_list)
vars.remove(var)

blockers = set()
retval = set()
blockers = set() # type: Set[Tensor]
retval = set() # type: Set[Tensor]
def _expand(node) -> Optional[Iterator[Tensor]]:
if node in blockers:
return None
Expand All @@ -58,9 +57,9 @@ def _expand(node) -> Optional[Iterator[Tensor]]:
mode='bfs')
return retval

def _filter_parents(self, var, parents):
def _filter_parents(self, var, parents) -> Set[VarName]:
"""Get direct parents of a var, as strings"""
keep = set()
keep = set() # type: Set[VarName]
for p in parents:
if p == var:
continue
Expand All @@ -73,7 +72,7 @@ def _filter_parents(self, var, parents):
raise AssertionError('Do not know what to do with {}'.format(str(p)))
return keep

def get_parents(self, var):
def get_parents(self, var: Tensor) -> Set[VarName]:
"""Get the named nodes that are direct inputs to the var"""
if hasattr(var, 'transformed'):
func = var.transformed.logpt
Expand All @@ -85,11 +84,26 @@ def get_parents(self, var):
parents = self._get_ancestors(var, func)
return self._filter_parents(var, parents)

def make_compute_graph(self):
def make_compute_graph(self) -> Dict[str, Set[VarName]]:
"""Get map of var_name -> set(input var names) for the model"""
input_map = {}
input_map = {} # type: Dict[str, Set[VarName]]
def update_input_map(key: str, val: Set[VarName]):
if key in input_map:
input_map[key] = input_map[key].union(val)
else:
input_map[key] = val

for var_name in self.var_names:
input_map[var_name] = self.get_parents(self.model[var_name])
var = self.model[var_name]
update_input_map(var_name, self.get_parents(var))
if isinstance(var, ObservedRV):
try:
obs_name = var.observations.name
if obs_name:
input_map[var_name] = input_map[var_name].difference(set([obs_name]))
update_input_map(obs_name, set([var_name]))
except AttributeError:
pass
return input_map

def _make_node(self, var_name, graph):
Expand All @@ -101,12 +115,19 @@ def _make_node(self, var_name, graph):
if isinstance(v, pm.model.ObservedRV):
attrs['style'] = 'filled'

# make Data be roundtangle, instead of rectangle
if isinstance(v, SharedVariable):
attrs['style'] = 'filled'
attrs['style'] = 'rounded, filled'

# Get name for node
if hasattr(v, 'distribution'):
if v in self.model.potentials:
distribution = 'Potential'
attrs['shape'] = 'octagon'
elif hasattr(v, 'distribution'):
distribution = v.distribution.__class__.__name__
elif isinstance(v, SharedVariable):
distribution = 'Data'
attrs['shape'] = 'box'
else:
distribution = 'Deterministic'
attrs['shape'] = 'box'
Expand Down
9 changes: 8 additions & 1 deletion pymc3/tests/test_data_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,5 +101,12 @@ def test_model_to_graphviz_for_model_with_data_container(self):
pm.sample(1000, init=None, tune=1000, chains=1)

g = pm.model_to_graphviz(model)
text = 'x [label="x ~ Deterministic" shape=box style=filled]'

# Data node rendered correctly?
text = 'x [label="x ~ Data" shape=box style="rounded, filled"]'
assert text in g.source
# Didn't break ordinary variables?
text = 'beta [label="beta ~ Normal"]'
assert text in g.source
text = 'obs [label="obs ~ Normal" style=filled]'
assert text in g.source