From 05a05809ff371f25da3fd485f0c1a839f13e7b54 Mon Sep 17 00:00:00 2001 From: "Robert P. Goldman" Date: Thu, 23 May 2019 15:11:06 -0500 Subject: [PATCH] Unsuccessful attempt to graph observations. I don't have a clear enough notion of what can be in the `observations` property of a random variable. --- pymc3/model_graph.py | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/pymc3/model_graph.py b/pymc3/model_graph.py index 846bda391ef..c6c01cc0aa8 100644 --- a/pymc3/model_graph.py +++ b/pymc3/model_graph.py @@ -1,5 +1,5 @@ from collections import deque -from typing import Iterator, Optional, MutableSet +from typing import Iterator, Optional, MutableSet, FrozenSet, Tuple from theano.gof.graph import stack_search from theano.compile import SharedVariable @@ -34,7 +34,7 @@ def _get_ancestors(self, var, func) -> MutableSet[Tensor]: vars = set(self.var_list) vars.remove(var) - blockers = set() + blockers = set() # type: MutableSet[Tensor] retval = set() def _expand(node) -> Optional[Iterator[Tensor]]: if node in blockers: @@ -54,9 +54,9 @@ def _expand(node) -> Optional[Iterator[Tensor]]: mode='bfs') return retval - def _filter_parents(self, var, parents): + def _filter_parents(self, var, parents) -> Tuple[FrozenSet[str], FrozenSet[str]]: """Get direct parents of a var, as strings""" - keep = set() + keep = set() # type: MutableSet[str] for p in parents: if p == var: continue @@ -67,9 +67,14 @@ def _filter_parents(self, var, parents): keep.add(self.transform_map[p]) else: raise AssertionError('Do not know what to do with {}'.format(str(p))) - return keep + children = frozenset() # type: FrozenSet[str] + try: + children = frozenset(var.observations.name) + except AttributeError: + pass + return frozenset(keep - children), children - def get_parents(self, var): + def get_parents(self, var) -> Tuple[FrozenSet[str], FrozenSet[str]]: """Get the named nodes that are direct inputs to the var""" if hasattr(var, 'transformed'): func = var.transformed.logpt @@ -83,9 +88,18 @@ def get_parents(self, var): def make_compute_graph(self): """Get map of var_name -> set(input var names) for the model""" - input_map = {} + input_map = {} # type: Dict[str, FrozenSet[str]] for var_name in self.var_names: - input_map[var_name] = self.get_parents(self.model[var_name]) + # the "children" here are exclusively observations + parents, children = self.get_parents(self.model[var_name]) + input_map[var_name] = parents + for child in children: + if child in input_map: + # slightly awkward union because input map values are not + # mutable + input_map[child] = frozenset(input_map[child] + set(var_name)) + else: + input_map[child] = frozenset(var_name) return input_map def _make_node(self, var_name, graph):