Skip to content

Commit

Permalink
Modify templates
Browse files Browse the repository at this point in the history
  • Loading branch information
pnbabu committed Jun 27, 2023
1 parent 920ea74 commit 5eb7a3c
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 217 deletions.
4 changes: 2 additions & 2 deletions pynestml/codegeneration/nest_gpu_code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ class NESTGPUCodeGenerator(NESTCodeGenerator):
"templates": {
"path": os.path.join(os.path.dirname(__file__), "resources_nest_gpu"),
"model_templates": {
"neuron": ["@NEURON_NAME@.cu.jinja2", "@NEURON_NAME@.h.jinja2",
"@NEURON_NAME@_kernel.h.jinja2", "@NEURON_NAME@_rk5.h.jinja2"],
"neuron": ["@NEURON_NAME@.cu.jinja2", "@NEURON_NAME@.h.jinja2"]
# "@NEURON_NAME@_kernel.h.jinja2", "@NEURON_NAME@_rk5.h.jinja2"],
},
"module_templates": [""]
},
Expand Down
210 changes: 126 additions & 84 deletions pynestml/codegeneration/resources_nest_gpu/@NEURON_NAME@.cu.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -23,106 +23,142 @@
#include <config.h>
#include <cmath>
#include <iostream>
#include "{{ neuronName }}_kernel.h"
#include "rk5.h"
#include "{{ neuronName }}.h"
#include "spike_buffer.h"

