Skip to content

Commit

Permalink
Merge branch 'dev' into sharding
Browse files Browse the repository at this point in the history
  • Loading branch information
M-R-Schaefer authored Apr 8, 2024
2 parents c119992 + 1c483d0 commit 7b9e29b
Show file tree
Hide file tree
Showing 24 changed files with 2,472 additions and 245 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
- name: Install package
run: |
poetry --version
poetry install
poetry install --all-extras
- name: Unit Tests
run: |
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ repos:
- id: trailing-whitespace

- repo: https://github.com/psf/black
rev: 24.1.1
rev: 24.3.0
hooks:
- id: black
exclude: ^apax/utils/jax_md_reduced/
Expand Down
28 changes: 12 additions & 16 deletions apax/data/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@ def pad_nl(idx, offsets, max_neighbors):
return idx, offsets


def find_largest_system(inputs: dict[str, np.ndarray], r_max) -> tuple[int]:
def find_largest_system(inputs, r_max) -> tuple[int]:
positions, boxes = inputs["positions"], inputs["box"]
max_atoms = np.max(inputs["n_atoms"])

max_nbrs = 0
for position, box in zip(inputs["positions"], inputs["box"]):
neighbor_idxs, _ = compute_nl(position, box, r_max)
for pos, box in zip(positions, boxes):
neighbor_idxs, _ = compute_nl(pos, box, r_max)
n_neighbors = neighbor_idxs.shape[1]
max_nbrs = max(max_nbrs, n_neighbors)

