Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…Link into feat/pec_cond_all
  • Loading branch information
davidt0x committed Nov 4, 2024
2 parents efa2875 + 68752cc commit 13ab7b9
Show file tree
Hide file tree
Showing 19 changed files with 1,026 additions and 1,033 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,14 @@ def calc_prob(em_preds, test_ys):
previous_state_d = 11, # length of state vector
context_d = 11, # length of context vector
memory_capacity = ALL, # number of entries in EM memory; ALL=> match to number of stims
memory_init = (0,.0001), # Initialize memory with random values in interval
# memory_init = None, # Initialize with zeros
# memory_init = (0,.0001), # Initialize memory with random values in interval
memory_init = None, # Initialize with zeros
concatenate_queries = False,
# concatenate_queries = True,

# environment
# curriculum_type = 'Interleaved',
curriculum_type = 'Blocked',
curriculum_type = 'Interleaved',
# curriculum_type = 'Blocked',
# num_stims = 100, # Integer or ALL
num_stims = ALL, # Integer or ALL

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def construct_model(model_name:str=model_params['name'],
if RUN_MODEL:
import timeit
def print_stuff(**kwargs):
print(f"\n**************\n BATCH: {kwargs['batch']}\n**************\n")
print(f"\n**************\n BATCH: {kwargs['minibatch']}\n**************\n")
print(kwargs)
print('\nContext internal: \n', model.nodes['CONTEXT'].function.parameters.value.get(kwargs['context']))
print('\nContext hidden: \n', model.nodes['CONTEXT'].parameters.value.get(kwargs['context']))
Expand Down Expand Up @@ -407,8 +407,8 @@ def print_stuff(**kwargs):
)
stop_time = timeit.default_timer()
print(f"Elapsed time: {stop_time - start_time}")
if DISPLAY_MODEL is not None:
model.show_graph(**DISPLAY_MODEL)
# if DISPLAY_MODEL is not None:
# model.show_graph(**DISPLAY_MODEL)
if PRINT_RESULTS:
print("MEMORY:")
print(np.round(model.nodes['EM'].parameters.memory.get(model.name),3))
Expand Down Expand Up @@ -450,7 +450,7 @@ def eval_weights(weight_mat):
axes[1].set_xlabel('Stimuli')
axes[1].set_ylabel(model_params['loss_spec'])
# Logit of loss
axes[2].plot( (model.results[1:TOTAL_NUM_STIMS,2]*TARGETS[:TOTAL_NUM_STIMS-1]).sum(-1) )
axes[2].plot( (model.results[2:TOTAL_NUM_STIMS,2]*TARGETS[:TOTAL_NUM_STIMS-2]).sum(-1) )
axes[2].set_xlabel('Stimuli')
axes[2].set_ylabel('Correct Logit')
plt.suptitle(f"{model_params['curriculum_type']} Training")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
from torch.utils.data import dataset
from torch import utils
from numpy.random import randint
from random import randint

def one_hot_encode(labels, num_classes):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,6 @@
PRINT_RESULTS = False # don't print model.results to console after execution
# PRINT_RESULTS = True # print model.results to console after execution
SAVE_RESULTS = False # save model.results to disk
PLOT_RESULTS = False # don't plot results (PREDICTIONS) vs. TARGETS
# PLOT_RESULTS = True # plot results (PREDICTIONS) vs. TARGETS
# PLOT_RESULTS = False # don't plot results (PREDICTIONS) vs. TARGETS
PLOT_RESULTS = True # plot results (PREDICTIONS) vs. TARGETS
ANIMATE = False # {UNIT:EXECUTION_SET} # Specifies whether to generate animation of execution
30 changes: 13 additions & 17 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,28 +155,24 @@ def comp_mode_no_llvm():
# dummy fixture to allow 'comp_mode' filtering
pass

