Skip to content

Commit

Permalink
add Clopath synapse
Browse files Browse the repository at this point in the history
  • Loading branch information
C.A.P. Linssen committed Jul 5, 2023
1 parent ff24afe commit a3b11bf
Show file tree
Hide file tree
Showing 2 changed files with 381 additions and 0 deletions.
54 changes: 54 additions & 0 deletions models/synapses/clopath_synapse.nestml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""
"""
synapse clopath_synapse:
state:
w real = 1. @nest::weight # Synaptic weight
pre_trace real = 0.
post_membrane_potential_avg_plus mV = -70 mV
post_membrane_potential_avg_minus mV = -70 mV
post_membrane_avg_avg mV = -70 mV

parameters:
d ms = 1 ms @nest::delay # Synaptic transmission delay
w_max real = 100
tau_post_membrane_avg_plus ms = 7 ms
tau_post_membrane_avg_minus ms = 10 ms
tau_post_membrane_avg_avg ms = 500 ms
tau_pre_tr ms = 15 ms
theta_minus mV = -70.6 mV
theta_plus mV = -45.3 mV # should be greater than theta_minus
A_LTD real = 14.0e-5
A_LTP real = 8.0e-5

equations:
pre_trace' = -pre_trace / tau_pre_tr
post_membrane_potential_avg_plus' = (-post_membrane_potential_avg_plus + post_membrane_potential) / tau_post_membrane_avg_plus
post_membrane_potential_avg_minus' = (-post_membrane_potential_avg_minus + post_membrane_potential) / tau_post_membrane_avg_minus
post_membrane_avg_avg' = (-post_membrane_avg_avg + post_membrane_potential_avg_minus) / tau_post_membrane_avg_avg

input:
pre_spikes real <- spike
post_spikes real <- spike
post_membrane_potential mV <- continuous

output:
spike

onReceive(post_spikes):
if post_membrane_potential > theta_plus and post_membrane_potential_avg_plus > theta_minus:
# potentiate synapse
# w += A_LTP * pre_trace * (post_membrane_potential - theta_plus) * (post_membrane_potential_avg_plus(t - membrane_potential_delay) - theta_minus)
w += A_LTP * pre_trace * (post_membrane_potential - theta_plus) * (post_membrane_potential_avg_plus - theta_minus)
w = min(w, w_max)

onReceive(pre_spikes):
pre_trace += 1 / tau_pre_tr

if post_membrane_potential_avg_minus > theta_minus:
# depress synapse
#w -= A_LTD * (post_membrane_potential_avg_minus(t - membrane_potential_delay) - theta_minus)
w -= A_LTD * (post_membrane_potential_avg_minus - theta_minus)
w = max(w, 0)

# deliver spike to postsynaptic partner
deliver_spike(w, d)
327 changes: 327 additions & 0 deletions tests/nest_tests/test_clopath_synapse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,327 @@
# -*- coding: utf-8 -*-
#
# clopath_synapse_test.py
#
# This file is part of NEST.
#
# Copyright (C) 2004 The NEST Initiative
#
# NEST is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# NEST is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see <http://www.gnu.org/licenses/>.

import numpy as np
import os
import pytest

import nest

from pynestml.codegeneration.nest_tools import NESTTools
from pynestml.frontend.pynestml_frontend import generate_nest_target

try:
import matplotlib
matplotlib.use("Agg")
import matplotlib.ticker
import matplotlib.pyplot as plt
TEST_PLOTS = True
except Exception:
TEST_PLOTS = False

sim_mdl = True
sim_ref = True


class TestClopathSynapse:

neuron_model_name = "iaf_psc_exp_nestml__with_clopath_nestml"
ref_neuron_model_name = "iaf_psc_exp_nestml_non_jit"

synapse_model_name = "clopath_nestml__with_iaf_psc_exp_nestml"
ref_synapse_model_name = "clopath_synapse"

@pytest.fixture(scope="module", autouse=True)
def setUp(self):
"""Generate the model code"""

jit_codegen_opts = {"neuron_synapse_pairs": [{"neuron": "iaf_psc_exp",
"synapse": "clopath_synapse",
"post_ports": ["post_spikes",
"post_membrane_potential", "V_m"]}]}

files = [os.path.join("models", "neurons", "iaf_psc_exp.nestml"),
os.path.join("models", "synapses", "clopath_synapse.nestml")]
input_path = [os.path.realpath(os.path.join(os.path.dirname(__file__), os.path.join(
os.pardir, os.pardir, s))) for s in files]
generate_nest_target(input_path=input_path,
logging_level="DEBUG",
suffix="_nestml",
codegen_opts=jit_codegen_opts)

def test_nest_clopath_synapse(self):
fname_snip = ""

pre_spike_times = [1., 11., 21.] # [ms]
post_spike_times = [6., 16., 26.] # [ms]

post_spike_times = np.sort(np.unique(1 + np.round(10 * np.sort(np.abs(np.random.randn(10)))))) # [ms]
pre_spike_times = np.sort(np.unique(1 + np.round(10 * np.sort(np.abs(np.random.randn(10)))))) # [ms]

post_spike_times = np.sort(np.unique(1 + np.round(100 * np.sort(np.abs(np.random.randn(100)))))) # [ms]
pre_spike_times = np.sort(np.unique(1 + np.round(100 * np.sort(np.abs(np.random.randn(100)))))) # [ms]

pre_spike_times = np.array([2., 4., 7., 8., 12., 13., 19., 23., 24., 28., 29., 30., 33., 34.,
35., 36., 38., 40., 42., 46., 51., 53., 54., 55., 56., 59., 63., 64.,
65., 66., 68., 72., 73., 76., 79., 80., 83., 84., 86., 87., 90., 95.,
99., 100., 103., 104., 105., 111., 112., 126., 131., 133., 134., 139., 147., 150.,
152., 155., 172., 175., 176., 181., 196., 197., 199., 202., 213., 215., 217., 265.])
post_spike_times = np.array([4., 5., 6., 7., 10., 11., 12., 16., 17., 18., 19., 20., 22., 23.,
25., 27., 29., 30., 31., 32., 34., 36., 37., 38., 39., 42., 44., 46.,
48., 49., 50., 54., 56., 57., 59., 60., 61., 62., 67., 74., 76., 79.,
80., 81., 83., 88., 93., 94., 97., 99., 100., 105., 111., 113., 114., 115.,
116., 119., 123., 130., 132., 134., 135., 145., 152., 155., 158., 166., 172., 174.,
188., 194., 202., 245., 249., 289., 454.])

self.run_synapse_test(neuron_model_name=self.neuron_model_name,
ref_neuron_model_name=self.ref_neuron_model_name,
synapse_model_name=self.synapse_model_name,
ref_synapse_model_name=self.ref_synapse_model_name,
resolution=.5, # [ms]
delay=1.5, # [ms]
pre_spike_times=pre_spike_times,
post_spike_times=post_spike_times,
fname_snip=fname_snip)

def run_synapse_test(self, neuron_model_name,
ref_neuron_model_name,
synapse_model_name,
ref_synapse_model_name,
resolution=1., # [ms]
delay=1., # [ms]
sim_time=None, # if None, computed from pre and post spike times
pre_spike_times=None,
post_spike_times=None,
fname_snip=""):

if pre_spike_times is None:
pre_spike_times = []

if post_spike_times is None:
post_spike_times = []

if sim_time is None:
sim_time = max(np.amax(pre_spike_times), np.amax(post_spike_times)) + 5 * delay

nest.set_verbosity("M_ALL")
nest.ResetKernel()
nest.Install("nestml_jit_module")
nest.Install("nestml_non_jit_module")

print("Pre spike times: " + str(pre_spike_times))
print("Post spike times: " + str(post_spike_times))

# nest.set_verbosity("M_WARNING")
nest.set_verbosity("M_ERROR")

post_weights = {"parrot": []}

nest.ResetKernel()
nest.SetKernelStatus({"resolution": resolution})

wr = nest.Create("weight_recorder")
wr_ref = nest.Create("weight_recorder")
nest.CopyModel(synapse_model_name, "clopath_nestml_rec",
{"weight_recorder": wr[0], "w": 1., "d": 1., "receptor_type": 0})
nest.CopyModel(ref_synapse_model_name, "clopath_ref_rec",
{"weight_recorder": wr_ref[0], "weight": 1., "delay": 1., "receptor_type": 0})

# create spike_generators with these times
pre_sg = nest.Create("spike_generator",
params={"spike_times": pre_spike_times})
post_sg = nest.Create("spike_generator",
params={"spike_times": post_spike_times,
"allow_offgrid_times": True})

# create parrot neurons and connect spike_generators
if sim_mdl:
pre_neuron = nest.Create("parrot_neuron")
post_neuron = nest.Create(neuron_model_name)

if sim_ref:
pre_neuron_ref = nest.Create("parrot_neuron")
post_neuron_ref = nest.Create(ref_neuron_model_name)

if sim_mdl:
if NESTTools.detect_nest_version().startswith("v2"):
spikedet_pre = nest.Create("spike_detector")
spikedet_post = nest.Create("spike_detector")
else:
spikedet_pre = nest.Create("spike_recorder")
spikedet_post = nest.Create("spike_recorder")
mm = nest.Create("multimeter", params={"record_from": [
"V_m", "post_trace__for_clopath_nestml"]})
if sim_ref:
if NESTTools.detect_nest_version().startswith("v2"):
spikedet_pre_ref = nest.Create("spike_detector")
spikedet_post_ref = nest.Create("spike_detector")
else:
spikedet_pre_ref = nest.Create("spike_recorder")
spikedet_post_ref = nest.Create("spike_recorder")
mm_ref = nest.Create("multimeter", params={"record_from": ["V_m"]})

if sim_mdl:
nest.Connect(pre_sg, pre_neuron, "one_to_one", syn_spec={"delay": 1.})
nest.Connect(post_sg, post_neuron, "one_to_one", syn_spec={"delay": 1., "weight": 9999.})
if NESTTools.detect_nest_version().startswith("v2"):
nest.Connect(pre_neuron, post_neuron, "all_to_all", syn_spec={"model": "clopath_nestml_rec"})
else:
nest.Connect(pre_neuron, post_neuron, "all_to_all", syn_spec={"synapse_model": "clopath_nestml_rec"})
nest.Connect(mm, post_neuron)
nest.Connect(pre_neuron, spikedet_pre)
nest.Connect(post_neuron, spikedet_post)
if sim_ref:
nest.Connect(pre_sg, pre_neuron_ref, "one_to_one", syn_spec={"delay": 1.})
nest.Connect(post_sg, post_neuron_ref, "one_to_one", syn_spec={"delay": 1., "weight": 9999.})
if NESTTools.detect_nest_version().startswith("v2"):
nest.Connect(pre_neuron_ref, post_neuron_ref, "all_to_all",
syn_spec={"model": ref_synapse_model_name})
else:
nest.Connect(pre_neuron_ref, post_neuron_ref, "all_to_all",
syn_spec={"synapse_model": ref_synapse_model_name})
nest.Connect(mm_ref, post_neuron_ref)
nest.Connect(pre_neuron_ref, spikedet_pre_ref)
nest.Connect(post_neuron_ref, spikedet_post_ref)

# get Clopath synapse and weight before protocol
if sim_mdl:
syn = nest.GetConnections(source=pre_neuron, synapse_model="clopath_nestml_rec")
if sim_ref:
syn_ref = nest.GetConnections(source=pre_neuron_ref, synapse_model=ref_synapse_model_name)

n_steps = int(np.ceil(sim_time / resolution)) + 1
t = 0.
t_hist = []
if sim_mdl:
w_hist = []
if sim_ref:
w_hist_ref = []
while t <= sim_time:
nest.Simulate(resolution)
t += resolution
t_hist.append(t)
if sim_ref:
w_hist_ref.append(nest.GetStatus(syn_ref)[0]["weight"])
if sim_mdl:
w_hist.append(nest.GetStatus(syn)[0]["w"])

# plot
if TEST_PLOTS:
fig, ax = plt.subplots(nrows=2)
ax1, ax2 = ax

if sim_mdl:
timevec = nest.GetStatus(mm, "events")[0]["times"]
V_m = nest.GetStatus(mm, "events")[0]["V_m"]
ax2.plot(timevec, nest.GetStatus(mm, "events")[0]["post_trace__for_clopath_nestml"], label="post_tr nestml")
ax1.plot(timevec, V_m, label="nestml", alpha=.7, linestyle=":")
if sim_ref:
pre_ref_spike_times_ = nest.GetStatus(spikedet_pre_ref, "events")[0]["times"]
timevec = nest.GetStatus(mm_ref, "events")[0]["times"]
V_m = nest.GetStatus(mm_ref, "events")[0]["V_m"]
ax1.plot(timevec, V_m, label="nest ref", alpha=.7)
ax1.set_ylabel("V_m")

for _ax in ax:
_ax.grid(which="major", axis="both")
_ax.grid(which="minor", axis="x", linestyle=":", alpha=.4)
# _ax.minorticks_on()
_ax.set_xlim(0., sim_time)
_ax.legend()
fig.savefig("/tmp/clopath_synapse_test" + fname_snip + "_V_m.png", dpi=300)

# plot
if TEST_PLOTS:
fig, ax = plt.subplots(nrows=3)
ax1, ax2, ax3 = ax

if sim_mdl:
pre_spike_times_ = nest.GetStatus(spikedet_pre, "events")[0]["times"]
print("Actual pre spike times: " + str(pre_spike_times_))
if sim_ref:
pre_ref_spike_times_ = nest.GetStatus(spikedet_pre_ref, "events")[0]["times"]
print("Actual pre ref spike times: " + str(pre_ref_spike_times_))

if sim_mdl:
n_spikes = len(pre_spike_times_)
for i in range(n_spikes):
if i == 0:
_lbl = "nestml"
else:
_lbl = None
ax1.plot(2 * [pre_spike_times_[i] + delay], [0, 1], linewidth=2, color="blue", alpha=.4, label=_lbl)

if sim_mdl:
post_spike_times_ = nest.GetStatus(spikedet_post, "events")[0]["times"]
print("Actual post spike times: " + str(post_spike_times_))
if sim_ref:
post_ref_spike_times_ = nest.GetStatus(spikedet_post_ref, "events")[0]["times"]
print("Actual post ref spike times: " + str(post_ref_spike_times_))

if sim_ref:
n_spikes = len(pre_ref_spike_times_)
for i in range(n_spikes):
if i == 0:
_lbl = "nest ref"
else:
_lbl = None
ax1.plot(2 * [pre_ref_spike_times_[i] + delay], [0, 1],
linewidth=2, color="cyan", label=_lbl, alpha=.4)
ax1.set_ylabel("Pre spikes")

if sim_mdl:
n_spikes = len(post_spike_times_)
for i in range(n_spikes):
if i == 0:
_lbl = "nestml"
else:
_lbl = None
ax2.plot(2 * [post_spike_times_[i]], [0, 1], linewidth=2, color="black", alpha=.4, label=_lbl)
if sim_ref:
n_spikes = len(post_ref_spike_times_)
for i in range(n_spikes):
if i == 0:
_lbl = "nest ref"
else:
_lbl = None
ax2.plot(2 * [post_ref_spike_times_[i]], [0, 1], linewidth=2, color="red", alpha=.4, label=_lbl)
ax2.plot(timevec, nest.GetStatus(mm, "events")[0]["post_trace__for_clopath_nestml"], label="nestml post tr")
ax2.set_ylabel("Post spikes")

if sim_mdl:
ax3.plot(t_hist, w_hist, marker="o", label="nestml")
if sim_ref:
ax3.plot(t_hist, w_hist_ref, linestyle="--", marker="x", label="ref")

ax3.set_xlabel("Time [ms]")
ax3.set_ylabel("w")
for _ax in ax:
_ax.grid(which="major", axis="both")
_ax.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.arange(0, np.ceil(sim_time))))
_ax.set_xlim(0., sim_time)
_ax.legend()
fig.savefig("/tmp/clopath_synapse_test" + fname_snip + ".png", dpi=300)

# verify
MAX_ABS_ERROR = 1E-6
assert np.any(np.abs(np.array(w_hist) - 1) > MAX_ABS_ERROR), "No change in the weight!"
assert np.all(np.abs(np.array(w_hist) - np.array(w_hist_ref)) < MAX_ABS_ERROR), \
"Difference between NESTML model and reference model!"

0 comments on commit a3b11bf

Please sign in to comment.