diff --git a/aiida/engine/processes/workchains/restart.py b/aiida/engine/processes/workchains/restart.py index 481f056755..2be440d530 100644 --- a/aiida/engine/processes/workchains/restart.py +++ b/aiida/engine/processes/workchains/restart.py @@ -240,7 +240,7 @@ def inspect_process(self): # pylint: disable=inconsistent-return-statements,too # Here either the process finished successful or at least one handler returned a report so it can no longer be # considered to be an unhandled failed process and therefore we reset the flag - self.ctx.unhandled_failure = True + self.ctx.unhandled_failure = False # If at least one handler returned a report, the action depends on its exit code and that of the process itself if last_report: diff --git a/tests/conftest.py b/tests/conftest.py index 84afd04efd..4f0b661230 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -91,3 +91,56 @@ def _generate_calc_job(folder, entry_point_name, inputs=None): return calc_info return _generate_calc_job + + +@pytest.fixture +def generate_work_chain(): + """Generate an instance of a `WorkChain`.""" + + def _generate_work_chain(entry_point, inputs=None): + """Generate an instance of a `WorkChain` with the given entry point and inputs. + + :param entry_point: entry point name of the work chain subclass. + :param inputs: inputs to be passed to process construction. + :return: a `WorkChain` instance. + """ + from aiida.engine.utils import instantiate_process + from aiida.manage.manager import get_manager + from aiida.plugins import WorkflowFactory + + inputs = inputs or {} + process_class = WorkflowFactory(entry_point) if isinstance(entry_point, str) else entry_point + runner = get_manager().get_runner() + process = instantiate_process(runner, process_class, **inputs) + + return process + + return _generate_work_chain + + +@pytest.fixture +def generate_calculation_node(): + """Generate an instance of a `CalculationNode`.""" + from aiida.engine import ProcessState + + def _generate_calculation_node(process_state=ProcessState.FINISHED, exit_status=None): + """Generate an instance of a `CalculationNode`.. + + :param process_state: state to set + :param exit_status: optional exit status, will be set to `0` if `process_state` is `ProcessState.FINISHED` + :return: a `CalculationNode` instance. + """ + from aiida.orm import CalculationNode + + if process_state is ProcessState.FINISHED and exit_status is None: + exit_status = 0 + + node = CalculationNode() + node.set_process_state(process_state) + + if exit_status is not None: + node.set_exit_status(exit_status) + + return node + + return _generate_calculation_node diff --git a/tests/engine/processes/workchains/test_restart.py b/tests/engine/processes/workchains/test_restart.py index 5cbabaa73b..77de70517d 100644 --- a/tests/engine/processes/workchains/test_restart.py +++ b/tests/engine/processes/workchains/test_restart.py @@ -8,44 +8,99 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Tests for `aiida.engine.processes.workchains.restart` module.""" -from aiida.backends.testbase import AiidaTestCase -from aiida.engine.processes.workchains.restart import BaseRestartWorkChain -from aiida.engine.processes.workchains.utils import process_handler +# pylint: disable=invalid-name +import pytest +from aiida.engine import CalcJob, BaseRestartWorkChain, process_handler, ProcessState, ProcessHandlerReport, ExitCode -class TestBaseRestartWorkChain(AiidaTestCase): - """Tests for the `BaseRestartWorkChain` class.""" - @staticmethod - def test_is_process_handler(): - """Test the `BaseRestartWorkChain.is_process_handler` class method.""" +class SomeWorkChain(BaseRestartWorkChain): + """Dummy class.""" - class SomeWorkChain(BaseRestartWorkChain): - """Dummy class.""" + _process_class = CalcJob - @process_handler() - def handler_a(self, node): - pass + @process_handler() + def handler_a(self, node): # pylint: disable=inconsistent-return-statements,no-self-use + if node.exit_status == 400: + return ProcessHandlerReport() - def not_a_handler(self, node): - pass + def not_a_handler(self, node): + pass - assert SomeWorkChain.is_process_handler('handler_a') - assert not SomeWorkChain.is_process_handler('not_a_handler') - assert not SomeWorkChain.is_process_handler('unexisting_method') - @staticmethod - def test_get_process_handler(): - """Test the `BaseRestartWorkChain.get_process_handlers` class method.""" +def test_is_process_handler(): + """Test the `BaseRestartWorkChain.is_process_handler` class method.""" + assert SomeWorkChain.is_process_handler('handler_a') + assert not SomeWorkChain.is_process_handler('not_a_handler') + assert not SomeWorkChain.is_process_handler('unexisting_method') - class SomeWorkChain(BaseRestartWorkChain): - """Dummy class.""" - @process_handler - def handler_a(self, node): - pass +def test_get_process_handler(): + """Test the `BaseRestartWorkChain.get_process_handlers` class method.""" + assert [handler.__name__ for handler in SomeWorkChain.get_process_handlers()] == ['handler_a'] - def not_a_handler(self, node): - pass - assert [handler.__name__ for handler in SomeWorkChain.get_process_handlers()] == ['handler_a'] +@pytest.mark.usefixtures('clear_database_before_test') +def test_excepted_process(generate_work_chain, generate_calculation_node): + """Test that the workchain aborts if the sub process was excepted.""" + process = generate_work_chain(SomeWorkChain, {}) + process.setup() + process.ctx.children = [generate_calculation_node(ProcessState.EXCEPTED)] + assert process.inspect_process() == BaseRestartWorkChain.exit_codes.ERROR_SUB_PROCESS_EXCEPTED # pylint: disable=no-member + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_killed_process(generate_work_chain, generate_calculation_node): + """Test that the workchain aborts if the sub process was killed.""" + process = generate_work_chain(SomeWorkChain, {}) + process.setup() + process.ctx.children = [generate_calculation_node(ProcessState.KILLED)] + assert process.inspect_process() == BaseRestartWorkChain.exit_codes.ERROR_SUB_PROCESS_KILLED # pylint: disable=no-member + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_unhandled_failure(generate_work_chain, generate_calculation_node): + """Test the unhandled failure mechanism. + + The workchain should be aborted if there are two consecutive failed sub processes that went unhandled. We simulate + it by setting `ctx.unhandled_failure` to True and append two failed process nodes in `ctx.children`. + """ + process = generate_work_chain(SomeWorkChain, {}) + process.setup() + process.ctx.children = [generate_calculation_node(exit_status=100)] + assert process.inspect_process() is None + assert process.ctx.unhandled_failure is True + + process.ctx.children.append(generate_calculation_node(exit_status=100)) + assert process.inspect_process() == BaseRestartWorkChain.exit_codes.ERROR_SECOND_CONSECUTIVE_UNHANDLED_FAILURE # pylint: disable=no-member + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_unhandled_reset_after_success(generate_work_chain, generate_calculation_node): + """Test `ctx.unhandled_failure` is reset to `False` in `inspect_process` after a successful process.""" + process = generate_work_chain(SomeWorkChain, {}) + process.setup() + process.ctx.children = [generate_calculation_node(exit_status=100)] + assert process.inspect_process() is None + assert process.ctx.unhandled_failure is True + + process.ctx.children.append(generate_calculation_node(exit_status=0)) + assert process.inspect_process() is None + assert process.ctx.unhandled_failure is False + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_unhandled_reset_after_handled(generate_work_chain, generate_calculation_node): + """Test `ctx.unhandled_failure` is reset to `False` in `inspect_process` after a handled failed process.""" + process = generate_work_chain(SomeWorkChain, {}) + process.setup() + process.ctx.children = [generate_calculation_node(exit_status=0)] + assert process.inspect_process() is None + assert process.ctx.unhandled_failure is False + + # Exit status 400 of the last calculation job will be handled and so should reset the flag + process.ctx.children.append(generate_calculation_node(exit_status=400)) + result = process.inspect_process() + assert isinstance(result, ExitCode) + assert result.status == 0 + assert process.ctx.unhandled_failure is False