Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Feb 14, 2024
1 parent a64aae1 commit 9fb92c5
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 25 deletions.
14 changes: 10 additions & 4 deletions aiida_nanotech_empa/workflows/cp2k/cp2k_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,9 @@ def determine_kinds(structure, magnetization_per_site=None, ghost_per_site=None)
combined[symbol + "_0_0"] = 0
else:
tag += 1
combined[symbol + "_" + str(mag_ghost[0]) + "_" + str(mag_ghost[1])] = (
tag
)
combined[
symbol + "_" + str(mag_ghost[0]) + "_" + str(mag_ghost[1])
] = tag

# Assigning correct tags to every atom.
tags1 = [combined[key] for key in complex_symbols]
Expand Down Expand Up @@ -205,7 +205,13 @@ def get_dft_inputs(dft_params, structure, template, protocol):
# number of atoms
if isinstance(structure, orm.TrajectoryData):
natoms = structure.get_shape("positions")[1]
structure = orm.StructureData(ase=ase.Atoms(structure.symbols, positions=structure.get_array('positions')[0],cell=(structure.get_array('cells')[0])))
structure = orm.StructureData(
ase=ase.Atoms(
structure.symbols,
positions=structure.get_array("positions")[0],
cell=(structure.get_array("cells")[0]),
)
)
natoms = len(structure.sites)

# Load input template.
Expand Down
39 changes: 21 additions & 18 deletions aiida_nanotech_empa/workflows/cp2k/reftraj_md_workchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,17 @@
from . import cp2k_utils

Cp2kBaseWorkChain = plugins.WorkflowFactory("cp2k.base")
#Cp2kRefTrajWorkChain = plugins.WorkflowFactory("cp2k.reftraj")
# Cp2kRefTrajWorkChain = plugins.WorkflowFactory("cp2k.reftraj")
TrajectoryData = plugins.DataFactory("array.trajectory")


@engine.calcfunction
def create_batches(trajectory, num_batches,steps_completed):
"""Create lists of consecutive integers. Counting start from 1 for CP2K input.
"""
lst=[i+1 for i in range(trajectory.get_shape('positions')[0])]
def create_batches(trajectory, num_batches, steps_completed):
"""Create lists of consecutive integers. Counting start from 1 for CP2K input."""
lst = [i + 1 for i in range(trajectory.get_shape("positions")[0])]
for i in steps_completed:
lst.remove(i)
max_batch_size = int(len(lst)/num_batches)
max_batch_size = int(len(lst) / num_batches)
consecutive_lists = []
current_list = []
for num in lst:
Expand All @@ -45,7 +44,7 @@ def define(cls, spec):

# Define the inputs of the workflow.
spec.input("code", valid_type=orm.Code)
#spec.input("structure", valid_type=orm.StructureData)
# spec.input("structure", valid_type=orm.StructureData)
spec.input("trajectory", valid_type=TrajectoryData)
spec.input("num_batches", valid_type=orm.Int, default=lambda: orm.Int(10))
spec.input("parent_calc_folder", valid_type=orm.RemoteData, required=False)
Expand Down Expand Up @@ -94,10 +93,12 @@ def setup(self):
)
self.ctx.input_dict["GLOBAL"]["WALLTIME"] = max(
600, self.inputs.options["max_wallclock_seconds"] - 600
)
)
self.ctx.steps_completed = []
# create batches avoiding steps already completed.
self.ctx.batches = create_batches(self.inputs.trajectory, self.inputs.num_batches,self.ctx.steps_completed).get_list()
self.ctx.batches = create_batches(
self.inputs.trajectory, self.inputs.num_batches, self.ctx.steps_completed
).get_list()
return engine.ExitCode(0)

def first_structure(self):
Expand All @@ -106,12 +107,12 @@ def first_structure(self):
batches = self.ctx.batches
first_snapshot = batches[0].pop(0)
self.ctx.batches = batches

self.report(f"Running structure {first_snapshot} to {first_snapshot} ")

input_dict["MOTION"]["MD"]["REFTRAJ"]["FIRST_SNAPSHOT"] = first_snapshot
input_dict["MOTION"]["MD"]["REFTRAJ"]["LAST_SNAPSHOT"] = first_snapshot

# create the input for the reftraj workchain
builder = Cp2kBaseWorkChain.get_builder()
builder.cp2k.structure = orm.StructureData(ase=self.ctx.structure_with_tags)
Expand All @@ -127,7 +128,9 @@ def first_structure(self):
builder.cp2k.parameters = orm.Dict(dict=input_dict)

future = self.submit(builder)
self.report(f"Submitted structures {first_snapshot} to {first_snapshot}: {future.pk}")
self.report(
f"Submitted structures {first_snapshot} to {first_snapshot}: {future.pk}"
)
self.to_context(first_structure=future)

def run_reftraj_batches(self):
Expand All @@ -154,18 +157,18 @@ def run_reftraj_batches(self):
builder.cp2k.parent_calc_folder = (
self.ctx.first_structure.outputs.remote_folder
)

future = self.submit(builder)

key = f"reftraj_batch_{batch[0]}_to_{batch[-1]}"
self.report(f"Submitted reftraj batch: {key} with pk: {future.pk}")

self.to_context(**{key: engine.append_(future)})

def merge_batches_output(self):
"""Merge the output of the succefull batches only."""
self.report("done")
#merged_traj = []
#for i_batch in range(self.ctx.n_batches):
# merged_traj = []
# for i_batch in range(self.ctx.n_batches):
# merged_traj.extend(self.ctx[f"reftraj_batch_{i_batch}"].outputs.trajectory)
return engine.ExitCode(0)
6 changes: 3 additions & 3 deletions examples/workflows/example_cp2k_md_reftraj.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@


def _example_cp2k_reftraj(cp2k_code):
thisdir = os.path.dirname(os.path.realpath(__file__))
os.path.dirname(os.path.realpath(__file__))

# Structure.
#structure = StructureData(ase=ase.io.read(os.path.join(thisdir, ".", "h2.xyz")))
# structure = StructureData(ase=ase.io.read(os.path.join(thisdir, ".", "h2.xyz")))

# Trajectory.
steps = 20
Expand Down Expand Up @@ -54,7 +54,7 @@ def _example_cp2k_reftraj(cp2k_code):
},
}

#builder.structure = structure
# builder.structure = structure
builder.trajectory = trajectory
builder.num_batches = orm.Int(2)
builder.protocol = orm.Str("debug")
Expand Down

0 comments on commit 9fb92c5

Please sign in to comment.