Skip to content

Commit

Permalink
Fix type derivation of kernel buffers (nest#936)
Browse files Browse the repository at this point in the history
  • Loading branch information
clinssen authored Aug 10, 2023
1 parent e3ad132 commit 8b48853
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 10 deletions.
6 changes: 2 additions & 4 deletions pynestml/codegeneration/nest_code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,10 +846,8 @@ def get_spike_update_expressions(self, neuron: ASTNeuron, kernel_buffers, solver

for kernel_var in kernel.get_variables():
for var_order in range(ASTUtils.get_kernel_var_order_from_ode_toolbox_result(kernel_var.get_name(), solver_dicts)):
kernel_spike_buf_name = ASTUtils.construct_kernel_X_spike_buf_name(
kernel_var.get_name(), spike_input_port, var_order)
expr = ASTUtils.get_initial_value_from_ode_toolbox_result(
kernel_spike_buf_name, solver_dicts)
kernel_spike_buf_name = ASTUtils.construct_kernel_X_spike_buf_name(kernel_var.get_name(), spike_input_port, var_order)
expr = ASTUtils.get_initial_value_from_ode_toolbox_result(kernel_spike_buf_name, solver_dicts)
assert expr is not None, "Initial value not found for kernel " + kernel_var
expr = str(expr)
if expr in ["0", "0.", "0.0"]:
Expand Down
22 changes: 17 additions & 5 deletions pynestml/utils/ast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from pynestml.codegeneration.printers.ast_printer import ASTPrinter
from pynestml.codegeneration.printers.cpp_variable_printer import CppVariablePrinter
from pynestml.codegeneration.printers.nestml_printer import NESTMLPrinter
from pynestml.generated.PyNestMLLexer import PyNestMLLexer
from pynestml.meta_model.ast_assignment import ASTAssignment
from pynestml.meta_model.ast_block import ASTBlock
Expand Down Expand Up @@ -994,7 +995,7 @@ def add_declarations_to_state_block(cls, neuron: ASTNeuron, variables: List, ini
return neuron

@classmethod
def add_declaration_to_state_block(cls, neuron: ASTNeuron, variable: str, initial_value: str) -> ASTNeuron:
def add_declaration_to_state_block(cls, neuron: ASTNeuron, variable: str, initial_value: str, type_str: str = "real") -> ASTNeuron:
"""
Adds a single declaration to an arbitrary state block of the neuron. The declared variable is of type real.
:param neuron: a neuron
Expand All @@ -1007,7 +1008,7 @@ def add_declaration_to_state_block(cls, neuron: ASTNeuron, variable: str, initia

tmp = ModelParser.parse_expression(initial_value)
vector_variable = ASTUtils.get_vectorized_variable(tmp, neuron.get_scope())
declaration_string = variable + ' real' + (
declaration_string = variable + " " + type_str + (
'[' + vector_variable.get_vector_parameter() + ']'
if vector_variable is not None and vector_variable.has_vector_parameter() else '') + ' = ' + initial_value
ast_declaration = ModelParser.parse_declaration(declaration_string)
Expand Down Expand Up @@ -1619,12 +1620,13 @@ def update_initial_values_for_odes(cls, neuron: ASTNeuron, solver_dicts: List[di

@classmethod
def create_initial_values_for_kernels(cls, neuron: ASTNeuron, solver_dicts: List[dict], kernels: List[ASTKernel]) -> None:
"""
r"""
Add the variables used in kernels from the ode-toolbox result dictionary as ODEs in NESTML AST
"""
for solver_dict in solver_dicts:
if solver_dict is None:
continue

for var_name in solver_dict["initial_values"].keys():
if cls.variable_in_kernels(var_name, kernels):
# original initial value expressions should have been removed to make place for ode-toolbox results
Expand All @@ -1637,9 +1639,19 @@ def create_initial_values_for_kernels(cls, neuron: ASTNeuron, solver_dicts: List
for var_name, expr in solver_dict["initial_values"].items():
# overwrite is allowed because initial values might be repeated between numeric and analytic solver
if cls.variable_in_kernels(var_name, kernels):
expr = "0" # for kernels, "initial value" returned by ode-toolbox is actually the increment value; the actual initial value is assumed to be 0
spike_in_port_name = var_name.split("__X__")[1]
spike_in_port_name = spike_in_port_name.split("__d")[0]
spike_in_port = ASTUtils.get_input_port_by_name(neuron.get_input_blocks(), spike_in_port_name)
if spike_in_port:
type_str = NESTMLPrinter().print_data_type(spike_in_port.data_type)
differential_order: int = len(re.findall("__d", var_name))
if differential_order:
type_str += "*s**-" + str(differential_order)
else:
type_str = "real"
expr = "0 " + type_str # for kernels, "initial value" returned by ode-toolbox is actually the increment value; the actual initial value is assumed to be 0
if not cls.declaration_in_state_block(neuron, var_name):
cls.add_declaration_to_state_block(neuron, var_name, expr)
cls.add_declaration_to_state_block(neuron, var_name, expr, type_str)

@classmethod
def transform_ode_and_kernels_to_json(cls, neuron: ASTNeuron, parameters_blocks: Sequence[ASTBlockWithVariables],
Expand Down
2 changes: 1 addition & 1 deletion pynestml/utils/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1050,7 +1050,7 @@ def templated_arg_types_inconsistent(cls, function_name, failing_arg_idx, other_
"""
message = 'In function \'' + function_name + '\': actual derived type of templated parameter ' + \
str(failing_arg_idx + 1) + ' is \'' + failing_arg_type_str + '\', which is inconsistent with that of parameter(s) ' + \
', '.join([str(_ + 1) for _ in other_args_idx]) + ', which have type \'' + other_type_str + '\''
', '.join([str(_ + 1) for _ in other_args_idx]) + ', which has/have type \'' + other_type_str + '\''
return MessageCode.TEMPLATED_ARG_TYPES_INCONSISTENT, message

@classmethod
Expand Down

0 comments on commit 8b48853

Please sign in to comment.