Skip to content

Commit

Permalink
Implement support for the REFTRAJ simpulation (#207)
Browse files Browse the repository at this point in the history
- Add optional `trajectory` input of type`TrajectoryData` to the inputs to the cp2k calculation,
which will further be transformed into `aiida-reftraj.xyz` and `aiida-reftraj.cell`.
- Update the restart handler that specifies the `EXT_RESTART` sections explicitly.
- Update the restart handler to make it understand that the MD simulation produced some steps.
- Add an example of a reftraj calculation that also does a restart.

---------
Co-authored-by: Aliaksandr Yakutovich <yakutovicha@gmail.com>
  • Loading branch information
cpignedoli authored Mar 6, 2024
1 parent 01564a7 commit 33fd994
Show file tree
Hide file tree
Showing 6 changed files with 320 additions and 27 deletions.
71 changes: 68 additions & 3 deletions aiida_cp2k/calculations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from operator import add

import numpy as np
from aiida.common import CalcInfo, CodeInfo, InputValidationError
from aiida.engine import CalcJob
from aiida.orm import Dict, RemoteData, SinglefileData
Expand All @@ -25,6 +26,7 @@

BandsData = DataFactory("core.array.bands")
StructureData = DataFactory("core.structure")
TrajectoryData = DataFactory("core.array.trajectory")
KpointsData = DataFactory("core.array.kpoints")


Expand All @@ -44,7 +46,9 @@ class Cp2kCalculation(CalcJob):
_DEFAULT_TRAJECT_FORCES_FILE_NAME = _DEFAULT_PROJECT_NAME + "-frc-1.xyz"
_DEFAULT_TRAJECT_CELL_FILE_NAME = _DEFAULT_PROJECT_NAME + "-1.cell"
_DEFAULT_PARENT_CALC_FLDR_NAME = "parent_calc/"
_DEFAULT_COORDS_FILE_NAME = "aiida.coords.xyz"
_DEFAULT_COORDS_FILE_NAME = _DEFAULT_PROJECT_NAME + ".coords.xyz"
_DEFAULT_INPUT_TRAJECT_XYZ_FILE_NAME = _DEFAULT_PROJECT_NAME + "-reftraj.xyz"
_DEFAULT_INPUT_CELL_FILE_NAME = _DEFAULT_PROJECT_NAME + "-reftraj.cell"
_DEFAULT_PARSER = "cp2k_base_parser"

@classmethod
Expand All @@ -59,6 +63,12 @@ def define(cls, spec):
required=False,
help="The main input structure.",
)
spec.input(
"trajectory",
valid_type=TrajectoryData,
required=False,
help="Input trajectory for a REFTRAJ simulation.",
)
spec.input(
"settings",
valid_type=Dict,
Expand Down Expand Up @@ -219,6 +229,12 @@ def define(cls, spec):
required=False,
help="The relaxed output structure.",
)
spec.output(
"output_trajectory",
valid_type=TrajectoryData,
required=False,
help="The output trajectory.",
)
spec.output(
"output_bands",
valid_type=BandsData,
Expand Down Expand Up @@ -270,6 +286,15 @@ def prepare_for_submission(self, folder):
conflicting_keys=["COORDINATE"],
)

# Create input trajectory files
if "trajectory" in self.inputs:
self._write_trajectories(
self.inputs.trajectory,
folder,
self._DEFAULT_INPUT_TRAJECT_XYZ_FILE_NAME,
self._DEFAULT_INPUT_CELL_FILE_NAME,
)

