Skip to content

Commit

Permalink
Refactored get_named_nodes_and_relations to make it consistent with t…
Browse files Browse the repository at this point in the history
…heano naming
  • Loading branch information
lucianopaz committed Oct 14, 2019
1 parent e811314 commit f28a692
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 55 deletions.
34 changes: 17 additions & 17 deletions pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,30 +565,30 @@ def draw_values(params, point=None, size=None):
# specified in the point. Need to find the node-inputs, their
# parents and children to replace them.
leaf_nodes = {}
named_nodes_parents = {}
named_nodes_children = {}
named_nodes_descendents = {}
named_nodes_ancestors = {}
for _, param in symbolic_params:
if hasattr(param, 'name'):
# Get the named nodes under the `param` node
nn, nnp, nnc = get_named_nodes_and_relations(param)
nn, nnd, nna = get_named_nodes_and_relations(param)
leaf_nodes.update(nn)
# Update the discovered parental relationships
for k in nnp.keys():
if k not in named_nodes_parents.keys():
named_nodes_parents[k] = nnp[k]
for k in nnd.keys():
if k not in named_nodes_descendents.keys():
named_nodes_descendents[k] = nnd[k]
else:
named_nodes_parents[k].update(nnp[k])
named_nodes_descendents[k].update(nnd[k])
# Update the discovered child relationships
for k in nnc.keys():
if k not in named_nodes_children.keys():
named_nodes_children[k] = nnc[k]
for k in nna.keys():
if k not in named_nodes_ancestors.keys():
named_nodes_ancestors[k] = nna[k]
else:
named_nodes_children[k].update(nnc[k])
stack = [k for k, v in named_nodes_children.items() if len(v) == 0]
named_nodes_ancestors[k].update(nna[k])

