Skip to content

Commit

Permalink
Modify model_graph to handle new pm.Data
Browse files Browse the repository at this point in the history
Mark Data nodes in graph: previously these were incorrectly marked as
`Deterministic`.

Put observations below observed nodes: Now that there are `Data`
objects that are graphable, they need to be placed correctly with
respect to other nodes: if they are inputs they should be above, and
if they are observations, they should be below.
  • Loading branch information
rpgoldman committed May 28, 2019
1 parent 8317281 commit f31d7f4
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 17 deletions.
2 changes: 1 addition & 1 deletion pymc3/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def align_minibatches(batches=None):

class Data:
"""Data container class that wraps the theano SharedVariable class
and let the model be aware of its inputs and outputs.
and lets the model be aware of its inputs and outputs.
Parameters
----------
Expand Down
48 changes: 33 additions & 15 deletions pymc3/model_graph.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
from collections import deque
from typing import Iterator, Optional, MutableSet
from typing import Dict, Iterator, Set, Optional

VarName = str

from theano.gof.graph import stack_search
from theano.compile import SharedVariable
from theano.tensor import Tensor

from .util import get_default_varnames
from .model import ObservedRV
import pymc3 as pm

# this is a placeholder for a better characterization of the type
# of variables in a model.
RV = Tensor


class ModelGraph:
def __init__(self, model):
Expand All @@ -30,16 +29,16 @@ def get_deterministics(self, var):
deterministics.append(v)
return deterministics

def _get_ancestors(self, var, func) -> MutableSet[RV]:
def _get_ancestors(self, var: Tensor, func) -> Set[Tensor]:
"""Get all ancestors of a function, doing some accounting for deterministics.
"""

# this contains all of the variables in the model EXCEPT var...
vars = set(self.var_list)
vars.remove(var)

blockers = set()
retval = set()
blockers = set() # type: Set[Tensor]
retval = set() # type: Set[Tensor]
def _expand(node) -> Optional[Iterator[Tensor]]:
if node in blockers:
return None
Expand All @@ -58,9 +57,9 @@ def _expand(node) -> Optional[Iterator[Tensor]]:
mode='bfs')
return retval

def _filter_parents(self, var, parents):
def _filter_parents(self, var, parents) -> Set[VarName]:
"""Get direct parents of a var, as strings"""
keep = set()
keep = set() # type: Set[VarName]
for p in parents:
if p == var:
continue
Expand All @@ -73,7 +72,7 @@ def _filter_parents(self, var, parents):
raise AssertionError('Do not know what to do with {}'.format(str(p)))
return keep

def get_parents(self, var):
def get_parents(self, var: Tensor) -> Set[VarName]:
"""Get the named nodes that are direct inputs to the var"""
if hasattr(var, 'transformed'):
func = var.transformed.logpt
Expand All @@ -85,11 +84,26 @@ def get_parents(self, var):
parents = self._get_ancestors(var, func)
return self._filter_parents(var, parents)

def make_compute_graph(self):
def make_compute_graph(self) -> Dict[str, Set[VarName]]:
"""Get map of var_name -> set(input var names) for the model"""
input_map = {}
input_map = {} # type: Dict[str, Set[VarName]]
def update_input_map(key: str, val: Set[VarName]):
if key in input_map:
input_map[key] = input_map[key].union(val)
else:
input_map[key] = val

for var_name in self.var_names:
input_map[var_name] = self.get_parents(self.model[var_name])
var = self.model[var_name]
update_input_map(var_name, self.get_parents(var))
if isinstance(var, ObservedRV):
try:
obs_name = var.observations.name
if obs_name:
input_map[var_name] = input_map[var_name].difference(set([obs_name]))
update_input_map(obs_name, set([var_name]))
except AttributeError:
pass
return input_map

def _make_node(self, var_name, graph):
Expand All @@ -101,12 +115,16 @@ def _make_node(self, var_name, graph):
if isinstance(v, pm.model.ObservedRV):
attrs['style'] = 'filled'

# make Data be roundtangle, instead of rectangle
if isinstance(v, SharedVariable):
attrs['style'] = 'filled'
attrs['style'] = 'rounded, filled'

# Get name for node
if hasattr(v, 'distribution'):
distribution = v.distribution.__class__.__name__
elif isinstance(v, SharedVariable):
distribution = 'Data'
attrs['shape'] = 'box'
else:
distribution = 'Deterministic'
attrs['shape'] = 'box'
Expand Down
9 changes: 8 additions & 1 deletion pymc3/tests/test_data_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,5 +101,12 @@ def test_model_to_graphviz_for_model_with_data_container(self):
pm.sample(1000, init=None, tune=1000, chains=1)

g = pm.model_to_graphviz(model)
text = 'x [label="x ~ Deterministic" shape=box style=filled]'

# Data node rendered correctly?
text = 'x [label="x ~ Data" shape=box style="rounded, filled"]'
assert text in g.source
# Didn't break ordinary variables?
text = 'beta [label="beta ~ Normal"]'
assert text in g.source
text = 'obs [label="obs ~ Normal" style=filled]'
assert text in g.source

0 comments on commit f31d7f4

Please sign in to comment.