Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
PythonFZ committed Apr 5, 2024
1 parent e8b9af9 commit 2ce3c06
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 22 deletions.
38 changes: 20 additions & 18 deletions apax/nodes/md.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import yaml
import znh5md
import zntrack.utils
from zntrack import dvc, zn

from apax.md.nvt import run_md

from .model import Apax
from .utils import check_duplicate_keys
Expand All @@ -21,55 +22,56 @@ class ApaxJaxMD(zntrack.Node):
Attributes
----------
data: list[ase.Atoms]
MD starting structure
data_id: int, default=-1
index of the configuration from the data list to use
model: ApaxModel
model to use for the simulation
repeat: float
number of repeats
md_parameter: dict
parameter for the MD simulation
md_parameter_file: str
config: str
path to the MD simulation parameter file
"""

data: list[ase.Atoms] = zntrack.deps()
data_id: int = zntrack.params(-1)

model: Apax = zntrack.deps()
repeat = zn.params(None)
repeat = zntrack.params(None)

config: str = zntrack.params_path(None)

md_parameter: dict = zn.params(None)
md_parameter_file: str = dvc.params(None)
sim_dir: pathlib.Path = zntrack.outs_path(zntrack.nwd / "md")
init_struc_dir: pathlib.Path = zntrack.outs_path(
zntrack.nwd / "initial_structure.extxyz"
)

sim_dir: pathlib.Path = dvc.outs(zntrack.nwd / "md")
init_struc_dir: pathlib.Path = dvc.outs(zntrack.nwd / "initial_structure.extxyz")
_parameter: dict = None

def _post_load_(self) -> None:
self._handle_parameter_file()

def _handle_parameter_file(self):
if self.md_parameter_file:
md_parameter_file_content = pathlib.Path(self.md_parameter_file).read_text()
self.md_parameter = yaml.safe_load(md_parameter_file_content)
with self.state.use_tmp_path():
self._parameter = yaml.safe_load(pathlib.Path(self.config).read_text())

custom_parameters = {
"sim_dir": self.sim_dir.as_posix(),
"initial_structure": self.init_struc_dir.as_posix(),
}
check_duplicate_keys(custom_parameters, self.md_parameter, log)
self.md_parameter.update(custom_parameters)
check_duplicate_keys(custom_parameters, self._parameter, log)
self._parameter.update(custom_parameters)

def run(self):
"""Primary method to run which executes all steps of the model training"""
from apax.md.nvt import run_md

# self._handle_parameter_file()
atoms = self.data[self.data_id]
if self.repeat is not None:
atoms = atoms.repeat(self.repeat)
ase.io.write(self.init_struc_dir.as_posix(), atoms)

self.model._handle_parameter_file()
run_md(self.model._parameter, self.md_parameter)
run_md(self.model._parameter, self._parameter)

@functools.cached_property
def atoms(self) -> typing.List[ase.Atoms]:
Expand Down
10 changes: 7 additions & 3 deletions apax/nodes/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,18 @@ class Apax(zntrack.Node):
----------
config: str
path to the apax config file
validation_data: ase.Atoms
data: list[ase.Atoms]
the training data set
validation_data: list[ase.Atoms]
atoms object with the validation data set
model: t.Optional[Apax]
model to be used as a base model
model_directory: pathlib.Path
model directory
train_data_file: pathlib.Path
path to the training data
output path to the training data
validation_data_file: pathlib.Path
path to the valdidation data
output path to the validation data
"""

data: list = zntrack.deps()
Expand Down
2 changes: 1 addition & 1 deletion tests/nodes/test_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_n_jax_md(tmp_path, get_md22_stachyose):
with proj:
data = AddData(file=get_md22_stachyose)
model = Apax(data=data.atoms, validation_data=data.atoms, config="example.yaml")
md = ApaxJaxMD(model=model, md_parameter_file="md.yaml", data=data.atoms)
md = ApaxJaxMD(model=model, config="md.yaml", data=data.atoms)

proj.run()

Expand Down

0 comments on commit 2ce3c06

Please sign in to comment.