Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PhBaseWorkChain: add handle schedulor out of walltime #754

Merged
merged 4 commits into from
Nov 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions aiida_quantumespresso/parsers/ph.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ def parse(self, **kwargs):
self.emit_logs(logs)
self.out('output_parameters', orm.Dict(dict=parsed_data))

# If the scheduler detected OOW, simply keep that exit code by not returning anything more specific.
if self.node.exit_status == PhCalculation.exit_codes.ERROR_SCHEDULER_OUT_OF_WALLTIME:
return

if 'ERROR_OUT_OF_WALLTIME' in logs['error']:
return self.exit_codes.ERROR_OUT_OF_WALLTIME

Expand Down
26 changes: 24 additions & 2 deletions aiida_quantumespresso/workflows/ph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,28 @@ def handle_unrecoverable_failure(self, node):
self.report_error_handled(node, 'unrecoverable error, aborting...')
return ProcessHandlerReport(True, self.exit_codes.ERROR_UNRECOVERABLE_FAILURE)

@process_handler(priority=610, exit_codes=PhCalculation.exit_codes.ERROR_SCHEDULER_OUT_OF_WALLTIME)
def handle_scheduler_out_of_walltime(self, node):
"""Handle `ERROR_SCHEDULER_OUT_OF_WALLTIME` exit code: decrease the max_secondes and restart from scratch."""

# Decrease `max_seconds` significantly in order to make sure that the calculation has the time to shut down
# neatly before reaching the scheduler wall time and one can restart from this calculation.
factor = 0.5
max_seconds = self.ctx.inputs.parameters.get('INPUTPH', {}).get('max_seconds', None)
if max_seconds is None:
max_seconds = self.ctx.inputs.metadata.options.get(
'max_wallclock_seconds', None
) * self.defaults.delta_factor_max_seconds
max_seconds_new = max_seconds * factor
unkcpz marked this conversation as resolved.
Show resolved Hide resolved

self.ctx.restart_calc = node
self.ctx.inputs.parameters.setdefault('INPUTPH', {})['recover'] = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not too familiar with ph.x so I am not sure if this is the correct way to do it, but it could very well be. Just saying that I cannot really sign off on this as I have no idea. I would advise that we ask someone that is knowledgeable on ph.x to validate this strategy.

self.ctx.inputs.parameters.setdefault('INPUTPH', {})['max_seconds'] = max_seconds_new

action = f'reduced max_seconds from {max_seconds} to {max_seconds_new} and restarting'
self.report_error_handled(node, action)
return ProcessHandlerReport(True)

@process_handler(priority=580, exit_codes=PhCalculation.exit_codes.ERROR_OUT_OF_WALLTIME)
def handle_out_of_walltime(self, node):
"""Handle `ERROR_OUT_OF_WALLTIME` exit code: calculation shut down neatly and we can simply restart."""
Expand All @@ -117,8 +139,8 @@ def handle_out_of_walltime(self, node):
return ProcessHandlerReport(True)

@process_handler(priority=410, exit_codes=PhCalculation.exit_codes.ERROR_CONVERGENCE_NOT_REACHED)
def handle_convergence_not_achieved(self, node):
"""Handle `ERROR_CONVERGENCE_NOT_REACHED` exit code: decrease the mixing beta and restart from scratch."""
def handle_convergence_not_reached(self, node):
unkcpz marked this conversation as resolved.
Show resolved Hide resolved
"""Handle `ERROR_CONVERGENCE_NOT_REACHED` exit code: decrease the mixing beta and restart."""
factor = self.defaults.delta_factor_alpha_mix
alpha_mix = self.ctx.inputs.parameters.get('INPUTPH', {}).get('alpha_mix(1)', self.defaults.alpha_mix)
alpha_mix_new = alpha_mix * factor
Expand Down
4 changes: 2 additions & 2 deletions aiida_quantumespresso/workflows/pw/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def validate_kpoints(self):
the case of the latter, the `KpointsData` will be constructed for the input `StructureData` using the
`create_kpoints_from_distance` calculation function.
"""
if all([key not in self.inputs for key in ['kpoints', 'kpoints_distance']]):
if all(key not in self.inputs for key in ['kpoints', 'kpoints_distance']):
return self.exit_codes.ERROR_INVALID_INPUT_KPOINTS

try:
Expand Down Expand Up @@ -637,7 +637,7 @@ def handle_relax_recoverable_electronic_convergence_error(self, calculation):
@process_handler(priority=410, exit_codes=[
PwCalculation.exit_codes.ERROR_ELECTRONIC_CONVERGENCE_NOT_REACHED,
])
def handle_electronic_convergence_not_achieved(self, calculation):
def handle_electronic_convergence_not_reached(self, calculation):
"""Handle `ERROR_ELECTRONIC_CONVERGENCE_NOT_REACHED` error.

