Skip to content

Commit

Permalink
Merge branch 'dev' into nl_fix
Browse files Browse the repository at this point in the history
  • Loading branch information
M-R-Schaefer authored Apr 8, 2024
2 parents 0a9ec1c + 286f80e commit 631e34a
Show file tree
Hide file tree
Showing 20 changed files with 2,425 additions and 209 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
- name: Install package
run: |
poetry --version
poetry install
poetry install --all-extras
- name: Unit Tests
run: |
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ repos:
- id: trailing-whitespace

- repo: https://github.com/psf/black
rev: 24.1.1
rev: 24.3.0
hooks:
- id: black
exclude: ^apax/utils/jax_md_reduced/
Expand Down
4 changes: 4 additions & 0 deletions apax/nodes/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .md import ApaxJaxMD
from .model import Apax, ApaxEnsemble

__all__ = ["Apax", "ApaxEnsemble", "ApaxJaxMD"]
87 changes: 87 additions & 0 deletions apax/nodes/md.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import functools
import logging
import pathlib
import typing

import ase.io
import h5py
import yaml
import znh5md
import zntrack.utils

from apax.md.nvt import run_md

from .model import Apax
from .utils import check_duplicate_keys

log = logging.getLogger(__name__)


class ApaxJaxMD(zntrack.Node):
"""Class to run a more performant JaxMD simulation with a apax Model.
Attributes
----------
data: list[ase.Atoms]
MD starting structure
data_id: int, default=-1
index of the configuration from the data list to use
model: ApaxModel
model to use for the simulation
repeat: float
number of repeats
config: str
path to the MD simulation parameter file
"""

data: list[ase.Atoms] = zntrack.deps()
data_id: int = zntrack.params(-1)

model: Apax = zntrack.deps()
repeat = zntrack.params(None)

config: str = zntrack.params_path(None)

sim_dir: pathlib.Path = zntrack.outs_path(zntrack.nwd / "md")
init_struc_dir: pathlib.Path = zntrack.outs_path(
zntrack.nwd / "initial_structure.extxyz"
)

_parameter: dict = None

def _post_load_(self) -> None:
self._handle_parameter_file()

def _handle_parameter_file(self):
with self.state.use_tmp_path():
self._parameter = yaml.safe_load(pathlib.Path(self.config).read_text())

custom_parameters = {
"sim_dir": self.sim_dir.as_posix(),
"initial_structure": self.init_struc_dir.as_posix(),
}
check_duplicate_keys(custom_parameters, self._parameter, log)
self._parameter.update(custom_parameters)

def run(self):
"""Primary method to run which executes all steps of the model training"""

atoms = self.data[self.data_id]
if self.repeat is not None:
atoms = atoms.repeat(self.repeat)
ase.io.write(self.init_struc_dir.as_posix(), atoms)

run_md(self.model._parameter, self._parameter)

@functools.cached_property
def atoms(self) -> typing.List[ase.Atoms]:
def file_handle(filename):
file = self.state.fs.open(filename, "rb")
return h5py.File(file)

return znh5md.ASEH5MD(
self.sim_dir / "md.h5",
format_handler=functools.partial(
znh5md.FormatHandler, file_handle=file_handle
),
).get_atoms_list()
147 changes: 147 additions & 0 deletions apax/nodes/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import logging
import pathlib
import typing as t

import ase.io
import pandas as pd
import yaml
import zntrack.utils

from apax.md import ASECalculator
from apax.md.function_transformations import available_transformations
from apax.train.run import run as apax_run

from .utils import check_duplicate_keys

log = logging.getLogger(__name__)


class Apax(zntrack.Node):
"""Class for the implementation of the apax model
Attributes
----------
config: str
path to the apax config file
data: list[ase.Atoms]
the training data set
validation_data: list[ase.Atoms]
atoms object with the validation data set
model: t.Optional[Apax]
model to be used as a base model
model_directory: pathlib.Path
model directory
train_data_file: pathlib.Path
output path to the training data
validation_data_file: pathlib.Path
output path to the validation data
"""

data: list = zntrack.deps()
config: str = zntrack.params_path()
validation_data = zntrack.deps()
model: t.Optional[t.Any] = zntrack.deps(None)

model_directory: pathlib.Path = zntrack.outs_path(zntrack.nwd / "apax_model")

train_data_file: pathlib.Path = zntrack.outs_path(zntrack.nwd / "train_atoms.extxyz")
validation_data_file: pathlib.Path = zntrack.outs_path(
zntrack.nwd / "val_atoms.extxyz"
)