class FirstBench():
def __init__(self, benchmark):
super().__setattr__("benchmark", benchmark)
@pytest.fixture
def benchmark(benchmark):

def __call__(self, f, *args, **kwargs):
res = []
# Compute the first result if benchmark is enabled
if self.benchmark.enabled:
res.append(f(*args, **kwargs))
orig_class = type(benchmark)

res.append(self.benchmark(f, *args, **kwargs))
return res[0]
class _FirstBench(orig_class):
def __call__(self, f, *args, **kwargs):
res = []
# Compute the first result if benchmark is enabled
if self.enabled:
res.append(f(*args, **kwargs))

def __getattr__(self, attr):
return getattr(self.benchmark, attr)
res.append(orig_class.__call__(self, f, *args, **kwargs))
return res[0]

def __setattr__(self, attr, val):
return setattr(self.benchmark, attr, val)
benchmark.__class__ = _FirstBench

@pytest.fixture
def benchmark(benchmark):
return FirstBench(benchmark)
return benchmark

@pytest.helpers.register
def llvm_current_fp_precision():
Expand Down
2 changes: 1 addition & 1 deletion dev_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
jupyter<1.1.2
packaging<25.0
pytest<8.3.4
pytest-benchmark<4.0.1
pytest-benchmark<5.1.1
pytest-cov<5.0.1
pytest-forked<1.7.0
pytest-helpers-namespace<2021.12.30
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
--------
A GatingMechanism is a subclass of `ControlMechanism` that is restricted to using only `GatingSignals <GatingSignal>`,
which modulate the `input <Mechanism_InputPorts>` or `output <Mechanism_InputPorts>` of a `Mechanism <Mechanism>`,
which modulate the `input <Mechanism_InputPorts>` or `output <Mechanism_OutputPorts>` of a `Mechanism <Mechanism>`,
but not the paramaters of its `function <Mechanism_Base.function>`. Accordingly, its constructor has a **gate**
argument in place of a **control** argument. It also lacks several attributes related to control, including those
related to costs and net_outcome. In all other respects it is identical to its parent class, ControlMechanism.
Expand Down Expand Up @@ -58,7 +58,7 @@
*Specifying gating*
~~~~~~~~~~~~~~~~~~~
A GatingMechanism is used to modulate the value of an `InputPort` or `OutputPort`. An InputPort or OutputPort can
A GatingMechanism is used to modulate the value of an `InputPort` or `InputPort`. An InputPort or OutputPort can
be specified for gating by assigning it a `GatingProjection` or `GatingSignal` anywhere that the Projections to a Port
or its `ModulatorySignals can be specified <State_Creation>`. A `Mechanism <Mechanism>` can also be specified for
gating, in which case the `primary InputPort <InputPort_Primary>` of the specified Mechanism is used. Ports
Expand Down
18 changes: 6 additions & 12 deletions psyneulink/core/compositions/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -12948,22 +12948,16 @@ def disable_all_history(self):
self._set_all_parameter_properties_recursively(history_max_length=0)

def _get_processing_condition_set(self, node):
dep_group = []
for group in self.scheduler.consideration_queue:
for index, group in enumerate(self.scheduler.consideration_queue):
if node in group:
break
dep_group = group

# This condition is used to check of the step has passed.
# Not all nodes in the previous step need to execute
# (they might have other conditions), but if any one does we're good
# FIXME: This will fail if none of the previously considered
# nodes executes in this pass, but that is unlikely.
conds = [Any(*(AllHaveRun(dep, time_scale=TimeScale.PASS) for dep in dep_group))] if len(dep_group) else []

assert index is not None

if node in self.scheduler.conditions:
conds.append(self.scheduler.conditions[node])
return index, self.scheduler.conditions[node]

return All(*conds)
return index, Always()

def _input_matches_variable(self, input_value, var):
var_shape = convert_to_np_array(var).shape
Expand Down
Loading

0 comments on commit 13ab7b9

Please sign in to comment.