Skip to content

Commit

Permalink
Apply reviewer suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
mbercx committed Jan 13, 2021
1 parent 8a745bb commit 50713cc
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 34 deletions.
43 changes: 21 additions & 22 deletions aiida_quantumespresso/workflows/pw/relax.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ def setup(self):
self.ctx.relax_inputs = AttributeDict(self.exposed_inputs(PwBaseWorkChain, namespace='base'))
self.ctx.relax_inputs.pw.parameters = self.ctx.relax_inputs.pw.parameters.get_dict()

self.ctx.relax_inputs.pw.parameters.setdefault('CONTROL', {})
self.ctx.relax_inputs.pw.parameters['CONTROL']['restart_mode'] = 'from_scratch'

# Adjust the inputs for the chosen relaxation scheme
if 'relaxation_scheme' in self.inputs:
if self.inputs.relaxation_scheme.value in ('relax', 'vc-relax'):
Expand Down Expand Up @@ -187,12 +190,22 @@ def setup(self):
elif 'base_final_scf' in self.inputs:
self.ctx.final_scf_inputs = AttributeDict(self.exposed_inputs(PwBaseWorkChain, namespace='base_final_scf'))

if 'final_scf_inputs' in self.ctx and self.ctx.relax_inputs.pw.parameters['CONTROL']['calculation'] == 'scf':
self.report(
'Work chain will not run final SCF when `calculation` is set to `scf` for the relaxation '
'`PwBaseWorkChain`.'
)
self.ctx.pop('final_scf_inputs')
if 'final_scf_inputs' in self.ctx:
if self.ctx.relax_inputs.pw.parameters['CONTROL']['calculation'] == 'scf':
self.report(
'Work chain will not run final SCF when `calculation` is set to `scf` for the relaxation '
'`PwBaseWorkChain`.'
)
self.ctx.pop('final_scf_inputs')

else:
self.ctx.final_scf_inputs.pw.parameters = self.ctx.final_scf_inputs.pw.parameters.get_dict()

self.ctx.final_scf_inputs.pw.parameters.setdefault('CONTROL', {})
self.ctx.final_scf_inputs.pw.parameters['CONTROL']['calculation'] = 'scf'
self.ctx.final_scf_inputs.pw.parameters['CONTROL']['restart_mode'] = 'from_scratch'
self.ctx.final_scf_inputs.pw.parameters.pop('CELL', None)
self.ctx.final_scf_inputs.metadata.call_link_label = 'final_scf'

