From 4482756b68c79d0001984e1c9190c87140b562a1 Mon Sep 17 00:00:00 2001 From: Brandon Zhang Date: Mon, 26 Jun 2023 21:56:30 +0800 Subject: [PATCH] [dyn] add reduce models, HH-type models and channels --- brainpy/_src/__init__.py | 1 - brainpy/_src/channels/__init__.py | 25 - brainpy/_src/dyn/_docs.py | 7 + brainpy/_src/dyn/base.py | 10 +- brainpy/_src/{ => dyn}/channels/Ca.py | 0 brainpy/_src/{ => dyn}/channels/IH.py | 0 brainpy/_src/{ => dyn}/channels/K.py | 0 brainpy/_src/{ => dyn}/channels/KCa.py | 0 brainpy/_src/{ => dyn}/channels/Na.py | 0 brainpy/_src/dyn/channels/__init__.py | 25 + brainpy/_src/{ => dyn}/channels/base.py | 0 brainpy/_src/{ => dyn}/channels/leaky.py | 0 .../_src/{ => dyn}/channels/tests/test_Ca.py | 2 +- .../_src/{ => dyn}/channels/tests/test_IH.py | 2 +- .../_src/{ => dyn}/channels/tests/test_K.py | 2 +- .../_src/{ => dyn}/channels/tests/test_KCa.py | 2 +- .../_src/{ => dyn}/channels/tests/test_Na.py | 2 +- .../{ => dyn}/channels/tests/test_leaky.py | 2 +- brainpy/_src/dyn/neurons/hh.py | 757 +++++++ brainpy/_src/dyn/neurons/input.py | 210 ++ brainpy/_src/dyn/neurons/lif.py | 1803 ++++++++++++++++- brainpy/_src/dyn/neurons/tests/test_hh.py | 140 ++ brainpy/_src/dyn/neurons/tests/test_input.py | 24 + brainpy/_src/dyn/neurons/tests/test_lif.py | 41 + brainpy/channels.py | 14 +- 25 files changed, 3011 insertions(+), 58 deletions(-) delete mode 100644 brainpy/_src/channels/__init__.py rename brainpy/_src/{ => dyn}/channels/Ca.py (100%) rename brainpy/_src/{ => dyn}/channels/IH.py (100%) rename brainpy/_src/{ => dyn}/channels/K.py (100%) rename brainpy/_src/{ => dyn}/channels/KCa.py (100%) rename brainpy/_src/{ => dyn}/channels/Na.py (100%) rename brainpy/_src/{ => dyn}/channels/base.py (100%) rename brainpy/_src/{ => dyn}/channels/leaky.py (100%) rename brainpy/_src/{ => dyn}/channels/tests/test_Ca.py (99%) rename brainpy/_src/{ => dyn}/channels/tests/test_IH.py (94%) rename brainpy/_src/{ => dyn}/channels/tests/test_K.py (97%) rename brainpy/_src/{ => dyn}/channels/tests/test_KCa.py (93%) rename brainpy/_src/{ => dyn}/channels/tests/test_Na.py (96%) rename brainpy/_src/{ => dyn}/channels/tests/test_leaky.py (93%) create mode 100644 brainpy/_src/dyn/neurons/input.py create mode 100644 brainpy/_src/dyn/neurons/tests/test_hh.py create mode 100644 brainpy/_src/dyn/neurons/tests/test_input.py create mode 100644 brainpy/_src/dyn/neurons/tests/test_lif.py diff --git a/brainpy/_src/__init__.py b/brainpy/_src/__init__.py index 40a96afc6..e69de29bb 100644 --- a/brainpy/_src/__init__.py +++ b/brainpy/_src/__init__.py @@ -1 +0,0 @@ -# -*- coding: utf-8 -*- diff --git a/brainpy/_src/channels/__init__.py b/brainpy/_src/channels/__init__.py deleted file mode 100644 index 326e68b12..000000000 --- a/brainpy/_src/channels/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -# -*- coding: utf-8 -*- - -""" - -Access through ``brainpy.channels``. -""" - -from . import base, Ca, IH, K, Na, KCa, leaky - -__all__ = [] -__all__ += base.__all__ -__all__ += K.__all__ -__all__ += Na.__all__ -__all__ += Ca.__all__ -__all__ += IH.__all__ -__all__ += KCa.__all__ -__all__ += leaky.__all__ - -from .base import * -from .K import * -from .Na import * -from .IH import * -from .Ca import * -from .KCa import * -from .leaky import * diff --git a/brainpy/_src/dyn/_docs.py b/brainpy/_src/dyn/_docs.py index 738300240..823be6787 100644 --- a/brainpy/_src/dyn/_docs.py +++ b/brainpy/_src/dyn/_docs.py @@ -18,6 +18,13 @@ has_ref_var: bool. Whether has the refractory variable. Default is ``False``. '''.strip() +if_doc = ''' + V_rest: float, ArrayType, callable. Resting membrane potential. + R: float, ArrayType, callable. Membrane resistance. + tau: float, ArrayType, callable. Membrane time constant. + V_initializer: ArrayType, callable. The initializer of membrane potential. +'''.strip() + lif_doc = ''' V_rest: float, ArrayType, callable. Resting membrane potential. V_reset: float, ArrayType, callable. Reset potential after spike. diff --git a/brainpy/_src/dyn/base.py b/brainpy/_src/dyn/base.py index 95c5dc269..74b8d19c9 100644 --- a/brainpy/_src/dyn/base.py +++ b/brainpy/_src/dyn/base.py @@ -28,6 +28,7 @@ def __init__( keep_size: bool = False, mode: bm.Mode = None, name: str = None, + method: str = 'exp_auto' ): super().__init__(size=size, mode=mode, @@ -37,6 +38,9 @@ def __init__( # axis names for parallelization self.sharding = sharding + # integration method + self.method = method + # the before- / after-updates used for computing self.before_updates: Dict[str, Callable] = bm.node_dict() self.after_updates: Dict[str, Callable] = bm.node_dict() @@ -109,21 +113,21 @@ def __init__( keep_size: bool = False, mode: Optional[bm.Mode] = None, name: Optional[str] = None, + method: str = 'exp_auto', spk_fun: Callable = bm.surrogate.InvSquareGrad(), spk_type: Any = None, detach_spk: bool = False, - method: str = 'exp_auto', ): super().__init__(size=size, mode=mode, keep_size=keep_size, name=name, - sharding=sharding) + sharding=sharding, + method=method) self.spk_fun = is_callable(spk_fun) self.detach_spk = detach_spk - self.method = method self._spk_type = spk_type @property diff --git a/brainpy/_src/channels/Ca.py b/brainpy/_src/dyn/channels/Ca.py similarity index 100% rename from brainpy/_src/channels/Ca.py rename to brainpy/_src/dyn/channels/Ca.py diff --git a/brainpy/_src/channels/IH.py b/brainpy/_src/dyn/channels/IH.py similarity index 100% rename from brainpy/_src/channels/IH.py rename to brainpy/_src/dyn/channels/IH.py diff --git a/brainpy/_src/channels/K.py b/brainpy/_src/dyn/channels/K.py similarity index 100% rename from brainpy/_src/channels/K.py rename to brainpy/_src/dyn/channels/K.py diff --git a/brainpy/_src/channels/KCa.py b/brainpy/_src/dyn/channels/KCa.py similarity index 100% rename from brainpy/_src/channels/KCa.py rename to brainpy/_src/dyn/channels/KCa.py diff --git a/brainpy/_src/channels/Na.py b/brainpy/_src/dyn/channels/Na.py similarity index 100% rename from brainpy/_src/channels/Na.py rename to brainpy/_src/dyn/channels/Na.py diff --git a/brainpy/_src/dyn/channels/__init__.py b/brainpy/_src/dyn/channels/__init__.py index e69de29bb..326e68b12 100644 --- a/brainpy/_src/dyn/channels/__init__.py +++ b/brainpy/_src/dyn/channels/__init__.py @@ -0,0 +1,25 @@ +# -*- coding: utf-8 -*- + +""" + +Access through ``brainpy.channels``. +""" + +from . import base, Ca, IH, K, Na, KCa, leaky + +__all__ = [] +__all__ += base.__all__ +__all__ += K.__all__ +__all__ += Na.__all__ +__all__ += Ca.__all__ +__all__ += IH.__all__ +__all__ += KCa.__all__ +__all__ += leaky.__all__ + +from .base import * +from .K import * +from .Na import * +from .IH import * +from .Ca import * +from .KCa import * +from .leaky import * diff --git a/brainpy/_src/channels/base.py b/brainpy/_src/dyn/channels/base.py similarity index 100% rename from brainpy/_src/channels/base.py rename to brainpy/_src/dyn/channels/base.py diff --git a/brainpy/_src/channels/leaky.py b/brainpy/_src/dyn/channels/leaky.py similarity index 100% rename from brainpy/_src/channels/leaky.py rename to brainpy/_src/dyn/channels/leaky.py diff --git a/brainpy/_src/channels/tests/test_Ca.py b/brainpy/_src/dyn/channels/tests/test_Ca.py similarity index 99% rename from brainpy/_src/channels/tests/test_Ca.py rename to brainpy/_src/dyn/channels/tests/test_Ca.py index d58905385..3c08c9873 100644 --- a/brainpy/_src/channels/tests/test_Ca.py +++ b/brainpy/_src/dyn/channels/tests/test_Ca.py @@ -4,7 +4,7 @@ import brainpy as bp import brainpy.math as bm from absl.testing import parameterized -from brainpy._src.channels import Ca +from brainpy._src.dyn.channels import Ca class Test_Ca(parameterized.TestCase): def test_Ca(self): diff --git a/brainpy/_src/channels/tests/test_IH.py b/brainpy/_src/dyn/channels/tests/test_IH.py similarity index 94% rename from brainpy/_src/channels/tests/test_IH.py rename to brainpy/_src/dyn/channels/tests/test_IH.py index 4767622ff..f4e589a0d 100644 --- a/brainpy/_src/channels/tests/test_IH.py +++ b/brainpy/_src/dyn/channels/tests/test_IH.py @@ -4,7 +4,7 @@ import brainpy as bp import brainpy.math as bm from absl.testing import parameterized -from brainpy._src.channels import IH, Ca +from brainpy._src.dyn.channels import IH, Ca class Test_IH(parameterized.TestCase): diff --git a/brainpy/_src/channels/tests/test_K.py b/brainpy/_src/dyn/channels/tests/test_K.py similarity index 97% rename from brainpy/_src/channels/tests/test_K.py rename to brainpy/_src/dyn/channels/tests/test_K.py index dab67d359..1fc625b90 100644 --- a/brainpy/_src/channels/tests/test_K.py +++ b/brainpy/_src/dyn/channels/tests/test_K.py @@ -4,7 +4,7 @@ import brainpy as bp import brainpy.math as bm from absl.testing import parameterized -from brainpy._src.channels import K +from brainpy._src.dyn.channels import K class Test_K(parameterized.TestCase): bm.random.seed(1234) diff --git a/brainpy/_src/channels/tests/test_KCa.py b/brainpy/_src/dyn/channels/tests/test_KCa.py similarity index 93% rename from brainpy/_src/channels/tests/test_KCa.py rename to brainpy/_src/dyn/channels/tests/test_KCa.py index e907861bb..d422dc28a 100644 --- a/brainpy/_src/channels/tests/test_KCa.py +++ b/brainpy/_src/dyn/channels/tests/test_KCa.py @@ -4,7 +4,7 @@ import brainpy as bp import brainpy.math as bm from absl.testing import parameterized -from brainpy._src.channels import KCa, Ca +from brainpy._src.dyn.channels import KCa, Ca class Test_KCa(parameterized.TestCase): bm.random.seed(1234) diff --git a/brainpy/_src/channels/tests/test_Na.py b/brainpy/_src/dyn/channels/tests/test_Na.py similarity index 96% rename from brainpy/_src/channels/tests/test_Na.py rename to brainpy/_src/dyn/channels/tests/test_Na.py index c8a5f0f58..f2112162f 100644 --- a/brainpy/_src/channels/tests/test_Na.py +++ b/brainpy/_src/dyn/channels/tests/test_Na.py @@ -4,7 +4,7 @@ import brainpy as bp import brainpy.math as bm from absl.testing import parameterized -from brainpy._src.channels import Na +from brainpy._src.dyn.channels import Na class Test_Na(parameterized.TestCase): diff --git a/brainpy/_src/channels/tests/test_leaky.py b/brainpy/_src/dyn/channels/tests/test_leaky.py similarity index 93% rename from brainpy/_src/channels/tests/test_leaky.py rename to brainpy/_src/dyn/channels/tests/test_leaky.py index 43abd6fcf..341e7c213 100644 --- a/brainpy/_src/channels/tests/test_leaky.py +++ b/brainpy/_src/dyn/channels/tests/test_leaky.py @@ -4,7 +4,7 @@ import brainpy as bp import brainpy.math as bm from absl.testing import parameterized -from brainpy._src.channels import leaky +from brainpy._src.dyn.channels import leaky class Test_Leaky(parameterized.TestCase): bm.random.seed(1234) diff --git a/brainpy/_src/dyn/neurons/hh.py b/brainpy/_src/dyn/neurons/hh.py index e69de29bb..22a528136 100644 --- a/brainpy/_src/dyn/neurons/hh.py +++ b/brainpy/_src/dyn/neurons/hh.py @@ -0,0 +1,757 @@ +from functools import partial +from typing import Union, Callable, Optional, Any, Sequence + +import brainpy.math as bm +from brainpy._src.context import share +from brainpy._src.initialize import ZeroInit, OneInit, Uniform +from brainpy._src.integrators import odeint, JointEq +from brainpy.check import is_initializer +from brainpy.types import Shape, ArrayType, Sharding +from brainpy._src.dyn.base import HHTypeNeuLTC + + +__all__ = [ + 'HHLTC', + 'HH', + 'MorrisLecarLTC', + 'MorrisLecar', + 'WangBuzsakiModelLTC', + 'WangBuzsakiModel' +] + + +class HHLTC(HHTypeNeuLTC): + r"""Hodgkin–Huxley neuron model. + + **Model Descriptions** + + The Hodgkin-Huxley (HH; Hodgkin & Huxley, 1952) model [1]_ for the generation of + the nerve action potential is one of the most successful mathematical models of + a complex biological process that has ever been formulated. The basic concepts + expressed in the model have proved a valid approach to the study of bio-electrical + activity from the most primitive single-celled organisms such as *Paramecium*, + right through to the neurons within our own brains. + + Mathematically, the model is given by, + + .. math:: + + C \frac {dV} {dt} = -(\bar{g}_{Na} m^3 h (V &-E_{Na}) + + \bar{g}_K n^4 (V-E_K) + g_{leak} (V - E_{leak})) + I(t) + + \frac {dx} {dt} &= \alpha_x (1-x) - \beta_x, \quad x\in {\rm{\{m, h, n\}}} + + &\alpha_m(V) = \frac {0.1(V+40)}{1-\exp(\frac{-(V + 40)} {10})} + + &\beta_m(V) = 4.0 \exp(\frac{-(V + 65)} {18}) + + &\alpha_h(V) = 0.07 \exp(\frac{-(V+65)}{20}) + + &\beta_h(V) = \frac 1 {1 + \exp(\frac{-(V + 35)} {10})} + + &\alpha_n(V) = \frac {0.01(V+55)}{1-\exp(-(V+55)/10)} + + &\beta_n(V) = 0.125 \exp(\frac{-(V + 65)} {80}) + + The illustrated example of HH neuron model please see `this notebook <../neurons/HH_model.ipynb>`_. + + The Hodgkin–Huxley model can be thought of as a differential equation system with + four state variables, :math:`V_{m}(t),n(t),m(t)`, and :math:`h(t)`, that change + with respect to time :math:`t`. The system is difficult to study because it is a + nonlinear system and cannot be solved analytically. However, there are many numeric + methods available to analyze the system. Certain properties and general behaviors, + such as limit cycles, can be proven to exist. + + *1. Center manifold* + + Because there are four state variables, visualizing the path in phase space can + be difficult. Usually two variables are chosen, voltage :math:`V_{m}(t)` and the + potassium gating variable :math:`n(t)`, allowing one to visualize the limit cycle. + However, one must be careful because this is an ad-hoc method of visualizing the + 4-dimensional system. This does not prove the existence of the limit cycle. + + .. image:: ../../../_static/Hodgkin_Huxley_Limit_Cycle.png + :align: center + + A better projection can be constructed from a careful analysis of the Jacobian of + the system, evaluated at the equilibrium point. Specifically, the eigenvalues of + the Jacobian are indicative of the center manifold's existence. Likewise, the + eigenvectors of the Jacobian reveal the center manifold's orientation. The + Hodgkin–Huxley model has two negative eigenvalues and two complex eigenvalues + with slightly positive real parts. The eigenvectors associated with the two + negative eigenvalues will reduce to zero as time :math:`t` increases. The remaining + two complex eigenvectors define the center manifold. In other words, the + 4-dimensional system collapses onto a 2-dimensional plane. Any solution + starting off the center manifold will decay towards the *center manifold*. + Furthermore, the limit cycle is contained on the center manifold. + + *2. Bifurcations* + + If the injected current :math:`I` were used as a bifurcation parameter, then the + Hodgkin–Huxley model undergoes a Hopf bifurcation. As with most neuronal models, + increasing the injected current will increase the firing rate of the neuron. + One consequence of the Hopf bifurcation is that there is a minimum firing rate. + This means that either the neuron is not firing at all (corresponding to zero + frequency), or firing at the minimum firing rate. Because of the all-or-none + principle, there is no smooth increase in action potential amplitude, but + rather there is a sudden "jump" in amplitude. The resulting transition is + known as a `canard `_. + + .. image:: ../../../_static/Hodgkins_Huxley_bifurcation_by_I.gif + :align: center + + The following image shows the bifurcation diagram of the Hodgkin–Huxley model + as a function of the external drive :math:`I` [3]_. The green lines show the amplitude + of a stable limit cycle and the blue lines indicate unstable limit-cycle behaviour, + both born from Hopf bifurcations. The solid red line shows the stable fixed point + and the black line shows the unstable fixed point. + + .. image:: ../../../_static/Hodgkin_Huxley_bifurcation.png + :align: center + + **Model Examples** + + .. plot:: + :include-source: True + + >>> import brainpy as bp + >>> group = bp.neurons.HH(2) + >>> runner = bp.DSRunner(group, monitors=['V'], inputs=('input', 10.)) + >>> runner.run(200.) + >>> bp.visualize.line_plot(runner.mon.ts, runner.mon.V, show=True) + + .. plot:: + :include-source: True + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> + >>> group = bp.neurons.HH(2) + >>> + >>> I1 = bp.inputs.spike_input(sp_times=[500., 550., 1000, 1030, 1060, 1100, 1200], sp_lens=5, sp_sizes=5., duration=2000, ) + >>> I2 = bp.inputs.spike_input(sp_times=[600., 900, 950, 1500], sp_lens=5, sp_sizes=5., duration=2000, ) + >>> I1 += bp.math.random.normal(0, 3, size=I1.shape) + >>> I2 += bp.math.random.normal(0, 3, size=I2.shape) + >>> I = bm.stack((I1, I2), axis=-1) + >>> + >>> runner = bp.DSRunner(group, monitors=['V'], inputs=('input', I, 'iter')) + >>> runner.run(2000.) + >>> + >>> fig, gs = bp.visualize.get_figure(1, 1, 3, 8) + >>> fig.add_subplot(gs[0, 0]) + >>> plt.plot(runner.mon.ts, runner.mon.V[:, 0]) + >>> plt.plot(runner.mon.ts, runner.mon.V[:, 1] + 130) + >>> plt.xlim(10, 2000) + >>> plt.xticks([]) + >>> plt.yticks([]) + >>> plt.show() + + Parameters + ---------- + size: sequence of int, int + The size of the neuron group. + ENa: float, ArrayType, Initializer, callable + The reversal potential of sodium. Default is 50 mV. + gNa: float, ArrayType, Initializer, callable + The maximum conductance of sodium channel. Default is 120 msiemens. + EK: float, ArrayType, Initializer, callable + The reversal potential of potassium. Default is -77 mV. + gK: float, ArrayType, Initializer, callable + The maximum conductance of potassium channel. Default is 36 msiemens. + EL: float, ArrayType, Initializer, callable + The reversal potential of learky channel. Default is -54.387 mV. + gL: float, ArrayType, Initializer, callable + The conductance of learky channel. Default is 0.03 msiemens. + V_th: float, ArrayType, Initializer, callable + The threshold of the membrane spike. Default is 20 mV. + C: float, ArrayType, Initializer, callable + The membrane capacitance. Default is 1 ufarad. + V_initializer: ArrayType, Initializer, callable + The initializer of membrane potential. + m_initializer: ArrayType, Initializer, callable + The initializer of m channel. + h_initializer: ArrayType, Initializer, callable + The initializer of h channel. + n_initializer: ArrayType, Initializer, callable + The initializer of n channel. + method: str + The numerical integration method. + name: str + The group name. + + References + ---------- + + .. [1] Hodgkin, Alan L., and Andrew F. Huxley. "A quantitative description + of membrane current and its application to conduction and excitation + in nerve." The Journal of physiology 117.4 (1952): 500. + .. [2] https://en.wikipedia.org/wiki/Hodgkin%E2%80%93Huxley_model + .. [3] Ashwin, Peter, Stephen Coombes, and Rachel Nicks. "Mathematical + frameworks for oscillatory network dynamics in neuroscience." + The Journal of Mathematical Neuroscience 6, no. 1 (2016): 1-92. + """ + def __init__( + self, + size: Union[int, Sequence[int]], + sharding: Any = None, + keep_size: bool = False, + mode: bm.Mode = None, + name: str = None, + method: str = 'exp_auto', + init_var: bool = True, + + # neuron parameters + ENa: Union[float, ArrayType, Callable] = 50., + gNa: Union[float, ArrayType, Callable] = 120., + EK: Union[float, ArrayType, Callable] = -77., + gK: Union[float, ArrayType, Callable] = 36., + EL: Union[float, ArrayType, Callable] = -54.387, + gL: Union[float, ArrayType, Callable] = 0.03, + V_th: Union[float, ArrayType, Callable] = 20., + C: Union[float, ArrayType, Callable] = 1.0, + V_initializer: Union[Callable, ArrayType] = Uniform(-70, -60.), + m_initializer: Optional[Union[Callable, ArrayType]] = None, + h_initializer: Optional[Union[Callable, ArrayType]] = None, + n_initializer: Optional[Union[Callable, ArrayType]] = None, + ): + # initialization + super().__init__(size=size, + sharding=sharding, + keep_size=keep_size, + mode=mode, + name=name, + method=method) + + # parameters + self.ENa = self.init_param(ENa) + self.EK = self.init_param(EK) + self.EL = self.init_param(EL) + self.gNa = self.init_param(gNa) + self.gK = self.init_param(gK) + self.gL = self.init_param(gL) + self.C = self.init_param(C) + self.V_th = self.init_param(V_th) + + # initializers + self._m_initializer = is_initializer(m_initializer, allow_none=True) + self._h_initializer = is_initializer(h_initializer, allow_none=True) + self._n_initializer = is_initializer(n_initializer, allow_none=True) + self._V_initializer = is_initializer(V_initializer) + + # integral + self.integral = odeint(method=method, f=self.derivative) + + # model + if init_var: + self.reset_state(self.mode) + + # m channel + m_alpha = lambda self, V: 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10)) + m_beta = lambda self, V: 4.0 * bm.exp(-(V + 65) / 18) + m_inf = lambda self, V: self.m_alpha(V) / (self.m_alpha(V) + self.m_beta(V)) + dm = lambda self, m, t, V: self.m_alpha(V) * (1 - m) - self.m_beta(V) * m + + # h channel + h_alpha = lambda self, V: 0.07 * bm.exp(-(V + 65) / 20.) + h_beta = lambda self, V: 1 / (1 + bm.exp(-(V + 35) / 10)) + h_inf = lambda self, V: self.h_alpha(V) / (self.h_alpha(V) + self.h_beta(V)) + dh = lambda self, h, t, V: self.h_alpha(V) * (1 - h) - self.h_beta(V) * h + + # n channel + n_alpha = lambda self, V: 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10)) + n_beta = lambda self, V: 0.125 * bm.exp(-(V + 65) / 80) + n_inf = lambda self, V: self.n_alpha(V) / (self.n_alpha(V) + self.n_beta(V)) + dn = lambda self, n, t, V: self.n_alpha(V) * (1 - n) - self.n_beta(V) * n + + def reset_state(self, batch_size=None): + self.V = self.init_variable(self._V_initializer, batch_size) + if self._m_initializer is None: + self.m = bm.Variable(self.m_inf(self.V.value), batch_axis=self.V.batch_axis) + else: + self.m = self.init_variable(self._m_initializer, batch_size) + if self._h_initializer is None: + self.h = bm.Variable(self.h_inf(self.V.value), batch_axis=self.V.batch_axis) + else: + self.h = self.init_variable(self._h_initializer, batch_size) + if self._n_initializer is None: + self.n = bm.Variable(self.n_inf(self.V.value), batch_axis=self.V.batch_axis) + else: + self.n = self.init_variable(self._n_initializer, batch_size) + self.spike = self.init_variable(partial(bm.zeros, dtype=bool), batch_size) + + def dV(self, V, t, m, h, n, I): + for out in self.cur_inputs.values(): + I += out(V) + I_Na = (self.gNa * m ** 3.0 * h) * (V - self.ENa) + I_K = (self.gK * n ** 4.0) * (V - self.EK) + I_leak = self.gL * (V - self.EL) + dVdt = (- I_Na - I_K - I_leak + I) / self.C + return dVdt + + @property + def derivative(self): + return JointEq(self.dV, self.dm, self.dh, self.dn) + + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + x = 0. if x is None else x + + V, m, h, n = self.integral(self.V.value, self.m.value, self.h.value, self.n.value, t, x, dt) + self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th) + self.V.value = V + self.m.value = m + self.h.value = h + self.n.value = n + return self.spike.value + + def return_for_delay(self): + return self.spike + + +class HH(HHLTC): + def dV(self, V, t, m, h, n, I): + I_Na = (self.gNa * m ** 3.0 * h) * (V - self.ENa) + I_K = (self.gK * n ** 4.0) * (V - self.EK) + I_leak = self.gL * (V - self.EL) + dVdt = (- I_Na - I_K - I_leak + I) / self.C + return dVdt + + @property + def derivative(self): + return JointEq(self.dV, self.dm, self.dh, self.dn) + + def update(self, x=None): + x = 0. if x is None else x + for out in self.cur_inputs.values(): + x += out(self.V.value) + super().update(x) + + +class MorrisLecarLTC(HHTypeNeuLTC): + r"""The Morris-Lecar neuron model. + + **Model Descriptions** + + The Morris-Lecar model [4]_ (Also known as :math:`I_{Ca}+I_K`-model) + is a two-dimensional "reduced" excitation model applicable to + systems having two non-inactivating voltage-sensitive conductances. + This model was named after Cathy Morris and Harold Lecar, who + derived it in 1981. Because it is two-dimensional, the Morris-Lecar + model is one of the favorite conductance-based models in computational neuroscience. + + The original form of the model employed an instantaneously + responding voltage-sensitive Ca2+ conductance for excitation and a delayed + voltage-dependent K+ conductance for recovery. The equations of the model are: + + .. math:: + + \begin{aligned} + C\frac{dV}{dt} =& - g_{Ca} M_{\infty} (V - V_{Ca})- g_{K} W(V - V_{K}) - + g_{Leak} (V - V_{Leak}) + I_{ext} \\ + \frac{dW}{dt} =& \frac{W_{\infty}(V) - W}{ \tau_W(V)} + \end{aligned} + + Here, :math:`V` is the membrane potential, :math:`W` is the "recovery variable", + which is almost invariably the normalized :math:`K^+`-ion conductance, and + :math:`I_{ext}` is the applied current stimulus. + + **Model Examples** + + .. plot:: + :include-source: True + + >>> import brainpy as bp + >>> + >>> group = bp.neurons.MorrisLecar(1) + >>> runner = bp.DSRunner(group, monitors=['V', 'W'], inputs=('input', 100.)) + >>> runner.run(1000) + >>> + >>> fig, gs = bp.visualize.get_figure(2, 1, 3, 8) + >>> fig.add_subplot(gs[0, 0]) + >>> bp.visualize.line_plot(runner.mon.ts, runner.mon.W, ylabel='W') + >>> fig.add_subplot(gs[1, 0]) + >>> bp.visualize.line_plot(runner.mon.ts, runner.mon.V, ylabel='V', show=True) + + + **Model Parameters** + + ============= ============== ======== ======================================================= + **Parameter** **Init Value** **Unit** **Explanation** + ------------- -------------- -------- ------------------------------------------------------- + V_Ca 130 mV Equilibrium potentials of Ca+.(mV) + g_Ca 4.4 \ Maximum conductance of corresponding Ca+.(mS/cm2) + V_K -84 mV Equilibrium potentials of K+.(mV) + g_K 8 \ Maximum conductance of corresponding K+.(mS/cm2) + V_Leak -60 mV Equilibrium potentials of leak current.(mV) + g_Leak 2 \ Maximum conductance of leak current.(mS/cm2) + C 20 \ Membrane capacitance.(uF/cm2) + V1 -1.2 \ Potential at which M_inf = 0.5.(mV) + V2 18 \ Reciprocal of slope of voltage dependence of M_inf.(mV) + V3 2 \ Potential at which W_inf = 0.5.(mV) + V4 30 \ Reciprocal of slope of voltage dependence of W_inf.(mV) + phi 0.04 \ A temperature factor. (1/s) + V_th 10 mV The spike threshold. + ============= ============== ======== ======================================================= + + References + ---------- + + .. [4] Lecar, Harold. "Morris-lecar model." Scholarpedia 2.10 (2007): 1333. + .. [5] http://www.scholarpedia.org/article/Morris-Lecar_model + .. [6] https://en.wikipedia.org/wiki/Morris%E2%80%93Lecar_model + """ + def __init__( + self, + size: Union[int, Sequence[int]], + sharding: Any = None, + keep_size: bool = False, + mode: bm.Mode = None, + name: str = None, + method: str = 'exp_auto', + init_var: bool = True, + + # neuron parameters + V_Ca: Union[float, ArrayType, Callable] = 130., + g_Ca: Union[float, ArrayType, Callable] = 4.4, + V_K: Union[float, ArrayType, Callable] = -84., + g_K: Union[float, ArrayType, Callable] = 8., + V_leak: Union[float, ArrayType, Callable] = -60., + g_leak: Union[float, ArrayType, Callable] = 2., + C: Union[float, ArrayType, Callable] = 20., + V1: Union[float, ArrayType, Callable] = -1.2, + V2: Union[float, ArrayType, Callable] = 18., + V3: Union[float, ArrayType, Callable] = 2., + V4: Union[float, ArrayType, Callable] = 30., + phi: Union[float, ArrayType, Callable] = 0.04, + V_th: Union[float, ArrayType, Callable] = 10., + W_initializer: Union[Callable, ArrayType] = OneInit(0.02), + V_initializer: Union[Callable, ArrayType] = Uniform(-70., -60.), + ): + # initialization + super().__init__(size=size, + sharding=sharding, + keep_size=keep_size, + mode=mode, + name=name, + method=method) + + # parameters + self.V_Ca = self.init_param(V_Ca) + self.g_Ca = self.init_param(g_Ca) + self.V_K = self.init_param(V_K) + self.g_K = self.init_param(g_K) + self.V_leak = self.init_param(V_leak) + self.g_leak = self.init_param(g_leak) + self.C = self.init_param(C) + self.V1 = self.init_param(V1) + self.V2 = self.init_param(V2) + self.V3 = self.init_param(V3) + self.V4 = self.init_param(V4) + self.phi = self.init_param(phi) + self.V_th = self.init_param(V_th) + + # initializers + self._W_initializer = is_initializer(W_initializer) + self._V_initializer = is_initializer(V_initializer) + + # integral + self.integral = odeint(method=method, f=self.derivative) + + # model + if init_var: + self.reset_state(self.mode) + + def reset_state(self, batch_size=None): + self.V = self.init_variable(self._V_initializer, batch_size) + self.W = self.init_variable(self._W_initializer, batch_size) + self.spike = self.init_variable(partial(bm.zeros, dtype=bool), batch_size) + + def dV(self, V, t, W, I): + for out in self.cur_inputs.values(): + I += out(V) + M_inf = (1 / 2) * (1 + bm.tanh((V - self.V1) / self.V2)) + I_Ca = self.g_Ca * M_inf * (V - self.V_Ca) + I_K = self.g_K * W * (V - self.V_K) + I_Leak = self.g_leak * (V - self.V_leak) + dVdt = (- I_Ca - I_K - I_Leak + I) / self.C + return dVdt + + def dW(self, W, t, V): + tau_W = 1 / (self.phi * bm.cosh((V - self.V3) / (2 * self.V4))) + W_inf = (1 / 2) * (1 + bm.tanh((V - self.V3) / self.V4)) + dWdt = (W_inf - W) / tau_W + return dWdt + + @property + def derivative(self): + return JointEq(self.dV, self.dW) + + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + x = 0. if x is None else x + + V, W = self.integral(self.V, self.W, t, x, dt) + + spike = bm.logical_and(self.V < self.V_th, V >= self.V_th) + self.V.value = V + self.W.value = W + self.spike.value = spike + return spike + + def return_for_delay(self): + return self.spike + + +class MorrisLecar(MorrisLecarLTC): + def dV(self, V, t, W, I): + M_inf = (1 / 2) * (1 + bm.tanh((V - self.V1) / self.V2)) + I_Ca = self.g_Ca * M_inf * (V - self.V_Ca) + I_K = self.g_K * W * (V - self.V_K) + I_Leak = self.g_leak * (V - self.V_leak) + dVdt = (- I_Ca - I_K - I_Leak + I) / self.C + return dVdt + + def dW(self, W, t, V): + tau_W = 1 / (self.phi * bm.cosh((V - self.V3) / (2 * self.V4))) + W_inf = (1 / 2) * (1 + bm.tanh((V - self.V3) / self.V4)) + dWdt = (W_inf - W) / tau_W + return dWdt + + @property + def derivative(self): + return JointEq(self.dV, self.dW) + + def update(self, x=None): + x = 0. if x is None else x + for out in self.cur_inputs.values(): + x += out(self.V.value) + super().update(x) + + +class WangBuzsakiModelLTC(HHTypeNeuLTC): + r"""Wang-Buzsaki model [9]_, an implementation of a modified Hodgkin-Huxley model. + + Each model is described by a single compartment and obeys the current balance equation: + + .. math:: + + C_{m} \frac{d V}{d t}=-I_{\mathrm{Na}}-I_{\mathrm{K}}-I_{\mathrm{L}}-I_{\mathrm{syn}}+I_{\mathrm{app}} + + where :math:`C_{m}=1 \mu \mathrm{F} / \mathrm{cm}^{2}` and :math:`I_{\mathrm{app}}` is the + injected current (in :math:`\mu \mathrm{A} / \mathrm{cm}^{2}` ). The leak current + :math:`I_{\mathrm{L}}=g_{\mathrm{L}}\left(V-E_{\mathrm{L}}\right)` has a conductance + :math:`g_{\mathrm{L}}=0.1 \mathrm{mS} / \mathrm{cm}^{2}`, so that the passive time constant + :math:`\tau_{0}=C_{m} / g_{\mathrm{L}}=10 \mathrm{msec} ; E_{\mathrm{L}}=-65 \mathrm{mV}`. + + The spike-generating :math:`\mathrm{Na}^{+}` and :math:`\mathrm{K}^{+}` voltage-dependent ion + currents :math:`\left(I_{\mathrm{Na}}\right.` and :math:`I_{\mathrm{K}}` ) are of the + Hodgkin-Huxley type (Hodgkin and Huxley, 1952). The transient sodium current + :math:`I_{\mathrm{Na}}=g_{\mathrm{Na}} m_{\infty}^{3} h\left(V-E_{\mathrm{Na}}\right)`, + where the activation variable :math:`m` is assumed fast and substituted by its steady-state + function :math:`m_{\infty}=\alpha_{m} /\left(\alpha_{m}+\beta_{m}\right)` ; + :math:`\alpha_{m}(V)=-0.1(V+35) /(\exp (-0.1(V+35))-1), \beta_{m}(V)=4 \exp (-(V+60) / 18)`. + The inactivation variable :math:`h` obeys a first-order kinetics: + + .. math:: + + \frac{d h}{d t}=\phi\left(\alpha_{h}(1-h)-\beta_{h} h\right) + + where :math:`\alpha_{h}(V)=0.07 \exp (-(V+58) / 20)` and + :math:`\beta_{h}(V)=1 /(\exp (-0.1(V+28)) +1) \cdot g_{\mathrm{Na}}=35 \mathrm{mS} / \mathrm{cm}^{2}` ; + :math:`E_{\mathrm{Na}}=55 \mathrm{mV}, \phi=5 .` + + The delayed rectifier :math:`I_{\mathrm{K}}=g_{\mathrm{K}} n^{4}\left(V-E_{\mathrm{K}}\right)`, + where the activation variable :math:`n` obeys the following equation: + + .. math:: + + \frac{d n}{d t}=\phi\left(\alpha_{n}(1-n)-\beta_{n} n\right) + + with :math:`\alpha_{n}(V)=-0.01(V+34) /(\exp (-0.1(V+34))-1)` and + :math:`\beta_{n}(V)=0.125\exp (-(V+44) / 80)` ; :math:`g_{\mathrm{K}}=9 \mathrm{mS} / \mathrm{cm}^{2}`, and + :math:`E_{\mathrm{K}}=-90 \mathrm{mV}`. + + + Parameters + ---------- + size: sequence of int, int + The size of the neuron group. + ENa: float, ArrayType, Initializer, callable + The reversal potential of sodium. Default is 50 mV. + gNa: float, ArrayType, Initializer, callable + The maximum conductance of sodium channel. Default is 120 msiemens. + EK: float, ArrayType, Initializer, callable + The reversal potential of potassium. Default is -77 mV. + gK: float, ArrayType, Initializer, callable + The maximum conductance of potassium channel. Default is 36 msiemens. + EL: float, ArrayType, Initializer, callable + The reversal potential of learky channel. Default is -54.387 mV. + gL: float, ArrayType, Initializer, callable + The conductance of learky channel. Default is 0.03 msiemens. + V_th: float, ArrayType, Initializer, callable + The threshold of the membrane spike. Default is 20 mV. + C: float, ArrayType, Initializer, callable + The membrane capacitance. Default is 1 ufarad. + phi: float, ArrayType, Initializer, callable + The temperature regulator constant. + V_initializer: ArrayType, Initializer, callable + The initializer of membrane potential. + h_initializer: ArrayType, Initializer, callable + The initializer of h channel. + n_initializer: ArrayType, Initializer, callable + The initializer of n channel. + method: str + The numerical integration method. + name: str + The group name. + + References + ---------- + .. [9] Wang, X.J. and Buzsaki, G., (1996) Gamma oscillation by synaptic + inhibition in a hippocampal interneuronal network model. Journal of + neuroscience, 16(20), pp.6402-6413. + + """ + def __init__( + self, + size: Union[int, Sequence[int]], + sharding: Any = None, + keep_size: bool = False, + mode: bm.Mode = None, + name: str = None, + method: str = 'exp_auto', + init_var: bool = True, + + # neuron parameters + ENa: Union[float, ArrayType, Callable] = 55., + gNa: Union[float, ArrayType, Callable] = 35., + EK: Union[float, ArrayType, Callable] = -90., + gK: Union[float, ArrayType, Callable] = 9., + EL: Union[float, ArrayType, Callable] = -65, + gL: Union[float, ArrayType, Callable] = 0.1, + V_th: Union[float, ArrayType, Callable] = 20., + phi: Union[float, ArrayType, Callable] = 5.0, + C: Union[float, ArrayType, Callable] = 1.0, + V_initializer: Union[Callable, ArrayType] = OneInit(-65.), + h_initializer: Union[Callable, ArrayType] = OneInit(0.6), + n_initializer: Union[Callable, ArrayType] = OneInit(0.32), + ): + # initialization + super().__init__(size=size, + sharding=sharding, + keep_size=keep_size, + mode=mode, + name=name, + method=method) + + # parameters + self.ENa = self.init_param(ENa) + self.EK = self.init_param(EK) + self.EL = self.init_param(EL) + self.gNa = self.init_param(gNa) + self.gK = self.init_param(gK) + self.gL = self.init_param(gL) + self.phi = self.init_param(phi) + self.C = self.init_param(C) + self.V_th = self.init_param(V_th) + + # initializers + self._h_initializer = is_initializer(h_initializer) + self._n_initializer = is_initializer(n_initializer) + self._V_initializer = is_initializer(V_initializer) + + # integral + self.integral = odeint(method=method, f=self.derivative) + + # model + if init_var: + self.reset_state(self.mode) + + def reset_state(self, batch_size=None): + self.V = self.init_variable(self._V_initializer, batch_size) + self.h = self.init_variable(self._h_initializer, batch_size) + self.n = self.init_variable(self._n_initializer, batch_size) + self.spike = self.init_variable(partial(bm.zeros, dtype=bool), batch_size) + + def m_inf(self, V): + alpha = -0.1 * (V + 35) / (bm.exp(-0.1 * (V + 35)) - 1) + beta = 4. * bm.exp(-(V + 60.) / 18.) + return alpha / (alpha + beta) + + def dh(self, h, t, V): + alpha = 0.07 * bm.exp(-(V + 58) / 20) + beta = 1 / (bm.exp(-0.1 * (V + 28)) + 1) + dhdt = alpha * (1 - h) - beta * h + return self.phi * dhdt + + def dn(self, n, t, V): + alpha = -0.01 * (V + 34) / (bm.exp(-0.1 * (V + 34)) - 1) + beta = 0.125 * bm.exp(-(V + 44) / 80) + dndt = alpha * (1 - n) - beta * n + return self.phi * dndt + + def dV(self, V, t, h, n, I): + for out in self.cur_inputs.values(): + I += out(V) + INa = self.gNa * self.m_inf(V) ** 3 * h * (V - self.ENa) + IK = self.gK * n ** 4 * (V - self.EK) + IL = self.gL * (V - self.EL) + dVdt = (- INa - IK - IL + I) / self.C + return dVdt + + @property + def derivative(self): + return JointEq(self.dV, self.dh, self.dn) + + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + x = 0. if x is None else x + + V, h, n = self.integral(self.V, self.h, self.n, t, x, dt) + self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th) + self.V.value = V + self.h.value = h + self.n.value = n + return self.spike.value + + def return_for_delay(self): + return self.spike + +class WangBuzsakiModel(WangBuzsakiModelLTC): + def m_inf(self, V): + alpha = -0.1 * (V + 35) / (bm.exp(-0.1 * (V + 35)) - 1) + beta = 4. * bm.exp(-(V + 60.) / 18.) + return alpha / (alpha + beta) + + def dh(self, h, t, V): + alpha = 0.07 * bm.exp(-(V + 58) / 20) + beta = 1 / (bm.exp(-0.1 * (V + 28)) + 1) + dhdt = alpha * (1 - h) - beta * h + return self.phi * dhdt + + def dn(self, n, t, V): + alpha = -0.01 * (V + 34) / (bm.exp(-0.1 * (V + 34)) - 1) + beta = 0.125 * bm.exp(-(V + 44) / 80) + dndt = alpha * (1 - n) - beta * n + return self.phi * dndt + + def dV(self, V, t, h, n, I): + INa = self.gNa * self.m_inf(V) ** 3 * h * (V - self.ENa) + IK = self.gK * n ** 4 * (V - self.EK) + IL = self.gL * (V - self.EL) + dVdt = (- INa - IK - IL + I) / self.C + return dVdt + + @property + def derivative(self): + return JointEq(self.dV, self.dh, self.dn) + + def update(self, x=None): + x = 0. if x is None else x + for out in self.cur_inputs.values(): + x += out(self.V.value) + super().update(x) \ No newline at end of file diff --git a/brainpy/_src/dyn/neurons/input.py b/brainpy/_src/dyn/neurons/input.py new file mode 100644 index 000000000..68bba78a9 --- /dev/null +++ b/brainpy/_src/dyn/neurons/input.py @@ -0,0 +1,210 @@ +# -*- coding: utf-8 -*- + +from typing import Union, Sequence, Any + +import jax.numpy as jnp +from brainpy._src.context import share +import brainpy.math as bm +from brainpy._src.initialize import Initializer, parameter, variable_ +from brainpy.types import Shape, ArrayType +from brainpy._src.dyn.base import NeuDyn + +__all__ = [ + 'InputGroup', + 'OutputGroup', + 'SpikeTimeGroup', + 'PoissonGroup', +] + + +class InputGroup(NeuDyn): + """Input neuron group for place holder. + + Parameters + ---------- + size: int, tuple of int + keep_size: bool + mode: Mode + name: str + """ + + def __init__( + self, + size: Union[int, Sequence[int]], + sharding: Any = None, + keep_size: bool = False, + mode: bm.Mode = None, + name: str = None, + ): + super(InputGroup, self).__init__(name=name, + sharding=sharding, + size=size, + keep_size=keep_size, + mode=mode) + self.spike = None + + def update(self, x): + return x + + def reset_state(self, batch_size=None): + pass + + +class OutputGroup(NeuDyn): + """Output neuron group for place holder. + + Parameters + ---------- + size: int, tuple of int + keep_size: bool + mode: Mode + name: str + """ + + def __init__( + self, + size: Union[int, Sequence[int]], + sharding: Any = None, + keep_size: bool = False, + mode: bm.Mode = None, + name: str = None, + ): + super(OutputGroup, self).__init__(name=name, + sharding=sharding, + size=size, + keep_size=keep_size, + mode=mode) + self.spike = None + + def update(self, x): + return x + + def reset_state(self, batch_size=None): + pass + + +class SpikeTimeGroup(NeuDyn): + """The input neuron group characterized by spikes emitting at given times. + + >>> # Get 2 neurons, firing spikes at 10 ms and 20 ms. + >>> SpikeTimeGroup(2, times=[10, 20]) + >>> # or + >>> # Get 2 neurons, the neuron 0 fires spikes at 10 ms and 20 ms. + >>> SpikeTimeGroup(2, times=[10, 20], indices=[0, 0]) + >>> # or + >>> # Get 2 neurons, neuron 0 fires at 10 ms and 30 ms, neuron 1 fires at 20 ms. + >>> SpikeTimeGroup(2, times=[10, 20, 30], indices=[0, 1, 0]) + >>> # or + >>> # Get 2 neurons; at 10 ms, neuron 0 fires; at 20 ms, neuron 0 and 1 fire; + >>> # at 30 ms, neuron 1 fires. + >>> SpikeTimeGroup(2, times=[10, 20, 20, 30], indices=[0, 0, 1, 1]) + + Parameters + ---------- + size : int, tuple, list + The neuron group geometry. + indices : list, tuple, ArrayType + The neuron indices at each time point to emit spikes. + times : list, tuple, ArrayType + The time points which generate the spikes. + name : str, optional + The name of the dynamic system. + """ + + def __init__( + self, + size: Union[int, Sequence[int]], + indices: Union[Sequence, ArrayType], + times: Union[Sequence, ArrayType], + name: str = None, + sharding: Any = None, + keep_size: bool = False, + mode: bm.Mode = None, + need_sort: bool = True, + ): + super(SpikeTimeGroup, self).__init__(size=size, + sharding=sharding, + name=name, + keep_size=keep_size, + mode=mode) + + # parameters + if keep_size: + raise NotImplementedError(f'Do not support keep_size=True in {self.__class__.__name__}') + if len(indices) != len(times): + raise ValueError(f'The length of "indices" and "times" must be the same. ' + f'However, we got {len(indices)} != {len(times)}.') + self.num_times = len(times) + + # data about times and indices + self.times = bm.asarray(times) + self.indices = bm.asarray(indices, dtype=bm.int_) + if need_sort: + sort_idx = bm.argsort(self.times) + self.indices.value = self.indices[sort_idx] + self.times.value = self.times[sort_idx] + + # variables + self.reset_state(self.mode) + + def reset_state(self, batch_size=None): + self.i = bm.Variable(bm.asarray(0)) + self.spike = variable_(lambda s: jnp.zeros(s, dtype=bool), self.varshape, batch_size) + + def update(self): + self.spike.value = bm.zeros_like(self.spike) + bm.while_loop(self._body_fun, self._cond_fun, share.load('t')) + return self.spike.value + + # functions + def _cond_fun(self, t): + i = self.i.value + return bm.logical_and(i < self.num_times, t >= self.times[i]) + + def _body_fun(self, t): + i = self.i.value + if isinstance(self.mode, bm.BatchingMode): + self.spike[:, self.indices[i]] = True + else: + self.spike[self.indices[i]] = True + self.i += 1 + + +class PoissonGroup(NeuDyn): + """Poisson Neuron Group. + """ + + def __init__( + self, + size: Shape, + freqs: Union[int, float, jnp.ndarray, bm.Array, Initializer], + seed: int = None, + name: str = None, + sharding: Any = None, + keep_size: bool = False, + mode: bm.Mode = None, + ): + super(PoissonGroup, self).__init__(size=size, + sharding=sharding, + name=name, + keep_size=keep_size, + mode=mode) + + # parameters + self.keep_size = keep_size + self.seed = seed + self.freqs = parameter(freqs, self.num, allow_none=False) + + # variables + self.reset_state(self.mode) + + def update(self): + spikes = bm.random.rand_like(self.spike) <= (self.freqs * share.dt / 1000.) + self.spike.value = spikes + return spikes + + def reset_state(self, batch_size=None): + self.spike = variable_(lambda s: jnp.zeros(s, dtype=bool), self.varshape, batch_size) + + + diff --git a/brainpy/_src/dyn/neurons/lif.py b/brainpy/_src/dyn/neurons/lif.py index 3b975689c..ad999cd7f 100644 --- a/brainpy/_src/dyn/neurons/lif.py +++ b/brainpy/_src/dyn/neurons/lif.py @@ -5,23 +5,143 @@ import brainpy.math as bm from brainpy._src.context import share -from brainpy._src.initialize import ZeroInit -from brainpy._src.integrators import odeint +from brainpy._src.initialize import ZeroInit, OneInit +from brainpy._src.integrators import odeint, JointEq from brainpy.check import is_initializer from brainpy.types import Shape, ArrayType, Sharding -from brainpy._src.dyn._docs import ref_doc, lif_doc, pneu_doc, dpneu_doc, ltc_doc +from brainpy._src.dyn._docs import ref_doc, lif_doc, pneu_doc, dpneu_doc, ltc_doc, if_doc from brainpy._src.dyn.base import GradNeuDyn __all__ = [ + 'IF', + 'IFLTC', 'Lif', 'LifLTC', 'LifRef', 'LifRefLTC', + 'ExpIF', + 'ExpIFLTC', + 'ExpIFRef', + 'ExpIFRefLTC', + 'AdExIF', + 'AdExIFLTC', + 'AdExIFRef', + 'AdExIFRefLTC', + 'QuaIF', + 'QuaIFLTC', + 'QuaIFRef', + 'QuaIFRefLTC', + 'AdQuaIF', + 'AdQuaIFLTC', + 'AdQuaIFRef', + 'AdQuaIFRefLTC', + 'Gif', + 'GifLTC', + 'GifRef', + 'GifRefLTC', ] -class PIF(GradNeuDyn): - pass +class IFLTC(GradNeuDyn): + r"""Leaky Integrator Model %s. + + **Model Descriptions** + + This class implements a leaky integrator model, in which its dynamics is + given by: + + .. math:: + + \tau \frac{dV}{dt} = - (V(t) - V_{rest}) + RI(t) + + where :math:`V` is the membrane potential, :math:`V_{rest}` is the resting + membrane potential, :math:`\tau` is the time constant, and :math:`R` is the + resistance. + + Args: + %s + %s + %s + """ + def __init__( + self, + size: Shape, + sharding: Optional[Sequence[str]] = None, + keep_size: bool = False, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + spk_fun: Callable = bm.surrogate.InvSquareGrad(), + spk_type: Any = None, + detach_spk: bool = False, + method: str = 'exp_auto', + init_var: bool = True, + + # neuron parameters + V_rest: Union[float, ArrayType, Callable] = 0., + R: Union[float, ArrayType, Callable] = 1., + tau: Union[float, ArrayType, Callable] = 10., + V_initializer: Union[Callable, ArrayType] = ZeroInit(), + ): + # initialization + super().__init__(size=size, + name=name, + keep_size=keep_size, + mode=mode, + sharding=sharding, + spk_fun=spk_fun, + detach_spk=detach_spk, + method=method, + spk_type=spk_type) + + # parameters + self.V_rest = self.init_param(V_rest) + self.tau = self.init_param(tau) + self.R = self.init_param(R) + + # initializers + self._V_initializer = is_initializer(V_initializer) + + # integral + self.integral = odeint(method=method, f=self.derivative) + + # variables + if init_var: + self.reset_state(self.mode) + + def derivative(self, V, t, I): + for out in self.cur_inputs.values(): + I += out(V) + return (-V + self.V_rest + self.R * I) / self.tau + + def reset_state(self, batch_size=None): + self.V = self.init_variable(self._V_initializer, batch_size) + + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + x = 0. if x is None else x + + # integrate membrane potential + self.V.value = self.integral(self.V.value, t, x, dt) + return self.V.value + + def return_for_delay(self): + return self.V + + +class IF(IFLTC): + def derivative(self, V, t, I): + return (-V + self.V_rest + self.R * I) / self.tau + + def update(self, x=None): + x = 0. if x is None else x + for out in self.cur_inputs.values(): + x += out(self.V.value) + super().update(x) + + +IF.__doc__ = IFLTC.__doc__ % ('', if_doc, pneu_doc, dpneu_doc) +IFLTC.__doc__ = IFLTC.__doc__ % (ltc_doc, if_doc, pneu_doc, dpneu_doc) class LifLTC(GradNeuDyn): @@ -198,7 +318,7 @@ def __init__( V_initializer: Union[Callable, ArrayType] = ZeroInit(), # new neuron parameter - tau_ref: Optional[Union[float, ArrayType, Callable]] = None, + tau_ref: Union[float, ArrayType, Callable] = 0., ref_var: bool = False, ): # initialization @@ -296,21 +416,1672 @@ def update(self, x=None): LifRefLTC.__doc__ = LifRefLTC.__doc__ % (ltc_doc, lif_doc, pneu_doc, dpneu_doc, ref_doc) -class PExpIF(GradNeuDyn): - pass +class ExpIFLTC(GradNeuDyn): + r"""Exponential integrate-and-fire neuron model %s. + + **Model Descriptions** + + In the exponential integrate-and-fire model [1]_, the differential + equation for the membrane potential is given by + + .. math:: + + \tau\frac{d V}{d t}= - (V-V_{rest}) + \Delta_T e^{\frac{V-V_T}{\Delta_T}} + RI(t), \\ + \text{after} \, V(t) \gt V_{th}, V(t) = V_{reset} \, \text{last} \, \tau_{ref} \, \text{ms} + + This equation has an exponential nonlinearity with "sharpness" parameter :math:`\Delta_{T}` + and "threshold" :math:`\vartheta_{rh}`. + + The moment when the membrane potential reaches the numerical threshold :math:`V_{th}` + defines the firing time :math:`t^{(f)}`. After firing, the membrane potential is reset to + :math:`V_{rest}` and integration restarts at time :math:`t^{(f)}+\tau_{\rm ref}`, + where :math:`\tau_{\rm ref}` is an absolute refractory time. + If the numerical threshold is chosen sufficiently high, :math:`V_{th}\gg v+\Delta_T`, + its exact value does not play any role. The reason is that the upswing of the action + potential for :math:`v\gg v +\Delta_{T}` is so rapid, that it goes to infinity in + an incredibly short time. The threshold :math:`V_{th}` is introduced mainly for numerical + convenience. For a formal mathematical analysis of the model, the threshold can be pushed + to infinity. + + The model was first introduced by Nicolas Fourcaud-Trocmé, David Hansel, Carl van Vreeswijk + and Nicolas Brunel [1]_. The exponential nonlinearity was later confirmed by Badel et al. [3]_. + It is one of the prominent examples of a precise theoretical prediction in computational + neuroscience that was later confirmed by experimental neuroscience. + + Two important remarks: + + - (i) The right-hand side of the above equation contains a nonlinearity + that can be directly extracted from experimental data [3]_. In this sense the exponential + nonlinearity is not an arbitrary choice but directly supported by experimental evidence. + - (ii) Even though it is a nonlinear model, it is simple enough to calculate the firing + rate for constant input, and the linear response to fluctuations, even in the presence + of input noise [4]_. + + **Model Examples** + + .. plot:: + :include-source: True + + >>> import brainpy as bp + >>> group = bp.neurons.ExpIF(1) + >>> runner = bp.DSRunner(group, monitors=['V'], inputs=('input', 10.)) + >>> runner.run(300., ) + >>> bp.visualize.line_plot(runner.mon.ts, runner.mon.V, ylabel='V', show=True) + + + **Model Parameters** + + ============= ============== ======== =================================================== + **Parameter** **Init Value** **Unit** **Explanation** + ------------- -------------- -------- --------------------------------------------------- + V_rest -65 mV Resting potential. + V_reset -68 mV Reset potential after spike. + V_th -30 mV Threshold potential of spike. + V_T -59.9 mV Threshold potential of generating action potential. + delta_T 3.48 \ Spike slope factor. + R 1 \ Membrane resistance. + tau 10 \ Membrane time constant. Compute by R * C. + tau_ref 1.7 \ Refractory period length. + ============= ============== ======== =================================================== + + **Model Variables** + + ================== ================= ========================================================= + **Variables name** **Initial Value** **Explanation** + ------------------ ----------------- --------------------------------------------------------- + V 0 Membrane potential. + input 0 External and synaptic input current. + spike False Flag to mark whether the neuron is spiking. + refractory False Flag to mark whether the neuron is in refractory period. + t_last_spike -1e7 Last spike time stamp. + ================== ================= ========================================================= + + **References** + + .. [1] Fourcaud-Trocmé, Nicolas, et al. "How spike generation + mechanisms determine the neuronal response to fluctuating + inputs." Journal of Neuroscience 23.37 (2003): 11628-11640. + .. [2] Gerstner, W., Kistler, W. M., Naud, R., & Paninski, L. (2014). + Neuronal dynamics: From single neurons to networks and models + of cognition. Cambridge University Press. + .. [3] Badel, Laurent, Sandrine Lefort, Romain Brette, Carl CH Petersen, + Wulfram Gerstner, and Magnus JE Richardson. "Dynamic IV curves + are reliable predictors of naturalistic pyramidal-neuron voltage + traces." Journal of Neurophysiology 99, no. 2 (2008): 656-666. + .. [4] Richardson, Magnus JE. "Firing-rate response of linear and nonlinear + integrate-and-fire neurons to modulated current-based and + conductance-based synaptic drive." Physical Review E 76, no. 2 (2007): 021919. + .. [5] https://en.wikipedia.org/wiki/Exponential_integrate-and-fire + """ + def __init__( + self, + size: Shape, + sharding: Optional[Sequence[str]] = None, + keep_size: bool = False, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + spk_fun: Callable = bm.surrogate.InvSquareGrad(), + spk_type: Any = None, + detach_spk: bool = False, + method: str = 'exp_auto', + init_var: bool = True, + + # neuron parameters + V_rest: Union[float, ArrayType, Callable] = -65., + V_reset: Union[float, ArrayType, Callable] = -68., + V_th: Union[float, ArrayType, Callable] = -30., + V_T: Union[float, ArrayType, Callable] = -59.9, + delta_T: Union[float, ArrayType, Callable] = 3.48, + R: Union[float, ArrayType, Callable] = 1., + tau: Union[float, ArrayType, Callable] = 10., + V_initializer: Union[Callable, ArrayType] = ZeroInit(), + ): + # initialization + super().__init__(size=size, + name=name, + keep_size=keep_size, + mode=mode, + sharding=sharding, + spk_fun=spk_fun, + detach_spk=detach_spk, + method=method, + spk_type=spk_type) + # parameters + self.V_rest = self.init_param(V_rest) + self.V_reset = self.init_param(V_reset) + self.V_th = self.init_param(V_th) + self.V_T = self.init_param(V_T) + self.delta_T = self.init_param(delta_T) + self.tau = self.init_param(tau) + self.R = self.init_param(R) + + # initializers + self._V_initializer = is_initializer(V_initializer) + + # integral + self.integral = odeint(method=method, f=self.derivative) + + # variables + if init_var: + self.reset_state(self.mode) + + def derivative(self, V, t, I): + for out in self.cur_inputs.values(): + I += out(V) + exp_v = self.delta_T * bm.exp((V - self.V_T) / self.delta_T) + dvdt = (- (V - self.V_rest) + exp_v + self.R * I) / self.tau + return dvdt + + def reset_state(self, batch_size=None): + self.V = self.init_variable(self._V_initializer, batch_size) + self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_type), batch_size) + + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + x = 0. if x is None else x + + # integrate membrane potential + V = self.integral(self.V.value, t, x, dt) + + # spike, spiking time, and membrane potential reset + if isinstance(self.mode, bm.TrainingMode): + spike = self.spk_fun(V - self.V_th) + spike = stop_gradient(spike) if self.detach_spk else spike + V += (self.V_reset - V) * spike + + else: + spike = V >= self.V_th + V = bm.where(spike, self.V_reset, V) + + self.V.value = V + self.spike.value = spike + return spike + + def return_for_delay(self): + return self.spike + + +class ExpIF(ExpIFLTC): + def derivative(self, V, t, I): + exp_v = self.delta_T * bm.exp((V - self.V_T) / self.delta_T) + dvdt = (- (V - self.V_rest) + exp_v + self.R * I) / self.tau + return dvdt + + def update(self, x=None): + x = 0. if x is None else x + for out in self.cur_inputs.values(): + x += out(self.V.value) + super().update(x) + +ExpIF.__doc__ = ExpIFLTC.__doc__ % ('') +ExpIFLTC.__doc__ = ExpIFLTC.__doc__ % (ltc_doc) + +class ExpIFRefLTC(ExpIFLTC): + def __init__( + self, + size: Shape, + sharding: Optional[Sharding] = None, + keep_size: bool = False, + mode: Optional[bm.Mode] = None, + spk_fun: Callable = bm.surrogate.InvSquareGrad(), + spk_type: Any = None, + detach_spk: bool = False, + method: str = 'exp_auto', + name: Optional[str] = None, + init_var: bool = True, + + # old neuron parameter + V_rest: Union[float, ArrayType, Callable] = -65., + V_reset: Union[float, ArrayType, Callable] = -68., + V_th: Union[float, ArrayType, Callable] = -30., + V_T: Union[float, ArrayType, Callable] = -59.9, + delta_T: Union[float, ArrayType, Callable] = 3.48, + R: Union[float, ArrayType, Callable] = 1., + tau: Union[float, ArrayType, Callable] = 10., + V_initializer: Union[Callable, ArrayType] = ZeroInit(), + + # new neuron parameter + tau_ref: Union[float, ArrayType, Callable] = 0., + ref_var: bool = False, + ): + # initialization + super().__init__( + size=size, + name=name, + keep_size=keep_size, + mode=mode, + method=method, + sharding=sharding, + spk_fun=spk_fun, + detach_spk=detach_spk, + spk_type=spk_type, + + init_var=False, + V_rest=V_rest, + V_reset=V_reset, + V_th=V_th, + V_T=V_T, + delta_T=delta_T, + R=R, + tau=tau, + V_initializer=V_initializer, + ) + + # parameters + self.ref_var = ref_var + self.tau_ref = self.init_param(tau_ref) + + # initializers + self._V_initializer = is_initializer(V_initializer) + + # integral + self.integral = odeint(method=method, f=self.derivative) -class PAdExIF(GradNeuDyn): - pass + # variables + if init_var: + self.reset_state(self.mode) + def reset_state(self, batch_size=None): + super().reset_state(batch_size) + self.t_last_spike = self.init_variable(bm.ones, batch_size) + self.t_last_spike.fill_(-1e7) + if self.ref_var: + self.refractory = self.init_variable(partial(bm.zeros, dtype=bool), batch_size) -class PQuaIF(GradNeuDyn): - pass + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + x = 0. if x is None else x + # integrate membrane potential + V = self.integral(self.V.value, t, x, dt) -class PAdQuaIF(GradNeuDyn): - pass + # refractory + refractory = (t - self.t_last_spike) <= self.tau_ref + if isinstance(self.mode, bm.TrainingMode): + refractory = stop_gradient(refractory) + V = bm.where(refractory, self.V.value, V) + # spike, refractory, spiking time, and membrane potential reset + if isinstance(self.mode, bm.TrainingMode): + spike = self.spk_fun(V - self.V_th) + spike_no_grad = stop_gradient(spike) if self.detach_spk else spike + V += (self.V_reset - V) * spike_no_grad + spike_ = spike_no_grad > 0. + # will be used in other place, like Delta Synapse, so stop its gradient + if self.ref_var: + self.refractory.value = stop_gradient(bm.logical_or(refractory, spike_).value) + t_last_spike = stop_gradient(bm.where(spike_, t, self.t_last_spike.value)) -class PGIF(GradNeuDyn): - pass + else: + spike = V >= self.V_th + V = bm.where(spike, self.V_reset, V) + if self.ref_var: + self.refractory.value = bm.logical_or(refractory, spike) + t_last_spike = bm.where(spike, t, self.t_last_spike.value) + self.V.value = V + self.spike.value = spike + self.t_last_spike.value = t_last_spike + return spike + + +class ExpIFRef(ExpIFRefLTC): + def derivative(self, V, t, I): + exp_v = self.delta_T * bm.exp((V - self.V_T) / self.delta_T) + dvdt = (- (V - self.V_rest) + exp_v + self.R * I) / self.tau + return dvdt + + def update(self, x=None): + x = 0. if x is None else x + for out in self.cur_inputs.values(): + x += out(self.V.value) + super().update(x) + + +class AdExIFLTC(GradNeuDyn): + def __init__( + self, + size: Shape, + sharding: Optional[Sequence[str]] = None, + keep_size: bool = False, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + spk_fun: Callable = bm.surrogate.InvSquareGrad(), + spk_type: Any = None, + detach_spk: bool = False, + method: str = 'exp_auto', + init_var: bool = True, + + # neuron parameters + V_rest: Union[float, ArrayType, Callable] = -65., + V_reset: Union[float, ArrayType, Callable] = -68., + V_th: Union[float, ArrayType, Callable] = -30., + V_T: Union[float, ArrayType, Callable] = -59.9, + delta_T: Union[float, ArrayType, Callable] = 3.48, + a: Union[float, ArrayType, Callable] = 1., + b: Union[float, ArrayType, Callable] = 1., + tau: Union[float, ArrayType, Callable] = 10., + tau_w: Union[float, ArrayType, Callable] = 30., + R: Union[float, ArrayType, Callable] = 1., + V_initializer: Union[Callable, ArrayType] = ZeroInit(), + w_initializer: Union[Callable, ArrayType] = ZeroInit(), + ): + # initialization + super().__init__(size=size, + name=name, + keep_size=keep_size, + mode=mode, + sharding=sharding, + spk_fun=spk_fun, + detach_spk=detach_spk, + method=method, + spk_type=spk_type) + # parameters + self.V_rest = self.init_param(V_rest) + self.V_reset = self.init_param(V_reset) + self.V_th = self.init_param(V_th) + self.V_T = self.init_param(V_T) + self.a = self.init_param(a) + self.b = self.init_param(b) + self.R = self.init_param(R) + self.delta_T = self.init_param(delta_T) + self.tau = self.init_param(tau) + self.tau_w = self.init_param(tau_w) + + # initializers + self._V_initializer = is_initializer(V_initializer) + self._w_initializer = is_initializer(w_initializer) + + # integral + self.integral = odeint(method=method, f=self.derivative) + + # variables + if init_var: + self.reset_state(self.mode) + + def dV(self, V, t, w, I): + for out in self.cur_inputs.values(): + I += out(V) + exp = self.delta_T * bm.exp((V - self.V_T) / self.delta_T) + dVdt = (- V + self.V_rest + exp - self.R * w + self.R * I) / self.tau + return dVdt + + def dw(self, w, t, V): + dwdt = (self.a * (V - self.V_rest) - w) / self.tau_w + return dwdt + + @property + def derivative(self): + return JointEq([self.dV, self.dw]) + + def reset_state(self, batch_size=None): + self.V = self.init_variable(self._V_initializer, batch_size) + self.w = self.init_variable(self._w_initializer, batch_size) + self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_type), batch_size) + + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + x = 0. if x is None else x + + # integrate membrane potential + V, w = self.integral(self.V.value, self.w.value, t, x, dt) + + # spike, spiking time, and membrane potential reset + if isinstance(self.mode, bm.TrainingMode): + spike = self.spk_fun(V - self.V_th) + spike = stop_gradient(spike) if self.detach_spk else spike + V += (self.V_reset - V) * spike + w += self.b * spike + + else: + spike = V >= self.V_th + V = bm.where(spike, self.V_reset, V) + w = bm.where(spike, w + self.b, w) + + self.V.value = V + self.w.value = w + self.spike.value = spike + return spike + + def return_for_delay(self): + return self.spike + + +class AdExIF(AdExIFLTC): + def dV(self, V, t, w, I): + exp = self.delta_T * bm.exp((V - self.V_T) / self.delta_T) + dVdt = (- V + self.V_rest + exp - self.R * w + self.R * I) / self.tau + return dVdt + + def dw(self, w, t, V): + dwdt = (self.a * (V - self.V_rest) - w) / self.tau_w + return dwdt + + @property + def derivative(self): + return JointEq([self.dV, self.dw]) + + def update(self, x=None): + x = 0. if x is None else x + for out in self.cur_inputs.values(): + x += out(self.V.value) + super().update(x) + + +class AdExIFRefLTC(AdExIFLTC): + def __init__( + self, + size: Shape, + sharding: Optional[Sharding] = None, + keep_size: bool = False, + mode: Optional[bm.Mode] = None, + spk_fun: Callable = bm.surrogate.InvSquareGrad(), + spk_type: Any = None, + detach_spk: bool = False, + method: str = 'exp_auto', + name: Optional[str] = None, + init_var: bool = True, + + # old neuron parameter + V_rest: Union[float, ArrayType, Callable] = -65., + V_reset: Union[float, ArrayType, Callable] = -68., + V_th: Union[float, ArrayType, Callable] = -30., + V_T: Union[float, ArrayType, Callable] = -59.9, + delta_T: Union[float, ArrayType, Callable] = 3.48, + a: Union[float, ArrayType, Callable] = 1., + b: Union[float, ArrayType, Callable] = 1., + tau: Union[float, ArrayType, Callable] = 10., + tau_w: Union[float, ArrayType, Callable] = 30., + R: Union[float, ArrayType, Callable] = 1., + V_initializer: Union[Callable, ArrayType] = ZeroInit(), + w_initializer: Union[Callable, ArrayType] = ZeroInit(), + + # new neuron parameter + tau_ref: Union[float, ArrayType, Callable] = 0., + ref_var: bool = False, + ): + # initialization + super().__init__( + size=size, + name=name, + keep_size=keep_size, + mode=mode, + method=method, + sharding=sharding, + spk_fun=spk_fun, + detach_spk=detach_spk, + spk_type=spk_type, + + init_var=False, + + V_rest=V_rest, + V_reset=V_reset, + V_th=V_th, + V_T=V_T, + delta_T=delta_T, + a=a, + b=b, + R=R, + tau=tau, + tau_w=tau_w, + V_initializer=V_initializer, + w_initializer=w_initializer + ) + + # parameters + self.ref_var = ref_var + self.tau_ref = self.init_param(tau_ref) + + # initializers + self._V_initializer = is_initializer(V_initializer) + self._w_initializer = is_initializer(w_initializer) + + # integral + self.integral = odeint(method=method, f=self.derivative) + + # variables + if init_var: + self.reset_state(self.mode) + + def reset_state(self, batch_size=None): + super().reset_state(batch_size) + self.t_last_spike = self.init_variable(bm.ones, batch_size) + self.t_last_spike.fill_(-1e8) + if self.ref_var: + self.refractory = self.init_variable(partial(bm.zeros, dtype=bool), batch_size) + + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + x = 0. if x is None else x + + # integrate membrane potential + V, w = self.integral(self.V.value, self.w.value, t, x, dt) + + # refractory + refractory = (t - self.t_last_spike) <= self.tau_ref + if isinstance(self.mode, bm.TrainingMode): + refractory = stop_gradient(refractory) + V = bm.where(refractory, self.V.value, V) + + # spike, refractory, spiking time, and membrane potential reset + if isinstance(self.mode, bm.TrainingMode): + spike = self.spk_fun(V - self.V_th) + spike_no_grad = stop_gradient(spike) if self.detach_spk else spike + V += (self.V_reset - V) * spike_no_grad + w += self.b * spike_no_grad + spike_ = spike_no_grad > 0. + # will be used in other place, like Delta Synapse, so stop its gradient + if self.ref_var: + self.refractory.value = stop_gradient(bm.logical_or(refractory, spike_).value) + t_last_spike = stop_gradient(bm.where(spike_, t, self.t_last_spike.value)) + + else: + spike = V >= self.V_th + V = bm.where(spike, self.V_reset, V) + w = bm.where(spike, w + self.b, w) + if self.ref_var: + self.refractory.value = bm.logical_or(refractory, spike) + t_last_spike = bm.where(spike, t, self.t_last_spike.value) + self.V.value = V + self.w.value = w + self.spike.value = spike + self.t_last_spike.value = t_last_spike + return spike + + +class AdExIFRef(AdExIFRefLTC): + def dV(self, V, t, w, I): + exp = self.delta_T * bm.exp((V - self.V_T) / self.delta_T) + dVdt = (- V + self.V_rest + exp - self.R * w + self.R * I) / self.tau + return dVdt + + def dw(self, w, t, V): + dwdt = (self.a * (V - self.V_rest) - w) / self.tau_w + return dwdt + + @property + def derivative(self): + return JointEq([self.dV, self.dw]) + + def update(self, x=None): + x = 0. if x is None else x + for out in self.cur_inputs.values(): + x += out(self.V.value) + super().update(x) + + +class QuaIFLTC(GradNeuDyn): + def __init__( + self, + size: Shape, + sharding: Optional[Sequence[str]] = None, + keep_size: bool = False, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + spk_fun: Callable = bm.surrogate.InvSquareGrad(), + spk_type: Any = None, + detach_spk: bool = False, + method: str = 'exp_auto', + init_var: bool = True, + + # neuron parameters + V_rest: Union[float, ArrayType, Callable] = -65., + V_reset: Union[float, ArrayType, Callable] = -68., + V_th: Union[float, ArrayType, Callable] = -30., + V_c: Union[float, ArrayType, Callable] = -50.0, + c: Union[float, ArrayType, Callable] = 0.07, + R: Union[float, ArrayType, Callable] = 1., + tau: Union[float, ArrayType, Callable] = 10., + V_initializer: Union[Callable, ArrayType] = ZeroInit(), + ): + # initialization + super().__init__(size=size, + name=name, + keep_size=keep_size, + mode=mode, + sharding=sharding, + spk_fun=spk_fun, + detach_spk=detach_spk, + method=method, + spk_type=spk_type) + # parameters + self.V_rest = self.init_param(V_rest) + self.V_reset = self.init_param(V_reset) + self.V_th = self.init_param(V_th) + self.V_c = self.init_param(V_c) + self.c = self.init_param(c) + self.R = self.init_param(R) + self.tau = self.init_param(tau) + + # initializers + self._V_initializer = is_initializer(V_initializer) + + # integral + self.integral = odeint(method=method, f=self.derivative) + + # variables + if init_var: + self.reset_state(self.mode) + + def derivative(self, V, t, I): + for out in self.cur_inputs.values(): + I += out(V) + dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) + self.R * I) / self.tau + return dVdt + + def reset_state(self, batch_size=None): + self.V = self.init_variable(self._V_initializer, batch_size) + self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_type), batch_size) + + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + x = 0. if x is None else x + + # integrate membrane potential + V = self.integral(self.V.value, t, x, dt) + + # spike, spiking time, and membrane potential reset + if isinstance(self.mode, bm.TrainingMode): + spike = self.spk_fun(V - self.V_th) + spike = stop_gradient(spike) if self.detach_spk else spike + V += (self.V_reset - V) * spike + + else: + spike = V >= self.V_th + V = bm.where(spike, self.V_reset, V) + + self.V.value = V + self.spike.value = spike + return spike + + def return_for_delay(self): + return self.spike + + +class QuaIF(QuaIFLTC): + def derivative(self, V, t, I): + dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) + self.R * I) / self.tau + return dVdt + + def update(self, x=None): + x = 0. if x is None else x + for out in self.cur_inputs.values(): + x += out(self.V.value) + super().update(x) + + +class QuaIFRefLTC(QuaIFLTC): + def __init__( + self, + size: Shape, + sharding: Optional[Sharding] = None, + keep_size: bool = False, + mode: Optional[bm.Mode] = None, + spk_fun: Callable = bm.surrogate.InvSquareGrad(), + spk_type: Any = None, + detach_spk: bool = False, + method: str = 'exp_auto', + name: Optional[str] = None, + init_var: bool = True, + + # old neuron parameter + V_rest: Union[float, ArrayType, Callable] = -65., + V_reset: Union[float, ArrayType, Callable] = -68., + V_th: Union[float, ArrayType, Callable] = -30., + V_c: Union[float, ArrayType, Callable] = -50.0, + c: Union[float, ArrayType, Callable] = 0.07, + R: Union[float, ArrayType, Callable] = 1., + tau: Union[float, ArrayType, Callable] = 10., + V_initializer: Union[Callable, ArrayType] = ZeroInit(), + + # new neuron parameter + tau_ref: Union[float, ArrayType, Callable] = 0., + ref_var: bool = False, + ): + # initialization + super().__init__( + size=size, + name=name, + keep_size=keep_size, + mode=mode, + method=method, + sharding=sharding, + spk_fun=spk_fun, + detach_spk=detach_spk, + spk_type=spk_type, + + init_var=False, + + V_rest=V_rest, + V_reset=V_reset, + V_th=V_th, + V_c=V_c, + c=c, + R=R, + tau=tau, + V_initializer=V_initializer, + ) + + # parameters + self.ref_var = ref_var + self.tau_ref = self.init_param(tau_ref) + + # initializers + self._V_initializer = is_initializer(V_initializer) + + # integral + self.integral = odeint(method=method, f=self.derivative) + + # variables + if init_var: + self.reset_state(self.mode) + + def reset_state(self, batch_size=None): + super().reset_state(batch_size) + self.t_last_spike = self.init_variable(bm.ones, batch_size) + self.t_last_spike.fill_(-1e7) + if self.ref_var: + self.refractory = self.init_variable(partial(bm.zeros, dtype=bool), batch_size) + + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + x = 0. if x is None else x + + # integrate membrane potential + V = self.integral(self.V.value, t, x, dt) + + # refractory + refractory = (t - self.t_last_spike) <= self.tau_ref + if isinstance(self.mode, bm.TrainingMode): + refractory = stop_gradient(refractory) + V = bm.where(refractory, self.V.value, V) + + # spike, refractory, spiking time, and membrane potential reset + if isinstance(self.mode, bm.TrainingMode): + spike = self.spk_fun(V - self.V_th) + spike_no_grad = stop_gradient(spike) if self.detach_spk else spike + V += (self.V_reset - V) * spike_no_grad + spike_ = spike_no_grad > 0. + # will be used in other place, like Delta Synapse, so stop its gradient + if self.ref_var: + self.refractory.value = stop_gradient(bm.logical_or(refractory, spike_).value) + t_last_spike = stop_gradient(bm.where(spike_, t, self.t_last_spike.value)) + + else: + spike = V >= self.V_th + V = bm.where(spike, self.V_reset, V) + if self.ref_var: + self.refractory.value = bm.logical_or(refractory, spike) + t_last_spike = bm.where(spike, t, self.t_last_spike.value) + self.V.value = V + self.spike.value = spike + self.t_last_spike.value = t_last_spike + return spike + + +class QuaIFRef(QuaIFRefLTC): + def derivative(self, V, t, I): + dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) + self.R * I) / self.tau + return dVdt + + def update(self, x=None): + x = 0. if x is None else x + for out in self.cur_inputs.values(): + x += out(self.V.value) + super().update(x) + + +class AdQuaIFLTC(GradNeuDyn): + def __init__( + self, + size: Shape, + sharding: Optional[Sequence[str]] = None, + keep_size: bool = False, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + spk_fun: Callable = bm.surrogate.InvSquareGrad(), + spk_type: Any = None, + detach_spk: bool = False, + method: str = 'exp_auto', + init_var: bool = True, + + # neuron parameters + V_rest: Union[float, ArrayType, Callable] = -65., + V_reset: Union[float, ArrayType, Callable] = -68., + V_th: Union[float, ArrayType, Callable] = -30., + V_c: Union[float, ArrayType, Callable] = -50.0, + a: Union[float, ArrayType, Callable] = 1., + b: Union[float, ArrayType, Callable] = .1, + c: Union[float, ArrayType, Callable] = .07, + tau: Union[float, ArrayType, Callable] = 10., + tau_w: Union[float, ArrayType, Callable] = 10., + V_initializer: Union[Callable, ArrayType] = ZeroInit(), + w_initializer: Union[Callable, ArrayType] = ZeroInit(), + ): + # initialization + super().__init__(size=size, + name=name, + keep_size=keep_size, + mode=mode, + sharding=sharding, + spk_fun=spk_fun, + detach_spk=detach_spk, + method=method, + spk_type=spk_type) + # parameters + self.V_rest = self.init_param(V_rest) + self.V_reset = self.init_param(V_reset) + self.V_th = self.init_param(V_th) + self.V_c = self.init_param(V_c) + self.a = self.init_param(a) + self.b = self.init_param(b) + self.c = self.init_param(c) + self.tau = self.init_param(tau) + self.tau_w = self.init_param(tau_w) + + # initializers + self._V_initializer = is_initializer(V_initializer) + self._w_initializer = is_initializer(w_initializer) + + # integral + self.integral = odeint(method=method, f=self.derivative) + + # variables + if init_var: + self.reset_state(self.mode) + + def dV(self, V, t, w, I): + for out in self.cur_inputs.values(): + I += out(V) + dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) - w + I) / self.tau + return dVdt + + def dw(self, w, t, V): + dwdt = (self.a * (V - self.V_rest) - w) / self.tau_w + return dwdt + + @property + def derivative(self): + return JointEq([self.dV, self.dw]) + + def reset_state(self, batch_size=None): + self.V = self.init_variable(self._V_initializer, batch_size) + self.w = self.init_variable(self._w_initializer, batch_size) + self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_type), batch_size) + + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + x = 0. if x is None else x + + # integrate membrane potential + V, w = self.integral(self.V.value, self.w.value, t, x, dt) + + # spike, spiking time, and membrane potential reset + if isinstance(self.mode, bm.TrainingMode): + spike = self.spk_fun(V - self.V_th) + spike = stop_gradient(spike) if self.detach_spk else spike + V += (self.V_reset - V) * spike + w += self.b * spike + + else: + spike = V >= self.V_th + V = bm.where(spike, self.V_reset, V) + w = bm.where(spike, w + self.b, w) + + self.V.value = V + self.w.value = w + self.spike.value = spike + return spike + + def return_for_delay(self): + return self.spike + + +class AdQuaIF(AdQuaIFLTC): + def dV(self, V, t, w, I): + dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) - w + I) / self.tau + return dVdt + + def dw(self, w, t, V): + dwdt = (self.a * (V - self.V_rest) - w) / self.tau_w + return dwdt + + @property + def derivative(self): + return JointEq([self.dV, self.dw]) + + def update(self, x=None): + x = 0. if x is None else x + for out in self.cur_inputs.values(): + x += out(self.V.value) + super().update(x) + + +class AdQuaIFRefLTC(AdQuaIFLTC): + def __init__( + self, + size: Shape, + sharding: Optional[Sharding] = None, + keep_size: bool = False, + mode: Optional[bm.Mode] = None, + spk_fun: Callable = bm.surrogate.InvSquareGrad(), + spk_type: Any = None, + detach_spk: bool = False, + method: str = 'exp_auto', + name: Optional[str] = None, + init_var: bool = True, + + # old neuron parameter + V_rest: Union[float, ArrayType, Callable] = -65., + V_reset: Union[float, ArrayType, Callable] = -68., + V_th: Union[float, ArrayType, Callable] = -30., + V_c: Union[float, ArrayType, Callable] = -50.0, + a: Union[float, ArrayType, Callable] = 1., + b: Union[float, ArrayType, Callable] = .1, + c: Union[float, ArrayType, Callable] = .07, + tau: Union[float, ArrayType, Callable] = 10., + tau_w: Union[float, ArrayType, Callable] = 10., + V_initializer: Union[Callable, ArrayType] = ZeroInit(), + w_initializer: Union[Callable, ArrayType] = ZeroInit(), + + # new neuron parameter + tau_ref: Union[float, ArrayType, Callable] = 0., + ref_var: bool = False, + ): + # initialization + super().__init__( + size=size, + name=name, + keep_size=keep_size, + mode=mode, + method=method, + sharding=sharding, + spk_fun=spk_fun, + detach_spk=detach_spk, + spk_type=spk_type, + + init_var=False, + + V_rest=V_rest, + V_reset=V_reset, + V_th=V_th, + V_c=V_c, + a=a, + b=b, + c=c, + tau=tau, + tau_w=tau_w, + V_initializer=V_initializer, + w_initializer=w_initializer + ) + + # parameters + self.ref_var = ref_var + self.tau_ref = self.init_param(tau_ref) + + # initializers + self._V_initializer = is_initializer(V_initializer) + self._w_initializer = is_initializer(w_initializer) + + # integral + self.integral = odeint(method=method, f=self.derivative) + + # variables + if init_var: + self.reset_state(self.mode) + + def reset_state(self, batch_size=None): + super().reset_state(batch_size) + self.t_last_spike = self.init_variable(bm.ones, batch_size) + self.t_last_spike.fill_(-1e8) + if self.ref_var: + self.refractory = self.init_variable(partial(bm.zeros, dtype=bool), batch_size) + + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + x = 0. if x is None else x + + # integrate membrane potential + V, w = self.integral(self.V.value, self.w.value, t, x, dt) + + # refractory + refractory = (t - self.t_last_spike) <= self.tau_ref + if isinstance(self.mode, bm.TrainingMode): + refractory = stop_gradient(refractory) + V = bm.where(refractory, self.V.value, V) + + # spike, refractory, spiking time, and membrane potential reset + if isinstance(self.mode, bm.TrainingMode): + spike = self.spk_fun(V - self.V_th) + spike_no_grad = stop_gradient(spike) if self.detach_spk else spike + V += (self.V_reset - V) * spike_no_grad + w += self.b * spike_no_grad + spike_ = spike_no_grad > 0. + # will be used in other place, like Delta Synapse, so stop its gradient + if self.ref_var: + self.refractory.value = stop_gradient(bm.logical_or(refractory, spike_).value) + t_last_spike = stop_gradient(bm.where(spike_, t, self.t_last_spike.value)) + + else: + spike = V >= self.V_th + V = bm.where(spike, self.V_reset, V) + w = bm.where(spike, w + self.b, w) + if self.ref_var: + self.refractory.value = bm.logical_or(refractory, spike) + t_last_spike = bm.where(spike, t, self.t_last_spike.value) + self.V.value = V + self.w.value = w + self.spike.value = spike + self.t_last_spike.value = t_last_spike + return spike + + +class AdQuaIFRef(AdQuaIFRefLTC): + def dV(self, V, t, w, I): + dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) - w + I) / self.tau + return dVdt + + def dw(self, w, t, V): + dwdt = (self.a * (V - self.V_rest) - w) / self.tau_w + return dwdt + + @property + def derivative(self): + return JointEq([self.dV, self.dw]) + + def update(self, x=None): + x = 0. if x is None else x + for out in self.cur_inputs.values(): + x += out(self.V.value) + super().update(x) + + +class GifLTC(GradNeuDyn): + def __init__( + self, + size: Shape, + sharding: Optional[Sequence[str]] = None, + keep_size: bool = False, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + spk_fun: Callable = bm.surrogate.InvSquareGrad(), + spk_type: Any = None, + detach_spk: bool = False, + method: str = 'exp_auto', + init_var: bool = True, + + # neuron parameters + V_rest: Union[float, ArrayType, Callable] = -70., + V_reset: Union[float, ArrayType, Callable] = -70., + V_th_inf: Union[float, ArrayType, Callable] = -50., + V_th_reset: Union[float, ArrayType, Callable] = -60., + R: Union[float, ArrayType, Callable] = 20., + tau: Union[float, ArrayType, Callable] = 20., + a: Union[float, ArrayType, Callable] = 0., + b: Union[float, ArrayType, Callable] = 0.01, + k1: Union[float, ArrayType, Callable] = 0.2, + k2: Union[float, ArrayType, Callable] = 0.02, + R1: Union[float, ArrayType, Callable] = 0., + R2: Union[float, ArrayType, Callable] = 1., + A1: Union[float, ArrayType, Callable] = 0., + A2: Union[float, ArrayType, Callable] = 0., + V_initializer: Union[Callable, ArrayType] = OneInit(-70.), + I1_initializer: Union[Callable, ArrayType] = ZeroInit(), + I2_initializer: Union[Callable, ArrayType] = ZeroInit(), + Vth_initializer: Union[Callable, ArrayType] = OneInit(-50.), + ): + # initialization + super().__init__(size=size, + name=name, + keep_size=keep_size, + mode=mode, + sharding=sharding, + spk_fun=spk_fun, + detach_spk=detach_spk, + method=method, + spk_type=spk_type) + # parameters + self.V_rest = self.init_param(V_rest) + self.V_reset = self.init_param(V_reset) + self.V_th_inf = self.init_param(V_th_inf) + self.V_th_reset = self.init_param(V_th_reset) + self.R = self.init_param(R) + self.a = self.init_param(a) + self.b = self.init_param(b) + self.k1 = self.init_param(k1) + self.k2 = self.init_param(k2) + self.R1 = self.init_param(R1) + self.R2 = self.init_param(R2) + self.A1 = self.init_param(A1) + self.A2 = self.init_param(A2) + self.tau = self.init_param(tau) + + # initializers + self._V_initializer = is_initializer(V_initializer) + self._I1_initializer = is_initializer(I1_initializer) + self._I2_initializer = is_initializer(I2_initializer) + self._Vth_initializer = is_initializer(Vth_initializer) + + # integral + self.integral = odeint(method=method, f=self.derivative) + + # variables + if init_var: + self.reset_state(self.mode) + + def dI1(self, I1, t): + return - self.k1 * I1 + + def dI2(self, I2, t): + return - self.k2 * I2 + + def dVth(self, V_th, t, V): + return self.a * (V - self.V_rest) - self.b * (V_th - self.V_th_inf) + + def dV(self, V, t, I1, I2, I): + for out in self.cur_inputs.values(): + I += out(V) + return (- (V - self.V_rest) + self.R * (I + I1 + I2)) / self.tau + + @property + def derivative(self): + return JointEq(self.dI1, self.dI2, self.dVth, self.dV) + + def reset_state(self, batch_size=None): + self.V = self.init_variable(self._V_initializer, batch_size) + self.I1 = self.init_variable(self._I1_initializer, batch_size) + self.I2 = self.init_variable(self._I2_initializer, batch_size) + self.V_th = self.init_variable(self._Vth_initializer, batch_size) + self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_type), batch_size) + + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + x = 0. if x is None else x + + # integrate membrane potential + I1, I2, V_th, V = self.integral(self.I1.value, self.I2.value, self.V_th.value, self.V.value, t, x, dt) + + # spike, spiking time, and membrane potential reset + if isinstance(self.mode, bm.TrainingMode): + spike = self.spk_fun(V - self.V_th) + spike = stop_gradient(spike) if self.detach_spk else spike + V += (self.V_reset - V) * spike + I1 += spike * (self.R1 * I1 + self.A1 - I1) + I2 += spike * (self.R2 * I2 + self.A2 - I2) + reset_th = self.spk_fun(self.V_th_reset - V_th) * spike + V_th += reset_th * (self.V_th_reset - V_th) + + else: + spike = self.V_th <= V + V = bm.where(spike, self.V_reset, V) + I1 = bm.where(spike, self.R1 * I1 + self.A1, I1) + I2 = bm.where(spike, self.R2 * I2 + self.A2, I2) + V_th = bm.where(spike, bm.maximum(self.V_th_reset, V_th), V_th) + self.spike.value = spike + self.I1.value = I1 + self.I2.value = I2 + self.V_th.value = V_th + self.V.value = V + return spike + + def return_for_delay(self): + return self.spike + + +class Gif(GifLTC): + def dI1(self, I1, t): + return - self.k1 * I1 + + def dI2(self, I2, t): + return - self.k2 * I2 + + def dVth(self, V_th, t, V): + return self.a * (V - self.V_rest) - self.b * (V_th - self.V_th_inf) + + def dV(self, V, t, I1, I2, I): + return (- (V - self.V_rest) + self.R * (I + I1 + I2)) / self.tau + + @property + def derivative(self): + return JointEq(self.dI1, self.dI2, self.dVth, self.dV) + + def update(self, x=None): + x = 0. if x is None else x + for out in self.cur_inputs.values(): + x += out(self.V.value) + super().update(x) + + +class GifRefLTC(GifLTC): + def __init__( + self, + size: Shape, + sharding: Optional[Sharding] = None, + keep_size: bool = False, + mode: Optional[bm.Mode] = None, + spk_fun: Callable = bm.surrogate.InvSquareGrad(), + spk_type: Any = None, + detach_spk: bool = False, + method: str = 'exp_auto', + name: Optional[str] = None, + init_var: bool = True, + + # old neuron parameter + V_rest: Union[float, ArrayType, Callable] = -70., + V_reset: Union[float, ArrayType, Callable] = -70., + V_th_inf: Union[float, ArrayType, Callable] = -50., + V_th_reset: Union[float, ArrayType, Callable] = -60., + R: Union[float, ArrayType, Callable] = 20., + tau: Union[float, ArrayType, Callable] = 20., + a: Union[float, ArrayType, Callable] = 0., + b: Union[float, ArrayType, Callable] = 0.01, + k1: Union[float, ArrayType, Callable] = 0.2, + k2: Union[float, ArrayType, Callable] = 0.02, + R1: Union[float, ArrayType, Callable] = 0., + R2: Union[float, ArrayType, Callable] = 1., + A1: Union[float, ArrayType, Callable] = 0., + A2: Union[float, ArrayType, Callable] = 0., + V_initializer: Union[Callable, ArrayType] = OneInit(-70.), + I1_initializer: Union[Callable, ArrayType] = ZeroInit(), + I2_initializer: Union[Callable, ArrayType] = ZeroInit(), + Vth_initializer: Union[Callable, ArrayType] = OneInit(-50.), + + # new neuron parameter + tau_ref: Union[float, ArrayType, Callable] = 0., + ref_var: bool = False, + ): + # initialization + super().__init__( + size=size, + name=name, + keep_size=keep_size, + mode=mode, + method=method, + sharding=sharding, + spk_fun=spk_fun, + detach_spk=detach_spk, + spk_type=spk_type, + + init_var=False, + + V_rest=V_rest, + V_reset=V_reset, + V_th_inf=V_th_inf, + V_th_reset=V_th_reset, + R=R, + a=a, + b=b, + k1=k1, + k2=k2, + R1=R1, + R2=R2, + A1=A1, + A2=A2, + tau=tau, + V_initializer=V_initializer, + I1_initializer=I1_initializer, + I2_initializer=I2_initializer, + Vth_initializer=Vth_initializer, + ) + + # parameters + self.ref_var = ref_var + self.tau_ref = self.init_param(tau_ref) + + # initializers + self._V_initializer = is_initializer(V_initializer) + self._I1_initializer = is_initializer(I1_initializer) + self._I2_initializer = is_initializer(I2_initializer) + self._Vth_initializer = is_initializer(Vth_initializer) + + # integral + self.integral = odeint(method=method, f=self.derivative) + + # variables + if init_var: + self.reset_state(self.mode) + + def reset_state(self, batch_size=None): + super().reset_state(batch_size) + self.t_last_spike = self.init_variable(bm.ones, batch_size) + self.t_last_spike.fill_(-1e8) + if self.ref_var: + self.refractory = self.init_variable(partial(bm.zeros, dtype=bool), batch_size) + + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + x = 0. if x is None else x + + # integrate membrane potential + I1, I2, V_th, V = self.integral(self.I1.value, self.I2.value, self.V_th.value, self.V.value, t, x, dt) + + # refractory + refractory = (t - self.t_last_spike) <= self.tau_ref + if isinstance(self.mode, bm.TrainingMode): + refractory = stop_gradient(refractory) + V = bm.where(refractory, self.V.value, V) + + # spike, refractory, spiking time, and membrane potential reset + if isinstance(self.mode, bm.TrainingMode): + spike = self.spk_fun(V - self.V_th) + spike_no_grad = stop_gradient(spike) if self.detach_spk else spike + V += (self.V_reset - V) * spike + I1 += spike * (self.R1 * I1 + self.A1 - I1) + I2 += spike * (self.R2 * I2 + self.A2 - I2) + reset_th = self.spk_fun(self.V_th_reset - V_th) * spike + V_th += reset_th * (self.V_th_reset - V_th) + spike_ = spike_no_grad > 0. + # will be used in other place, like Delta Synapse, so stop its gradient + if self.ref_var: + self.refractory.value = stop_gradient(bm.logical_or(refractory, spike_).value) + t_last_spike = stop_gradient(bm.where(spike_, t, self.t_last_spike.value)) + + else: + spike = V >= self.V_th + V = bm.where(spike, self.V_reset, V) + I1 = bm.where(spike, self.R1 * I1 + self.A1, I1) + I2 = bm.where(spike, self.R2 * I2 + self.A2, I2) + V_th = bm.where(spike, bm.maximum(self.V_th_reset, V_th), V_th) + if self.ref_var: + self.refractory.value = bm.logical_or(refractory, spike) + t_last_spike = bm.where(spike, t, self.t_last_spike.value) + self.V.value = V + self.I1.value = I1 + self.I2.value = I2 + self.V_th.value = V_th + self.spike.value = spike + self.t_last_spike.value = t_last_spike + return spike + + +class GifRef(GifRefLTC): + def dI1(self, I1, t): + return - self.k1 * I1 + + def dI2(self, I2, t): + return - self.k2 * I2 + + def dVth(self, V_th, t, V): + return self.a * (V - self.V_rest) - self.b * (V_th - self.V_th_inf) + + def dV(self, V, t, I1, I2, I): + return (- (V - self.V_rest) + self.R * (I + I1 + I2)) / self.tau + + @property + def derivative(self): + return JointEq(self.dI1, self.dI2, self.dVth, self.dV) + + def update(self, x=None): + x = 0. if x is None else x + for out in self.cur_inputs.values(): + x += out(self.V.value) + super().update(x) + + +class IzhikevichLTC(GradNeuDyn): + def __init__( + self, + size: Shape, + sharding: Optional[Sequence[str]] = None, + keep_size: bool = False, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + spk_fun: Callable = bm.surrogate.InvSquareGrad(), + spk_type: Any = None, + detach_spk: bool = False, + method: str = 'exp_auto', + init_var: bool = True, + + # neuron parameters + V_th: Union[float, ArrayType, Callable] = 30., + a: Union[float, ArrayType, Callable] = 0.02, + b: Union[float, ArrayType, Callable] = 0.20, + c: Union[float, ArrayType, Callable] = -65., + d: Union[float, ArrayType, Callable] = 8., + tau: Union[float, ArrayType, Callable] = 10., + R: Union[float, ArrayType, Callable] = 1., + V_initializer: Union[Callable, ArrayType] = OneInit(-70.), + u_initializer: Union[Callable, ArrayType] = None, + ): + # initialization + super().__init__(size=size, + name=name, + keep_size=keep_size, + mode=mode, + sharding=sharding, + spk_fun=spk_fun, + detach_spk=detach_spk, + method=method, + spk_type=spk_type) + # parameters + self.V_th = self.init_param(V_th) + self.a = self.init_param(a) + self.b = self.init_param(b) + self.c = self.init_param(c) + self.d = self.init_param(d) + self.R = self.init_param(R) + self.tau = self.init_param(tau) + + # initializers + self._V_initializer = is_initializer(V_initializer) + self._u_initializer = is_initializer(u_initializer) + + # integral + self.integral = odeint(method=method, f=self.derivative) + + # variables + if init_var: + self.reset_state(self.mode) + + def dV(self, V, t, u, I): + for out in self.cur_inputs.values(): + I += out(V) + dVdt = 0.04 * V * V + 5 * V + 140 - u + I + return dVdt + + def du(self, u, t, V): + dudt = self.a * (self.b * V - u) + return dudt + + @property + def derivative(self): + return JointEq([self.dV, self.du]) + + def reset_state(self, batch_size=None): + self.V = self.init_variable(self._V_initializer, batch_size) + u_initializer = OneInit(self.b * self.V) if self._u_initializer is None else self._u_initializer + self._u_initializer = is_initializer(u_initializer) + self.u = self.init_variable(self._u_initializer, batch_size) + self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_type), batch_size) + + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + x = 0. if x is None else x + + # integrate membrane potential + V, u = self.integral(self.V.value, self.u.value, t, x, dt) + + # spike, spiking time, and membrane potential reset + if isinstance(self.mode, bm.TrainingMode): + spike = self.spk_fun(V - self.V_th) + spike = stop_gradient(spike) if self.detach_spk else spike + V += spike * (self.c - self.V_th) + u += spike * self.d + + else: + spike = V >= self.V_th + V = bm.where(spike, self.c, V) + u = bm.where(spike, u + self.d, u) + + self.V.value = V + self.u.value = u + self.spike.value = spike + return spike + + def return_for_delay(self): + return self.spike + + +class Izhikevich(IzhikevichLTC): + def dV(self, V, t, u, I): + dVdt = 0.04 * V * V + 5 * V + 140 - u + I + return dVdt + + def du(self, u, t, V): + dudt = self.a * (self.b * V - u) + return dudt + + def derivative(self): + return JointEq([self.dV, self.du]) + + def update(self, x=None): + x = 0. if x is None else x + for out in self.cur_inputs.values(): + x += out(self.V.value) + super().update(x) + + +class IzhikevichRefLTC(IzhikevichLTC): + def __init__( + self, + size: Shape, + sharding: Optional[Sharding] = None, + keep_size: bool = False, + mode: Optional[bm.Mode] = None, + spk_fun: Callable = bm.surrogate.InvSquareGrad(), + spk_type: Any = None, + detach_spk: bool = False, + method: str = 'exp_auto', + name: Optional[str] = None, + init_var: bool = True, + + # old neuron parameter + V_th: Union[float, ArrayType, Callable] = 30., + a: Union[float, ArrayType, Callable] = 0.02, + b: Union[float, ArrayType, Callable] = 0.20, + c: Union[float, ArrayType, Callable] = -65., + d: Union[float, ArrayType, Callable] = 8., + tau: Union[float, ArrayType, Callable] = 10., + R: Union[float, ArrayType, Callable] = 1., + V_initializer: Union[Callable, ArrayType] = OneInit(-70.), + u_initializer: Union[Callable, ArrayType] = None, + + # new neuron parameter + tau_ref: Union[float, ArrayType, Callable] = 0., + ref_var: bool = False, + ): + # initialization + super().__init__( + size=size, + name=name, + keep_size=keep_size, + mode=mode, + method=method, + sharding=sharding, + spk_fun=spk_fun, + detach_spk=detach_spk, + spk_type=spk_type, + + init_var=False, + + V_th=V_th, + a=a, + b=b, + c=c, + d=d, + R=R, + tau=tau, + V_initializer=V_initializer, + u_initializer=u_initializer + ) + + # parameters + self.ref_var = ref_var + self.tau_ref = self.init_param(tau_ref) + + # initializers + self._V_initializer = is_initializer(V_initializer) + self._u_initializer = is_initializer(u_initializer) + + # integral + self.integral = odeint(method=method, f=self.derivative) + + # variables + if init_var: + self.reset_state(self.mode) + + def reset_state(self, batch_size=None): + super().reset_state(batch_size) + self.t_last_spike = self.init_variable(bm.ones, batch_size) + self.t_last_spike.fill_(-1e7) + if self.ref_var: + self.refractory = self.init_variable(partial(bm.zeros, dtype=bool), batch_size) + + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + x = 0. if x is None else x + + # integrate membrane potential + V, u = self.integral(self.V.value, self.u.value, t, x, dt) + + # refractory + refractory = (t - self.t_last_spike) <= self.tau_ref + if isinstance(self.mode, bm.TrainingMode): + refractory = stop_gradient(refractory) + V = bm.where(refractory, self.V.value, V) + + # spike, refractory, spiking time, and membrane potential reset + if isinstance(self.mode, bm.TrainingMode): + spike = self.spk_fun(V - self.V_th) + spike_no_grad = stop_gradient(spike) if self.detach_spk else spike + V += spike * (self.c - self.V_th) + u += spike * self.d + spike_ = spike_no_grad > 0. + # will be used in other place, like Delta Synapse, so stop its gradient + if self.ref_var: + self.refractory.value = stop_gradient(bm.logical_or(refractory, spike_).value) + t_last_spike = stop_gradient(bm.where(spike_, t, self.t_last_spike.value)) + + else: + spike = V >= self.V_th + V = bm.where(spike, self.c, V) + u = bm.where(spike, u + self.d, u) + if self.ref_var: + self.refractory.value = bm.logical_or(refractory, spike) + t_last_spike = bm.where(spike, t, self.t_last_spike.value) + self.V.value = V + self.u.value = u + self.spike.value = spike + self.t_last_spike.value = t_last_spike + return spike + + +class IzhikevichRef(IzhikevichRefLTC): + def dV(self, V, t, u, I): + dVdt = 0.04 * V * V + 5 * V + 140 - u + I + return dVdt + + def du(self, u, t, V): + dudt = self.a * (self.b * V - u) + return dudt + + def derivative(self): + return JointEq([self.dV, self.du]) + + def update(self, x=None): + x = 0. if x is None else x + for out in self.cur_inputs.values(): + x += out(self.V.value) + super().update(x) \ No newline at end of file diff --git a/brainpy/_src/dyn/neurons/tests/test_hh.py b/brainpy/_src/dyn/neurons/tests/test_hh.py new file mode 100644 index 000000000..2a9bd7a46 --- /dev/null +++ b/brainpy/_src/dyn/neurons/tests/test_hh.py @@ -0,0 +1,140 @@ +# -*- coding: utf-8 -*- + + +import brainpy as bp +import brainpy.math as bm +from absl.testing import parameterized +from brainpy._src.dyn.neurons import hh + +class Test_HH(parameterized.TestCase): + def test_HH(self): + model = hh.HH(size=1) + runner = bp.DSRunner(model, + monitors=['V', 'm', 'n', 'h', 'spike'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['m'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['n'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['h'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['spike'].shape, (100, 1)) + + def test_HH_batching_mode(self): + model = hh.HH(size=10, mode=bm.batching_mode) + runner = bp.DSRunner(model, + monitors=['V', 'm', 'n', 'h', 'spike'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (1, 100, 10)) + self.assertTupleEqual(runner.mon['m'].shape, (1, 100, 10)) + self.assertTupleEqual(runner.mon['n'].shape, (1, 100, 10)) + self.assertTupleEqual(runner.mon['h'].shape, (1, 100, 10)) + self.assertTupleEqual(runner.mon['spike'].shape, (1, 100, 10)) + + def test_HHLTC(self): + model = hh.HHLTC(size=1) + runner = bp.DSRunner(model, + monitors=['V', 'm', 'n', 'h', 'spike'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['m'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['n'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['h'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['spike'].shape, (100, 1)) + + def test_HHLTC_batching_mode(self): + model = hh.HHLTC(size=10, mode=bm.batching_mode) + runner = bp.DSRunner(model, + monitors=['V', 'm', 'n', 'h', 'spike'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (1, 100, 10)) + self.assertTupleEqual(runner.mon['m'].shape, (1, 100, 10)) + self.assertTupleEqual(runner.mon['n'].shape, (1, 100, 10)) + self.assertTupleEqual(runner.mon['h'].shape, (1, 100, 10)) + self.assertTupleEqual(runner.mon['spike'].shape, (1, 100, 10)) + + def test_MorrisLecar(self): + model = hh.MorrisLecar(size=1) + runner = bp.DSRunner(model, + monitors=['V', 'W', 'spike'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['W'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['spike'].shape, (100, 1)) + + def test_MorrisLecar_batching_mode(self): + model = hh.MorrisLecar(size=10, mode=bm.batching_mode) + runner = bp.DSRunner(model, + monitors=['V', 'W', 'spike'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (1, 100, 10)) + self.assertTupleEqual(runner.mon['W'].shape, (1, 100, 10)) + self.assertTupleEqual(runner.mon['spike'].shape, (1, 100, 10)) + + def test_MorrisLecarLTC(self): + model = hh.MorrisLecarLTC(size=1) + runner = bp.DSRunner(model, + monitors=['V', 'W', 'spike'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['W'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['spike'].shape, (100, 1)) + + def test_MorrisLecarLTC_batching_mode(self): + model = hh.MorrisLecarLTC(size=10, mode=bm.batching_mode) + runner = bp.DSRunner(model, + monitors=['V', 'W', 'spike'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (1, 100, 10)) + self.assertTupleEqual(runner.mon['W'].shape, (1, 100, 10)) + self.assertTupleEqual(runner.mon['spike'].shape, (1, 100, 10)) + + def test_WangBuzsakiModel(self): + model = hh.WangBuzsakiModel(size=1) + runner = bp.DSRunner(model, + monitors=['V', 'n', 'h', 'spike'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['n'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['h'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['spike'].shape, (100, 1)) + + def test_WangBuzsakiModel_batching_mode(self): + model = hh.WangBuzsakiModel(size=10, mode=bm.batching_mode) + runner = bp.DSRunner(model, + monitors=['V', 'n', 'h', 'spike'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (1, 100, 10)) + self.assertTupleEqual(runner.mon['n'].shape, (1, 100, 10)) + self.assertTupleEqual(runner.mon['h'].shape, (1, 100, 10)) + self.assertTupleEqual(runner.mon['spike'].shape, (1, 100, 10)) + + def test_WangBuzsakiModelLTC(self): + model = hh.WangBuzsakiModelLTC(size=1) + runner = bp.DSRunner(model, + monitors=['V', 'n', 'h', 'spike'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['n'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['h'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['spike'].shape, (100, 1)) + + def test_WangBuzsakiModelLTC_batching_mode(self): + model = hh.WangBuzsakiModelLTC(size=10, mode=bm.batching_mode) + runner = bp.DSRunner(model, + monitors=['V', 'n', 'h', 'spike'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (1, 100, 10)) + self.assertTupleEqual(runner.mon['n'].shape, (1, 100, 10)) + self.assertTupleEqual(runner.mon['h'].shape, (1, 100, 10)) + self.assertTupleEqual(runner.mon['spike'].shape, (1, 100, 10)) \ No newline at end of file diff --git a/brainpy/_src/dyn/neurons/tests/test_input.py b/brainpy/_src/dyn/neurons/tests/test_input.py new file mode 100644 index 000000000..fc05c62b8 --- /dev/null +++ b/brainpy/_src/dyn/neurons/tests/test_input.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -*- + + +import brainpy as bp +from absl.testing import parameterized +from brainpy._src.dyn.neurons import input + + +class Test_input(parameterized.TestCase): + def test_SpikeTimeGroup(self): + model = input.SpikeTimeGroup(size=2, times=[10, 20, 20, 30], indices=[0, 0, 1, 1]) + runner = bp.DSRunner(model, + monitors=['spike'], + progress_bar=False) + runner.run(30.) + self.assertTupleEqual(runner.mon['spike'].shape, (300, 2)) + + def test_PoissonGroup(self): + model = input.PoissonGroup(size=2, freqs=1000, seed=0) + runner = bp.DSRunner(model, + monitors=['spike'], + progress_bar=False) + runner.run(30.) + self.assertTupleEqual(runner.mon['spike'].shape, (300, 2)) diff --git a/brainpy/_src/dyn/neurons/tests/test_lif.py b/brainpy/_src/dyn/neurons/tests/test_lif.py new file mode 100644 index 000000000..2ed50f195 --- /dev/null +++ b/brainpy/_src/dyn/neurons/tests/test_lif.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- + + +import brainpy as bp +import brainpy.math as bm +from absl.testing import parameterized +from brainpy._src.dyn.neurons import lif + +class Test_lif(parameterized.TestCase): + @parameterized.named_parameters( + {'testcase_name': f'{name}', 'neuron': name} + for name in lif.__all__ + ) + def test_run_shape(self, neuron): + model = getattr(lif, neuron)(size=1) + if neuron in ['IF', 'IFLTC']: + runner = bp.DSRunner(model, + monitors=['V'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) + else: + runner = bp.DSRunner(model, + monitors=['V', 'spike'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['spike'].shape, (100, 1)) + + + @parameterized.named_parameters( + {'testcase_name': f'{name}', 'neuron': name} + for name in lif.__all__ + ) + def test_training_shape(self, neuron): + model = getattr(lif, neuron)(size=10, mode=bm.training_mode) + runner = bp.DSRunner(model, + monitors=['V'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (1, 100, 10)) diff --git a/brainpy/channels.py b/brainpy/channels.py index 6a19f7f55..16769e2f1 100644 --- a/brainpy/channels.py +++ b/brainpy/channels.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -from brainpy._src.channels.base import ( +from brainpy._src.dyn.channels.base import ( Ion as Ion, IonChannel as IonChannel, Calcium as Calcium, @@ -11,7 +11,7 @@ LeakyChannel as LeakyChannel, ) -from brainpy._src.channels.Ca import ( +from brainpy._src.dyn.channels.Ca import ( CalciumFixed as CalciumFixed, CalciumDyna as CalciumDyna, CalciumDetailed as CalciumDetailed, @@ -23,12 +23,12 @@ ICaL_IS2008 as ICaL_IS2008, ) -from brainpy._src.channels.IH import ( +from brainpy._src.dyn.channels.IH import ( Ih_HM1992 as Ih_HM1992, Ih_De1996 as Ih_De1996, ) -from brainpy._src.channels.K import ( +from brainpy._src.dyn.channels.K import ( IKDR_Ba2002 as IKDR_Ba2002, IK_TM1991 as IK_TM1991, IK_HH1952 as IK_HH1952, @@ -39,16 +39,16 @@ IKNI_Ya1989 as IKNI_Ya1989, ) -from brainpy._src.channels.KCa import ( +from brainpy._src.dyn.channels.KCa import ( IAHP_De1994 as IAHP_De1994, ) -from brainpy._src.channels.leaky import ( +from brainpy._src.dyn.channels.leaky import ( IL as IL, IKL as IKL, ) -from brainpy._src.channels.Na import ( +from brainpy._src.dyn.channels.Na import ( INa_Ba2002 as INa_Ba2002, INa_TM1991 as INa_TM1991, INa_HH1952 as INa_HH1952,