Skip to content

Commit

Permalink
Merge pull request #2601 from kmantel/emm
Browse files Browse the repository at this point in the history
EpisodicMemoryMechanism: make memory (_memory_init) a FunctionParameter
  • Loading branch information
kmantel authored Feb 11, 2023
2 parents b294d61 + b6c03ed commit bc85b0a
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"""

import copy
import numbers
import warnings
from collections import deque
Expand Down Expand Up @@ -2541,8 +2542,8 @@ def reset(self, previous_value=None, context=None):
previous_value = self._get_current_parameter_value("initializer", context)

if previous_value == []:
self.parameters.previous_value._get(context).clear()
value = np.ndarray(shape=(2, 0, len(self.defaults.variable[0])))
self.parameters.previous_value._set(copy.deepcopy(value), context)

else:
value = self._initialize_previous_value(previous_value, context=context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@
"""
import copy
import warnings
from typing import Optional, Union

Expand All @@ -416,7 +417,7 @@
from psyneulink.core.components.mechanisms.processing.processingmechanism import ProcessingMechanism_Base
from psyneulink.core.components.ports.inputport import InputPort
from psyneulink.core.globals.keywords import EPISODIC_MEMORY_MECHANISM, INITIALIZER, NAME, OWNER_VALUE, VARIABLE
from psyneulink.core.globals.parameters import Parameter, check_user_specified
from psyneulink.core.globals.parameters import FunctionParameter, Parameter, check_user_specified
from psyneulink.core.globals.preferences.basepreferenceset import is_pref_set
from psyneulink.core.globals.utilities import deprecation_warning, convert_to_np_array, convert_all_elements_to_np_array

Expand Down Expand Up @@ -508,6 +509,13 @@ class Parameters(ProcessingMechanism_Base.Parameters):
"""
variable = Parameter([[0,0]], pnl_internal=True, constructor_argument='default_variable')
function = Parameter(ContentAddressableMemory, stateful=False, loggable=False)
memory = FunctionParameter(None, function_parameter_name='initializer')

def _parse_memory(self, memory):
if memory is None:
return memory

return ContentAddressableMemory._enforce_memory_shape(memory)

@check_user_specified
def __init__(self,
Expand Down Expand Up @@ -538,15 +546,14 @@ def __init__(self,
size += kwargs['assoc_size']
kwargs.pop('assoc_size')

self._memory_init = memory

super().__init__(
default_variable=default_variable,
size=size,
function=function,
params=params,
name=name,
prefs=prefs,
memory=memory,
**kwargs
)

Expand All @@ -564,18 +571,15 @@ def _handle_default_variable(self, default_variable=None, size=None, input_ports
variable_shape = convert_all_elements_to_np_array(default_variable).shape \
if default_variable is not None else None
function_instance = self.function if isinstance(self.function, Function) else None
function_type = self.function if isinstance(self.function, type) else self.function.__class__

# **memory** arg is specified in constructor, so use that to initialize or validate default_variable
if self._memory_init:
try:
self._memory_init = function_type._enforce_memory_shape(self._memory_init)
except:
pass
if self.parameters.memory._user_specified:
memory = self.defaults.memory

if default_variable is None:
default_variable = self._memory_init[0]
default_variable = copy.deepcopy(memory[0])
else:
entry_shape = convert_all_elements_to_np_array(self._memory_init[0]).shape
entry_shape = convert_all_elements_to_np_array(memory[0]).shape
if entry_shape != variable_shape:
raise EpisodicMemoryMechanismError(f"Shape of 'variable' for {self.name} ({variable_shape}) "
f"does not match the shape of entries ({entry_shape}) in "
Expand Down Expand Up @@ -610,14 +614,9 @@ def _instantiate_input_ports(self, context=None):

def _instantiate_function(self, function, function_params, context):
"""Assign memory to function if specified in Mechanism's constructor"""
if self._memory_init is not None:
if isinstance(function, type):
function_params.update({INITIALIZER:self._memory_init})
else:
if len(function.memory):
warnings.warn(f"The 'memory' argument specified for {self.name} will override the specification "
f"for the {repr(INITIALIZER)} argument of its function ({self.function.name}).")
function.reset(self._memory_init)
memory = self.parameters.memory._get(context)
if memory is not None:
function.reset(memory)
super()._instantiate_function(function, function_params, context)

def _instantiate_output_ports(self, context=None):
Expand Down
7 changes: 5 additions & 2 deletions tests/mechanisms/test_episodic_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,11 @@ def test_with_contentaddressablememory(name, func, func_params, mech_params, tes
def test_contentaddressable_memory_warnings_and_errors():

# both memory arg of Mechanism and initializer for its function are specified
text = "The 'memory' argument specified for EpisodicMemoryMechanism-0 will override the specification " \
"for the 'initializer' argument of its function"
text = (
r"Specification of the \"memory\" parameter[.\S\s]*The value"
+ r" specified on \(ContentAddressableMemory ContentAddressableMemory"
+ r" Function-\d\) will be used\."
)
with pytest.warns(UserWarning, match=text):
em = EpisodicMemoryMechanism(
memory = [[[1,2,3],[4,5,6]]],
Expand Down

0 comments on commit bc85b0a

Please sign in to comment.