diff --git a/pynestml/codegeneration/nest_code_generator.py b/pynestml/codegeneration/nest_code_generator.py index c4e492ee4..141c0eeda 100644 --- a/pynestml/codegeneration/nest_code_generator.py +++ b/pynestml/codegeneration/nest_code_generator.py @@ -846,10 +846,8 @@ def get_spike_update_expressions(self, neuron: ASTNeuron, kernel_buffers, solver for kernel_var in kernel.get_variables(): for var_order in range(ASTUtils.get_kernel_var_order_from_ode_toolbox_result(kernel_var.get_name(), solver_dicts)): - kernel_spike_buf_name = ASTUtils.construct_kernel_X_spike_buf_name( - kernel_var.get_name(), spike_input_port, var_order) - expr = ASTUtils.get_initial_value_from_ode_toolbox_result( - kernel_spike_buf_name, solver_dicts) + kernel_spike_buf_name = ASTUtils.construct_kernel_X_spike_buf_name(kernel_var.get_name(), spike_input_port, var_order) + expr = ASTUtils.get_initial_value_from_ode_toolbox_result(kernel_spike_buf_name, solver_dicts) assert expr is not None, "Initial value not found for kernel " + kernel_var expr = str(expr) if expr in ["0", "0.", "0.0"]: diff --git a/pynestml/utils/ast_utils.py b/pynestml/utils/ast_utils.py index 92aee031a..6fba498a7 100644 --- a/pynestml/utils/ast_utils.py +++ b/pynestml/utils/ast_utils.py @@ -28,6 +28,7 @@ from pynestml.codegeneration.printers.ast_printer import ASTPrinter from pynestml.codegeneration.printers.cpp_variable_printer import CppVariablePrinter +from pynestml.codegeneration.printers.nestml_printer import NESTMLPrinter from pynestml.generated.PyNestMLLexer import PyNestMLLexer from pynestml.meta_model.ast_assignment import ASTAssignment from pynestml.meta_model.ast_block import ASTBlock @@ -994,7 +995,7 @@ def add_declarations_to_state_block(cls, neuron: ASTNeuron, variables: List, ini return neuron @classmethod - def add_declaration_to_state_block(cls, neuron: ASTNeuron, variable: str, initial_value: str) -> ASTNeuron: + def add_declaration_to_state_block(cls, neuron: ASTNeuron, variable: str, initial_value: str, type_str: str = "real") -> ASTNeuron: """ Adds a single declaration to an arbitrary state block of the neuron. The declared variable is of type real. :param neuron: a neuron @@ -1007,7 +1008,7 @@ def add_declaration_to_state_block(cls, neuron: ASTNeuron, variable: str, initia tmp = ModelParser.parse_expression(initial_value) vector_variable = ASTUtils.get_vectorized_variable(tmp, neuron.get_scope()) - declaration_string = variable + ' real' + ( + declaration_string = variable + " " + type_str + ( '[' + vector_variable.get_vector_parameter() + ']' if vector_variable is not None and vector_variable.has_vector_parameter() else '') + ' = ' + initial_value ast_declaration = ModelParser.parse_declaration(declaration_string) @@ -1619,12 +1620,13 @@ def update_initial_values_for_odes(cls, neuron: ASTNeuron, solver_dicts: List[di @classmethod def create_initial_values_for_kernels(cls, neuron: ASTNeuron, solver_dicts: List[dict], kernels: List[ASTKernel]) -> None: - """ + r""" Add the variables used in kernels from the ode-toolbox result dictionary as ODEs in NESTML AST """ for solver_dict in solver_dicts: if solver_dict is None: continue + for var_name in solver_dict["initial_values"].keys(): if cls.variable_in_kernels(var_name, kernels): # original initial value expressions should have been removed to make place for ode-toolbox results @@ -1637,9 +1639,19 @@ def create_initial_values_for_kernels(cls, neuron: ASTNeuron, solver_dicts: List for var_name, expr in solver_dict["initial_values"].items(): # overwrite is allowed because initial values might be repeated between numeric and analytic solver if cls.variable_in_kernels(var_name, kernels): - expr = "0" # for kernels, "initial value" returned by ode-toolbox is actually the increment value; the actual initial value is assumed to be 0 + spike_in_port_name = var_name.split("__X__")[1] + spike_in_port_name = spike_in_port_name.split("__d")[0] + spike_in_port = ASTUtils.get_input_port_by_name(neuron.get_input_blocks(), spike_in_port_name) + if spike_in_port: + type_str = NESTMLPrinter().print_data_type(spike_in_port.data_type) + differential_order: int = len(re.findall("__d", var_name)) + if differential_order: + type_str += "*s**-" + str(differential_order) + else: + type_str = "real" + expr = "0 " + type_str # for kernels, "initial value" returned by ode-toolbox is actually the increment value; the actual initial value is assumed to be 0 if not cls.declaration_in_state_block(neuron, var_name): - cls.add_declaration_to_state_block(neuron, var_name, expr) + cls.add_declaration_to_state_block(neuron, var_name, expr, type_str) @classmethod def transform_ode_and_kernels_to_json(cls, neuron: ASTNeuron, parameters_blocks: Sequence[ASTBlockWithVariables], diff --git a/pynestml/utils/messages.py b/pynestml/utils/messages.py index 0529911d6..693fcaee7 100644 --- a/pynestml/utils/messages.py +++ b/pynestml/utils/messages.py @@ -1050,7 +1050,7 @@ def templated_arg_types_inconsistent(cls, function_name, failing_arg_idx, other_ """ message = 'In function \'' + function_name + '\': actual derived type of templated parameter ' + \ str(failing_arg_idx + 1) + ' is \'' + failing_arg_type_str + '\', which is inconsistent with that of parameter(s) ' + \ - ', '.join([str(_ + 1) for _ in other_args_idx]) + ', which have type \'' + other_type_str + '\'' + ', '.join([str(_ + 1) for _ in other_args_idx]) + ', which has/have type \'' + other_type_str + '\'' return MessageCode.TEMPLATED_ARG_TYPES_INCONSISTENT, message @classmethod