Skip to content

Commit

Permalink
Implement parallelization options as explicit inputs.
Browse files Browse the repository at this point in the history
If any parallelization option is already specified in the
`settings['CMDLINE']`, a `DeprecationWarning` is issued. However,
the value from `settings['CMDLINE']` still takes precedence. This
allows us to implement the defaults at the level of the input
port, without having to care about it later on.

The `BasePwCpInputGenerator` is given three class attributes:
- `_PARALLELIZATION_FLAGS` is a mapping
  `{flag_name: (default, help)}` of all possible flags.
- `_ENABLED_PARALLELIZATION_FLAGS` is a tuple of flag names that
  are implemented in a particular code.
- `_PARALLELIZATION_FLAG_ALIASES` is used to detect all possible
  variations on a flag name. These are taken from the QE source.

When checking for existing parallelization flags in the manually
passed cmdline parameters, these are normalized by splitting on
whitespace.
QE ignores flags that are capitalized differently, so we do not
have to normalize capitalization here.

TODO:
- Documentation
- Set `_ENABLED_PARALLELIZATION_FLAGS` for codes other than pw.x.
- `pw.x` accepts the `-nimage` flag - but I'm not sure if it makes
  any sense.
  • Loading branch information
Dominik Gresch committed Oct 13, 2020
1 parent 0db7978 commit 0c1e2ee
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 1 deletion.
85 changes: 85 additions & 0 deletions aiida_quantumespresso/calculations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
"""Base `CalcJob` for implementations for pw.x and cp.x of Quantum ESPRESSO."""
import abc
import os
import warnings
from functools import partial
from types import MappingProxyType

from aiida import orm
from aiida.common import datastructures, exceptions
Expand All @@ -27,6 +30,39 @@ class BasePwCpInputGenerator(CalcJob):
_ENVIRON_INPUT_FILE_NAME = 'environ.in'
_DEFAULT_IBRAV = 0

# A mapping {flag_name: (default, help)} of parallelization flags
# possible in QE codes. The flags that are actually implemented in a
# given code should be specified in the '_ENABLED_PARALLELIZATION_FLAGS'
# tuple of each calculation subclass.
_PARALLELIZATION_FLAGS = MappingProxyType(
dict(
nimage=(
None, "The number of 'images', each corresponding to a different self-consistent or "
'linear-response calculation.'
),
npool=(None, "The number of 'pools', each taking care of a group of k-points."),
nband=(None, "The number of 'band groups', each taking care of a group of Kohn-Sham orbitals."),
ntg=(None, "The number of 'task groups' across which the FFT planes are distributed."),
ndiag=(
None, "The number of 'linear algebra groups' used when parallelizing the subspace "
'diagonalization / iterative orthonormalization. By default, no parameter is '
'passed to Quantum ESPRESSO, meaning it will use its default.'
)
)
)

_ENABLED_PARALLELIZATION_FLAGS = tuple()

_PARALLELIZATION_FLAG_ALIASES = MappingProxyType(
dict(
nimage=('ni', 'nimages', 'npot'),
npool=('nk', 'npools'),
nband=('nb', 'nbgrp', 'nband_group'),
ntg=('nt', 'ntask_groups', 'nyfft'),
ndiag=('northo', 'nd', 'nproc_diag', 'nproc_ortho')
)
)

# Additional files that should always be retrieved for the specific plugin
_internal_retrieve_list = []

Expand Down Expand Up @@ -85,6 +121,26 @@ def define(cls, spec):
spec.input_namespace('pseudos', valid_type=orm.UpfData, dynamic=True,
help='A mapping of `UpfData` nodes onto the kind name to which they should apply.')

for flag_name in cls._ENABLED_PARALLELIZATION_FLAGS:
try:
default, help_ = cls._PARALLELIZATION_FLAGS[flag_name]
except KeyError as exc:
raise KeyError(
f"The parallelization flag '{flag_name}' specified in _ENABLED_PARALLELIZATION_FLAGS "
'does not exist in _PARALLELIZATION_FLAGS. Please report this issue to the calculation '
'plugin developer.'
) from exc

if default is None:
extra_kwargs = {'required': False}
else:
# We use 'functools.partial' here because lambda: `orm.Int(default)` would
# bind to the name `default`, and look up its value only when the lambda is
# actually called. Because `default` changes throughout the for-loop, this
# would give the wrong value.
extra_kwargs = {'default': partial(orm.Int, default)}
spec.input(f'parallelization.{flag_name}', help=help_, **extra_kwargs)

