Skip to content

Commit

Permalink
WIP: Allow non-contiguous subgroups for synapses
Browse files Browse the repository at this point in the history
Only C++ standalone supported until now
  • Loading branch information
mstimberg committed Nov 2, 2023
1 parent 285dffb commit 4cb865b
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,16 @@ const size_t _new_num_synapses = _old_num_synapses + _numsources;
constants or scalar arrays#}
const size_t _N_pre = {{constant_or_scalar('N_pre', variables['N_pre'])}};
const size_t _N_post = {{constant_or_scalar('N_post', variables['N_post'])}};
{% if "_target_sub_idx" in variables %}
{{_dynamic_N_incoming}}.resize({{get_array_name(variables['_target_sub_idx'])}}[_num_target_sub_idx - 1] + 1);
{% else %}
{{_dynamic_N_incoming}}.resize(_N_post + _target_offset);
{% endif %}
{% if "_source_sub_idx" in variables %}
{{_dynamic_N_outgoing}}.resize({{get_array_name(variables['_source_sub_idx'])}}[_num_source_sub_idx - 1] + 1);
{% else %}
{{_dynamic_N_outgoing}}.resize(_N_pre + _source_offset);
{% endif %}

for (size_t _idx=0; _idx<_numsources; _idx++) {
{# After this code has been executed, the arrays _real_sources and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,16 @@
constants or scalar arrays#}
const size_t _N_pre = {{constant_or_scalar('N_pre', variables['N_pre'])}};
const size_t _N_post = {{constant_or_scalar('N_post', variables['N_post'])}};
{% if "_target_sub_idx" in variables %}
{{_dynamic_N_incoming}}.resize({{get_array_name(variables['_target_sub_idx'])}}[_num_target_sub_idx - 1] + 1);
{% else %}
{{_dynamic_N_incoming}}.resize(_N_post + _target_offset);
{% endif %}
{% if "_source_sub_idx" in variables %}
{{_dynamic_N_outgoing}}.resize({{get_array_name(variables['_source_sub_idx'])}}[_num_source_sub_idx - 1] + 1);
{% else %}
{{_dynamic_N_outgoing}}.resize(_N_pre + _source_offset);
{% endif %}
size_t _raw_pre_idx, _raw_post_idx;
{# For a connect call j='k+i for k in range(0, N_post, 2) if k+i < N_post'
"j" is called the "result index" (and "_post_idx" the "result index array", etc.)
Expand All @@ -35,7 +43,11 @@
for(size_t _{{outer_index}}=0; _{{outer_index}}<_{{outer_index_size}}; _{{outer_index}}++)
{
bool __cond, _cond;
{% if outer_contiguous %}
_raw{{outer_index_array}} = _{{outer_index}} + {{outer_index_offset}};
{% else %}
_raw{{outer_index_array}} = {{get_array_name(variables[outer_sub_idx])}}[_{{outer_index}}];
{% endif %}
{% if not result_index_condition %}
{
{{vector_code['create_cond']|autoindent}}
Expand Down Expand Up @@ -181,7 +193,11 @@
}
_{{result_index}} = __{{result_index}}; // make the previously locally scoped var available
{{outer_index_array}} = _{{outer_index_array}};
{% if result_contiguous %}
_raw{{result_index_array}} = _{{result_index}} + {{result_index_offset}};
{% else %}
_raw{{result_index_array}} = {{get_array_name(variables[result_sub_idx])}}[_{{result_index}}];
{% endif %}
{% if result_index_condition %}
{
{% if result_index_used %}
Expand Down
110 changes: 78 additions & 32 deletions brian2/synapses/synapses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1363,45 +1363,55 @@ def _create_variables(self, equations, user_dtype=None):
self.variables.add_reference("_presynaptic_idx", self, "_synaptic_pre")
self.variables.add_reference("_postsynaptic_idx", self, "_synaptic_post")

# Except for subgroups (which potentially add an offset), the "i" and
# "j" variables are simply equivalent to `_synaptic_pre` and
# `_synaptic_post`
if getattr(self.source, "start", 0) == 0:
self.variables.add_reference("i", self, "_synaptic_pre")
else:
if isinstance(self.source, Subgroup) and not self.source.contiguous:
raise TypeError(
"Cannot use a non-contiguous subgroup as a "
"source group for Synapses."
)
# Except for subgroups, the "i" and "j" variables are simply equivalent to
# `_synaptic_pre` and `_synaptic_post`
if (
isinstance(self.source, Subgroup)
and not getattr(self.source, "start", -1) == 0
):
self.variables.add_reference(
"_source_i", self.source.source, "i", index="_presynaptic_idx"
)
self.variables.add_reference("_source_offset", self.source, "_offset")
self.variables.add_subexpression(
"i",
dtype=self.source.source.variables["i"].dtype,
expr="_source_i - _source_offset",
index="_presynaptic_idx",
)
if getattr(self.target, "start", 0) == 0:
self.variables.add_reference("j", self, "_synaptic_post")
else:
if isinstance(self.target, Subgroup) and not self.target.contiguous:
raise TypeError(
"Cannot use a non-contiguous subgroup as a "
"target group for Synapses."
if getattr(self.source, "contiguous", True):
# Contiguous subgroup simply shift the indices
self.variables.add_subexpression(
"i",
dtype=self.source.source.variables["i"].dtype,
expr="_source_i - _source_offset",
index="_presynaptic_idx",
)
else:
# Non-contiguous subgroups need a full translation
self.variables.add_reference(
"i", self.source, "i", index="_presynaptic_idx"
)
else:
# No subgroup or subgroup starting at 0
self.variables.add_reference("i", self, "_synaptic_pre")

if (
isinstance(self.target, Subgroup)
and not getattr(self.target, "start", -1) == 0
):
self.variables.add_reference(
"_target_j", self.target.source, "i", index="_postsynaptic_idx"
)
self.variables.add_reference("_target_offset", self.target, "_offset")
self.variables.add_subexpression(
"j",
dtype=self.target.source.variables["i"].dtype,
expr="_target_j - _target_offset",
index="_postsynaptic_idx",
)
if getattr(self.target, "contiguous", True):
# Contiguous subgroup simply shift the indices
self.variables.add_subexpression(
"j",
dtype=self.target.source.variables["i"].dtype,
expr="_target_j - _target_offset",
index="_postsynaptic_idx",
)
else:
# Non-contiguous subgroups need a full translation
self.variables.add_reference(
"j", self.target, "i", index="_postsynaptic_idx"
)
else:
# No subgroup or subgroup starting at 0
self.variables.add_reference("j", self, "_synaptic_post")

# Add the standard variables
self.variables.add_array(
Expand Down Expand Up @@ -1934,11 +1944,21 @@ def _add_synapses_from_arrays(self, sources, targets, n, p, namespace=None):
if "_offset" in self.source.variables:
variables.add_reference("_source_offset", self.source, "_offset")
abstract_code += "_real_sources = sources + _source_offset\n"
elif not getattr(self.source, "contiguous", True):
variables.add_reference(
"_source_sub_idx", self.source, "_sub_idx", index="sources"
)
abstract_code += "_real_sources = _source_sub_idx\n"
else:
abstract_code += "_real_sources = sources\n"
if "_offset" in self.target.variables:
variables.add_reference("_target_offset", self.target, "_offset")
abstract_code += "_real_targets = targets + _target_offset\n"
elif not getattr(self.target, "contiguous", True):
variables.add_reference(
"_target_sub_idx", self.target, "_sub_idx", index="targets"
)
abstract_code += "_real_targets = _target_sub_idx\n"
else:
abstract_code += "_real_targets = targets"
logger.debug(
Expand Down Expand Up @@ -2022,23 +2042,41 @@ def _add_synapses_generator(
outer_index_size = "N_pre" if over_presynaptic else "N_post"
outer_index_array = "_pre_idx" if over_presynaptic else "_post_idx"
outer_index_offset = "_source_offset" if over_presynaptic else "_target_offset"
outer_sub_idx = "_source_sub_idx" if over_presynaptic else "_target_sub_idx"
outer_contiguous = (
getattr(self.source, "contiguous", True)
if over_presynaptic
else getattr(self.target, "contiguous", True)
)

result_index = "j" if over_presynaptic else "i"
result_index_size = "N_post" if over_presynaptic else "N_pre"
target_idx = "_postsynaptic_idx" if over_presynaptic else "_presynaptic_idx"
result_index_array = "_post_idx" if over_presynaptic else "_pre_idx"
result_index_offset = "_target_offset" if over_presynaptic else "_source_offset"
result_sub_idx = "_target_sub_idx" if over_presynaptic else "_source_sub_idx"
result_contiguous = (
getattr(self.target, "contiguous", True)
if over_presynaptic
else getattr(self.source, "contiguous", True)
)

result_index_name = "postsynaptic" if over_presynaptic else "presynaptic"
template_kwds.update(
{
"outer_index": outer_index,
"outer_index_size": outer_index_size,
"outer_index_array": outer_index_array,
"outer_index_offset": outer_index_offset,
"outer_sub_idx": outer_sub_idx,
"outer_contiguous": outer_contiguous,
"result_index": result_index,
"result_index_size": result_index_size,
"result_index_name": result_index_name,
"result_index_array": result_index_array,
"result_index_offset": result_index_offset,
"result_sub_idx": result_sub_idx,
"result_contiguous": result_contiguous,
}
)
abstract_code = {
Expand Down Expand Up @@ -2126,11 +2164,19 @@ def _add_synapses_generator(
else:
variables.add_constant("_source_offset", value=0)

if not getattr(self.source, "contiguous", True):
variables.add_reference("_source_sub_idx", self.source, "_sub_idx")
needed_variables.append("_source_sub_idx")

if "_offset" in self.target.variables:
variables.add_reference("_target_offset", self.target, "_offset")
else:
variables.add_constant("_target_offset", value=0)

if not getattr(self.target, "contiguous", True):
variables.add_reference("_target_sub_idx", self.target, "_sub_idx")
needed_variables.append("_target_sub_idx")

variables.add_auxiliary_variable("_raw_pre_idx", dtype=np.int32)
variables.add_auxiliary_variable("_raw_post_idx", dtype=np.int32)

Expand Down
48 changes: 48 additions & 0 deletions brian2/tests/test_subgroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,54 @@ def test_synapse_creation_generator():
assert all(S5.v_pre[:] < 25)


@pytest.mark.standalone_compatible
def test_synapse_creation_generator_non_contiguous():
G1 = NeuronGroup(10, "v:1")
G2 = NeuronGroup(20, "v:1")
G1.v = "i"
G2.v = "10 + i"
SG1 = G1[[0, 2, 4, 6, 8]]
SG2 = G2[1::2]
S = Synapses(SG1, SG2, "w:1")
S.connect(j="i*2 + k for k in range(2)") # diverging connections

# connect based on pre-/postsynaptic state variables
S2 = Synapses(SG1, SG2, "w:1")
S2.connect(j="k for k in range(N_post) if v_pre > 2")

S3 = Synapses(SG1, SG2, "w:1")
S3.connect(j="k for k in range(N_post) if v_post < 25")

S4 = Synapses(SG2, SG1, "w:1")
S4.connect(j="k for k in range(N_post) if v_post > 2")

S5 = Synapses(SG2, SG1, "w:1")
S5.connect(j="k for k in range(N_post) if v_pre < 25")

run(0 * ms) # for standalone

# Internally, the "real" neuron indices should be used
assert_equal(S._synaptic_pre[:], np.arange(0, 10, 2).repeat(2))
assert_equal(S._synaptic_post[:], np.arange(1, 20, 2))
# For the user, the subgroup-relative indices should be presented
assert_equal(S.i[:], np.arange(5).repeat(2))
assert_equal(S.j[:], np.arange(10))

# N_incoming and N_outgoing should also be correct
assert all(S.N_outgoing[:] == 2)
assert all(S.N_incoming[:] == 1)

assert len(S2) == 3 * len(SG2), str(len(S2))
assert all(S2.v_pre[:] > 2)
assert len(S3) == 7 * len(SG1), f"{len(S3)} != {7 * len(SG1)} "
assert all(S3.v_post[:] < 25)

assert len(S4) == 3 * len(SG2), str(len(S4))
assert all(S4.v_post[:] > 2)
assert len(S5) == 7 * len(SG1), f"{len(S5)} != {7 * len(SG1)} "
assert all(S5.v_pre[:] < 25)


@pytest.mark.standalone_compatible
def test_synapse_creation_generator_multiple_synapses():
G1 = NeuronGroup(10, "v:1")
Expand Down
6 changes: 0 additions & 6 deletions brian2/tests/test_synapses.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,6 @@ def test_creation_errors():
Synapses(G, G, "w:1", pre="v+=w", on_pre="v+=w")
with pytest.raises(TypeError):
Synapses(G, G, "w:1", post="v+=w", on_post="v+=w")
# We do not allow non-contiguous subgroups as source/target groups at the
# moment
with pytest.raises(TypeError):
Synapses(G[::2], G, "")
with pytest.raises(TypeError):
Synapses(G, G[::2], "")


@pytest.mark.codegen_independent
Expand Down

0 comments on commit 4cb865b

Please sign in to comment.