Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PwBaseWorkChain: fix bug in validate_resources validator #683

Merged
merged 1 commit into from
May 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 8 additions & 17 deletions aiida_quantumespresso/workflows/ph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def define(cls, spec):
spec.outline(
cls.setup,
cls.validate_parameters,
cls.validate_resources,
while_(cls.should_run_process)(
cls.prepare_process,
cls.run_process,
Expand All @@ -40,7 +39,8 @@ def define(cls, spec):
)
spec.expose_outputs(PwCalculation, exclude=('retrieved_folder',))
spec.exit_code(204, 'ERROR_INVALID_INPUT_RESOURCES_UNDERSPECIFIED',
message='The `metadata.options` did not specify both `resources.num_machines` and `max_wallclock_seconds`.')
message='The `metadata.options` did not specify both `resources.num_machines` and `max_wallclock_seconds`. '
'This exit status has been deprecated as the check it corresponded to was incorrect.')
spec.exit_code(300, 'ERROR_UNRECOVERABLE_FAILURE',
message='The calculation failed with an unrecoverable error.')
# yapf: enable
Expand All @@ -66,21 +66,7 @@ def validate_parameters(self):
if self.inputs.only_initialization.value:
self.ctx.inputs.settings['ONLY_INITIALIZATION'] = True

def validate_resources(self):
"""Validate the inputs related to the resources.

The `metadata.options` should at least contain the options `resources` and `max_wallclock_seconds`, where
`resources` should define the `num_machines`.
"""
num_machines = self.ctx.inputs.metadata.options.get('resources', {}).get('num_machines', None)
max_wallclock_seconds = self.ctx.inputs.metadata.options.get('max_wallclock_seconds', None)

if num_machines is None or max_wallclock_seconds is None:
return self.exit_codes.ERROR_INVALID_INPUT_RESOURCES_UNDERSPECIFIED

self.set_max_seconds(max_wallclock_seconds)

def set_max_seconds(self, max_wallclock_seconds):
def set_max_seconds(self, max_wallclock_seconds: None):
"""Set the `max_seconds` to a fraction of `max_wallclock_seconds` option to prevent out-of-walltime problems.

:param max_wallclock_seconds: the maximum wallclock time that will be set in the scheduler settings.
Expand All @@ -95,6 +81,11 @@ def prepare_process(self):
If a `restart_calc` has been set in the context, its `remote_folder` will be used as the `parent_folder` input
for the next calculation and the `restart_mode` is set to `restart`.
"""
max_wallclock_seconds = self.ctx.inputs.metadata.options.get('max_wallclock_seconds', None)

if max_wallclock_seconds is not None and 'max_seconds' not in self.ctx.inputs.parameters['INPUTPH']:
self.set_max_seconds(max_wallclock_seconds)

