diff --git a/aiida_quantumespresso/workflows/ph/base.py b/aiida_quantumespresso/workflows/ph/base.py index 119b683b5..95bd54117 100644 --- a/aiida_quantumespresso/workflows/ph/base.py +++ b/aiida_quantumespresso/workflows/ph/base.py @@ -30,7 +30,6 @@ def define(cls, spec): spec.outline( cls.setup, cls.validate_parameters, - cls.validate_resources, while_(cls.should_run_process)( cls.prepare_process, cls.run_process, @@ -40,7 +39,8 @@ def define(cls, spec): ) spec.expose_outputs(PwCalculation, exclude=('retrieved_folder',)) spec.exit_code(204, 'ERROR_INVALID_INPUT_RESOURCES_UNDERSPECIFIED', - message='The `metadata.options` did not specify both `resources.num_machines` and `max_wallclock_seconds`.') + message='The `metadata.options` did not specify both `resources.num_machines` and `max_wallclock_seconds`. ' + 'This exit status has been deprecated as the check it corresponded to was incorrect.') spec.exit_code(300, 'ERROR_UNRECOVERABLE_FAILURE', message='The calculation failed with an unrecoverable error.') # yapf: enable @@ -66,21 +66,7 @@ def validate_parameters(self): if self.inputs.only_initialization.value: self.ctx.inputs.settings['ONLY_INITIALIZATION'] = True - def validate_resources(self): - """Validate the inputs related to the resources. - - The `metadata.options` should at least contain the options `resources` and `max_wallclock_seconds`, where - `resources` should define the `num_machines`. - """ - num_machines = self.ctx.inputs.metadata.options.get('resources', {}).get('num_machines', None) - max_wallclock_seconds = self.ctx.inputs.metadata.options.get('max_wallclock_seconds', None) - - if num_machines is None or max_wallclock_seconds is None: - return self.exit_codes.ERROR_INVALID_INPUT_RESOURCES_UNDERSPECIFIED - - self.set_max_seconds(max_wallclock_seconds) - - def set_max_seconds(self, max_wallclock_seconds): + def set_max_seconds(self, max_wallclock_seconds: None): """Set the `max_seconds` to a fraction of `max_wallclock_seconds` option to prevent out-of-walltime problems. :param max_wallclock_seconds: the maximum wallclock time that will be set in the scheduler settings. @@ -95,6 +81,11 @@ def prepare_process(self): If a `restart_calc` has been set in the context, its `remote_folder` will be used as the `parent_folder` input for the next calculation and the `restart_mode` is set to `restart`. """ + max_wallclock_seconds = self.ctx.inputs.metadata.options.get('max_wallclock_seconds', None) + + if max_wallclock_seconds is not None and 'max_seconds' not in self.ctx.inputs.parameters['INPUTPH']: + self.set_max_seconds(max_wallclock_seconds) + if self.ctx.restart_calc: self.ctx.inputs.parameters['INPUTPH']['recover'] = True self.ctx.inputs.parent_folder = self.ctx.restart_calc.outputs.remote_folder diff --git a/aiida_quantumespresso/workflows/pw/base.py b/aiida_quantumespresso/workflows/pw/base.py index 46d113f46..ad974d171 100644 --- a/aiida_quantumespresso/workflows/pw/base.py +++ b/aiida_quantumespresso/workflows/pw/base.py @@ -76,7 +76,6 @@ def define(cls, spec): cls.validate_parameters, cls.validate_kpoints, cls.validate_pseudos, - cls.validate_resources, if_(cls.should_run_init)( cls.validate_init_inputs, cls.run_init, @@ -99,9 +98,11 @@ def define(cls, spec): spec.exit_code(202, 'ERROR_INVALID_INPUT_KPOINTS', message='Neither the `kpoints` nor the `kpoints_distance` input was specified.') spec.exit_code(203, 'ERROR_INVALID_INPUT_RESOURCES', - message='Neither the `options` nor `automatic_parallelization` input was specified.') + message='Neither the `options` nor `automatic_parallelization` input was specified. ' + 'This exit status has been deprecated as the check it corresponded to was incorrect.') spec.exit_code(204, 'ERROR_INVALID_INPUT_RESOURCES_UNDERSPECIFIED', - message='The `metadata.options` did not specify both `resources.num_machines` and `max_wallclock_seconds`.') + message='The `metadata.options` did not specify both `resources.num_machines` and `max_wallclock_seconds`. ' + 'This exit status has been deprecated as the check it corresponded to was incorrect.') spec.exit_code(210, 'ERROR_INVALID_INPUT_AUTOMATIC_PARALLELIZATION_MISSING_KEY', message='Required key for `automatic_parallelization` was not specified.') spec.exit_code(211, 'ERROR_INVALID_INPUT_AUTOMATIC_PARALLELIZATION_UNRECOGNIZED_KEY', @@ -282,26 +283,6 @@ def validate_pseudos(self): self.report(f'{exception}') return self.exit_codes.ERROR_INVALID_INPUT_PSEUDO_POTENTIALS - def validate_resources(self): - """Validate the inputs related to the resources. - - One can omit the normally required `options.resources` input for the `PwCalculation`, as long as the input - `automatic_parallelization` is specified. If this is not the case, the `metadata.options` should at least - contain the options `resources` and `max_wallclock_seconds`, where `resources` should define the `num_machines`. - """ - if 'automatic_parallelization' not in self.inputs and 'options' not in self.ctx.inputs.metadata: - return self.exit_codes.ERROR_INVALID_INPUT_RESOURCES - - # If automatic parallelization is not enabled, we better make sure that the options satisfy minimum requirements - if 'automatic_parallelization' not in self.inputs: - num_machines = self.ctx.inputs.metadata.options.get('resources', {}).get('num_machines', None) - max_wallclock_seconds = self.ctx.inputs.metadata.options.get('max_wallclock_seconds', None) - - if num_machines is None or max_wallclock_seconds is None: - return self.exit_codes.ERROR_INVALID_INPUT_RESOURCES_UNDERSPECIFIED - - self.set_max_seconds(max_wallclock_seconds) - def set_max_seconds(self, max_wallclock_seconds): """Set the `max_seconds` to a fraction of `max_wallclock_seconds` option to prevent out-of-walltime problems. @@ -419,6 +400,11 @@ def prepare_process(self): for the next calculation and the `restart_mode` is set to `restart`. Otherwise, no `parent_folder` is used and `restart_mode` is set to `from_scratch`. """ + max_wallclock_seconds = self.ctx.inputs.metadata.options.get('max_wallclock_seconds', None) + + if max_wallclock_seconds is not None and 'max_seconds' not in self.ctx.inputs.parameters['CONTROL']: + self.set_max_seconds(max_wallclock_seconds) + if self.ctx.restart_calc: self.ctx.inputs.parameters['CONTROL']['restart_mode'] = 'restart' self.ctx.inputs.parent_folder = self.ctx.restart_calc.outputs.remote_folder diff --git a/tests/conftest.py b/tests/conftest.py index 6b860329d..042cbef4e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -591,16 +591,20 @@ def _generate_inputs_cp(autopilot=False): def generate_workchain_pw(generate_workchain, generate_inputs_pw, generate_calc_job_node): """Generate an instance of a `PwBaseWorkChain`.""" - def _generate_workchain_pw(exit_code=None, inputs=None): + def _generate_workchain_pw(exit_code=None, inputs=None, return_inputs=False): from plumpy import ProcessState from aiida.orm import Dict entry_point = 'quantumespresso.pw.base' + if inputs is None: pw_inputs = generate_inputs_pw() kpoints = pw_inputs.pop('kpoints') inputs = {'pw': pw_inputs, 'kpoints': kpoints} + if return_inputs: + return inputs + process = generate_workchain(entry_point, inputs) if exit_code is not None: @@ -616,6 +620,36 @@ def _generate_workchain_pw(exit_code=None, inputs=None): return _generate_workchain_pw +@pytest.fixture +def generate_workchain_ph(generate_workchain, generate_inputs_ph, generate_calc_job_node): + """Generate an instance of a `PhBaseWorkChain`.""" + + def _generate_workchain_ph(exit_code=None, inputs=None, return_inputs=False): + from plumpy import ProcessState + + entry_point = 'quantumespresso.ph.base' + + if inputs is None: + inputs = {'ph': generate_inputs_ph()} + + if return_inputs: + return inputs + + process = generate_workchain(entry_point, inputs) + + if exit_code is not None: + node = generate_calc_job_node() + node.set_process_state(ProcessState.FINISHED) + node.set_exit_status(exit_code.status) + + process.ctx.iteration = 1 + process.ctx.children = [node] + + return process + + return _generate_workchain_ph + + @pytest.fixture def generate_workchain_pdos(generate_workchain, generate_inputs_pw, fixture_code): """Generate an instance of a `PdosWorkChain`.""" diff --git a/tests/workflows/ph/test_base.py b/tests/workflows/ph/test_base.py index 2372d78e3..a4b76a351 100644 --- a/tests/workflows/ph/test_base.py +++ b/tests/workflows/ph/test_base.py @@ -1,10 +1,6 @@ # -*- coding: utf-8 -*- # pylint: disable=no-member,redefined-outer-name """Tests for the `PhBaseWorkChain` class.""" -import pytest - -from plumpy import ProcessState - from aiida.common import AttributeDict from aiida.engine import ProcessHandlerReport @@ -12,27 +8,6 @@ from aiida_quantumespresso.workflows.ph.base import PhBaseWorkChain -@pytest.fixture -def generate_workchain_ph(generate_workchain, generate_inputs_ph, generate_calc_job_node): - """Generate an instance of a `PhBaseWorkChain`.""" - - def _generate_workchain_ph(exit_code=None): - entry_point = 'quantumespresso.ph.base' - process = generate_workchain(entry_point, {'ph': generate_inputs_ph()}) - - if exit_code is not None: - node = generate_calc_job_node() - node.set_process_state(ProcessState.FINISHED) - node.set_exit_status(exit_code.status) - - process.ctx.iteration = 1 - process.ctx.children = [node] - - return process - - return _generate_workchain_ph - - def test_setup(generate_workchain_ph): """Test `PhBaseWorkChain.setup`.""" process = generate_workchain_ph() @@ -84,3 +59,32 @@ def test_handle_convergence_not_achieved(generate_workchain_ph): result = process.inspect_process() assert result.status == 0 + + +def test_set_max_seconds(generate_workchain_ph): + """Test that `max_seconds` gets set in the parameters based on `max_wallclock_seconds` unless already set.""" + inputs = generate_workchain_ph(return_inputs=True) + max_wallclock_seconds = inputs['ph']['metadata']['options']['max_wallclock_seconds'] + + process = generate_workchain_ph(inputs=inputs) + process.setup() + process.validate_parameters() + process.prepare_process() + + expected_max_seconds = max_wallclock_seconds * process.defaults.delta_factor_max_seconds + assert 'max_seconds' in process.ctx.inputs['parameters']['INPUTPH'] + assert process.ctx.inputs['parameters']['INPUTPH']['max_seconds'] == expected_max_seconds + + # Now check that if `max_seconds` is already explicitly set in the parameters, it is not overwritten. + inputs = generate_workchain_ph(return_inputs=True) + max_seconds = 1 + max_wallclock_seconds = inputs['ph']['metadata']['options']['max_wallclock_seconds'] + inputs['ph']['parameters']['INPUTPH']['max_seconds'] = max_seconds + + process = generate_workchain_ph(inputs=inputs) + process.setup() + process.validate_parameters() + process.prepare_process() + + assert 'max_seconds' in process.ctx.inputs['parameters']['INPUTPH'] + assert process.ctx.inputs['parameters']['INPUTPH']['max_seconds'] == max_seconds diff --git a/tests/workflows/pw/test_base.py b/tests/workflows/pw/test_base.py index b50dc39f3..8cd89a52c 100644 --- a/tests/workflows/pw/test_base.py +++ b/tests/workflows/pw/test_base.py @@ -137,3 +137,32 @@ def test_sanity_check_no_bands(generate_workchain_pw): calculation = process.ctx.children[-1] assert process.sanity_check_insufficient_bands(calculation) is None + + +def test_set_max_seconds(generate_workchain_pw): + """Test that `max_seconds` gets set in the parameters based on `max_wallclock_seconds` unless already set.""" + inputs = generate_workchain_pw(return_inputs=True) + max_wallclock_seconds = inputs['pw']['metadata']['options']['max_wallclock_seconds'] + + process = generate_workchain_pw(inputs=inputs) + process.setup() + process.validate_parameters() + process.prepare_process() + + expected_max_seconds = max_wallclock_seconds * process.defaults.delta_factor_max_seconds + assert 'max_seconds' in process.ctx.inputs['parameters']['CONTROL'] + assert process.ctx.inputs['parameters']['CONTROL']['max_seconds'] == expected_max_seconds + + # Now check that if `max_seconds` is already explicitly set in the parameters, it is not overwritten. + inputs = generate_workchain_pw(return_inputs=True) + max_seconds = 1 + max_wallclock_seconds = inputs['pw']['metadata']['options']['max_wallclock_seconds'] + inputs['pw']['parameters']['CONTROL']['max_seconds'] = max_seconds + + process = generate_workchain_pw(inputs=inputs) + process.setup() + process.validate_parameters() + process.prepare_process() + + assert 'max_seconds' in process.ctx.inputs['parameters']['CONTROL'] + assert process.ctx.inputs['parameters']['CONTROL']['max_seconds'] == max_seconds