Skip to content

Commit

Permalink
Testing and fix for 'round_robin_halt policy'; 'expsyn_curr' mechanism
Browse files Browse the repository at this point in the history
  • Loading branch information
jlubo committed May 10, 2022
1 parent 56d7051 commit 4483daf
Show file tree
Hide file tree
Showing 5 changed files with 290 additions and 32 deletions.
12 changes: 10 additions & 2 deletions arbor/label_resolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,17 @@ cell_lid_type resolver::resolve(const cell_global_label_type& iden) {
}
const auto& range_set = label_map_->at(iden.gid, iden.label.tag);

// Construct state if if doesn't exist
// Selected policy round_robin_halt
if (iden.label.policy == lid_selection_policy::round_robin_halt) {
// Use state of round_robin policy if it exists
if (state_map_[iden.gid][iden.label.tag].count(lid_selection_policy::round_robin)) {
state_map_[iden.gid][iden.label.tag][iden.label.policy] = state_map_[iden.gid][iden.label.tag][lid_selection_policy::round_robin];
}
}

// Construct state if it doesn't exist
if (!state_map_[iden.gid][iden.label.tag].count(iden.label.policy)) {
state_map_[iden.gid][iden.label.tag][iden.label.policy] = construct_state(iden.label.policy);
state_map_[iden.gid][iden.label.tag][iden.label.policy] = construct_state(iden.label.policy);
}

