Skip to content

Commit

Permalink
Fix/input port combine (#2755)
Browse files Browse the repository at this point in the history
• inputport.py
  _parse_port_specific_specs(): fix bug in which COMBINE was not parsed when specified in an InputPort specification dict (though still some weirdness in passing spec through to constructor, requiring function assignment in place)

• port.py
  _parse_port_spec():  add passing of Context.string for local handling based on caller (e.g., warning messages)
  • Loading branch information
jdcpni authored Jul 29, 2023
1 parent 4c22a0a commit 2890038
Show file tree
Hide file tree
Showing 12 changed files with 133 additions and 66 deletions.
12 changes: 7 additions & 5 deletions psyneulink/core/components/mechanisms/mechanism.py
Original file line number Diff line number Diff line change
Expand Up @@ -1865,9 +1865,9 @@ def _handle_arg_input_ports(self, input_ports):

try:
parsed_input_port_spec = _parse_port_spec(owner=self,
port_type=InputPort,
port_spec=s,
)
port_type=InputPort,
port_spec=s,
context=Context(string='handle_arg_input_ports'))
except AttributeError as e:
if DEFER_VARIABLE_SPEC_TO_MECH_MSG in e.args[0]:
default_variable_from_input_ports.append(InputPort.defaults.variable)
Expand Down Expand Up @@ -1980,9 +1980,11 @@ def _validate_params(self, request_set, target_set=None, context=None):
try:
try:
for port_spec in params[INPUT_PORTS]:
_parse_port_spec(owner=self, port_type=InputPort, port_spec=port_spec)
_parse_port_spec(owner=self, port_type=InputPort, port_spec=port_spec,
context=Context(string='mechanism.validate_params'))
except TypeError:
_parse_port_spec(owner=self, port_type=InputPort, port_spec=params[INPUT_PORTS])
_parse_port_spec(owner=self, port_type=InputPort, port_spec=params[INPUT_PORTS],
context=Context(string='mechanism.validate_params'))
except AttributeError as e:
if DEFER_VARIABLE_SPEC_TO_MECH_MSG in e.args[0]:
pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,7 @@
OBJECTIVE_MECHANISM, OUTCOME, OWNER_VALUE, PARAMS, PORT_TYPE, PRODUCT, PROJECTION_TYPE, PROJECTIONS, \
SEPARATE, SIZE
from psyneulink.core.globals.parameters import Parameter, check_user_specified
from psyneulink.core.globals.context import Context
from psyneulink.core.globals.preferences.basepreferenceset import ValidPrefSet
from psyneulink.core.globals.preferences.preferenceset import PreferenceLevel
from psyneulink.core.globals.utilities import ContentAddressableList, convert_all_elements_to_np_array, convert_to_list, convert_to_np_array
Expand Down Expand Up @@ -672,24 +673,17 @@ def __init__(self, message, data=None):


def validate_monitored_port_spec(owner, spec_list):
context = Context(string='ControlMechanism.validate_monitored_port_spec')
for spec in spec_list:
if isinstance(spec, MonitoredOutputPortTuple):
spec = spec.output_port
elif isinstance(spec, tuple):
spec = _parse_port_spec(
owner=owner,
port_type=InputPort,
port_spec=spec,
)
spec = _parse_port_spec(owner=owner, port_type=InputPort, port_spec=spec, context=context)
spec = spec['params'][PROJECTIONS][0][0]
elif isinstance(spec, dict):
# If it is a dict, parse to validate that it is an InputPort specification dict
# (for InputPort of ObjectiveMechanism to be assigned to the monitored_output_port)
spec = _parse_port_spec(
owner=owner,
port_type=InputPort,
port_spec=spec,
)
spec = _parse_port_spec(owner=owner, port_type=InputPort, port_spec=spec, context=context)
# Get the OutputPort, to validate that it is in the ControlMechanism's Composition (below);
# presumes that the monitored_output_port is the first in the list of projection_specs
# in the InputPort port specification dictionary returned from the parse,
Expand Down Expand Up @@ -1263,15 +1257,10 @@ def _validate_output_ports(self, control):