namespace {{ neuronName }}_ns
{
using namespace {{ neuronName }}_ns;

extern __constant__ float NESTGPUTimeResolution;

{%- for variable_symbol in neuron.get_state_symbols() %}
{%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
#define {{ printer_no_origin.print(variable) }} var[i_{{ printer_no_origin.print(variable) }}]
{%- endfor %}

{%- for variable_symbol in neuron.get_parameter_symbols() %}
{%- set variable = utils.get_parameter_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
#define {{ printer_no_origin.print(variable) }} param[i_{{ printer_no_origin.print(variable) }}]
{%- endfor %}

{%- for variable_symbol in neuron.get_internal_symbols() %}
{%- set variable = utils.get_internal_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
#define {{ printer_no_origin.print(variable) }} param[i_{{ printer_no_origin.print(variable) }}]
{%- endfor %}

__device__
void NodeInit(int n_var, int n_param, double x, float *y, float *param,
{{ neuronName }}_rk5 data_struct)
double propagator_32( double tau_syn, double tau, double C, double h )
{
//int array_idx = threadIdx.x + blockIdx.x * blockDim.x;
int n_port = (n_var-N_SCAL_VAR)/N_PORT_VAR;

V_th = -50.4;
Delta_T = 2.0;
g_L = 30.0;
E_L = -70.6;
C_m = 281.0;
a = 4.0;
b = 80.5;
tau_w = 144.0;
I_e = 0.0;
V_peak = 0.0;
V_reset = -60.0;
t_ref = 0.0;
den_delay = 0.0;

V_m = E_L;
w = 0;
refractory_step = 0;
for (int i = 0; i<n_port; i++) {
tau_syn(i) = 0.2;
I_syn(i) = 0;
const double P32_linear = 1.0 / ( 2.0 * C * tau * tau ) * h * h
* ( tau_syn - tau ) * exp( -h / tau );
const double P32_singular = h / C * exp( -h / tau );
const double P32 =
-tau / ( C * ( 1.0 - tau / tau_syn ) ) * exp( -h / tau_syn )
* expm1( h * ( 1.0 / tau_syn - 1.0 / tau ) );

const double dev_P32 = fabs( P32 - P32_singular );

if ( tau == tau_syn || ( fabs( tau - tau_syn ) < 0.1 && dev_P32 > 2.0
* fabs( P32_linear ) ) )
{
return P32_singular;
}
else
{
return P32;
}
}

__device__
void NodeCalibrate(int n_var, int n_param, double x, float *y,
float *param, {{ neuronName }}p_rk5 data_struct)
{
//int array_idx = threadIdx.x + blockIdx.x * blockDim.x;
//int n_port = (n_var-N_SCAL_VAR)/N_PORT_VAR;

refractory_step = 0;
// set the right threshold depending on Delta_T
if (Delta_T <= 0.0) {
V_peak = V_th; // same as IAF dynamics for spikes if Delta_T == 0.
__global__ void {{ neuronName }}_Calibrate(int n_node, float *param_arr,
int n_param, float h)
{
int i_neuron = threadIdx.x + blockIdx.x * blockDim.x;
if (i_neuron < n_node) {
float *param = param_arr + n_param*i_neuron;

P11ex = exp( -h / tau_ex );
P11in = exp( -h / tau_in );
P22 = exp( -h / tau_m );
P21ex = (float)propagator_32( tau_ex, tau_m, C_m, h );
P21in = (float)propagator_32( tau_in, tau_m, C_m, h );
P20 = tau_m / C_m * ( 1.0 - P22 );
}
}

}

__device__
void NodeInit(int n_var, int n_param, double x, float *y,
float *param, {{ neuronName }}_rk5 data_struct)
__global__ void {{ neuronName }}_Update(int n_node, int i_node_0, float *var_arr,
float *param_arr, int n_var, int n_param)
{
{{ neuronName }}_ns::NodeInit(n_var, n_param, x, y, param, data_struct);
int i_neuron = threadIdx.x + blockIdx.x * blockDim.x;
if (i_neuron < n_node) {
float *var = var_arr + n_var*i_neuron;
float *param = param_arr + n_param*i_neuron;

if ( refractory_step > 0.0 ) {
// neuron is absolute refractory
refractory_step -= 1.0;
}
else { // neuron is not refractory, so evolve V
V_m_rel = V_m_rel * P22 + I_syn_ex * P21ex + I_syn_in * P21in + I_e * P20;
}
// exponential decaying PSCs
I_syn_ex *= P11ex;
I_syn_in *= P11in;

if (V_m_rel >= Theta_rel ) { // threshold crossing
PushSpike(i_node_0 + i_neuron, 1.0);
V_m_rel = V_reset_rel;
refractory_step = (int)round(t_ref/NESTGPUTimeResolution);
}
}
}

__device__
void NodeCalibrate(int n_var, int n_param, double x, float *y,
float *param, {{ neuronName }}_rk5 data_struct)

{{ neuronName }}::~{{ neuronName }}()
{
{{ neuronName }}_ns::NodeCalibrate(n_var, n_param, x, y, param, data_struct);
FreeVarArr();
FreeParamArr();
}

using namespace {{ neuronName }}_ns;

int {{ neuronName }}::Init(int i_node_0, int n_node, int n_port,
int i_group, unsigned long long *seed) {
BaseNeuron::Init(i_node_0, n_node, n_port, i_group, seed);
int {{ neuronName }}::Init(int i_node_0, int n_node, int /*n_port*/,
int i_group, unsigned long long *seed)
{
BaseNeuron::Init(i_node_0, n_node, 2 /*n_port*/, i_group, seed);
node_type_ = i_{{ neuronName }}_model;

n_scal_var_ = N_SCAL_VAR;
n_port_var_ = N_PORT_VAR;
n_var_ = n_scal_var_;
n_scal_param_ = N_SCAL_PARAM;
n_port_param_ = N_PORT_PARAM;
n_group_param_ = N_GROUP_PARAM;
n_param_ = n_scal_param_;

n_var_ = n_scal_var_ + n_port_var_*n_port;
n_param_ = n_scal_param_ + n_port_param_*n_port;

group_param_ = new float[N_GROUP_PARAM];
AllocParamArr();
AllocVarArr();

scal_var_name_ = {{ neuronName }}_scal_var_name;
port_var_name_= {{ neuronName }}_port_var_name;
scal_param_name_ = {{ neuronName }}_scal_param_name;
port_param_name_ = {{ neuronName }}_port_param_name;
group_param_name_ = {{ neuronName }}_group_param_name;
//rk5_data_struct_.node_type_ = i_{{ neuronName }}_model;
rk5_data_struct_.i_node_0_ = i_node_0_;

SetGroupParam("h_min_rel", 1.0e-3);
SetGroupParam("h0_rel", 1.0e-2);
h_ = h0_rel_* 0.1;

rk5_.Init(n_node, n_var_, n_param_, 0.0, h_, rk5_data_struct_);
var_arr_ = rk5_.GetYArr();
param_arr_ = rk5_.GetParamArr();
SetScalParam(0, n_node, "tau_m", 10.0 ); // in ms
SetScalParam(0, n_node, "C_m", 250.0 ); // in pF
SetScalParam(0, n_node, "E_L", -70.0 ); // in mV
SetScalParam(0, n_node, "I_e", 0.0 ); // in pA
SetScalParam(0, n_node, "Theta_rel", -55.0 - (-70.0) ); // relative to E_L_
SetScalParam(0, n_node, "V_reset_rel", -70.0 - (-70.0) ); // relative to E_L_
SetScalParam(0, n_node, "tau_ex", 2.0 ); // in ms
SetScalParam(0, n_node, "tau_in", 2.0 ); // in ms
// SetScalParam(0, n_node, "rho", 0.01 ); // in 1/s
// SetScalParam(0, n_node, "delta", 0.0 ); // in mV
SetScalParam(0, n_node, "t_ref", 2.0 ); // in ms
SetScalParam(0, n_node, "den_delay", 0.0); // in ms
SetScalParam(0, n_node, "P20", 0.0);
SetScalParam(0, n_node, "P11ex", 0.0);
SetScalParam(0, n_node, "P11in", 0.0);
SetScalParam(0, n_node, "P21ex", 0.0);
SetScalParam(0, n_node, "P21in", 0.0);
SetScalParam(0, n_node, "P22", 0.0);

SetScalVar(0, n_node, "I_syn_ex", 0.0 );
SetScalVar(0, n_node, "I_syn_in", 0.0 );
SetScalVar(0, n_node, "V_m_rel", -70.0 - (-70.0) ); // in mV, relative to E_L
SetScalVar(0, n_node, "refractory_step", 0 );

// multiplication factor of input signal is always 1 for all nodes
float input_weight = 1.0;
Expand All @@ -132,32 +168,38 @@ int {{ neuronName }}::Init(int i_node_0, int n_node, int n_port,
port_weight_arr_step_ = 0;
port_weight_port_step_ = 0;

port_input_arr_ = GetVarArr() + n_scal_var_
+ GetPortVarIdx("I_syn");
// input spike signal is stored in I_syn_ex, I_syn_in
port_input_arr_ = GetVarArr() + GetScalVarIdx("I_syn_ex");
port_input_arr_step_ = n_var_;
port_input_port_step_ = n_port_var_;
port_input_port_step_ = 1;

den_delay_arr_ = GetParamArr() + GetScalParamIdx("den_delay");

return 0;
}

int {{ neuronName }}::Calibrate(double time_min, float time_resolution)
int {{ neuronName }}::Update(long long it, double t1)
{
h_min_ = h_min_rel_* time_resolution;
h_ = h0_rel_* time_resolution;
rk5_.Calibrate(time_min, h_, rk5_data_struct_);
{{ neuronName }}_Update<<<(n_node_+1023)/1024, 1024>>>
(n_node_, i_node_0_, var_arr_, param_arr_, n_var_, n_param_);
// gpuErrchk( cudaDeviceSynchronize() );

return 0;
}

template <>
int {{ neuronName }}::UpdateNR<0>(long long it, double t1)
int {{ neuronName }}::Free()
{
FreeVarArr();
FreeParamArr();

return 0;
}

int {{ neuronName }}::Update(long long it, double t1) {
UpdateNR<MAX_PORT_NUM>(it, t1);
int {{ neuronName }}::Calibrate(double, float time_resolution)
{
{{ neuronName }}_Calibrate<<<(n_node_+1023)/1024, 1024>>>
(n_node_, param_arr_, n_param_, time_resolution);

return 0;
}

Loading

0 comments on commit 5eb7a3c

Please sign in to comment.