Skip to content

Commit

Permalink
👌 Add fixes for SIRIUS FORTRAN XML issue
Browse files Browse the repository at this point in the history
  • Loading branch information
mbercx committed Dec 20, 2023
1 parent b9e4215 commit 0d668ec
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 19 deletions.
6 changes: 1 addition & 5 deletions src/aiida_quantumespresso/parsers/parse_xml/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,7 @@ def parse_xml_post_6_2(xml):
# xml_dictionary['key']['@attr'] returns its attribute 'attr'
# xml_dictionary['key']['nested_key'] goes one level deeper.

xml_dictionary, errors = xsd.to_dict(xml, validation='lax')
if errors:
logs.error.append(f'{len(errors)} XML schema validation error(s) schema: {schema_filepath}:')
for err in errors:
logs.error.append(str(err))
xml_dictionary = xsd.to_dict(xml, validation='skip')

xml_version = Version(xml_dictionary['general_info']['xml_format']['@VERSION'])
inputs = xml_dictionary.get('input', {})
Expand Down
58 changes: 44 additions & 14 deletions src/aiida_quantumespresso/parsers/pw.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,28 @@
from .parse_raw.pw import reduce_symmetries


def fix_sirius_xml_prints(array):
"""Fix an issue where SIRIUS prints very small numbers incorrectly.
In some cases SIRIUS prints very small numbers in scientific notation, but
without the capital `E` to indicate the exponent:
<forces rank="2" dims=" 3 6">
3.220681500679849E-86 4.347898683069749-103 -3.696810758148386E-49
This function will fix this by converting any number that cannot be converted into
a float to zero.
"""
def try_convert(s):
try:
return float(s)
except (ValueError, TypeError):
return 0

return numpy.vectorize(try_convert)(array)


class PwParser(BaseParser):
"""`Parser` implementation for the `PwCalculation` calculation job class."""

Expand All @@ -40,11 +62,11 @@ def parse(self, **kwargs):
parser_options = settings.get(self.get_parser_settings_key(), None)

# Verify that the retrieved_temporary_folder is within the arguments if temporary files were specified
if self.node.base.attributes.get('retrieve_temporary_list', None):
try:
dir_with_bands = kwargs['retrieved_temporary_folder']
except KeyError:
return self.exit(self.exit_codes.ERROR_NO_RETRIEVED_TEMPORARY_FOLDER)
# if self.node.base.attributes.get('retrieve_temporary_list', None):
# try:
# dir_with_bands = kwargs['retrieved_temporary_folder']
# except KeyError:
# return self.exit(self.exit_codes.ERROR_NO_RETRIEVED_TEMPORARY_FOLDER)

# We check if the `CRASH` file was retrieved. If so, we parse its output
crash_file_filename = self.node.process_class._CRASH_FILE
Expand Down Expand Up @@ -512,13 +534,19 @@ def build_output_trajectory(parsed_trajectory, structure):
trajectory = orm.TrajectoryData()
trajectory.set_trajectory(
stepids=stepids,
cells=cells,
cells=fix_sirius_xml_prints(cells),
symbols=symbols,
positions=positions,
positions=fix_sirius_xml_prints(positions),
)

for key, value in parsed_trajectory.items():
trajectory.set_array(key, numpy.array(value))
if key in (
'forces',
'stress'
):
trajectory.set_array(key, fix_sirius_xml_prints(numpy.array(value)))
else:
trajectory.set_array(key, numpy.array(value))

return trajectory

Expand Down Expand Up @@ -570,14 +598,16 @@ def build_output_bands(self, parsed_bands, parsed_kpoints=None):

# Correct the occupation for nspin=1 calculations where Quantum ESPRESSO populates each band only halfway
if len(parsed_bands['occupations']) > 1:
occupations = parsed_bands['occupations']
occupations = numpy.array(parsed_bands['occupations'])
else:
occupations = 2. * numpy.array(parsed_bands['occupations'][0])
occupations = numpy.array(parsed_bands['occupations'][0])

if len(parsed_bands['bands']) > 1:
bands_energies = parsed_bands['bands']
else:
bands_energies = parsed_bands['bands'][0]
occupations = fix_sirius_xml_prints(occupations)

if len(parsed_bands['occupations']) > 1:
occupations *= 2.

bands_energies = parsed_bands['bands'][0] if len(parsed_bands['bands']) == 1 else parsed_bands['bands']

bands = orm.BandsData()
bands.set_kpointsdata(parsed_kpoints)
Expand Down

0 comments on commit 0d668ec

Please sign in to comment.