Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ProtocolMixin: increase flexibility and usability for other packages #678

Merged
merged 2 commits into from
May 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions aiida_quantumespresso/workflows/pdos.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,13 @@ def define(cls, spec):
spec.expose_outputs(DosCalculation, namespace='dos')
spec.expose_outputs(ProjwfcCalculation, namespace='projwfc')

@classmethod
def get_protocol_filepath(cls):
"""Return ``pathlib.Path`` to the ``.yaml`` file that defines the protocols."""
from importlib_resources import files
from . import protocols
return files(protocols) / 'pdos.yaml'

@classmethod
def get_builder_from_protocol(
cls, pw_code, dos_code, projwfc_code, structure, protocol=None, overrides=None, **kwargs
Expand All @@ -324,11 +331,10 @@ def get_builder_from_protocol(
:return: a process builder instance with all inputs defined ready for launch.
"""

args = (pw_code, structure, protocol)

inputs = cls.get_protocol_inputs(protocol, overrides)
builder = cls.get_builder()

args = (pw_code, structure, protocol)
scf = PwBaseWorkChain.get_builder_from_protocol(*args, overrides=inputs.get('scf', None), **kwargs)
scf['pw'].pop('structure', None)
scf.pop('clean_workdir', None)
Expand All @@ -338,6 +344,7 @@ def get_builder_from_protocol(
nscf['pw']['parameters']['SYSTEM'].pop('degauss', None)
nscf.pop('clean_workdir', None)

builder = cls.get_builder()
builder.structure = structure
builder.clean_workdir = orm.Bool(inputs['clean_workdir'])
builder.scf = scf
Expand Down
Empty file.
70 changes: 37 additions & 33 deletions aiida_quantumespresso/workflows/protocols/utils.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,49 @@
# -*- coding: utf-8 -*-
"""Utilities to manipulate the workflow input protocols."""
import functools
import os
import pathlib
from typing import Optional, Union
import yaml

from aiida.plugins import DataFactory, GroupFactory

StructureData = DataFactory('structure')
PseudoPotentialFamily = GroupFactory('pseudo.family')


class ProtocolMixin:
"""Utility class for processes to build input mappings for a given protocol based on a YAML configuration file."""

@classmethod
def get_default_protocol(cls):
def get_protocol_filepath(cls) -> pathlib.Path:
"""Return the ``pathlib.Path`` to the ``.yaml`` file that defines the protocols."""
raise NotImplementedError

@classmethod
def get_default_protocol(cls) -> str:
"""Return the default protocol for a given workflow class.

:param cls: the workflow class.
:return: the default protocol.
"""
return load_protocol_file(cls)['default_protocol']
return cls._load_protocol_file()['default_protocol']

@classmethod
def get_available_protocols(cls):
def get_available_protocols(cls) -> dict:
"""Return the available protocols for a given workflow class.

:param cls: the workflow class.
:return: dictionary of available protocols, where each key is a protocol and value is another dictionary that
contains at least the key `description` and optionally other keys with supplementary information.
"""
data = load_protocol_file(cls)
data = cls._load_protocol_file()
return {protocol: {'description': values['description']} for protocol, values in data['protocols'].items()}

@classmethod
def get_protocol_inputs(cls, protocol=None, overrides=None):
def get_protocol_inputs(
cls,
protocol: Optional[dict] = None,
overrides: Union[dict, pathlib.Path, None] = None,
) -> dict:
"""Return the inputs for the given workflow class and protocol.

:param cls: the workflow class.
Expand All @@ -39,7 +52,7 @@ def get_protocol_inputs(cls, protocol=None, overrides=None):
maintain the exact same nesting structure as the input port namespace of the corresponding workflow class.
:return: mapping of inputs to be used for the workflow class.
"""
data = load_protocol_file(cls)
data = cls._load_protocol_file()
protocol = protocol or data['default_protocol']

try:
Expand All @@ -48,17 +61,26 @@ def get_protocol_inputs(cls, protocol=None, overrides=None):
raise ValueError(
f'`{protocol}` is not a valid protocol. Call ``get_available_protocols`` to show available protocols.'
) from exception

inputs = recursive_merge(data['default_inputs'], protocol_inputs)
inputs.pop('description')

if isinstance(overrides, pathlib.Path):
with overrides.open() as file:
overrides = yaml.safe_load(file)

if overrides:
return recursive_merge(inputs, overrides)

return inputs

@classmethod
def _load_protocol_file(cls) -> dict:
"""Return the contents of the protocol file for workflow class."""
with cls.get_protocol_filepath().open() as file:
return yaml.safe_load(file)

def recursive_merge(left, right):

def recursive_merge(left: dict, right: dict) -> dict:
"""Recursively merge two dictionaries into a single dictionary.

If any key is present in both ``left`` and ``right`` dictionaries, the value from the ``right`` dictionary is
Expand All @@ -84,28 +106,6 @@ def recursive_merge(left, right):
return merged


def load_protocol_file(cls):
"""Load the protocol file for the given workflow class.

:param cls: the workflow class.
:return: the contents of the protocol file.
"""
from aiida.plugins.entry_point import get_entry_point_from_class

_, entry_point = get_entry_point_from_class(cls.__module__, cls.__name__)
entry_point_name = entry_point.name
parts = entry_point_name.split('.')
parts.pop(0)
filename = f'{parts.pop()}.yaml'
try:
basepath = functools.reduce(os.path.join, parts)
except TypeError:
basepath = '.'

with (pathlib.Path(__file__).resolve().parent / basepath / filename).open() as handle:
return yaml.safe_load(handle)


def get_magnetization_parameters() -> dict:
"""Return the mapping of suggested initial magnetic moments for each element.

Expand All @@ -115,7 +115,11 @@ def get_magnetization_parameters() -> dict:
return yaml.safe_load(handle)


def get_starting_magnetization(structure, pseudo_family, initial_magnetic_moments=None):
def get_starting_magnetization(
structure: StructureData,
pseudo_family: PseudoPotentialFamily,
initial_magnetic_moments: Optional[dict] = None
) -> dict:
"""Return the dictionary with starting magnetization for each kind in the structure.

:param structure: the structure.
Expand Down
11 changes: 9 additions & 2 deletions aiida_quantumespresso/workflows/pw/bands.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,13 @@ def define(cls, spec):
help='The computed band structure.')
# yapf: enable

@classmethod
def get_protocol_filepath(cls):
"""Return ``pathlib.Path`` to the ``.yaml`` file that defines the protocols."""
from importlib_resources import files
from ..protocols import pw as pw_protocols
return files(pw_protocols) / 'bands.yaml'

@classmethod
def get_builder_from_protocol(cls, code, structure, protocol=None, overrides=None, **kwargs):
"""Return a builder prepopulated with inputs selected according to the chosen protocol.
Expand All @@ -124,10 +131,9 @@ def get_builder_from_protocol(cls, code, structure, protocol=None, overrides=Non
sub processes that are called by this workchain.
:return: a process builder instance with all inputs defined ready for launch.
"""
args = (code, structure, protocol)
inputs = cls.get_protocol_inputs(protocol, overrides)
builder = cls.get_builder()

args = (code, structure, protocol)
relax = PwRelaxWorkChain.get_builder_from_protocol(*args, overrides=inputs.get('relax', None), **kwargs)
scf = PwBaseWorkChain.get_builder_from_protocol(*args, overrides=inputs.get('scf', None), **kwargs)
bands = PwBaseWorkChain.get_builder_from_protocol(*args, overrides=inputs.get('bands', None), **kwargs)
Expand All @@ -142,6 +148,7 @@ def get_builder_from_protocol(cls, code, structure, protocol=None, overrides=Non
bands.pop('kpoints_distance', None)
bands.pop('kpoints_force_parity', None)

builder = cls.get_builder()
builder.structure = structure
builder.relax = relax
builder.scf = scf
Expand Down
9 changes: 8 additions & 1 deletion aiida_quantumespresso/workflows/pw/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,13 @@ def define(cls, spec):
message='Then ionic minimization cycle converged but the thresholds are exceeded in the final SCF.')
# yapf: enable

@classmethod
def get_protocol_filepath(cls):
"""Return ``pathlib.Path`` to the ``.yaml`` file that defines the protocols."""
from importlib_resources import files
from ..protocols import pw as pw_protocols
return files(pw_protocols) / 'base.yaml'

@classmethod
def get_builder_from_protocol(
cls,
Expand Down Expand Up @@ -160,7 +167,6 @@ def get_builder_from_protocol(
if initial_magnetic_moments is not None and spin_type is not SpinType.COLLINEAR:
raise ValueError(f'`initial_magnetic_moments` is specified but spin type `{spin_type}` is incompatible.')

builder = cls.get_builder()
inputs = cls.get_protocol_inputs(protocol, overrides)

meta_parameters = inputs.pop('meta_parameters')
Expand Down Expand Up @@ -202,6 +208,7 @@ def get_builder_from_protocol(
parameters['SYSTEM']['starting_magnetization'] = starting_magnetization

# pylint: disable=no-member
builder = cls.get_builder()
builder.pw['code'] = code
builder.pw['pseudos'] = pseudo_family.get_pseudos(structure=structure)
builder.pw['structure'] = structure
Expand Down
11 changes: 9 additions & 2 deletions aiida_quantumespresso/workflows/pw/relax.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,13 @@ def define(cls, spec):
help='The successfully relaxed structure.')
# yapf: enable

@classmethod
def get_protocol_filepath(cls):
"""Return ``pathlib.Path`` to the ``.yaml`` file that defines the protocols."""
from importlib_resources import files
from ..protocols import pw as pw_protocols
return files(pw_protocols) / 'relax.yaml'

@classmethod
def get_builder_from_protocol(
cls, code, structure, protocol=None, overrides=None, relax_type=RelaxType.POSITIONS_CELL, **kwargs
Expand All @@ -112,10 +119,9 @@ def get_builder_from_protocol(
"""
type_check(relax_type, RelaxType)

args = (code, structure, protocol)
inputs = cls.get_protocol_inputs(protocol, overrides)
builder = cls.get_builder()

args = (code, structure, protocol)
base = PwBaseWorkChain.get_builder_from_protocol(*args, overrides=inputs.get('base', None), **kwargs)
base_final_scf = PwBaseWorkChain.get_builder_from_protocol(
*args, overrides=inputs.get('base_final_scf', None), **kwargs
Expand Down Expand Up @@ -153,6 +159,7 @@ def get_builder_from_protocol(
if relax_type in (RelaxType.CELL, RelaxType.POSITIONS_CELL):
base.pw.parameters['CELL']['cell_dofree'] = 'all'

builder = cls.get_builder()
builder.base = base
builder.base_final_scf = base_final_scf
builder.structure = structure
Expand Down
3 changes: 2 additions & 1 deletion setup.json
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@
"packaging",
"qe-tools~=2.0rc1",
"xmlschema~=1.2,>=1.2.5",
"numpy"
"numpy",
"importlib_resources"
],
"license": "MIT License",
"name": "aiida_quantumespresso",
Expand Down