port_types = self._owner.outputPortTypes
for ctl_spec in control:
ctl_spec = _parse_port_spec(
port_type=port_types, owner=self._owner, port_spec=ctl_spec
)
if not (
isinstance(ctl_spec, port_types)
or (
isinstance(ctl_spec, dict) and ctl_spec[PORT_TYPE] == port_types
)
):
ctl_spec = _parse_port_spec(port_type=port_types, owner=self._owner, port_spec=ctl_spec,
context=Context(string='ControlMechanism._validate_input_ports'))
if not (isinstance(ctl_spec, port_types)
or (isinstance(ctl_spec, dict) and ctl_spec[PORT_TYPE] == port_types)):
return 'invalid port specification'

# FIX 5/28/20:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2581,7 +2581,8 @@ def _validate_entries(spec=None, source=None):
self.state_feature_specs[i] = spec

# Get InputPort specification dictionary for state_input_port and update its entries
parsed_spec = _parse_port_spec(owner=self, port_type=InputPort, port_spec=spec)
parsed_spec = _parse_port_spec(owner=self, port_type=InputPort, port_spec=spec,
context=Context(string='OptimizationControlMechanism._parse_specs'))
parsed_spec[NAME] = state_input_port_names[i]
if parsed_spec[PARAMS] and SHADOW_INPUTS in parsed_spec[PARAMS]:
# Composition._update_shadow_projections will take care of PROJECTIONS specification
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@
from psyneulink.core.components.ports.modulatorysignals.learningsignal import LearningSignal
from psyneulink.core.components.ports.parameterport import ParameterPort
from psyneulink.core.components.shellclasses import Mechanism
from psyneulink.core.globals.context import ContextFlags, handle_external_context
from psyneulink.core.globals.context import Context, ContextFlags, handle_external_context
from psyneulink.core.globals.keywords import \
ADDITIVE, ASSERT, ENABLED, INPUT_PORTS, \
LEARNING, LEARNING_MECHANISM, LEARNING_PROJECTION, LEARNING_SIGNAL, LEARNING_SIGNALS, MATRIX, \
Expand Down Expand Up @@ -1161,7 +1161,8 @@ def _validate_params(self, request_set, target_set=None, context=None):
format(LEARNING_SIGNAL, self.name))

for spec in target_set[LEARNING_SIGNALS]:
learning_signal = _parse_port_spec(port_type=LearningSignal, owner=self, port_spec=spec)
learning_signal = _parse_port_spec(port_type=LearningSignal, owner=self, port_spec=spec,
context=Context(string='LearningMechanism.validate_params'))

# Validate that the receiver of the LearningProjection (if specified)
# is a MappingProjection and in the same Composition as self (if specified)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@
from psyneulink.core.components.ports.inputport import InputPort, INPUT_PORT
from psyneulink.core.components.ports.outputport import OutputPort
from psyneulink.core.components.ports.port import _parse_port_spec
from psyneulink.core.globals.context import ContextFlags, handle_external_context
from psyneulink.core.globals.context import Context, ContextFlags, handle_external_context
from psyneulink.core.globals.keywords import \
CONTROL, EXPONENT, EXPONENTS, LEARNING, MATRIX, NAME, OBJECTIVE_MECHANISM, OUTCOME, OWNER_VALUE, \
PARAMS, PREFERENCE_SET_NAME, PROJECTION, PROJECTIONS, PORT_TYPE, VARIABLE, WEIGHT, WEIGHTS
Expand Down Expand Up @@ -714,7 +714,8 @@ def add_to_monitor(self, monitor_specs, context=None):
monitor_specs[i] = spec

# Parse spec to get value of OutputPort and (possibly) the Projection from it
input_port = _parse_port_spec(owner=self, port_type = InputPort, port_spec=spec)
input_port = _parse_port_spec(owner=self, port_type = InputPort, port_spec=spec,
context=Context(string='objective_mechanism.add_to_monitor'))