def prepare_for_submission(self, folder):
"""Create the input files from the input nodes passed to this instance of the `CalcJob`.
Expand Down Expand Up @@ -202,6 +258,35 @@ def prepare_for_submission(self, folder):
calcinfo.uuid = str(self.uuid)
# Empty command line by default
cmdline_params = settings.pop('CMDLINE', [])

# Add the parallelization flags.
# The `cmdline_params_normalized` are used only here to check
# for existing parallelization flags.
cmdline_params_normalized = []
for param in cmdline_params:
cmdline_params_normalized.extend(param.split())
# To make the order of flags consistent and "nice", we use the
# ordering from the flag definition.
for flag_name in self._ENABLED_PARALLELIZATION_FLAGS:
# We check for existing inputs for all possible flag names,
# to make sure a `DeprecationWarning` is emitted whenever
# a flag is specified in the `settings['CMDLINE']`.
all_aliases = list(self._PARALLELIZATION_FLAG_ALIASES[flag_name]) + [flag_name]
if any(f'-{alias}' in cmdline_params_normalized for alias in all_aliases):
# To preserve backwards compatibility, we ignore the
# `parallelization.flag_name` input if the same flag is
# already present in `cmdline_params`. This is necessary
# because the new inputs can be set by their default,
# in user code that doesn't explicitly specify them.
warnings.warn(
"Specifying the parallelization flags through settings['CMDLINE'] is "
"deprecated, use the 'parallelization' input namespace instead.", DeprecationWarning
)
continue
if flag_name in self.inputs.parallelization:
flag_value = self.inputs.parallelization[flag_name].value
cmdline_params += [f'-{flag_name}', str(flag_value)]

# we commented calcinfo.stin_name and added it here in cmdline_params
# in this way the mpirun ... pw.x ... < aiida.in
# is replaced by mpirun ... pw.x ... -in aiida.in
Expand Down
2 changes: 2 additions & 0 deletions aiida_quantumespresso/calculations/pw.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class PwCalculation(BasePwCpInputGenerator):
# Not using symlink in pw to allow multiple nscf to run on top of the same scf
_default_symlink_usage = False

_ENABLED_PARALLELIZATION_FLAGS = ('nimage', 'npool', 'nband', 'ntg', 'ndiag')

@classproperty
def xml_filepaths(cls):
"""Return a list of XML output filepaths relative to the remote working directory that should be retrieved."""
Expand Down
41 changes: 40 additions & 1 deletion tests/calculations/test_pw.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_pw_default(fixture_sandbox, generate_calc_job, generate_inputs_pw, file

# Check the attributes of the returned `CalcInfo`
assert isinstance(calc_info, datastructures.CalcInfo)
assert sorted(calc_info.cmdline_params) == sorted(cmdline_params)
assert calc_info.cmdline_params == cmdline_params
assert sorted(calc_info.local_copy_list) == sorted(local_copy_list)
assert sorted(calc_info.retrieve_list) == sorted(retrieve_list)
assert sorted(calc_info.retrieve_temporary_list) == sorted(retrieve_temporary_list)
Expand Down Expand Up @@ -155,3 +155,42 @@ def test_pw_ibrav_tol(fixture_sandbox, generate_calc_job, fixture_code, generate
# After adjusting the tolerance, the input validation no longer fails.
inputs['settings'] = orm.Dict(dict={'ibrav_cell_tolerance': eps})
generate_calc_job(fixture_sandbox, entry_point_name, inputs)


def test_pw_parallelization_inputs(fixture_sandbox, generate_calc_job, generate_inputs_pw):
"""Test that the parallelization settings are set correctly in the commandline params."""
entry_point_name = 'quantumespresso.pw'

inputs = generate_inputs_pw()
inputs['parallelization'] = {
'nimage': orm.Int(1),
'npool': orm.Int(4),
'nband': orm.Int(2),
'ntg': orm.Int(3),
'ndiag': orm.Int(12)
}
calc_info = generate_calc_job(fixture_sandbox, entry_point_name, inputs)

cmdline_params = ['-nimage', '1', '-npool', '4', '-nband', '2', '-ntg', '3', '-ndiag', '12', '-in', 'aiida.in']

# Check that the command-line parameters are as expected.
assert calc_info.cmdline_params == cmdline_params


@pytest.mark.parametrize(
'flag_name', ['nimage', 'ni', 'npool', 'nk', 'nband', 'nb', 'ntg', 'nt', 'northo', 'ndiag', 'nd']
)
def test_pw_parallelization_deprecation(fixture_sandbox, generate_calc_job, generate_inputs_pw, flag_name):
"""Test the deprecation warning on specifying parallelization flags manually.
Test that passing parallelization flags in the `settings['CMDLINE']
emits a `DeprecationWarning`.
"""
entry_point_name = 'quantumespresso.pw'

inputs = generate_inputs_pw()
extra_cmdline_args = [f'-{flag_name}', '2']
inputs['settings'] = orm.Dict(dict={'CMDLINE': extra_cmdline_args})
with pytest.warns(DeprecationWarning):
calc_info = generate_calc_job(fixture_sandbox, entry_point_name, inputs)
assert calc_info.cmdline_params == extra_cmdline_args + ['-in', 'aiida.in']

0 comments on commit 0c1e2ee

Please sign in to comment.