diff --git a/pynestml/visitors/ast_comparison_operator_visitor.py b/pynestml/visitors/ast_comparison_operator_visitor.py index 02ea6a49e..98e8030c4 100644 --- a/pynestml/visitors/ast_comparison_operator_visitor.py +++ b/pynestml/visitors/ast_comparison_operator_visitor.py @@ -19,17 +19,15 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . -""" -rhs : left=rhs comparisonOperator right=rhs -""" +from pynestml.symbols.boolean_type_symbol import BooleanTypeSymbol +from pynestml.symbols.error_type_symbol import ErrorTypeSymbol from pynestml.symbols.predefined_types import PredefinedTypes +from pynestml.symbols.string_type_symbol import StringTypeSymbol from pynestml.symbols.unit_type_symbol import UnitTypeSymbol from pynestml.utils.error_strings import ErrorStrings from pynestml.utils.logger import Logger, LoggingLevel from pynestml.utils.messages import MessageCode from pynestml.visitors.ast_visitor import ASTVisitor -from pynestml.symbols.boolean_type_symbol import BooleanTypeSymbol -from pynestml.symbols.error_type_symbol import ErrorTypeSymbol class ASTComparisonOperatorVisitor(ASTVisitor): @@ -49,6 +47,11 @@ def visit_expression(self, expr): lhs_type.referenced_object = expr.get_lhs() rhs_type.referenced_object = expr.get_rhs() + # both are string types + if lhs_type.is_primitive() and rhs_type.is_primitive() and isinstance(lhs_type, StringTypeSymbol) and isinstance(rhs_type, StringTypeSymbol): + expr.type = PredefinedTypes.get_boolean_type() + return + if (lhs_type.is_numeric_primitive() and rhs_type.is_numeric_primitive()) \ or (lhs_type.equals(rhs_type) and lhs_type.is_numeric()) or ( isinstance(lhs_type, BooleanTypeSymbol) and isinstance(rhs_type, BooleanTypeSymbol)): @@ -65,11 +68,10 @@ def visit_expression(self, expr): error_position=expr.get_source_position(), log_level=LoggingLevel.WARNING) return - else: - # hard incompatibility, cannot recover in c++, ERROR - error_msg = ErrorStrings.message_comparison(self, expr.get_source_position()) - expr.type = ErrorTypeSymbol() - Logger.log_message(code=MessageCode.HARD_INCOMPATIBILITY, - error_position=expr.get_source_position(), - message=error_msg, log_level=LoggingLevel.ERROR) - return + + # hard incompatibility, cannot recover in c++, ERROR + error_msg = ErrorStrings.message_comparison(self, expr.get_source_position()) + expr.type = ErrorTypeSymbol() + Logger.log_message(code=MessageCode.HARD_INCOMPATIBILITY, + error_position=expr.get_source_position(), + message=error_msg, log_level=LoggingLevel.ERROR) diff --git a/tests/nest_tests/resources/StringHandlingTest.nestml b/tests/nest_tests/resources/StringHandlingTest.nestml new file mode 100644 index 000000000..33d3feff5 --- /dev/null +++ b/tests/nest_tests/resources/StringHandlingTest.nestml @@ -0,0 +1,46 @@ +""" +StringHandlingTest.nestml +######################### + + +Copyright statement ++++++++++++++++++++ + +This file is part of NEST. + +Copyright (C) 2004 The NEST Initiative + +NEST is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, either version 2 of the License, or +(at your option) any later version. + +NEST is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License +along with NEST. If not, see . +""" +neuron string_handling_test: + state: + s1 string = "abc" + s2 string = "def" + b1 boolean = false + b2 boolean = false + + parameters: + s3 string = "ghi" + s4 string = "klm" + + internals: + s5 string = "ghi" + + update: + s7 string = s1 + s2 + if s7 == "abcdef": + b1 = true + + if s3 + s4 == s5 + "klm": + b2 = true diff --git a/tests/nest_tests/test_string_handling.py b/tests/nest_tests/test_string_handling.py new file mode 100644 index 000000000..cafa5802c --- /dev/null +++ b/tests/nest_tests/test_string_handling.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- +# +# test_string_handling.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +import nest +import numpy as np +import scipy as sp +import os + +from pynestml.frontend.pynestml_frontend import generate_nest_target +from pynestml.codegeneration.nest_tools import NESTTools + + +class TestStringHandling: + """Test string handling""" + + def test_string_handling(self): + input_path = os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), "resources", "StringHandlingTest.nestml"))) + target_path = "target" + logging_level = "INFO" + module_name = "nestmlmodule" + suffix = "_nestml" + + nest_version = NESTTools.detect_nest_version() + + nest.set_verbosity("M_ALL") + generate_nest_target(input_path, + target_path=target_path, + logging_level=logging_level, + module_name=module_name, + suffix=suffix) + nest.ResetKernel() + nest.Install("nestmlmodule") + + nrn = nest.Create("string_handling_test_nestml") + + nest.Simulate(100.) + + assert nrn.b1 + assert nrn.b2