# There should be only one ProjectionTuple specified,
# that designates the OutputPort and (possibly) a Projection from it
Expand Down
90 changes: 65 additions & 25 deletions psyneulink/core/components/ports/inputport.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,14 +519,14 @@
.. _InputPort_Function:
* `function <InputPort.function>` -- combines the `value <Projection_Base.value>` of all of the
`Projections <Projection>` received by the InputPort, and assigns the result to the InputPort's `value
<InputPort.value>` attribute. The default function is `LinearCombination` that performs an elementwise (Hadamard)
sums the values. However, the parameters of the `function <InputPort.function>` -- and thus the `value
<InputPort.value>` of the InputPort -- can be modified by any `GatingProjections <GatingProjection>` received by
the InputPort (listed in its `mod_afferents <Port_Base.mod_afferents>` attribute. A custom function can also be
specified, so long as it generates a result that is compatible with the item of the Mechanism's `variable
<Mechanism_Base.variable>` to which the `InputPort is assigned <Mechanism_InputPorts>`.
* `function <InputPort.function>` -- combines the `value <Projection_Base.value>` of all of the `path_afferent
<InputPort.path_afferents>` `Projections <Projection>` received by the InputPort, and assigns the result to the
InputPort's `value <InputPort.value>` attribute. The default function is `LinearCombination` that performs an
elementwise (Hadamard) sum of the afferent values. However, the parameters of the `function <InputPort.function>`
-- and thus the `value <InputPort.value>` of the InputPort -- can be modified by any `GatingProjections
<GatingProjection>` received by the InputPort (listed in its `mod_afferents <Port_Base.mod_afferents>` attribute.
A custom function can also be specified, so long as it generates a result that is compatible with the item of the
Mechanism's `variable <Mechanism_Base.variable>` to which the `InputPort is assigned <Mechanism_InputPorts>`.
.. _InputPort_Value:
Expand All @@ -551,7 +551,7 @@
An InputPort cannot be executed directly. It is executed when the Mechanism to which it belongs is executed.
When this occurs, the InputPort executes any `Projections <Projection>` it receives, calls its `function
<InputPort.function>` to combines the values received from any `MappingProjections <MappingProjection>` it receives
<InputPort.function>` to combine the values received from any `MappingProjections <MappingProjection>` it receives
(listed in its its `path_afferents <Port_Base.path_afferents>` attribute) and modulate them in response to any
`GatingProjections <GatingProjection>` (listed in its `mod_afferents <Port_Base.mod_afferents>` attribute),
and then assigns the result to the InputPort's `value <InputPort.value>` attribute. This, in turn, is assigned to
Expand Down Expand Up @@ -739,7 +739,7 @@ class InputPort(Port_Base):
applied and it will generate a value that is the same length as the Projection's `value
<Projection_Base.value>`. However, if the InputPort receives more than one Projection and
uses a function other than a CombinationFunction, a warning is generated and only the `value
<Projection_Base.value>` of the first Projection list in `path_afferents <Port_Base.path_afferents>`
<Projection_Base.value>` of the first Projection listed in `path_afferents <Port_Base.path_afferents>`
is used by the function, which may generate unexpected results when executing the Mechanism or Composition
to which it belongs.
Expand Down Expand Up @@ -1113,18 +1113,24 @@ def _get_all_projections(self):
return self._get_all_afferents()