metrics = zntrack.metrics()

_parameter: dict = None

def _post_load_(self) -> None:
self._handle_parameter_file()

def _handle_parameter_file(self):
with self.state.use_tmp_path():
self._parameter = yaml.safe_load(pathlib.Path(self.config).read_text())

custom_parameters = {
"directory": self.model_directory.resolve().as_posix(),
"experiment": "",
"train_data_path": self.train_data_file.as_posix(),
"val_data_path": self.validation_data_file.as_posix(),
}

if self.model is not None:
param_files = self.model._parameter["data"]["directory"]
base_path = {"base_model_checkpoint": param_files}
try:
self._parameter["checkpoints"].update(base_path)
except KeyError:
self._parameter["checkpoints"] = base_path

check_duplicate_keys(custom_parameters, self._parameter["data"], log)
self._parameter["data"].update(custom_parameters)

def train_model(self):
"""Train the model using `apax.train.run`"""
apax_run(self._parameter)

def get_metrics_from_plots(self):
"""In addition to the plots write a model metric"""
metrics_df = pd.read_csv(self.model_directory / "log.csv")
self.metrics = metrics_df.iloc[-1].to_dict()

def run(self):
"""Primary method to run which executes all steps of the model training"""
ase.io.write(self.train_data_file, self.data)
ase.io.write(self.validation_data_file, self.validation_data)

self.train_model()
self.get_metrics_from_plots()

def get_calculator(self, **kwargs):
"""Get an apax ase calculator"""
with self.state.use_tmp_path():
return ASECalculator(model_dir=self.model_directory)


class ApaxEnsemble(zntrack.Node):
"""Parallel apax model ensemble in ASE.
Attributes
----------
models: list
List of `ApaxModel` nodes to ensemble.
nl_skin: float
Neighborlist skin.
transformations: dict
Key-parameter dict with function transformations applied
to the model function within the ASE calculator.
See the apax documentation for available methods.
"""

models: list[Apax] = zntrack.deps()
nl_skin: float = zntrack.params(0.5)
transformations: dict[str, dict] = zntrack.params(None)

def run(self) -> None:
pass

def get_calculator(self, **kwargs) -> ase.calculators.calculator.Calculator:
"""Property to return a model specific ase calculator object.
Returns
-------
calc:
ase calculator object
"""

param_files = [m._parameter["data"]["directory"] for m in self.models]

transformations = []
if self.transformations:
for transform, params in self.transformations.items():
transformations.append(available_transformations[transform](**params))

calc = ASECalculator(
param_files,
dr=self.nl_skin,
transformations=transformations,
)
return calc
30 changes: 30 additions & 0 deletions apax/nodes/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import functools
import logging

import ase.io
import zntrack


class AddData(zntrack.Node):
file: str = zntrack.deps_path()

def run(self):
pass

@functools.cached_property
def atoms(self) -> list[ase.Atoms]:
data = []
for atoms in ase.io.iread(self.file):
data.append(atoms)
if len(data) == 50:
return data


def check_duplicate_keys(dict_a: dict, dict_b: dict, log: logging.Logger) -> None:
"""Check if a key of dict_a is present in dict_b and then log a warning."""
for key in dict_a:
if key in dict_b:
log.warning(
f"Found <{key}> in given config file. Please be aware that <{key}>"
" will be overwritten by MLSuite!"
)
10 changes: 6 additions & 4 deletions apax/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,12 @@ def fit(
epoch_loss["val_loss"] /= val_steps_per_epoch
epoch_loss["val_loss"] = float(epoch_loss["val_loss"])

epoch_metrics.update({
f"val_{key}": float(val)
for key, val in val_batch_metrics.compute().items()
})
epoch_metrics.update(
{
f"val_{key}": float(val)
for key, val in val_batch_metrics.compute().items()
}
)

epoch_metrics.update({**epoch_loss})

Expand Down
2 changes: 1 addition & 1 deletion apax/utils/jax_md_reduced/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,7 @@ def neighbor_fn(position_and_error, max_occupancy=None):
if not is_sparse(format):
capacity_limit = N - 1 if mask_self else N
elif format is NeighborListFormat.Sparse:
capacity_limit = N * (N - 1) if mask_self else N ** 2
capacity_limit = N * (N - 1) if mask_self else N**2
else:
capacity_limit = N * (N - 1) // 2
if max_occupancy > capacity_limit:
Expand Down
Loading

0 comments on commit 631e34a

Please sign in to comment.