diff --git a/models/synapses/clopath_synapse.nestml b/models/synapses/clopath_synapse.nestml new file mode 100644 index 000000000..8efb3b594 --- /dev/null +++ b/models/synapses/clopath_synapse.nestml @@ -0,0 +1,54 @@ +""" +""" +synapse clopath_synapse: + state: + w real = 1. @nest::weight # Synaptic weight + pre_trace real = 0. + post_membrane_potential_avg_plus mV = -70 mV + post_membrane_potential_avg_minus mV = -70 mV + post_membrane_avg_avg mV = -70 mV + + parameters: + d ms = 1 ms @nest::delay # Synaptic transmission delay + w_max real = 100 + tau_post_membrane_avg_plus ms = 7 ms + tau_post_membrane_avg_minus ms = 10 ms + tau_post_membrane_avg_avg ms = 500 ms + tau_pre_tr ms = 15 ms + theta_minus mV = -70.6 mV + theta_plus mV = -45.3 mV # should be greater than theta_minus + A_LTD real = 14.0e-5 + A_LTP real = 8.0e-5 + + equations: + pre_trace' = -pre_trace / tau_pre_tr + post_membrane_potential_avg_plus' = (-post_membrane_potential_avg_plus + post_membrane_potential) / tau_post_membrane_avg_plus + post_membrane_potential_avg_minus' = (-post_membrane_potential_avg_minus + post_membrane_potential) / tau_post_membrane_avg_minus + post_membrane_avg_avg' = (-post_membrane_avg_avg + post_membrane_potential_avg_minus) / tau_post_membrane_avg_avg + + input: + pre_spikes real <- spike + post_spikes real <- spike + post_membrane_potential mV <- continuous + + output: + spike + + onReceive(post_spikes): + if post_membrane_potential > theta_plus and post_membrane_potential_avg_plus > theta_minus: + # potentiate synapse + # w += A_LTP * pre_trace * (post_membrane_potential - theta_plus) * (post_membrane_potential_avg_plus(t - membrane_potential_delay) - theta_minus) + w += A_LTP * pre_trace * (post_membrane_potential - theta_plus) * (post_membrane_potential_avg_plus - theta_minus) + w = min(w, w_max) + + onReceive(pre_spikes): + pre_trace += 1 / tau_pre_tr + + if post_membrane_potential_avg_minus > theta_minus: + # depress synapse + #w -= A_LTD * (post_membrane_potential_avg_minus(t - membrane_potential_delay) - theta_minus) + w -= A_LTD * (post_membrane_potential_avg_minus - theta_minus) + w = max(w, 0) + + # deliver spike to postsynaptic partner + deliver_spike(w, d) diff --git a/tests/nest_tests/test_clopath_synapse.py b/tests/nest_tests/test_clopath_synapse.py new file mode 100644 index 000000000..100bfc527 --- /dev/null +++ b/tests/nest_tests/test_clopath_synapse.py @@ -0,0 +1,327 @@ +# -*- coding: utf-8 -*- +# +# clopath_synapse_test.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +import numpy as np +import os +import pytest + +import nest + +from pynestml.codegeneration.nest_tools import NESTTools +from pynestml.frontend.pynestml_frontend import generate_nest_target + +try: + import matplotlib + matplotlib.use("Agg") + import matplotlib.ticker + import matplotlib.pyplot as plt + TEST_PLOTS = True +except Exception: + TEST_PLOTS = False + +sim_mdl = True +sim_ref = True + + +class TestClopathSynapse: + + neuron_model_name = "iaf_psc_exp_nestml__with_clopath_nestml" + ref_neuron_model_name = "iaf_psc_exp_nestml_non_jit" + + synapse_model_name = "clopath_nestml__with_iaf_psc_exp_nestml" + ref_synapse_model_name = "clopath_synapse" + + @pytest.fixture(scope="module", autouse=True) + def setUp(self): + """Generate the model code""" + + jit_codegen_opts = {"neuron_synapse_pairs": [{"neuron": "iaf_psc_exp", + "synapse": "clopath_synapse", + "post_ports": ["post_spikes", + "post_membrane_potential", "V_m"]}]} + + files = [os.path.join("models", "neurons", "iaf_psc_exp.nestml"), + os.path.join("models", "synapses", "clopath_synapse.nestml")] + input_path = [os.path.realpath(os.path.join(os.path.dirname(__file__), os.path.join( + os.pardir, os.pardir, s))) for s in files] + generate_nest_target(input_path=input_path, + logging_level="DEBUG", + suffix="_nestml", + codegen_opts=jit_codegen_opts) + + def test_nest_clopath_synapse(self): + fname_snip = "" + + pre_spike_times = [1., 11., 21.] # [ms] + post_spike_times = [6., 16., 26.] # [ms] + + post_spike_times = np.sort(np.unique(1 + np.round(10 * np.sort(np.abs(np.random.randn(10)))))) # [ms] + pre_spike_times = np.sort(np.unique(1 + np.round(10 * np.sort(np.abs(np.random.randn(10)))))) # [ms] + + post_spike_times = np.sort(np.unique(1 + np.round(100 * np.sort(np.abs(np.random.randn(100)))))) # [ms] + pre_spike_times = np.sort(np.unique(1 + np.round(100 * np.sort(np.abs(np.random.randn(100)))))) # [ms] + + pre_spike_times = np.array([2., 4., 7., 8., 12., 13., 19., 23., 24., 28., 29., 30., 33., 34., + 35., 36., 38., 40., 42., 46., 51., 53., 54., 55., 56., 59., 63., 64., + 65., 66., 68., 72., 73., 76., 79., 80., 83., 84., 86., 87., 90., 95., + 99., 100., 103., 104., 105., 111., 112., 126., 131., 133., 134., 139., 147., 150., + 152., 155., 172., 175., 176., 181., 196., 197., 199., 202., 213., 215., 217., 265.]) + post_spike_times = np.array([4., 5., 6., 7., 10., 11., 12., 16., 17., 18., 19., 20., 22., 23., + 25., 27., 29., 30., 31., 32., 34., 36., 37., 38., 39., 42., 44., 46., + 48., 49., 50., 54., 56., 57., 59., 60., 61., 62., 67., 74., 76., 79., + 80., 81., 83., 88., 93., 94., 97., 99., 100., 105., 111., 113., 114., 115., + 116., 119., 123., 130., 132., 134., 135., 145., 152., 155., 158., 166., 172., 174., + 188., 194., 202., 245., 249., 289., 454.]) + + self.run_synapse_test(neuron_model_name=self.neuron_model_name, + ref_neuron_model_name=self.ref_neuron_model_name, + synapse_model_name=self.synapse_model_name, + ref_synapse_model_name=self.ref_synapse_model_name, + resolution=.5, # [ms] + delay=1.5, # [ms] + pre_spike_times=pre_spike_times, + post_spike_times=post_spike_times, + fname_snip=fname_snip) + + def run_synapse_test(self, neuron_model_name, + ref_neuron_model_name, + synapse_model_name, + ref_synapse_model_name, + resolution=1., # [ms] + delay=1., # [ms] + sim_time=None, # if None, computed from pre and post spike times + pre_spike_times=None, + post_spike_times=None, + fname_snip=""): + + if pre_spike_times is None: + pre_spike_times = [] + + if post_spike_times is None: + post_spike_times = [] + + if sim_time is None: + sim_time = max(np.amax(pre_spike_times), np.amax(post_spike_times)) + 5 * delay + + nest.set_verbosity("M_ALL") + nest.ResetKernel() + nest.Install("nestml_jit_module") + nest.Install("nestml_non_jit_module") + + print("Pre spike times: " + str(pre_spike_times)) + print("Post spike times: " + str(post_spike_times)) + + # nest.set_verbosity("M_WARNING") + nest.set_verbosity("M_ERROR") + + post_weights = {"parrot": []} + + nest.ResetKernel() + nest.SetKernelStatus({"resolution": resolution}) + + wr = nest.Create("weight_recorder") + wr_ref = nest.Create("weight_recorder") + nest.CopyModel(synapse_model_name, "clopath_nestml_rec", + {"weight_recorder": wr[0], "w": 1., "d": 1., "receptor_type": 0}) + nest.CopyModel(ref_synapse_model_name, "clopath_ref_rec", + {"weight_recorder": wr_ref[0], "weight": 1., "delay": 1., "receptor_type": 0}) + + # create spike_generators with these times + pre_sg = nest.Create("spike_generator", + params={"spike_times": pre_spike_times}) + post_sg = nest.Create("spike_generator", + params={"spike_times": post_spike_times, + "allow_offgrid_times": True}) + + # create parrot neurons and connect spike_generators + if sim_mdl: + pre_neuron = nest.Create("parrot_neuron") + post_neuron = nest.Create(neuron_model_name) + + if sim_ref: + pre_neuron_ref = nest.Create("parrot_neuron") + post_neuron_ref = nest.Create(ref_neuron_model_name) + + if sim_mdl: + if NESTTools.detect_nest_version().startswith("v2"): + spikedet_pre = nest.Create("spike_detector") + spikedet_post = nest.Create("spike_detector") + else: + spikedet_pre = nest.Create("spike_recorder") + spikedet_post = nest.Create("spike_recorder") + mm = nest.Create("multimeter", params={"record_from": [ + "V_m", "post_trace__for_clopath_nestml"]}) + if sim_ref: + if NESTTools.detect_nest_version().startswith("v2"): + spikedet_pre_ref = nest.Create("spike_detector") + spikedet_post_ref = nest.Create("spike_detector") + else: + spikedet_pre_ref = nest.Create("spike_recorder") + spikedet_post_ref = nest.Create("spike_recorder") + mm_ref = nest.Create("multimeter", params={"record_from": ["V_m"]}) + + if sim_mdl: + nest.Connect(pre_sg, pre_neuron, "one_to_one", syn_spec={"delay": 1.}) + nest.Connect(post_sg, post_neuron, "one_to_one", syn_spec={"delay": 1., "weight": 9999.}) + if NESTTools.detect_nest_version().startswith("v2"): + nest.Connect(pre_neuron, post_neuron, "all_to_all", syn_spec={"model": "clopath_nestml_rec"}) + else: + nest.Connect(pre_neuron, post_neuron, "all_to_all", syn_spec={"synapse_model": "clopath_nestml_rec"}) + nest.Connect(mm, post_neuron) + nest.Connect(pre_neuron, spikedet_pre) + nest.Connect(post_neuron, spikedet_post) + if sim_ref: + nest.Connect(pre_sg, pre_neuron_ref, "one_to_one", syn_spec={"delay": 1.}) + nest.Connect(post_sg, post_neuron_ref, "one_to_one", syn_spec={"delay": 1., "weight": 9999.}) + if NESTTools.detect_nest_version().startswith("v2"): + nest.Connect(pre_neuron_ref, post_neuron_ref, "all_to_all", + syn_spec={"model": ref_synapse_model_name}) + else: + nest.Connect(pre_neuron_ref, post_neuron_ref, "all_to_all", + syn_spec={"synapse_model": ref_synapse_model_name}) + nest.Connect(mm_ref, post_neuron_ref) + nest.Connect(pre_neuron_ref, spikedet_pre_ref) + nest.Connect(post_neuron_ref, spikedet_post_ref) + + # get Clopath synapse and weight before protocol + if sim_mdl: + syn = nest.GetConnections(source=pre_neuron, synapse_model="clopath_nestml_rec") + if sim_ref: + syn_ref = nest.GetConnections(source=pre_neuron_ref, synapse_model=ref_synapse_model_name) + + n_steps = int(np.ceil(sim_time / resolution)) + 1 + t = 0. + t_hist = [] + if sim_mdl: + w_hist = [] + if sim_ref: + w_hist_ref = [] + while t <= sim_time: + nest.Simulate(resolution) + t += resolution + t_hist.append(t) + if sim_ref: + w_hist_ref.append(nest.GetStatus(syn_ref)[0]["weight"]) + if sim_mdl: + w_hist.append(nest.GetStatus(syn)[0]["w"]) + + # plot + if TEST_PLOTS: + fig, ax = plt.subplots(nrows=2) + ax1, ax2 = ax + + if sim_mdl: + timevec = nest.GetStatus(mm, "events")[0]["times"] + V_m = nest.GetStatus(mm, "events")[0]["V_m"] + ax2.plot(timevec, nest.GetStatus(mm, "events")[0]["post_trace__for_clopath_nestml"], label="post_tr nestml") + ax1.plot(timevec, V_m, label="nestml", alpha=.7, linestyle=":") + if sim_ref: + pre_ref_spike_times_ = nest.GetStatus(spikedet_pre_ref, "events")[0]["times"] + timevec = nest.GetStatus(mm_ref, "events")[0]["times"] + V_m = nest.GetStatus(mm_ref, "events")[0]["V_m"] + ax1.plot(timevec, V_m, label="nest ref", alpha=.7) + ax1.set_ylabel("V_m") + + for _ax in ax: + _ax.grid(which="major", axis="both") + _ax.grid(which="minor", axis="x", linestyle=":", alpha=.4) + # _ax.minorticks_on() + _ax.set_xlim(0., sim_time) + _ax.legend() + fig.savefig("/tmp/clopath_synapse_test" + fname_snip + "_V_m.png", dpi=300) + + # plot + if TEST_PLOTS: + fig, ax = plt.subplots(nrows=3) + ax1, ax2, ax3 = ax + + if sim_mdl: + pre_spike_times_ = nest.GetStatus(spikedet_pre, "events")[0]["times"] + print("Actual pre spike times: " + str(pre_spike_times_)) + if sim_ref: + pre_ref_spike_times_ = nest.GetStatus(spikedet_pre_ref, "events")[0]["times"] + print("Actual pre ref spike times: " + str(pre_ref_spike_times_)) + + if sim_mdl: + n_spikes = len(pre_spike_times_) + for i in range(n_spikes): + if i == 0: + _lbl = "nestml" + else: + _lbl = None + ax1.plot(2 * [pre_spike_times_[i] + delay], [0, 1], linewidth=2, color="blue", alpha=.4, label=_lbl) + + if sim_mdl: + post_spike_times_ = nest.GetStatus(spikedet_post, "events")[0]["times"] + print("Actual post spike times: " + str(post_spike_times_)) + if sim_ref: + post_ref_spike_times_ = nest.GetStatus(spikedet_post_ref, "events")[0]["times"] + print("Actual post ref spike times: " + str(post_ref_spike_times_)) + + if sim_ref: + n_spikes = len(pre_ref_spike_times_) + for i in range(n_spikes): + if i == 0: + _lbl = "nest ref" + else: + _lbl = None + ax1.plot(2 * [pre_ref_spike_times_[i] + delay], [0, 1], + linewidth=2, color="cyan", label=_lbl, alpha=.4) + ax1.set_ylabel("Pre spikes") + + if sim_mdl: + n_spikes = len(post_spike_times_) + for i in range(n_spikes): + if i == 0: + _lbl = "nestml" + else: + _lbl = None + ax2.plot(2 * [post_spike_times_[i]], [0, 1], linewidth=2, color="black", alpha=.4, label=_lbl) + if sim_ref: + n_spikes = len(post_ref_spike_times_) + for i in range(n_spikes): + if i == 0: + _lbl = "nest ref" + else: + _lbl = None + ax2.plot(2 * [post_ref_spike_times_[i]], [0, 1], linewidth=2, color="red", alpha=.4, label=_lbl) + ax2.plot(timevec, nest.GetStatus(mm, "events")[0]["post_trace__for_clopath_nestml"], label="nestml post tr") + ax2.set_ylabel("Post spikes") + + if sim_mdl: + ax3.plot(t_hist, w_hist, marker="o", label="nestml") + if sim_ref: + ax3.plot(t_hist, w_hist_ref, linestyle="--", marker="x", label="ref") + + ax3.set_xlabel("Time [ms]") + ax3.set_ylabel("w") + for _ax in ax: + _ax.grid(which="major", axis="both") + _ax.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.arange(0, np.ceil(sim_time)))) + _ax.set_xlim(0., sim_time) + _ax.legend() + fig.savefig("/tmp/clopath_synapse_test" + fname_snip + ".png", dpi=300) + + # verify + MAX_ABS_ERROR = 1E-6 + assert np.any(np.abs(np.array(w_hist) - 1) > MAX_ABS_ERROR), "No change in the weight!" + assert np.all(np.abs(np.array(w_hist) - np.array(w_hist_ref)) < MAX_ABS_ERROR), \ + "Difference between NESTML model and reference model!"