Skip to content

Commit

Permalink
mg
Browse files Browse the repository at this point in the history
  • Loading branch information
kmantel committed Nov 10, 2023
1 parent 87a4f69 commit c407667
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 7 deletions.
13 changes: 13 additions & 0 deletions psyneulink/core/scheduling/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,19 @@ def _create_as_pnl_condition(condition):
if not issubclass(pnl_class, gs_condition_base_class):
return None

if (
graph_structure_conditions_available
and isinstance(condition, graph_scheduler.condition.GraphStructureCondition)
):
print(f'creating new cond for {id(condition)} {condition}')
try:
return pnl_class(
*condition.nodes,
**{k: v for k, v in condition.kwargs.items() if k != 'nodes'}
)
except AttributeError:
return pnl_class(**condition.kwargs)

new_args = [_create_as_pnl_condition(a) or a for a in condition.args]
new_kwargs = {k: _create_as_pnl_condition(v) or v for k, v in condition.kwargs.items()}
sig = inspect.signature(pnl_class)
Expand Down
16 changes: 14 additions & 2 deletions psyneulink/core/scheduling/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
import copy
import logging
import typing
from typing import Hashable

import graph_scheduler
import pint

import psyneulink as pnl
from psyneulink import _unit_registry
from psyneulink.core.globals.context import Context, handle_external_context
from psyneulink.core.globals.mdf import MDFSerializable
Expand Down Expand Up @@ -76,7 +78,7 @@ def replace_term_conds(term_conds):
def _validate_conditions(self):
unspecified_nodes = []
for node in self.nodes:
if node not in self.conditions:
if node not in self.conditions.conditions_basic:
dependencies = list(self.dependency_dict[node])
if len(dependencies) == 0:
cond = graph_scheduler.Always()
Expand All @@ -91,7 +93,7 @@ def _validate_conditions(self):
if len(unspecified_nodes) > 0:
logger.info(
'These nodes have no Conditions specified, and will be scheduled with conditions: {0}'.format(
{node: self.conditions[node] for node in unspecified_nodes}
{node: self.conditions.conditions_basic[node] for node in unspecified_nodes}
)
)

Expand Down Expand Up @@ -158,6 +160,16 @@ def as_mdf_model(self):
def get_clock(self, context):
return super().get_clock(context.execution_id)

def add_graph_edge(self, sender: Hashable, receiver: Hashable) -> 'pnl.AddEdgeTo':
cond = pnl.AddEdgeTo(receiver)
self.add_condition(sender, cond)
return cond

def remove_graph_edge(self, sender: Hashable, receiver: Hashable) -> 'pnl.RemoveEdgeFrom':
cond = pnl.RemoveEdgeFrom(sender)
self.add_condition(receiver, cond)
return cond


_doc_subs = {
None: [
Expand Down
12 changes: 7 additions & 5 deletions tests/scheduling/test_condition.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging

import graph_scheduler
import numpy as np
import psyneulink as pnl
import pytest
Expand Down Expand Up @@ -1168,9 +1169,10 @@ def add_condition(owner, condition):
scheduler.add_condition(owner, condition)
return condition

comp, _, (A, B, C, D, E) = pytest.helpers.composition_from_string_pathways(
comp, _, mechanisms = pytest.helpers.composition_from_string_pathways(
[['A', 'B', 'C', 'D', 'E']]
)
A, B, C, D, E = mechanisms
scheduler = comp.scheduler
initial_conds = {A: pnl.AddEdgeTo(C)}
initial_graph = scheduler.graph
Expand Down Expand Up @@ -1214,7 +1216,7 @@ def add_condition(owner, condition):
addl_conditions[i][1] for i in range(addl_conds_sub_idx)
if addl_conditions[i][0] == k
]
for k in initial_graph
for k in mechanisms
},
A: initial_conds[A],
})
Expand Down Expand Up @@ -1245,7 +1247,7 @@ def test_run_graph_structure_conditions(self, pathways, conditions, expected_out
for owner, cond in conditions.items()
}
)
comp.run()
comp.run({n: [0] for n in mechanisms.values() if len(comp.scheduler._graphs[0][n]) == 0})
output = comp.scheduler.execution_list[comp.default_execution_id]

assert output == [{mechanisms[n] for n in eset} for eset in expected_output]
Expand All @@ -1265,8 +1267,8 @@ def test_gsc_creates_cyclic_graph(self):
assert len(comp.scheduler._graphs) == 3
assert len(comp.scheduler.conditions.structural_condition_order) == 2

with pytest.raises(pnl.SchedulerError, match='contains a cycle'):
comp.run()
with pytest.raises(graph_scheduler.SchedulerError, match='contains a cycle'):
comp.run({A: [0]})

def test_gsc_exact_time_warning(self):
A = ProcessingMechanism(name='A')
Expand Down

0 comments on commit c407667

Please sign in to comment.