Decrease the mixing beta and fully restart from the previous calculation.
Expand Down
29 changes: 26 additions & 3 deletions tests/workflows/ph/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,38 @@ def test_handle_out_of_walltime(generate_workchain_ph):
assert result.status == 0


def test_handle_convergence_not_achieved(generate_workchain_ph):
"""Test `PhBaseWorkChain.handle_convergence_not_achieved`."""
def test_handle_scheduler_out_of_walltime(generate_workchain_ph):
"""Test `PhBaseWorkChain.handle_scheduler_out_of_walltime`."""
inputs = generate_workchain_ph(return_inputs=True)
max_wallclock_seconds = inputs['ph']['metadata']['options']['max_wallclock_seconds']
max_seconds = max_wallclock_seconds * PhBaseWorkChain.defaults.delta_factor_max_seconds

process = generate_workchain_ph(exit_code=PhCalculation.exit_codes.ERROR_SCHEDULER_OUT_OF_WALLTIME)
process.setup()
process.validate_parameters()
process.prepare_process()

max_seconds_new = max_seconds * 0.5

result = process.handle_scheduler_out_of_walltime(process.ctx.children[-1])
assert isinstance(result, ProcessHandlerReport)
assert result.do_break
assert process.ctx.inputs.parameters['INPUTPH']['max_seconds'] == max_seconds_new
assert not process.ctx.inputs.parameters['INPUTPH']['recover']

result = process.inspect_process()
assert result.status == 0


def test_handle_convergence_not_reached(generate_workchain_ph):
"""Test `PhBaseWorkChain.handle_convergence_not_reached`."""
process = generate_workchain_ph(exit_code=PhCalculation.exit_codes.ERROR_CONVERGENCE_NOT_REACHED)
process.setup()
process.validate_parameters()

alpha_new = PhBaseWorkChain.defaults.alpha_mix * PhBaseWorkChain.defaults.delta_factor_alpha_mix

result = process.handle_convergence_not_achieved(process.ctx.children[-1])
result = process.handle_convergence_not_reached(process.ctx.children[-1])
assert isinstance(result, ProcessHandlerReport)
assert result.do_break
assert process.ctx.inputs.parameters['INPUTPH']['alpha_mix(1)'] == alpha_new
Expand Down
8 changes: 4 additions & 4 deletions tests/workflows/pw/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_handle_out_of_walltime(generate_workchain_pw, fixture_localhost, genera
)
process.setup()

result = process.handle_electronic_convergence_not_achieved(process.ctx.children[-1])
result = process.handle_electronic_convergence_not_reached(process.ctx.children[-1])
result = process.handle_out_of_walltime(process.ctx.children[-1])
assert isinstance(result, ProcessHandlerReport)
assert process.ctx.inputs.parameters['CONTROL']['restart_mode'] == 'restart'
Expand All @@ -74,8 +74,8 @@ def test_handle_out_of_walltime_structure_changed(generate_workchain_pw, generat
assert result.status == 0


def test_handle_electronic_convergence_not_achieved(generate_workchain_pw, fixture_localhost, generate_remote_data):
"""Test `PwBaseWorkChain.handle_electronic_convergence_not_achieved`."""
def test_handle_electronic_convergence_not_reached(generate_workchain_pw, fixture_localhost, generate_remote_data):
"""Test `PwBaseWorkChain.handle_electronic_convergence_not_reached`."""
remote_data = generate_remote_data(computer=fixture_localhost, remote_path='/path/to/remote')

process = generate_workchain_pw(
Expand All @@ -86,7 +86,7 @@ def test_handle_electronic_convergence_not_achieved(generate_workchain_pw, fixtu

process.ctx.inputs.parameters['ELECTRONS']['mixing_beta'] = 0.5

result = process.handle_electronic_convergence_not_achieved(process.ctx.children[-1])
result = process.handle_electronic_convergence_not_reached(process.ctx.children[-1])
assert isinstance(result, ProcessHandlerReport)
assert process.ctx.inputs.parameters['ELECTRONS']['mixing_beta'] == \
process.defaults.delta_factor_mixing_beta * 0.5
Expand Down