Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement parallelization options as explicit inputs. #554

Merged
merged 2 commits into from
Nov 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 121 additions & 2 deletions aiida_quantumespresso/calculations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,16 @@
"""Base `CalcJob` for implementations for pw.x and cp.x of Quantum ESPRESSO."""
import abc
import os
import copy
import numbers
import warnings
from functools import partial
from types import MappingProxyType

from aiida import orm
from aiida.common import datastructures, exceptions
from aiida.common.lang import classproperty
from aiida.common.warnings import AiidaDeprecationWarning
from aiida.plugins import DataFactory
from qe_tools.converters import get_parameters_from_cell

Expand All @@ -30,6 +36,37 @@ class BasePwCpInputGenerator(CalcJob):
_ENVIRON_INPUT_FILE_NAME = 'environ.in'
_DEFAULT_IBRAV = 0

# A mapping {flag_name: help_string} of parallelization flags
# possible in QE codes. The flags that are actually implemented in a
# given code should be specified in the '_ENABLED_PARALLELIZATION_FLAGS'
# tuple of each calculation subclass.
_PARALLELIZATION_FLAGS = MappingProxyType(
dict(
nimage="The number of 'images', each corresponding to a different self-consistent or "
'linear-response calculation.',
npool="The number of 'pools', each taking care of a group of k-points.",
nband="The number of 'band groups', each taking care of a group of Kohn-Sham orbitals.",
ntg="The number of 'task groups' across which the FFT planes are distributed.",
ndiag="The number of 'linear algebra groups' used when parallelizing the subspace "
'diagonalization / iterative orthonormalization. By default, no parameter is '
'passed to Quantum ESPRESSO, meaning it will use its default.',
nhw="The 'nmany' FFT bands parallelization option."
)
)

_ENABLED_PARALLELIZATION_FLAGS = tuple()

_PARALLELIZATION_FLAG_ALIASES = MappingProxyType(
dict(
nimage=('ni', 'nimages', 'npot'),
npool=('nk', 'npools'),
nband=('nb', 'nbgrp', 'nband_group'),
ntg=('nt', 'ntask_groups', 'nyfft'),
ndiag=('northo', 'nd', 'nproc_diag', 'nproc_ortho'),
nhw=('nh', 'n_howmany', 'howmany')
)
)

# Additional files that should always be retrieved for the specific plugin
_internal_retrieve_list = []

Expand Down Expand Up @@ -87,6 +124,35 @@ def define(cls, spec):
help='Optional van der Waals table contained in a `SinglefileData`.')
spec.input_namespace('pseudos', valid_type=(LegacyUpfData, UpfData), dynamic=True,
help='A mapping of `UpfData` nodes onto the kind name to which they should apply.')
spec.input(
'parallelization',
valid_type=orm.Dict,
required=False,
help=(
'Parallelization options. The following flags are allowed:\n' + '\n'.join(
f'{flag_name:<7}: {cls._PARALLELIZATION_FLAGS[flag_name]}'
for flag_name in cls._ENABLED_PARALLELIZATION_FLAGS
)
),
validator=cls._validate_parallelization
)

@classmethod
def _validate_parallelization(cls, value, port_namespace): # pylint: disable=unused-argument
if value:
value_dict = value.get_dict()
unknown_flags = set(value_dict.keys()) - set(cls._ENABLED_PARALLELIZATION_FLAGS)
if unknown_flags:
return (
f"Unknown flags in 'parallelization': {unknown_flags}, "
f'allowed flags are {cls._ENABLED_PARALLELIZATION_FLAGS}.'
)
invalid_values = [val for val in value_dict.values() if not isinstance(val, numbers.Integral)]
if invalid_values:
return (
f'Parallelization values must be integers; got invalid values {invalid_values}.'
)


def prepare_for_submission(self, folder):
"""Create the input files from the input nodes passed to this instance of the `CalcJob`.
Expand Down Expand Up @@ -203,8 +269,12 @@ def prepare_for_submission(self, folder):
calcinfo = datastructures.CalcInfo()

calcinfo.uuid = str(self.uuid)
# Empty command line by default
cmdline_params = settings.pop('CMDLINE', [])
# Start from an empty command line by default
cmdline_params = self._add_parallelization_flags_to_cmdline_params(
cmdline_params=settings.pop('CMDLINE', [])
)


# we commented calcinfo.stin_name and added it here in cmdline_params
# in this way the mpirun ... pw.x ... < aiida.in
# is replaced by mpirun ... pw.x ... -in aiida.in
Expand Down Expand Up @@ -244,6 +314,55 @@ def prepare_for_submission(self, folder):

return calcinfo

def _add_parallelization_flags_to_cmdline_params(self, cmdline_params):
"""Get the command line parameters with added parallelization flags.

