Skip to content

Commit

Permalink
fix python_standalone code generator (#914)
Browse files Browse the repository at this point in the history
Co-authored-by: C.A.P. Linssen <charl@turingbirds.com>
  • Loading branch information
clinssen and C.A.P. Linssen authored Jun 21, 2023
1 parent f517c2f commit 4f39c75
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 11 deletions.
67 changes: 67 additions & 0 deletions .github/workflows/nestml-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,73 @@ jobs:
run: |
python3 extras/codeanalysis/check_copyright_headers.py && python3 -m pycodestyle $GITHUB_WORKSPACE -v --ignore=E241,E501,E714,E713,E714,E252,W503 --exclude=$GITHUB_WORKSPACE/doc,$GITHUB_WORKSPACE/.git,$GITHUB_WORKSPACE/NESTML.egg-info,$GITHUB_WORKSPACE/pynestml/generated,$GITHUB_WORKSPACE/extras,$GITHUB_WORKSPACE/build,$GITHUB_WORKSPACE/.github
build_and_test_py_standalone:
needs: [static_checks]
runs-on: ubuntu-latest
steps:
# Checkout the repository contents
- name: Checkout NESTML code
uses: actions/checkout@v3

# Setup Python version
- name: Setup Python 3.8
uses: actions/setup-python@v4
with:
python-version: 3.8

# Install dependencies
- name: Install apt dependencies
run: |
sudo apt-get update
sudo apt-get install libgsl0-dev libncurses5-dev pkg-config
sudo apt-get install python3-all-dev python3-matplotlib python3-numpy python3-scipy ipython3
# Install Python dependencies
- name: Python dependencies
run: |
python -m pip install --upgrade pip pytest jupyterlab matplotlib pycodestyle scipy
python -m pip install -r requirements.txt
# Install Java
- name: Install Java 11
uses: actions/setup-java@v2
with:
distribution: 'zulu'
java-version: '11.0.x'
java-package: jre

# Install Antlr4
- name: Install Antlr4
run: |
wget http://www.antlr.org/download/antlr-4.10-complete.jar
echo \#\!/bin/bash > antlr4
echo java -cp \"`pwd`/antlr-4.10-complete.jar:$CLASSPATH\" org.antlr.v4.Tool \"\$@\" >> antlr4
echo >> antlr4
chmod +x antlr4
echo PATH=$PATH:`pwd` >> $GITHUB_ENV
# Install NESTML
- name: Install NESTML
run: |
export PYTHONPATH=${{ env.PYTHONPATH }}:${{ env.NEST_INSTALL }}/lib/python3.8/site-packages
#echo PYTHONPATH=`pwd` >> $GITHUB_ENV
echo "PYTHONPATH=$PYTHONPATH" >> $GITHUB_ENV
python setup.py install
- name: Generate Lexer and Parser using Antlr4
run: |
cd pynestml/grammars
./generate_lexer_parser
# Run integration tests
- name: Run integration tests
run: |
rc=0
for fn in $GITHUB_WORKSPACE/tests/python_standalone_tests/*.py; do
pytest -s -o log_cli=true -o log_cli_level="DEBUG" ${fn} || rc=1
done;
exit $rc
build_and_test:
needs: [static_checks]
runs-on: ubuntu-latest
Expand Down
3 changes: 0 additions & 3 deletions pynestml/codegeneration/printers/nest_variable_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def __init__(self, expression_printer: ExpressionPrinter, with_origin: bool = Tr
super().__init__(expression_printer)
self.with_origin = with_origin
self.with_vector_parameter = with_vector_parameter
self._state_symbols = []

def print_variable(self, variable: ASTVariable) -> str:
"""
Expand Down Expand Up @@ -159,8 +158,6 @@ def _print_buffer_value(self, variable: ASTVariable) -> str:
return variable_symbol.get_symbol_name() + '_grid_sum_'

def _print(self, variable: ASTVariable, symbol, with_origin: bool = True) -> str:
assert all([type(s) == str for s in self._state_symbols])

variable_name = CppVariablePrinter._print_cpp_name(variable.get_complete_name())

if symbol.is_local():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def print_variable(self, node: ASTVariable) -> str:
symbol = node.get_scope().resolve_to_symbol(node.get_complete_name(), SymbolKind.VARIABLE)

if symbol.is_state() and not symbol.is_inline_expression:
if node.get_complete_name() in self._state_symbols:
if "_is_numeric" in dir(node) and node._is_numeric:
# ode_state[] here is---and must be---the state vector supplied by the integrator, not the state vector in the node, node.S_.ode_state[].
return "ode_state[node.S_.ode_state_variable_name_to_index[\"" + CppVariablePrinter._print_cpp_name(node.get_complete_name()) + "\"]]"

Expand Down
3 changes: 0 additions & 3 deletions pynestml/codegeneration/printers/python_variable_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def __init__(self, expression_printer: ExpressionPrinter, with_origin: bool = Tr
super().__init__(expression_printer)
self.with_origin = with_origin
self.with_vector_parameter = with_vector_parameter
self._state_symbols = []

@classmethod
def _print_python_name(cls, variable_name: str) -> str:
Expand Down Expand Up @@ -146,8 +145,6 @@ def _print_vector_parameter_name_reference(self, variable: ASTVariable) -> str:
return self._expression_printer.print(vector_parameter)

def _print(self, variable, symbol, with_origin: bool = True) -> str:
assert all([type(s) == str for s in self._state_symbols])

variable_name = PythonVariablePrinter._print_python_name(variable.get_complete_name())

if symbol.is_local():
Expand Down
5 changes: 3 additions & 2 deletions pynestml/codegeneration/python_code_generator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,22 @@
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see <http://www.gnu.org/licenses/>.

from pynestml.meta_model.ast_variable import ASTVariable
from pynestml.symbols.variable_symbol import VariableSymbol
from pynestml.symbols.variable_symbol import BlockType


class PythonCodeGeneratorUtils:

@classmethod
def print_symbol_origin(cls, variable_symbol: VariableSymbol) -> str:
def print_symbol_origin(cls, variable_symbol: VariableSymbol, variable: ASTVariable) -> str:
"""
Returns a prefix corresponding to the origin of the variable symbol.
:param variable_symbol: a single variable symbol.
:return: the corresponding prefix
"""
if variable_symbol.block_type in [BlockType.STATE, BlockType.EQUATION]:
if numerical_state_symbols and variable_symbol.get_symbol_name() in numerical_state_symbols:
if "_is_numeric" in dir(variable) and variable._is_numeric:
return 'self.S_.ode_state[self.S_.ode_state_variable_name_to_index["%s"]]'

return 'self.S_.%s'
Expand Down
3 changes: 2 additions & 1 deletion pynestml/codegeneration/python_standalone_code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ class PythonStandaloneCodeGenerator(NESTCodeGenerator):
"neuron": ["@NEURON_NAME@.py.jinja2"]
},
"module_templates": ["simulator.py.jinja2", "test_python_standalone_module.py.jinja2", "neuron.py.jinja2", "spike_generator.py.jinja2", "utils.py.jinja2"]
}
},
"solver": "analytic"
}

def __init__(self, options: Optional[Mapping[str, Any]] = None):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,4 @@ def test_python_standalone_neuron_build_and_sim_analytic(self):

from nestmlmodule.test_python_standalone_module import TestSimulator
neuron_log = TestSimulator().test_simulator()
np.testing.assert_allclose(neuron_log["V_abs"][-1], 11.192718053106296)
np.testing.assert_allclose(neuron_log["V_m"][-1], -58.80728194689356)

0 comments on commit 4f39c75

Please sign in to comment.