Skip to content

Commit

Permalink
Implement parallelization options as explicit inputs.
Browse files Browse the repository at this point in the history
If any parallelization option is already specified in the
`settings['CMDLINE']`, a `DeprecationWarning` is issued. However,
the value from `settings['CMDLINE']` still takes precedence. This
allows us to implement the defaults at the level of the input
port, without having to care about it later on.

The `BasePwCpInputGenerator` is given three class attributes:
- `_ALLOWED_PARALLELIZATION_FLAGS` is a list of tuples
  (flag_name, default, help), of all possible flags
- `_ENABLED_PARALLELIZATION_FLAGS` is a list of flag names that
  are implemented in a particular code.
- `_PARALLELIZATION_FLAG_ALIASES` is used to detect all possible
  variations on a flag name. These are taken from the QE source.

When checking for existing parallelization flags in the manually
passed cmdline parameters, these are normalized by splitting on
whitespace.
QE ignores flags that are capitalized differently, so we do not
have to normalize capitalization here.

TODO:
- Documentation
- Set `_ENABLED_PARALLELIZATION_FLAGS` for codes other than pw.x.
- `pw.x` accepts the `-nimage` flag - but I'm not sure if it makes
  any sense.
  • Loading branch information
Dominik Gresch committed Oct 8, 2020
1 parent 7427398 commit 5e43d1e
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 2 deletions.
92 changes: 92 additions & 0 deletions aiida_quantumespresso/calculations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
"""Base `CalcJob` for implementations for pw.x and cp.x of Quantum ESPRESSO."""
import abc
import os
import warnings
from functools import partial

from aiida import orm
from aiida.common import datastructures, exceptions
Expand All @@ -24,6 +26,44 @@ class BasePwCpInputGenerator(CalcJob):
_DATAFILE_XML_POST_6_2 = 'data-file-schema.xml'
_ENVIRON_INPUT_FILE_NAME = 'environ.in'

_ENABLED_PARALLELIZATION_FLAGS = []
# Tuples (flag_name, default, help) of parallelization flags possible
# in QE codes. The flags that are actually implemented in a given
# code should be specified in the '_ENABLED_PARALLELIZATION_FLAGS'
# list of each calculation subclass.
_ALLOWED_PARALLELIZATION_FLAGS = ( # yapf: disable
(
'nimage', 1,
"The number of 'images', each corresponding to a different self-consistent or "
'linear-response calculation.'
),
(
'npool', 1,
"The number of 'pools', each taking care of a group of k-points."
),
(
'nband', 1,
"The number of 'band groups', each taking care of a group of Kohn-Sham orbitals."
),
(
'ntg', 1,
"The number of 'task groups' across which the FFT planes are distributed."
),
(
'ndiag', None,
"The number of 'linear algebra groups' used when parallelizing the subspace "
'diagonalization / iterative orthonormalization. By default, no parameter is '
'passed to Quantum ESPRESSO, meaning it will use its default.'
)
)
_PARALLELIZATION_FLAG_ALIASES = {
'nimage': ['ni', 'nimages', 'npot'],
'npool': ['nk', 'npools'],
'nband': ['nb', 'nbgrp', 'nband_group'],
'ntg': ['nt', 'ntask_groups', 'nyfft'],
'ndiag': ['northo', 'nd', 'nproc_diag', 'nproc_ortho']
}

# Additional files that should always be retrieved for the specific plugin
_internal_retrieve_list = []

Expand Down Expand Up @@ -82,6 +122,29 @@ def define(cls, spec):
spec.input_namespace('pseudos', valid_type=orm.UpfData, dynamic=True,
help='A mapping of `UpfData` nodes onto the kind name to which they should apply.')

_allowed_parallelization_mapping = {
flag_name: (default, help_) for flag_name, default, help_ in cls._ALLOWED_PARALLELIZATION_FLAGS
}
for flag_name in cls._ENABLED_PARALLELIZATION_FLAGS:
try:
default, help_ = _allowed_parallelization_mapping[flag_name]
except KeyError as exc:
raise KeyError(
"The parallelization flag '{}' specified in _ENABLED_PARALLELIZATION_FLAGS ".format(flag_name) +
'does not exist in _ALLOWED_PARALLELIZATION_FLAGS. Please report this issue '
'to the calculation plugin developer.'
) from exc

if default is None:
extra_kwargs = {'required': False}
else:
# We use 'functools.partial' here because lambda: `orm.Int(default)` would
# bind to the name `default`, and look up its value only when the lambda is
# actually called. Because `default` changes throughout the for-loop, this
# would give the wrong value.
extra_kwargs = {'default': partial(orm.Int, default)}
spec.input('parallelization.{}'.format(flag_name), help=help_, **extra_kwargs)

