diff --git a/apax/model/builder.py b/apax/model/builder.py index 7a19ed65..95c48a27 100644 --- a/apax/model/builder.py +++ b/apax/model/builder.py @@ -19,7 +19,6 @@ def __init__(self, model_config: ModelConfig, n_species: int = 119): self.n_species = n_species def build_basis_function(self): - basis_config = self.config["basis"] name = basis_config["name"] diff --git a/apax/nodes/model.py b/apax/nodes/model.py index 160cfc41..e32b77fb 100644 --- a/apax/nodes/model.py +++ b/apax/nodes/model.py @@ -3,6 +3,7 @@ import typing as t import ase.io +import numpy as np import pandas as pd import yaml import zntrack.utils @@ -21,7 +22,7 @@ class ApaxBase(zntrack.Node): class Apax(ApaxBase): - """Class for the implementation of the apax model + """Class for traing Apax models Parameters ---------- @@ -32,19 +33,16 @@ class Apax(ApaxBase): validation_data: list[ase.Atoms] atoms object with the validation data set model: t.Optional[Apax] - model to be used as a base model - model_directory: pathlib.Path - model directory - train_data_file: pathlib.Path - output path to the training data - validation_data_file: pathlib.Path - output path to the validation data + model to be used as a base model for transfer learning + log_level: str + verbosity of logging during training """ data: list = zntrack.deps() config: str = zntrack.params_path() validation_data = zntrack.deps() model: t.Optional[t.Any] = zntrack.deps(None) + log_level: str = zntrack.meta.Text("info") model_directory: pathlib.Path = zntrack.outs_path(zntrack.nwd / "apax_model") @@ -84,20 +82,29 @@ def _handle_parameter_file(self): def train_model(self): """Train the model using `apax.train.run`""" - apax_run(self._parameter) + apax_run(self._parameter, log_level=self.log_level) - def get_metrics_from_plots(self): + def get_metrics(self): """In addition to the plots write a model metric""" metrics_df = pd.read_csv(self.model_directory / "log.csv") - self.metrics = metrics_df.iloc[-1].to_dict() + best_epoch = np.argmin(metrics_df["val_loss"]) + self.metrics = metrics_df.iloc[best_epoch].to_dict() def run(self): """Primary method to run which executes all steps of the model training""" - ase.io.write(self.train_data_file, self.data) - ase.io.write(self.validation_data_file, self.validation_data) + if not self.state.restarted: + ase.io.write(self.train_data_file.as_posix(), self.data) + ase.io.write(self.validation_data_file.as_posix(), self.validation_data) + + csv_path = self.model_directory / "log.csv" + if self.state.restarted and csv_path.is_file(): + metrics_df = pd.read_csv(self.model_directory / "log.csv") + + if metrics_df["epoch"].iloc[-1] >= self._parameter["n_epochs"] - 1: + return self.train_model() - self.get_metrics_from_plots() + self.get_metrics() def get_calculator(self, **kwargs): """Get an apax ase calculator""" diff --git a/apax/train/callbacks.py b/apax/train/callbacks.py index 1c99fc8f..bac4a428 100644 --- a/apax/train/callbacks.py +++ b/apax/train/callbacks.py @@ -13,11 +13,15 @@ log = logging.getLogger(__name__) +def format_str(k): + return f"{k:.5f}" + + class CSVLoggerApax(CSVLogger): def __init__(self, filename, separator=",", append=False): - super().__init__(filename, separator=",", append=False) + super().__init__(filename, separator=separator, append=append) - def on_test_batch_end(self, batch, logs=None): + def on_epoch_end(self, epoch, logs=None): logs = logs or {} def handle_value(k): @@ -25,9 +29,51 @@ def handle_value(k): if isinstance(k, str): return k elif isinstance(k, collections.abc.Iterable) and not is_zero_dim_ndarray: - return f"\"[{', '.join(map(str, k))}]\"" + return f"\"[{', '.join(map(format_str, k))}]\"" else: + return format_str(k) + + if self.keys is None: + self.keys = sorted(logs.keys()) + # When validation_freq > 1, `val_` keys are not in first epoch logs + # Add the `val_` keys so that its part of the fieldnames of writer. + val_keys_found = False + for key in self.keys: + if key.startswith("val_"): + val_keys_found = True + break + if not val_keys_found: + self.keys.extend(["val_" + k for k in self.keys]) + + if not self.writer: + + class CustomDialect(csv.excel): + delimiter = self.sep + + fieldnames = ["epoch"] + self.keys + + self.writer = csv.DictWriter( + self.csv_file, fieldnames=fieldnames, dialect=CustomDialect + ) + if self.append_header: + self.writer.writeheader() + + row_dict = collections.OrderedDict({"epoch": epoch}) + row_dict.update((key, handle_value(logs.get(key, "NA"))) for key in self.keys) + self.writer.writerow(row_dict) + self.csv_file.flush() + + def on_test_batch_end(self, batch, logs=None): + logs = logs or {} + + def handle_value(k): + is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0 + if isinstance(k, str): return k + elif isinstance(k, collections.abc.Iterable) and not is_zero_dim_ndarray: + return f"\"[{', '.join(map(format_str, k))}]\"" + else: + return format_str(k) if self.keys is None: self.keys = sorted(logs.keys())