Skip to content

Commit

Permalink
Non-contiguous subgroups for synapses: summed variables
Browse files Browse the repository at this point in the history
  • Loading branch information
mstimberg committed Nov 14, 2023
1 parent bb7ddb9 commit ce5bb3b
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@

# Set all the target variable values to zero
for _target_idx in range({{_target_size_name}}):
{% if _target_contiguous %}
{{_target_var_array}}[_target_idx + {{_target_start}}] = 0
{% else %}
{{_target_var_array}}[{{_target_indices}}[_target_idx]] = 0
{% endif %}

# scalar code
_vectorisation_idx = 1
Expand Down
4 changes: 4 additions & 0 deletions brian2/codegen/runtime/numpy_rt/templates/summed_variable.py_
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ _vectorisation_idx = LazyArange(N)
# We write to the array, using the name provided as a keyword argument to the
# template
# Note that for subgroups, we do not want to overwrite the full array
{% if not _target_contiguous %}
{{_target_var_array}}[{{_target_indices}}] = _numpy.broadcast_to(_synaptic_var, (N, ))
{% else %}
{% if _target_start > 0 %}
_indices = {{_index_array}} - {{_target_start}}
{% else %}
Expand All @@ -32,4 +35,5 @@ _length = _target_stop - {{_target_start}}
{{_target_var_array}}[{{_target_start}}:_target_stop] = _numpy.bincount(_indices,
minlength=_length,
weights=_numpy.broadcast_to(_synaptic_var, (N, )))
{% endif %}
{% endblock %}
34 changes: 22 additions & 12 deletions brian2/core/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from collections import Counter, defaultdict, namedtuple
from collections.abc import Mapping, Sequence

import numpy as np

from brian2.core.base import BrianObject, BrianObjectException
from brian2.core.clocks import Clock, defaultclock
from brian2.core.names import Nameable
Expand Down Expand Up @@ -312,18 +314,26 @@ def _check_multiple_summed_updaters(objects):
"the target group instead."
)
raise NotImplementedError(msg)
elif (
obj.target.start < other_target.stop
and other_target.start < obj.target.stop
):
# Overlapping subgroups
msg = (
"Multiple 'summed variables' target the "
f"variable '{obj.target_var.name}' in overlapping "
f"groups '{other_target.name}' and '{obj.target.name}'. "
"Use separate variables in the target groups instead."
)
raise NotImplementedError(msg)
else:
if getattr(obj.target, "contiguous", True):
target_indices = np.arange(obj.target.start, obj.target.stop)
else:
target_indices = obj.target.indices[:]
if getattr(other_target, "contiguous", True):
other_indices = np.arange(other_target.start, other_target.stop)
else:
other_indices = other_target.indices[:]
if np.intersect1d(
target_indices, other_indices, assume_unique=True
).size:
# Overlapping subgroups
msg = (
"Multiple 'summed variables' target the "
f"variable '{obj.target_var.name}' in overlapping "
f"groups '{other_target.name}' and '{obj.target.name}'. "
"Use separate variables in the target groups instead."
)
raise NotImplementedError(msg)
summed_targets[obj.target_var] = obj.target


Expand Down
5 changes: 4 additions & 1 deletion brian2/devices/cpp_standalone/templates/summed_variable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@
//// MAIN CODE ////////////
{# This enables summed variables for connections to a synapse #}
const int _target_size = {{constant_or_scalar(_target_size_name, variables[_target_size_name])}};

// Set all the target variable values to zero
{{ openmp_pragma('parallel-static') }}
for (int _target_idx=0; _target_idx<_target_size; _target_idx++)
{
{% if _target_contiguous %}
{{_target_var_array}}[_target_idx + {{_target_start}}] = 0;
{% else %}
{{_target_var_array}}[{{_target_indices}}[_target_idx]] = 0;
{% endif %}
}

// scalar code
Expand Down
11 changes: 9 additions & 2 deletions brian2/synapses/synapses.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class SummedVariableUpdater(CodeRunner):
def __init__(
self, expression, target_varname, synapses, target, target_size_name, index_var
):
# Handling sumped variables using the standard mechanisms is not
# Handling summed variables using the standard mechanisms is not
# possible, we therefore also directly give the names of the arrays
# to the template.

Expand All @@ -139,14 +139,21 @@ def __init__(
"_index_var": synapses.variables[index_var],
"_target_start": getattr(target, "start", 0),
"_target_stop": getattr(target, "stop", -1),
"_target_contiguous": True,
}
needed_variables = [target_varname, target_size_name, index_var]
self.variables = Variables(synapses)
if not getattr(target, "contiguous", True):
self.variables.add_reference("_target_indices", target, "_sub_idx")
needed_variables.append("_target_indices")
template_kwds["_target_contiguous"] = False

CodeRunner.__init__(
self,
group=synapses,
template="summed_variable",
code=code,
needed_variables=[target_varname, target_size_name, index_var],
needed_variables=needed_variables,
# We want to update the summed variable before
# the target group gets updated
clock=target.clock,
Expand Down
22 changes: 15 additions & 7 deletions brian2/tests/test_subgroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,20 +644,28 @@ def test_synapses_access_subgroups_problematic():
@pytest.mark.standalone_compatible
def test_subgroup_summed_variable():
# Check in particular that only neurons targeted are reset to 0 (see github issue #925)
source = NeuronGroup(1, "")
target = NeuronGroup(5, "Iin : 1")
source = NeuronGroup(1, "x : 1")
target = NeuronGroup(
7,
"""Iin : 1
x : 1""",
)
source.x = 5
target.Iin = 10
target.x = "i"
target1 = target[1:2]
target2 = target[3:]

syn1 = Synapses(source, target1, "Iin_post = 5 : 1 (summed)")
target2 = target[3:5]
target3 = target[[0, 6]]
syn1 = Synapses(source, target1, "Iin_post = x_pre + x_post : 1 (summed)")
syn1.connect(True)
syn2 = Synapses(source, target2, "Iin_post = 1 : 1 (summed)")
syn2 = Synapses(source, target2, "Iin_post = x_pre + x_post : 1 (summed)")
syn2.connect(True)
syn3 = Synapses(source, target3, "Iin_post = x_pre + x_post : 1 (summed)")
syn3.connect(True)

run(2 * defaultclock.dt)

assert_array_equal(target.Iin, [10, 5, 10, 1, 1])
assert_array_equal(target.Iin, [5, 6, 10, 8, 9, 10, 11])


def test_subexpression_references():
Expand Down

0 comments on commit ce5bb3b

Please sign in to comment.