def should_run_relax(self):
"""Return whether a relaxation workchain should be run.
Expand All @@ -217,13 +230,6 @@ def run_relax(self):
inputs = self.ctx.relax_inputs
inputs.pw.structure = self.ctx.current_structure

inputs.pw.parameters.setdefault('CONTROL', {})
inputs.pw.parameters['CONTROL']['restart_mode'] = 'from_scratch'

# If the relaxation uses 'vc-relax make sure there is a 'CELL' namelist
if inputs.pw.parameters['CONTROL']['calculation'] == 'vc-relax':
inputs.pw.parameters.setdefault('CELL', {})

# If one of the nested `PwBaseWorkChains` changed the number of bands, apply it here
if self.ctx.current_number_of_bands is not None:
inputs.pw.parameters.setdefault('SYSTEM', {})['nbnd'] = self.ctx.current_number_of_bands
Expand Down Expand Up @@ -310,13 +316,6 @@ def run_final_scf(self):
"""Run the `PwBaseWorkChain` to run a final scf `PwCalculation` for the relaxed structure."""
inputs = self.ctx.final_scf_inputs
inputs.pw.structure = self.ctx.current_structure
inputs.pw.parameters = inputs.pw.parameters.get_dict()

inputs.pw.parameters.setdefault('CONTROL', {})
inputs.pw.parameters['CONTROL']['calculation'] = 'scf'
inputs.pw.parameters['CONTROL']['restart_mode'] = 'from_scratch'
inputs.pw.parameters.pop('CELL', None)
inputs.metadata.call_link_label = 'final_scf'

if self.ctx.current_number_of_bands is not None:
inputs.pw.parameters.setdefault('SYSTEM', {})['nbnd'] = self.ctx.current_number_of_bands
Expand Down Expand Up @@ -349,9 +348,9 @@ def results(self):
if self.inputs.base.pw.parameters['CONTROL']['calculation'] != 'scf':
self.out('output_structure', final_relax_workchain.outputs.output_structure)

if 'final_scf_inputs' in self.ctx:
try:
self.out_many(self.exposed_outputs(self.ctx.workchain_scf, PwBaseWorkChain))
else:
except AttributeError:
self.out_many(self.exposed_outputs(final_relax_workchain, PwBaseWorkChain))

def on_terminated(self):
Expand Down
3 changes: 1 addition & 2 deletions tests/workflows/protocols/pw/test_bands.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,5 +72,4 @@ def test_relax_type(fixture_code, generate_structure):

builder = PwBandsWorkChain.get_builder_from_protocol(code, structure, relax_type=RelaxType.NONE)
assert builder.relax['base']['pw']['parameters']['CONTROL']['calculation'] == 'scf'
with pytest.raises(KeyError):
builder.relax['base']['pw']['parameters']['CELL'] # pylint: disable=pointless-statement
assert 'CELL' not in builder.relax['base']['pw']['parameters'].get_dict()
15 changes: 5 additions & 10 deletions tests/workflows/protocols/pw/test_relax.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,11 @@ def test_relax_type(fixture_code, generate_structure):

builder = PwRelaxWorkChain.get_builder_from_protocol(code, structure, relax_type=RelaxType.NONE)
assert builder.base['pw']['parameters']['CONTROL']['calculation'] == 'scf'
with pytest.raises(KeyError):
builder.base['pw']['parameters']['CELL'] # pylint: disable=pointless-statement
assert 'CELL' not in builder.base['pw']['parameters'].get_dict()

builder = PwRelaxWorkChain.get_builder_from_protocol(code, structure, relax_type=RelaxType.ATOMS)
assert builder.base['pw']['parameters']['CONTROL']['calculation'] == 'relax'
with pytest.raises(KeyError):
builder.base['pw']['parameters']['CELL'] # pylint: disable=pointless-statement
assert 'CELL' not in builder.base['pw']['parameters'].get_dict()

builder = PwRelaxWorkChain.get_builder_from_protocol(code, structure, relax_type=RelaxType.VOLUME)
assert builder.base['pw']['parameters']['CONTROL']['calculation'] == 'vc-relax'
Expand All @@ -98,17 +96,14 @@ def test_relax_type(fixture_code, generate_structure):
builder = PwRelaxWorkChain.get_builder_from_protocol(code, structure, relax_type=RelaxType.ATOMS_VOLUME)
assert builder.base['pw']['parameters']['CONTROL']['calculation'] == 'vc-relax'
assert builder.base['pw']['parameters']['CELL']['cell_dofree'] == 'volume'
with pytest.raises(KeyError):
builder.base['pw']['settings'] # pylint: disable=pointless-statement
assert 'settings' not in builder.base['pw']

builder = PwRelaxWorkChain.get_builder_from_protocol(code, structure, relax_type=RelaxType.ATOMS_SHAPE)
assert builder.base['pw']['parameters']['CONTROL']['calculation'] == 'vc-relax'
assert builder.base['pw']['parameters']['CELL']['cell_dofree'] == 'shape'
with pytest.raises(KeyError):
builder.base['pw']['settings'] # pylint: disable=pointless-statement
assert 'settings' not in builder.base['pw']

builder = PwRelaxWorkChain.get_builder_from_protocol(code, structure, relax_type=RelaxType.ATOMS_CELL)
assert builder.base['pw']['parameters']['CONTROL']['calculation'] == 'vc-relax'
assert builder.base['pw']['parameters']['CELL']['cell_dofree'] == 'all'
with pytest.raises(KeyError):
builder.base['pw']['settings'] # pylint: disable=pointless-statement
assert 'settings' not in builder.base['pw']

0 comments on commit 50713cc

Please sign in to comment.