From 333ce3c31ac3c2f640fc101decba6b4d97503c62 Mon Sep 17 00:00:00 2001 From: Dominik Gresch Date: Fri, 18 Sep 2020 17:52:37 +0200 Subject: [PATCH] Implement parallelization options as explicit inputs. If any parallelization option is already specified in the `settings['CMDLINE']`, a `AiidaDeprecationWarning` is issued. However, the value from `settings['CMDLINE']` still takes precedence. This allows us to implement the defaults at the level of the input port, without having to care about it later on. The `BasePwCpInputGenerator` is given three class attributes: - `_PARALLELIZATION_FLAGS` is a mapping `{flag_name: (default, help)}` of all possible flags. - `_ENABLED_PARALLELIZATION_FLAGS` is a tuple of flag names that are implemented in a particular code. - `_PARALLELIZATION_FLAG_ALIASES` is used to detect all possible variations on a flag name. These are taken from the QE source. When checking for existing parallelization flags in the manually passed cmdline parameters, these are normalized by splitting on whitespace. QE ignores flags that are capitalized differently, so we do not have to normalize capitalization here. TODO: - Documentation - Set `_ENABLED_PARALLELIZATION_FLAGS` for codes other than pw.x. - `pw.x` accepts the `-nimage` flag - but I'm not sure if it makes any sense. --- .../calculations/__init__.py | 86 +++++++++++++++++++ aiida_quantumespresso/calculations/pw.py | 2 + tests/calculations/test_pw.py | 43 +++++++++- 3 files changed, 130 insertions(+), 1 deletion(-) diff --git a/aiida_quantumespresso/calculations/__init__.py b/aiida_quantumespresso/calculations/__init__.py index 5bae32f4a..e05911bc6 100644 --- a/aiida_quantumespresso/calculations/__init__.py +++ b/aiida_quantumespresso/calculations/__init__.py @@ -2,10 +2,14 @@ """Base `CalcJob` for implementations for pw.x and cp.x of Quantum ESPRESSO.""" import abc import os +import warnings +from functools import partial +from types import MappingProxyType from aiida import orm from aiida.common import datastructures, exceptions from aiida.common.lang import classproperty +from aiida.common.warnings import AiidaDeprecationWarning from qe_tools.converters import get_parameters_from_cell from aiida_quantumespresso.utils.convert import convert_input_to_namelist_entry @@ -27,6 +31,39 @@ class BasePwCpInputGenerator(CalcJob): _ENVIRON_INPUT_FILE_NAME = 'environ.in' _DEFAULT_IBRAV = 0 + # A mapping {flag_name: (default, help)} of parallelization flags + # possible in QE codes. The flags that are actually implemented in a + # given code should be specified in the '_ENABLED_PARALLELIZATION_FLAGS' + # tuple of each calculation subclass. + _PARALLELIZATION_FLAGS = MappingProxyType( + dict( + nimage=( + None, "The number of 'images', each corresponding to a different self-consistent or " + 'linear-response calculation.' + ), + npool=(None, "The number of 'pools', each taking care of a group of k-points."), + nband=(None, "The number of 'band groups', each taking care of a group of Kohn-Sham orbitals."), + ntg=(None, "The number of 'task groups' across which the FFT planes are distributed."), + ndiag=( + None, "The number of 'linear algebra groups' used when parallelizing the subspace " + 'diagonalization / iterative orthonormalization. By default, no parameter is ' + 'passed to Quantum ESPRESSO, meaning it will use its default.' + ) + ) + ) + + _ENABLED_PARALLELIZATION_FLAGS = tuple() + + _PARALLELIZATION_FLAG_ALIASES = MappingProxyType( + dict( + nimage=('ni', 'nimages', 'npot'), + npool=('nk', 'npools'), + nband=('nb', 'nbgrp', 'nband_group'), + ntg=('nt', 'ntask_groups', 'nyfft'), + ndiag=('northo', 'nd', 'nproc_diag', 'nproc_ortho') + ) + ) + # Additional files that should always be retrieved for the specific plugin _internal_retrieve_list = [] @@ -85,6 +122,26 @@ def define(cls, spec): spec.input_namespace('pseudos', valid_type=orm.UpfData, dynamic=True, help='A mapping of `UpfData` nodes onto the kind name to which they should apply.') + for flag_name in cls._ENABLED_PARALLELIZATION_FLAGS: + try: + default, help_ = cls._PARALLELIZATION_FLAGS[flag_name] + except KeyError as exc: + raise KeyError( + f"The parallelization flag '{flag_name}' specified in _ENABLED_PARALLELIZATION_FLAGS " + 'does not exist in _PARALLELIZATION_FLAGS. Please report this issue to the calculation ' + 'plugin developer.' + ) from exc + + if default is None: + extra_kwargs = {'required': False} + else: + # We use 'functools.partial' here because lambda: `orm.Int(default)` would + # bind to the name `default`, and look up its value only when the lambda is + # actually called. Because `default` changes throughout the for-loop, this + # would give the wrong value. + extra_kwargs = {'default': partial(orm.Int, default)} + spec.input(f'parallelization.{flag_name}', help=help_, **extra_kwargs) + def prepare_for_submission(self, folder): """Create the input files from the input nodes passed to this instance of the `CalcJob`. @@ -202,6 +259,35 @@ def prepare_for_submission(self, folder): calcinfo.uuid = str(self.uuid) # Empty command line by default cmdline_params = settings.pop('CMDLINE', []) + + # Add the parallelization flags. + # The `cmdline_params_normalized` are used only here to check + # for existing parallelization flags. + cmdline_params_normalized = [] + for param in cmdline_params: + cmdline_params_normalized.extend(param.split()) + # To make the order of flags consistent and "nice", we use the + # ordering from the flag definition. + for flag_name in self._ENABLED_PARALLELIZATION_FLAGS: + # We check for existing inputs for all possible flag names, + # to make sure a `DeprecationWarning` is emitted whenever + # a flag is specified in the `settings['CMDLINE']`. + all_aliases = list(self._PARALLELIZATION_FLAG_ALIASES[flag_name]) + [flag_name] + if any(f'-{alias}' in cmdline_params_normalized for alias in all_aliases): + # To preserve backwards compatibility, we ignore the + # `parallelization.flag_name` input if the same flag is + # already present in `cmdline_params`. This is necessary + # because the new inputs can be set by their default, + # in user code that doesn't explicitly specify them. + warnings.warn( + "Specifying the parallelization flags through settings['CMDLINE'] is " + "deprecated, use the 'parallelization' input namespace instead.", AiidaDeprecationWarning + ) + continue + if flag_name in self.inputs.parallelization: + flag_value = self.inputs.parallelization[flag_name].value + cmdline_params += [f'-{flag_name}', str(flag_value)] + # we commented calcinfo.stin_name and added it here in cmdline_params # in this way the mpirun ... pw.x ... < aiida.in # is replaced by mpirun ... pw.x ... -in aiida.in diff --git a/aiida_quantumespresso/calculations/pw.py b/aiida_quantumespresso/calculations/pw.py index 1bfe93dfe..a27503551 100644 --- a/aiida_quantumespresso/calculations/pw.py +++ b/aiida_quantumespresso/calculations/pw.py @@ -43,6 +43,8 @@ class PwCalculation(BasePwCpInputGenerator): # Not using symlink in pw to allow multiple nscf to run on top of the same scf _default_symlink_usage = False + _ENABLED_PARALLELIZATION_FLAGS = ('nimage', 'npool', 'nband', 'ntg', 'ndiag') + @classproperty def xml_filepaths(cls): """Return a list of XML output filepaths relative to the remote working directory that should be retrieved.""" diff --git a/tests/calculations/test_pw.py b/tests/calculations/test_pw.py index 604a544ba..a0b2f5101 100644 --- a/tests/calculations/test_pw.py +++ b/tests/calculations/test_pw.py @@ -5,6 +5,7 @@ from aiida import orm from aiida.common import datastructures +from aiida.common.warnings import AiidaDeprecationWarning from aiida_quantumespresso.utils.resources import get_default_options from aiida_quantumespresso.calculations.helpers import QEInputValidationError @@ -24,7 +25,7 @@ def test_pw_default(fixture_sandbox, generate_calc_job, generate_inputs_pw, file # Check the attributes of the returned `CalcInfo` assert isinstance(calc_info, datastructures.CalcInfo) - assert sorted(calc_info.cmdline_params) == sorted(cmdline_params) + assert calc_info.cmdline_params == cmdline_params assert sorted(calc_info.local_copy_list) == sorted(local_copy_list) assert sorted(calc_info.retrieve_list) == sorted(retrieve_list) assert sorted(calc_info.retrieve_temporary_list) == sorted(retrieve_temporary_list) @@ -155,3 +156,43 @@ def test_pw_ibrav_tol(fixture_sandbox, generate_calc_job, fixture_code, generate # After adjusting the tolerance, the input validation no longer fails. inputs['settings'] = orm.Dict(dict={'ibrav_cell_tolerance': eps}) generate_calc_job(fixture_sandbox, entry_point_name, inputs) + + +def test_pw_parallelization_inputs(fixture_sandbox, generate_calc_job, generate_inputs_pw): + """Test that the parallelization settings are set correctly in the commandline params.""" + entry_point_name = 'quantumespresso.pw' + + inputs = generate_inputs_pw() + inputs['parallelization'] = { + 'nimage': orm.Int(1), + 'npool': orm.Int(4), + 'nband': orm.Int(2), + 'ntg': orm.Int(3), + 'ndiag': orm.Int(12) + } + calc_info = generate_calc_job(fixture_sandbox, entry_point_name, inputs) + + cmdline_params = ['-nimage', '1', '-npool', '4', '-nband', '2', '-ntg', '3', '-ndiag', '12', '-in', 'aiida.in'] + + # Check that the command-line parameters are as expected. + assert calc_info.cmdline_params == cmdline_params + + +@pytest.mark.parametrize( + 'flag_name', ['nimage', 'ni', 'npool', 'nk', 'nband', 'nb', 'ntg', 'nt', 'northo', 'ndiag', 'nd'] +) +def test_pw_parallelization_deprecation(fixture_sandbox, generate_calc_job, generate_inputs_pw, flag_name): + """Test the deprecation warning on specifying parallelization flags manually. + + Test that passing parallelization flags in the `settings['CMDLINE'] + emits an `AiidaDeprecationWarning`. + """ + entry_point_name = 'quantumespresso.pw' + + inputs = generate_inputs_pw() + extra_cmdline_args = [f'-{flag_name}', '2'] + inputs['settings'] = orm.Dict(dict={'CMDLINE': extra_cmdline_args}) + with pytest.warns(AiidaDeprecationWarning) as captured_warnings: + calc_info = generate_calc_job(fixture_sandbox, entry_point_name, inputs) + assert calc_info.cmdline_params == extra_cmdline_args + ['-in', 'aiida.in'] + assert any('parallelization flags' in str(warning.message) for warning in captured_warnings.list)