Expand All @@ -38,7 +39,7 @@ def find_largest_system(inputs: dict[str, np.ndarray], r_max) -> tuple[int]:
class InMemoryDataset:
def __init__(
self,
atoms,
atoms_list,
cutoff,
bs,
n_epochs,
Expand All @@ -50,26 +51,24 @@ def __init__(
ignore_labels=False,
cache_path=".",
) -> None:

self.n_epochs = n_epochs
self.cutoff = cutoff
self.n_jit_steps = n_jit_steps
self.buffer_size = buffer_size
self.n_data = len(atoms)
self.n_data = len(atoms_list)
self.batch_size = self.validate_batch_size(bs)
self.pos_unit = pos_unit

if pre_shuffle:
shuffle(atoms)
self.sample_atoms = atoms[0]
self.inputs = atoms_to_inputs(atoms, self.pos_unit)
shuffle(atoms_list)
self.sample_atoms = atoms_list[0]
self.inputs = atoms_to_inputs(atoms_list, pos_unit)

max_atoms, max_nbrs = find_largest_system(self.inputs, self.cutoff)
self.max_atoms = max_atoms
self.max_nbrs = max_nbrs

if atoms[0].calc and not ignore_labels:
self.labels = atoms_to_labels(atoms, self.pos_unit, energy_unit)
if atoms_list[0].calc and not ignore_labels:
self.labels = atoms_to_labels(atoms_list, pos_unit, energy_unit)
else:
self.labels = None

Expand Down Expand Up @@ -109,9 +108,6 @@ def prepare_data(self, i):
inputs["numbers"] = np.pad(
inputs["numbers"], (0, zeros_to_add), "constant"
).astype(np.int16)
inputs["n_atoms"] = np.pad(
inputs["n_atoms"], (0, zeros_to_add), "constant"
).astype(np.int16)

if not self.labels:
return inputs
Expand All @@ -121,7 +117,6 @@ def prepare_data(self, i):
labels["forces"] = np.pad(
labels["forces"], ((0, zeros_to_add), (0, 0)), "constant"
)

inputs = {k: tf.constant(v) for k, v in inputs.items()}
labels = {k: tf.constant(v) for k, v in labels.items()}
return (inputs, labels)
Expand Down Expand Up @@ -170,6 +165,7 @@ def init_input(self) -> Dict[str, np.ndarray]:
"""Returns first batch of inputs and labels to init the model."""
positions = self.sample_atoms.positions * unit_dict[self.pos_unit]
box = self.sample_atoms.cell.array * unit_dict[self.pos_unit]
# For an input sample, it does not matter whether pos is fractional or cartesian
idx, offsets = compute_nl(positions, box, self.cutoff)
inputs = (
positions,
Expand Down
24 changes: 13 additions & 11 deletions apax/data/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,33 @@
log = logging.getLogger(__name__)


def compute_nl(position, box, r_max):
def compute_nl(positions, box, r_max):
"""Computes the NL for a single structure.
For periodic systems, positions are assumed to be in
fractional coordinates.
"""
if np.all(box < 1e-6):
cell, cell_origin = get_shrink_wrapped_cell(position)
box, box_origin = get_shrink_wrapped_cell(positions)
idxs_i, idxs_j = neighbour_list(
"ij",
positions=position,
positions=positions,
cutoff=r_max,
cell=cell,
cell_origin=cell_origin,
cell=box,
cell_origin=box_origin,
pbc=[False, False, False],
)

neighbor_idxs = np.array([idxs_i, idxs_j], dtype=np.int32)
neighbor_idxs = np.array([idxs_i, idxs_j], dtype=np.int16)

n_neighbors = neighbor_idxs.shape[1]
offsets = np.full([n_neighbors, 3], 0)

else:
positions = positions @ box
idxs_i, idxs_j, offsets = neighbour_list(
"ijS",
positions=position,
cutoff=r_max,
cell=box,
"ijS", positions=positions, cutoff=r_max, cell=box, pbc=[True, True, True]
)
neighbor_idxs = np.array([idxs_i, idxs_j], dtype=np.int32)
neighbor_idxs = np.array([idxs_i, idxs_j], dtype=np.int16)
offsets = np.matmul(offsets, box)
return neighbor_idxs, offsets

Expand Down
4 changes: 4 additions & 0 deletions apax/nodes/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .md import ApaxJaxMD
from .model import Apax, ApaxEnsemble

__all__ = ["Apax", "ApaxEnsemble", "ApaxJaxMD"]
87 changes: 87 additions & 0 deletions apax/nodes/md.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import functools
import logging
import pathlib
import typing

import ase.io
import h5py
import yaml
import znh5md
import zntrack.utils

from apax.md.nvt import run_md

from .model import Apax
from .utils import check_duplicate_keys

log = logging.getLogger(__name__)


class ApaxJaxMD(zntrack.Node):
"""Class to run a more performant JaxMD simulation with a apax Model.
Attributes
----------
data: list[ase.Atoms]
MD starting structure
data_id: int, default=-1
index of the configuration from the data list to use
model: ApaxModel
model to use for the simulation
repeat: float
number of repeats
config: str
path to the MD simulation parameter file
"""

data: list[ase.Atoms] = zntrack.deps()
data_id: int = zntrack.params(-1)

model: Apax = zntrack.deps()
repeat = zntrack.params(None)

config: str = zntrack.params_path(None)

sim_dir: pathlib.Path = zntrack.outs_path(zntrack.nwd / "md")
init_struc_dir: pathlib.Path = zntrack.outs_path(
zntrack.nwd / "initial_structure.extxyz"
)

_parameter: dict = None

def _post_load_(self) -> None:
self._handle_parameter_file()

def _handle_parameter_file(self):
with self.state.use_tmp_path():
self._parameter = yaml.safe_load(pathlib.Path(self.config).read_text())

custom_parameters = {
"sim_dir": self.sim_dir.as_posix(),
"initial_structure": self.init_struc_dir.as_posix(),
}
check_duplicate_keys(custom_parameters, self._parameter, log)
self._parameter.update(custom_parameters)

def run(self):
"""Primary method to run which executes all steps of the model training"""

atoms = self.data[self.data_id]
if self.repeat is not None:
atoms = atoms.repeat(self.repeat)
ase.io.write(self.init_struc_dir.as_posix(), atoms)

run_md(self.model._parameter, self._parameter)

@functools.cached_property
def atoms(self) -> typing.List[ase.Atoms]:
def file_handle(filename):
file = self.state.fs.open(filename, "rb")
return h5py.File(file)

return znh5md.ASEH5MD(
self.sim_dir / "md.h5",
format_handler=functools.partial(
znh5md.FormatHandler, file_handle=file_handle
),
).get_atoms_list()
147 changes: 147 additions & 0 deletions apax/nodes/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import logging
import pathlib
import typing as t

import ase.io
import pandas as pd
import yaml
import zntrack.utils

from apax.md import ASECalculator
from apax.md.function_transformations import available_transformations
from apax.train.run import run as apax_run

from .utils import check_duplicate_keys

log = logging.getLogger(__name__)


class Apax(zntrack.Node):
"""Class for the implementation of the apax model
Attributes
----------
config: str
path to the apax config file
data: list[ase.Atoms]
the training data set
validation_data: list[ase.Atoms]
atoms object with the validation data set
model: t.Optional[Apax]
model to be used as a base model
model_directory: pathlib.Path
model directory
train_data_file: pathlib.Path
output path to the training data
validation_data_file: pathlib.Path
output path to the validation data
"""

data: list = zntrack.deps()
config: str = zntrack.params_path()
validation_data = zntrack.deps()
model: t.Optional[t.Any] = zntrack.deps(None)

model_directory: pathlib.Path = zntrack.outs_path(zntrack.nwd / "apax_model")

train_data_file: pathlib.Path = zntrack.outs_path(zntrack.nwd / "train_atoms.extxyz")
validation_data_file: pathlib.Path = zntrack.outs_path(
zntrack.nwd / "val_atoms.extxyz"
)

metrics = zntrack.metrics()

_parameter: dict = None

def _post_load_(self) -> None:
self._handle_parameter_file()

def _handle_parameter_file(self):
with self.state.use_tmp_path():
self._parameter = yaml.safe_load(pathlib.Path(self.config).read_text())

custom_parameters = {
"directory": self.model_directory.resolve().as_posix(),
"experiment": "",
"train_data_path": self.train_data_file.as_posix(),
"val_data_path": self.validation_data_file.as_posix(),
}

if self.model is not None:
param_files = self.model._parameter["data"]["directory"]
base_path = {"base_model_checkpoint": param_files}
try:
self._parameter["checkpoints"].update(base_path)
except KeyError:
self._parameter["checkpoints"] = base_path

check_duplicate_keys(custom_parameters, self._parameter["data"], log)
self._parameter["data"].update(custom_parameters)

def train_model(self):
"""Train the model using `apax.train.run`"""
apax_run(self._parameter)

def get_metrics_from_plots(self):
"""In addition to the plots write a model metric"""
metrics_df = pd.read_csv(self.model_directory / "log.csv")
self.metrics = metrics_df.iloc[-1].to_dict()

def run(self):
"""Primary method to run which executes all steps of the model training"""
ase.io.write(self.train_data_file, self.data)
ase.io.write(self.validation_data_file, self.validation_data)

self.train_model()
self.get_metrics_from_plots()

def get_calculator(self, **kwargs):
"""Get an apax ase calculator"""
with self.state.use_tmp_path():
return ASECalculator(model_dir=self.model_directory)


class ApaxEnsemble(zntrack.Node):
"""Parallel apax model ensemble in ASE.
Attributes
----------
models: list
List of `ApaxModel` nodes to ensemble.
nl_skin: float
Neighborlist skin.
transformations: dict
Key-parameter dict with function transformations applied
to the model function within the ASE calculator.
See the apax documentation for available methods.
"""

models: list[Apax] = zntrack.deps()
nl_skin: float = zntrack.params(0.5)
transformations: dict[str, dict] = zntrack.params(None)

def run(self) -> None:
pass

def get_calculator(self, **kwargs) -> ase.calculators.calculator.Calculator:
"""Property to return a model specific ase calculator object.
Returns
-------
calc:
ase calculator object
"""

param_files = [m._parameter["data"]["directory"] for m in self.models]

transformations = []
if self.transformations:
for transform, params in self.transformations.items():
transformations.append(available_transformations[transform](**params))

calc = ASECalculator(
param_files,
dr=self.nl_skin,
transformations=transformations,
)
return calc
Loading

0 comments on commit 7b9e29b

Please sign in to comment.