Skip to content

Commit

Permalink
Component: set _needs_recompile flag for execution_id on parameter sh…
Browse files Browse the repository at this point in the history
…ape change
  • Loading branch information
kmantel committed Jan 30, 2024
1 parent b595ed4 commit 8eedec0
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 1 deletion.
15 changes: 15 additions & 0 deletions psyneulink/core/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -2616,6 +2616,21 @@ def _set_multiple_parameter_values(self, context, **kwargs):
for (k, v) in kwargs.items():
getattr(self.parameters, k)._set(v, context)

def _record_parameter_shapes(self, context, visited=None):
if visited is None:
visited = set([self])

if context is None:
context = self.most_recent_context

for p in self.parameters:
p._record_shape(context)

for obj in self._dependent_components:
if obj not in visited:
visited.add(obj)
obj._record_parameter_shapes(context, visited=visited)

# ------------------------------------------------------------------------------------------------------------------
# Parsing methods
# ------------------------------------------------------------------------------------------------------------------
Expand Down
3 changes: 3 additions & 0 deletions psyneulink/core/compositions/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -13156,6 +13156,9 @@ def _gen_llvm_function(self, *, ctx:pnlvm.LLVMBuilderContext, tags:frozenset):
else:
return pnlvm.codegen.gen_composition_exec(ctx, self, tags=tags)

def _delete_compilation_data(self, context):
self._compilation_data.execution.delete(context)

def enable_logging(self):
for item in self.nodes + self.projections:
if isinstance(item, Composition):
Expand Down
25 changes: 24 additions & 1 deletion psyneulink/core/globals/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def _recurrent_transfer_mechanism_matrix_setter(value, owning_component=None, co
from psyneulink.core.globals.context import Context, ContextError, ContextFlags, _get_time, handle_external_context
from psyneulink.core.globals.context import time as time_object
from psyneulink.core.globals.log import LogCondition, LogEntry, LogError
from psyneulink.core.globals.utilities import call_with_pruned_args, convert_all_elements_to_np_array, copy_iterable_with_shared, \
from psyneulink.core.globals.utilities import call_with_pruned_args, convert_all_elements_to_np_array, copy_iterable_with_shared, extended_shape, \
get_alias_property_getter, get_alias_property_setter, get_deepcopy_with_shared, is_numeric, unproxy_weakproxy, create_union_set, safe_equals, get_function_sig_default_value, try_extract_0d_array_item
from psyneulink.core.rpc.graph_pb2 import Entry, ndArray

Expand Down Expand Up @@ -1003,6 +1003,7 @@ def __init__(
_inherited_source=None,
_user_specified=False,
_scalar_converted=False,
_shape=None,
**kwargs
):
if isinstance(aliases, str):
Expand Down Expand Up @@ -1065,6 +1066,7 @@ def __init__(
_user_specified=_user_specified,
_temp_uninherited=set(),
_scalar_converted=_scalar_converted,
_shape=_shape,
**kwargs
)

Expand Down Expand Up @@ -1537,6 +1539,10 @@ def _set(self, value, context, skip_history=False, skip_log=False, **kwargs):
return value

def _set_value(self, value, execution_id=None, context=None, skip_history=False, skip_log=False, skip_delivery=False):
# before compilation has happened, we don't need to track shape changes
if self._shape is not None:
self._record_shape(context, value)

# store history
if not skip_history:
if execution_id in self.values:
Expand Down Expand Up @@ -1765,6 +1771,23 @@ def _initialize_from_context(self, context=None, base_context=Context(execution_
except ParameterError as e:
raise ParameterError('Error when attempting to initialize from {0}: {1}'.format(base_context.execution_id, e))

def _record_shape(self, context, value=NotImplemented):
if value is NotImplemented:
value = self._get(context)

try:
new_shape = extended_shape(value)
except TypeError:
new_shape = type(value)

if self._shape is not None and new_shape != self._shape:
self._shape = new_shape
try:
for comp in self._owner._owner.compositions:
comp._delete_compilation_data(context)
except AttributeError:
pass

# KDM 7/30/18: the below is weird like this in order to use this like a property, but also include it
# in the interface for user simplicity: that is, inheritable (by this Parameter's children or from its parent),
# visible in a Parameter's repr, and easily settable by the user
Expand Down
2 changes: 2 additions & 0 deletions psyneulink/core/llvm/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,8 @@ def __init__(self, composition, execution_ids=[None], *, additional_tags=frozens

@staticmethod
def get(composition, context, additional_tags=frozenset()):
composition._record_parameter_shapes(context)

executions = composition._compilation_data.execution._get(context)
if executions is None:
executions = dict()
Expand Down

0 comments on commit 8eedec0

Please sign in to comment.