From d807b67746a81b4243682416834834d305c3fb62 Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Sat, 6 Jun 2020 09:33:23 +0200 Subject: [PATCH] Fix the unhandled failure mechanism of the `BaseRestartWorkChain` Any `BaseRestartWorkChain` will fail if the sub process fails twice in a row without being handled by a registered process handler. To monitor this the `unhandled_failure` context variable is used, which is set to `True` once a failed process was not handled. It should be unset as soon as the next process finishes successfully, or the failed process is handled. The reset was there, but was incorrectly setting the context var to `True` instead of `False`, which went unnoticed due to a lack of tests. Tests are now added, as well as for the case where the sub process excepted or was killed. --- aiida/engine/processes/workchains/restart.py | 2 +- tests/conftest.py | 53 ++++++++ .../processes/workchains/test_restart.py | 113 +++++++++++++----- 3 files changed, 138 insertions(+), 30 deletions(-) diff --git a/aiida/engine/processes/workchains/restart.py b/aiida/engine/processes/workchains/restart.py index 481f056755..2be440d530 100644 --- a/aiida/engine/processes/workchains/restart.py +++ b/aiida/engine/processes/workchains/restart.py @@ -240,7 +240,7 @@ def inspect_process(self): # pylint: disable=inconsistent-return-statements,too # Here either the process finished successful or at least one handler returned a report so it can no longer be # considered to be an unhandled failed process and therefore we reset the flag - self.ctx.unhandled_failure = True + self.ctx.unhandled_failure = False # If at least one handler returned a report, the action depends on its exit code and that of the process itself if last_report: diff --git a/tests/conftest.py b/tests/conftest.py index 84afd04efd..4f0b661230 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -91,3 +91,56 @@ def _generate_calc_job(folder, entry_point_name, inputs=None): return calc_info return _generate_calc_job + + +@pytest.fixture +def generate_work_chain(): + """Generate an instance of a `WorkChain`.""" + + def _generate_work_chain(entry_point, inputs=None): + """Generate an instance of a `WorkChain` with the given entry point and inputs. + + :param entry_point: entry point name of the work chain subclass. + :param inputs: inputs to be passed to process construction. + :return: a `WorkChain` instance. + """ + from aiida.engine.utils import instantiate_process + from aiida.manage.manager import get_manager + from aiida.plugins import WorkflowFactory + + inputs = inputs or {} + process_class = WorkflowFactory(entry_point) if isinstance(entry_point, str) else entry_point + runner = get_manager().get_runner() + process = instantiate_process(runner, process_class, **inputs) + + return process + + return _generate_work_chain + + +@pytest.fixture +def generate_calculation_node(): + """Generate an instance of a `CalculationNode`.""" + from aiida.engine import ProcessState + + def _generate_calculation_node(process_state=ProcessState.FINISHED, exit_status=None): + """Generate an instance of a `CalculationNode`.. + + :param process_state: state to set + :param exit_status: optional exit status, will be set to `0` if `process_state` is `ProcessState.FINISHED` + :return: a `CalculationNode` instance. + """ + from aiida.orm import CalculationNode + + if process_state is ProcessState.FINISHED and exit_status is None: + exit_status = 0 + + node = CalculationNode() + node.set_process_state(process_state) + + if exit_status is not None: + node.set_exit_status(exit_status) + + return node + + return _generate_calculation_node diff --git a/tests/engine/processes/workchains/test_restart.py b/tests/engine/processes/workchains/test_restart.py index 5cbabaa73b..77de70517d 100644 --- a/tests/engine/processes/workchains/test_restart.py +++ b/tests/engine/processes/workchains/test_restart.py @@ -8,44 +8,99 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Tests for `aiida.engine.processes.workchains.restart` module.""" -from aiida.backends.testbase import AiidaTestCase -from aiida.engine.processes.workchains.restart import BaseRestartWorkChain -from aiida.engine.processes.workchains.utils import process_handler +# pylint: disable=invalid-name +import pytest +from aiida.engine import CalcJob, BaseRestartWorkChain, process_handler, ProcessState, ProcessHandlerReport, ExitCode -class TestBaseRestartWorkChain(AiidaTestCase): - """Tests for the `BaseRestartWorkChain` class.""" - @staticmethod - def test_is_process_handler(): - """Test the `BaseRestartWorkChain.is_process_handler` class method.""" +class SomeWorkChain(BaseRestartWorkChain): + """Dummy class.""" - class SomeWorkChain(BaseRestartWorkChain): - """Dummy class.""" + _process_class = CalcJob - @process_handler() - def handler_a(self, node): - pass + @process_handler() + def handler_a(self, node): # pylint: disable=inconsistent-return-statements,no-self-use + if node.exit_status == 400: + return ProcessHandlerReport() - def not_a_handler(self, node): - pass + def not_a_handler(self, node): + pass - assert SomeWorkChain.is_process_handler('handler_a') - assert not SomeWorkChain.is_process_handler('not_a_handler') - assert not SomeWorkChain.is_process_handler('unexisting_method') - @staticmethod - def test_get_process_handler(): - """Test the `BaseRestartWorkChain.get_process_handlers` class method.""" +def test_is_process_handler(): + """Test the `BaseRestartWorkChain.is_process_handler` class method.""" + assert SomeWorkChain.is_process_handler('handler_a') + assert not SomeWorkChain.is_process_handler('not_a_handler') + assert not SomeWorkChain.is_process_handler('unexisting_method') - class SomeWorkChain(BaseRestartWorkChain): - """Dummy class.""" - @process_handler - def handler_a(self, node): - pass +def test_get_process_handler(): + """Test the `BaseRestartWorkChain.get_process_handlers` class method.""" + assert [handler.__name__ for handler in SomeWorkChain.get_process_handlers()] == ['handler_a'] - def not_a_handler(self, node): - pass - assert [handler.__name__ for handler in SomeWorkChain.get_process_handlers()] == ['handler_a'] +@pytest.mark.usefixtures('clear_database_before_test') +def test_excepted_process(generate_work_chain, generate_calculation_node): + """Test that the workchain aborts if the sub process was excepted.""" + process = generate_work_chain(SomeWorkChain, {}) + process.setup() + process.ctx.children = [generate_calculation_node(ProcessState.EXCEPTED)] + assert process.inspect_process() == BaseRestartWorkChain.exit_codes.ERROR_SUB_PROCESS_EXCEPTED # pylint: disable=no-member + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_killed_process(generate_work_chain, generate_calculation_node): + """Test that the workchain aborts if the sub process was killed.""" + process = generate_work_chain(SomeWorkChain, {}) + process.setup() + process.ctx.children = [generate_calculation_node(ProcessState.KILLED)] + assert process.inspect_process() == BaseRestartWorkChain.exit_codes.ERROR_SUB_PROCESS_KILLED # pylint: disable=no-member + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_unhandled_failure(generate_work_chain, generate_calculation_node): + """Test the unhandled failure mechanism. + + The workchain should be aborted if there are two consecutive failed sub processes that went unhandled. We simulate + it by setting `ctx.unhandled_failure` to True and append two failed process nodes in `ctx.children`. + """ + process = generate_work_chain(SomeWorkChain, {}) + process.setup() + process.ctx.children = [generate_calculation_node(exit_status=100)] + assert process.inspect_process() is None + assert process.ctx.unhandled_failure is True + + process.ctx.children.append(generate_calculation_node(exit_status=100)) + assert process.inspect_process() == BaseRestartWorkChain.exit_codes.ERROR_SECOND_CONSECUTIVE_UNHANDLED_FAILURE # pylint: disable=no-member + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_unhandled_reset_after_success(generate_work_chain, generate_calculation_node): + """Test `ctx.unhandled_failure` is reset to `False` in `inspect_process` after a successful process.""" + process = generate_work_chain(SomeWorkChain, {}) + process.setup() + process.ctx.children = [generate_calculation_node(exit_status=100)] + assert process.inspect_process() is None + assert process.ctx.unhandled_failure is True + + process.ctx.children.append(generate_calculation_node(exit_status=0)) + assert process.inspect_process() is None + assert process.ctx.unhandled_failure is False + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_unhandled_reset_after_handled(generate_work_chain, generate_calculation_node): + """Test `ctx.unhandled_failure` is reset to `False` in `inspect_process` after a handled failed process.""" + process = generate_work_chain(SomeWorkChain, {}) + process.setup() + process.ctx.children = [generate_calculation_node(exit_status=0)] + assert process.inspect_process() is None + assert process.ctx.unhandled_failure is False + + # Exit status 400 of the last calculation job will be handled and so should reset the flag + process.ctx.children.append(generate_calculation_node(exit_status=400)) + result = process.inspect_process() + assert isinstance(result, ExitCode) + assert result.status == 0 + assert process.ctx.unhandled_failure is False