auto lid = std::visit([range_set](auto& state) { return state.update(range_set); }, state_map_[iden.gid][iden.label.tag][iden.label.policy]);
Expand Down
2 changes: 1 addition & 1 deletion mechanisms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ make_catalogue(
NAME default
SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/default"
OUTPUT "CAT_DEFAULT_SOURCES"
MOD exp2syn expsyn expsyn_stdp hh kamt kdrmt nax nernst pas gj
MOD exp2syn expsyn expsyn_curr expsyn_stdp hh kamt kdrmt nax nernst pas gj
CXX
PREFIX "${PROJECT_SOURCE_DIR}/mechanisms"
CXX_FLAGS_TARGET "${ARB_CXX_FLAGS_TARGET_FULL}"
Expand Down
46 changes: 46 additions & 0 deletions mechanisms/default/expsyn_curr.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
: Exponential current-based synapse

NEURON {
POINT_PROCESS expsyn_curr
RANGE w, tau, R_mem
NONSPECIFIC_CURRENT I
}

UNITS {
(ms) = (milliseconds)
(mV) = (millivolt)
(MOhm) = (megaohm)
}

PARAMETER {
R_mem = 10.0 (MOhm) : membrane resistance
tau = 5.0 (ms) : synaptic time constant
w = 4.20075 (mV) : weight
}

STATE {
g (mV) : instantaneous synaptic conductance
}

INITIAL {
I = 0
g = 0
}

BREAKPOINT {
:SOLVE state METHOD cnexp
SOLVE state METHOD sparse : to match with expsyn_curr_calcium_plasticity

I = -g / R_mem
}

DERIVATIVE state {
: Exponential decay of postsynaptic potential
g' = -g / tau
}

NET_RECEIVE(weight) {
: Start of postsynaptic potential
g = g + w
}

62 changes: 33 additions & 29 deletions python/test/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,34 +140,10 @@ class empty_recipe(arbor.recipe):
"""
pass


@_fixture
def cable_cell():
# (1) Create a morphology with a single (cylindrical) segment of length=diameter=6 μm
tree = arbor.segment_tree()
tree.append(
arbor.mnpos,
arbor.mpoint(-3, 0, 0, 3),
arbor.mpoint(3, 0, 0, 3),
tag=1,
)

# (2) Define the soma and its midpoint
labels = arbor.label_dict({'soma': '(tag 1)',
'midpoint': '(location 0 0.5)'})

# (3) Create cell and set properties
decor = arbor.decor()
decor.set_property(Vm=-40)
decor.paint('"soma"', arbor.density('hh'))
decor.place('"midpoint"', arbor.iclamp( 10, 2, 0.8), "iclamp")
decor.place('"midpoint"', arbor.spike_detector(-10), "detector")
return arbor.cable_cell(tree, labels, decor)

@_fixture
class art_spiker_recipe(arbor.recipe):
"""
Recipe fixture with 3 artificial spiking cells.
Recipe fixture with 3 artificial spiking cells and one cable cell.
"""
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -201,15 +177,43 @@ def probes(self, gid):
else:
return [arbor.cable_probe_membrane_voltage('"midpoint"')]

@cable_cell
def _cable_cell(self, cable_cell):
return cable_cell
def _cable_cell_elements(self):
# (1) Create a morphology with a single (cylindrical) segment of length=diameter=6 μm
tree = arbor.segment_tree()
tree.append(
arbor.mnpos,
arbor.mpoint(-3, 0, 0, 3),
arbor.mpoint(3, 0, 0, 3),
tag=1,
)

# (2) Define the soma and its midpoint
labels = arbor.label_dict({'soma': '(tag 1)',
'midpoint': '(location 0 0.5)'})

# (3) Create cell and set properties
decor = arbor.decor()
decor.set_property(Vm=-40)
decor.paint('"soma"', arbor.density('hh'))
decor.place('"midpoint"', arbor.iclamp( 10, 2, 0.8), "iclamp")
decor.place('"midpoint"', arbor.spike_detector(-10), "detector")

# return tuple of tree, labels, and decor for creating a cable cell (can still be modified before calling arbor.cable_cell())
return tree, labels, decor

def cell_description(self, gid):
if gid < 3:
return arbor.spike_source_cell("src", arbor.explicit_schedule(self.trains[gid]))
else:
return self._cable_cell()
tree, labels, decor = self._cable_cell_elements()
return arbor.cable_cell(tree, labels, decor)

@_fixture
def sum_weight_hh_spike():
""" Fixture returning connection weight which is just small enough to evoke an immediate spike
at t=1ms in the 'hh' neuron in 'art_spiker_recipe'
"""
return 47.5

@_fixture
@context
Expand Down
200 changes: 200 additions & 0 deletions python/test/unit/test_multiple_connections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
# -*- coding: utf-8 -*-
#
# test_multiple_connections.py

import unittest
import types
import numpy as np

import arbor as arb
from .. import fixtures

"""
tests for multiple excitatory and inhibitory connections onto the same postsynaptic label and for one connection that has the same net impact as the multiple-connection paradigm,
thereby testing the selection policies 'round_robin', 'round_robin_halt', and 'univalent' (in principle testing if the right instances of mechanisms are used)
"""

class TestMultipleConnections(unittest.TestCase):

@fixtures.context
@fixtures.art_spiker_recipe
@fixtures.sum_weight_hh_spike
def test_multiple_connections(self, context, art_spiker_recipe, sum_weight_hh_spike):
runtime = 2 # ms
dt = 0.01 # ms
weight = sum_weight_hh_spike # connection strength which is just small enough so that all connections (exc. and inh.) from neuron 0 added up will evoke an immediate spike

# define new method 'connections_on()' and overwrite the original one in the 'art_spiker_recipe' object
def connections_on(self, gid):
# incoming to neurons 0--2
if gid < 3:
return []

# incoming to neuron 3
elif gid == 3:
source_label_0 = arb.cell_global_label(0, "spike_source") # referring to the "spike_source" label of neuron 0
source_label_1 = arb.cell_global_label(1, "spike_source") # referring to the "spike_source" label of neuron 1

target_label_rr_halt = arb.cell_local_label("postsyn_target", arb.selection_policy.round_robin_halt) # referring to the current item in the "postsyn_target" label group of neuron 3
target_label_rr = arb.cell_local_label("postsyn_target", arb.selection_policy.round_robin) # referring to the current item in the "postsyn_target" label group of neuron 3, moving to the next item afterwards

conn_0_3_n1 = arb.connection(source_label_0, target_label_rr_halt, 1, 0.2) # first (exc.) connection from neuron 0 to 3
conn_0_3_n2 = arb.connection(source_label_0, target_label_rr_halt, 1, 0.2) # second (exc.) connection from neuron 0 to 3
conn_0_3_n3 = arb.connection(source_label_0, target_label_rr, 1, 0.2) # third (exc.) connection from neuron 0 to 3
conn_1_3_n1 = arb.connection(source_label_1, target_label_rr_halt, 1, 0.6) # first (inh.) connection from neuron 1 to 3
conn_1_3_n2 = arb.connection(source_label_1, target_label_rr, 1, 0.6) # second (inh.) connection from neuron 1 to 3

return [conn_0_3_n1, conn_0_3_n2, conn_0_3_n3, conn_1_3_n1, conn_1_3_n2]
art_spiker_recipe.connections_on = types.MethodType(connections_on, art_spiker_recipe)

# define new method 'cell_description()' and overwrite the original one in the 'art_spiker_recipe' object
def cell_description(self, gid):
# spike source neuron
if gid < 3:
return arb.spike_source_cell("spike_source", arb.explicit_schedule(self.trains[gid]))

# spike-receiving cable neuron
elif gid == 3:
tree, labels, decor = self._cable_cell_elements()

syn_mechanism1 = arb.mechanism("expsyn_curr")
syn_mechanism1.set('w', weight) # set weight for excitation
syn_mechanism1.set("tau", dt) # set minimal decay time

syn_mechanism2 = arb.mechanism("expsyn_curr")
syn_mechanism2.set('w', -weight) # set weight for inhibition
syn_mechanism2.set("tau", dt) # set minimal decay time

decor.place('"midpoint"', arb.synapse(syn_mechanism2), "postsyn_target") # place synaptic input from one presynaptic neuron at the center of the soma
decor.place('"midpoint"', arb.synapse(syn_mechanism1), "postsyn_target") # place synaptic input from another presynaptic neuron at the center of the soma
# (using the same label as above!)

return arb.cable_cell(tree, labels, decor)
art_spiker_recipe.cell_description = types.MethodType(cell_description, art_spiker_recipe)

# read connections from recipe for testing
connections_from_recipe = art_spiker_recipe.connections_on(3)

# connection #1 from neuron 0 to 3
self.assertEqual(connections_from_recipe[0].dest.label, "postsyn_target")
self.assertAlmostEqual(connections_from_recipe[0].weight, 1)
self.assertAlmostEqual(connections_from_recipe[0].delay, 0.2)

# connection #2 from neuron 0 to 3
self.assertEqual(connections_from_recipe[1].dest.label, "postsyn_target")
self.assertAlmostEqual(connections_from_recipe[1].weight, 1)
self.assertAlmostEqual(connections_from_recipe[1].delay, 0.2)

# connection #3 from neuron 0 to 3
self.assertEqual(connections_from_recipe[2].dest.label, "postsyn_target")
self.assertAlmostEqual(connections_from_recipe[2].weight, 1)
self.assertAlmostEqual(connections_from_recipe[2].delay, 0.2)

# connection #1 from neuron 1 to 3
self.assertEqual(connections_from_recipe[3].dest.label, "postsyn_target")
self.assertAlmostEqual(connections_from_recipe[3].weight, 1)
self.assertAlmostEqual(connections_from_recipe[3].delay, 0.6)

# connection #2 from neuron 1 to 3
self.assertEqual(connections_from_recipe[4].dest.label, "postsyn_target")
self.assertAlmostEqual(connections_from_recipe[4].weight, 1)
self.assertAlmostEqual(connections_from_recipe[4].delay, 0.6)

# construct domain_decomposition and simulation object
dd = arb.partition_load_balance(art_spiker_recipe, context)
sim = arb.simulation(art_spiker_recipe, dd, context)
sim.record(arb.spike_recording.all)

# create schedule and handle to record the membrane potential of neuron 3
reg_sched = arb.regular_schedule(0, dt, runtime)
handle_mem = sim.sample((3, 0), reg_sched)

# run the simulation
sim.run(runtime, dt)

# evaluate the outcome
data_mem, _ = sim.samples(handle_mem)[0]
#print(data_mem[(data_mem[:, 0] >= 1.0), 1])
self.assertGreater(data_mem[(np.round(data_mem[:, 0], 2) == 1.04), 1], -10)
spike_times = sim.spikes()["time"]
spike_gids = sim.spikes()["source"]["gid"]
#print(list(zip(*[spike_times, spike_gids])))
self.assertGreater(sum(spike_gids == 3), 0)
self.assertEqual([2, 1, 0, 3], spike_gids.tolist())

# spike in neuron 3 shall occur at around 1.0 ms, when the added input from all connections will cause threshold crossing
self.assertAlmostEqual(spike_times[(spike_gids == 3)][0], 1.00, delta=0.04)

@fixtures.context
@fixtures.art_spiker_recipe
@fixtures.sum_weight_hh_spike
def test_uni_connection(self, context, art_spiker_recipe, sum_weight_hh_spike):
runtime = 2 # ms
dt = 0.01 # ms
weight = sum_weight_hh_spike # set connection strength to the net sum of the connections in test_multiple_connections() (to evoke an immediate spike)

# define new method 'connections_on()' and overwrite the original one in the 'art_spiker_recipe' object
def connections_on(self, gid):
# incoming to neurons 0--2
if gid < 3:
return []

# incoming to neuron 3
elif gid == 3:
source_label_0 = arb.cell_global_label(0, "spike_source") # referring to the "spike_source" label of neuron 0
target_label_uni = arb.cell_local_label("postsyn_target", arb.selection_policy.univalent) # referring to an only item in the "postsyn_target" label group of neuron 3
conn_2_3 = arb.connection(source_label_0, target_label_uni, weight, 0.2) # connection from neuron 0 to 3

return [conn_2_3]
art_spiker_recipe.connections_on = types.MethodType(connections_on, art_spiker_recipe)

# define new method 'cell_description()' and overwrite the original one in the 'art_spiker_recipe' object
def cell_description(self, gid):
# spike source neuron
if gid < 3:
return arb.spike_source_cell("spike_source", arb.explicit_schedule(self.trains[gid]))

# spike-receiving cable neuron
elif gid == 3:
tree, labels, decor = self._cable_cell_elements()
syn_mechanism = arb.mechanism("expsyn_curr")
syn_mechanism.set('w', weight) # set weight for excitation
syn_mechanism.set("tau", dt) # set minimal decay time
decor.place('"midpoint"', arb.synapse(syn_mechanism), "postsyn_target") # place synaptic input for one neuron at the center of the soma
return arb.cable_cell(tree, labels, decor)

art_spiker_recipe.cell_description = types.MethodType(cell_description, art_spiker_recipe)

# read connections from recipe for testing
connections_from_recipe = art_spiker_recipe.connections_on(3)

# connection from neuron 0 to 3
self.assertEqual(connections_from_recipe[0].dest.label, "postsyn_target")
self.assertAlmostEqual(connections_from_recipe[0].weight, weight)
self.assertAlmostEqual(connections_from_recipe[0].delay, 0.2)

# construct domain_decomposition and simulation object
dd = arb.partition_load_balance(art_spiker_recipe, context)
sim = arb.simulation(art_spiker_recipe, dd, context)
sim.record(arb.spike_recording.all)

# create schedule and handle to record the membrane potential of neuron 3
reg_sched = arb.regular_schedule(0, dt, runtime)
handle_mem = sim.sample((3, 0), reg_sched)

# run the simulation
sim.run(runtime, dt)

# evaluate the outcome
data_mem, _ = sim.samples(handle_mem)[0]
#print(data_mem[(data_mem[:, 0] >= 1.0), 1])
self.assertGreater(data_mem[(np.round(data_mem[:, 0], 2) == 1.04), 1], -10)
spike_times = sim.spikes()["time"]
spike_gids = sim.spikes()["source"]["gid"]
#print(list(zip(*[spike_times, spike_gids])))
self.assertGreater(sum(spike_gids == 3), 0)
self.assertEqual([2, 1, 0, 3], spike_gids.tolist())

# spike in neuron 3 shall occur at around 1.0 ms, when the input will cause threshold crossing
self.assertAlmostEqual(spike_times[(spike_gids == 3)][0], 1.00, delta=0.04)

0 comments on commit 4483daf

Please sign in to comment.