diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index d1e7ee41d3..3ed800e889 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -1,5 +1,10 @@ # Release Notes +## PyMC3 3.8 (on deck) + +### New features +- Distinguish between `Data` and `Deterministic` variables when graphing models with graphviz. PR [#3491](https://github.com/pymc-defs/pymc3/pulls/3491). + ## PyMC3 3.7 (May 29 2019) ### New features diff --git a/pymc3/data.py b/pymc3/data.py index 0afa5e1a69..cac3b97ba3 100644 --- a/pymc3/data.py +++ b/pymc3/data.py @@ -390,7 +390,7 @@ def align_minibatches(batches=None): class Data: """Data container class that wraps the theano SharedVariable class - and let the model be aware of its inputs and outputs. + and lets the model be aware of its inputs and outputs. Parameters ---------- diff --git a/pymc3/model_graph.py b/pymc3/model_graph.py index 80718fec56..1bcafa4bcb 100644 --- a/pymc3/model_graph.py +++ b/pymc3/model_graph.py @@ -1,17 +1,16 @@ from collections import deque -from typing import Iterator, Optional, MutableSet +from typing import Dict, Iterator, Set, Optional + +VarName = str from theano.gof.graph import stack_search from theano.compile import SharedVariable from theano.tensor import Tensor from .util import get_default_varnames +from .model import ObservedRV import pymc3 as pm -# this is a placeholder for a better characterization of the type -# of variables in a model. -RV = Tensor - class ModelGraph: def __init__(self, model): @@ -30,7 +29,7 @@ def get_deterministics(self, var): deterministics.append(v) return deterministics - def _get_ancestors(self, var, func) -> MutableSet[RV]: + def _get_ancestors(self, var: Tensor, func) -> Set[Tensor]: """Get all ancestors of a function, doing some accounting for deterministics. """ @@ -38,8 +37,8 @@ def _get_ancestors(self, var, func) -> MutableSet[RV]: vars = set(self.var_list) vars.remove(var) - blockers = set() - retval = set() + blockers = set() # type: Set[Tensor] + retval = set() # type: Set[Tensor] def _expand(node) -> Optional[Iterator[Tensor]]: if node in blockers: return None @@ -58,9 +57,9 @@ def _expand(node) -> Optional[Iterator[Tensor]]: mode='bfs') return retval - def _filter_parents(self, var, parents): + def _filter_parents(self, var, parents) -> Set[VarName]: """Get direct parents of a var, as strings""" - keep = set() + keep = set() # type: Set[VarName] for p in parents: if p == var: continue @@ -73,7 +72,7 @@ def _filter_parents(self, var, parents): raise AssertionError('Do not know what to do with {}'.format(str(p))) return keep - def get_parents(self, var): + def get_parents(self, var: Tensor) -> Set[VarName]: """Get the named nodes that are direct inputs to the var""" if hasattr(var, 'transformed'): func = var.transformed.logpt @@ -85,11 +84,26 @@ def get_parents(self, var): parents = self._get_ancestors(var, func) return self._filter_parents(var, parents) - def make_compute_graph(self): + def make_compute_graph(self) -> Dict[str, Set[VarName]]: """Get map of var_name -> set(input var names) for the model""" - input_map = {} + input_map = {} # type: Dict[str, Set[VarName]] + def update_input_map(key: str, val: Set[VarName]): + if key in input_map: + input_map[key] = input_map[key].union(val) + else: + input_map[key] = val + for var_name in self.var_names: - input_map[var_name] = self.get_parents(self.model[var_name]) + var = self.model[var_name] + update_input_map(var_name, self.get_parents(var)) + if isinstance(var, ObservedRV): + try: + obs_name = var.observations.name + if obs_name: + input_map[var_name] = input_map[var_name].difference(set([obs_name])) + update_input_map(obs_name, set([var_name])) + except AttributeError: + pass return input_map def _make_node(self, var_name, graph): @@ -101,12 +115,19 @@ def _make_node(self, var_name, graph): if isinstance(v, pm.model.ObservedRV): attrs['style'] = 'filled' + # make Data be roundtangle, instead of rectangle if isinstance(v, SharedVariable): - attrs['style'] = 'filled' + attrs['style'] = 'rounded, filled' # Get name for node - if hasattr(v, 'distribution'): + if v in self.model.potentials: + distribution = 'Potential' + attrs['shape'] = 'octagon' + elif hasattr(v, 'distribution'): distribution = v.distribution.__class__.__name__ + elif isinstance(v, SharedVariable): + distribution = 'Data' + attrs['shape'] = 'box' else: distribution = 'Deterministic' attrs['shape'] = 'box' diff --git a/pymc3/tests/test_data_container.py b/pymc3/tests/test_data_container.py index 9db76754b4..051e2e3c7b 100644 --- a/pymc3/tests/test_data_container.py +++ b/pymc3/tests/test_data_container.py @@ -101,5 +101,12 @@ def test_model_to_graphviz_for_model_with_data_container(self): pm.sample(1000, init=None, tune=1000, chains=1) g = pm.model_to_graphviz(model) - text = 'x [label="x ~ Deterministic" shape=box style=filled]' + + # Data node rendered correctly? + text = 'x [label="x ~ Data" shape=box style="rounded, filled"]' + assert text in g.source + # Didn't break ordinary variables? + text = 'beta [label="beta ~ Normal"]' + assert text in g.source + text = 'obs [label="obs ~ Normal" style=filled]' assert text in g.source