# Init givens and the stack of nodes to try to `_draw_value` from
givens = {p.name: (p, v) for (p, size), v in drawn.items()
if getattr(p, 'name', None) is not None}
stack = list(leaf_nodes.values())
while stack:
next_ = stack.pop(0)
if (next_, size) in drawn:
Expand All @@ -609,7 +609,7 @@ def draw_values(params, point=None, size=None):
# of TensorConstants or SharedVariables, we must add them
# to the stack or risk evaluating deterministics with the
# wrong values (issue #3354)
stack.extend([node for node in named_nodes_parents[next_]
stack.extend([node for node in named_nodes_descendents[next_]
if isinstance(node, (ObservedRV,
MultiObservedRV))
and (node, size) not in drawn])
Expand All @@ -618,7 +618,7 @@ def draw_values(params, point=None, size=None):
# If the node does not have a givens value, try to draw it.
# The named node's children givens values must also be taken
# into account.
children = named_nodes_children[next_]
children = named_nodes_ancestors[next_]
temp_givens = [givens[k] for k in givens if k in children]
try:
# This may fail for autotransformed RVs, which don't
Expand All @@ -633,7 +633,7 @@ def draw_values(params, point=None, size=None):
# The node failed, so we must add the node's parents to
# the stack of nodes to try to draw from. We exclude the
# nodes in the `params` list.
stack.extend([node for node in named_nodes_parents[next_]
stack.extend([node for node in named_nodes_descendents[next_]
if node is not None and
(node, size) not in drawn])

Expand All @@ -657,8 +657,8 @@ def draw_values(params, point=None, size=None):
# This may set values for certain nodes in the drawn
# dictionary, but they don't get added to the givens
# dictionary. Here, we try to fix that.
if param in named_nodes_children:
for node in named_nodes_children[param]:
if param in named_nodes_ancestors:
for node in named_nodes_ancestors[param]:
if (
node.name not in givens and
(node, size) in drawn
Expand Down
83 changes: 45 additions & 38 deletions pymc3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def incorporate_methods(source, destination, methods, default=None,
else:
setattr(destination, method, None)


def get_named_nodes_and_relations(graph):
"""Get the named nodes in a theano graph (i.e., nodes whose name
attribute is not None) along with their relationships (i.e., the
Expand All @@ -102,64 +103,70 @@ def get_named_nodes_and_relations(graph):
graph - a theano node
Returns:
leaf_nodes: A dictionary of name:node pairs, of the named nodes that
are also leafs of the graph
node_parents: A dictionary of node:set([parents]) pairs. Each key is
leafs: A dictionary of name:node pairs, of the named nodes that
have no named ancestors in the provided theano graph.
descendents: A dictionary of node:set([parents]) pairs. Each key is
a theano named node, and the corresponding value is the set of
theano named nodes that are parents of the node. These parental
relations skip unnamed intermediate nodes.
node_children: A dictionary of node:set([children]) pairs. Each key
theano named nodes that are direct descendents of the node in the
supplied ``graph``. These relations skip unnamed intermediate nodes.
ancestors: A dictionary of node:set([ancestors]) pairs. Each key
is a theano named node, and the corresponding value is the set
of theano named nodes that are children of the node. These child
relations skip unnamed intermediate nodes.
of theano named nodes that are direct ancestors in the of the node in
the supplied ``graph``. These relations skip unnamed intermediate
nodes.
"""
if graph.name is not None:
node_parents = {graph: set()}
node_children = {graph: set()}
ancestors = {graph: set()}
descendents = {graph: set()}
else:
node_parents = {}
node_children = {}
return _get_named_nodes_and_relations(graph, None, {}, node_parents, node_children)

def _get_named_nodes_and_relations(graph, parent, leaf_nodes,
node_parents, node_children):
ancestors = {}
descendents = {}
descendents, ancestors = _get_named_nodes_and_relations(
graph, None, ancestors, descendents
)
leafs = {
node.name: node for node, ancestor in ancestors.items()
if len(ancestor) == 0
}
return leafs, descendents, ancestors


def _get_named_nodes_and_relations(graph, descendent, descendents, ancestors):
if getattr(graph, 'owner', None) is None: # Leaf node
if graph.name is not None: # Named leaf node
leaf_nodes.update({graph.name: graph})
if parent is not None: # Is None for the root node
if descendent is not None: # Is None for the first node
try:
node_parents[graph].add(parent)
descendents[graph].add(descendent)
except KeyError:
node_parents[graph] = {parent}
node_children[parent].add(graph)
descendents[graph] = {descendent}
ancestors[descendent].add(graph)
else:
node_parents[graph] = set()
descendents[graph] = set()
# Flag that the leaf node has no children
node_children[graph] = set()
ancestors[graph] = set()
else: # Intermediate node
if graph.name is not None: # Intermediate named node
if parent is not None: # Is only None for the root node
if descendent is not None: # Is only None for the root node
try:
node_parents[graph].add(parent)
descendents[graph].add(descendent)
except KeyError:
node_parents[graph] = {parent}
node_children[parent].add(graph)
descendents[graph] = {descendent}
ancestors[descendent].add(graph)
else:
node_parents[graph] = set()
# The current node will be set as the parent of the next
descendents[graph] = set()
# The current node will be set as the descendent of the next
# nodes only if it is a named node
parent = graph
descendent = graph
# Init the nodes children to an empty set
node_children[graph] = set()
ancestors[graph] = set()
for i in graph.owner.inputs:
temp_nodes, temp_inter, temp_tree = \
_get_named_nodes_and_relations(i, parent, leaf_nodes,
node_parents, node_children)
leaf_nodes.update(temp_nodes)
node_parents.update(temp_inter)
node_children.update(temp_tree)
return leaf_nodes, node_parents, node_children
temp_desc, temp_ances = _get_named_nodes_and_relations(
i, descendent, descendents, ancestors
)
descendents.update(temp_desc)
ancestors.update(temp_ances)
return descendents, ancestors


class Context:
Expand Down

0 comments on commit f28a692

Please sign in to comment.