Skip to content

Commit

Permalink
mg ORIGIN/TERMINAL determ
Browse files Browse the repository at this point in the history
  • Loading branch information
kmantel committed Nov 23, 2023
1 parent 5581005 commit c315f29
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 36 deletions.
66 changes: 31 additions & 35 deletions psyneulink/core/compositions/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -3348,7 +3348,10 @@ class NodeRole(enum.Enum):
----------

ORIGIN
A `Node <Composition_Nodes>` that does not receive any `Projections <Projection>` from any other Nodes
A `Node <Composition_Nodes>` that has no scheduling dependencies
on any other Nodes within its own `Composition`. Typically,
an `ORIGIN` Node also do not receive any `Projections <Projection>` from
any other Nodes
within its own `Composition`, though if it is in a `nested Composition <Composition_Nested>` it may
receive Projections from the outer Composition. `Execution of a `Composition <Composition_Execution>`
always begins with an `ORIGIN` Node. A Composition may have many `ORIGIN` Nodes. This role cannot be
Expand Down Expand Up @@ -3448,7 +3451,10 @@ class NodeRole(enum.Enum):
<Composition_Node_Role_Assignment>` to Nodes. A Composition can have many `OUTPUT` Nodes.

TERMINAL
A `Node <Composition_Nodes>` that does not send any `Projections <Projection>` to any other Nodes within
A `Node <Composition_Nodes>` on which no other Nodes have
scheduling dependencies within its own `Composition`, excluding
`ObjectiveMechanism`. Typically, a `TERMINAL` Node does not send
any `Projections <Projection>` to any other Nodes within
its own `Composition`, though if it is in a `nested Composition <Composition_Nested>` it may send Projections
to the outer Composition. A Composition may have many `TERMINAL` Nodes. The `ObjectiveMechanism` associated
with the Composition's `controller <Composition.controller>` (assigned the role `CONTROLLER_OBJECTIVE`)
Expand Down Expand Up @@ -4977,51 +4983,41 @@ def _is_in_composition(self, component, nested=True):
if nested:
return any(component in comp._all_nodes for comp in self._get_nested_compositions())

def _get_terminal_nodes(self, graph, toposorted_graph=None) -> List[Component]:
def _get_terminal_nodes(self, graph) -> List[Component]:
"""
Returns a list of nodes in this composition that are
NodeRole.TERMINAL with respect to **toposorted_graph**. The
NodeRole.TERMINAL with respect to **graph**. The
result can change depending on whether the scheduler or
composition graph is used. The **toposorted_graph** of the
composition graph is used. The **graph** of the
scheduler graph is the scheduler's consideration_queue.

Includes all nodes in the last entry of **toposorted_graph**.
The ObjectiveMechanism of a Composition's
controller may not be NodeRole.TERMINAL, so if the
ObjectiveMechanism is the only node in the last entry of the
**toposorted_graph**, then the second-to-last entry is
NodeRole.TERMINAL instead.

Args:
toposorted_graph (List[Set[Component]]): the topological
sort of a graph
Includes all nodes that have no receivers in **graph**. The
ObjectiveMechanism of a Composition's controller cannot be
NodeRole.TERMINAL, so if the ObjectiveMechanism is the only node
with no receivers in **graph**, then that node's senders are
assigned NodeRole.TERMINAL instead.
"""
terminal_nodes = set()
if toposorted_graph is None:
toposorted_graph = list(toposort.toposort(graph))

if len(toposorted_graph) > 0:
for node in toposorted_graph[-1]:
if NodeRole.CONTROLLER_OBJECTIVE not in self.get_roles_by_node(node):
terminal_nodes.add(node)
elif len(toposorted_graph[-1]) < 2:
for previous_node in toposorted_graph[-2]:
terminal_nodes.add(previous_node)
receivers = {n: set() for n in graph}
for n in graph:
for sender in graph[n]:
receivers[sender].add(n)

candidates = {n for n in graph if len(receivers[n]) == 0}

# IMPLEMENTATION NOTE:
# The following is needed because the assignments above only identify nodes in the *last* consideration_set;
# however, the TERMINAL node(s) of a pathway with fewer nodes than the longest one may not be in the last
# consideration set. Identifying these assumes that graph_processing has been called/updated,
# which identifies and "breaks" cycles, and assigns FEEDBACK_SENDER to the appropriate consideration set(s).
children = {n: set() for n in graph}
for n in graph:
for sender in graph[n]:
children[sender].add(n)

terminal_nodes = terminal_nodes.union([
n for n in graph
if len(children[n]) == 0 and NodeRole.CYCLE not in self.nodes_to_roles[n]
])
if len(candidates) > 0:
for node in candidates:
if NodeRole.CONTROLLER_OBJECTIVE not in self.get_roles_by_node(node):
terminal_nodes.add(node)
elif len(candidates) < 2:
for previous_node in graph[node]:
terminal_nodes.add(previous_node)

return terminal_nodes

Expand All @@ -5036,7 +5032,7 @@ def _determine_origin_and_terminal_nodes_from_consideration_queue(self):
for node in list(queue)[0]:
self._add_node_role(node, NodeRole.ORIGIN)

for node in self._get_terminal_nodes(self.scheduler.graph):
for node in self._get_terminal_nodes(self.scheduler.dependency_dict):
self._add_node_role(node, NodeRole.TERMINAL)

def _add_node_aux_components(self, node, context=None):
Expand Down Expand Up @@ -5428,7 +5424,7 @@ def _determine_node_roles(self, context=None):
# the composition graph, not the scheduler graph, because OUTPUT
# is determined by composition structure, not scheduling order.
try:
output_nodes = self._get_terminal_nodes(comp_graph_dependencies, comp_graph_toposort)
output_nodes = self._get_terminal_nodes(comp_graph_dependencies)
except IndexError:
output_nodes = []

Expand Down
2 changes: 1 addition & 1 deletion tests/composition/test_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -7353,7 +7353,7 @@ def test_extended_loop(self):
comp = Composition(pathways=[[a, b, c, d],[e, c, f, b, d]])
comp.run(inputs={a: [2, 2], e: [0]})
assert set(comp.get_nodes_by_role(NodeRole.ORIGIN))=={a,e}
assert set(comp.get_nodes_by_role(NodeRole.TERMINAL))=={d}
assert set(comp.get_nodes_by_role(NodeRole.TERMINAL)) == {d, f}
assert set(comp.get_nodes_by_role(NodeRole.CYCLE))=={b,c,f}

def test_two_node_cycle(self):
Expand Down

0 comments on commit c315f29

Please sign in to comment.