Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

XspectraCrystalWorkChain: Enable Symmetry Data Inputs #1028

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ def get_xspectra_structures(structure, **kwargs): # pylint: disable=too-many-st
new_supercell = get_supercell_result['new_supercell']
output_params['supercell_factors'] = multiples

result['supercell'] = new_supercell
output_params['supercell_num_sites'] = len(new_supercell.sites)
output_params['supercell_cell_matrix'] = new_supercell.cell
output_params['supercell_cell_lengths'] = new_supercell.cell_lengths
Expand Down
105 changes: 81 additions & 24 deletions src/aiida_quantumespresso/workflows/xspectra/crystal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
Uses QuantumESPRESSO pw.x and xspectra.x.
"""
from aiida import orm
from aiida.common import AttributeDict, ValidationError
from aiida.common import AttributeDict
from aiida.engine import ToContext, WorkChain, if_
from aiida.orm import UpfData as aiida_core_upf
from aiida.plugins import CalculationFactory, DataFactory, WorkflowFactory
Expand Down Expand Up @@ -173,6 +173,19 @@ def define(cls, spec):
help=('Input namespace to provide core wavefunction inputs for each element. Must follow the format: '
'``core_wfc_data__{symbol} = {node}``')
)
spec.input_namespace(
'symmetry_data',
valid_type=(orm.Dict, orm.Int),
dynamic=True,
required=False,
help=(
'Input namespace to define equivalent sites and spacegroup number for the system. If defined, will '
'skip symmetry analysis and structure standardization. Use *only* if symmetry data are known '
'for certain. Requires ``spacegroup_number`` (Int) and ``equivalent_sites_data`` (Dict) to be '
'defined separately. All keys in `equivalent_sites_data` must be formatted as "site_<site_index>". '
'See docstring of `get_xspectra_structures` for more information about inputs.'
)
)
spec.inputs.validator = cls.validate_inputs
spec.outline(
cls.setup,
Expand Down Expand Up @@ -370,7 +383,7 @@ def get_builder_from_protocol( # pylint: disable=too-many-statements


@staticmethod
def validate_inputs(inputs, _):
def validate_inputs(inputs, _): # pylint: disable=too-many-return-statements
"""Validate the inputs before launching the WorkChain."""
structure = inputs['structure']
kinds_present = [kind.name for kind in structure.kinds]
Expand All @@ -382,54 +395,92 @@ def validate_inputs(inputs, _):
if element not in elements_present:
extra_elements.append(element)
if len(extra_elements) > 0:
raise ValidationError(
return (
f'Some elements in ``elements_list`` {extra_elements} do not exist in the'
f' structure provided {elements_present}.'
)

abs_atom_marker = inputs['abs_atom_marker'].value
if abs_atom_marker in kinds_present:
raise ValidationError(
return (
f'The marker given for the absorbing atom ("{abs_atom_marker}") matches an existing Kind in the '
f'input structure ({kinds_present}).'
)

if not inputs['core']['get_powder_spectrum'].value:
raise ValidationError(
return (
'The ``get_powder_spectrum`` input for the XspectraCoreWorkChain namespace must be ``True``.'
)

if 'upf2plotcore_code' not in inputs and 'core_wfc_data' not in inputs:
raise ValidationError(
return (
'Neither a ``Code`` node for upf2plotcore.sh or a set of ``core_wfc_data`` were provided.'
)

if 'core_wfc_data' in inputs:
core_wfc_data_list = sorted(inputs['core_wfc_data'].keys())
if core_wfc_data_list != absorbing_elements_list:
raise ValidationError(
return (
f'The ``core_wfc_data`` provided ({core_wfc_data_list}) does not match the list of'
f' absorbing elements ({absorbing_elements_list})'
)
else:
empty_core_wfc_data = []
for key, value in inputs['core_wfc_data'].items():
header_line = value.get_content()[:40]
try:
num_core_states = int(header_line.split(' ')[5])
except Exception as exc:
raise ValidationError(
'The core wavefunction data file is not of the correct format'
) from exc
if num_core_states == 0:
empty_core_wfc_data.append(key)
if len(empty_core_wfc_data) > 0:
raise ValidationError(
f'The ``core_wfc_data`` provided for elements {empty_core_wfc_data} do not contain '
'any wavefunction data.'
)
empty_core_wfc_data = []
for key, value in inputs['core_wfc_data'].items():
header_line = value.get_content()[:40]
try:
num_core_states = int(header_line.split(' ')[5])
except: # pylint: disable=bare-except
return (
'The core wavefunction data file is not of the correct format'
) # pylint: enable=bare-except
if num_core_states == 0:
empty_core_wfc_data.append(key)
if len(empty_core_wfc_data) > 0:
return (
f'The ``core_wfc_data`` provided for elements {empty_core_wfc_data} do not contain '
'any wavefunction data.'
)

if 'symmetry_data' in inputs:
spacegroup_number = inputs['symmetry_data']['spacegroup_number'].value
equivalent_sites_data = inputs['symmetry_data']['equivalent_sites_data'].get_dict()
if spacegroup_number <= 0 or spacegroup_number >= 231:
return (
f'Input spacegroup number ({spacegroup_number}) outside of valid range (1-230).'
)

input_elements = []
required_keys = sorted(['symbol', 'multiplicity', 'kind_name', 'site_index'])
invalid_entries = []
# We check three things here: (1) are there any site indices which are outside of the possible
# range of site indices (2) do we have all the required keys for each entry,
# and (3) is there a mismatch between `absorbing_elements_list` and the elements specified
# in the entries of `equivalent_sites_data`. These checks are intended only to avoid a crash.
# We assume otherwise that the user knows what they're doing and has set everything else
# to their preferences correctly.
for site_label, value in equivalent_sites_data.items():
if not set(required_keys).issubset(set(value.keys())) :
invalid_entries.append(site_label)
elif value['symbol'] not in input_elements:
input_elements.append(value['symbol'])
if value['site_index'] < 0 or value['site_index'] >= len(structure.sites):
return (
f'The site index for {site_label} ({value["site_index"]}) is outside the range of '
+ f'sites within the structure (0-{len(structure.sites) -1}).'
)

if len(invalid_entries) != 0:
return (
f'The required keys ({required_keys}) were not found in the following entries: {invalid_entries}'
)

sorted_input_elements = sorted(input_elements)
if sorted_input_elements != absorbing_elements_list:
return (f'Elements defined for sites in `equivalent_sites_data` ({sorted_input_elements}) '
f'do not match the list of absorbing elements ({absorbing_elements_list})')


# pylint: enable=too-many-return-statements
def setup(self):
"""Set required context variables."""
if 'core_wfc_data' in self.inputs.keys():
Expand Down Expand Up @@ -489,6 +540,12 @@ def get_xspectra_structures(self):
if 'spglib_settings' in self.inputs:
inputs['spglib_settings'] = self.inputs.spglib_settings

if 'symmetry_data' in self.inputs:
inputs['parse_symmetry'] = orm.Bool(False)
input_sym_data = self.inputs.symmetry_data
inputs['equivalent_sites_data'] = input_sym_data['equivalent_sites_data']
inputs['spacegroup_number'] = input_sym_data['spacegroup_number']

if 'relax' in self.inputs:
result = get_xspectra_structures(self.ctx.optimized_structure, **inputs)
else:
Expand Down
Loading