Skip to content

Commit

Permalink
Fix/matrix transform l0 (#3113)
Browse files Browse the repository at this point in the history
• transformfunctions.py
  - MatrixTransform:  allow normaliation for L0

• emcomposition.py
  - enforce normalize_memories for len(keys)==1

• test_emcomposition.py
  - test_simple_execution_without_learning():  add tests for scalar keys & use of L0 in MatrixTransform
  • Loading branch information
jdcpni authored Nov 14, 2024
1 parent 3fc73e1 commit ee61d35
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 95 deletions.
11 changes: 0 additions & 11 deletions docs/source/CombinationFunctions.rst

This file was deleted.

4 changes: 2 additions & 2 deletions docs/source/Core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ Core

- `NonStatefulFunctions`

- `CombinationFunctions`

- `DistributionFunctions`

- `LearningFunctions`
Expand All @@ -71,6 +69,8 @@ Core

- `TransferFunctions`

- `TransformFunctions`

- `StatefulFunctions`

- `IntegratorFunctions`
Expand Down
5 changes: 3 additions & 2 deletions docs/source/NonStatefulFunctions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ Functions that do *not* depend on a previous value.
.. toctree::
:maxdepth: 1

CombinationFunctions

DistributionFunctions
LearningFunctions
ObjectiveFunctions
OptimizationFunctions
SelectionFunctions
TransferFunctions
TransferFunctions
TransformFunctions
11 changes: 11 additions & 0 deletions docs/source/TransformFunctions.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
TransformFunctions
==================

.. toctree::
:maxdepth: 3

.. automodule:: psyneulink.core.components.functions.transformfunctions
:members: Concatenate, Rearrange, Reduce, LinearCombination, CombineMeans, MatrixTransform, PredictionErrorDeltaFunction
:private-members:
:exclude-members: Parameters

Original file line number Diff line number Diff line change
Expand Up @@ -1628,20 +1628,48 @@ class MatrixTransform(TransformFunction): # -----------------------------------
Matrix transform of `variable <MatrixTransform.variable>`.
`function <MatrixTransform._function>` returns dot product of variable with matrix:
`function <MatrixTransform._function>` returns a matrix transform of `variable <MatrixTransform.variable>`
based on the **operation** argument.
.. math::
variable \\bullet matrix
**operation** = *DOT_PRODUCT*:
If *DOT_PRODUCT* is specified as the **operation*, the result is the dot product of `variable
<MatrixTransform.variable>` and `matrix <MatrixTransform.matrix>`; if *L0* is specified, the result is the
difference between `variable <MatrixTransform.variable>` and `matrix <MatrixTransform.matrix>` (see
`operation <MatrixTransform.operation>` for additional details).
Returns the dot (inner) product of `variable <MatrixTransform.variable>` and `matrix <MatrixTransform.matrix>`:
If **normalize** is True, the result is normalized by the product of the norms of the variable and matrix:
.. math::
{variable} \\bullet |matrix|
If **normalize** =True, the result is normalized by the product of the norms of the variable and matrix:
.. math::
\\frac{variable \\bullet matrix}{\\|variable\\| \\cdot \\|matrix\\|}
.. note::
For **normalize** =True, the result is the same as the cosine of the angle between pairs of vectors.
**operation** = *L0*:
Returns the absolute value of the difference between `variable <MatrixTransform.variable>` and `matrix
<MatrixTransform.matrix>`:
.. math::
|variable - matrix|
If **normalize** =True, the result is normalized by the norm of the sum of differences between the variable and
matrix, which is then subtracted from 1:
.. math::
1 - \\frac{|variable - matrix|}{\\|variable - matrix\\|}
.. note::
For **normalize** =True, the result has the same effect as the normalized *DOT_PRODUCT* operation,
with more similar pairs of vectors producing larger values (closer to 1).
.. warning::
For **normalize** =False, the result is smaller (closer to 0) for more similar pairs of vectors,
which is **opposite** the effect of the *DOT_PRODUCT* and normalized *L0* operations. If the desired
result is that more similar pairs of vectors produce larger values, set **normalize** =True or
use the *DOT_PRODUCT* operation.
.. math::
\\frac{variable \\bullet matrix}{\\|variable\\| \\cdot \\|matrix\\|}
COMMENT: [CONVERT TO FIGURE]
----------------------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -1679,7 +1707,7 @@ class MatrixTransform(TransformFunction): # -----------------------------------
specifies matrix used to transform `variable <MatrixTransform.variable>`
(see `matrix <MatrixTransform.matrix>` for specification details).
When MatrixTransform is the `function <Projection_Base._function>` of a projection:
When MatrixTransform is the `function <Projection_Base.function>` of a projection:
- the matrix specification must be compatible with the variables of the `sender <Projection_Base.sender>`
and `receiver <Projection_Base.receiver>`
Expand Down Expand Up @@ -1795,15 +1823,6 @@ class Parameters(TransformFunction.Parameters):
normalize = Parameter(False)
bounds = None

# def is_matrix_spec(m):
# if m is None:
# return True
# if m in MATRIX_KEYWORD_VALUES:
# return True
# if isinstance(m, (list, np.ndarray, types.FunctionType)):
# return True
# return False

@check_user_specified
@beartype
def __init__(self,
Expand Down Expand Up @@ -1833,25 +1852,6 @@ def __init__(self,
skip_log=True,
)

# def _validate_variable(self, variable, context=None):
# """Insure that variable passed to MatrixTransform is a max 2D array
#
# :param variable: (max 2D array)
# :param context:
# :return:
# """
# variable = super()._validate_variable(variable, context)
#
# # Check that variable <= 2D
# try:
# if not variable.ndim <= 2:
# raise FunctionError("variable ({0}) for {1} must be a numpy.ndarray of dimension at most 2".format(variable, self.__class__.__name__))
# except AttributeError:
# raise FunctionError("PROGRAM ERROR: variable ({0}) for {1} should be a numpy.ndarray".
# format(variable, self.__class__.__name__))
#
# return variable


def _validate_params(self, request_set, target_set=None, context=None):
"""Validate params and assign to targets
Expand Down Expand Up @@ -2013,15 +2013,6 @@ def _validate_params(self, request_set, target_set=None, context=None):
self.name,
self.owner_name,
MATRIX_KEYWORD_NAMES))

# operation param
elif param_name == OPERATION:
if param_value == L0 and NORMALIZE in param_set and param_set[NORMALIZE]:
raise FunctionError(f"The 'operation' parameter for the {self.name} function of "
f"{self.owner_name} is set to 'L0', so the 'normalize' parameter "
f"should not be set to True "
f"(normalization is not needed, and can cause a divide by zero error). "
f"Set 'normalize' to False or change 'operation' to 'DOT_PRODUCT'.")
else:
continue

Expand Down Expand Up @@ -2176,7 +2167,7 @@ def diff_with_normalization(vector, matrix):
if normalize:
return diff_with_normalization
else:
return lambda x, y: torch.sum((1 - torch.abs(x - y)),axis=0)
return lambda x, y: torch.sum(torch.abs(x - y),axis=0)

else:
from psyneulink.library.compositions.autodiffcomposition import AutodiffCompositionError
Expand Down Expand Up @@ -2224,10 +2215,11 @@ def _function(self,
result = np.dot(vector, matrix)

elif operation == L0:
normalization = 1
if normalize:
normalization = np.sum(np.abs(vector - matrix))
result = np.sum(((1 - np.abs(vector - matrix)) / normalization),axis=0)
result = np.sum((1 - (np.abs(vector - matrix)) / normalization),axis=0)
else:
result = np.sum((np.abs(vector - matrix)),axis=0)

return self.convert_output_type(result)

Expand Down
7 changes: 5 additions & 2 deletions psyneulink/library/compositions/emcomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -2201,7 +2201,11 @@ def _construct_match_nodes(self, memory_template, memory_capacity, concatenate_q
"""
OPERATION = 0
NORMALIZE = 1
args = [(L0,False) if len(key) == 1 else (DOT_PRODUCT,normalize_memories) for key in memory_template[0]]
# Enforce normalization of memories if key is a scalar
# (this is to allow 1-L0 distance to be used as similarity measure, so that better matches
# (more similar memories) have higher match values; see `MatrixTransform` for explanation)
args = [(L0,True) if len(key) == 1 else (DOT_PRODUCT,normalize_memories)
for key in memory_template[0]]

if concatenate_queries:
# Get fields of memory structure corresponding to the keys
Expand Down Expand Up @@ -2238,7 +2242,6 @@ def _construct_match_nodes(self, memory_template, memory_capacity, concatenate_q
for i in range(self.num_keys)
]


return match_nodes

# FIX: CONVERT TO _construct_weight_control_nodes
Expand Down
Loading

0 comments on commit ee61d35

Please sign in to comment.