if "basissets" in self.inputs:
validate_basissets(
inp,
Expand Down Expand Up @@ -388,6 +413,19 @@ def _write_structure(structure, folder, name):
with open(folder.get_abs_path(name), mode="w", encoding="utf-8") as fobj:
fobj.write(xyz)

@staticmethod
def _write_trajectories(trajectory, folder, name_pos, name_cell):
"""Function that writes a structure and takes care of element tags."""

(xyz, cell) = _trajectory_to_xyz_and_cell(trajectory)
with open(folder.get_abs_path(name_pos), mode="w", encoding="utf-8") as fobj:
fobj.write(xyz)
if cell is not None:
with open(
folder.get_abs_path(name_cell), mode="w", encoding="utf-8"
) as fobj:
fobj.write(cell)


def kind_names(atoms):
"""Get atom kind names from ASE atoms based on tags.
Expand All @@ -402,7 +440,7 @@ def kind_names(atoms):
return list(map(add, atoms.get_chemical_symbols(), elem_tags))


def _atoms_to_xyz(atoms):
def _atoms_to_xyz(atoms, infoline="No info"):
"""Converts ASE atoms to string, taking care of element tags.
:param atoms: ASE Atoms instance
Expand All @@ -412,6 +450,33 @@ def _atoms_to_xyz(atoms):
elem_coords = [
f"{p[0]:25.16f} {p[1]:25.16f} {p[2]:25.16f}" for p in atoms.get_positions()
]
xyz = f"{len(elem_coords)}\n\n"
xyz = f"{len(elem_coords)}\n"
xyz += f"{infoline}\n"
xyz += "\n".join(map(add, elem_symbols, elem_coords))
return xyz


def _trajectory_to_xyz_and_cell(trajectory):
"""Converts postions and cell from a TrajectoryData to string, taking care of element tags from ASE atoms.
:param atoms: ASE Atoms instance
:param trajectory: TrajectoryData instance
:returns: positions str (in xyz format) and cell str
"""
cell = None
xyz = ""
stepids = trajectory.get_stepids()
for i, step in enumerate(stepids):
xyz += _atoms_to_xyz(
trajectory.get_step_structure(i).get_ase(),
infoline=f"i = {step+1} , time = {(step+1)*0.5}", # reftraj trajectories cannot start from STEP 0
)
xyz += "\n"
if "cells" in trajectory.get_arraynames():
cell = "# Step Time [fs] Ax [Angstrom] Ay [Angstrom] Az [Angstrom] Bx [Angstrom] By [Angstrom] Bz [Angstrom] Cx [Angstrom] Cy [Angstrom] Cz [Angstrom] Volume [Angstrom^3]\n"
cell_vecs = [
f"{stepid+1} {(stepid+1)*0.5:6.3f} {cellvec[0][0]:25.16f} {cellvec[0][1]:25.16f} {cellvec[0][2]:25.16f} {cellvec[1][0]:25.16f} {cellvec[1][1]:25.16f} {cellvec[1][2]:25.16f} {cellvec[2][0]:25.16f} {cellvec[2][1]:25.16f} {cellvec[2][2]:25.16f} {np.dot(cellvec[0],np.cross(cellvec[1],cellvec[2]))}"
for (stepid, cellvec) in zip(stepids, trajectory.get_array("cells"))
]
cell += "\n".join(cell_vecs)
return xyz, cell
8 changes: 7 additions & 1 deletion aiida_cp2k/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
###############################################################################
"""AiiDA-CP2K utils"""

from .input_generator import Cp2kInput, add_ext_restart_section, add_wfn_restart_section
from .input_generator import (
Cp2kInput,
add_ext_restart_section,
add_first_snapshot_in_reftraj_section,
add_wfn_restart_section,
)
from .parser import parse_cp2k_output, parse_cp2k_output_advanced, parse_cp2k_trajectory
from .workchains import (
HARTREE2EV,
Expand All @@ -23,6 +28,7 @@
__all__ = [
"Cp2kInput",
"add_ext_restart_section",
"add_first_snapshot_in_reftraj_section",
"add_wfn_restart_section",
"parse_cp2k_output",
"parse_cp2k_output_advanced",
Expand Down
19 changes: 18 additions & 1 deletion aiida_cp2k/utils/input_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,5 +211,22 @@ def add_ext_restart_section(input_dict):
"""Add external restart section to the input dictionary."""
params = input_dict.get_dict()
# overwrite the complete EXT_RESTART section if present
params["EXT_RESTART"] = {"RESTART_FILE_NAME": "./parent_calc/aiida-1.restart"}
params["EXT_RESTART"] = {
"RESTART_FILE_NAME": "./parent_calc/aiida-1.restart",
"RESTART_DEFAULT": ".TRUE.",
"RESTART_COUNTERS": ".TRUE.",
"RESTART_POS": ".TRUE.",
"RESTART_VEL": ".TRUE.",
"RESTART_CELL": ".TRUE.",
"RESTART_THERMOSTAT": ".TRUE.",
"RESTART_CONSTRAINT": ".FALSE.",
}
return Dict(params)


@calcfunction
def add_first_snapshot_in_reftraj_section(input_dict, first_snapshot):
"""Add first_snapshot in REFTRAJ section to the input dictionary."""
params = input_dict.get_dict()
params["MOTION"]["MD"]["REFTRAJ"]["FIRST_SNAPSHOT"] = first_snapshot
return Dict(params)
2 changes: 1 addition & 1 deletion aiida_cp2k/utils/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def parse_cp2k_output_advanced(

# If a tag has been detected, now read the following line knowing what they are
if line_is in ["eigen_spin1_au", "eigen_spin2_au"]:
if "------" in line:
if "------" in line or "*** WARNING" in line:
continue
splitted_line = line.split()
try:
Expand Down
42 changes: 21 additions & 21 deletions aiida_cp2k/workchains/base.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,13 @@
"""Base work chain to run a CP2K calculation."""

from aiida.common import AttributeDict
from aiida.engine import (
BaseRestartWorkChain,
ProcessHandlerReport,
process_handler,
while_,
)
from aiida.orm import Bool, Dict
from aiida.plugins import CalculationFactory
from aiida import common, engine, orm, plugins

from ..utils import add_ext_restart_section, add_wfn_restart_section
from .. import utils

Cp2kCalculation = CalculationFactory('cp2k')
Cp2kCalculation = plugins.CalculationFactory('cp2k')


class Cp2kBaseWorkChain(BaseRestartWorkChain):
class Cp2kBaseWorkChain(engine.BaseRestartWorkChain):
"""Workchain to run a CP2K calculation with automated error handling and restarts."""

_process_class = Cp2kCalculation
Expand All @@ -28,7 +20,7 @@ def define(cls, spec):

spec.outline(
cls.setup,
while_(cls.should_run_process)(
engine.while_(cls.should_run_process)(
cls.run_process,
cls.inspect_process,
cls.overwrite_input_structure,
Expand All @@ -37,7 +29,7 @@ def define(cls, spec):
)

spec.expose_outputs(Cp2kCalculation)
spec.output('final_input_parameters', valid_type=Dict, required=False,
spec.output('final_input_parameters', valid_type=orm.Dict, required=False,
help='The input parameters used for the final calculation.')
spec.exit_code(400, 'NO_RESTART_DATA', message="The calculation didn't produce any data to restart from.")
spec.exit_code(300, 'ERROR_UNRECOVERABLE_FAILURE',
Expand All @@ -52,7 +44,7 @@ def setup(self):
internal loop.
"""
super().setup()
self.ctx.inputs = AttributeDict(self.exposed_inputs(Cp2kCalculation, 'cp2k'))
self.ctx.inputs = common.AttributeDict(self.exposed_inputs(Cp2kCalculation, 'cp2k'))

def results(self):
super().results()
Expand All @@ -63,7 +55,7 @@ def overwrite_input_structure(self):
if "output_structure" in self.ctx.children[self.ctx.iteration-1].outputs:
self.ctx.inputs.structure = self.ctx.children[self.ctx.iteration-1].outputs.output_structure

@process_handler(priority=401, exit_codes=[
@engine.process_handler(priority=401, exit_codes=[
Cp2kCalculation.exit_codes.ERROR_OUT_OF_WALLTIME,
Cp2kCalculation.exit_codes.ERROR_OUTPUT_INCOMPLETE,
], enabled=False)
Expand All @@ -72,7 +64,7 @@ def restart_incomplete_calculation(self, calc):
content_string = calc.outputs.retrieved.base.repository.get_object_content(calc.base.attributes.get('output_filename'))

# CP2K was updating geometry - continue with that.
restart_geometry_transformation = "Max. gradient =" in content_string
restart_geometry_transformation = "Max. gradient =" in content_string or "MD| Step number" in content_string
end_inner_scf_loop = "Total energy: " in content_string
# The message is written in the log file when the CP2K input parameter `LOG_PRINT_KEY` is set to True.
if not (restart_geometry_transformation or end_inner_scf_loop or "Writing RESTART" in content_string):
Expand All @@ -81,18 +73,26 @@ def restart_incomplete_calculation(self, calc):
"Sending a signal to stop the Base work chain.")

# Signaling to the base work chain that the problem could not be recovered.
return ProcessHandlerReport(True, self.exit_codes.NO_RESTART_DATA)
return engine.ProcessHandlerReport(True, self.exit_codes.NO_RESTART_DATA)

self.ctx.inputs.parent_calc_folder = calc.outputs.remote_folder
params = self.ctx.inputs.parameters

params = add_wfn_restart_section(params, Bool('kpoints' in self.ctx.inputs))
params = utils.add_wfn_restart_section(params, orm.Bool('kpoints' in self.ctx.inputs))

if restart_geometry_transformation:
params = add_ext_restart_section(params)
# Check if we need to fix restart snapshot in REFTRAJ MD
first_snapshot = None
try:
first_snapshot = int(params['MOTION']['MD']['REFTRAJ']['FIRST_SNAPSHOT']) + calc.outputs.output_trajectory.get_shape('positions')[0]
if first_snapshot:
params = utils.add_first_snapshot_in_reftraj_section(params, first_snapshot)
except KeyError:
pass
params = utils.add_ext_restart_section(params)

self.ctx.inputs.parameters = params # params (new or old ones) that include the necessary restart information.
self.report(
"The CP2K calculation wasn't completed. The restart of the calculation might be able to "
"fix the problem.")
return ProcessHandlerReport(False)
return engine.ProcessHandlerReport(False)
Loading

0 comments on commit 33fd994

Please sign in to comment.