Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EpisodicMemoryMechanism: make memory (_memory_init) a FunctionParameter #2601

Merged
merged 2 commits into from
Feb 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

"""

import copy
import numbers
import warnings
from collections import deque
Expand Down Expand Up @@ -2541,8 +2542,8 @@ def reset(self, previous_value=None, context=None):
previous_value = self._get_current_parameter_value("initializer", context)

if previous_value == []:
self.parameters.previous_value._get(context).clear()
value = np.ndarray(shape=(2, 0, len(self.defaults.variable[0])))
self.parameters.previous_value._set(copy.deepcopy(value), context)

else:
value = self._initialize_previous_value(previous_value, context=context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@


"""
import copy
import warnings
from typing import Optional, Union

Expand All @@ -416,7 +417,7 @@
from psyneulink.core.components.mechanisms.processing.processingmechanism import ProcessingMechanism_Base
from psyneulink.core.components.ports.inputport import InputPort
from psyneulink.core.globals.keywords import EPISODIC_MEMORY_MECHANISM, INITIALIZER, NAME, OWNER_VALUE, VARIABLE
from psyneulink.core.globals.parameters import Parameter, check_user_specified
from psyneulink.core.globals.parameters import FunctionParameter, Parameter, check_user_specified
from psyneulink.core.globals.preferences.basepreferenceset import is_pref_set
from psyneulink.core.globals.utilities import deprecation_warning, convert_to_np_array, convert_all_elements_to_np_array

Expand Down Expand Up @@ -508,6 +509,13 @@ class Parameters(ProcessingMechanism_Base.Parameters):
"""
variable = Parameter([[0,0]], pnl_internal=True, constructor_argument='default_variable')
function = Parameter(ContentAddressableMemory, stateful=False, loggable=False)
memory = FunctionParameter(None, function_parameter_name='initializer')

def _parse_memory(self, memory):
if memory is None:
return memory

return ContentAddressableMemory._enforce_memory_shape(memory)

@check_user_specified
def __init__(self,
Expand Down Expand Up @@ -538,15 +546,14 @@ def __init__(self,
size += kwargs['assoc_size']
kwargs.pop('assoc_size')

self._memory_init = memory

super().__init__(
default_variable=default_variable,
size=size,
function=function,
params=params,
name=name,
prefs=prefs,
memory=memory,
**kwargs
)

Expand All @@ -564,18 +571,15 @@ def _handle_default_variable(self, default_variable=None, size=None, input_ports
variable_shape = convert_all_elements_to_np_array(default_variable).shape \
if default_variable is not None else None
function_instance = self.function if isinstance(self.function, Function) else None
function_type = self.function if isinstance(self.function, type) else self.function.__class__

# **memory** arg is specified in constructor, so use that to initialize or validate default_variable
if self._memory_init:
try:
self._memory_init = function_type._enforce_memory_shape(self._memory_init)
except:
pass
if self.parameters.memory._user_specified:
memory = self.defaults.memory

if default_variable is None:
default_variable = self._memory_init[0]
default_variable = copy.deepcopy(memory[0])
else:
entry_shape = convert_all_elements_to_np_array(self._memory_init[0]).shape
entry_shape = convert_all_elements_to_np_array(memory[0]).shape
if entry_shape != variable_shape:
raise EpisodicMemoryMechanismError(f"Shape of 'variable' for {self.name} ({variable_shape}) "
f"does not match the shape of entries ({entry_shape}) in "
Expand Down Expand Up @@ -610,14 +614,9 @@ def _instantiate_input_ports(self, context=None):

def _instantiate_function(self, function, function_params, context):
"""Assign memory to function if specified in Mechanism's constructor"""
if self._memory_init is not None:
if isinstance(function, type):
function_params.update({INITIALIZER:self._memory_init})
else:
if len(function.memory):
warnings.warn(f"The 'memory' argument specified for {self.name} will override the specification "
f"for the {repr(INITIALIZER)} argument of its function ({self.function.name}).")
function.reset(self._memory_init)
memory = self.parameters.memory._get(context)
if memory is not None:
function.reset(memory)
super()._instantiate_function(function, function_params, context)

def _instantiate_output_ports(self, context=None):
Expand Down
7 changes: 5 additions & 2 deletions tests/mechanisms/test_episodic_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,11 @@ def test_with_contentaddressablememory(name, func, func_params, mech_params, tes
def test_contentaddressable_memory_warnings_and_errors():

# both memory arg of Mechanism and initializer for its function are specified
text = "The 'memory' argument specified for EpisodicMemoryMechanism-0 will override the specification " \
"for the 'initializer' argument of its function"
text = (
r"Specification of the \"memory\" parameter[.\S\s]*The value"
+ r" specified on \(ContentAddressableMemory ContentAddressableMemory"
+ r" Function-\d\) will be used\."
)
with pytest.warns(UserWarning, match=text):
em = EpisodicMemoryMechanism(
memory = [[[1,2,3],[4,5,6]]],
Expand Down