diff --git a/aiida_quantumespresso/workflows/pdos.py b/aiida_quantumespresso/workflows/pdos.py index 231cabe3c..f23daf8b6 100644 --- a/aiida_quantumespresso/workflows/pdos.py +++ b/aiida_quantumespresso/workflows/pdos.py @@ -307,6 +307,13 @@ def define(cls, spec): spec.expose_outputs(DosCalculation, namespace='dos') spec.expose_outputs(ProjwfcCalculation, namespace='projwfc') + @classmethod + def get_protocol_filepath(cls): + """Return ``pathlib.Path`` to the ``.yaml`` file that defines the protocols.""" + from importlib_resources import files + from . import protocols + return files(protocols) / 'pdos.yaml' + @classmethod def get_builder_from_protocol( cls, pw_code, dos_code, projwfc_code, structure, protocol=None, overrides=None, **kwargs @@ -324,11 +331,10 @@ def get_builder_from_protocol( :return: a process builder instance with all inputs defined ready for launch. """ - args = (pw_code, structure, protocol) inputs = cls.get_protocol_inputs(protocol, overrides) - builder = cls.get_builder() + args = (pw_code, structure, protocol) scf = PwBaseWorkChain.get_builder_from_protocol(*args, overrides=inputs.get('scf', None), **kwargs) scf['pw'].pop('structure', None) scf.pop('clean_workdir', None) @@ -338,6 +344,7 @@ def get_builder_from_protocol( nscf['pw']['parameters']['SYSTEM'].pop('degauss', None) nscf.pop('clean_workdir', None) + builder = cls.get_builder() builder.structure = structure builder.clean_workdir = orm.Bool(inputs['clean_workdir']) builder.scf = scf diff --git a/aiida_quantumespresso/workflows/protocols/pw/__init__.py b/aiida_quantumespresso/workflows/protocols/pw/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/aiida_quantumespresso/workflows/protocols/utils.py b/aiida_quantumespresso/workflows/protocols/utils.py index 709c01cef..2bfe3b509 100644 --- a/aiida_quantumespresso/workflows/protocols/utils.py +++ b/aiida_quantumespresso/workflows/protocols/utils.py @@ -1,36 +1,49 @@ # -*- coding: utf-8 -*- """Utilities to manipulate the workflow input protocols.""" -import functools -import os import pathlib +from typing import Optional, Union import yaml +from aiida.plugins import DataFactory, GroupFactory + +StructureData = DataFactory('structure') +PseudoPotentialFamily = GroupFactory('pseudo.family') + class ProtocolMixin: """Utility class for processes to build input mappings for a given protocol based on a YAML configuration file.""" @classmethod - def get_default_protocol(cls): + def get_protocol_filepath(cls) -> pathlib.Path: + """Return the ``pathlib.Path`` to the ``.yaml`` file that defines the protocols.""" + raise NotImplementedError + + @classmethod + def get_default_protocol(cls) -> str: """Return the default protocol for a given workflow class. :param cls: the workflow class. :return: the default protocol. """ - return load_protocol_file(cls)['default_protocol'] + return cls._load_protocol_file()['default_protocol'] @classmethod - def get_available_protocols(cls): + def get_available_protocols(cls) -> dict: """Return the available protocols for a given workflow class. :param cls: the workflow class. :return: dictionary of available protocols, where each key is a protocol and value is another dictionary that contains at least the key `description` and optionally other keys with supplementary information. """ - data = load_protocol_file(cls) + data = cls._load_protocol_file() return {protocol: {'description': values['description']} for protocol, values in data['protocols'].items()} @classmethod - def get_protocol_inputs(cls, protocol=None, overrides=None): + def get_protocol_inputs( + cls, + protocol: Optional[dict] = None, + overrides: Union[dict, pathlib.Path, None] = None, + ) -> dict: """Return the inputs for the given workflow class and protocol. :param cls: the workflow class. @@ -39,7 +52,7 @@ def get_protocol_inputs(cls, protocol=None, overrides=None): maintain the exact same nesting structure as the input port namespace of the corresponding workflow class. :return: mapping of inputs to be used for the workflow class. """ - data = load_protocol_file(cls) + data = cls._load_protocol_file() protocol = protocol or data['default_protocol'] try: @@ -48,17 +61,26 @@ def get_protocol_inputs(cls, protocol=None, overrides=None): raise ValueError( f'`{protocol}` is not a valid protocol. Call ``get_available_protocols`` to show available protocols.' ) from exception - inputs = recursive_merge(data['default_inputs'], protocol_inputs) inputs.pop('description') + if isinstance(overrides, pathlib.Path): + with overrides.open() as file: + overrides = yaml.safe_load(file) + if overrides: return recursive_merge(inputs, overrides) return inputs + @classmethod + def _load_protocol_file(cls) -> dict: + """Return the contents of the protocol file for workflow class.""" + with cls.get_protocol_filepath().open() as file: + return yaml.safe_load(file) -def recursive_merge(left, right): + +def recursive_merge(left: dict, right: dict) -> dict: """Recursively merge two dictionaries into a single dictionary. If any key is present in both ``left`` and ``right`` dictionaries, the value from the ``right`` dictionary is @@ -84,28 +106,6 @@ def recursive_merge(left, right): return merged -def load_protocol_file(cls): - """Load the protocol file for the given workflow class. - - :param cls: the workflow class. - :return: the contents of the protocol file. - """ - from aiida.plugins.entry_point import get_entry_point_from_class - - _, entry_point = get_entry_point_from_class(cls.__module__, cls.__name__) - entry_point_name = entry_point.name - parts = entry_point_name.split('.') - parts.pop(0) - filename = f'{parts.pop()}.yaml' - try: - basepath = functools.reduce(os.path.join, parts) - except TypeError: - basepath = '.' - - with (pathlib.Path(__file__).resolve().parent / basepath / filename).open() as handle: - return yaml.safe_load(handle) - - def get_magnetization_parameters() -> dict: """Return the mapping of suggested initial magnetic moments for each element. @@ -115,7 +115,11 @@ def get_magnetization_parameters() -> dict: return yaml.safe_load(handle) -def get_starting_magnetization(structure, pseudo_family, initial_magnetic_moments=None): +def get_starting_magnetization( + structure: StructureData, + pseudo_family: PseudoPotentialFamily, + initial_magnetic_moments: Optional[dict] = None +) -> dict: """Return the dictionary with starting magnetization for each kind in the structure. :param structure: the structure. diff --git a/aiida_quantumespresso/workflows/pw/bands.py b/aiida_quantumespresso/workflows/pw/bands.py index 5940f60ac..1a68f46ce 100644 --- a/aiida_quantumespresso/workflows/pw/bands.py +++ b/aiida_quantumespresso/workflows/pw/bands.py @@ -112,6 +112,13 @@ def define(cls, spec): help='The computed band structure.') # yapf: enable + @classmethod + def get_protocol_filepath(cls): + """Return ``pathlib.Path`` to the ``.yaml`` file that defines the protocols.""" + from importlib_resources import files + from ..protocols import pw as pw_protocols + return files(pw_protocols) / 'bands.yaml' + @classmethod def get_builder_from_protocol(cls, code, structure, protocol=None, overrides=None, **kwargs): """Return a builder prepopulated with inputs selected according to the chosen protocol. @@ -124,10 +131,9 @@ def get_builder_from_protocol(cls, code, structure, protocol=None, overrides=Non sub processes that are called by this workchain. :return: a process builder instance with all inputs defined ready for launch. """ - args = (code, structure, protocol) inputs = cls.get_protocol_inputs(protocol, overrides) - builder = cls.get_builder() + args = (code, structure, protocol) relax = PwRelaxWorkChain.get_builder_from_protocol(*args, overrides=inputs.get('relax', None), **kwargs) scf = PwBaseWorkChain.get_builder_from_protocol(*args, overrides=inputs.get('scf', None), **kwargs) bands = PwBaseWorkChain.get_builder_from_protocol(*args, overrides=inputs.get('bands', None), **kwargs) @@ -142,6 +148,7 @@ def get_builder_from_protocol(cls, code, structure, protocol=None, overrides=Non bands.pop('kpoints_distance', None) bands.pop('kpoints_force_parity', None) + builder = cls.get_builder() builder.structure = structure builder.relax = relax builder.scf = scf diff --git a/aiida_quantumespresso/workflows/pw/base.py b/aiida_quantumespresso/workflows/pw/base.py index ad974d171..4ecbf35f4 100644 --- a/aiida_quantumespresso/workflows/pw/base.py +++ b/aiida_quantumespresso/workflows/pw/base.py @@ -117,6 +117,13 @@ def define(cls, spec): message='Then ionic minimization cycle converged but the thresholds are exceeded in the final SCF.') # yapf: enable + @classmethod + def get_protocol_filepath(cls): + """Return ``pathlib.Path`` to the ``.yaml`` file that defines the protocols.""" + from importlib_resources import files + from ..protocols import pw as pw_protocols + return files(pw_protocols) / 'base.yaml' + @classmethod def get_builder_from_protocol( cls, @@ -160,7 +167,6 @@ def get_builder_from_protocol( if initial_magnetic_moments is not None and spin_type is not SpinType.COLLINEAR: raise ValueError(f'`initial_magnetic_moments` is specified but spin type `{spin_type}` is incompatible.') - builder = cls.get_builder() inputs = cls.get_protocol_inputs(protocol, overrides) meta_parameters = inputs.pop('meta_parameters') @@ -202,6 +208,7 @@ def get_builder_from_protocol( parameters['SYSTEM']['starting_magnetization'] = starting_magnetization # pylint: disable=no-member + builder = cls.get_builder() builder.pw['code'] = code builder.pw['pseudos'] = pseudo_family.get_pseudos(structure=structure) builder.pw['structure'] = structure diff --git a/aiida_quantumespresso/workflows/pw/relax.py b/aiida_quantumespresso/workflows/pw/relax.py index 6428b1acf..561d329b0 100644 --- a/aiida_quantumespresso/workflows/pw/relax.py +++ b/aiida_quantumespresso/workflows/pw/relax.py @@ -95,6 +95,13 @@ def define(cls, spec): help='The successfully relaxed structure.') # yapf: enable + @classmethod + def get_protocol_filepath(cls): + """Return ``pathlib.Path`` to the ``.yaml`` file that defines the protocols.""" + from importlib_resources import files + from ..protocols import pw as pw_protocols + return files(pw_protocols) / 'relax.yaml' + @classmethod def get_builder_from_protocol( cls, code, structure, protocol=None, overrides=None, relax_type=RelaxType.POSITIONS_CELL, **kwargs @@ -112,10 +119,9 @@ def get_builder_from_protocol( """ type_check(relax_type, RelaxType) - args = (code, structure, protocol) inputs = cls.get_protocol_inputs(protocol, overrides) - builder = cls.get_builder() + args = (code, structure, protocol) base = PwBaseWorkChain.get_builder_from_protocol(*args, overrides=inputs.get('base', None), **kwargs) base_final_scf = PwBaseWorkChain.get_builder_from_protocol( *args, overrides=inputs.get('base_final_scf', None), **kwargs @@ -153,6 +159,7 @@ def get_builder_from_protocol( if relax_type in (RelaxType.CELL, RelaxType.POSITIONS_CELL): base.pw.parameters['CELL']['cell_dofree'] = 'all' + builder = cls.get_builder() builder.base = base builder.base_final_scf = base_final_scf builder.structure = structure diff --git a/setup.json b/setup.json index 92efdf62f..fbbe8cba2 100644 --- a/setup.json +++ b/setup.json @@ -97,7 +97,8 @@ "packaging", "qe-tools~=2.0rc1", "xmlschema~=1.2,>=1.2.5", - "numpy" + "numpy", + "importlib_resources" ], "license": "MIT License", "name": "aiida_quantumespresso",