Adds the parallelization flags to the given `cmdline_params` and
returns the updated list.

Raises an `InputValidationError` if multiple aliases to the same
flag are given in `cmdline_params`, or the same flag is given
both in `cmdline_params` and the explicit `parallelization`
input.
"""
cmdline_params_res = copy.deepcopy(cmdline_params)
# The `cmdline_params_normalized` are used only here to check
# for existing parallelization flags.
cmdline_params_normalized = []
for param in cmdline_params:
cmdline_params_normalized.extend(param.split())

if 'parallelization' in self.inputs:
parallelization_dict = self.inputs.parallelization.get_dict()
else:
parallelization_dict = {}
# To make the order of flags consistent and "nice", we use the
# ordering from the flag definition.
for flag_name in self._ENABLED_PARALLELIZATION_FLAGS:
all_aliases = list(self._PARALLELIZATION_FLAG_ALIASES[flag_name]) + [flag_name]
aliases_in_cmdline = [alias for alias in all_aliases if f'-{alias}' in cmdline_params_normalized]
if aliases_in_cmdline:
if len(aliases_in_cmdline) > 1:
raise exceptions.InputValidationError(
sphuber marked this conversation as resolved.
Show resolved Hide resolved
f'Conflicting parallelization flags {aliases_in_cmdline} '
"in settings['CMDLINE']"
)
if flag_name in parallelization_dict:
raise exceptions.InputValidationError(
f"Parallelization flag '{aliases_in_cmdline[0]}' specified in settings['CMDLINE'] conflicts "
f"with '{flag_name}' in the 'parallelization' input."
)
else:
warnings.warn(
"Specifying the parallelization flags through settings['CMDLINE'] is "
"deprecated, use the 'parallelization' input instead.", AiidaDeprecationWarning
)
continue
if flag_name in parallelization_dict:
flag_value = parallelization_dict[flag_name]
cmdline_params_res += [f'-{flag_name}', str(flag_value)]
return cmdline_params_res

@staticmethod
def _generate_PWCP_input_tail(*args, **kwargs):
"""Generate tail of input file.
Expand Down
2 changes: 2 additions & 0 deletions aiida_quantumespresso/calculations/pw.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class PwCalculation(BasePwCpInputGenerator):
# Not using symlink in pw to allow multiple nscf to run on top of the same scf
_default_symlink_usage = False

_ENABLED_PARALLELIZATION_FLAGS = ('npool', 'nband', 'ntg', 'ndiag')

@classproperty
def xml_filepaths(cls):
"""Return a list of XML output filepaths relative to the remote working directory that should be retrieved."""
Expand Down
Empty file.
4 changes: 4 additions & 0 deletions docs/source/user_guide/calculation_plugins/pw.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ This can then be used directly in the process builder of for example a ``PwCalcu
* **settings**, class :py:class:`Dict <aiida.orm.nodes.data.dict.Dict>` (optional)
An optional dictionary that activates non-default operations. For a list of possible
values to pass, see the section on the :ref:`advanced features <pw-advanced-features>`.
* **parallelization**, class :py:class:`Dict <aiida.orm.nodes.data.dict.Dict>` (optional)
An optional dictionary to specify the parallelization flags passed to `pw.x` on the
command line. The dictionary maps flag names (type `str`) to their values (type `int`).
Allowed flag names are `npool`, `nband`, `ntg`, and `ndiag`.
* **parent_folder**, class :py:class:`RemoteData <aiida.orm.nodes.data.dict.Dict>` (optional)
If specified, the scratch folder coming from a previous QE calculation is
copied in the scratch of the new calculation.
Expand Down
95 changes: 94 additions & 1 deletion tests/calculations/test_pw.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from aiida import orm
from aiida.common import datastructures
from aiida.common.warnings import AiidaDeprecationWarning
from aiida.common.exceptions import InputValidationError
from aiida_quantumespresso.utils.resources import get_default_options
from aiida_quantumespresso.calculations.helpers import QEInputValidationError

