diff --git a/pyproject.toml b/pyproject.toml index 8f93d8275..3636822a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -115,6 +115,7 @@ aiida-quantumespresso = 'aiida_quantumespresso.cli:cmd_root' 'quantumespresso.matdyn.base' = 'aiida_quantumespresso.workflows.matdyn.base:MatdynBaseWorkChain' 'quantumespresso.pdos' = 'aiida_quantumespresso.workflows.pdos:PdosWorkChain' 'quantumespresso.xspectra.base' = 'aiida_quantumespresso.workflows.xspectra.base:XspectraBaseWorkChain' +'quantumespresso.xspectra.core' = 'aiida_quantumespresso.workflows.xspectra.core:XspectraCoreWorkChain' [tool.flit.module] name = 'aiida_quantumespresso' diff --git a/src/aiida_quantumespresso/calculations/functions/xspectra/__init__.py b/src/aiida_quantumespresso/calculations/functions/xspectra/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/aiida_quantumespresso/calculations/functions/xspectra/get_powder_spectrum.py b/src/aiida_quantumespresso/calculations/functions/xspectra/get_powder_spectrum.py new file mode 100644 index 000000000..c152ffd30 --- /dev/null +++ b/src/aiida_quantumespresso/calculations/functions/xspectra/get_powder_spectrum.py @@ -0,0 +1,102 @@ +# -*- coding: utf-8 -*- +"""CalcFunction to compute the powder spectrum of a set of XANES spectra from ``XspectraCalculation``.""" +from aiida.common import ValidationError +from aiida.engine import calcfunction +from aiida.orm import XyData + + +@calcfunction +def get_powder_spectrum(**kwargs): # pylint: disable=too-many-statements + """Combine the given spectra into a single "Powder" spectrum, representing the XAS of a powder sample. + + The function expects between 1 and 3 XyData nodes from ``XspectraCalculation`` whose + polarisation vectors are the basis vectors of the original crystal structure (100, 010, 001). + """ + spectra = [node for node in kwargs.values() if isinstance(node, XyData)] + vectors = [node.creator.res['xepsilon'] for node in spectra] + + if len(vectors) > 3: + raise ValidationError(f'Expected between 1 and 3 XyData nodes as input, but {len(spectra)} were given.') + + # If the system is isochoric (e.g. a cubic system) then the three basis vectors are + # equal to each other, thus we simply return the + if len(vectors) == 1: + vector = vectors[0] + vector_string = f'{float(vector[0])} {float(vector[1])} {float(vector[2])}' + if vector not in [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]: + raise ValidationError( + f'Polarisation vector ({vector_string}) does not correspond to a crystal basis vector (100, 010, 001)' + ) + + powder_spectrum = spectra[0] + powder_x = powder_spectrum.get_x()[1] + powder_y = powder_spectrum.get_y()[0][1] + + powder_data = XyData() + powder_data.set_x(powder_x, 'energy', 'eV') + powder_data.set_y(powder_y, 'sigma', 'n/a') + + # if the system is dichoric (e.g. a hexagonal system) then the A and B periodic + # dimensions are equal to each other by symmetry, thus the powder spectrum is simply + # the average of 2x the 1 0 0 eps vector and 1x the 0 0 1 eps vector + if len(vectors) == 2: + # Since the individual vectors are labelled, we can extract just the spectra needed + # to produce the powder and leave the rest + + spectrum_a = None + spectrum_c = None + for vector, spectrum in zip(vectors, spectra): + vector_string = f'{float(vector[0])} {float(vector[1])} {float(vector[2])}' + if vector in [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]: + spectrum_a = spectrum + elif vector == [0.0, 0.0, 1.0]: + spectrum_c = spectrum + else: + raise ValidationError( + f'Polarisation vector ({vector_string}) does not correspond to a crystal basis vector' + ' (100, 010, 001)' + ) + + if spectrum_a and not spectrum_c: + raise ValidationError(f'Found no polarisation vector for the C-axis ([0, 0, 1]), found instead: {vectors}.') + powder_x = spectrum_a.get_x()[1] + yvals_a = spectrum_a.get_y()[0][1] + yvals_c = spectrum_c.get_y()[0][1] + + powder_y = ((yvals_a * 2) + yvals_c) / 3 + powder_data = XyData() + powder_data.set_x(powder_x, 'energy', 'eV') + powder_data.set_y(powder_y, 'sigma', 'n/a') + + # if the system is trichoric (e.g. a monoclinic system) then no periodic dimensions + # are equal by symmetry, thus the powder spectrum is the average of the three basis + # dipole vectors (1.0 0.0 0.0, 0.0 1.0 0.0, 0.0 0.0 1.0) + if len(vectors) == 3: + # Since the individual vectors are labelled, we can extract just the spectra needed to + # produce the powder and leave the rest + for vector, spectra in zip(vectors, spectra): + vector_string = f'{float(vector[0])} {float(vector[1])} {float(vector[2])}' + if vector == [1.0, 0.0, 0.0]: + spectrum_a = spectra + elif vector == [0.0, 1.0, 0.0]: + spectrum_b = spectra + elif vector == [0.0, 0.0, 1.0]: + spectrum_c = spectra + else: + raise ValidationError( + f'Polarisation vector ({vector_string}) does not correspond to a crystal basis vector' + ' (100, 010, 001)' + ) + + powder_x = spectrum_a.get_x()[1] + yvals_a = spectrum_a.get_y()[0][1] + yvals_b = spectrum_b.get_y()[0][1] + yvals_c = spectrum_c.get_y()[0][1] + + powder_y = (yvals_a + yvals_b + yvals_c) / 3 + + powder_data = XyData() + powder_data.set_x(powder_x, 'energy', 'eV') + powder_data.set_y(powder_y, 'sigma', 'n/a') + + return powder_data diff --git a/src/aiida_quantumespresso/calculations/functions/xspectra/merge_spectra.py b/src/aiida_quantumespresso/calculations/functions/xspectra/merge_spectra.py new file mode 100644 index 000000000..4e9d7116f --- /dev/null +++ b/src/aiida_quantumespresso/calculations/functions/xspectra/merge_spectra.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- +"""CalcFunction to merge multiple ``XyData`` nodes of calculated XANES spectra into a new ``XyData`` node.""" +from aiida.engine import calcfunction +from aiida.orm import XyData + + +@calcfunction +def merge_spectra(**kwargs): + """Compile all calculated spectra into a single ``XyData`` node for easier plotting. + + The keyword arguments must be an arbitrary number of ``XyData`` nodes from + the `output_spectra` of `XspectraCalculation`s, all other `kwargs` will be discarded at + runtime. + + Returns a single ``XyData`` node where each set of y values is labelled + according to the polarisation vector used for the `XspectraCalculation`. + """ + output_spectra = XyData() + y_arrays_list = [] + y_units_list = [] + y_labels_list = [] + + spectra = [node for label, node in kwargs.items() if isinstance(node, XyData)] + + for spectrum_node in spectra: + calc_node = spectrum_node.creator + calc_out_params = calc_node.res + eps_vector = calc_out_params['xepsilon'] + + old_y_component = spectrum_node.get_y() + if len(old_y_component) == 1: + y_array = old_y_component[0][1] + y_units = old_y_component[0][2] + y_arrays_list.append(y_array) + y_units_list.append(y_units) + y_labels_list.append(f'sigma_{eps_vector[0]}_{eps_vector[1]}_{eps_vector[2]}') + elif len(old_y_component) == 3: + y_tot = old_y_component[0][1] + y_tot_units = old_y_component[0][2] + y_tot_label = f'sigma_tot_{eps_vector[0]}_{eps_vector[1]}_{eps_vector[2]}' + y_arrays_list.append(y_tot) + y_units_list.append(y_tot_units) + y_labels_list.append(y_tot_label) + + y_up = old_y_component[1][1] + y_up_units = old_y_component[1][2] + y_up_label = f'sigma_up_{eps_vector[0]}_{eps_vector[1]}_{eps_vector[2]}' + y_arrays_list.append(y_up) + y_units_list.append(y_up_units) + y_labels_list.append(y_up_label) + + y_down = old_y_component[2][1] + y_down_units = old_y_component[2][2] + y_down_label = f'sigma_down_{eps_vector[0]}_{eps_vector[1]}_{eps_vector[2]}' + y_arrays_list.append(y_down) + y_units_list.append(y_down_units) + y_labels_list.append(y_down_label) + + x_array = spectrum_node.get_x()[1] + x_label = spectrum_node.get_x()[0] + x_units = spectrum_node.get_x()[2] + + output_spectra.set_x(x_array, x_label, x_units) + output_spectra.set_y(y_arrays_list, y_labels_list, y_units_list) + + return output_spectra diff --git a/src/aiida_quantumespresso/workflows/protocols/core_hole_treatments.yaml b/src/aiida_quantumespresso/workflows/protocols/core_hole_treatments.yaml new file mode 100644 index 000000000..2559287c7 --- /dev/null +++ b/src/aiida_quantumespresso/workflows/protocols/core_hole_treatments.yaml @@ -0,0 +1,29 @@ +default_inputs: + SYSTEM: + tot_charge: 1 +default_treatment: full +treatments: + full: + description: 'Core-hole treatment using a formal countercharge of +1, equivalent to removing one electron from the system.' + half: + description: 'Core-hole treatment using a formal countercharge of +0.5, equivalent to removing half an electron from the system.' + SYSTEM: + tot_charge: 0.5 + xch_fixed: + description: 'Core-hole treatment which places the excited electron into the conduction band (fixed occupations).' + SYSTEM: + occupations: fixed + tot_charge: 0 + nspin: 2 + tot_magnetization: 1 + xch_smear: + description: 'Core-hole treatment which places the excited electron into the conduction band (smeared occupations).' + SYSTEM: + occupations: smearing + tot_charge: 0 + nspin: 2 + starting_magnetization(1): 0 + none: + description: 'Applies no core-hole treatment (overrides the default tot_charge and changes it to 0).' + SYSTEM: + tot_charge: 0 diff --git a/src/aiida_quantumespresso/workflows/protocols/xspectra/core.yaml b/src/aiida_quantumespresso/workflows/protocols/xspectra/core.yaml new file mode 100644 index 000000000..d066e5023 --- /dev/null +++ b/src/aiida_quantumespresso/workflows/protocols/xspectra/core.yaml @@ -0,0 +1,17 @@ +default_inputs: + abs_atom_marker: X + clean_workdir: True + get_powder_spectrum: False + run_replot: False + eps_vectors: + - [1., 0., 0.] + - [0., 1., 0.] + - [0., 0., 1.] +default_protocol: moderate +protocols: + moderate: + description: 'Protocol to perform XANES dipole calculations at normal precision and moderate computational cost.' + precise: + description: 'Protocol to perform XANES dipole calculations at high precision and higher computational cost.' + fast: + description: 'Protocol to perform XANES dipole calculations at low precision and minimal computational cost for testing purposes.' diff --git a/src/aiida_quantumespresso/workflows/xspectra/core.py b/src/aiida_quantumespresso/workflows/xspectra/core.py new file mode 100644 index 000000000..474db806e --- /dev/null +++ b/src/aiida_quantumespresso/workflows/xspectra/core.py @@ -0,0 +1,760 @@ +# -*- coding: utf-8 -*- +"""Workchain to compute the X-ray absorption spectrum for a given structure. + +Uses QuantumESPRESSO pw.x and xspectra.x, requires ``aiida-shell`` to run ``upf2plotcore.sh``. +""" +import pathlib +from typing import Optional, Union + +from aiida import orm +from aiida.common import AttributeDict +from aiida.engine import ToContext, WorkChain, append_, if_ +from aiida.orm.nodes.data.base import to_aiida_type +from aiida.plugins import CalculationFactory, WorkflowFactory +import yaml + +from aiida_quantumespresso.calculations.functions.xspectra.get_powder_spectrum import get_powder_spectrum +from aiida_quantumespresso.calculations.functions.xspectra.merge_spectra import merge_spectra +from aiida_quantumespresso.utils.mapping import prepare_process_inputs +from aiida_quantumespresso.workflows.protocols.utils import ProtocolMixin, recursive_merge + +PwCalculation = CalculationFactory('quantumespresso.pw') +PwBaseWorkChain = WorkflowFactory('quantumespresso.pw.base') +XspectraBaseWorkChain = WorkflowFactory('quantumespresso.xspectra.base') + + +class XspectraCoreWorkChain(ProtocolMixin, WorkChain): + """Workchain to compute X-ray absorption spectra for a given structure using Quantum ESPRESSO. + + The workflow follows the process required to compute the XAS of an input structure: an SCF calculation is performed + using the provided structure, which is then followed by the calculation of the XAS itself by XSpectra. The + calculations performed by the WorkChain in a typical run will be: + + - PwSCF calculation with pw.x of the input structure with a core-hole present. + - Generation of core-wavefunction data with upf2plotcore.sh (if requested). + - XAS calculation with xspectra.x to compute the Lanczos coefficients and print the XANES spectra for the + polarisation vectors requested in the input. + - Collation of output data from pw.x and xspectra.x calculations, including a combination of XANES dipole spectra + based on polarisation vectors to represent the powder spectrum of the structure (if requested). + + If ``run_replot = True`` is set in the inputs (defaults to False), the WorkChain will run a second xspectra.x + calculation which replots the spectra produced from the ``xs_prod`` step. This option can be very useful for + obtaining a final spectrum at low levels of broadening (relative to the default of 0.5 eV), particularly as higher + levels of broadening significantly speed up the convergence of the Lanczos procedure. Inputs for the replot + calculation are found in the ``xs_plot`` namespace. + + The core-wavefunction plot derived from the ground-state of the absorbing element can be provided as a top-level + input or produced by the WorkChain. If left to the WorkChain, the ground-state pseudopotential assigned to the + absorbing element will be used to generate this data using the upf2plotcore.sh utility script (via the + ``aiida-shell`` plugin). + + In its current stage of development, the workflow requires the following: + + - An input structure where the desired absorbing atom in the system is marked as a separate Kind. The default + behaviour for the WorkChain is to set the Kind name as 'X', however this can be changed via the `overrides` + dictionary. + - A code node for ``upf2plotcore``, configured for the ``aiida-shell`` plugin + (https://github.com/sphuber/aiida-shell). Alternatively, a ``SinglefileData`` node from a previous ``ShellJob`` + run can be supplied under ``inputs.core_wfc_data``. + - A suitable pair of pseudopotentials for the element type of the absorbing atom, one for the ground-state occupancy + which contains GIPAW informtation for the core level of interest for the XAS (e.g. 1s in the case of a K-edge + calculation) and the other containing a core hole. (For the moment this can be passed either via the + ``core_hole_pseudos`` field in ``get_builder_from_protocol`` or via the overrides, but will be changed later once + full families of core-hole pseudopotentials become available). + """ + + # pylint: disable=too-many-public-methods, too-many-statements + + @classmethod + def define(cls, spec): + """Define the process specification.""" + + super().define(spec) + spec.expose_inputs( + PwBaseWorkChain, + namespace='scf', + exclude=('pw.parent_folder', 'pw.structure', 'clean_workdir'), + namespace_options={ + 'help': ('Input parameters for the `pw.x` calculation.'), + 'validator': cls.validate_scf, + } + ) + spec.expose_inputs( + XspectraBaseWorkChain, + namespace='xs_prod', + exclude=('clean_workdir', 'xspectra.parent_folder', 'xspectra.core_wfc_data'), + namespace_options={ + 'help': ('Input parameters for the `xspectra.x` calculation' + ' to compute the Lanczos.') + } + ) + spec.expose_inputs( + XspectraBaseWorkChain, + namespace='xs_plot', + exclude=('clean_workdir', 'xspectra.parent_folder', 'xspectra.core_wfc_data'), + namespace_options={ + 'help': ('Input parameters for the re-plot `xspectra.x` calculation of the Lanczos.'), + 'required': False, + 'populate_defaults': False + } + ) + spec.inputs.validator = cls.validate_inputs + spec.input( + 'structure', + valid_type=orm.StructureData, + help=( + 'Structure to be used for calculation, with at least one site containing the `abs_atom_marker` ' + 'as the kind label.' + ) + ) + spec.input( + 'eps_vectors', + valid_type=orm.List, + help=( + 'The list of 3-vectors to use in XSpectra sub-processes. ' + 'The number of sub-lists will subsequently define the number of XSpectra calculations to perform' + ), + ) + spec.input( + 'abs_atom_marker', + valid_type=orm.Str, + required=False, + help=( + 'The name for the Kind representing the absorbing atom in the structure. ' + 'Must corespond to a Kind within the StructureData node supplied to the calculation.' + ), + ) + spec.input( + 'get_powder_spectrum', + valid_type=orm.Bool, + default=lambda: orm.Bool(False), + help=( + 'If `True`, the WorkChain will combine XANES dipole spectra computed using the XAS basis vectors' + ' defined according to the `get_powder_spectrum` CalcFunction.' + ), + ) + spec.input( + 'core_wfc_data', + valid_type=orm.SinglefileData, + required=False, + help='The core wavefunction data file extracted from the ground-state pseudo for the absorbing atom.' + ) + spec.input( + 'run_replot', + valid_type=orm.Bool, + serializer=to_aiida_type, + default=lambda: orm.Bool(False), + ) + spec.input( + 'upf2plotcore_code', + valid_type=orm.Code, + required=False, + help='The code node required for upf2plotcore.sh configured for ``aiida-shell``. ' + 'Must be provided if `core_wfc_data` is not provided.' + ) + spec.input( + 'clean_workdir', + valid_type=orm.Bool, + serializer=to_aiida_type, + default=lambda: orm.Bool(False), + help=('If `True`, work directories of all called calculation will be cleaned at the end of execution.'), + ) + spec.input( + 'dry_run', + valid_type=orm.Bool, + serializer=to_aiida_type, + required=False, + help='Terminate workchain steps before submitting calculations (test purposes only).' + ) + spec.outline( + cls.setup, + cls.run_scf, + cls.inspect_scf, + if_(cls.should_run_upf2plotcore)( + cls.run_upf2plotcore, + cls.inspect_upf2plotcore, + ), + cls.run_all_xspectra_prod, + cls.inspect_all_xspectra_prod, + if_(cls.should_run_replot)( + cls.run_all_xspectra_plot, + cls.inspect_all_xspectra_plot, + ), + cls.results, + ) + + spec.exit_code(401, 'ERROR_SUB_PROCESS_FAILED_SCF', message='The SCF sub process failed') + spec.exit_code(402, 'ERROR_SUB_PROCESS_FAILED_XSPECTRA', message='One or more XSpectra sub processes failed') + spec.exit_code( + 403, + 'ERROR_NO_GIPAW_INFO_FOUND', + message='The pseudo for the absorbing element contains no' + ' GIPAW information.' + ) + spec.output( + 'parameters_scf', valid_type=orm.Dict, help='The output parameters of the SCF' + ' `PwBaseWorkChain`.' + ) + spec.output_namespace( + 'parameters_xspectra', + valid_type=orm.Dict, + help='The output dictionaries of each `XspectraBaseWorkChain` performed', + dynamic=True + ) + spec.output( + 'spectra', + valid_type=orm.XyData, + help='An XyData node containing all the final spectra produced by the WorkChain.' + ) + spec.output('powder_spectrum', valid_type=orm.XyData, required=False, help='The simulated powder spectrum') + + @classmethod + def get_protocol_filepath(cls): + """Return ``pathlib.Path`` to the ``.yaml`` file that defines the protocols.""" + from importlib_resources import files + + from ..protocols import xspectra as protocols + return files(protocols) / 'core.yaml' + + @classmethod + def get_treatment_filepath(cls) -> pathlib.Path: + """Return ``pathlib.Path`` to the ``.yaml`` file that defines the core-hole treatments for the SCF step.""" + from importlib_resources import files + + from .. import protocols + return files(protocols) / 'core_hole_treatments.yaml' + + @classmethod + def get_default_treatment(cls) -> str: + """Return the default core-hole treatment. + + :param cls: the workflow class. + :return: the default core-hole treatment + """ + + return cls._load_treatment_file()['default_treatment'] + + @classmethod + def get_available_treatments(cls) -> dict: + """Return the available core-hole treatments. + + :param cls: the workflow class. + :return: dictionary of available treatments, where each key is a treatment and value + is another dictionary that contains at least the key `description` and + optionally other keys with supplimentary information. + """ + data = cls._load_treatment_file() + return {treatment: {'description': values['description']} for treatment, values in data['treatments'].items()} + + @classmethod + def get_treatment_inputs( + cls, + treatment: Optional[dict] = None, + overrides: Union[dict, pathlib.Path, None] = None, + ) -> dict: + """Return the inputs for the given workflow class and core-hole treatment. + + :param cls: the workflow class. + :param treatment: optional specific treatment, if not specified, the default will be used + :param overrides: dictionary of inputs that should override those specified by the treatment. The mapping should + maintain the exact same nesting structure as the input port namespace of the corresponding workflow class. + :return: mapping of inputs to be used for the workflow class. + """ + data = cls._load_treatment_file() + treatment = treatment or data['default_treatment'] + + try: + treatment_inputs = data['treatments'][treatment] + except KeyError as exception: + raise ValueError( + f'`{treatment}` is not a valid treatment. ' + 'Call ``get_available_treatments`` to show available treatments.' + ) from exception + inputs = recursive_merge(data['default_inputs'], treatment_inputs) + inputs.pop('description') + + if isinstance(overrides, pathlib.Path): + with overrides.open() as file: + overrides = yaml.safe_load(file) + + if overrides: + return recursive_merge(inputs, overrides) + + return inputs + + @classmethod + def _load_treatment_file(cls) -> dict: + """Return the contents of the core-hole treatment file.""" + with cls.get_treatment_filepath().open() as file: + return yaml.safe_load(file) + + @classmethod + def get_builder_from_protocol( + cls, + pw_code, + xs_code, + structure, + upf2plotcore_code=None, + core_wfc_data=None, + core_hole_pseudos=None, + core_hole_treatment=None, + protocol=None, + overrides=None, + options=None, + **kwargs + ): + """Return a builder prepopulated with inputs selected according to the chosen protocol. + + :param pw_code: the ``Code`` instance configured for the ``quantumespresso.pw`` + plugin. + :param xs_code: the ``Code`` instance configured for the + ``quantumespresso.xspectra`` plugin. + :param structure: the ``StructureData`` instance to use. + :param upf2plotcore_code: the aiida-shell ``Code`` instance configured for the + upf2plotcore.sh shell script, used to generate the core + wavefunction data. + :param core_wfc_data: + :param core_hole_pseudos: the core-hole pseudopotential pair (ground-state and + excited-state) for the chosen absorbing atom. + :param protocol: the protocol to use. If not specified, the default will be used. + :param core_hole_treatment: the core-hole treatment desired for the SCF calculation, + using presets found in ``core_hole_treatments.yaml``. + Defaults to "full". Overrides the settings derived from + the ``PwBaseWorkChain`` protocol, but is itself overriden + by the ``overrides`` dictionary. + :param overrides: optional dictionary of inputs to override the defaults of the + XspectraBaseWorkChain itself. + :param options: a dictionary of options that will be recursively set for the # + ``metadata.options`` input of all the ``CalcJobs`` that are nested in + this work chain. + :param run_replot: a bool parameter to request inputs for the re-plot step. + :param kwargs: additional keyword arguments that will be passed to the + ``get_builder_from_protocol`` of all the sub processes that are called by this + workchain. + :return: a process builder instance with all inputs defined ready for launch. + """ + + inputs = cls.get_protocol_inputs(protocol, overrides) + pw_inputs = PwBaseWorkChain.get_protocol_inputs(protocol=protocol) + pw_params = pw_inputs['pw']['parameters'] + kinds_present = sorted([kind.name for kind in structure.kinds]) + # Get the default inputs from the PwBaseWorkChain and override them with those + # required for the chosen core-hole treatment + pw_params = recursive_merge( + left=pw_params, + right=cls.get_treatment_inputs( + treatment=core_hole_treatment, overrides=inputs.get('scf', {}).get('pw', {}).get('parameters', None) + ) + ) + + pw_inputs['pw']['parameters'] = pw_params + + pw_args = (pw_code, structure, protocol) + scf = PwBaseWorkChain.get_builder_from_protocol(*pw_args, overrides=pw_inputs, options=options, **kwargs) + + scf.pop('clean_workdir', None) + scf['pw'].pop('structure', None) + + # pylint: disable=no-member + builder = cls.get_builder() + builder.scf = scf + + xs_prod_inputs = XspectraBaseWorkChain.get_protocol_inputs(protocol, inputs.get('xs_prod')) + xs_prod_parameters = xs_prod_inputs['xspectra']['parameters'] + xs_prod_metadata = xs_prod_inputs['xspectra']['metadata'] + if options: + xs_prod_metadata['options'] = recursive_merge(xs_prod_metadata['options'], options) + + abs_atom_marker = inputs['abs_atom_marker'] + xs_prod_parameters['INPUT_XSPECTRA']['xiabs'] = kinds_present.index(abs_atom_marker) + 1 + if core_hole_pseudos: + for kind in structure.kinds: + if kind.name == abs_atom_marker: + abs_element = kind.symbol + + builder.scf.pw.pseudos[abs_atom_marker] = core_hole_pseudos[abs_atom_marker] + builder.scf.pw.pseudos[abs_element] = core_hole_pseudos[abs_element] + + builder.xs_prod.xspectra.code = xs_code + builder.xs_prod.xspectra.parameters = orm.Dict(xs_prod_parameters) + builder.xs_prod.xspectra.metadata = xs_prod_metadata + if xs_prod_inputs['kpoints_distance']: + builder.xs_prod.kpoints_distance = orm.Float(xs_prod_inputs['kpoints_distance']) + elif xs_prod_inputs['kpoints']: + builder.xs_prod.kpoints = xs_prod_inputs['kpoints'] + + if upf2plotcore_code: + builder.upf2plotcore_code = upf2plotcore_code + elif core_wfc_data: + builder.core_wfc_data = core_wfc_data + else: + raise ValueError( + 'Either a code node for upf2plotcore.sh or an already-generated core-wavefunction' + ' file must be given.' + ) + + builder.structure = structure + builder.eps_vectors = orm.List(list=inputs['eps_vectors']) + builder.clean_workdir = orm.Bool(inputs['clean_workdir']) + builder.get_powder_spectrum = orm.Bool(inputs['get_powder_spectrum']) + builder.abs_atom_marker = orm.Str(abs_atom_marker) + if inputs['run_replot']: + builder.run_replot = orm.Bool(inputs['run_replot']) + xs_plot_inputs = XspectraBaseWorkChain.get_protocol_inputs('replot') + xs_plot_parameters = xs_plot_inputs['xspectra']['parameters'] + xs_plot_metadata = xs_plot_inputs['xspectra']['metadata'] + if options: + xs_plot_metadata['options'] = recursive_merge(xs_plot_metadata['options'], options) + builder.xs_plot.xspectra.code = xs_code + builder.xs_plot.xspectra.parameters = orm.Dict(xs_plot_parameters) + builder.xs_plot.xspectra.metadata = xs_plot_metadata + else: + builder.pop('run_replot', None) + # pylint: enable=no-member + return builder + + @staticmethod + def validate_scf(inputs, _): + """Validate the scf parameters.""" + parameters = inputs['pw']['parameters'].get_dict() + if parameters.get('CONTROL', {}).get('calculation', 'scf') != 'scf': + return '`CONTROL.calculation` in `scf.pw.parameters` is not set to `scf`.' + + @staticmethod + def validate_inputs(inputs, _): + """Validate the inputs before launching the WorkChain.""" + + eps_vector_list = inputs['eps_vectors'].get_list() + if len(eps_vector_list) == 0: + return '`eps_vectors` list empty.' + if 'core_wfc_data' not in inputs and 'upf2plotcore_code' not in inputs: + return 'Either a core wavefunction file or a code node for upf2plotcore.sh must be provided.' + structure = inputs['structure'] + kinds_present = structure.kinds + abs_atom_found = False + for kind in kinds_present: + if kind.name == inputs['abs_atom_marker'].value: + abs_atom_found = True + if not abs_atom_found: + return ( + f'Error: the marker given for the absorbing atom ("{inputs["abs_atom_marker"].value}") ' + + 'does not appear in the structure provided.' + ) + + def setup(self): + """Initialize context variables that are used during the logical flow of the workchain.""" + + self.ctx.dry_run = 'dry_run' in self.inputs and self.inputs.dry_run.value + self.ctx.all_lanczos_computed = False + self.ctx.finished_lanczos = [] + self.ctx.finished_replots = [] + abs_atom_marker = self.inputs.abs_atom_marker + structure = self.inputs.structure + for kind in structure.kinds: + if kind.name == abs_atom_marker: + abs_kind = kind + + self.ctx.abs_kind = abs_kind + if 'core_wfc_data' in self.inputs: + self.ctx.core_wfc_data = self.inputs.core_wfc_data + + def run_scf(self): + """Run an SCF calculation as a first step.""" + + inputs = AttributeDict(self.exposed_inputs(PwBaseWorkChain, 'scf')) + + inputs.metadata.call_link_label = 'scf' + inputs.pw.structure = self.inputs.structure + inputs = prepare_process_inputs(PwBaseWorkChain, inputs) + + if self.ctx.dry_run: + return inputs + + future = self.submit(PwBaseWorkChain, **inputs) + self.report(f'launching SCF PwBaseWorkChain<{future.pk}>') + + return ToContext(scf_workchain=future) + + def inspect_scf(self): + """Verify that the PwBaseWorkChain finished successfully.""" + + workchain = self.ctx.scf_workchain + if not workchain.is_finished_ok: + self.report(f'SCF PwBaseWorkChain failed with exit status {workchain.exit_status}') + return self.exit_codes.ERROR_SUB_PROCESS_FAILED_SCF + + def should_run_upf2plotcore(self): + """Don't calculate the core wavefunction data if one has already been provided.""" + + return 'core_wfc_data' not in self.inputs + + def should_run_replot(self): + """Run the WorkChain as a two-step production + replot process if requested.""" + + return self.inputs.run_replot.value + + def run_upf2plotcore(self): + """Generate the core-wavefunction data on-the-fly, if no data is given in the inputs. + + This will determine which pseudopotential is assigned to the atomic species of the + same element as the absorbing atom, though not the absorbing atom itself, thus the + corresponding species must use a pseudopotential which contains the correct GIPAW + information required by the upf2plotcore.sh helper script. + + As this uses the AiiDA-Shell plugin, we assume that this is already installed. + """ + + ShellJob = CalculationFactory('core.shell') # pylint: disable=invalid-name + + pw_inputs = self.exposed_inputs(PwBaseWorkChain, 'scf') + pseudo_dict = pw_inputs['pw']['pseudos'] + abs_kind = self.ctx.abs_kind + + upf = pseudo_dict[abs_kind.symbol] + + shell_inputs = {} + + shell_inputs['code'] = self.inputs.upf2plotcore_code + shell_inputs['nodes'] = {'upf': upf} + shell_inputs['arguments'] = orm.List(list=['{upf}']) + shell_inputs['metadata'] = {'call_link_label': 'upf2plotcore'} + + shelljob_node = self.submit(ShellJob, **shell_inputs) + self.report(f'Launching ShellJob for upf2plotcore.sh<{shelljob_node.pk}>') + + return ToContext(upf2plotcore_node=shelljob_node) + + def inspect_upf2plotcore(self): + """Check that the output from the upf2plotcore step has yielded a meaningful result. + + This will simply check that the core wavefunction data returned contains at least + one core state and return an error if this is not the case. + """ + + shelljob_node = self.ctx.upf2plotcore_node + core_wfc_data = shelljob_node.outputs.stdout + header_line = shelljob_node.outputs.stdout.get_content()[:40] + num_core_states = int(header_line.split(' ')[5]) + if num_core_states == 0: + return self.exit_codes.ERROR_NO_GIPAW_INFO_FOUND + self.ctx.core_wfc_data = core_wfc_data + + def run_all_xspectra_prod(self): + """Run an `XspectraBaseWorkChain` for each 3-vector given for epsilon.""" + + eps_vectors = self.inputs.eps_vectors.get_list() + parent_folder = self.ctx.scf_workchain.outputs.remote_folder + core_wfc_data = self.ctx.core_wfc_data + + calc_number = 0 + for calc_number, vector in enumerate(eps_vectors): + xspectra_inputs = AttributeDict(self.exposed_inputs(XspectraBaseWorkChain, 'xs_prod')) + xspectra_parameters = xspectra_inputs.xspectra.parameters.get_dict() + + parent_folder = self.ctx.scf_workchain.outputs.remote_folder + xspectra_inputs.xspectra.parent_folder = parent_folder + xspectra_inputs.xspectra.core_wfc_data = core_wfc_data + xspectra_inputs.metadata.call_link_label = f'xas_{calc_number}_prod' + + for index in [0, 1, 2]: + xspectra_parameters['INPUT_XSPECTRA'][f'xepsilon({index + 1})'] = vector[index] + xspectra_inputs.xspectra.parameters = orm.Dict(xspectra_parameters) + + if self.ctx.dry_run: + return xspectra_inputs + + future_xspectra = self.submit(XspectraBaseWorkChain, **xspectra_inputs) + self.to_context(xspectra_prod_calculations=append_(future_xspectra)) + self.report( + f'launching XspectraWorkChain<{future_xspectra.pk}> for epsilon vector {vector}' + ' (Lanczos production)' + ) + + def inspect_all_xspectra_prod(self): + """Verify that the `XspectraBaseWorkChain` Lanczos production sub-processes finished successfully.""" + + calculations = self.ctx.xspectra_prod_calculations + unrecoverable_failures = False # pylint: disable=unused-variable + + for calculation in calculations: + vector = calculation.outputs.output_parameters.get_dict()['xepsilon'] + if not calculation.is_finished_ok: + self.report(f'XspectraBaseWorkChain <{vector}>' + ' failed with exit status {calculation.exit_status}.') + unrecoverable_failures = True + else: + self.report(f'XspectraBaseWorkChain <{vector}> finished successfully.') + self.ctx['finished_lanczos'].append(calculation) + if unrecoverable_failures: + return self.exit_codes.ERROR_SUB_PROCESS_FAILED_XSPECTRA + + self.ctx.all_lanczos_computed = True + + def run_all_xspectra_plot(self): + """Run an `XspectraBaseWorkChain` for each 3-vector given for epsilon to plot the final spectra. + + This part simply convolutes and plots the spectra from the already-computed Lanczos + of ``run_all_xspectra_plot``. Only run if requested via ``run_replot`` in the inputs. + """ + + finished_calculations = self.ctx.finished_lanczos + + core_wfc_data = self.ctx.core_wfc_data + + for calc_number, parent_xspectra in enumerate(finished_calculations): + xspectra_inputs = AttributeDict(self.exposed_inputs(XspectraBaseWorkChain, 'xs_plot')) + # The epsilon vectors are not needed in the case of a replot, however they + # will be needed by the Parser at the end + xspectra_parameters = xspectra_inputs.xspectra.parameters.get_dict() + + parent_output_dict = parent_xspectra.outputs.output_parameters.get_dict() + parent_calc_job = parent_xspectra.outputs.output_parameters.creator + eps_vector = parent_output_dict['xepsilon'] + xspectra_parameters['INPUT_XSPECTRA']['xepsilon(1)'] = eps_vector[0] + xspectra_parameters['INPUT_XSPECTRA']['xepsilon(2)'] = eps_vector[1] + xspectra_parameters['INPUT_XSPECTRA']['xepsilon(3)'] = eps_vector[2] + xspectra_inputs.xspectra.parent_folder = parent_xspectra.outputs.remote_folder + xspectra_inputs.kpoints = parent_calc_job.inputs.kpoints + xspectra_inputs.xspectra.core_wfc_data = core_wfc_data + xspectra_inputs.metadata.call_link_label = f'xas_{calc_number}_plot' + + xspectra_inputs.xspectra.parameters = orm.Dict(xspectra_parameters) + + if self.ctx.dry_run: + return xspectra_inputs + + future_xspectra = self.submit(XspectraBaseWorkChain, **xspectra_inputs) + self.report( + f'launching XspectraBaseWorkChain<{future_xspectra.pk}> for epsilon vector {eps_vector} (Replot)' + ) + self.to_context(xspectra_plot_calculations=append_(future_xspectra)) + + def inspect_all_xspectra_plot(self): + """Verify that the `XspectraBaseWorkChain` re-plot sub-processes finished successfully.""" + + calculations = self.ctx.xspectra_plot_calculations + + finished_replots = [] + unrecoverable_failures = False + for calculation in calculations: + if not calculation.is_finished_ok: + self.report(f'XspectraBaseWorkChain failed with exit status {calculation.exit_status}') + unrecoverable_failures = True + else: + finished_replots.append(calculation) + if unrecoverable_failures: + return self.exit_codes.ERROR_SUB_PROCESS_FAILED_XSPECTRA + self.ctx.finished_replots = finished_replots + + def results(self): + """Attach the important output nodes to the outputs of the WorkChain. + + This will collect the SCF and XSpectra output parameters, as well as the + powder spectrum (if requested) + """ + + xspectra_prod_calcs = self.ctx.finished_lanczos + if self.inputs.run_replot.value: + final_calcs = self.ctx.finished_replots + else: + final_calcs = self.ctx.finished_lanczos + + eps_powder_vectors = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]] + basis_vectors_present = False + for calc in final_calcs: + out_params = calc.outputs.output_parameters + in_params = calc.inputs.xspectra.parameters.get_dict() + eps_vectors = out_params['xepsilon'] + xcoordcrys = out_params['xcoordcrys'] + calc_type = in_params['INPUT_XSPECTRA']['calculation'] + if xcoordcrys is False: + self.report( + 'WARNING: calculations were set to use a cartesian basis instead of a ' + 'crystallographic one. Please use ``"xcoordcrys" : True`` to compute the ' + 'powder spectrum using this WorkChain.' + ) + break + if eps_vectors in eps_powder_vectors and calc_type == 'xanes_dipole': + basis_vectors_present = True + + eps_basis_calcs = {} + if self.inputs.get_powder_spectrum and basis_vectors_present: + a_vector_present = False + b_vector_present = False + c_vector_present = False + for plot_calc in final_calcs: + out_params = plot_calc.outputs.output_parameters + plot_vector = out_params['xepsilon'] + spectrum_node = plot_calc.outputs.spectra + if plot_vector in eps_powder_vectors: + if plot_vector == [1., 0., 0.]: + eps_basis_calcs['eps_100'] = spectrum_node + a_vector_present = True + if plot_vector == [0., 1., 0.]: + eps_basis_calcs['eps_010'] = spectrum_node + b_vector_present = True + if plot_vector == [0., 0., 1.]: + eps_basis_calcs['eps_001'] = spectrum_node + c_vector_present = True + + # Here, we control for the case where the A and B vectors are given, but C is + # missing, which would cause a problem for ``get_powder_spectrum`` + if a_vector_present and b_vector_present and not c_vector_present: + self.report( + 'WARNING: epsilon vectors for [1.0 0.0 0.0] and [0.0 1.0 0.0] were ' + 'found, but not for [0.0 0.0 1.0]. Please ensure that the vectors ' + 'perpendicular and parallel to the C-axis are defined in the case ' + 'of a system with dichorism.' + ) + else: + eps_basis_calcs['metadata'] = {'call_link_label': 'get_powder_spectrum'} + powder_spectrum = get_powder_spectrum(**eps_basis_calcs) + self.out('powder_spectrum', powder_spectrum) + elif self.inputs.get_powder_spectrum and not basis_vectors_present: + self.report( + 'WARNING: A powder spectrum was requested, but none of the epsilon vectors ' + 'given are suitable to compute this.' + ) + + self.out('parameters_scf', self.ctx.scf_workchain.outputs.output_parameters) + + all_xspectra_prod_calcs = {} + for index, calc in enumerate(xspectra_prod_calcs): + all_xspectra_prod_calcs[f'xas_{index}'] = calc + + xspectra_prod_params = {} + for key, node in all_xspectra_prod_calcs.items(): + output_params = node.outputs.output_parameters + xspectra_prod_params[key] = output_params + self.out('parameters_xspectra', xspectra_prod_params) + + all_final_spectra = {} + for index, calc in enumerate(final_calcs): + all_final_spectra[f'xas_{index}'] = calc.outputs.spectra + + all_final_spectra['metadata'] = {'call_link_label': 'merge_spectra'} + output_spectra = merge_spectra(**all_final_spectra) + + self.out('spectra', output_spectra) + + def on_terminated(self): + """Clean the working directories of all child calculations if ``clean_workdir=True`` in the inputs.""" + + super().on_terminated() + + if self.inputs.clean_workdir.value is False: + self.report('remote folders will not be cleaned') + return + + cleaned_calcs = [] + + for called_descendant in self.node.called_descendants: + if isinstance(called_descendant, orm.CalcJobNode): + try: + called_descendant.outputs.remote_folder._clean() # pylint: disable=protected-access + cleaned_calcs.append(called_descendant.pk) + except (IOError, OSError, KeyError): + pass + + if cleaned_calcs: + self.report(f"cleaned remote folders of calculations: {' '.join(map(str, cleaned_calcs))}") diff --git a/tests/conftest.py b/tests/conftest.py index 73a1538dd..dc32a61ce 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -79,7 +79,7 @@ def serialize_builder(): def serialize_data(data): # pylint: disable=too-many-return-statements - from aiida.orm import AbstractCode, BaseType, Data, Dict, KpointsData, RemoteData, SinglefileData + from aiida.orm import AbstractCode, BaseType, Data, Dict, KpointsData, List, RemoteData, SinglefileData from aiida.plugins import DataFactory StructureData = DataFactory('core.structure') @@ -97,6 +97,9 @@ def serialize_data(data): if isinstance(data, Dict): return data.get_dict() + if isinstance(data, List): + return data.get_list() + if isinstance(data, StructureData): return data.get_formula() diff --git a/tests/workflows/protocols/xspectra/test_core.py b/tests/workflows/protocols/xspectra/test_core.py new file mode 100644 index 000000000..25810d7e5 --- /dev/null +++ b/tests/workflows/protocols/xspectra/test_core.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +"""Tests for the ``XspectraCoreWorkChain.get_builder_from_protocol`` method.""" +import io + +from aiida.engine import ProcessBuilder +from aiida.orm import SinglefileData + +from aiida_quantumespresso.workflows.xspectra.core import XspectraCoreWorkChain + + +def test_get_available_protocols(): + """Test ``XspectraCoreWorkChain.get_available_protocols``.""" + protocols = XspectraCoreWorkChain.get_available_protocols() + assert sorted(protocols.keys()) == ['fast', 'moderate', 'precise'] + assert all('description' in protocol for protocol in protocols.values()) + + +def test_get_default_protocol(): + """Test ``XspectraCoreWorkChain.get_default_protocol``.""" + assert XspectraCoreWorkChain.get_default_protocol() == 'moderate' + + +def test_get_available_treatments(): + """Test ``XspectraCoreWorkChain.get_available_treatments``.""" + treatments = XspectraCoreWorkChain.get_available_treatments() + assert sorted(treatments.keys()) == ['full', 'half', 'none', 'xch_fixed', 'xch_smear'] + assert all('description' in treatment for treatment in treatments.values()) + + +def test_get_default_treatment(): + """Test ``XspectraCoreWorkChain.get_default_treatment``.""" + assert XspectraCoreWorkChain.get_default_treatment() == 'full' + + +def test_default(fixture_code, generate_structure, data_regression, serialize_builder): + """Test ``XspectraCoreWorkChain.get_builder_from_protocol`` for the default protocol.""" + pw_code = fixture_code('quantumespresso.pw') + xs_code = fixture_code('quantumespresso.xspectra') + structure = generate_structure('silicon') + core_wfc_data = SinglefileData( + io.StringIO( + '# number of core states 3 = 1 0; 2 0;' + '\n6.51344e-05 6.615743462459999e-3' + '\n6.59537e-05 6.698882211449999e-3' + ) + ) + overrides = {'abs_atom_marker': 'Si'} + builder = XspectraCoreWorkChain.get_builder_from_protocol( + pw_code=pw_code, xs_code=xs_code, core_wfc_data=core_wfc_data, structure=structure, overrides=overrides + ) + + assert isinstance(builder, ProcessBuilder) + data_regression.check(serialize_builder(builder)) diff --git a/tests/workflows/protocols/xspectra/test_core/test_default.yml b/tests/workflows/protocols/xspectra/test_core/test_default.yml new file mode 100644 index 000000000..c9619e326 --- /dev/null +++ b/tests/workflows/protocols/xspectra/test_core/test_default.yml @@ -0,0 +1,78 @@ +abs_atom_marker: Si +clean_workdir: true +core_wfc_data: '# number of core states 3 = 1 0; 2 0; + + 6.51344e-05 6.615743462459999e-3 + + 6.59537e-05 6.698882211449999e-3' +eps_vectors: +- - 1.0 + - 0.0 + - 0.0 +- - 0.0 + - 1.0 + - 0.0 +- - 0.0 + - 0.0 + - 1.0 +get_powder_spectrum: false +scf: + kpoints_distance: 0.15 + kpoints_force_parity: false + pw: + code: test.quantumespresso.pw@localhost + metadata: + options: + max_wallclock_seconds: 43200 + resources: + num_machines: 1 + withmpi: true + parameters: + CONTROL: + calculation: scf + etot_conv_thr: 2.0e-05 + forc_conv_thr: 0.0001 + tprnfor: true + tstress: true + ELECTRONS: + conv_thr: 4.0e-10 + electron_maxstep: 80 + mixing_beta: 0.4 + SYSTEM: + degauss: 0.01 + ecutrho: 240.0 + ecutwfc: 30.0 + nosym: false + occupations: smearing + smearing: cold + tot_charge: 1 + pseudos: + Si: Si +structure: Si2 +xs_prod: + kpoints_distance: 0.15 + xspectra: + code: test.quantumespresso.xspectra@localhost + metadata: + options: + max_wallclock_seconds: 43200 + resources: + num_machines: 1 + withmpi: true + parameters: + CUT_OCC: + cut_desmooth: 0.1 + INPUT_XSPECTRA: + calculation: xanes_dipole + xcheck_conv: 10 + xerror: 0.001 + xiabs: 1 + xniter: 2000 + xonly_plot: false + PLOT: + cut_occ_states: true + terminator: true + xemax: 30 + xemin: -10 + xgamma: 0.5 + xnepoint: 2000 diff --git a/tests/workflows/xspectra/test_core.py b/tests/workflows/xspectra/test_core.py new file mode 100644 index 000000000..1615eef5e --- /dev/null +++ b/tests/workflows/xspectra/test_core.py @@ -0,0 +1,150 @@ +# -*- coding: utf-8 -*- +"""Tests for the `XspectraCoreWorkChain` class.""" +import io + +from aiida import engine, orm +from aiida.common import LinkType +from aiida.manage.manager import get_manager +from plumpy import ProcessState +import pytest + + +def instantiate_process(cls_or_builder, inputs=None): + """Instantiate a process, from a ``Process`` class or ``ProcessBuilder`` instance.""" + from aiida.engine.utils import instantiate_process as _instantiate_process + manager = get_manager() + runner = manager.get_runner() + return _instantiate_process(runner, cls_or_builder, **(inputs or {})) + + +@pytest.fixture +def generate_workchain_xspectra_core(generate_inputs_pw, generate_workchain, generate_inputs_xspectra): + """Generate an instance of a `XspectraCoreWorkChain`.""" + + def _generate_workchain_xspectra_core(): + from aiida.orm import Bool, List, SinglefileData, Str + + entry_point = 'quantumespresso.xspectra.core' + + scf_pw_inputs = generate_inputs_pw() + xs_prod_inputs = generate_inputs_xspectra() + + xs_prod = {'xspectra': xs_prod_inputs} + xs_prod_inputs.pop('parent_folder') + xs_prod_inputs.pop('kpoints') + + kpoints = scf_pw_inputs.pop('kpoints') + structure = scf_pw_inputs.pop('structure') + scf = {'pw': scf_pw_inputs, 'kpoints': kpoints} + + inputs = { + 'structure': + structure, + 'scf': + scf, + 'xs_prod': + xs_prod, + 'xs_plot': + xs_prod, + 'run_replot': + Bool(False), + 'dry_run': + Bool(True), + 'abs_atom_marker': + Str('Si'), + 'core_wfc_data': + SinglefileData( + io.StringIO( + '# number of core states 3 = 1 0; 2 0;' + '\n6.51344e-05 6.615743462459999e-3' + '\n6.59537e-05 6.698882211449999e-3' + ) + ), + 'eps_vectors': + List(list=[[1., 0., 0.]]) + } + + return generate_workchain(entry_point, inputs) + + return _generate_workchain_xspectra_core + + +def test_default( + generate_inputs_xspectra, + generate_workchain_pw, + generate_workchain_xspectra, + generate_workchain_xspectra_core, #pylint: disable=redefined-outer-name + fixture_localhost, + generate_remote_data, + generate_calc_job_node, + generate_xy_data, +): + """Test instantiating the WorkChain, then mock its process, by calling methods in the ``spec.outline``.""" + + wkchain = generate_workchain_xspectra_core() + + assert wkchain.setup() is None + assert wkchain.should_run_upf2plotcore() is False + assert wkchain.should_run_replot() is False + + # run scf + scf_inputs = wkchain.run_scf() + + scf_wkchain = generate_workchain_pw(inputs=scf_inputs) + scf_wkchain.node.set_process_state(ProcessState.FINISHED) + scf_wkchain.node.set_exit_status(0) + + remote = generate_remote_data(computer=fixture_localhost, remote_path='/path/on/remote') + remote.store() + remote.base.links.add_incoming(scf_wkchain.node, link_type=LinkType.RETURN, link_label='remote_folder') + + result = orm.Dict() + result.store() + result.base.links.add_incoming(scf_wkchain.node, link_type=LinkType.RETURN, link_label='output_parameters') + + wkchain.ctx.scf_workchain = scf_wkchain.node + wkchain.ctx.scf_parent_folder = remote + + assert wkchain.inspect_scf() is None + + # mock run the xs_prod step + xs_prod_inputs = wkchain.run_all_xspectra_prod() + # mock xs_prod outputs + xs_prod_wc = generate_workchain_xspectra(inputs=xs_prod_inputs) + xs_prod_node = xs_prod_wc.node + xs_prod_node.label = 'xas_0_prod' + xs_prod_node.store() + xs_prod_node.set_exit_status(0) + xs_prod_node.set_process_state(engine.ProcessState.FINISHED) + + # mock an XspectraCalculation node, as the WorkChain needs one, since it's looking for a + # "creator" node in the final steps + xspectra_node = generate_calc_job_node( + entry_point_name='quantumespresso.xspectra', inputs=generate_inputs_xspectra() + ) + xspectra_node.store() + + result = orm.Dict(dict={'xepsilon': [1., 0., 0.], 'xcoordcrys': False}) + result.store() + result.base.links.add_incoming(xs_prod_node, link_type=LinkType.RETURN, link_label='output_parameters') + result.base.links.add_incoming(xspectra_node, link_type=LinkType.CREATE, link_label='output_parameters') + + spectra = generate_xy_data() + spectra.store() + spectra.base.links.add_incoming(xs_prod_node, link_type=LinkType.RETURN, link_label='spectra') + spectra.base.links.add_incoming(xspectra_node, link_type=LinkType.CREATE, link_label='spectra') + + wkchain.ctx.xspectra_prod_calculations = [ + xs_prod_node, + ] + + assert wkchain.inspect_all_xspectra_prod() is None + assert wkchain.ctx.all_lanczos_computed is True + + # process results + wkchain.results() + + wkchain.update_outputs() + + assert set(wkchain.node.base.links.get_outgoing().all_link_labels() + ) == {'parameters_scf', 'parameters_xspectra__xas_0', 'spectra'}