Skip to content

Commit

Permalink
Modify neuron templates
Browse files Browse the repository at this point in the history
  • Loading branch information
pnbabu committed Jun 30, 2023
1 parent de4305c commit 1a0dabb
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
2 changes: 1 addition & 1 deletion pynestml/codegeneration/nest_gpu_code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class NESTGPUCodeGenerator(NESTCodeGenerator):
"preserve_expressions": False,
"simplify_expression": "sympy.logcombine(sympy.powsimp(sympy.expand(expr)))",
"templates": {
"path": os.path.join(os.path.dirname(__file__), "resources_nest_gpu"),
"path": "point_neuron",
"model_templates": {
"neuron": ["@NEURON_NAME@.cu.jinja2", "@NEURON_NAME@.h.jinja2"]
# "@NEURON_NAME@_kernel.h.jinja2", "@NEURON_NAME@_rk5.h.jinja2"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ __global__ void {{ neuronName }}_Calibrate(int n_node, float *param_arr,
{%- for internals_block in neuron.get_internals_blocks() %}
{%- for decl in internals_block.get_declarations() %}
{%- for variable in decl.get_variables() %}
{%- set variable_symbol = variable.get_scope().resolve_to_symbol(variable.get_complete_name(), SymbolKind.VARIABLE) %}
{%- include "directives/MemberInitialization.jinja2" %}
{%- endfor %}
{%- endfor %}
Expand Down Expand Up @@ -184,14 +183,16 @@ int {{ neuronName }}::Init(int i_node_0, int n_node, int /*n_port*/,
{# SetScalVar(0, n_node, "refractory_step", 0 );#}

{%- filter indent(2) %}
{%- for variable in neuron.get_parameter_symbols() %}
SetScalParam(0, n_node, {{ printer_no_origin.print(variable) }}, {{printer.print(variable.get_declaring_expression())}}); // as {{variable.get_type_symbol().print_symbol()}}
{%- for variable_symbol in neuron.get_parameter_symbols() %}
{%- set variable = utils.get_parameter_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
SetScalParam(0, n_node, {{ printer_no_origin.print(variable) }}, {{printer.print(variable_symbol.get_declaring_expression())}}); // as {{variable_symbol.get_type_symbol().print_symbol()}}
{%- endfor %}
{%- endfilter %}


{%- filter indent(2) %}
{%- for variable in neuron.get_internal_symbols() %}
{%- for variable_symbol in neuron.get_internal_symbols() %}
{%- set variable = utils.get_internal_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
SetScalParam(0, n_node, {{ printer_no_origin.print(variable) }}, 0.0);
{%- endfor %}
{%- endfilter %}
Expand Down

0 comments on commit 1a0dabb

Please sign in to comment.