Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ZnTrack Nodes #254

Merged
merged 25 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
- name: Install package
run: |
poetry --version
poetry install
poetry install --all-extras

- name: Unit Tests
run: |
Expand Down
4 changes: 4 additions & 0 deletions apax/nodes/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .md import ApaxJaxMD
from .model import Apax, ApaxEnsemble

__all__ = ["Apax", "ApaxEnsemble", "ApaxJaxMD"]
87 changes: 87 additions & 0 deletions apax/nodes/md.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import functools
import logging
import pathlib
import typing

import ase.io
import h5py
import yaml
import znh5md
import zntrack.utils

from apax.md.nvt import run_md

from .model import Apax
from .utils import check_duplicate_keys

log = logging.getLogger(__name__)


class ApaxJaxMD(zntrack.Node):
"""Class to run a more performant JaxMD simulation with a apax Model.

Attributes
----------
data: list[ase.Atoms]
MD starting structure
data_id: int, default=-1
index of the configuration from the data list to use
model: ApaxModel
model to use for the simulation
repeat: float
number of repeats
config: str
path to the MD simulation parameter file
"""

data: list[ase.Atoms] = zntrack.deps()
data_id: int = zntrack.params(-1)

model: Apax = zntrack.deps()
repeat = zntrack.params(None)

config: str = zntrack.params_path(None)

sim_dir: pathlib.Path = zntrack.outs_path(zntrack.nwd / "md")
init_struc_dir: pathlib.Path = zntrack.outs_path(
zntrack.nwd / "initial_structure.extxyz"
)

_parameter: dict = None

def _post_load_(self) -> None:
self._handle_parameter_file()

def _handle_parameter_file(self):
with self.state.use_tmp_path():
self._parameter = yaml.safe_load(pathlib.Path(self.config).read_text())

custom_parameters = {
"sim_dir": self.sim_dir.as_posix(),
"initial_structure": self.init_struc_dir.as_posix(),
}
check_duplicate_keys(custom_parameters, self._parameter, log)
self._parameter.update(custom_parameters)

def run(self):
"""Primary method to run which executes all steps of the model training"""

atoms = self.data[self.data_id]
if self.repeat is not None:
atoms = atoms.repeat(self.repeat)
ase.io.write(self.init_struc_dir.as_posix(), atoms)

run_md(self.model._parameter, self._parameter)

@functools.cached_property
def atoms(self) -> typing.List[ase.Atoms]:
def file_handle(filename):
file = self.state.fs.open(filename, "rb")
return h5py.File(file)

return znh5md.ASEH5MD(
self.sim_dir / "md.h5",
format_handler=functools.partial(
znh5md.FormatHandler, file_handle=file_handle
),
).get_atoms_list()
147 changes: 147 additions & 0 deletions apax/nodes/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import logging
import pathlib
import typing as t

import ase.io
import pandas as pd
import yaml
import zntrack.utils

from apax.md import ASECalculator
from apax.md.function_transformations import available_transformations
from apax.train.run import run as apax_run

from .utils import check_duplicate_keys

log = logging.getLogger(__name__)


class Apax(zntrack.Node):
"""Class for the implementation of the apax model

Attributes
----------
config: str
path to the apax config file
data: list[ase.Atoms]
the training data set
validation_data: list[ase.Atoms]
atoms object with the validation data set
model: t.Optional[Apax]
model to be used as a base model
model_directory: pathlib.Path
model directory
train_data_file: pathlib.Path
output path to the training data
validation_data_file: pathlib.Path
output path to the validation data
"""

data: list = zntrack.deps()
config: str = zntrack.params_path()
validation_data = zntrack.deps()
model: t.Optional[t.Any] = zntrack.deps(None)

model_directory: pathlib.Path = zntrack.outs_path(zntrack.nwd / "apax_model")

train_data_file: pathlib.Path = zntrack.outs_path(zntrack.nwd / "train_atoms.extxyz")
validation_data_file: pathlib.Path = zntrack.outs_path(
zntrack.nwd / "val_atoms.extxyz"
)

metrics = zntrack.metrics()

_parameter: dict = None

def _post_load_(self) -> None:
self._handle_parameter_file()

def _handle_parameter_file(self):
with self.state.use_tmp_path():
self._parameter = yaml.safe_load(pathlib.Path(self.config).read_text())

custom_parameters = {
"directory": self.model_directory.resolve().as_posix(),
"experiment": "",
"train_data_path": self.train_data_file.as_posix(),
"val_data_path": self.validation_data_file.as_posix(),
}

if self.model is not None:
param_files = self.model._parameter["data"]["directory"]
base_path = {"base_model_checkpoint": param_files}
try:
self._parameter["checkpoints"].update(base_path)
except KeyError:
self._parameter["checkpoints"] = base_path

check_duplicate_keys(custom_parameters, self._parameter["data"], log)
self._parameter["data"].update(custom_parameters)

def train_model(self):
"""Train the model using `apax.train.run`"""
apax_run(self._parameter)

def get_metrics_from_plots(self):
"""In addition to the plots write a model metric"""
metrics_df = pd.read_csv(self.model_directory / "log.csv")
self.metrics = metrics_df.iloc[-1].to_dict()

def run(self):
"""Primary method to run which executes all steps of the model training"""
ase.io.write(self.train_data_file, self.data)
ase.io.write(self.validation_data_file, self.validation_data)

self.train_model()
self.get_metrics_from_plots()

def get_calculator(self, **kwargs):
"""Get an apax ase calculator"""
with self.state.use_tmp_path():
return ASECalculator(model_dir=self.model_directory)


class ApaxEnsemble(zntrack.Node):
"""Parallel apax model ensemble in ASE.

Attributes
----------
models: list
List of `ApaxModel` nodes to ensemble.
nl_skin: float
Neighborlist skin.
transformations: dict
Key-parameter dict with function transformations applied
to the model function within the ASE calculator.
See the apax documentation for available methods.
"""

models: list[Apax] = zntrack.deps()
nl_skin: float = zntrack.params(0.5)
transformations: dict[str, dict] = zntrack.params(None)

def run(self) -> None:
pass

def get_calculator(self, **kwargs) -> ase.calculators.calculator.Calculator:
"""Property to return a model specific ase calculator object.

Returns
-------
calc:
ase calculator object
"""

param_files = [m._parameter["data"]["directory"] for m in self.models]

transformations = []
if self.transformations:
for transform, params in self.transformations.items():
transformations.append(available_transformations[transform](**params))

calc = ASECalculator(
param_files,
dr=self.nl_skin,
transformations=transformations,
)
return calc
30 changes: 30 additions & 0 deletions apax/nodes/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import functools
import logging

import ase.io
import zntrack


class AddData(zntrack.Node):
file: str = zntrack.deps_path()

def run(self):
pass

@functools.cached_property
def atoms(self) -> list[ase.Atoms]:
data = []
for atoms in ase.io.iread(self.file):
data.append(atoms)
if len(data) == 50:
return data


def check_duplicate_keys(dict_a: dict, dict_b: dict, log: logging.Logger) -> None:
"""Check if a key of dict_a is present in dict_b and then log a warning."""
for key in dict_a:
if key in dict_b:
log.warning(
f"Found <{key}> in given config file. Please be aware that <{key}>"
" will be overwritten by MLSuite!"
)
Loading
Loading