diff --git a/Scripts/Models (Under Development)/EGO/Using EMComposition/DeclanParams.py b/Scripts/Models (Under Development)/EGO/Using EMComposition/DeclanParams.py index 9f5f652b28a..ddc95037997 100644 --- a/Scripts/Models (Under Development)/EGO/Using EMComposition/DeclanParams.py +++ b/Scripts/Models (Under Development)/EGO/Using EMComposition/DeclanParams.py @@ -50,14 +50,14 @@ def calc_prob(em_preds, test_ys): previous_state_d = 11, # length of state vector context_d = 11, # length of context vector memory_capacity = ALL, # number of entries in EM memory; ALL=> match to number of stims - memory_init = (0,.0001), # Initialize memory with random values in interval - # memory_init = None, # Initialize with zeros + # memory_init = (0,.0001), # Initialize memory with random values in interval + memory_init = None, # Initialize with zeros concatenate_queries = False, # concatenate_queries = True, # environment - # curriculum_type = 'Interleaved', - curriculum_type = 'Blocked', + curriculum_type = 'Interleaved', + # curriculum_type = 'Blocked', # num_stims = 100, # Integer or ALL num_stims = ALL, # Integer or ALL diff --git a/Scripts/Models (Under Development)/EGO/Using EMComposition/EGO Model - CSW with Simple Integrator.py b/Scripts/Models (Under Development)/EGO/Using EMComposition/EGO Model - CSW with Simple Integrator.py index aabdecfd655..432bacf4c3e 100644 --- a/Scripts/Models (Under Development)/EGO/Using EMComposition/EGO Model - CSW with Simple Integrator.py +++ b/Scripts/Models (Under Development)/EGO/Using EMComposition/EGO Model - CSW with Simple Integrator.py @@ -363,7 +363,7 @@ def construct_model(model_name:str=model_params['name'], if RUN_MODEL: import timeit def print_stuff(**kwargs): - print(f"\n**************\n BATCH: {kwargs['batch']}\n**************\n") + print(f"\n**************\n BATCH: {kwargs['minibatch']}\n**************\n") print(kwargs) print('\nContext internal: \n', model.nodes['CONTEXT'].function.parameters.value.get(kwargs['context'])) print('\nContext hidden: \n', model.nodes['CONTEXT'].parameters.value.get(kwargs['context'])) @@ -407,8 +407,8 @@ def print_stuff(**kwargs): ) stop_time = timeit.default_timer() print(f"Elapsed time: {stop_time - start_time}") - if DISPLAY_MODEL is not None: - model.show_graph(**DISPLAY_MODEL) + # if DISPLAY_MODEL is not None: + # model.show_graph(**DISPLAY_MODEL) if PRINT_RESULTS: print("MEMORY:") print(np.round(model.nodes['EM'].parameters.memory.get(model.name),3)) @@ -450,7 +450,7 @@ def eval_weights(weight_mat): axes[1].set_xlabel('Stimuli') axes[1].set_ylabel(model_params['loss_spec']) # Logit of loss - axes[2].plot( (model.results[1:TOTAL_NUM_STIMS,2]*TARGETS[:TOTAL_NUM_STIMS-1]).sum(-1) ) + axes[2].plot( (model.results[2:TOTAL_NUM_STIMS,2]*TARGETS[:TOTAL_NUM_STIMS-2]).sum(-1) ) axes[2].set_xlabel('Stimuli') axes[2].set_ylabel('Correct Logit') plt.suptitle(f"{model_params['curriculum_type']} Training") diff --git a/Scripts/Models (Under Development)/EGO/Using EMComposition/Environment.py b/Scripts/Models (Under Development)/EGO/Using EMComposition/Environment.py index 0ce08fafaaf..78aca55b459 100644 --- a/Scripts/Models (Under Development)/EGO/Using EMComposition/Environment.py +++ b/Scripts/Models (Under Development)/EGO/Using EMComposition/Environment.py @@ -2,7 +2,7 @@ import torch from torch.utils.data import dataset from torch import utils -from numpy.random import randint +from random import randint def one_hot_encode(labels, num_classes): """ diff --git a/Scripts/Models (Under Development)/EGO/Using EMComposition/ScriptControl.py b/Scripts/Models (Under Development)/EGO/Using EMComposition/ScriptControl.py index 04027649aa3..a06c4a95058 100644 --- a/Scripts/Models (Under Development)/EGO/Using EMComposition/ScriptControl.py +++ b/Scripts/Models (Under Development)/EGO/Using EMComposition/ScriptControl.py @@ -24,6 +24,6 @@ PRINT_RESULTS = False # don't print model.results to console after execution # PRINT_RESULTS = True # print model.results to console after execution SAVE_RESULTS = False # save model.results to disk -PLOT_RESULTS = False # don't plot results (PREDICTIONS) vs. TARGETS -# PLOT_RESULTS = True # plot results (PREDICTIONS) vs. TARGETS +# PLOT_RESULTS = False # don't plot results (PREDICTIONS) vs. TARGETS +PLOT_RESULTS = True # plot results (PREDICTIONS) vs. TARGETS ANIMATE = False # {UNIT:EXECUTION_SET} # Specifies whether to generate animation of execution diff --git a/conftest.py b/conftest.py index 22050caa6b5..ea4a1b2f206 100644 --- a/conftest.py +++ b/conftest.py @@ -155,28 +155,24 @@ def comp_mode_no_llvm(): # dummy fixture to allow 'comp_mode' filtering pass -class FirstBench(): - def __init__(self, benchmark): - super().__setattr__("benchmark", benchmark) +@pytest.fixture +def benchmark(benchmark): - def __call__(self, f, *args, **kwargs): - res = [] - # Compute the first result if benchmark is enabled - if self.benchmark.enabled: - res.append(f(*args, **kwargs)) + orig_class = type(benchmark) - res.append(self.benchmark(f, *args, **kwargs)) - return res[0] + class _FirstBench(orig_class): + def __call__(self, f, *args, **kwargs): + res = [] + # Compute the first result if benchmark is enabled + if self.enabled: + res.append(f(*args, **kwargs)) - def __getattr__(self, attr): - return getattr(self.benchmark, attr) + res.append(orig_class.__call__(self, f, *args, **kwargs)) + return res[0] - def __setattr__(self, attr, val): - return setattr(self.benchmark, attr, val) + benchmark.__class__ = _FirstBench -@pytest.fixture -def benchmark(benchmark): - return FirstBench(benchmark) + return benchmark @pytest.helpers.register def llvm_current_fp_precision(): diff --git a/dev_requirements.txt b/dev_requirements.txt index d82bd9d7ca6..9f06766b7a6 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -1,7 +1,7 @@ jupyter<1.1.2 packaging<25.0 pytest<8.3.4 -pytest-benchmark<4.0.1 +pytest-benchmark<5.1.1 pytest-cov<5.0.1 pytest-forked<1.7.0 pytest-helpers-namespace<2021.12.30 diff --git a/psyneulink/core/components/mechanisms/modulatory/control/gating/gatingmechanism.py b/psyneulink/core/components/mechanisms/modulatory/control/gating/gatingmechanism.py index f76fbb90981..6df5e05f0be 100644 --- a/psyneulink/core/components/mechanisms/modulatory/control/gating/gatingmechanism.py +++ b/psyneulink/core/components/mechanisms/modulatory/control/gating/gatingmechanism.py @@ -28,7 +28,7 @@ -------- A GatingMechanism is a subclass of `ControlMechanism` that is restricted to using only `GatingSignals `, -which modulate the `input ` or `output ` of a `Mechanism `, +which modulate the `input ` or `output ` of a `Mechanism `, but not the paramaters of its `function `. Accordingly, its constructor has a **gate** argument in place of a **control** argument. It also lacks several attributes related to control, including those related to costs and net_outcome. In all other respects it is identical to its parent class, ControlMechanism. @@ -58,7 +58,7 @@ *Specifying gating* ~~~~~~~~~~~~~~~~~~~ -A GatingMechanism is used to modulate the value of an `InputPort` or `OutputPort`. An InputPort or OutputPort can +A GatingMechanism is used to modulate the value of an `InputPort` or `InputPort`. An InputPort or OutputPort can be specified for gating by assigning it a `GatingProjection` or `GatingSignal` anywhere that the Projections to a Port or its `ModulatorySignals can be specified `. A `Mechanism ` can also be specified for gating, in which case the `primary InputPort ` of the specified Mechanism is used. Ports diff --git a/psyneulink/core/compositions/composition.py b/psyneulink/core/compositions/composition.py index bd66556cc86..8ceb616d9cd 100644 --- a/psyneulink/core/compositions/composition.py +++ b/psyneulink/core/compositions/composition.py @@ -12948,22 +12948,16 @@ def disable_all_history(self): self._set_all_parameter_properties_recursively(history_max_length=0) def _get_processing_condition_set(self, node): - dep_group = [] - for group in self.scheduler.consideration_queue: + for index, group in enumerate(self.scheduler.consideration_queue): if node in group: break - dep_group = group - - # This condition is used to check of the step has passed. - # Not all nodes in the previous step need to execute - # (they might have other conditions), but if any one does we're good - # FIXME: This will fail if none of the previously considered - # nodes executes in this pass, but that is unlikely. - conds = [Any(*(AllHaveRun(dep, time_scale=TimeScale.PASS) for dep in dep_group))] if len(dep_group) else [] + + assert index is not None + if node in self.scheduler.conditions: - conds.append(self.scheduler.conditions[node]) + return index, self.scheduler.conditions[node] - return All(*conds) + return index, Always() def _input_matches_variable(self, input_value, var): var_shape = convert_to_np_array(var).shape diff --git a/psyneulink/core/llvm/codegen.py b/psyneulink/core/llvm/codegen.py index cd14fadc52e..c476f5280d5 100644 --- a/psyneulink/core/llvm/codegen.py +++ b/psyneulink/core/llvm/codegen.py @@ -18,7 +18,7 @@ from psyneulink.core.globals.keywords import AFTER, BEFORE from psyneulink.core.scheduling.condition import Never from psyneulink.core.scheduling.time import TimeScale -from . import helpers +from . import helpers, scheduler from .debug import debug_env from .warnings import PNLCompilerWarning @@ -604,7 +604,7 @@ def gen_node_assembly(ctx, composition, node, *, tags:frozenset): if not is_mech and "reset" not in tags: # Add condition struct of the parent composition # This includes structures of all nested compositions - cond_gen = helpers.ConditionGenerator(ctx, composition) + cond_gen = scheduler.ConditionGenerator(ctx, composition) cond_ty = cond_gen.get_condition_struct_type().as_pointer() args.append(cond_ty) @@ -762,7 +762,7 @@ def gen_node_assembly(ctx, composition, node, *, tags:frozenset): @contextmanager def _gen_composition_exec_context(ctx, composition, *, tags:frozenset, suffix="", extra_args=[]): - cond_gen = helpers.ConditionGenerator(ctx, composition) + cond_gen = scheduler.ConditionGenerator(ctx, composition) name = "_".join(("wrap_exec", *tags, composition.name + suffix)) args = [ctx.get_state_struct_type(composition).as_pointer(), @@ -782,7 +782,15 @@ def _gen_composition_exec_context(ctx, composition, *, tags:frozenset, suffix="" params = builder.alloca(const_params.type, name="const_params_loc") builder.store(const_params, params) + for scale in TimeScale: + num_executions_ptr = helpers.get_state_ptr(builder, composition, state, "num_executions") + num_exec_time_ptr = builder.gep(num_executions_ptr, [ctx.int32_ty(0), ctx.int32_ty(scale.value)]) + num_exec = builder.load(num_exec_time_ptr) + num_exec = builder.add(num_exec, num_exec.type(1)) + builder.store(num_exec, num_exec_time_ptr) + node_tags = tags.union({"node_assembly"}) + # Call input CIM input_cim_w = ctx.get_node_assembly(composition, composition.input_CIM) input_cim_f = ctx.import_llvm_function(input_cim_w, tags=node_tags) @@ -801,8 +809,19 @@ def _gen_composition_exec_context(ctx, composition, *, tags:frozenset, suffix="" builder.ret_void() +def _reset_composition_nodes_exec_counts(ctx, builder, composition, comp_state, time_scales): + nodes_states = helpers.get_state_ptr(builder, composition, comp_state, "nodes") + for idx, node in enumerate(composition._all_nodes): + node_state = builder.gep(nodes_states, [ctx.int32_ty(0), ctx.int32_ty(idx)]) + num_exec_vec_ptr = helpers.get_state_ptr(builder, node, node_state, "num_executions") + + for scale in time_scales: + num_exec_time_ptr = builder.gep(num_exec_vec_ptr, [ctx.int32_ty(0), ctx.int32_ty(scale.value)]) + builder.store(num_exec_time_ptr.type.pointee(0), num_exec_time_ptr) + + def gen_composition_exec(ctx, composition, *, tags:frozenset): - simulation = "simulation" in tags + is_simulation = "simulation" in tags node_tags = tags.union({"node_assembly"}) with _gen_composition_exec_context(ctx, composition, tags=tags) as (builder, data, params, cond_gen): @@ -827,19 +846,16 @@ def gen_composition_exec(ctx, composition, *, tags:frozenset): is_finished_callbacks[node] = (wrapper, args) - # Reset internal TRIAL/PASS/TIME_STEP clock for each node - # This also resets TIME_STEP counter for input_CIM and parameter_CIM - # executed above - for time_loc in num_exec_locs.values(): - for scale in (TimeScale.TRIAL, TimeScale.PASS, TimeScale.TIME_STEP): - num_exec_time_ptr = builder.gep(time_loc, [ctx.int32_ty(0), ctx.int32_ty(scale.value)]) - builder.store(num_exec_time_ptr.type.pointee(0), num_exec_time_ptr) + # Resetting internal TRIAL/PASS/TIME_STEP clock for each node + # also resets TIME_STEP counter for input_CIM and parameter_CIM + # executed when setting up the context + _reset_composition_nodes_exec_counts(ctx, builder, composition, state, [TimeScale.TRIAL, TimeScale.PASS, TimeScale.TIME_STEP]) - # Check if there's anything to reset + # Check if there's any stateful node to to reset for node in composition._all_nodes: - # FIXME: This should not be necessary. The code gets DCE'd, - # but there are still some problems with generation - # 'reset' function + # FIXME: This should not be necessary. The code gets DCE'd, but + # there are still some issues with generating the 'reset' + # function. if node is composition.controller: continue @@ -848,7 +864,6 @@ def gen_composition_exec(ctx, composition, *, tags:frozenset): cond, node, is_finished_callbacks, - num_exec_locs, nodes_states) with builder.if_then(reinit_cond): node_w = ctx.get_node_assembly(composition, node) @@ -856,21 +871,29 @@ def gen_composition_exec(ctx, composition, *, tags:frozenset): builder.call(node_reinit_f, [state, params, comp_in, data, data]) # Run controller if it's enabled in 'BEFORE' mode - if simulation is False and composition.enable_controller and composition.controller_mode == BEFORE: + if is_simulation is False and composition.enable_controller and composition.controller_mode == BEFORE: assert composition.controller is not None + assert composition.controller_time_scale == TimeScale.TRIAL + + helpers.printf(ctx, + builder, + "<%u/%u/%u> Executing: {}/{}\n".format(composition.name, composition.controller.name), + cond_gen.get_global_trial(builder, cond), + cond_gen.get_global_pass(builder, cond), + cond_gen.get_global_step(builder, cond), + tags={"scheduler"}) + controller_w = ctx.get_node_assembly(composition, composition.controller) controller_f = ctx.import_llvm_function(controller_w, tags=node_tags) builder.call(controller_f, [state, params, comp_in, data, data]) - # Allocate run set structure run_set_type = ir.ArrayType(ctx.bool_ty, len(composition.nodes)) run_set_ptr = builder.alloca(run_set_type, name="run_set") builder.store(run_set_type(None), run_set_ptr) - - iter_ptr = builder.alloca(ctx.int32_ty, name="iter_counter") - builder.store(iter_ptr.type.pointee(0), iter_ptr) + consideration_index_ptr = builder.alloca(ctx.int32_ty, name="consideration_index_loc") + builder.store(consideration_index_ptr.type.pointee(0), consideration_index_ptr) # Start the main loop structure loop_condition = builder.append_basic_block(name="scheduling_loop_condition") @@ -884,7 +907,6 @@ def gen_composition_exec(ctx, composition, *, tags:frozenset): cond, None, is_finished_callbacks, - num_exec_locs, nodes_states) trial_cond = builder.not_(trial_term_cond, name="not_trial_term_cond") @@ -896,24 +918,33 @@ def gen_composition_exec(ctx, composition, *, tags:frozenset): builder.position_at_end(loop_body) previous_step = builder.load(run_set_ptr) - zero = ctx.int32_ty(0) any_cond = ctx.bool_ty(0) + consideration_index = builder.load(consideration_index_ptr) + # Calculate execution set before running the mechanisms for idx, node in enumerate(composition.nodes): run_set_node_ptr = builder.gep(run_set_ptr, [zero, ctx.int32_ty(idx)], name="run_cond_ptr_" + node.name) - node_cond = cond_gen.generate_sched_condition(builder, - composition._get_processing_condition_set(node), - cond, - node, - is_finished_callbacks, - num_exec_locs, - nodes_states) - ran = cond_gen.generate_ran_this_pass(builder, cond, node) - node_cond = builder.and_(node_cond, builder.not_(ran), name="run_cond_" + node.name) + node_consideration_index, node_condition = composition._get_processing_condition_set(node) + + is_consideration_turn = builder.icmp_unsigned("==", consideration_index, consideration_index.type(node_consideration_index)) + node_cond = cond_gen.generate_sched_condition(builder, node_condition, cond, node, is_finished_callbacks, nodes_states) + node_cond = builder.and_(node_cond, is_consideration_turn, name="run_cond_" + node.name) + any_cond = builder.or_(any_cond, node_cond, name="any_ran_cond") builder.store(node_cond, run_set_node_ptr) + prefix = "[SIMULATION] " if is_simulation else "" + helpers.printf(ctx, + builder, + "{}<%u/%u/%u> Considered: {}/{}: %d\n".format(prefix, composition.name, node.name), + cond_gen.get_global_trial(builder, cond), + cond_gen.get_global_pass(builder, cond), + cond_gen.get_global_step(builder, cond), + builder.select(node_cond, zero.type(1), zero), + tags={"scheduler" if not is_simulation else "simulation_scheduler"}) + + # Reset internal TIME_STEP clock for each node # NOTE: This is done _after_ condition evaluation, otherwise # TIME_STEP related conditions will only see 0 executions @@ -932,14 +963,25 @@ def gen_composition_exec(ctx, composition, *, tags:frozenset): node_w = ctx.get_node_assembly(composition, node) node_f = ctx.import_llvm_function(node_w, tags=node_tags) builder.block.name = "invoke_" + node_f.name + + prefix = "[SIMULATION] " if is_simulation else "" + helpers.printf(ctx, + builder, + "{}<%u/%u/%u> Executing: {}/{}\n".format(prefix, composition.name, node.name), + cond_gen.get_global_trial(builder, cond), + cond_gen.get_global_pass(builder, cond), + cond_gen.get_global_step(builder, cond), + tags={"scheduler" if not is_simulation else "simulation_scheduler"}) + # Wrappers do proper indexing of all structures # Mechanisms have only 5 args args = [state, params, comp_in, data, output_storage] if len(node_f.args) >= 6: # Composition wrappers have 6 args args.append(cond) + builder.call(node_f, args) + cond_gen.generate_update_after_node_execution(builder, cond, node) - cond_gen.generate_update_after_run(builder, cond, node) builder.block.name = "post_invoke_" + node_f.name # Writeback results @@ -959,31 +1001,40 @@ def gen_composition_exec(ctx, composition, *, tags:frozenset): cond_gen.bump_ts(builder, cond) builder.block.name = "update_iter_count" + # Increment number of iterations - iters = builder.load(iter_ptr, name="iterw") - iters = builder.add(iters, iters.type(1), name="iterw_inc") - builder.store(iters, iter_ptr) + consideration_index = builder.add(consideration_index, consideration_index.type(1), name="consideration_index_inc") + builder.store(consideration_index, consideration_index_ptr) + + max_considerations = consideration_index.type(len(composition.scheduler.consideration_queue)) + completed_pass = builder.icmp_unsigned("==", consideration_index, max_considerations, name="completed_pass") - max_iters = len(composition.scheduler.consideration_queue) - completed_pass = builder.icmp_unsigned("==", iters, iters.type(max_iters), name="completed_pass") # Increment pass and reset time step with builder.if_then(completed_pass): builder.block.name = "inc_pass" - builder.store(zero, iter_ptr) + builder.store(consideration_index_ptr.type.pointee(0), consideration_index_ptr) + # Bumping automatically zeros lower elements cond_gen.bump_ts(builder, cond, (0, 1, 0)) - # Reset internal PASS clock for each node - for time_loc in num_exec_locs.values(): - num_exec_time_ptr = builder.gep(time_loc, [zero, ctx.int32_ty(TimeScale.PASS.value)]) - builder.store(num_exec_time_ptr.type.pointee(0), num_exec_time_ptr) + + _reset_composition_nodes_exec_counts(ctx, builder, composition, state, [TimeScale.PASS]) builder.branch(loop_condition) builder.position_at_end(exit_block) - if simulation is False and composition.enable_controller and \ - composition.controller_mode == AFTER: + if is_simulation is False and composition.enable_controller and composition.controller_mode == AFTER: assert composition.controller is not None + assert composition.controller_time_scale == TimeScale.TRIAL + + helpers.printf(ctx, + builder, + "<%u/%u/%u> Executing: {}/{}\n".format(composition.name, composition.controller.name), + cond_gen.get_global_trial(builder, cond), + cond_gen.get_global_pass(builder, cond), + cond_gen.get_global_step(builder, cond), + tags={"scheduler"}) + controller_w = ctx.get_node_assembly(composition, composition.controller) controller_f = ctx.import_llvm_function(controller_w, tags=node_tags) builder.call(controller_f, [state, params, comp_in, data, data]) @@ -1044,15 +1095,10 @@ def gen_composition_run(ctx, composition, *, tags:frozenset): builder.store(data_in.type.pointee(input_init), data_in) builder.store(inputs_ptr.type.pointee(1), inputs_ptr) - # Reset internal 'RUN' clocks of each node - for idx, node in enumerate(composition._all_nodes): - node_state = builder.gep(state, [ctx.int32_ty(0), ctx.int32_ty(0), ctx.int32_ty(idx)]) - num_executions_ptr = helpers.get_state_ptr(builder, node, node_state, "num_executions") - num_exec_time_ptr = builder.gep(num_executions_ptr, [ctx.int32_ty(0), ctx.int32_ty(TimeScale.RUN.value)]) - builder.store(num_exec_time_ptr.type.pointee(None), num_exec_time_ptr) + _reset_composition_nodes_exec_counts(ctx, builder, composition, state, [TimeScale.RUN]) # Allocate and initialize condition structure - cond_gen = helpers.ConditionGenerator(ctx, composition) + cond_gen = scheduler.ConditionGenerator(ctx, composition) cond_type = cond_gen.get_condition_struct_type() cond = builder.alloca(cond_type, name="scheduler_metadata") cond_init = cond_type(cond_gen.get_condition_initializer()) @@ -1069,9 +1115,12 @@ def gen_composition_run(ctx, composition, *, tags:frozenset): # Generate a while not 'end condition' loop builder.position_at_end(loop_condition) - run_term_cond = cond_gen.generate_sched_condition( - builder, composition.termination_processing[TimeScale.RUN], - cond, None, None, None, nodes_states) + run_term_cond = cond_gen.generate_sched_condition(builder, + composition.termination_processing[TimeScale.RUN], + cond, + None, + None, + nodes_states) run_cond = builder.not_(run_term_cond, name="not_run_term_cond") # Iter cond diff --git a/psyneulink/core/llvm/execution.py b/psyneulink/core/llvm/execution.py index e877bc0a5a7..5a172a83eec 100644 --- a/psyneulink/core/llvm/execution.py +++ b/psyneulink/core/llvm/execution.py @@ -23,7 +23,7 @@ from psyneulink.core import llvm as pnlvm from psyneulink.core.globals.context import Context -from . import helpers, jit_engine, builder_context +from . import builder_context, jit_engine, scheduler from .debug import debug_env __all__ = ['CompExecution', 'FuncExecution', 'MechExecution'] @@ -347,7 +347,7 @@ def _set_bin_node(self, node): @property def _conditions(self): if self.__conditions is None: - gen = helpers.ConditionGenerator(None, self._composition) + gen = scheduler.ConditionGenerator(None, self._composition) conditions_ctype = self._bin_func.byref_arg_types[4] conditions_initializer = gen.get_condition_initializer() diff --git a/psyneulink/core/llvm/helpers.py b/psyneulink/core/llvm/helpers.py index 9581b62bebc..c3dc3336bfe 100644 --- a/psyneulink/core/llvm/helpers.py +++ b/psyneulink/core/llvm/helpers.py @@ -14,11 +14,7 @@ from llvmlite import ir - from .debug import debug_env -from psyneulink.core.scheduling.condition import All, AllHaveRun, Always, Any, AtPass, AtTrial, BeforeNCalls, AtNCalls, AfterNCalls, \ - EveryNCalls, Never, Not, WhenFinished, WhenFinishedAny, WhenFinishedAll, Threshold -from psyneulink.core.scheduling.time import TimeScale @contextmanager @@ -448,329 +444,3 @@ def printf_float_matrix(ctx, builder, matrix, prefix="", suffix="\n", *, tags:se printf_float_array(ctx, b1, row, suffix="\n", tags=tags) printf(ctx, builder, suffix, tags=tags) - - -class ConditionGenerator: - def __init__(self, ctx, composition): - self.ctx = ctx - self.composition = composition - self._zero = ctx.int32_ty(0) if ctx is not None else None - - def get_private_condition_struct_type(self, composition): - time_stamp_struct = ir.LiteralStructType([self.ctx.int32_ty, # Trial - self.ctx.int32_ty, # Pass - self.ctx.int32_ty]) # Step - - status_struct = ir.LiteralStructType([ - self.ctx.int32_ty, # number of executions in this run - time_stamp_struct # time stamp of last execution - ]) - structure = ir.LiteralStructType([ - time_stamp_struct, # current time stamp - ir.ArrayType(status_struct, len(composition.nodes)) # for each node - ]) - return structure - - def get_private_condition_initializer(self, composition): - return ((0, 0, 0), - tuple((0, (-1, -1, -1)) for _ in composition.nodes)) - - def get_condition_struct_type(self, node=None): - node = self.composition if node is None else node - - subnodes = getattr(node, 'nodes', []) - structs = [self.get_condition_struct_type(n) for n in subnodes] - if len(structs) != 0: - structs.insert(0, self.get_private_condition_struct_type(node)) - - return ir.LiteralStructType(structs) - - def get_condition_initializer(self, node=None): - node = self.composition if node is None else node - - subnodes = getattr(node, 'nodes', []) - data = [self.get_condition_initializer(n) for n in subnodes] - if len(data) != 0: - data.insert(0, self.get_private_condition_initializer(node)) - - return tuple(data) - - def bump_ts(self, builder, cond_ptr, count=(0, 0, 1)): - """ - Increments the time structure of the composition. - Count should be a tuple where there is a number in only one spot, and zeroes elsewhere. - Indices greater than that of the one are zeroed. - """ - - # Only one element should be non-zero - assert count.count(0) == len(count) - 1 - - # Get timestruct pointer - ts_ptr = builder.gep(cond_ptr, [self._zero, self._zero, self._zero]) - ts = builder.load(ts_ptr) - - assert len(ts.type) == len(count) - # Update run, pass, step of ts - for idx in range(len(ts.type)): - if all(v == 0 for v in count[:idx]): - el = builder.extract_value(ts, idx) - el = builder.add(el, el.type(count[idx])) - else: - el = self.ctx.int32_ty(0) - ts = builder.insert_value(ts, el, idx) - - builder.store(ts, ts_ptr) - return builder - - def ts_compare(self, builder, ts1, ts2, comp): - assert comp == '<' - - # True if all elements to the left of the current one are equal - prefix_eq = self.ctx.bool_ty(1) - result = self.ctx.bool_ty(0) - - assert ts1.type == ts2.type - for element in range(len(ts1.type)): - a = builder.extract_value(ts1, element) - b = builder.extract_value(ts2, element) - - # Use existing prefix_eq to construct expression - # for the current element - element_comp = builder.icmp_signed(comp, a, b) - current_comp = builder.and_(prefix_eq, element_comp) - result = builder.or_(result, current_comp) - - # Update prefix_eq - element_eq = builder.icmp_signed('==', a, b) - prefix_eq = builder.and_(prefix_eq, element_eq) - - return result - - def __get_node_status_ptr(self, builder, cond_ptr, node): - node_idx = self.ctx.int32_ty(self.composition.nodes.index(node)) - return builder.gep(cond_ptr, [self._zero, self._zero, self.ctx.int32_ty(1), node_idx]) - - def __get_node_ts(self, builder, cond_ptr, node): - status_ptr = self.__get_node_status_ptr(builder, cond_ptr, node) - ts_ptr = builder.gep(status_ptr, [self.ctx.int32_ty(0), - self.ctx.int32_ty(1)]) - return builder.load(ts_ptr) - - def get_global_ts(self, builder, cond_ptr): - ts_ptr = builder.gep(cond_ptr, [self._zero, self._zero, self._zero]) - return builder.load(ts_ptr) - - def generate_update_after_run(self, builder, cond_ptr, node): - status_ptr = self.__get_node_status_ptr(builder, cond_ptr, node) - status = builder.load(status_ptr) - - # Update number of runs - runs = builder.extract_value(status, 0) - runs = builder.add(runs, runs.type(1)) - status = builder.insert_value(status, runs, 0) - - # Update time stamp - ts = self.get_global_ts(builder, cond_ptr) - status = builder.insert_value(status, ts, 1) - - builder.store(status, status_ptr) - - def generate_ran_this_pass(self, builder, cond_ptr, node): - global_ts = self.get_global_ts(builder, cond_ptr) - global_trial = builder.extract_value(global_ts, 0) - global_pass = builder.extract_value(global_ts, 1) - - node_ts = self.__get_node_ts(builder, cond_ptr, node) - node_trial = builder.extract_value(node_ts, 0) - node_pass = builder.extract_value(node_ts, 1) - - pass_eq = builder.icmp_signed("==", node_pass, global_pass) - trial_eq = builder.icmp_signed("==", node_trial, global_trial) - return builder.and_(pass_eq, trial_eq) - - def generate_ran_this_trial(self, builder, cond_ptr, node): - global_ts = self.get_global_ts(builder, cond_ptr) - global_trial = builder.extract_value(global_ts, 0) - - node_ts = self.__get_node_ts(builder, cond_ptr, node) - node_trial = builder.extract_value(node_ts, 0) - - return builder.icmp_signed("==", node_trial, global_trial) - - # TODO: replace num_exec_locs use with equivalent from nodes_states - def generate_sched_condition(self, builder, condition, cond_ptr, node, - is_finished_callbacks, num_exec_locs, - nodes_states): - - - if isinstance(condition, Always): - return self.ctx.bool_ty(1) - - if isinstance(condition, Never): - return self.ctx.bool_ty(0) - - elif isinstance(condition, Not): - orig_condition = self.generate_sched_condition(builder, condition.condition, cond_ptr, node, is_finished_callbacks, num_exec_locs, nodes_states) - return builder.not_(orig_condition) - - elif isinstance(condition, All): - agg_cond = self.ctx.bool_ty(1) - for cond in condition.args: - cond_res = self.generate_sched_condition(builder, cond, cond_ptr, node, is_finished_callbacks, num_exec_locs, nodes_states) - agg_cond = builder.and_(agg_cond, cond_res) - return agg_cond - - elif isinstance(condition, AllHaveRun): - # Extract dependencies - dependencies = self.composition.nodes - if len(condition.args) > 0: - dependencies = condition.args - - run_cond = self.ctx.bool_ty(1) - for node in dependencies: - if condition.time_scale == TimeScale.TRIAL: - node_ran = self.generate_ran_this_trial(builder, cond_ptr, node) - elif condition.time_scale == TimeScale.PASS: - node_ran = self.generate_ran_this_pass(builder, cond_ptr, node) - else: - assert False, "Unsupported 'AllHaveRun' time scale: {}".format(condition.time_scale) - run_cond = builder.and_(run_cond, node_ran) - return run_cond - - elif isinstance(condition, Any): - agg_cond = self.ctx.bool_ty(0) - for cond in condition.args: - cond_res = self.generate_sched_condition(builder, cond, cond_ptr, node, is_finished_callbacks, num_exec_locs, nodes_states) - agg_cond = builder.or_(agg_cond, cond_res) - return agg_cond - - elif isinstance(condition, AtTrial): - trial_num = condition.args[0] - global_ts = self.get_global_ts(builder, cond_ptr) - trial = builder.extract_value(global_ts, 0) - return builder.icmp_unsigned("==", trial, trial.type(trial_num)) - - elif isinstance(condition, AtPass): - pass_num = condition.args[0] - global_ts = self.get_global_ts(builder, cond_ptr) - current_pass = builder.extract_value(global_ts, 1) - return builder.icmp_unsigned("==", current_pass, - current_pass.type(pass_num)) - - elif isinstance(condition, EveryNCalls): - target, count = condition.args - assert count == 1, "EveryNCalls isonly supprted with count == 1" - - target_ts = self.__get_node_ts(builder, cond_ptr, target) - node_ts = self.__get_node_ts(builder, cond_ptr, node) - - # If target ran after node did its TS will be greater node's - return self.ts_compare(builder, node_ts, target_ts, '<') - - elif isinstance(condition, BeforeNCalls): - target, count = condition.args - scale = condition.time_scale.value - target_num_execs_in_scale = builder.gep(num_exec_locs[target], - [self.ctx.int32_ty(0), - self.ctx.int32_ty(scale)]) - num_execs = builder.load(target_num_execs_in_scale) - - return builder.icmp_unsigned('<', num_execs, num_execs.type(count)) - - elif isinstance(condition, AtNCalls): - target, count = condition.args - scale = condition.time_scale.value - target_num_execs_in_scale = builder.gep(num_exec_locs[target], - [self.ctx.int32_ty(0), - self.ctx.int32_ty(scale)]) - num_execs = builder.load(target_num_execs_in_scale) - return builder.icmp_unsigned('==', num_execs, num_execs.type(count)) - - elif isinstance(condition, AfterNCalls): - target, count = condition.args - scale = condition.time_scale.value - target_num_execs_in_scale = builder.gep(num_exec_locs[target], - [self.ctx.int32_ty(0), - self.ctx.int32_ty(scale)]) - num_execs = builder.load(target_num_execs_in_scale) - return builder.icmp_unsigned('>=', num_execs, num_execs.type(count)) - - elif isinstance(condition, WhenFinished): - # The first argument is the target node - assert len(condition.args) == 1 - target = is_finished_callbacks[condition.args[0]] - is_finished_f = self.ctx.import_llvm_function(target[0], tags=frozenset({"is_finished", "node_assembly"})) - return builder.call(is_finished_f, target[1]) - - elif isinstance(condition, WhenFinishedAny): - assert len(condition.args) > 0 - - run_cond = self.ctx.bool_ty(0) - for node in condition.args: - target = is_finished_callbacks[node] - is_finished_f = self.ctx.import_llvm_function(target[0], tags=frozenset({"is_finished", "node_assembly"})) - node_is_finished = builder.call(is_finished_f, target[1]) - - run_cond = builder.or_(run_cond, node_is_finished) - - return run_cond - - elif isinstance(condition, WhenFinishedAll): - assert len(condition.args) > 0 - - run_cond = self.ctx.bool_ty(1) - for node in condition.args: - target = is_finished_callbacks[node] - is_finished_f = self.ctx.import_llvm_function(target[0], tags=frozenset({"is_finished", "node_assembly"})) - node_is_finished = builder.call(is_finished_f, target[1]) - - run_cond = builder.and_(run_cond, node_is_finished) - - return run_cond - - elif isinstance(condition, Threshold): - target = condition.dependency - param = condition.parameter - threshold = condition.threshold - comparator = condition.comparator - indices = condition.indices - - # Convert execution_count to ('num_executions', TimeScale.LIFE). - # These two are identical in compiled semantics. - if param == 'execution_count': - assert indices is None - param = 'num_executions' - indices = TimeScale.LIFE - - assert param in target.llvm_state_ids, ( - f"Threshold for {target} only supports items in llvm_state_ids" - f" ({target.llvm_state_ids})" - ) - - node_idx = self.composition._get_node_index(target) - node_state = builder.gep(nodes_states, [self.ctx.int32_ty(0), self.ctx.int32_ty(node_idx)]) - param_ptr = get_state_ptr(builder, target, node_state, param) - - # parameters in state include history of at least one element - # so they are always arrays. - assert isinstance(param_ptr.type.pointee, ir.ArrayType) - - if indices is None: - indices = [0, 0] - elif isinstance(indices, TimeScale): - indices = [indices.value] - - param_ptr = builder.gep(param_ptr, [self.ctx.int32_ty(x) for x in [0] + list(indices)]) - - val = builder.load(param_ptr) - val = convert_type(builder, val, ir.DoubleType()) - threshold = val.type(threshold) - - if comparator == '==': - return is_close(self.ctx, builder, val, threshold, condition.rtol, condition.atol) - elif comparator == '!=': - return builder.not_(is_close(self.ctx, builder, val, threshold, condition.rtol, condition.atol)) - else: - return builder.fcmp_ordered(comparator, val, threshold) - - assert False, "Unsupported scheduling condition: {}".format(condition) diff --git a/psyneulink/core/llvm/scheduler.py b/psyneulink/core/llvm/scheduler.py new file mode 100644 index 00000000000..e98bafd1423 --- /dev/null +++ b/psyneulink/core/llvm/scheduler.py @@ -0,0 +1,311 @@ +from enum import Enum +from llvmlite import ir + +from . import helpers +from psyneulink.core.scheduling.time import TimeScale +from psyneulink.core.scheduling.condition import All, AllHaveRun, Always, Any, AtPass, AtTrial, BeforeNCalls, AtNCalls, AfterNCalls, \ + EveryNCalls, Never, Not, WhenFinished, WhenFinishedAny, WhenFinishedAll, Threshold + +class ConditionGenerator: + class TimeIndex(Enum): + TRIAL = 0, + PASS = 1, + STEP = 2, + + def __init__(self, ctx, composition): + self.ctx = ctx + self.composition = composition + self._zero = ctx.int32_ty(0) if ctx is not None else None + + def get_private_condition_struct_type(self, composition): + time_stamp_struct = ir.LiteralStructType([self.ctx.int32_ty for _ in self.TimeIndex]) + nodes_time_stamps_array = ir.ArrayType(time_stamp_struct, len(composition.nodes)) + + return ir.LiteralStructType((time_stamp_struct, nodes_time_stamps_array)) + + def get_private_condition_initializer(self, composition): + init_global = tuple(0 for _ in self.TimeIndex) + init_node = tuple(-1 for _ in self.TimeIndex) + + return (init_global, tuple(init_node for _ in composition.nodes)) + + def get_condition_struct_type(self, node=None): + node = self.composition if node is None else node + + subnodes = getattr(node, 'nodes', []) + structs = [self.get_condition_struct_type(n) for n in subnodes] + if len(structs) != 0: + structs.insert(0, self.get_private_condition_struct_type(node)) + + return ir.LiteralStructType(structs) + + def get_condition_initializer(self, node=None): + node = self.composition if node is None else node + + subnodes = getattr(node, 'nodes', []) + data = [self.get_condition_initializer(n) for n in subnodes] + if len(data) != 0: + data.insert(0, self.get_private_condition_initializer(node)) + + return tuple(data) + + def bump_ts(self, builder, cond_ptr, count=(0, 0, 1)): + """ + Increments the time structure of the composition. + Count should be a tuple where there is a number in only one spot, and zeroes elsewhere. + Indices greater than the incremented one are zeroed. + """ + + # Only one element should be non-zero + assert count.count(0) == len(count) - 1 + + # Get timestruct pointer + ts_ptr = self.__get_global_ts_ptr(builder, cond_ptr) + ts = builder.load(ts_ptr) + + assert len(ts.type) == len(count) + + # Update run, pass, step of ts + for idx in range(len(ts.type)): + if all(v == 0 for v in count[:idx]): + el = builder.extract_value(ts, idx) + el = builder.add(el, el.type(count[idx])) + else: + el = self.ctx.int32_ty(0) + + ts = builder.insert_value(ts, el, idx) + + builder.store(ts, ts_ptr) + return builder + + def ts_compare(self, builder, ts1, ts2, comp): + assert comp == '<' + + # True if all elements to the left of the current one are equal + prefix_eq = self.ctx.bool_ty(1) + result = self.ctx.bool_ty(0) + + assert ts1.type == ts2.type + for element in range(len(ts1.type)): + a = builder.extract_value(ts1, element) + b = builder.extract_value(ts2, element) + + # Use existing prefix_eq to construct expression + # for the current element + element_comp = builder.icmp_signed(comp, a, b) + current_comp = builder.and_(prefix_eq, element_comp) + result = builder.or_(result, current_comp) + + # Update prefix_eq + element_eq = builder.icmp_signed('==', a, b) + prefix_eq = builder.and_(prefix_eq, element_eq) + + return result + + def __get_global_ts_ptr(self, builder, cond_ptr): + # derefence the structure, the first element (private structure), + # and the first element of the private strucutre is the global ts. + return builder.gep(cond_ptr, [self._zero, self._zero, self._zero]) + + def __get_node_ts_ptr(self, builder, cond_ptr, node): + node_idx = self.ctx.int32_ty(self.composition.nodes.index(node)) + + # derefence the structure, the first element (private structure), the + # second element is the node time stamp array, use index in the array + return builder.gep(cond_ptr, [self._zero, self._zero, self.ctx.int32_ty(1), node_idx]) + + def __get_node_ts(self, builder, cond_ptr, node): + ts_ptr = self.__get_node_ts_ptr(builder, cond_ptr, node) + return builder.load(ts_ptr) + + def get_global_ts(self, builder, cond_ptr): + ts_ptr = builder.gep(cond_ptr, [self._zero, self._zero, self._zero]) + return builder.load(ts_ptr) + + def _extract_global_time(self, builder, cond_ptr, time_index): + global_ts = self.get_global_ts(builder, cond_ptr) + return builder.extract_value(global_ts, time_index.value) + + def get_global_trial(self, builder, cond_ptr): + return self._extract_global_time(builder, cond_ptr, self.TimeIndex.TRIAL) + + def get_global_pass(self, builder, cond_ptr): + return self._extract_global_time(builder, cond_ptr, self.TimeIndex.PASS) + + def get_global_step(self, builder, cond_ptr): + return self._extract_global_time(builder, cond_ptr, self.TimeIndex.STEP) + + def generate_update_after_node_execution(self, builder, cond_ptr, node): + # Update time stamp of the last execution + global_ts_ptr = self.__get_global_ts_ptr(builder, cond_ptr) + node_ts_ptr = self.__get_node_ts_ptr(builder, cond_ptr, node) + + global_ts = builder.load(global_ts_ptr) + builder.store(global_ts, node_ts_ptr) + + def _node_executions_for_scale(self, builder, node, node_states, time_scale:TimeScale): + node_idx = self.composition._get_node_index(node) + node_state = builder.gep(node_states, [self._zero, self.ctx.int32_ty(node_idx)]) + num_exec_ptr = helpers.get_state_ptr(builder, node, node_state, "num_executions") + + count_ptr = builder.gep(num_exec_ptr, [self._zero, self.ctx.int32_ty(time_scale.value)]) + return builder.load(count_ptr) + + def generate_sched_condition(self, builder, condition, cond_ptr, self_node, is_finished_callbacks, nodes_states): + + if isinstance(condition, Always): + return self.ctx.bool_ty(1) + + if isinstance(condition, Never): + return self.ctx.bool_ty(0) + + elif isinstance(condition, Not): + orig_condition = self.generate_sched_condition(builder, condition.condition, cond_ptr, self_node, is_finished_callbacks, nodes_states) + return builder.not_(orig_condition) + + elif isinstance(condition, All): + agg_cond = self.ctx.bool_ty(1) + for cond in condition.args: + cond_res = self.generate_sched_condition(builder, cond, cond_ptr, self_node, is_finished_callbacks, nodes_states) + agg_cond = builder.and_(agg_cond, cond_res) + return agg_cond + + elif isinstance(condition, AllHaveRun): + # Extract dependencies + dependencies = self.composition.nodes + if len(condition.args) > 0: + dependencies = condition.args + + run_cond = self.ctx.bool_ty(1) + for node in dependencies: + count = self._node_executions_for_scale(builder, node, nodes_states, condition.time_scale) + + node_ran = builder.icmp_unsigned(">", count, count.type(0)) + run_cond = builder.and_(run_cond, node_ran) + + return run_cond + + elif isinstance(condition, Any): + agg_cond = self.ctx.bool_ty(0) + for cond in condition.args: + cond_res = self.generate_sched_condition(builder, cond, cond_ptr, self_node, is_finished_callbacks, nodes_states) + agg_cond = builder.or_(agg_cond, cond_res) + return agg_cond + + elif isinstance(condition, AtTrial): + trial_num = condition.args[0] + current_trial = self.get_global_trial(builder, cond_ptr) + return builder.icmp_unsigned("==", current_trial, current_trial.type(trial_num)) + + elif isinstance(condition, AtPass): + pass_num = condition.args[0] + current_pass = self.get_global_pass(builder, cond_ptr) + return builder.icmp_unsigned("==", current_pass, current_pass.type(pass_num)) + + elif isinstance(condition, EveryNCalls): + target, count = condition.args + assert count == 1, "EveryNCalls is only supported with count == 1 (count: {})".format(count) + + target_ts = self.__get_node_ts(builder, cond_ptr, target) + node_ts = self.__get_node_ts(builder, cond_ptr, self_node) + + # If target ran after node did its TS will be greater node's + return self.ts_compare(builder, node_ts, target_ts, '<') + + elif isinstance(condition, BeforeNCalls): + node, count = condition.args + num_execs = self._node_executions_for_scale(builder, node, nodes_states, condition.time_scale) + + return builder.icmp_unsigned('<', num_execs, num_execs.type(count)) + + elif isinstance(condition, AtNCalls): + node, count = condition.args + num_execs = self._node_executions_for_scale(builder, node, nodes_states, condition.time_scale) + + return builder.icmp_unsigned('==', num_execs, num_execs.type(count)) + + elif isinstance(condition, AfterNCalls): + node, count = condition.args + num_execs = self._node_executions_for_scale(builder, node, nodes_states, condition.time_scale) + + return builder.icmp_unsigned('>=', num_execs, num_execs.type(count)) + + elif isinstance(condition, WhenFinished): + # The first argument is the target node + assert len(condition.args) == 1 + target = is_finished_callbacks[condition.args[0]] + is_finished_f = self.ctx.import_llvm_function(target[0], tags=frozenset({"is_finished", "node_assembly"})) + return builder.call(is_finished_f, target[1]) + + elif isinstance(condition, WhenFinishedAny): + assert len(condition.args) > 0 + + run_cond = self.ctx.bool_ty(0) + for node in condition.args: + target = is_finished_callbacks[node] + is_finished_f = self.ctx.import_llvm_function(target[0], tags=frozenset({"is_finished", "node_assembly"})) + node_is_finished = builder.call(is_finished_f, target[1]) + + run_cond = builder.or_(run_cond, node_is_finished) + + return run_cond + + elif isinstance(condition, WhenFinishedAll): + assert len(condition.args) > 0 + + run_cond = self.ctx.bool_ty(1) + for node in condition.args: + target = is_finished_callbacks[node] + is_finished_f = self.ctx.import_llvm_function(target[0], tags=frozenset({"is_finished", "node_assembly"})) + node_is_finished = builder.call(is_finished_f, target[1]) + + run_cond = builder.and_(run_cond, node_is_finished) + + return run_cond + + elif isinstance(condition, Threshold): + target = condition.dependency + param = condition.parameter + threshold = condition.threshold + comparator = condition.comparator + indices = condition.indices + + # Convert execution_count to ('num_executions', TimeScale.LIFE). + # These two are identical in compiled semantics. + if param == 'execution_count': + assert indices is None + param = 'num_executions' + indices = TimeScale.LIFE + + assert param in target.llvm_state_ids, ( + f"Threshold for {target} only supports items in llvm_state_ids" + f" ({target.llvm_state_ids})" + ) + + node_idx = self.composition._get_node_index(target) + node_state = builder.gep(nodes_states, [self.ctx.int32_ty(0), self.ctx.int32_ty(node_idx)]) + param_ptr = helpers.get_state_ptr(builder, target, node_state, param) + + # parameters in state include history of at least one element + # so they are always arrays. + assert isinstance(param_ptr.type.pointee, ir.ArrayType) + + if indices is None: + indices = [0, 0] + elif isinstance(indices, TimeScale): + indices = [indices.value] + + param_ptr = builder.gep(param_ptr, [self.ctx.int32_ty(x) for x in [0] + list(indices)]) + + val = builder.load(param_ptr) + val = helpers.convert_type(builder, val, ir.DoubleType()) + threshold = val.type(threshold) + + if comparator == '==': + return helpers.is_close(self.ctx, builder, val, threshold, condition.rtol, condition.atol) + elif comparator == '!=': + return builder.not_(helpers.is_close(self.ctx, builder, val, threshold, condition.rtol, condition.atol)) + else: + return builder.fcmp_ordered(comparator, val, threshold) + + assert False, "Unsupported scheduling condition: {}".format(condition) diff --git a/psyneulink/core/scheduling/condition.py b/psyneulink/core/scheduling/condition.py index 632cd9fa752..f3b9d20fcac 100644 --- a/psyneulink/core/scheduling/condition.py +++ b/psyneulink/core/scheduling/condition.py @@ -25,28 +25,21 @@ __all__ = [ # noqa: F822 (dynamically generated) - 'AbsoluteCondition', 'AddEdgeTo', 'AfterCall', - 'AfterConsiderationSetExecution', 'AfterEnvironmentSequence', - 'AfterEnvironmentStateUpdate', 'AfterNCalls', 'AfterNCallsCombined', - 'AfterNConsiderationSetExecutions', 'AfterNEnvironmentSequences', - 'AfterNEnvironmentStateUpdates', 'AfterNPasses', 'AfterNRuns', - 'AfterNTimeSteps', 'AfterNTrials', 'AfterNode', 'AfterNodes', - 'AfterPass', 'AfterRun', 'AfterTimeStep', 'AfterTrial', 'All', - 'AllHaveRun', 'Always', 'And', 'Any', 'AtConsiderationSetExecution', - 'AtEnvironmentSequence', 'AtEnvironmentSequenceNStart', - 'AtEnvironmentSequenceStart', 'AtEnvironmentStateUpdate', - 'AtEnvironmentStateUpdateNStart', 'AtEnvironmentStateUpdateStart', - 'AtNCalls', 'AtPass', 'AtRun', 'AtRunNStart', 'AtRunStart', - 'AtTimeStep', 'AtTrial', 'AtTrialNStart', 'AtTrialStart', - 'BeforeConsiderationSetExecution', 'BeforeEnvironmentStateUpdate', - 'BeforeNCalls', 'BeforeNode', 'BeforeNodes', 'BeforePass', - 'BeforeTimeStep', 'BeforeTrial', 'CompositeCondition', 'Condition', - 'ConditionBase', 'ConditionError', 'ConditionSet', - 'CustomGraphStructureCondition', 'EveryNCalls', 'EveryNPasses', - 'GraphStructureCondition', 'JustRan', 'NWhen', 'Never', 'Not', - 'Operation', 'Or', 'RemoveEdgeFrom', 'Threshold', 'TimeInterval', - 'TimeTermination', 'WhenFinished', 'WhenFinishedAll', - 'WhenFinishedAny', 'When', 'While', 'WhileNot', 'WithNode', + 'AbsoluteCondition', 'AddEdgeTo', 'AfterCall', 'AfterNCalls', + 'AfterNCallsCombined', 'AfterNode', 'AfterNodes', 'AfterNPasses', + 'AfterNRuns', 'AfterNTimeSteps', 'AfterNTrials', 'AfterPass', + 'AfterRun', 'AfterTimeStep', 'AfterTrial', 'All', 'AllHaveRun', + 'Always', 'And', 'Any', 'AtNCalls', 'AtPass', 'AtRun', + 'AtRunNStart', 'AtRunStart', 'AtTimeStep', 'AtTrial', + 'AtTrialNStart', 'AtTrialStart', 'BeforeNCalls', 'BeforeNode', + 'BeforeNodes', 'BeforePass', 'BeforeTimeStep', 'BeforeTrial', + 'CompositeCondition', 'Condition', 'ConditionBase', + 'ConditionError', 'ConditionSet', 'CustomGraphStructureCondition', + 'EveryNCalls', 'EveryNPasses', 'GraphStructureCondition', 'JustRan', + 'Never', 'Not', 'NWhen', 'Operation', 'Or', 'RemoveEdgeFrom', + 'Threshold', 'TimeInterval', 'TimeTermination', 'When', + 'WhenFinished', 'WhenFinishedAll', 'WhenFinishedAny', 'While', + 'WhileNot', 'WithNode', ] diff --git a/psyneulink/core/scheduling/condition.pyi b/psyneulink/core/scheduling/condition.pyi index 1b32f3554c2..6dc7fda5d3c 100644 --- a/psyneulink/core/scheduling/condition.pyi +++ b/psyneulink/core/scheduling/condition.pyi @@ -6,7 +6,7 @@ import graph_scheduler.time import pint from _typeshed import Incomplete -__all__ = ['Operation', 'ConditionError', 'ConditionSet', 'ConditionBase', 'Condition', 'AbsoluteCondition', 'While', 'When', 'WhileNot', 'Always', 'Never', 'CompositeCondition', 'All', 'Any', 'And', 'Or', 'Not', 'NWhen', 'TimeInterval', 'TimeTermination', 'BeforeConsiderationSetExecution', 'AtConsiderationSetExecution', 'AfterConsiderationSetExecution', 'AfterNConsiderationSetExecutions', 'BeforePass', 'AtPass', 'AfterPass', 'AfterNPasses', 'EveryNPasses', 'BeforeEnvironmentStateUpdate', 'AtEnvironmentStateUpdate', 'AfterEnvironmentStateUpdate', 'AfterNEnvironmentStateUpdates', 'AtEnvironmentSequence', 'AfterEnvironmentSequence', 'AfterNEnvironmentSequences', 'BeforeNCalls', 'AtNCalls', 'AfterCall', 'AfterNCalls', 'AfterNCallsCombined', 'EveryNCalls', 'JustRan', 'AllHaveRun', 'WhenFinished', 'WhenFinishedAny', 'WhenFinishedAll', 'AtEnvironmentStateUpdateStart', 'AtEnvironmentStateUpdateNStart', 'AtEnvironmentSequenceStart', 'AtEnvironmentSequenceNStart', 'Threshold', 'GraphStructureCondition', 'CustomGraphStructureCondition', 'BeforeNodes', 'BeforeNode', 'WithNode', 'AfterNodes', 'AfterNode', 'AddEdgeTo', 'RemoveEdgeFrom'] +__all__ = ['Operation', 'ConditionError', 'ConditionSet', 'ConditionBase', 'Condition', 'AbsoluteCondition', 'While', 'When', 'WhileNot', 'Always', 'Never', 'CompositeCondition', 'All', 'Any', 'And', 'Or', 'Not', 'NWhen', 'TimeInterval', 'TimeTermination', 'BeforeTimeStep', 'AtTimeStep', 'AfterTimeStep', 'AfterNTimeSteps', 'BeforePass', 'AtPass', 'AfterPass', 'AfterNPasses', 'EveryNPasses', 'BeforeTrial', 'AtTrial', 'AfterTrial', 'AfterNTrials', 'AtRun', 'AfterRun', 'AfterNRuns', 'BeforeNCalls', 'AtNCalls', 'AfterCall', 'AfterNCalls', 'AfterNCallsCombined', 'EveryNCalls', 'JustRan', 'AllHaveRun', 'WhenFinished', 'WhenFinishedAny', 'WhenFinishedAll', 'AtTrialStart', 'AtTrialNStart', 'AtRunStart', 'AtRunNStart', 'Threshold', 'GraphStructureCondition', 'CustomGraphStructureCondition', 'BeforeNodes', 'BeforeNode', 'WithNode', 'AfterNodes', 'AfterNode', 'AddEdgeTo', 'RemoveEdgeFrom'] SubjectOperation = Union['Operation', str, Dict[Hashable, Union['Operation', str]]] @@ -635,87 +635,87 @@ class TimeTermination(AbsoluteCondition): @property def absolute_fixed_points(self): ... -class BeforeConsiderationSetExecution(Condition): +class BeforeTimeStep(Condition): - """BeforeConsiderationSetExecution + """BeforeTimeStep Parameters: - n(int): the 'CONSIDERATION_SET_EXECUTION' before which the Condition is satisfied + n(int): the 'TIME_STEP' before which the Condition is satisfied - time_scale(TimeScale): the TimeScale used as basis for counting `CONSIDERATION_SET_EXECUTION`\\ s (default: TimeScale.ENVIRONMENT_STATE_UPDATE) + time_scale(TimeScale): the TimeScale used as basis for counting `TIME_STEP`\\ s (default: TimeScale.TRIAL) Satisfied when: - - at most n-1 `CONSIDERATION_SET_EXECUTION`\\ s have occurred within one unit of time at the `TimeScale` specified by **time_scale**. + - at most n-1 `TIME_STEP`\\ s have occurred within one unit of time at the `TimeScale` specified by **time_scale**. Notes: - - Counts of TimeScales are zero-indexed (that is, the first `CONSIDERATION_SET_EXECUTION` is 0, the second `CONSIDERATION_SET_EXECUTION` is 1, etc.); - so, `BeforeConsiderationSetExecution(2)` is satisfied at `CONSIDERATION_SET_EXECUTION` 0 and `CONSIDERATION_SET_EXECUTION` 1. + - Counts of TimeScales are zero-indexed (that is, the first `TIME_STEP` is 0, the second `TIME_STEP` is 1, etc.); + so, `BeforeTimeStep(2)` is satisfied at `TIME_STEP` 0 and `TIME_STEP` 1. """ def __init__(self, n, time_scale: graph_scheduler.time.TimeScale = ...) -> None: ... -class AtConsiderationSetExecution(Condition): +class AtTimeStep(Condition): - """AtConsiderationSetExecution + """AtTimeStep Parameters: - n(int): the `CONSIDERATION_SET_EXECUTION` at which the Condition is satisfied + n(int): the `TIME_STEP` at which the Condition is satisfied - time_scale(TimeScale): the TimeScale used as basis for counting `CONSIDERATION_SET_EXECUTION`\\ s (default: TimeScale.ENVIRONMENT_STATE_UPDATE) + time_scale(TimeScale): the TimeScale used as basis for counting `TIME_STEP`\\ s (default: TimeScale.TRIAL) Satisfied when: - - exactly n `CONSIDERATION_SET_EXECUTION`\\ s have occurred within one unit of time at the `TimeScale` specified by **time_scale**. + - exactly n `TIME_STEP`\\ s have occurred within one unit of time at the `TimeScale` specified by **time_scale**. Notes: - - Counts of TimeScales are zero-indexed (that is, the first 'CONSIDERATION_SET_EXECUTION' is pass 0, the second 'CONSIDERATION_SET_EXECUTION' is 1, etc.); - so, `AtConsiderationSetExecution(1)` is satisfied when a single `CONSIDERATION_SET_EXECUTION` (`CONSIDERATION_SET_EXECUTION` 0) has occurred, and `AtConsiderationSetExecution(2)` is satisfied - when two `CONSIDERATION_SET_EXECUTION`\\ s have occurred (`CONSIDERATION_SET_EXECUTION` 0 and `CONSIDERATION_SET_EXECUTION` 1), etc.. + - Counts of TimeScales are zero-indexed (that is, the first 'TIME_STEP' is pass 0, the second 'TIME_STEP' is 1, etc.); + so, `AtTimeStep(1)` is satisfied when a single `TIME_STEP` (`TIME_STEP` 0) has occurred, and `AtTimeStep(2)` is satisfied + when two `TIME_STEP`\\ s have occurred (`TIME_STEP` 0 and `TIME_STEP` 1), etc.. """ def __init__(self, n, time_scale: graph_scheduler.time.TimeScale = ...) -> None: ... -class AfterConsiderationSetExecution(Condition): +class AfterTimeStep(Condition): - """AfterConsiderationSetExecution + """AfterTimeStep Parameters: - n(int): the `CONSIDERATION_SET_EXECUTION` after which the Condition is satisfied + n(int): the `TIME_STEP` after which the Condition is satisfied - time_scale(TimeScale): the TimeScale used as basis for counting `CONSIDERATION_SET_EXECUTION`\\ s (default: TimeScale.ENVIRONMENT_STATE_UPDATE) + time_scale(TimeScale): the TimeScale used as basis for counting `TIME_STEP`\\ s (default: TimeScale.TRIAL) Satisfied when: - - at least n+1 `CONSIDERATION_SET_EXECUTION`\\ s have occurred within one unit of time at the `TimeScale` specified by **time_scale**. + - at least n+1 `TIME_STEP`\\ s have occurred within one unit of time at the `TimeScale` specified by **time_scale**. Notes: - - Counts of TimeScals are zero-indexed (that is, the first `CONSIDERATION_SET_EXECUTION` is 0, the second `CONSIDERATION_SET_EXECUTION` is 1, etc.); so, - `AfterConsiderationSetExecution(1)` is satisfied after `CONSIDERATION_SET_EXECUTION` 1 has occurred and thereafter (i.e., in `CONSIDERATION_SET_EXECUTION`\\ s 2, 3, 4, etc.). + - Counts of TimeScals are zero-indexed (that is, the first `TIME_STEP` is 0, the second `TIME_STEP` is 1, etc.); so, + `AfterTimeStep(1)` is satisfied after `TIME_STEP` 1 has occurred and thereafter (i.e., in `TIME_STEP`\\ s 2, 3, 4, etc.). """ def __init__(self, n, time_scale: graph_scheduler.time.TimeScale = ...) -> None: ... -class AfterNConsiderationSetExecutions(Condition): +class AfterNTimeSteps(Condition): - """AfterNConsiderationSetExecutions + """AfterNTimeSteps Parameters: - n(int): the number of `CONSIDERATION_SET_EXECUTION`\\ s after which the Condition is satisfied + n(int): the number of `TIME_STEP`\\ s after which the Condition is satisfied - time_scale(TimeScale): the TimeScale used as basis for counting `CONSIDERATION_SET_EXECUTION`\\ s (default: TimeScale.ENVIRONMENT_STATE_UPDATE) + time_scale(TimeScale): the TimeScale used as basis for counting `TIME_STEP`\\ s (default: TimeScale.TRIAL) Satisfied when: - - at least n `CONSIDERATION_SET_EXECUTION`\\ s have occurred within one unit of time at the `TimeScale` specified by **time_scale**. + - at least n `TIME_STEP`\\ s have occurred within one unit of time at the `TimeScale` specified by **time_scale**. """ def __init__(self, n, time_scale: graph_scheduler.time.TimeScale = ...) -> None: ... @@ -728,7 +728,7 @@ class BeforePass(Condition): n(int): the 'PASS' before which the Condition is satisfied - time_scale(TimeScale): the TimeScale used as basis for counting `PASS`\\ es (default: TimeScale.ENVIRONMENT_STATE_UPDATE) + time_scale(TimeScale): the TimeScale used as basis for counting `PASS`\\ es (default: TimeScale.TRIAL) Satisfied when: @@ -750,7 +750,7 @@ class AtPass(Condition): n(int): the `PASS` at which the Condition is satisfied - time_scale(TimeScale): the TimeScale used as basis for counting `PASS`\\ es (default: TimeScale.ENVIRONMENT_STATE_UPDATE) + time_scale(TimeScale): the TimeScale used as basis for counting `PASS`\\ es (default: TimeScale.TRIAL) Satisfied when: @@ -773,7 +773,7 @@ class AfterPass(Condition): n(int): the `PASS` after which the Condition is satisfied - time_scale(TimeScale): the TimeScale used as basis for counting `PASS`\\ es (default: TimeScale.ENVIRONMENT_STATE_UPDATE) + time_scale(TimeScale): the TimeScale used as basis for counting `PASS`\\ es (default: TimeScale.TRIAL) Satisfied when: @@ -795,7 +795,7 @@ class AfterNPasses(Condition): n(int): the number of `PASS`\\ es after which the Condition is satisfied - time_scale(TimeScale): the TimeScale used as basis for counting `PASS`\\ es (default: TimeScale.ENVIRONMENT_STATE_UPDATE) + time_scale(TimeScale): the TimeScale used as basis for counting `PASS`\\ es (default: TimeScale.TRIAL) Satisfied when: @@ -813,7 +813,7 @@ class EveryNPasses(Condition): n(int): the frequency of passes with which this condition is satisfied - time_scale(TimeScale): the TimeScale used as basis for counting `PASS`\\ es (default: TimeScale.ENVIRONMENT_STATE_UPDATE) + time_scale(TimeScale): the TimeScale used as basis for counting `PASS`\\ es (default: TimeScale.TRIAL) Satisfied when: @@ -825,150 +825,150 @@ class EveryNPasses(Condition): """ def __init__(self, n, time_scale: graph_scheduler.time.TimeScale = ...) -> None: ... -class BeforeEnvironmentStateUpdate(Condition): +class BeforeTrial(Condition): - """BeforeEnvironmentStateUpdate + """BeforeTrial Parameters: - n(int): the `ENVIRONMENT_STATE_UPDATE ` before which the Condition is satisfied + n(int): the `TRIAL ` before which the Condition is satisfied - time_scale(TimeScale): the TimeScale used as basis for counting `ENVIRONMENT_STATE_UPDATE `\\ s - (default: TimeScale.ENVIRONMENT_SEQUENCE) + time_scale(TimeScale): the TimeScale used as basis for counting `TRIAL `\\ s + (default: TimeScale.RUN) Satisfied when: - - at most n-1 `ENVIRONMENT_STATE_UPDATE `\\ s have occurred within one unit of time at the `TimeScale` + - at most n-1 `TRIAL `\\ s have occurred within one unit of time at the `TimeScale` specified by **time_scale**. Notes: - - Counts of TimeScales are zero-indexed (that is, the first `ENVIRONMENT_STATE_UPDATE ` is 0, the second - `ENVIRONMENT_STATE_UPDATE ` is 1, etc.); so, `BeforeEnvironmentStateUpdate(2)` is satisfied at `ENVIRONMENT_STATE_UPDATE ` 0 - and `ENVIRONMENT_STATE_UPDATE ` 1. + - Counts of TimeScales are zero-indexed (that is, the first `TRIAL ` is 0, the second + `TRIAL ` is 1, etc.); so, `BeforeTrial(2)` is satisfied at `TRIAL ` 0 + and `TRIAL ` 1. """ def __init__(self, n, time_scale: graph_scheduler.time.TimeScale = ...) -> None: ... -class AtEnvironmentStateUpdate(Condition): +class AtTrial(Condition): - """AtEnvironmentStateUpdate + """AtTrial Parameters: - n(int): the `ENVIRONMENT_STATE_UPDATE ` at which the Condition is satisfied + n(int): the `TRIAL ` at which the Condition is satisfied - time_scale(TimeScale): the TimeScale used as basis for counting `ENVIRONMENT_STATE_UPDATE `\\ s - (default: TimeScale.ENVIRONMENT_SEQUENCE) + time_scale(TimeScale): the TimeScale used as basis for counting `TRIAL `\\ s + (default: TimeScale.RUN) Satisfied when: - - exactly n `ENVIRONMENT_STATE_UPDATE `\\ s have occurred within one unit of time at the `TimeScale` + - exactly n `TRIAL `\\ s have occurred within one unit of time at the `TimeScale` specified by **time_scale**. Notes: - - Counts of TimeScales are zero-indexed (that is, the first `ENVIRONMENT_STATE_UPDATE ` is 0, - the second `ENVIRONMENT_STATE_UPDATE ` is 1, etc.); so, `AtEnvironmentStateUpdate(1)` is satisfied when one - `ENVIRONMENT_STATE_UPDATE ` (`ENVIRONMENT_STATE_UPDATE ` 0) has already occurred. + - Counts of TimeScales are zero-indexed (that is, the first `TRIAL ` is 0, + the second `TRIAL ` is 1, etc.); so, `AtTrial(1)` is satisfied when one + `TRIAL ` (`TRIAL ` 0) has already occurred. """ def __init__(self, n, time_scale: graph_scheduler.time.TimeScale = ...) -> None: ... -class AfterEnvironmentStateUpdate(Condition): +class AfterTrial(Condition): - """AfterEnvironmentStateUpdate + """AfterTrial Parameters: - n(int): the `ENVIRONMENT_STATE_UPDATE ` after which the Condition is satisfied + n(int): the `TRIAL ` after which the Condition is satisfied - time_scale(TimeScale): the TimeScale used as basis for counting `ENVIRONMENT_STATE_UPDATE `\\ s. - (default: TimeScale.ENVIRONMENT_SEQUENCE) + time_scale(TimeScale): the TimeScale used as basis for counting `TRIAL `\\ s. + (default: TimeScale.RUN) Satisfied when: - - at least n+1 `ENVIRONMENT_STATE_UPDATE `\\ s have occurred within one unit of time at the `TimeScale` + - at least n+1 `TRIAL `\\ s have occurred within one unit of time at the `TimeScale` specified by **time_scale**. Notes: - - Counts of TimeScales are zero-indexed (that is, the first `ENVIRONMENT_STATE_UPDATE ` is 0, the second - `ENVIRONMENT_STATE_UPDATE ` is 1, etc.); so, `AfterPass(1)` is satisfied after `ENVIRONMENT_STATE_UPDATE ` 1 - has occurred and thereafter (i.e., in `ENVIRONMENT_STATE_UPDATE `\\ s 2, 3, 4, etc.). + - Counts of TimeScales are zero-indexed (that is, the first `TRIAL ` is 0, the second + `TRIAL ` is 1, etc.); so, `AfterPass(1)` is satisfied after `TRIAL ` 1 + has occurred and thereafter (i.e., in `TRIAL `\\ s 2, 3, 4, etc.). """ def __init__(self, n, time_scale: graph_scheduler.time.TimeScale = ...) -> None: ... -class AfterNEnvironmentStateUpdates(Condition): +class AfterNTrials(Condition): - """AfterNEnvironmentStateUpdates + """AfterNTrials Parameters: - n(int): the number of `ENVIRONMENT_STATE_UPDATE `\\ s after which the Condition is satisfied + n(int): the number of `TRIAL `\\ s after which the Condition is satisfied - time_scale(TimeScale): the TimeScale used as basis for counting `ENVIRONMENT_STATE_UPDATE `\\ s - (default: TimeScale.ENVIRONMENT_SEQUENCE) + time_scale(TimeScale): the TimeScale used as basis for counting `TRIAL `\\ s + (default: TimeScale.RUN) Satisfied when: - - at least n `ENVIRONMENT_STATE_UPDATE `\\ s have occured within one unit of time at the `TimeScale` + - at least n `TRIAL `\\ s have occured within one unit of time at the `TimeScale` specified by **time_scale**. """ def __init__(self, n, time_scale: graph_scheduler.time.TimeScale = ...) -> None: ... -class AtEnvironmentSequence(Condition): +class AtRun(Condition): - """AtEnvironmentSequence + """AtRun Parameters: - n(int): the `ENVIRONMENT_SEQUENCE` at which the Condition is satisfied + n(int): the `RUN` at which the Condition is satisfied Satisfied when: - - exactly n `ENVIRONMENT_SEQUENCE`\\ s have occurred. + - exactly n `RUN`\\ s have occurred. Notes: - - `ENVIRONMENT_SEQUENCE`\\ s are managed by the environment using the Scheduler (e.g. `end_environment_sequence ` ) and are not automatically updated by this package. + - `RUN`\\ s are managed by the environment using the Scheduler (e.g. `end_environment_sequence ` ) and are not automatically updated by this package. """ def __init__(self, n) -> None: ... -class AfterEnvironmentSequence(Condition): +class AfterRun(Condition): - """AfterEnvironmentSequence + """AfterRun Parameters: - n(int): the `ENVIRONMENT_SEQUENCE` after which the Condition is satisfied + n(int): the `RUN` after which the Condition is satisfied Satisfied when: - - at least n+1 `ENVIRONMENT_SEQUENCE`\\ s have occurred. + - at least n+1 `RUN`\\ s have occurred. Notes: - - `ENVIRONMENT_SEQUENCE`\\ s are managed by the environment using the Scheduler (e.g. `end_environment_sequence ` ) and are not automatically updated by this package. + - `RUN`\\ s are managed by the environment using the Scheduler (e.g. `end_environment_sequence ` ) and are not automatically updated by this package. """ def __init__(self, n) -> None: ... -class AfterNEnvironmentSequences(Condition): +class AfterNRuns(Condition): - """AfterNEnvironmentSequences + """AfterNRuns Parameters: - n(int): the number of `ENVIRONMENT_SEQUENCE`\\ s after which the Condition is satisfied + n(int): the number of `RUN`\\ s after which the Condition is satisfied Satisfied when: - - at least n `ENVIRONMENT_SEQUENCE`\\ s have occured. + - at least n `RUN`\\ s have occured. Notes: - - `ENVIRONMENT_SEQUENCE`\\ s are managed by the environment using the Scheduler (e.g. `end_environment_sequence ` ) and are not automatically updated by this package. + - `RUN`\\ s are managed by the environment using the Scheduler (e.g. `end_environment_sequence ` ) and are not automatically updated by this package. """ def __init__(self, n) -> None: ... @@ -984,7 +984,7 @@ class BeforeNCalls(_DependencyValidation, Condition): n(int): the number of executions of **dependency** before which the Condition is satisfied time_scale(TimeScale): the TimeScale used as basis for counting executions of **dependency** - (default: TimeScale.ENVIRONMENT_STATE_UPDATE) + (default: TimeScale.TRIAL) Satisfied when: @@ -1005,7 +1005,7 @@ class AtNCalls(_DependencyValidation, Condition): n(int): the number of executions of **dependency** at which the Condition is satisfied time_scale(TimeScale): the TimeScale used as basis for counting executions of **dependency** - (default: TimeScale.ENVIRONMENT_STATE_UPDATE) + (default: TimeScale.TRIAL) Satisfied when: @@ -1026,7 +1026,7 @@ class AfterCall(_DependencyValidation, Condition): n(int): the number of executions of **dependency** after which the Condition is satisfied time_scale(TimeScale): the TimeScale used as basis for counting executions of **dependency** - (default: TimeScale.ENVIRONMENT_STATE_UPDATE) + (default: TimeScale.TRIAL) Satisfied when: @@ -1047,7 +1047,7 @@ class AfterNCalls(_DependencyValidation, Condition): n(int): the number of executions of **dependency** after which the Condition is satisfied time_scale(TimeScale): the TimeScale used as basis for counting executions of **dependency** - (default: TimeScale.ENVIRONMENT_STATE_UPDATE) + (default: TimeScale.TRIAL) Satisfied when: @@ -1069,7 +1069,7 @@ class AfterNCallsCombined(_DependencyValidation, Condition): Condition is satisfied (default: None) time_scale(TimeScale): the TimeScale used as basis for counting executions of **dependency** - (default: TimeScale.ENVIRONMENT_STATE_UPDATE) + (default: TimeScale.TRIAL) Satisfied when: @@ -1115,13 +1115,13 @@ class JustRan(_DependencyValidation, Condition): Satisfied when: - - the node specified in **dependency** executed in the previous `CONSIDERATION_SET_EXECUTION`. + - the node specified in **dependency** executed in the previous `TIME_STEP`. Notes: - This Condition can transcend divisions between `TimeScales `. - For example, if A runs in the final `CONSIDERATION_SET_EXECUTION` of an `ENVIRONMENT_STATE_UPDATE `, - JustRan(A) is satisfied at the beginning of the next `ENVIRONMENT_STATE_UPDATE `. + For example, if A runs in the final `TIME_STEP` of an `TRIAL `, + JustRan(A) is satisfied at the beginning of the next `TRIAL `. """ def __init__(self, dependency) -> None: ... @@ -1135,7 +1135,7 @@ class AllHaveRun(_DependencyValidation, Condition): *dependencies (Hashable): an iterable of nodes on which the Condition depends time_scale(TimeScale): the TimeScale used as basis for counting executions of **dependency** - (default: TimeScale.ENVIRONMENT_STATE_UPDATE) + (default: TimeScale.TRIAL) Satisfied when: @@ -1222,13 +1222,13 @@ class WhenFinishedAll(_DependencyValidation, Condition): """ def __init__(self, *dependencies) -> None: ... -class AtEnvironmentStateUpdateStart(AtPass): +class AtTrialStart(AtPass): - """AtEnvironmentStateUpdateStart + """AtTrialStart Satisfied when: - - at the beginning of an `ENVIRONMENT_STATE_UPDATE ` + - at the beginning of an `TRIAL ` Notes: @@ -1236,57 +1236,57 @@ class AtEnvironmentStateUpdateStart(AtPass): """ def __init__(self) -> None: ... -class AtEnvironmentStateUpdateNStart(All): +class AtTrialNStart(All): - """AtEnvironmentStateUpdateNStart + """AtTrialNStart Parameters: - n(int): the `ENVIRONMENT_STATE_UPDATE ` on which the Condition is satisfied + n(int): the `TRIAL ` on which the Condition is satisfied - time_scale(TimeScale): the TimeScale used as basis for counting `ENVIRONMENT_STATE_UPDATE `\\ s - (default: TimeScale.ENVIRONMENT_SEQUENCE) + time_scale(TimeScale): the TimeScale used as basis for counting `TRIAL `\\ s + (default: TimeScale.RUN) Satisfied when: - - on `PASS` 0 of the specified `ENVIRONMENT_STATE_UPDATE ` counted using 'TimeScale` + - on `PASS` 0 of the specified `TRIAL ` counted using 'TimeScale` Notes: - - identical to All(AtPass(0), AtEnvironmentStateUpdate(n, time_scale)) + - identical to All(AtPass(0), AtTrial(n, time_scale)) """ def __init__(self, n, time_scale: graph_scheduler.time.TimeScale = ...) -> None: ... -class AtEnvironmentSequenceStart(AtEnvironmentStateUpdate): +class AtRunStart(AtTrial): - """AtEnvironmentSequenceStart + """AtRunStart Satisfied when: - - at the beginning of an `ENVIRONMENT_SEQUENCE` + - at the beginning of an `RUN` Notes: - - identical to `AtEnvironmentStateUpdate(0) ` + - identical to `AtTrial(0) ` """ def __init__(self) -> None: ... -class AtEnvironmentSequenceNStart(All): +class AtRunNStart(All): - """AtEnvironmentSequenceNStart + """AtRunNStart Parameters: - n(int): the `ENVIRONMENT_SEQUENCE` on which the Condition is satisfied + n(int): the `RUN` on which the Condition is satisfied Satisfied when: - - on `ENVIRONMENT_STATE_UPDATE ` 0 of the specified `ENVIRONMENT_SEQUENCE` counted using 'TimeScale` + - on `TRIAL ` 0 of the specified `RUN` counted using 'TimeScale` Notes: - - identical to `All(AtEnvironmentStateUpdate(0), AtEnvironmentSequence(n))` + - identical to `All(AtTrial(0), AtRun(n))` """ def __init__(self, n) -> None: ... diff --git a/psyneulink/library/compositions/emcomposition.py b/psyneulink/library/compositions/emcomposition.py index 807f316fe55..11eac6e1d20 100644 --- a/psyneulink/library/compositions/emcomposition.py +++ b/psyneulink/library/compositions/emcomposition.py @@ -264,7 +264,6 @@ - `Field Weights ` * `EMComposition_Class_Reference` - .. _EMComposition_Overview: Overview @@ -315,14 +314,15 @@ **Operation** *Retrieval.* The values retrieved from `memory ` (one for each field) are based -on the relative similarity of the keys to the entries in memory, computed as the dot product of each key and the -values in the corresponding field for each entry in memory. These dot products are then softmaxed, and those -softmax distributions are weighted by the corresponding `field_weights ` for each field -and then combined, to produce a single softmax distribution over the entries in memory. That is then used to generate -a weighted average of the retrieved values across all fields, which is returned as the `result ` -of the EMComposition's `execution ` (an EMComposition can also be configured to return the -entry with the highest dot product weighted by field, however then it is not compatible with learning; -see `softmax_choice `). +on the relative similarity of the keys to the entries in memory, computed as the distance of each key and the +values in the corresponding field for each entry in memory. By default, normalized dot products (comparable to cosine +similarity) are used to compute the similarity of each query to each key in memory. These distances are then +weighted by the corresponding `field_weights ` for each field (if specified) and then +summed, and the sum is softmaxed to produce a softmax distribution over the entries in memory. That is then used to +generate a softmax-weighted average of the retrieved values across all fields, which is returned as the `result +` of the EMComposition's `execution ` (an EMComposition can also be +configured to return the entry with the lowest distance weighted by field, however then it is not compatible +with learning; see `softmax_choice `). COMMENT: TBD DISTANCE ATTRIBUTES: @@ -428,11 +428,11 @@ process, but the values of which are retrieved and assigned as the `value ` of the corresponding `retrieved_node `. This distinction between keys and value corresponds to the format of a standard "dictionary," though in that case only a single key and value are allowed, whereas - here there can be one or more keys and any number of values; if all fields are keys, this implements a full form of - content-addressable memory. If **learn_field_weight** is True (and `enable_learning` + here there can be one or more keys and any number of values; if all fields are keys, this implements a full form of + content-addressable memory. If **learn_field_weights** is True (and `enable_learning` is either True or a list with True for at least one entry), then the field_weights can be modified during training (this functions similarly to the attention head of a Transformer model, although at present the field can only be - scalar values rather than vecdtors); if **learn_field_weight** is False, then the field_weights are fixed. + scalar values rather than vecdtors); if **learn_field_weights** is False, then the field_weights are fixed. The following options can be used to specify **field_weights**: * *None* (the default): all fields except the last are treated as keys, and are weighted equally for retrieval, @@ -458,13 +458,6 @@ are used to weight the retrieved value of each field. This setting is ignored if **field_weights** is None or `concatenate_queries ` is in effect. - .. warning:: - If **normalize_field_weights** is False and **enable_learning** is True, a warning is issued indicating that - this may produce an error if the `loss_spec ` for the EMComposition (or an - `AutodiffComposition` that contains it) requires all values to be between 0 and 1, and calling the - EMComposition's `learn ` method will generate an error if the loss_spec is specified is - one known to be incompatible (e.g., `BINARY_CROSS_ENTROPY `). - .. _EMComposition_Field_Names: * **field_names**: specifies names that can be assigned to the fields. The number of names specified must @@ -485,9 +478,8 @@ .. note:: While this is computationally more efficient, it can affect the outcome of the `matching process - `, since computing the normalized dot product of a single vector comprised of the - concatentated inputs is not identical to computing the normalized dot product of each field independently and - then combining the results. + `, since computing the distance of a single vector comprised of the concatentated + inputs is not identical to computing the distance of each field independently and then combining the results. .. note:: All `query_input_nodes ` and `retrieved_nodes ` @@ -514,41 +506,41 @@ .. _EMComposition_Softmax_Gain: -* **softmax_gain** : specifies the gain (inverse temperature) used for softmax normalizing the dot products of - queries and keys in memory (see `EMComposition_Execution` below). The following options can be used: +* **softmax_gain** : specifies the gain (inverse temperature) used for softmax normalizing the combined distances + used for retrieval (see `EMComposition_Execution` below). The following options can be used: * numeric value: the value is used as the gain of the `SoftMax` Function for the EMComposition's - `softmax_nodes `. + `softmax_node `. * *ADAPTIVE*: the `adapt_gain ` method of the `SoftMax` Function is used to adaptively set - the `softmax_gain ` based on the entropy of the dot products, in order to preserve + the `softmax_gain ` based on the entropy of the distances, in order to preserve the distribution over non- (or near) zero entries irrespective of how many (near) zero entries there are (see `Thresholding and Adaptive Gain ` for additional details). * *CONTROL*: a `ControlMechanism` is created, and its `ControlSignal` is used to modulate the `softmax_gain - ` parameter of the `SoftMax` function of the EMComposition's `softmax_nodes - `. + ` parameter of the `SoftMax` function of the EMComposition's `softmax_node + `. If *None* is specified, the default value for the `SoftMax` function is used. .. _EMComposition_Softmax_Threshold: * **softmax_threshold**: if this is specified, and **softmax_gain** is specified with a numeric value, - then any values below the specified threshold are set to 0 before the dot products are softmaxed + then any values below the specified threshold are set to 0 before the distances are softmaxed (see *mask_threhold* under `Thresholding and Adaptive Gain ` for additional details). .. _EMComposition_Softmax_Choice: -* **softmax_choice** : specifies how the `SoftMax` Function of each of the EMComposition's `softmax_nodes - ` is used, with the dot products of queries and keys, to generate a retrieved item; +* **softmax_choice** : specifies how the `SoftMax` Function of the EMComposition's `softmax_node + ` is used, with the combined distances, to generate a retrieved item; the following are the options that can be used and the retrieved value they produce: - * *WEIGHTED_AVG* (default): softmax-weighted average of entries, based on their dot products with the key(s). + * *WEIGHTED_AVG* (default): softmax-weighted average based on combined distances of queries and keys in memory. - * *ARG_MAX*: entry with the largest dot product (one with lowest index in `memory `)\ + * *ARG_MAX*: entry with the smallest distance (one with lowest index in `memory `)\ if there are identical ones). - * *PROBABISTIC*: probabilistically chosen entry based on softmax-transformed distribution of dot products. + * *PROBABISTIC*: probabilistically chosen entry based on softmax-transformed distribution of combined distance. .. warning:: Use of the *ARG_MAX* and *PROBABILISTIC* options is not compatible with learning, as these implement a discrete @@ -585,13 +577,14 @@ `, and each entry must be a boolean that specifies whether the corresponding `retrieved_node ` is used for learning. -* **learn_field_weight** : specifies whether `field_weights ` are modifiable during +* **learn_field_weights** : specifies whether `field_weights ` are modifiable during learning (see `field_weights ` and `Learning ` for additional information. For learning of `field_weights ` to occur, **enable_learning** must - also be True, or it must be a list with at least one True entry. + also be True, or it must be a list with at least one True entry. If **learn_field_weights** is True, + **use_gating_for_weighting** must be False (see `note `). * **learning_rate** : specifies the rate at which `field_weights ` are learned if - **learn_field_weight** is True; see `Learning ` for additional information. + **learn_field_weights** is True; see `Learning ` for additional information. .. _EMComposition_Structure: @@ -621,20 +614,19 @@ .. _EMComposition_Memory_Storage: .. technical_note:: The memories are actually stored in the `matrix ` parameters of the`MappingProjections` - from the `combined_softmax_node ` to each of the `retrieved_nodes - `. Memories associated with each key are also stored (in inverted form) - in the `matrix ` parameters of the `MappingProjection ` - from the `query_input_nodes ` to each of the corresponding `match_nodes + from the `combined_matches_node ` to each of the `retrieved_nodes + `. Memories associated with each key are also stored (in inverted form) in the + `matrix ` parameters of the `MappingProjection ` from the + `query_input_nodes ` to each of the corresponding `match_nodes `. This is done so that the match of each query to the keys in memory for the corresponding field can be computed simply by passing the input for each query through the Projection (which - computes the dot product of the input with the Projection's `matrix ` parameter) to - the corresponding match_node; and, similarly, retrieivals can be computed by passing the softmax distributions - and weighting for each field computed in the `combined_softmax_node ` - through its Projection to each `retrieved_node ` (which are inverted versions - of the matrices of the `MappingProjections ` from the `query_input_nodes - ` to each of the corresponding `match_nodes `), - to compute the dot product of the weighted softmax over entries with the corresponding field of each entry - that yields the retreieved value for each field. + computes the distance of the input with the Projection's `matrix ` parameter) to the + corresponding match_node; and, similarly, retrieivals can be computed by passing the softmax distributions for + each field computed in the `combined_matches_node ` through its Projection + to each `retrieved_node ` (which are inverted versions of the matrices of the + `MappingProjections ` from the `query_input_nodes ` to each + of the corresponding `match_nodes `), to compute the distance of the weighted + softmax over entries with the corresponding field of each entry that yields the retreieved value for each field. .. _EMComposition_Output: @@ -668,7 +660,7 @@ * **Concatenation**. By default, the input to every `query_input_node ` is passed to a to its own `match_node ` through a `MappingProjection` that computes its - dot product with the corresponding field of each entry in `memory `. In this way, each + distance with the corresponding field of each entry in `memory `. In this way, each match is normalized so that, absent `field_weighting `, all keys contribute equally to retrieval irrespective of relative differences in the norms of the queries or the keys in memory. However, if the `field_weights ` are the same for all `keys ` and @@ -685,32 +677,65 @@ however it will not necessarily produce the same results as passing each query through its own `match_node ` (see `concatenate keys <`concatenate_queries_node>` for additional information). -* **Match memories by field**. The values of each `query_input_node ` (or the - `concatenate_queries_node ` if `concatenate_queries - ` attribute is True) are passed through a `MappingProjection` that computes - the dot product of the input with each memory for the corresponding field, the result of which is passed to the - corresponding `match_node `. - -* **Softmax normalize matches over fields**. The dot product for each key field is passed from the `match_node - ` to the corresponding `softmax_node `, which applies - the `SoftMax` Function to normalize the dot products for each key field. If a numerical value is specified for - `softmax_gain `, that is used as the gain (inverse temperature) for the SoftMax Function; - if *ADAPTIVE* is specified, then the `SoftMax.adapt_gain` function is used to adaptively set the gain based on the - dot products in each field; if *CONTROL* is specified, then the dot products are monitored by a `ControlMechanism` - that uses the `adapt_gain ` method of the SoftMax Function to modulate its `gain ` - parameter; if None is specified, the default value of the `Softmax` Function is used as the `gain ` - parameter (see `Softmax_Gain ` for additional details). - -* **Weight fields**. If `field weights ` are specified, then the softmax normalized dot - product for each key field is passed to the corresponding `field_weight_node ` - where it is multiplied by the corresponding `field_weight ` (if - `use_gating_for_weighting ` is True, this is done by using the `field_weight - ` to output gate the `softmax_node `). The weighted softmax - vectors for all key fields are then passed to the `combined_softmax_node `, - where they are haddamard summed to produce a single weighting for each memory. - -* **Retrieve values by field**. The vector of softmax weights for each memory generated by the `combined_softmax_node - ` is passed through the Projections to the each of the `retrieved_nodes +.. _EMComposition_Distance_Computation: + +* **Match memories by field**. The values of each `query_input_node ` + (or the `concatenate_queries_node ` if `concatenate_queries + ` attribute is True) are passed through a `MappingProjection` that + computes the distance between the corresponding input (query) and each memory (key) for the corresponding field, + the result of which is possed to the corresponding `match_node `. By default, the + distance is computed as the normalized dot product (i.e., between the normalized query vector and the normalized + key for the corresponding `field `, that is comparable to using cosine similarity). However, + if `normalize_memories ` is set to False, just the raw dot product is computed. + The distance can also be customized by specifying a different `function ` for the + `MappingProjection` to the `match_node `. The result is assigned as the `value + ` of the corresponding `match_node `. + +.. _EMComposition_Field_Weighting: + +* **Weight distances**. If `field weights ` are specified, then the distance computed + by the `MappingProjection` to each `match_node ` is multiplied by the corresponding + `field_weight ` using the `field_weight_node `. + By default (if `use_gating_for_weighting ` is False), this is done using + the `weighted_match_nodes `, each of which receives a Projection from a + `match_node ` and the corresponding `field_weight_node ` + and multiplies them to produce the weighted distance for that field as its output. However, if + `use_gating_for_weighting ` is True, the `field_weight_nodes` are implemented + as `GatingMechanisms `, each of which uses its `field weight ` as a + `GatingSignal ` to output gate (i.e., multiplicatively modulate the output of) the corresponding + `match_node `. In this case, the `weighted_match_nodes` are not implemented, + and the output of the `match_node ` is passed directly to the `combined_matches_node + `. + + + .. _EMComposition_Gating_For_Weighting: + .. note:: + Setting `use_gating_for_weighting ` to True reduces the size and + complexity of the EMComposition, by eliminating the `weighted_match_nodes `. + However, doing to precludes the ability to learn the `field_weights `, + since `GatingSignals ` are `ModulatorySignal>` that cannot be learned. If learning is required, + then `use_gating_for_weighting` should be set to False. + +* **Combine distances**. If `field weights ` are used to specify more than one `key field + `, then the (weighted) distances computed for each field (see above) are summed across fields + by the `combined_matches_node `, before being passed to the `softmax_node + `. If only one key field is specified, then the output of the `match_node + ` is passed directly to the `softmax_node `. + +* **Softmax normalize distances**. The distances, passed either from the `combined_matches_node + `, or directly from the `match_node ` if there is + only one key field, are passed to the `softmax_node `, which applies the `SoftMax` + Function, which generates the softmax distribution used to retrieve entries from `memory `. + If a numerical value is specified for `softmax_gain `, that is used as the gain (inverse + temperature) for the SoftMax Function; if *ADAPTIVE* is specified, then the `SoftMax.adapt_gain` function is used + to adaptively set the gain based on the summed distance (i.e., the output of the `combined_matches_node + `; if *CONTROL* is specified, then the summed distance is monitored by a + `ControlMechanism` that uses the `adapt_gain ` method of the `SoftMax` Function to modulate its + `gain ` parameter; if None is specified, the default value of the `Softmax` Function is used as the + `gain ` parameter (see `Softmax_Gain ` for additional details). + +* **Retrieve values by field**. The vector of softmax weights for each memory generated by the `softmax_node + ` is passed through the Projections to the each of the `retrieved_nodes ` to compute the retrieved value for each field. * **Decay memories**. If `memory_decay ` is True, then each of the memories is decayed @@ -718,7 +743,7 @@ .. technical_note:: This is done by multiplying the `matrix ` parameter of the `MappingProjection` from - the `combined_softmax_node ` to each of the `retrieved_nodes + the `combined_matches_node ` to each of the `retrieved_nodes `, as well as the `matrix ` parameter of the `MappingProjection` from each `query_input_node ` to the corresponding `match_node ` by `memory_decay `, @@ -733,7 +758,7 @@ .. technical_note:: This is done by adding the input vectors to the the corresponding rows of the `matrix ` - of the `MappingProjection` from the `retreival_weighting_node ` to each + of the `MappingProjection` from the `combined_matches_node ` to each of the `retrieved_nodes `, as well as the `matrix ` parameter of the `MappingProjection` from each `query_input_node ` to the corresponding `match_node ` (see note `above ` for @@ -964,7 +989,7 @@ **Use of field_weights to specify keys and values.** -Note that the figure now shows `RETRIEVAL WEIGHTING ` `nodes `, +Note that the figure now shows ` [WEIGHT] ` `nodes `, that are used to implement the relative contribution that each key field makes to the matching process specifed in `field_weights ` argument. By default, these are equal (all assigned a value of 1), but different values can be used to weight the relative contribution of each key field. The values are normalized so @@ -983,7 +1008,7 @@ **Use of field_weights to specify relative contribution of fields to matching process.** Note that in this case, the `concatenate_queries_node ` has been replaced by -a pair of `retreival_weighting_nodes `, one for each key field. This is because +a pair of `weighted_match_node `, one for each key field. This is because the keys were assigned different weights; when they are assigned equal weights, or if no weights are specified, and `normalize_memories ` is `True`, then the keys are concatenated and are concatenated for efficiency of processing. This can be suppressed by specifying `concatenate_queries` as `False` @@ -1029,25 +1054,37 @@ WEIGHTED_AVG = ALL PROBABILISTIC = PROB_INDICATOR -QUERY_AFFIX = ' [QUERY]' -VALUE_AFFIX = ' [VALUE]' -MATCH_TO_KEYS_AFFIX = ' [MATCH to KEYS]' +QUERY_NODE_NAME = 'QUERY' +QUERY_AFFIX = f' [{QUERY_NODE_NAME}]' +VALUE_NODE_NAME = 'VALUE' +VALUE_AFFIX = f' [{VALUE_NODE_NAME}]' +MATCH = 'MATCH' +MATCH_AFFIX = f' [{MATCH}]' +MATCH_TO_KEYS_NODE_NAME = f'{MATCH} to KEYS' +WEIGHT = 'WEIGHT' +WEIGHT_AFFIX = f' [{WEIGHT}]' +MATCH_TO_KEYS_AFFIX = f' [{MATCH_TO_KEYS_NODE_NAME}]' +WEIGHTED_MATCH_NODE_NAME = 'WEIGHTED MATCH' +WEIGHTED_MATCH_AFFIX = f' [{WEIGHTED_MATCH_NODE_NAME}]' +CONCATENATE_QUERIES_NAME = 'CONCATENATE QUERIES' +COMBINE_MATCHES_NODE_NAME = 'COMBINE MATCHES' +COMBINE_MATCHES_AFFIX = f' [{COMBINE_MATCHES_NODE_NAME}]' +SOFTMAX_NODE_NAME = 'RETRIEVE' +SOFTMAX_AFFIX = f' [{SOFTMAX_NODE_NAME}]' +RETRIEVED_NODE_NAME = 'RETRIEVED' RETRIEVED_AFFIX = ' [RETRIEVED]' -WEIGHTED_SOFTMAX_AFFIX = ' [WEIGHTED SOFTMAX]' -COMBINED_SOFTMAX_NODE_NAME = 'RETRIEVE' STORE_NODE_NAME = 'STORE' - def _memory_getter(owning_component=None, context=None)->list: """Return list of memories in which rows (outer dimension) are memories for each field. These are derived from `matrix ` parameter of the `afferent - ` MappingProjections to each of the `retrieved_nodes `. + ` MappingProjections to each of the `2472s `. """ # If storage_node (EMstoragemechanism) is implemented, get memory from that if owning_component.is_initializing: return None - if owning_component.use_storage_node: + if owning_component._use_storage_node: return owning_component.storage_node.parameters.memory_matrix.get(context) # Otherwise, get memory from Projection(s) to each retrieved_node @@ -1142,7 +1179,7 @@ class EMComposition(AutodiffComposition): see `Match memories by field ` for additional details. softmax_gain : float, ADAPTIVE or CONTROL : default 1.0 - specifies the temperature used for softmax normalizing the dot products of keys and memories; + specifies the temperature used for softmax normalizing the distance of queries and keys in memory; see `Softmax normalize matches over fields ` for additional details. softmax_threshold : float : default .0001 @@ -1150,7 +1187,7 @@ class EMComposition(AutodiffComposition): see *mask_threshold* under `Thresholding and Adaptive Gain ` for details). softmax_choice : WEIGHTED_AVG, ARG_MAX, PROBABILISTIC : default WEIGHTED_AVG - specifies how the softmax over dot products of keys and memories is used for retrieval; + specifies how the softmax over distances of queries and keys in memory is used for retrieval; see `softmax_choice ` for a description of each option. storage_prob : float : default 1.0 @@ -1170,8 +1207,8 @@ class EMComposition(AutodiffComposition): learn_field_weights : bool : default True specifies whether `field_weights ` are learnable during training; - requires **enable_learning** to be True to have any effect; see `learn_field_weights - ` for additional details. + requires **enable_learning** to be True to have any effect, and **use_gating_for_weighting** must be False; + see `learn_field_weights ` for additional details. learning_rate : float : default .01 specifies rate at which `field_weights ` are learned @@ -1186,14 +1223,8 @@ class EMComposition(AutodiffComposition): the EMComposition into another Composition; to do so, use_storage_node must be True (default). use_gating_for_weighting : bool : default False - specifies whether to use a `GatingMechanism` to modulate the `combined_softmax_node - ` instead of a standard ProcessingMechanism. If True, then - a GatingMechanism is constructed and used to gate the `OutputPort` of each `field_weight_node - EMComposition.field_weight_nodes`; otherwise, the output of each `field_weight_node - EMComposition.field_weight_nodes` projects to the `InputPort` of the `combined_softmax_node - EMComposition.combined_softmax_node` that receives a Projection from the corresponding - `field_weight_node `, and multiplies its `value - `. + specifies whether to use output gating to weight the `match_nodes ` instead of + a standard input (see `Weight distances ` for additional details). Attributes ---------- @@ -1217,9 +1248,10 @@ class EMComposition(AutodiffComposition): field_weights : tuple[float] determines which fields of the input are treated as "keys" (non-zero values) that are used to match entries in - `memory ` for retrieval, and which are used as "values" (zero values), that are stored - and retrieved from memory, but not used in the match process (see `Match memories by field - `. see `field_weights ` additional details. + `memory ` for retrieval, and which are used as "values" (zero values) that are stored + and retrieved from memory but not used in the match process (see `Match memories by field + `; also determines the relative contribution of each key field to the match process; + see `field_weights ` additional details. normalize_field_weights : bool : default True determines whether `fields_weights ` are normalized over the number of keys, or @@ -1239,16 +1271,16 @@ class EMComposition(AutodiffComposition): see `Match memories by field ` for additional details. softmax_gain : float, ADAPTIVE or CONTROL - determines gain (inverse temperature) used for softmax normalizing the dot products of keys and memories - by the `softmax` function of the `softmax_nodes `; see `Softmax normalize matches - over fields ` for additional details. + determines gain (inverse temperature) used for softmax normalizing the summed distances of queries and keys in + memory by the `SoftMax` Function of the `softmax_node `; see `Softmax normalize + distances ` for additional details. softmax_threshold : float determines the threshold used to mask out small values in the softmax calculation; see *mask_threshold* under `Thresholding and Adaptive Gain ` for details). softmax_choice : WEIGHTED_AVG, ARG_MAX or PROBABILISTIC - determines how the softmax over dot products of keys and memories is used for retrieval; + determines how the softmax over distances of queries and keys in memory is used for retrieval; see `softmax_choice ` for a description of each option. storage_prob : float @@ -1300,61 +1332,47 @@ class EMComposition(AutodiffComposition): into a single vector used for the matching processing if `concatenate keys ` is True. This is not created if the **concatenate_queries** argument to the EMComposition's constructor is False or is overridden (see `concatenate_queries `), or there is only one - query_input_node. This node is named *CONCATENATE_KEYS* + query_input_node. This node is named *CONCATENATE_QUERIES* match_nodes : list[ProcessingMechanism] - `ProcessingMechanisms ` that receive the dot product of each key and those stored in + `ProcessingMechanisms ` that compute the dot product of each query and the key stored in the corresponding field of `memory ` (see `Match memories by field ` for additional details). These are named the same as the corresponding `query_input_nodes ` appended with the suffix *[MATCH to KEYS]*. - softmax_gain_control_nodes : list[ControlMechanism] + field_weight_nodes : list[ProcessingMechanism or GatingMechanism] + Nodes used to weight the distances computed by the `match_nodes ` with the + `field weight ` for the corresponding `key field ` + (see `Weight distances ` for implementation). These are named the same + as the corresponding `query_input_nodes `. + + weighted_match_nodes : list[ProcessingMechanism] + `ProcessingMechanisms ` that combine the `field weight ` + for each `key field ` with the dot product computed by the corresponding the + `match_node `. These are only implemented if `use_gating_for_weighting + ` is False (see `Weight distances ` + for details), and are named the same as the corresponding `query_input_nodes ` + appended with the suffix *[WEIGHTED MATCH]*. + + combined_matches_node : ProcessingMechanism + `ProcessingMechanism` that receives the weighted distances from the `weighted_match_nodes + ` if more than one `key field ` is specified + (or directly from `match_nodes ` if `use_gating_for_weighting + ` is True), and combines them into a single vector that is passed + to the `softmax_node ` for retrieval. This node is named *COMBINE MATCHES*. + + softmax_node : list[ProcessingMechanism] + `ProcessingMechanisms ` that computes the softmax over the summed distances of keys + and memories (output of the `combined_match_node `) + from the corresponding `match_nodes ` (see `Softmax over summed distances + ` for additional details). This is named *RETRIEVE* (as it yields the + softmax-weighted average over the keys in `memory `). + + softmax_gain_control_node : list[ControlMechanism] `ControlMechanisms ` that adaptively control the `softmax_gain ` - for the corresponding `softmax_nodes `. These are implemented only if - `softmax_gain ` is specified as *CONTROL* (see `softmax_gain - ` for details). - - softmax_nodes : list[ProcessingMechanism] - `ProcessingMechanisms ` that compute the softmax over the vectors received - from the corresponding `match_nodes ` (see `Softmax normalize matches over fields - ` for additional details). These are named the same as the corresponding - `query_input_nodes ` appended with the suffix *[SOFTMAX]*. - - field_weight_nodes : list[ProcessingMechanism] - `ProcessingMechanisms `, each of which use the `field weight ` - for a given `field ` as its (fixed) input and provides this to the corresponding - `weighted_softmax_node `. These are implemented only if more than one - `key field ` is specified (see `Fields ` for additional details), - and are replaced with `retrieval_gating_nodes ` if - `use_gating_for_weighting ` is True. These are named the same as the - corresponding `query_input_nodes ` appended with the suffix *[WEIGHT]*. - - weighted_softmax_nodes : list[ProcessingMechanism] - `ProcessingMechanisms `, each of which receives the output of the corresponding - `softmax_node ` and `field_weight_node ` - for a given `field `, and multiplies them to produce the weighted softmax for that field; - these are implemented only if more than one `key field ` is specified (see `Fields - ` for additional details) and `use_gating_for_weighting - ` is False (otherwise, `field_weights ` - are applied through output gating of the `softmax_nodes ` by the - `retrieval_gating_nodes `). These are named the same as the corresponding - `query_input_nodes ` appended with the suffix *[WEIGHTED SOFTMAX]*. - - retrieval_gating_nodes : list[GatingMechanism] - `GatingMechanisms ` that uses the `field weight ` for each - field to modulate the output of the corresponding `softmax_node ` before it - is passed to the `combined_softmax_node `. These are implemented - only if `use_gating_for_weighting ` is True and more than one - `key field ` is specified (see `Fields ` for additional details). - - combined_softmax_node : ProcessingMechanism - `ProcessingMechanism` that receives the softmax normalized dot products of the keys and memories from the - `softmax_nodes `, weighted by the `field_weights_nodes - ` if more than one `key field ` is specified - (or by `retrieval_gating_nodes ` if `use_gating_for_weighting - ` is True), and combines them into a single vector that is used to - retrieve the corresponding memory for each field from `memory ` (see `Retrieve values by - field ` for additional details). This node is named *RETRIEVE*. + of the `softmax_node `. This is implemented only if `softmax_gain + ` is specified as *CONTROL* (see `softmax_gain ` for + details). retrieved_nodes : list[ProcessingMechanism] `ProcessingMechanisms ` that receive the vector retrieved for each field in `memory @@ -1615,7 +1633,8 @@ def __init__(self, if memory_decay_rate is AUTO: memory_decay_rate = 1 / memory_capacity - self.use_storage_node = use_storage_node + self._use_storage_node = use_storage_node + self._use_gating_for_weighting = use_gating_for_weighting if softmax_gain == CONTROL: self.parameters.softmax_gain.modulable = False @@ -1643,7 +1662,10 @@ def __init__(self, **kwargs ) - self._validate_options_with_learning(softmax_choice, normalize_field_weights, enable_learning) + self._validate_options_with_learning(enable_learning, + use_gating_for_weighting, + learn_field_weights, + softmax_choice) self._construct_pathways(self.memory_template, self.memory_capacity, @@ -1655,10 +1677,10 @@ def __init__(self, self.softmax_choice, self.storage_prob, self.memory_decay_rate, - self.use_storage_node, + self._use_storage_node, self.enable_learning, self.learn_field_weights, - use_gating_for_weighting) + self._use_gating_for_weighting) # if torch_available: # from psyneulink.library.compositions.pytorchEMcompositionwrapper import PytorchEMCompositionWrapper @@ -1669,7 +1691,7 @@ def __init__(self, # Assign learning-related attributes self._set_learning_attributes() - if self.use_storage_node: + if self._use_storage_node: # --------------------------------------- # # CONDITION: @@ -1709,7 +1731,7 @@ def __init__(self, # Suppress warnings for no efferent Projections for node in self.value_input_nodes: node.output_port.parameters.require_projection_in_composition.set(False, override=True) - self.combined_softmax_node.output_port.parameters.require_projection_in_composition.set(False, override=True) + self.softmax_node.output_port.parameters.require_projection_in_composition.set(False, override=True) # Suppress field_weight_nodes as INPUT nodes of the Composition for node in self.field_weight_nodes: @@ -1861,7 +1883,7 @@ def _construct_entries(entry_template, num_entries, memory_fill=None)->np.ndarra # Get remaining entries populated with memory_fill remaining_entries = _construct_entries(memory_template[0], num_entries_needed, memory_fill) assert bool(num_entries_needed == len(remaining_entries)) - # I any remaining entries, concatenate them with the entries that were specified + # If any remaining entries, concatenate them with the entries that were specified if num_entries_needed: memory = np.concatenate((np.array(memory_template, dtype=object), np.array(remaining_entries, dtype=object))) @@ -1942,23 +1964,23 @@ def _parse_fields(self, and normalize_memories) # if concatenate_queries was forced to be False when user specified it as True, issue warning if user_specified_concatenate_queries and not parsed_concatenate_queries: - # Issue warning if concatenate_queries is True but either - # field weights are not all equal and/or normalize_memories is False + # Issue warning if concatenate_queries is True but: + # field weights are not all equal and/or + # normalize_memories is False and/or + # there is only one key fw_error_msg = nm_error_msg = fw_correction_msg = nm_correction_msg = None - if not all(np.all(keys_weights[i] == keys_weights[0] for i in range(len(keys_weights)))): - fw_error_msg = f" field weights ({field_weights}) are not all equal" - fw_correction_msg = f"remove `field_weights` specification or make them all the same." - if not normalize_memories: - nm_error_msg = f" normalize_memories is False" - nm_correction_msg = f" or set normalize_memories to True" - if fw_error_msg and nm_error_msg: - error_msg = f"{fw_error_msg} and {nm_error_msg}" - correction_msg = f"{fw_correction_msg} and/or {nm_correction_msg}" - else: - error_msg = fw_error_msg or nm_error_msg - correction_msg = fw_correction_msg or nm_correction_msg + if self.num_keys == 1: + error_msg = f"there is only one key" + correction_msg = "" + elif not all(np.all(keys_weights[i] == keys_weights[0] for i in range(len(keys_weights)))): + error_msg = f" field weights ({field_weights}) are not all equal" + correction_msg = (f" To use concatenation, remove `field_weights` " + f"specification or make them all the same.") + elif not normalize_memories: + error_msg = f" normalize_memories is False" + correction_msg = f" To use concatenation, set normalize_memories to True." warnings.warn(f"The 'concatenate_queries' arg for '{name}' is True but {error_msg}; " - f"concatenation will be ignored. To use concatenation, {correction_msg}.") + f"concatenation will be ignored.{correction_msg}") self.learning_rate = learning_rate return parsed_field_weights, parsed_field_names, parsed_concatenate_queries @@ -2021,19 +2043,21 @@ def _construct_pathways(self, self.concatenate_queries_node = self._construct_concatenate_queries_node(concatenate_queries) self.match_nodes = self._construct_match_nodes(memory_template, memory_capacity, concatenate_queries,normalize_memories) - self.softmax_nodes = self._construct_softmax_nodes(memory_capacity, - field_weights, - softmax_gain, - softmax_threshold, - softmax_choice) self.field_weight_nodes = self._construct_field_weight_nodes(field_weights, concatenate_queries, use_gating_for_weighting) - self.weighted_softmax_nodes = self._construct_weighted_softmax_nodes(memory_capacity, use_gating_for_weighting) - self.softmax_gain_control_nodes = self._construct_softmax_gain_control_nodes(softmax_gain) - self.combined_softmax_node = self._construct_combined_softmax_node(memory_capacity, + self.weighted_match_nodes = self._construct_weighted_match_nodes(memory_capacity, field_weights) + + self.combined_matches_node = self._construct_combined_matches_node(memory_capacity, field_weighting, use_gating_for_weighting) + self.softmax_node = self._construct_softmax_node(memory_capacity, + softmax_gain, + softmax_threshold, + softmax_choice) + + self.softmax_gain_control_node = self._construct_softmax_gain_control_node(softmax_gain) + self.retrieved_nodes = self._construct_retrieved_nodes(memory_template) if use_storage_node: @@ -2043,86 +2067,73 @@ def _construct_pathways(self, # Do some validation and get singleton softmax and match Nodes for concatenated queries if self.concatenate_queries: - softmax_node = self.softmax_nodes.pop() - assert not self.softmax_nodes, \ - f"PROGRAM ERROR: Too many softmax_nodes ({len(self.softmax_nodes)}) for concatenated queries." - assert len(self.softmax_gain_control_nodes) <= 1, \ - (f"PROGRAM ERROR: Too many softmax_gain_control_nodes " - f"{len(self.softmax_gain_control_nodes)}) for concatenated queries.") - match_node = self.match_nodes.pop() - assert not self.match_nodes, \ + assert len(self.match_nodes) == 1, \ f"PROGRAM ERROR: Too many match_nodes ({len(self.match_nodes)}) for concatenated queries." assert not self.field_weight_nodes, \ f"PROGRAM ERROR: There should be no field_weight_nodes for concatenated queries." # Construct Pathways -------------------------------------------------------------------------------- + # LEARNING NOT ENABLED -------------------------------------------------- # Set up pathways WITHOUT PsyNeuLink learning pathways if not self.enable_learning: self.add_nodes(self.query_input_nodes + self.value_input_nodes) - if self.concatenate_queries: - self.add_nodes([self.concatenate_queries_node, match_node, softmax_node]) - else: - self.add_nodes(self.match_nodes + - self.softmax_nodes + - self.field_weight_nodes + - self.weighted_softmax_nodes) - self.add_nodes(self.softmax_gain_control_nodes + - [self.combined_softmax_node] + - self.retrieved_nodes) if use_storage_node: self.add_node(self.storage_node) - # self.add_projections(proj for proj in self.storage_node.efferents) - + if self.concatenate_queries_node: + self.add_node(self.concatenate_queries_node) + self.add_nodes(self.match_nodes + self.field_weight_nodes + self.weighted_match_nodes) + if self.combined_matches_node: + self.add_node(self.combined_matches_node) + self.add_nodes([self.softmax_node] + self.retrieved_nodes) + if self.softmax_gain_control_node: + self.add_node(self.softmax_gain_control_node) + + # LEARNING ENABLED ----------------------------------------------------- # Set up pathways WITH psyneulink backpropagation learning field weights else: - - # Key pathways - for i in range(self.num_keys): - # Regular pathways - if not self.concatenate_queries: - pathway = [self.query_input_nodes[i], - self.match_nodes[i], - self.softmax_nodes[i], - self.combined_softmax_node] - if self.weighted_softmax_nodes: - pathway.insert(3, self.weighted_softmax_nodes[i]) - # if self.softmax_gain_control_nodes: - # pathway.insert(4, self.softmax_gain_control_nodes[i]) - # Key-concatenated pathways + # Query-specific pathways + if not self.concatenate_queries: + if self.num_keys == 1: + self.add_linear_processing_pathway([self.query_input_nodes[i], + self.match_nodes[i], + self.softmax_node]) else: + for i in range(self.num_keys): + pathway = [self.query_input_nodes[i], + self.match_nodes[i], + self.combined_matches_node] + if self.weighted_match_nodes: + pathway.insert(2, self.weighted_match_nodes[i]) + self.add_linear_processing_pathway(pathway) + self.add_linear_processing_pathway([self.combined_matches_node, self.softmax_node]) + # Query-concatenated pathways + else: + for i in range(self.num_keys): pathway = [self.query_input_nodes[i], self.concatenate_queries_node, - match_node, - softmax_node, - self.combined_softmax_node] - # if self.softmax_gain_control_nodes: - # pathway.insert(4, self.softmax_gain_control_nodes[0]) # Only one, ensured above - # self.add_backpropagation_learning_pathway(pathway) - self.add_linear_processing_pathway(pathway) + self.match_nodes[0]] + self.add_linear_processing_pathway(pathway) + self.add_linear_processing_pathway([self.match_nodes[0], self.softmax_node]) # softmax gain control is specified: - for gain_control_node in self.softmax_gain_control_nodes: - self.add_node(gain_control_node) + if self.softmax_gain_control_node: + self.add_node(self.softmax_gain_control_node) # field_weights -> weighted_softmax pathways if self.field_weight_nodes: for i in range(self.num_keys): - # self.add_backpropagation_learning_pathway([self.field_weight_nodes[i], - # self.weighted_softmax_nodes[i]]) - self.add_linear_processing_pathway([self.field_weight_nodes[i], self.weighted_softmax_nodes[i]]) + self.add_linear_processing_pathway([self.field_weight_nodes[i], self.weighted_match_nodes[i]]) self.add_nodes(self.value_input_nodes) # Retrieval pathways for i in range(len(self.retrieved_nodes)): - # self.add_backpropagation_learning_pathway([self.combined_softmax_node, self.retrieved_nodes[i]]) - self.add_linear_processing_pathway([self.combined_softmax_node, self.retrieved_nodes[i]]) + self.add_linear_processing_pathway([self.softmax_node, self.retrieved_nodes[i]]) # Storage Nodes if use_storage_node: self.add_node(self.storage_node) - # self.add_projections(proj for proj in self.storage_node.efferents) def _construct_query_input_nodes(self, field_weights)->list: """Create one node for each key to be used as cue for retrieval (and then stored) in memory. @@ -2145,7 +2156,7 @@ def _construct_query_input_nodes(self, field_weights)->list: def _construct_value_input_nodes(self, field_weights)->list: """Create one input node for each value to be stored in memory. - Used to assign new set of weights for Projection for combined_softmax_node -> retrieved_node[i] + Used to assign new set of weights for Projection for combined_matches_node -> retrieved_node[i] where i is selected randomly without replacement from (0->memory_capacity) """ @@ -2171,14 +2182,14 @@ def _construct_concatenate_queries_node(self, concatenate_queries)->ProcessingMe return None else: return ProcessingMechanism(function=Concatenate, - input_ports=[{NAME: 'CONCATENATE_QUERIES', + input_ports=[{NAME: 'CONCATENATE', SIZE: len(self.query_input_nodes[i].output_port.value), PROJECTIONS: MappingProjection( name=f'{self.key_names[i]} to CONCATENATE', sender=self.query_input_nodes[i].output_port, matrix=IDENTITY_MATRIX)} for i in range(self.num_keys)], - name='CONCATENATE KEYS') + name=CONCATENATE_QUERIES_NAME) def _construct_match_nodes(self, memory_template, memory_capacity, concatenate_queries, normalize_memories)->list: """Create nodes that, for each key field, compute the similarity between the input and each item in memory. @@ -2225,60 +2236,7 @@ def _construct_match_nodes(self, memory_template, memory_capacity, concatenate_q return match_nodes - def _validate_options_with_learning(self, softmax_choice, normalize_field_weights, enable_learning): - if softmax_choice in {ARG_MAX, PROBABILISTIC} and enable_learning: - warnings.warn(f"The 'softmax_choice' arg of '{self.name}' is set to '{softmax_choice}' with " - f"'enable_learning' set to True (or a list); this will generate an error if its " - f"'learn' method is called. Set 'softmax_choice' to WEIGHTED_AVG before learning.") - - if enable_learning and not normalize_field_weights: - warnings.warn(f"The 'normalize_field_weights' arg of '{self.name}' is set to False with " - f"'enable_learning' set to True (or a list); this may generate an error if " - f"the 'loss_spec' used for learning requires values to be between 0 and 1.") - - def _construct_softmax_nodes(self, memory_capacity, field_weights, - softmax_gain, softmax_threshold, softmax_choice)->list: - """Create nodes that, for each key field, compute the softmax over the similarities between the input and the - memories in the corresponding match_node. - """ - - # Get indices of field_weights that specify keys: - key_weights = [field_weights[i] for i in self.key_indices] - - if softmax_choice == ARG_MAX: - # ARG_MAX would return entry multiplied by its dot product - # ARG_MAX_INDICATOR returns the entry unmodified - softmax_choice = ARG_MAX_INDICATOR - - softmax_nodes = [ProcessingMechanism(input_ports={SIZE:memory_capacity, - PROJECTIONS: MappingProjection( - sender=match_node.output_port, - matrix=IDENTITY_MATRIX, - name=f'MATCH to SOFTMAX for {self.key_names[i]}')}, - function=SoftMax(gain=softmax_gain, - mask_threshold=softmax_threshold, - output=softmax_choice, - adapt_entropy_weighting=.95), - name='SOFTMAX' if len(self.match_nodes) == 1 - else f'{self.key_names[i]} [SOFTMAX]') - for i, match_node in enumerate(self.match_nodes)] - - return softmax_nodes - - def _construct_softmax_gain_control_nodes(self, softmax_gain)->list: - """Create nodes that set the softmax gain (inverse temperature) for each softmax_node.""" - - softmax_gain_control_nodes = [] - if softmax_gain == CONTROL: - softmax_gain_control_nodes = [ControlMechanism(monitor_for_control=match_node, - control_signals=[(GAIN, self.softmax_nodes[i])], - function=get_softmax_gain, - name='SOFTMAX GAIN CONTROL' if len(self.softmax_nodes) == 1 - else f'SOFTMAX GAIN CONTROL {self.key_names[i]}') - for i, match_node in enumerate(self.match_nodes)] - - return softmax_gain_control_nodes - + # FIX: CONVERT TO _construct_weight_control_nodes def _construct_field_weight_nodes(self, field_weights, concatenate_queries, use_gating_for_weighting)->list: """Create ProcessingMechanisms that weight each key's softmax contribution to the retrieved values.""" @@ -2291,71 +2249,131 @@ def _construct_field_weight_nodes(self, field_weights, concatenate_queries, use_ PARAMS:{DEFAULT_INPUT: DEFAULT_VARIABLE}, NAME: 'OUTCOME'}, gate=[key_match_pair[1].output_ports[0]], - name= 'RETRIEVAL WEIGHTING' if self.num_keys == 1 - else f'RETRIEVAL WEIGHTING {i}') + name= 'WEIGHT' if self.num_keys == 1 + else f'{self.key_names[i]}{WEIGHT_AFFIX}') for i, key_match_pair in enumerate(zip(self.query_input_nodes, - self.softmax_nodes))] + self.match_nodes))] else: field_weight_nodes = [ProcessingMechanism(input_ports={VARIABLE: np.array(field_weights[self.key_indices[i]]), PARAMS: {DEFAULT_INPUT: DEFAULT_VARIABLE}, NAME: 'FIELD_WEIGHT'}, - name= 'WEIGHT' if self.num_keys == 1 - else f'{self.key_names[i]} [WEIGHT]') + name= WEIGHT if self.num_keys == 1 + else f'{self.key_names[i]}{WEIGHT_AFFIX}') for i in range(self.num_keys)] return field_weight_nodes - def _construct_weighted_softmax_nodes(self, memory_capacity, use_gating_for_weighting)->list: - - if use_gating_for_weighting: - return [] - - weighted_softmax_nodes = \ - [ProcessingMechanism( - default_variable=[self.softmax_nodes[i].output_port.value, - self.softmax_nodes[i].output_port.value], - input_ports=[ - {PROJECTIONS: MappingProjection(sender=sm_fw_pair[0], - matrix=IDENTITY_MATRIX, - name=f'SOFTMAX to WEIGHTED SOFTMAX for {self.key_names[i]}')}, - {PROJECTIONS: MappingProjection(sender=sm_fw_pair[1], - matrix=FULL_CONNECTIVITY_MATRIX, - name=f'WEIGHT to WEIGHTED SOFTMAX for {self.key_names[i]}')}], - function=LinearCombination(operation=PRODUCT), - name=self.key_names[i] + WEIGHTED_SOFTMAX_AFFIX) - for i, sm_fw_pair in enumerate(zip(self.softmax_nodes, + def _construct_weighted_match_nodes(self, memory_capacity, field_weights)->list: + """Create nodes that weight the output of the match node for each key.""" + + weighted_match_nodes = \ + [ProcessingMechanism(default_variable=[self.match_nodes[i].output_port.value, + self.match_nodes[i].output_port.value], + input_ports=[{PROJECTIONS: + MappingProjection(sender=match_fw_pair[0], + matrix=IDENTITY_MATRIX, + name=f'{MATCH} to {WEIGHTED_MATCH_NODE_NAME} ' + f'for {self.key_names[i]}')}, + {PROJECTIONS: + MappingProjection(sender=match_fw_pair[1], + matrix=FULL_CONNECTIVITY_MATRIX, + name=f'{WEIGHT} to {WEIGHTED_MATCH_NODE_NAME} ' + f'for {self.key_names[i]}')}], + function=LinearCombination(operation=PRODUCT), + name=self.key_names[i] + WEIGHTED_MATCH_AFFIX) + for i, match_fw_pair in enumerate(zip(self.match_nodes, self.field_weight_nodes))] - return weighted_softmax_nodes - def _construct_combined_softmax_node(self, + return weighted_match_nodes + + def _construct_softmax_gain_control_node(self, softmax_gain)->Optional[ControlMechanism]: + """Create nodes that set the softmax gain (inverse temperature) for each softmax_node.""" + + if softmax_gain == CONTROL: + return ControlMechanism(monitor_for_control=self.combined_matches_node, + control_signals=[(GAIN, self.softmax_node)], + function=get_softmax_gain, + name='SOFTMAX GAIN CONTROL') + else: + return None + + def _construct_combined_matches_node(self, memory_capacity, field_weighting, use_gating_for_weighting )->ProcessingMechanism: - """Create nodes that compute the weighting of each item in memory. - """ + """Create node that combines weighted matches for all keys into one match vector.""" + + if self.num_keys == 1 or self.concatenate_queries_node: + return if not field_weighting or use_gating_for_weighting: - # If use_gating_for_weighting, then softmax_nodes are output gated by gating nodes - input_source = self.softmax_nodes + input_source = self.match_nodes else: - input_source = self.weighted_softmax_nodes + input_source = self.weighted_match_nodes - combined_softmax_node = ( + combined_matches_node = ( ProcessingMechanism(input_ports=[{SIZE:memory_capacity, - # PROJECTIONS:[s for s in input_source]}], PROJECTIONS:[MappingProjection(sender=s, matrix=IDENTITY_MATRIX, - name=f'WEIGHTED SOFTMAX to RETRIEVAL for ' - f'{self.key_names[i]}') + name=f'{WEIGHTED_MATCH_NODE_NAME} ' + f'for {self.key_names[i]} to ' + f'{COMBINE_MATCHES_NODE_NAME}') for i, s in enumerate(input_source)]}], - name=COMBINED_SOFTMAX_NODE_NAME)) + name=COMBINE_MATCHES_NODE_NAME)) + + assert len(combined_matches_node.output_port.value) == memory_capacity, \ + 'PROGRAM ERROR: number of items in combined_matches_node ' \ + f'({len(combined_matches_node.output_port)}) does not match memory_capacity ({self.memory_capacity})' - assert len(combined_softmax_node.output_port.value) == memory_capacity, \ - 'PROGRAM ERROR: number of items in combined_softmax_node ' \ - '({len(combined_softmax_node.output_port)}) does not match memory_capacity ({self.memory_capacity})' + return combined_matches_node - return combined_softmax_node + def _construct_softmax_node(self, memory_capacity, softmax_gain, softmax_threshold, softmax_choice)->list: + """Create node that applies softmax to output of combined_matches_node.""" + + if self.num_keys == 1 or self.concatenate_queries_node: + input_source = self.match_nodes[0] + proj_name =f'{MATCH} to {SOFTMAX_NODE_NAME}' + # elif self.concatenate_queries_node: + # input_source = self.concatenate_queries_node + # proj_name =f'{CONCATENATE_QUERIES_NAME} to {SOFTMAX_NODE_NAME}' + else: + input_source = self.combined_matches_node + proj_name =f'{COMBINE_MATCHES_NODE_NAME} to {SOFTMAX_NODE_NAME}' + + if softmax_choice == ARG_MAX: + # ARG_MAX would return entry multiplied by its dot product + # ARG_MAX_INDICATOR returns the entry unmodified + softmax_choice = ARG_MAX_INDICATOR + + softmax_node = ProcessingMechanism(input_ports={SIZE:memory_capacity, + PROJECTIONS: MappingProjection( + sender=input_source, + matrix=IDENTITY_MATRIX, + name=proj_name)}, + function=SoftMax(gain=softmax_gain, + mask_threshold=softmax_threshold, + output=softmax_choice, + adapt_entropy_weighting=.95), + name=SOFTMAX_NODE_NAME) + + return softmax_node + + def _validate_options_with_learning(self, + enable_learning, + use_gating_for_weighting, + learn_field_weights, + softmax_choice): + if use_gating_for_weighting and learn_field_weights: + warnings.warn(f"The 'learn_field_weights' option for '{self.name}' cannot be used with " + f"'use_gating_for_weighting' set to True; this will generate an error if its " + f"'learn' method is called. Set 'use_gating_for_weighting' to True in order " + f"to enable learning of field weights.") + + if softmax_choice in {ARG_MAX, PROBABILISTIC} and enable_learning: + warnings.warn(f"The 'softmax_choice' arg of '{self.name}' is set to '{softmax_choice}' with " + f"'enable_learning' set to True (or a list); this will generate an error if its " + f"'learn' method is called. Set 'softmax_choice' to WEIGHTED_AVG before learning.") def _construct_retrieved_nodes(self, memory_template)->list: """Create nodes that report the value field(s) for the item(s) matched in memory. @@ -2364,7 +2382,7 @@ def _construct_retrieved_nodes(self, memory_template)->list: [ProcessingMechanism(input_ports={SIZE: len(self.query_input_nodes[i].variable[0]), PROJECTIONS: MappingProjection( - sender=self.combined_softmax_node, + sender=self.softmax_node, matrix=memory_template[:,i], name=f'MEMORY FOR {self.key_names[i]} [RETRIEVE KEY]') }, @@ -2375,7 +2393,7 @@ def _construct_retrieved_nodes(self, memory_template)->list: [ProcessingMechanism(input_ports={SIZE: len(self.value_input_nodes[i].variable[0]), PROJECTIONS: MappingProjection( - sender=self.combined_softmax_node, + sender=self.softmax_node, matrix=memory_template[:, i + self.num_keys], name=f'MEMORY FOR {self.value_names[i]} [RETRIEVE VALUE]')}, @@ -2466,13 +2484,13 @@ def execute(self, **kwargs): """Set input to weights of Projections to match_nodes and retrieved_nodes if not use_storage_node.""" results = super().execute(inputs=inputs, context=context, **kwargs) - if not self.use_storage_node: + if not self._use_storage_node: self._store_memory(inputs, context) return results def _store_memory(self, inputs, context): """Store inputs to query and value nodes in memory - Store memories in weights of Projections to softmax_nodes (queries) and retrieved_nodes (values). + Store memories in weights of Projections to match_nodes (queries) and retrieved_nodes (values). Note: inputs argument is ignored (included for compatibility with function of MemoryFunctions class; storage is handled by call to EMComopsition._encode_memory """ @@ -2549,13 +2567,18 @@ def _encode_memory(self, context=None): @handle_external_context() def learn(self, *args, **kwargs)->list: """Override to check for inappropriate use of ARG_MAX or PROBABILISTIC options for retrieval with learning""" - arg = self.parameters.softmax_choice.get(kwargs[CONTEXT]) - if arg in {ARG_MAX, PROBABILISTIC}: + softmax_choice = self.parameters.softmax_choice.get(kwargs[CONTEXT]) + use_gating_for_weighting = self._use_gating_for_weighting + learn_field_weights = self.parameters.learn_field_weights.get(kwargs[CONTEXT]) + + if use_gating_for_weighting and learn_field_weights: + raise EMCompositionError(f"Field weights cannot be learned when 'use_gating_for_weighting' is True; " + f"Construct '{self.name}' with the 'learn_field_weights' arg set to False.") + + if softmax_choice in {ARG_MAX, PROBABILISTIC}: raise EMCompositionError(f"The ARG_MAX and PROBABILISTIC options for the 'softmax_choice' arg " f"of '{self.name}' cannot be used during learning; change to WEIGHTED_AVG.") - if self.loss_spec in {Loss.BINARY_CROSS_ENTROPY} and not self.normalize_field_weights: - raise EMCompositionError(f"The 'loss_spec' arg of '{self.name}' is set to '{self.loss_spec.name}' with " - f"'normalize_field_weights' set to False; this must be True to use this loss_spec.") + return super().learn(*args, **kwargs) def _get_execution_mode(self, execution_mode): diff --git a/psyneulink/library/compositions/pytorchwrappers.py b/psyneulink/library/compositions/pytorchwrappers.py index 69a4c9a74dc..cdf13733d80 100644 --- a/psyneulink/library/compositions/pytorchwrappers.py +++ b/psyneulink/library/compositions/pytorchwrappers.py @@ -1123,7 +1123,8 @@ class PytorchProjectionWrapper(): """ - def __init__(self, projection, + def __init__(self, + projection, pnl_proj, component_idx, port_idx, device, @@ -1131,7 +1132,7 @@ def __init__(self, projection, receiver=None, context=None): self._projection = projection # Projection being wrapped (may *not* be the one being learned; see note above) - self._pnl_proj = pnl_proj # Projection that directly projects to/from sender/receiver (see above) + self._pnl_proj = pnl_proj # Projection that directly projects to/from sender/receiver (see above) self._idx = component_idx # Index of Projection in Composition's list of projections self._port_idx = port_idx # Index of sender's port (used by LLVM) self._value_idx = 0 # Index of value in sender's value (used in aggregate_afferents) diff --git a/tests/composition/test_emcomposition.py b/tests/composition/test_emcomposition.py index d0206a020a4..55c01ad7b51 100644 --- a/tests/composition/test_emcomposition.py +++ b/tests/composition/test_emcomposition.py @@ -185,16 +185,14 @@ def test_structure(self, assert isinstance(em.concatenate_queries_node, Mechanism) == concatenate_node if em.concatenate_queries: assert em.field_weight_nodes == [] - assert bool(softmax_gain == CONTROL) == bool(len(em.softmax_gain_control_nodes)) + assert bool(softmax_gain == CONTROL) == bool(em.softmax_gain_control_node) else: if num_keys > 1: assert len(em.field_weight_nodes) == num_keys else: assert em.field_weight_nodes == [] if softmax_gain == CONTROL: - assert len(em.softmax_gain_control_nodes) == num_keys - else: - assert em.softmax_gain_control_nodes == [] + assert em.softmax_gain_control_node assert len(em.retrieved_nodes) == num_fields def test_memory_fill(start, memory_fill): @@ -252,23 +250,6 @@ def test_softmax_choice(self): f"'learn' method is called. Set 'softmax_choice' to WEIGHTED_AVG before learning.") assert warning_msg in str(warning[0].message) - def test_normalize_field_weights_with_learning_enabled(self): - with pytest.warns(UserWarning) as warning: - em = EMComposition(normalize_field_weights=False, - enable_learning=True, - memory_fill=(0,.1), - loss_spec=pnl.Loss.BINARY_CROSS_ENTROPY) - warning_msg = (f"The 'normalize_field_weights' arg of 'EM_Composition' is set to False with " - f"'enable_learning' set to True (or a list); this may generate an error if the " - f"'loss_spec' used for learning requires values to be between 0 and 1.") - assert warning_msg in str(warning[0].message) - - with pytest.raises(EMCompositionError) as error_text: - em.learn() - assert (f"The 'loss_spec' arg of 'EM_Composition' is set to 'BINARY_CROSS_ENTROPY' with " - f"'normalize_field_weights' set to False; this must be True to use this loss_spec." - in str(error_text.value)) - @pytest.mark.pytorch class TestExecution: @@ -304,11 +285,11 @@ class TestExecution: (4, [[[1,2,3],[4,6]], # Equal field_weights (but not concatenated) [[1,2,5],[4,6]], [[1,2,10],[4,6]]], (0,.01), 4, 0, [1,1], None, None, 100, 0, [[[1, 2, 3]], - [[4, 6]]], [[0.90323092, - 1.80586151, - 4.00008914], - [3.61161172, - 5.41731422]] + [[4, 6]]], [[0.99750462, + 1.99499376, + 3.51623568], + [3.98998465, + 5.9849743]] ), (5, [[[1,2,3],[4,6]], # Equal field_weights with concatenation [[1,2,5],[4,8]], @@ -321,44 +302,44 @@ class TestExecution: (6, [[[1,2,3],[4,6]], # Unequal field_weights [[1,2,5],[4,8]], [[1,2,10],[4,10]]], (0,.01), 4, 0, [9,1], None, None, 100, 0, [[[1, 2, 3]], - [[4, 6]]], [[0.96869477, - 1.93719534, - 3.1307577], - [3.87435467, - 6.02081578]]), + [[4, 6]]], [[0.99996025, + 1.99992024, + 3.19317783], + [3.99984044, + 6.19219795]]), (7, [[[1,2,3],[4,6]], # Store + no decay [[1,2,5],[4,8]], [[1,2,10],[4,10]]], (0,.01), 4, 0, [9,1], None, None, 100, 1, [[[1, 2, 3]], - [[4, 6]]], [[0.96869477, - 1.93719534, - 3.1307577], - [3.87435467, - 6.02081578]]), + [[4, 6]]], [[0.99996025, + 1.99992024, + 3.19317783], + [3.99984044, + 6.19219795]]), (8, [[[1,2,3],[4,6]], # Store + default decay (should be AUTO) [[1,2,5],[4,8]], [[1,2,10],[4,10]]], (0,.01), 4, None, [9,1], None, None, 100, 1, [[[1, 2, 3]], - [[4, 6]]], [[0.96869477, - 1.93719534, - 3.1307577 ], - [3.87435467, - 6.02081578]]), + [[4, 6]]], [[0.99996025, + 1.99992024, + 3.19317783], + [3.99984044, + 6.19219795]]), (9, [[[1,2,3],[4,6]], # Store + explicit AUTO decay [[1,2,5],[4,8]], [[1,2,10],[4,10]]], (0,.01), 4, AUTO, [9,1], None, None, 100, 1, [[[1, 2, 3]], - [[4, 6]]], [[0.96869477, - 1.93719534, - 3.1307577 ], - [3.87435467, - 6.02081578]]), + [[4, 6]]], [[0.99996025, + 1.99992024, + 3.19317783], + [3.99984044, + 6.19219795]]), (10, [[[1,2,3],[4,6]], # Store + numerical decay [[1,2,5],[4,8]], [[1,2,10],[4,10]]], (0,.01), 4, .1, [9,1], None, None, 100, 1, [[[1, 2, 3]], - [[4, 6]]], [[0.96869477, - 1.93719534, - 3.1307577 ], - [3.87435467, - 6.02081578]]), - (11, [[[1,2,3],[4,6]], # Same as 10, but with equal weights and concatenate keysdd + [[4, 6]]], [[0.99996025, + 1.99992024, + 3.19317783], + [3.99984044, + 6.19219795]]), + (11, [[[1,2,3],[4,6]], # Same as 10, but with equal weights and concatenate keys [[1,2,5],[4,8]], [[1,2,10],[4,10]]], (0,.01), 4, .1, [1,1], True, None, 100, 1, [[[1, 2, 3]], [[4, 6]]], [[0.99922544, @@ -366,6 +347,8 @@ class TestExecution: 3.38989346], [3.99689126, 6.38682264]]), +# [3.99984044, +# 6.19219795]]), ] args_names = "test_num, memory_template, memory_fill, memory_capacity, memory_decay_rate, field_weights, " \ @@ -376,7 +359,7 @@ class TestExecution: @pytest.mark.parametrize('enable_learning', [False, True], ids=['no_learning','learning']) @pytest.mark.composition @pytest.mark.parametrize('exec_mode', [pnl.ExecutionMode.Python, pnl.ExecutionMode.PyTorch]) - def test_simple_execution_witemhout_learning(self, + def test_simple_execution_without_learning(self, exec_mode, enable_learning, test_num, @@ -441,10 +424,10 @@ def test_simple_execution_witemhout_learning(self, np.testing.assert_allclose(retrieved, expected) # Validate that sum of weighted softmax distributions in field_weight_node itself sums to 1 - np.testing.assert_allclose(np.sum(em.combined_softmax_node.value), 1.0, atol=1e-15) + np.testing.assert_allclose(np.sum(em.softmax_node.value), 1.0, atol=1e-15) # Validate that sum of its output ports also sums to 1 - np.testing.assert_allclose(np.sum([port.value for port in em.combined_softmax_node.output_ports]), + np.testing.assert_allclose(np.sum([port.value for port in em.softmax_node.output_ports]), 1.0, atol=1e-15) # Validate storage diff --git a/tests/scheduling/test_condition.py b/tests/scheduling/test_condition.py index 182d85bfd7a..a823f534f29 100644 --- a/tests/scheduling/test_condition.py +++ b/tests/scheduling/test_condition.py @@ -25,7 +25,15 @@ class TestModule: def test_all_attr_parity(self): - missing = set(gs.condition.__all__) - set(pnl.core.scheduling.condition.__all__) + missing = { + c for c + in set(gs.condition.__all__) - set(pnl.core.scheduling.condition.__all__) + if ( + 'ConsiderationSetExecution' not in c + and 'EnvironmentStateUpdate' not in c + and 'EnvironmentSequence' not in c + ) + } assert len(missing) == 0, (f'Conditions in graph_scheduler must be added to psyneulink condition.py: {missing}') diff --git a/tests/scheduling/test_scheduler.py b/tests/scheduling/test_scheduler.py index 07feba1fcad..7b768c7849d 100644 --- a/tests/scheduling/test_scheduler.py +++ b/tests/scheduling/test_scheduler.py @@ -1574,80 +1574,49 @@ def test_time_termination_measures(self, comp_mode, timescale, expected): np.testing.assert_allclose(result, expected) @pytest.mark.composition - @pytest.mark.parametrize("condition,scale,expected_result", - [(pnl.BeforeNCalls, TimeScale.TRIAL, [[.05, .05]]), - (pnl.BeforeNCalls, TimeScale.PASS, [[.05, .05]]), - (pnl.EveryNCalls, None, [[0.05, .05]]), - (pnl.AtNCalls, TimeScale.TRIAL, [[.25, .25]]), - (pnl.AtNCalls, TimeScale.RUN, [[.25, .25]]), - (pnl.AfterNCalls, TimeScale.TRIAL, [[.25, .25]]), - (pnl.AfterNCalls, TimeScale.PASS, [[.05, .05]]), - (pnl.WhenFinished, None, [[1.0, 1.0]]), - (pnl.WhenFinishedAny, None, [[1.0, 1.0]]), - (pnl.WhenFinishedAll, None, [[1.0, 1.0]]), - (pnl.All, None, [[1.0, 1.0]]), - (pnl.Any, None, [[1.0, 1.0]]), - (pnl.Not, None, [[.05, .05]]), - (pnl.AllHaveRun, None, [[.05, .05]]), - (pnl.Always, None, [[0.05, 0.05]]), - (pnl.AtPass, None, [[.3, .3]]), - (pnl.AtTrial, None, [[0.05, 0.05]]), + @pytest.mark.parametrize("condition,condition_params,expected_result", + [(pnl.BeforeNCalls, {"time_scale": TimeScale.TRIAL, "n": 5}, [[.05, .05]]), + (pnl.BeforeNCalls, {"time_scale": TimeScale.PASS, "n": 5}, [[.05, .05]]), + (pnl.EveryNCalls, {"n": 1}, [[0.05, .05]]), + (pnl.AtNCalls, {"time_scale": TimeScale.TRIAL, "n": 5}, [[.25, .25]]), + (pnl.AtNCalls, {"time_scale": TimeScale.RUN, "n": 5}, [[.25, .25]]), + (pnl.AfterNCalls, {"time_scale": TimeScale.TRIAL, "n": 5}, [[.25, .25]]), + # Mechanisms run only once per PASS unless they are in 'run_until_finished' mode. + (pnl.AfterNCalls, {"time_scale": TimeScale.PASS, "n": 1}, [[.05, .05]]), + (pnl.WhenFinished, {}, [[1.0, 1.0]]), + (pnl.WhenFinishedAny, {}, [[1.0, 1.0]]), + (pnl.WhenFinishedAll, {}, [[1.0, 1.0]]), + (pnl.All, {}, [[1.0, 1.0]]), + (pnl.Any, {}, [[1.0, 1.0]]), + (pnl.Not, {}, [[.05, .05]]), + (pnl.AllHaveRun, {}, [[.05, .05]]), + (pnl.Always, {}, [[0.05, 0.05]]), + (pnl.AtPass, {"n": 5}, [[.3, .3]]), + (pnl.AtTrial, {"n": 0}, [[0.05, 0.05]]), #(pnl.Never), #TODO: Find a good test case for this! ]) - # 'LLVM' mode is not supported, because synchronization of compiler and - # python values during execution is not implemented. - @pytest.mark.usefixtures("comp_mode_no_llvm") - def test_scheduler_conditions(self, comp_mode, condition, scale, expected_result): - decisionMaker = pnl.DDM( - function=pnl.DriftDiffusionIntegrator(starting_value=0, - threshold=1, - noise=0.0, - time_step_size=1.0), - reset_stateful_function_when=pnl.AtTrialStart(), - execute_until_finished=False, - # Use only the decision variable in this test - output_ports=[pnl.DECISION_VARIABLE], - name='DDM') + def test_scheduler_conditions(self, comp_mode, condition, condition_params, expected_result): + decisionMaker = pnl.DDM(function=pnl.DriftDiffusionIntegrator(starting_value=0, + threshold=1, + noise=0.0, + time_step_size=1.0), + reset_stateful_function_when=pnl.AtTrialStart(), + execute_until_finished=False, + # Use only the decision variable in this test + output_ports=[pnl.DECISION_VARIABLE], + name='DDM') response = pnl.ProcessingMechanism(size=2, name="GATE") comp = pnl.Composition() comp.add_linear_processing_pathway([decisionMaker, response]) - if condition is pnl.BeforeNCalls: - comp.scheduler.add_condition(response, condition(decisionMaker, 5, - time_scale=scale)) - elif condition is pnl.AtNCalls: - comp.scheduler.add_condition(response, condition(decisionMaker, 5, - time_scale=scale)) - elif condition is pnl.AfterNCalls: - # Mechanisms run only once per PASS unless they are in - # 'run_until_finished' mode. - c = 1 if scale is TimeScale.PASS else 5 - comp.scheduler.add_condition(response, condition(decisionMaker, c, - time_scale=scale)) - elif condition is pnl.EveryNCalls: - comp.scheduler.add_condition(response, condition(decisionMaker, 1)) - elif condition is pnl.WhenFinished: - comp.scheduler.add_condition(response, condition(decisionMaker)) - elif condition is pnl.WhenFinishedAny: - comp.scheduler.add_condition(response, condition(decisionMaker)) - elif condition is pnl.WhenFinishedAll: - comp.scheduler.add_condition(response, condition(decisionMaker)) - elif condition is pnl.All: - comp.scheduler.add_condition(response, condition(pnl.WhenFinished(decisionMaker))) - elif condition is pnl.Any: - comp.scheduler.add_condition(response, condition(pnl.WhenFinished(decisionMaker))) - elif condition is pnl.Not: + if condition in {pnl.All, pnl.Any, pnl.Not}: comp.scheduler.add_condition(response, condition(pnl.WhenFinished(decisionMaker))) - elif condition is pnl.AllHaveRun: - comp.scheduler.add_condition(response, condition(decisionMaker)) - elif condition is pnl.Always: - comp.scheduler.add_condition(response, condition()) - elif condition is pnl.AtPass: - comp.scheduler.add_condition(response, condition(5)) - elif condition is pnl.AtTrial: - comp.scheduler.add_condition(response, condition(0)) + elif condition in {pnl.Always, pnl.Never, pnl.AtPass, pnl.AtTrial}: + comp.scheduler.add_condition(response, condition(**condition_params)) + else: + comp.scheduler.add_condition(response, condition(decisionMaker, **condition_params)) result = comp.run([0.05], execution_mode=comp_mode) np.testing.assert_allclose(result, expected_result) @@ -1662,16 +1631,12 @@ def test_scheduler_conditions(self, comp_mode, condition, scale, expected_result [(pnl.AtTrial, None, [[[1.0]], [[2.0]]]), ]) def test_run_term_conditions(self, mode, condition, scale, expected_result): - incrementing_mechanism = pnl.ProcessingMechanism( - function=pnl.SimpleIntegrator - ) - comp = pnl.Composition( - pathways=[incrementing_mechanism] - ) - comp.scheduler.termination_conds = { - pnl.TimeScale.RUN: condition(2) - } + incrementing_mechanism = pnl.ProcessingMechanism(function=pnl.SimpleIntegrator) + comp = pnl.Composition(pathways=[incrementing_mechanism]) + + comp.scheduler.termination_conds = {pnl.TimeScale.RUN: condition(2)} r = comp.run(inputs=[1], num_trials=5, execution_mode=mode) + np.testing.assert_allclose(r, expected_result[-1]) np.testing.assert_allclose(comp.results, expected_result)