@beartype
def _parse_port_specific_specs(self, owner, port_dict, port_specific_spec):
"""Get weights, exponents and/or any connections specified in an InputPort specification tuple
def _parse_port_specific_specs(self, owner, port_dict, port_specific_spec, context=None):
"""Parse any InputPort-specific specifications, including SIZE, COMBINE, WEIGHTS and EXPONENTS
Get SIZE and/or COMBINE specification in if port_specific_spec is a dict
Get weights, exponents and/or any connections specified if port_specific_spec is a tuple
Tuple specification can be:
(port_spec, connections)
(port_spec, weights, exponents, connections)
See Port._parse_port_specific_spec for additional info.
See Port._parse_port_specific_specs for additional info.
Returns:
- port_spec: 1st item of tuple if it is a numeric value; otherwise None
- params dict with WEIGHT, EXPONENT and/or PROJECTIONS entries if any of these was specified.
- port_spec:
- updated with SIZE and/or COMBINE specifications for dict;
- 1st item for tuple if it is a numeric value;
- otherwise None
- params dict:
- with WEIGHT, EXPONENT and/or PROJECTIONS entries if any of these was specified.
- purged of SIZE and/or COMBINE entries if they were specified in port_specific_spec
"""
# FIX: ADD FACILITY TO SPECIFY WEIGHTS AND/OR EXPONENTS FOR INDIVIDUAL OutputPort SPECS
Expand All @@ -1145,16 +1151,49 @@ def _parse_port_specific_specs(self, owner, port_dict, port_specific_spec):
# FIX: USE ObjectiveMechanism EXAMPLES
# if MECHANISM in port_specific_spec:
# if OUTPUT_PORTS in port_specific_spec
if SIZE in port_specific_spec:
if (VARIABLE in port_specific_spec or
any(key in port_dict and port_dict[key] is not None for key in {VARIABLE, SIZE})):
raise InputPortError(f"PROGRAM ERROR: SIZE specification found in port_specific_spec dict "
f"for {self.__name__} specification of {owner.name} when SIZE or VARIABLE "
f"is already present in its port_specific_spec dict or port_dict.")
port_dict.update({VARIABLE:np.zeros(port_specific_spec[SIZE])})
del port_specific_spec[SIZE]

if any(spec in port_specific_spec for spec in {SIZE, COMBINE}):

if SIZE in port_specific_spec:
if (VARIABLE in port_specific_spec or
any(key in port_dict and port_dict[key] is not None for key in {VARIABLE, SIZE})):
raise InputPortError(f"PROGRAM ERROR: SIZE specification found in port_specific_spec dict "
f"for {self.__name__} specification of {owner.name} when SIZE or VARIABLE "
f"is already present in its port_specific_spec dict or port_dict.")
port_dict.update({VARIABLE:np.zeros(port_specific_spec[SIZE])})
del port_specific_spec[SIZE]

if COMBINE in port_specific_spec:
fct_err = None
if (FUNCTION in port_specific_spec and port_specific_spec[FUNCTION] is not None):
fct_str = port_specific_spec[FUNCTION].componentName
fct_err = port_specific_spec[FUNCTION].operation != port_specific_spec[COMBINE]
del port_specific_spec[FUNCTION]
elif (FUNCTION in port_dict and port_dict[FUNCTION] is not None):
fct_str = port_dict[FUNCTION].componentName
fct_err = port_dict[FUNCTION].operation != port_specific_spec[COMBINE]
del port_dict[FUNCTION]
if fct_err is True:
raise InputPortError(f"COMBINE entry (='{port_specific_spec[COMBINE]}') of InputPort "
f"specification dictionary for '{self.__name__}' of '{owner.name}' "
f"conflicts with FUNCTION entry ({fct_str}); remove one or the other.")
if fct_err is False and any(source in context.string
for source in {'validate_params',
'_instantiate_input_ports',
'_instantiate_output_ports'}): # Suppress warning in earlier calls
warnings.warn(f"Both COMBINE ('{port_specific_spec[COMBINE]}') and FUNCTION ({fct_str}) "
f"specifications found in InputPort specification dictionary for '{self.__name__}' "
f"of '{owner.name}'; no need to specify both.")
# FIX: THE NEXT LINE, WHICH SHOULD JUST PASS THE COMBINE SPECIFICATION ON TO THE CONSTRUCTOR
# (AND HANDLE FUNCTION ASSIGNMENT THERE) CAUSES A CRASH (APPEARS TO BE A RECURSION ERROR);
# THEREFORE, NEED TO SET FUNCTION HERE
# port_dict.update({COMBINE: port_specific_spec[COMBINE]})
port_specific_spec[FUNCTION] = LinearCombination(operation=port_specific_spec[COMBINE])
del port_specific_spec[COMBINE]

return port_dict, port_specific_spec
return None, port_specific_spec
else:
return None, port_specific_spec

elif isinstance(port_specific_spec, tuple):

Expand Down Expand Up @@ -1520,6 +1559,7 @@ def _instantiate_input_ports(owner, input_ports=None, reference_value=None, cont
if input_ports is not None:
input_ports = _parse_shadow_inputs(owner, input_ports)

context.string = context.string or '_instantiate_input_ports'
port_list = _instantiate_port_list(owner=owner,
port_list=input_ports,
port_types=InputPort,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1025,7 +1025,7 @@ def _instantiate_cost_attributes(self, context=None):
self.duration_cost = 0
self.cost = self.defaults.cost = self.intensity_cost

def _parse_port_specific_specs(self, owner, port_dict, port_specific_spec):
def _parse_port_specific_specs(self, owner, port_dict, port_specific_spec, context=None):
"""Get ControlSignal specified for a parameter or in a 'control_signals' argument
Tuple specification can be:
Expand Down
Loading

0 comments on commit 2890038

Please sign in to comment.