Expand All @@ -25,7 +27,7 @@ def test_pw_default(fixture_sandbox, generate_calc_job, generate_inputs_pw, file
# Check the attributes of the returned `CalcInfo`
assert isinstance(calc_info, datastructures.CalcInfo)
assert isinstance(calc_info.codes_info[0], datastructures.CodeInfo)
assert sorted(calc_info.codes_info[0].cmdline_params) == sorted(cmdline_params)
assert sorted(calc_info.codes_info[0].cmdline_params) == cmdline_params
assert sorted(calc_info.local_copy_list) == sorted(local_copy_list)
assert sorted(calc_info.retrieve_list) == sorted(retrieve_list)
assert sorted(calc_info.retrieve_temporary_list) == sorted(retrieve_temporary_list)
Expand Down Expand Up @@ -157,3 +159,94 @@ def test_pw_ibrav_tol(fixture_sandbox, generate_calc_job, fixture_code, generate
# After adjusting the tolerance, the input validation no longer fails.
inputs['settings'] = orm.Dict(dict={'ibrav_cell_tolerance': eps})
generate_calc_job(fixture_sandbox, entry_point_name, inputs)


def test_pw_parallelization_inputs(fixture_sandbox, generate_calc_job, generate_inputs_pw):
"""Test that the parallelization settings are set correctly in the commandline params."""
entry_point_name = 'quantumespresso.pw'

inputs = generate_inputs_pw()
inputs['parallelization'] = orm.Dict(dict={'npool': 4, 'nband': 2, 'ntg': 3, 'ndiag': 12})
calc_info = generate_calc_job(fixture_sandbox, entry_point_name, inputs)

cmdline_params = ['-npool', '4', '-nband', '2', '-ntg', '3', '-ndiag', '12', '-in', 'aiida.in']

# Check that the command-line parameters are as expected.
assert calc_info.codes_info[0].cmdline_params == cmdline_params


@pytest.mark.parametrize('flag_name', ['npool', 'nk', 'nband', 'nb', 'ntg', 'nt', 'northo', 'ndiag', 'nd'])
def test_pw_parallelization_deprecation(fixture_sandbox, generate_calc_job, generate_inputs_pw, flag_name):
"""Test the deprecation warning on specifying parallelization flags manually.

Test that passing parallelization flags in the `settings['CMDLINE']
emits an `AiidaDeprecationWarning`.
"""
entry_point_name = 'quantumespresso.pw'

inputs = generate_inputs_pw()
extra_cmdline_args = [f'-{flag_name}', '2']
inputs['settings'] = orm.Dict(dict={'CMDLINE': extra_cmdline_args})
with pytest.warns(AiidaDeprecationWarning) as captured_warnings:
calc_info = generate_calc_job(fixture_sandbox, entry_point_name, inputs)
assert calc_info.codes_info[0].cmdline_params == extra_cmdline_args + ['-in', 'aiida.in']
assert any('parallelization flags' in str(warning.message) for warning in captured_warnings.list)


def test_pw_parallelization_conflict_error(fixture_sandbox, generate_calc_job, generate_inputs_pw):
"""Test conflict between `settings['CMDLINE']` and `parallelization`.

Test that passing the same parallelization flag (modulo aliases)
manually in `settings['CMDLINE']` and in the `parallelization`
input raises an `InputValidationError`.
"""
entry_point_name = 'quantumespresso.pw'

inputs = generate_inputs_pw()
extra_cmdline_args = ['-nk', '2']
inputs['settings'] = orm.Dict(dict={'CMDLINE': extra_cmdline_args})
inputs['parallelization'] = orm.Dict(dict={'npool': 2})
with pytest.raises(InputValidationError) as exc:
generate_calc_job(fixture_sandbox, entry_point_name, inputs)
assert 'conflicts' in str(exc.value)


def test_pw_parallelization_incorrect_flag(fixture_sandbox, generate_calc_job, generate_inputs_pw):
"""Test that passing a non-existing parallelization flag raises.

Test that specifying an non-existing parallelization flag in
the `parallelization` `Dict` raises a `ValueError`.
"""
entry_point_name = 'quantumespresso.pw'

inputs = generate_inputs_pw()
inputs['parallelization'] = orm.Dict(dict={'invalid_flag_name': 2})
with pytest.raises(ValueError) as exc:
generate_calc_job(fixture_sandbox, entry_point_name, inputs)
assert 'Unknown' in str(exc.value)


def test_pw_parallelization_incorrect_value(fixture_sandbox, generate_calc_job, generate_inputs_pw):
"""Test that passing a non-integer parallelization flag raises.

Test that specifying an non-integer parallelization flag value in
the `parallelization` `Dict` raises a `ValueError`.
"""
entry_point_name = 'quantumespresso.pw'

inputs = generate_inputs_pw()
inputs['parallelization'] = orm.Dict(dict={'npool': 2.2})
with pytest.raises(ValueError) as exc:
generate_calc_job(fixture_sandbox, entry_point_name, inputs)
assert 'integer' in str(exc.value)


def test_pw_parallelization_duplicate_cmdline_flag(fixture_sandbox, generate_calc_job, generate_inputs_pw):
"""Test that passing two different aliases to the same parallelization flag raises."""
entry_point_name = 'quantumespresso.pw'

inputs = generate_inputs_pw()
inputs['settings'] = orm.Dict(dict={'CMDLINE': ['-nk', '2', '-npools', '2']})
with pytest.raises(InputValidationError) as exc:
generate_calc_job(fixture_sandbox, entry_point_name, inputs)
assert 'Conflicting' in str(exc.value)