def prepare_for_submission(self, folder):
"""Create the input files from the input nodes passed to this instance of the `CalcJob`.
Expand Down Expand Up @@ -199,6 +262,35 @@ def prepare_for_submission(self, folder):
calcinfo.uuid = str(self.uuid)
# Empty command line by default
cmdline_params = settings.pop('CMDLINE', [])

# Add the parallelization flags.
# The `cmdline_params_normalized` are used only here to check
# for existing parallelization flags.
cmdline_params_normalized = []
for param in cmdline_params:
cmdline_params_normalized.extend(param.split())
# To make the order of flags consistent and "nice", we use the
# ordering from the flag definition.
for flag_name in self._ENABLED_PARALLELIZATION_FLAGS:
# We check for existing inputs for all possible flag names,
# to make sure a `DeprecationWarning` is emitted whenever
# a flag is specified in the `settings['CMDLINE']`.
all_aliases = self._PARALLELIZATION_FLAG_ALIASES[flag_name] + [flag_name]
if any('-{}'.format(alias) in cmdline_params_normalized for alias in all_aliases):
# To preserve backwards compatibility, we ignore the
# `parallelization.flag_name` input if the same flag is
# already present in `cmdline_params`. This is necessary
# because the new inputs can be set by their default,
# in user code that doesn't explicitly specify them.
warnings.warn(
"Specifying the parallelization flags through settings['CMDLINE'] is "
"deprecated, use the 'parallelization' input namespace instead.", DeprecationWarning
)
continue
if flag_name in self.inputs.parallelization:
flag_value = self.inputs.parallelization[flag_name].value
cmdline_params += ['-{}'.format(flag_name), str(flag_value)]

# we commented calcinfo.stin_name and added it here in cmdline_params
# in this way the mpirun ... pw.x ... < aiida.in
# is replaced by mpirun ... pw.x ... -in aiida.in
Expand Down
2 changes: 2 additions & 0 deletions aiida_quantumespresso/calculations/pw.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class PwCalculation(BasePwCpInputGenerator):
# Not using symlink in pw to allow multiple nscf to run on top of the same scf
_default_symlink_usage = False

_ENABLED_PARALLELIZATION_FLAGS = ['nimage', 'npool', 'nband', 'ntg', 'ndiag']

@classproperty
def xml_filepaths(cls):
"""Return a list of XML output filepaths relative to the remote working directory that should be retrieved."""
Expand Down
41 changes: 39 additions & 2 deletions tests/calculations/test_pw.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# -*- coding: utf-8 -*-
"""Tests for the `PwCalculation` class."""

import pytest

from aiida.common import datastructures
from aiida.plugins import DataFactory


def test_pw_default(fixture_sandbox, generate_calc_job, generate_inputs_pw, file_regression):
Expand All @@ -11,14 +15,14 @@ def test_pw_default(fixture_sandbox, generate_calc_job, generate_inputs_pw, file
calc_info = generate_calc_job(fixture_sandbox, entry_point_name, inputs)
upf = inputs['pseudos']['Si']

cmdline_params = ['-in', 'aiida.in']
cmdline_params = ['-nimage', '1', '-npool', '1', '-nband', '1', '-ntg', '1', '-in', 'aiida.in']
local_copy_list = [(upf.uuid, upf.filename, './pseudo/Si.upf')]
retrieve_list = ['aiida.out', './out/aiida.save/data-file-schema.xml', './out/aiida.save/data-file.xml']
retrieve_temporary_list = [['./out/aiida.save/K*[0-9]/eigenval*.xml', '.', 2]]

# Check the attributes of the returned `CalcInfo`
assert isinstance(calc_info, datastructures.CalcInfo)
assert sorted(calc_info.cmdline_params) == sorted(cmdline_params)
assert calc_info.cmdline_params == cmdline_params
assert sorted(calc_info.local_copy_list) == sorted(local_copy_list)
assert sorted(calc_info.retrieve_list) == sorted(retrieve_list)
assert sorted(calc_info.retrieve_temporary_list) == sorted(retrieve_temporary_list)
Expand All @@ -30,3 +34,36 @@ def test_pw_default(fixture_sandbox, generate_calc_job, generate_inputs_pw, file
# Checks on the files written to the sandbox folder as raw input
assert sorted(fixture_sandbox.get_content_list()) == sorted(['aiida.in', 'pseudo', 'out'])
file_regression.check(input_written, encoding='utf-8', extension='.in')


def test_pw_parallelization_inputs(fixture_sandbox, generate_calc_job, generate_inputs_pw):
"""Test that the parallelization settings are set correctly in the commandline params."""
Int = DataFactory('int') # pylint: disable=invalid-name
entry_point_name = 'quantumespresso.pw'

inputs = generate_inputs_pw()
inputs['parallelization'] = {'nimage': Int(1), 'npool': Int(4), 'nband': Int(2), 'ntg': Int(3), 'ndiag': Int(12)}
calc_info = generate_calc_job(fixture_sandbox, entry_point_name, inputs)

cmdline_params = ['-nimage', '1', '-npool', '4', '-nband', '2', '-ntg', '3', '-ndiag', '12', '-in', 'aiida.in']

# Check that the command-line parameters are as expected.
assert calc_info.cmdline_params == cmdline_params


@pytest.mark.parametrize(
'flag_name', ['nimage', 'ni', 'npool', 'nk', 'nband', 'nb', 'ntg', 'nt', 'northo', 'ndiag', 'nd']
)
def test_pw_parallelization_deprecation(fixture_sandbox, generate_calc_job, generate_inputs_pw, flag_name):
"""Test the deprecation warning on specifying parallelization flags manually.
Test that passing parallelization flags in the `settings['CMDLINE']
emits a `DeprecationWarning`.
"""
Dict = DataFactory('dict') # pylint: disable=invalid-name
entry_point_name = 'quantumespresso.pw'

inputs = generate_inputs_pw()
inputs['settings'] = Dict(dict={'CMDLINE': ['-{}'.format(flag_name), '2']})
with pytest.warns(DeprecationWarning):
generate_calc_job(fixture_sandbox, entry_point_name, inputs)

0 comments on commit 5e43d1e

Please sign in to comment.