Skip to content

Commit

Permalink
Unsuccessful attempt to graph observations.
Browse files Browse the repository at this point in the history
I don't have a clear enough notion of what can be in the `observations` property of a random variable.
  • Loading branch information
rpgoldman committed May 23, 2019
1 parent d03f5d1 commit 05a0580
Showing 1 changed file with 22 additions and 8 deletions.
30 changes: 22 additions & 8 deletions pymc3/model_graph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import deque
from typing import Iterator, Optional, MutableSet
from typing import Iterator, Optional, MutableSet, FrozenSet, Tuple

from theano.gof.graph import stack_search
from theano.compile import SharedVariable
Expand Down Expand Up @@ -34,7 +34,7 @@ def _get_ancestors(self, var, func) -> MutableSet[Tensor]:
vars = set(self.var_list)
vars.remove(var)

blockers = set()
blockers = set() # type: MutableSet[Tensor]
retval = set()
def _expand(node) -> Optional[Iterator[Tensor]]:
if node in blockers:
Expand All @@ -54,9 +54,9 @@ def _expand(node) -> Optional[Iterator[Tensor]]:
mode='bfs')
return retval

def _filter_parents(self, var, parents):
def _filter_parents(self, var, parents) -> Tuple[FrozenSet[str], FrozenSet[str]]:
"""Get direct parents of a var, as strings"""
keep = set()
keep = set() # type: MutableSet[str]
for p in parents:
if p == var:
continue
Expand All @@ -67,9 +67,14 @@ def _filter_parents(self, var, parents):
keep.add(self.transform_map[p])
else:
raise AssertionError('Do not know what to do with {}'.format(str(p)))
return keep
children = frozenset() # type: FrozenSet[str]
try:
children = frozenset(var.observations.name)
except AttributeError:
pass
return frozenset(keep - children), children

def get_parents(self, var):
def get_parents(self, var) -> Tuple[FrozenSet[str], FrozenSet[str]]:
"""Get the named nodes that are direct inputs to the var"""
if hasattr(var, 'transformed'):
func = var.transformed.logpt
Expand All @@ -83,9 +88,18 @@ def get_parents(self, var):

def make_compute_graph(self):
"""Get map of var_name -> set(input var names) for the model"""
input_map = {}
input_map = {} # type: Dict[str, FrozenSet[str]]
for var_name in self.var_names:
input_map[var_name] = self.get_parents(self.model[var_name])
# the "children" here are exclusively observations
parents, children = self.get_parents(self.model[var_name])
input_map[var_name] = parents
for child in children:
if child in input_map:
# slightly awkward union because input map values are not
# mutable
input_map[child] = frozenset(input_map[child] + set(var_name))
else:
input_map[child] = frozenset(var_name)
return input_map

def _make_node(self, var_name, graph):
Expand Down

0 comments on commit 05a0580

Please sign in to comment.