Skip to content

Commit

Permalink
Protocols: add usage of metadata/parallelization overrides (#652)
Browse files Browse the repository at this point in the history
The `metadata` and `parallelization` inputs of the `PwCalculation`,
stored in the `pw` namespace of the `PwBaseWorkChain`, are not
(properly) used in the `get_builder_from_protocol()` method. Here we
make the following changes to correct this issue:

* Move the `metadata` inputs under `pw` in the base.yaml file.
* Correctly obtain the `metadata` from the inputs in the
  `get_builder_from_protocol()` method of the `PwBaseWorkChain`.
* Check if there is `parallelization` input in the `pw` namespace of the
  inputs, and if so add it to the builder in the
  `get_builder_from_protocol()` method of the `PwBaseWorkChain`.
  • Loading branch information
mbercx authored May 6, 2021
1 parent 52c53b7 commit bedbd57
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 11 deletions.
12 changes: 6 additions & 6 deletions aiida_quantumespresso/workflows/protocols/pw/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ default_inputs:
meta_parameters:
conv_thr_per_atom: 0.2e-9
etot_conv_thr_per_atom: 1.e-5
metadata:
options:
resources:
num_machines: 1
max_wallclock_seconds: 43200 # Twelve hours
withmpi: True
pseudo_family: 'SSSP/1.1/PBE/efficiency'
pw:
metadata:
options:
resources:
num_machines: 1
max_wallclock_seconds: 43200 # Twelve hours
withmpi: True
parameters:
CONTROL:
calculation: scf
Expand Down
14 changes: 9 additions & 5 deletions aiida_quantumespresso/workflows/pw/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,14 +200,18 @@ def get_builder_from_protocol(
parameters['SYSTEM']['nspin'] = 2
parameters['SYSTEM']['starting_magnetization'] = starting_magnetization

builder.pw['code'] = code # pylint: disable=no-member
builder.pw['pseudos'] = pseudo_family.get_pseudos(structure=structure) # pylint: disable=no-member
builder.pw['structure'] = structure # pylint: disable=no-member
builder.pw['parameters'] = orm.Dict(dict=parameters) # pylint: disable=no-member
builder.pw['metadata'] = inputs['metadata'] # pylint: disable=no-member
# pylint: disable=no-member
builder.pw['code'] = code
builder.pw['pseudos'] = pseudo_family.get_pseudos(structure=structure)
builder.pw['structure'] = structure
builder.pw['parameters'] = orm.Dict(dict=parameters)
builder.pw['metadata'] = inputs['pw']['metadata']
if 'parallelization' in inputs['pw']:
builder.pw['parallelization'] = orm.Dict(dict=inputs['pw']['parallelization'])
builder.clean_workdir = orm.Bool(inputs['clean_workdir'])
builder.kpoints_distance = orm.Float(inputs['kpoints_distance'])
builder.kpoints_force_parity = orm.Bool(inputs['kpoints_force_parity'])
# pylint: enable=no-member

return builder

Expand Down
34 changes: 34 additions & 0 deletions tests/workflows/protocols/pw/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,37 @@ def test_initial_magnetic_moments(fixture_code, generate_structure):

assert parameters['SYSTEM']['nspin'] == 2
assert parameters['SYSTEM']['starting_magnetization'] == {'Si': 0.25}


def test_metadata_overrides(fixture_code, generate_structure):
"""Test that pw metadata is correctly passed through overrides."""
code = fixture_code('quantumespresso.pw')
structure = generate_structure()

overrides = {'pw': {'metadata': {'options': {'resources': {'num_machines': 1e90}, 'max_wallclock_seconds': 1}}}}
builder = PwBaseWorkChain.get_builder_from_protocol(
code,
structure,
overrides=overrides,
)
metadata = builder.pw.metadata

assert metadata['options']['resources']['num_machines'] == 1e90
assert metadata['options']['max_wallclock_seconds'] == 1


def test_parallelization_overrides(fixture_code, generate_structure):
"""Test that pw parallelization settings are correctly passed through overrides."""
code = fixture_code('quantumespresso.pw')
structure = generate_structure()

overrides = {'pw': {'parallelization': {'npool': 4, 'ndiag': 12}}}
builder = PwBaseWorkChain.get_builder_from_protocol(
code,
structure,
overrides=overrides,
)
parallelization = builder.pw.parallelization

assert parallelization['npool'] == 4
assert parallelization['ndiag'] == 12

0 comments on commit bedbd57

Please sign in to comment.