diff --git a/aiida_quantumespresso/calculations/__init__.py b/aiida_quantumespresso/calculations/__init__.py index 5bae32f4a..f146e3031 100644 --- a/aiida_quantumespresso/calculations/__init__.py +++ b/aiida_quantumespresso/calculations/__init__.py @@ -2,6 +2,9 @@ """Base `CalcJob` for implementations for pw.x and cp.x of Quantum ESPRESSO.""" import abc import os +import warnings +from functools import partial +from types import MappingProxyType from aiida import orm from aiida.common import datastructures, exceptions @@ -27,6 +30,39 @@ class BasePwCpInputGenerator(CalcJob): _ENVIRON_INPUT_FILE_NAME = 'environ.in' _DEFAULT_IBRAV = 0 + # A mapping {flag_name: (default, help)} of parallelization flags + # possible in QE codes. The flags that are actually implemented in a + # given code should be specified in the '_ENABLED_PARALLELIZATION_FLAGS' + # tuple of each calculation subclass. + _PARALLELIZATION_FLAGS = MappingProxyType( + dict( + nimage=( + None, "The number of 'images', each corresponding to a different self-consistent or " + 'linear-response calculation.' + ), + npool=(None, "The number of 'pools', each taking care of a group of k-points."), + nband=(None, "The number of 'band groups', each taking care of a group of Kohn-Sham orbitals."), + ntg=(None, "The number of 'task groups' across which the FFT planes are distributed."), + ndiag=( + None, "The number of 'linear algebra groups' used when parallelizing the subspace " + 'diagonalization / iterative orthonormalization. By default, no parameter is ' + 'passed to Quantum ESPRESSO, meaning it will use its default.' + ) + ) + ) + + _ENABLED_PARALLELIZATION_FLAGS = tuple() + + _PARALLELIZATION_FLAG_ALIASES = MappingProxyType( + dict( + nimage=('ni', 'nimages', 'npot'), + npool=('nk', 'npools'), + nband=('nb', 'nbgrp', 'nband_group'), + ntg=('nt', 'ntask_groups', 'nyfft'), + ndiag=('northo', 'nd', 'nproc_diag', 'nproc_ortho') + ) + ) + # Additional files that should always be retrieved for the specific plugin _internal_retrieve_list = [] @@ -85,6 +121,26 @@ def define(cls, spec): spec.input_namespace('pseudos', valid_type=orm.UpfData, dynamic=True, help='A mapping of `UpfData` nodes onto the kind name to which they should apply.') + for flag_name in cls._ENABLED_PARALLELIZATION_FLAGS: + try: + default, help_ = cls._PARALLELIZATION_FLAGS[flag_name] + except KeyError as exc: + raise KeyError( + f"The parallelization flag '{flag_name}' specified in _ENABLED_PARALLELIZATION_FLAGS " + 'does not exist in _PARALLELIZATION_FLAGS. Please report this issue to the calculation ' + 'plugin developer.' + ) from exc + + if default is None: + extra_kwargs = {'required': False} + else: + # We use 'functools.partial' here because lambda: `orm.Int(default)` would + # bind to the name `default`, and look up its value only when the lambda is + # actually called. Because `default` changes throughout the for-loop, this + # would give the wrong value. + extra_kwargs = {'default': partial(orm.Int, default)} + spec.input(f'parallelization.{flag_name}', help=help_, **extra_kwargs) + def prepare_for_submission(self, folder): """Create the input files from the input nodes passed to this instance of the `CalcJob`. @@ -202,6 +258,35 @@ def prepare_for_submission(self, folder): calcinfo.uuid = str(self.uuid) # Empty command line by default cmdline_params = settings.pop('CMDLINE', []) + + # Add the parallelization flags. + # The `cmdline_params_normalized` are used only here to check + # for existing parallelization flags. + cmdline_params_normalized = [] + for param in cmdline_params: + cmdline_params_normalized.extend(param.split()) + # To make the order of flags consistent and "nice", we use the + # ordering from the flag definition. + for flag_name in self._ENABLED_PARALLELIZATION_FLAGS: + # We check for existing inputs for all possible flag names, + # to make sure a `DeprecationWarning` is emitted whenever + # a flag is specified in the `settings['CMDLINE']`. + all_aliases = list(self._PARALLELIZATION_FLAG_ALIASES[flag_name]) + [flag_name] + if any(f'-{alias}' in cmdline_params_normalized for alias in all_aliases): + # To preserve backwards compatibility, we ignore the + # `parallelization.flag_name` input if the same flag is + # already present in `cmdline_params`. This is necessary + # because the new inputs can be set by their default, + # in user code that doesn't explicitly specify them. + warnings.warn( + "Specifying the parallelization flags through settings['CMDLINE'] is " + "deprecated, use the 'parallelization' input namespace instead.", DeprecationWarning + ) + continue + if flag_name in self.inputs.parallelization: + flag_value = self.inputs.parallelization[flag_name].value + cmdline_params += [f'-{flag_name}', str(flag_value)] + # we commented calcinfo.stin_name and added it here in cmdline_params # in this way the mpirun ... pw.x ... < aiida.in # is replaced by mpirun ... pw.x ... -in aiida.in diff --git a/aiida_quantumespresso/calculations/pw.py b/aiida_quantumespresso/calculations/pw.py index 1bfe93dfe..a27503551 100644 --- a/aiida_quantumespresso/calculations/pw.py +++ b/aiida_quantumespresso/calculations/pw.py @@ -43,6 +43,8 @@ class PwCalculation(BasePwCpInputGenerator): # Not using symlink in pw to allow multiple nscf to run on top of the same scf _default_symlink_usage = False + _ENABLED_PARALLELIZATION_FLAGS = ('nimage', 'npool', 'nband', 'ntg', 'ndiag') + @classproperty def xml_filepaths(cls): """Return a list of XML output filepaths relative to the remote working directory that should be retrieved.""" diff --git a/tests/calculations/test_pw.py b/tests/calculations/test_pw.py index 604a544ba..0f3606137 100644 --- a/tests/calculations/test_pw.py +++ b/tests/calculations/test_pw.py @@ -24,7 +24,7 @@ def test_pw_default(fixture_sandbox, generate_calc_job, generate_inputs_pw, file # Check the attributes of the returned `CalcInfo` assert isinstance(calc_info, datastructures.CalcInfo) - assert sorted(calc_info.cmdline_params) == sorted(cmdline_params) + assert calc_info.cmdline_params == cmdline_params assert sorted(calc_info.local_copy_list) == sorted(local_copy_list) assert sorted(calc_info.retrieve_list) == sorted(retrieve_list) assert sorted(calc_info.retrieve_temporary_list) == sorted(retrieve_temporary_list) @@ -155,3 +155,42 @@ def test_pw_ibrav_tol(fixture_sandbox, generate_calc_job, fixture_code, generate # After adjusting the tolerance, the input validation no longer fails. inputs['settings'] = orm.Dict(dict={'ibrav_cell_tolerance': eps}) generate_calc_job(fixture_sandbox, entry_point_name, inputs) + + +def test_pw_parallelization_inputs(fixture_sandbox, generate_calc_job, generate_inputs_pw): + """Test that the parallelization settings are set correctly in the commandline params.""" + entry_point_name = 'quantumespresso.pw' + + inputs = generate_inputs_pw() + inputs['parallelization'] = { + 'nimage': orm.Int(1), + 'npool': orm.Int(4), + 'nband': orm.Int(2), + 'ntg': orm.Int(3), + 'ndiag': orm.Int(12) + } + calc_info = generate_calc_job(fixture_sandbox, entry_point_name, inputs) + + cmdline_params = ['-nimage', '1', '-npool', '4', '-nband', '2', '-ntg', '3', '-ndiag', '12', '-in', 'aiida.in'] + + # Check that the command-line parameters are as expected. + assert calc_info.cmdline_params == cmdline_params + + +@pytest.mark.parametrize( + 'flag_name', ['nimage', 'ni', 'npool', 'nk', 'nband', 'nb', 'ntg', 'nt', 'northo', 'ndiag', 'nd'] +) +def test_pw_parallelization_deprecation(fixture_sandbox, generate_calc_job, generate_inputs_pw, flag_name): + """Test the deprecation warning on specifying parallelization flags manually. + + Test that passing parallelization flags in the `settings['CMDLINE'] + emits a `DeprecationWarning`. + """ + entry_point_name = 'quantumespresso.pw' + + inputs = generate_inputs_pw() + extra_cmdline_args = [f'-{flag_name}', '2'] + inputs['settings'] = orm.Dict(dict={'CMDLINE': extra_cmdline_args}) + with pytest.warns(DeprecationWarning): + calc_info = generate_calc_job(fixture_sandbox, entry_point_name, inputs) + assert calc_info.cmdline_params == extra_cmdline_args + ['-in', 'aiida.in']