if self.ctx.restart_calc:
self.ctx.inputs.parameters['INPUTPH']['recover'] = True
self.ctx.inputs.parent_folder = self.ctx.restart_calc.outputs.remote_folder
Expand Down
32 changes: 9 additions & 23 deletions aiida_quantumespresso/workflows/pw/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def define(cls, spec):
cls.validate_parameters,
cls.validate_kpoints,
cls.validate_pseudos,
cls.validate_resources,
if_(cls.should_run_init)(
cls.validate_init_inputs,
cls.run_init,
Expand All @@ -99,9 +98,11 @@ def define(cls, spec):
spec.exit_code(202, 'ERROR_INVALID_INPUT_KPOINTS',
message='Neither the `kpoints` nor the `kpoints_distance` input was specified.')
spec.exit_code(203, 'ERROR_INVALID_INPUT_RESOURCES',
message='Neither the `options` nor `automatic_parallelization` input was specified.')
message='Neither the `options` nor `automatic_parallelization` input was specified. '
'This exit status has been deprecated as the check it corresponded to was incorrect.')
spec.exit_code(204, 'ERROR_INVALID_INPUT_RESOURCES_UNDERSPECIFIED',
message='The `metadata.options` did not specify both `resources.num_machines` and `max_wallclock_seconds`.')
message='The `metadata.options` did not specify both `resources.num_machines` and `max_wallclock_seconds`. '
'This exit status has been deprecated as the check it corresponded to was incorrect.')
spec.exit_code(210, 'ERROR_INVALID_INPUT_AUTOMATIC_PARALLELIZATION_MISSING_KEY',
message='Required key for `automatic_parallelization` was not specified.')
spec.exit_code(211, 'ERROR_INVALID_INPUT_AUTOMATIC_PARALLELIZATION_UNRECOGNIZED_KEY',
Expand Down Expand Up @@ -282,26 +283,6 @@ def validate_pseudos(self):
self.report(f'{exception}')
return self.exit_codes.ERROR_INVALID_INPUT_PSEUDO_POTENTIALS

def validate_resources(self):
"""Validate the inputs related to the resources.

One can omit the normally required `options.resources` input for the `PwCalculation`, as long as the input
`automatic_parallelization` is specified. If this is not the case, the `metadata.options` should at least
contain the options `resources` and `max_wallclock_seconds`, where `resources` should define the `num_machines`.
"""
if 'automatic_parallelization' not in self.inputs and 'options' not in self.ctx.inputs.metadata:
return self.exit_codes.ERROR_INVALID_INPUT_RESOURCES
mbercx marked this conversation as resolved.
Show resolved Hide resolved

# If automatic parallelization is not enabled, we better make sure that the options satisfy minimum requirements
if 'automatic_parallelization' not in self.inputs:
mbercx marked this conversation as resolved.
Show resolved Hide resolved
num_machines = self.ctx.inputs.metadata.options.get('resources', {}).get('num_machines', None)
max_wallclock_seconds = self.ctx.inputs.metadata.options.get('max_wallclock_seconds', None)

if num_machines is None or max_wallclock_seconds is None:
return self.exit_codes.ERROR_INVALID_INPUT_RESOURCES_UNDERSPECIFIED

self.set_max_seconds(max_wallclock_seconds)

def set_max_seconds(self, max_wallclock_seconds):
"""Set the `max_seconds` to a fraction of `max_wallclock_seconds` option to prevent out-of-walltime problems.

Expand Down Expand Up @@ -419,6 +400,11 @@ def prepare_process(self):
for the next calculation and the `restart_mode` is set to `restart`. Otherwise, no `parent_folder` is used and
`restart_mode` is set to `from_scratch`.
"""
max_wallclock_seconds = self.ctx.inputs.metadata.options.get('max_wallclock_seconds', None)

if max_wallclock_seconds is not None and 'max_seconds' not in self.ctx.inputs.parameters['CONTROL']:
self.set_max_seconds(max_wallclock_seconds)

if self.ctx.restart_calc:
self.ctx.inputs.parameters['CONTROL']['restart_mode'] = 'restart'
self.ctx.inputs.parent_folder = self.ctx.restart_calc.outputs.remote_folder
Expand Down
36 changes: 35 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,16 +591,20 @@ def _generate_inputs_cp(autopilot=False):
def generate_workchain_pw(generate_workchain, generate_inputs_pw, generate_calc_job_node):
"""Generate an instance of a `PwBaseWorkChain`."""

def _generate_workchain_pw(exit_code=None, inputs=None):
def _generate_workchain_pw(exit_code=None, inputs=None, return_inputs=False):
from plumpy import ProcessState
from aiida.orm import Dict

entry_point = 'quantumespresso.pw.base'

if inputs is None:
pw_inputs = generate_inputs_pw()
kpoints = pw_inputs.pop('kpoints')
inputs = {'pw': pw_inputs, 'kpoints': kpoints}

if return_inputs:
return inputs

process = generate_workchain(entry_point, inputs)

if exit_code is not None:
Expand All @@ -616,6 +620,36 @@ def _generate_workchain_pw(exit_code=None, inputs=None):
return _generate_workchain_pw


@pytest.fixture
def generate_workchain_ph(generate_workchain, generate_inputs_ph, generate_calc_job_node):
"""Generate an instance of a `PhBaseWorkChain`."""

def _generate_workchain_ph(exit_code=None, inputs=None, return_inputs=False):
from plumpy import ProcessState

entry_point = 'quantumespresso.ph.base'

if inputs is None:
inputs = {'ph': generate_inputs_ph()}

if return_inputs:
return inputs

process = generate_workchain(entry_point, inputs)

if exit_code is not None:
node = generate_calc_job_node()
node.set_process_state(ProcessState.FINISHED)
node.set_exit_status(exit_code.status)

process.ctx.iteration = 1
process.ctx.children = [node]

return process

return _generate_workchain_ph


@pytest.fixture
def generate_workchain_pdos(generate_workchain, generate_inputs_pw, fixture_code):
"""Generate an instance of a `PdosWorkChain`."""
Expand Down
54 changes: 29 additions & 25 deletions tests/workflows/ph/test_base.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,13 @@
# -*- coding: utf-8 -*-
# pylint: disable=no-member,redefined-outer-name
"""Tests for the `PhBaseWorkChain` class."""
import pytest

from plumpy import ProcessState

from aiida.common import AttributeDict
from aiida.engine import ProcessHandlerReport

from aiida_quantumespresso.calculations.ph import PhCalculation
from aiida_quantumespresso.workflows.ph.base import PhBaseWorkChain


@pytest.fixture
def generate_workchain_ph(generate_workchain, generate_inputs_ph, generate_calc_job_node):
"""Generate an instance of a `PhBaseWorkChain`."""

def _generate_workchain_ph(exit_code=None):
entry_point = 'quantumespresso.ph.base'
process = generate_workchain(entry_point, {'ph': generate_inputs_ph()})

if exit_code is not None:
node = generate_calc_job_node()
node.set_process_state(ProcessState.FINISHED)
node.set_exit_status(exit_code.status)

process.ctx.iteration = 1
process.ctx.children = [node]

return process

return _generate_workchain_ph


def test_setup(generate_workchain_ph):
"""Test `PhBaseWorkChain.setup`."""
process = generate_workchain_ph()
Expand Down Expand Up @@ -84,3 +59,32 @@ def test_handle_convergence_not_achieved(generate_workchain_ph):

result = process.inspect_process()
assert result.status == 0


def test_set_max_seconds(generate_workchain_ph):
"""Test that `max_seconds` gets set in the parameters based on `max_wallclock_seconds` unless already set."""
inputs = generate_workchain_ph(return_inputs=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not generate the inputs with the generate_inputs_ph fixture instead of adding the return_inputs input to the generate_workchain_ph fixture and then calling it twice, once to get the inputs and once to actually get the work chain?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Problem described for the comment on pw is less pronounced for ph but I want to keep it consistent and the input definition for the PhBasWorkChain may become more complicated in the future

max_wallclock_seconds = inputs['ph']['metadata']['options']['max_wallclock_seconds']

process = generate_workchain_ph(inputs=inputs)
process.setup()
process.validate_parameters()
process.prepare_process()

expected_max_seconds = max_wallclock_seconds * process.defaults.delta_factor_max_seconds
assert 'max_seconds' in process.ctx.inputs['parameters']['INPUTPH']
assert process.ctx.inputs['parameters']['INPUTPH']['max_seconds'] == expected_max_seconds

# Now check that if `max_seconds` is already explicitly set in the parameters, it is not overwritten.
inputs = generate_workchain_ph(return_inputs=True)
max_seconds = 1
max_wallclock_seconds = inputs['ph']['metadata']['options']['max_wallclock_seconds']
inputs['ph']['parameters']['INPUTPH']['max_seconds'] = max_seconds

process = generate_workchain_ph(inputs=inputs)
process.setup()
process.validate_parameters()
process.prepare_process()

assert 'max_seconds' in process.ctx.inputs['parameters']['INPUTPH']
assert process.ctx.inputs['parameters']['INPUTPH']['max_seconds'] == max_seconds
29 changes: 29 additions & 0 deletions tests/workflows/pw/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,32 @@ def test_sanity_check_no_bands(generate_workchain_pw):

calculation = process.ctx.children[-1]
assert process.sanity_check_insufficient_bands(calculation) is None


def test_set_max_seconds(generate_workchain_pw):
"""Test that `max_seconds` gets set in the parameters based on `max_wallclock_seconds` unless already set."""
inputs = generate_workchain_pw(return_inputs=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as for the ph.x test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because the generate_inputs_pw does not return all the minimally required inputs needed for the workchain. If you look at generate_workchain_pw it needs to add the inputs to a namespace, and pop an input. I don't want to have to add these lines in every test.

max_wallclock_seconds = inputs['pw']['metadata']['options']['max_wallclock_seconds']

process = generate_workchain_pw(inputs=inputs)
process.setup()
process.validate_parameters()
process.prepare_process()

expected_max_seconds = max_wallclock_seconds * process.defaults.delta_factor_max_seconds
assert 'max_seconds' in process.ctx.inputs['parameters']['CONTROL']
assert process.ctx.inputs['parameters']['CONTROL']['max_seconds'] == expected_max_seconds

# Now check that if `max_seconds` is already explicitly set in the parameters, it is not overwritten.
inputs = generate_workchain_pw(return_inputs=True)
max_seconds = 1
max_wallclock_seconds = inputs['pw']['metadata']['options']['max_wallclock_seconds']
inputs['pw']['parameters']['CONTROL']['max_seconds'] = max_seconds

process = generate_workchain_pw(inputs=inputs)
process.setup()
process.validate_parameters()
process.prepare_process()

assert 'max_seconds' in process.ctx.inputs['parameters']['CONTROL']
assert process.ctx.inputs['parameters']['CONTROL']['max_seconds'] == max_seconds