Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Node fix #289

Merged
merged 10 commits into from
Jun 27, 2024
1 change: 0 additions & 1 deletion apax/model/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ def __init__(self, model_config: ModelConfig, n_species: int = 119):
self.n_species = n_species

def build_basis_function(self):

basis_config = self.config["basis"]
name = basis_config["name"]

Expand Down
35 changes: 21 additions & 14 deletions apax/nodes/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import typing as t

import ase.io
import numpy as np
import pandas as pd
import yaml
import zntrack.utils
Expand All @@ -21,7 +22,7 @@ class ApaxBase(zntrack.Node):


class Apax(ApaxBase):
"""Class for the implementation of the apax model
"""Class for traing Apax models

Parameters
----------
Expand All @@ -32,19 +33,16 @@ class Apax(ApaxBase):
validation_data: list[ase.Atoms]
atoms object with the validation data set
model: t.Optional[Apax]
model to be used as a base model
model_directory: pathlib.Path
model directory
train_data_file: pathlib.Path
output path to the training data
validation_data_file: pathlib.Path
output path to the validation data
model to be used as a base model for transfer learning
log_level: str
verbosity of logging during training
"""

data: list = zntrack.deps()
config: str = zntrack.params_path()
validation_data = zntrack.deps()
model: t.Optional[t.Any] = zntrack.deps(None)
log_level: str = zntrack.meta.Text("info")

model_directory: pathlib.Path = zntrack.outs_path(zntrack.nwd / "apax_model")

Expand Down Expand Up @@ -84,20 +82,29 @@ def _handle_parameter_file(self):

def train_model(self):
"""Train the model using `apax.train.run`"""
apax_run(self._parameter)
apax_run(self._parameter, log_level=self.log_level)

def get_metrics_from_plots(self):
def get_metrics(self):
"""In addition to the plots write a model metric"""
metrics_df = pd.read_csv(self.model_directory / "log.csv")
self.metrics = metrics_df.iloc[-1].to_dict()
best_epoch = np.argmin(metrics_df["val_loss"])
self.metrics = metrics_df.iloc[best_epoch].to_dict()

def run(self):
"""Primary method to run which executes all steps of the model training"""
ase.io.write(self.train_data_file, self.data)
ase.io.write(self.validation_data_file, self.validation_data)
if not self.state.restarted:
ase.io.write(self.train_data_file.as_posix(), self.data)
ase.io.write(self.validation_data_file.as_posix(), self.validation_data)

csv_path = self.model_directory / "log.csv"
if self.state.restarted and csv_path.is_file():
metrics_df = pd.read_csv(self.model_directory / "log.csv")

if metrics_df["epoch"].iloc[-1] >= self._parameter["n_epochs"] - 1:
return

self.train_model()
self.get_metrics_from_plots()
self.get_metrics()

def get_calculator(self, **kwargs):
"""Get an apax ase calculator"""
Expand Down
52 changes: 49 additions & 3 deletions apax/train/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,67 @@
log = logging.getLogger(__name__)


def format_str(k):
return f"{k:.5f}"


class CSVLoggerApax(CSVLogger):
def __init__(self, filename, separator=",", append=False):
super().__init__(filename, separator=",", append=False)
super().__init__(filename, separator=separator, append=append)

def on_test_batch_end(self, batch, logs=None):
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}

def handle_value(k):
is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0
if isinstance(k, str):
return k
elif isinstance(k, collections.abc.Iterable) and not is_zero_dim_ndarray:
return f"\"[{', '.join(map(str, k))}]\""
return f"\"[{', '.join(map(format_str, k))}]\""
else:
return format_str(k)

if self.keys is None:
self.keys = sorted(logs.keys())
# When validation_freq > 1, `val_` keys are not in first epoch logs
# Add the `val_` keys so that its part of the fieldnames of writer.
val_keys_found = False
for key in self.keys:
if key.startswith("val_"):
val_keys_found = True
break
if not val_keys_found:
self.keys.extend(["val_" + k for k in self.keys])

if not self.writer:

class CustomDialect(csv.excel):
delimiter = self.sep

fieldnames = ["epoch"] + self.keys

self.writer = csv.DictWriter(
self.csv_file, fieldnames=fieldnames, dialect=CustomDialect
)
if self.append_header:
self.writer.writeheader()

row_dict = collections.OrderedDict({"epoch": epoch})
row_dict.update((key, handle_value(logs.get(key, "NA"))) for key in self.keys)
self.writer.writerow(row_dict)
self.csv_file.flush()

def on_test_batch_end(self, batch, logs=None):
logs = logs or {}

def handle_value(k):
is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0
if isinstance(k, str):
return k
elif isinstance(k, collections.abc.Iterable) and not is_zero_dim_ndarray:
return f"\"[{', '.join(map(format_str, k))}]\""
else:
return format_str(k)

if self.keys is None:
self.keys = sorted(logs.keys())
Expand Down
Loading