3
3
import typing as t
4
4
5
5
import ase .io
6
+ import numpy as np
6
7
import pandas as pd
7
8
import yaml
8
9
import zntrack .utils
@@ -21,7 +22,7 @@ class ApaxBase(zntrack.Node):
21
22
22
23
23
24
class Apax (ApaxBase ):
24
- """Class for the implementation of the apax model
25
+ """Class for traing Apax models
25
26
26
27
Parameters
27
28
----------
@@ -32,19 +33,16 @@ class Apax(ApaxBase):
32
33
validation_data: list[ase.Atoms]
33
34
atoms object with the validation data set
34
35
model: t.Optional[Apax]
35
- model to be used as a base model
36
- model_directory: pathlib.Path
37
- model directory
38
- train_data_file: pathlib.Path
39
- output path to the training data
40
- validation_data_file: pathlib.Path
41
- output path to the validation data
36
+ model to be used as a base model for transfer learning
37
+ log_level: str
38
+ verbosity of logging during training
42
39
"""
43
40
44
41
data : list = zntrack .deps ()
45
42
config : str = zntrack .params_path ()
46
43
validation_data = zntrack .deps ()
47
44
model : t .Optional [t .Any ] = zntrack .deps (None )
45
+ log_level : str = zntrack .meta .Text ("info" )
48
46
49
47
model_directory : pathlib .Path = zntrack .outs_path (zntrack .nwd / "apax_model" )
50
48
@@ -84,20 +82,29 @@ def _handle_parameter_file(self):
84
82
85
83
def train_model (self ):
86
84
"""Train the model using `apax.train.run`"""
87
- apax_run (self ._parameter )
85
+ apax_run (self ._parameter , log_level = self . log_level )
88
86
89
- def get_metrics_from_plots (self ):
87
+ def get_metrics (self ):
90
88
"""In addition to the plots write a model metric"""
91
89
metrics_df = pd .read_csv (self .model_directory / "log.csv" )
92
- self .metrics = metrics_df .iloc [- 1 ].to_dict ()
90
+ best_epoch = np .argmin (metrics_df ["val_loss" ])
91
+ self .metrics = metrics_df .iloc [best_epoch ].to_dict ()
93
92
94
93
def run (self ):
95
94
"""Primary method to run which executes all steps of the model training"""
96
- ase .io .write (self .train_data_file , self .data )
97
- ase .io .write (self .validation_data_file , self .validation_data )
95
+ if not self .state .restarted :
96
+ ase .io .write (self .train_data_file .as_posix (), self .data )
97
+ ase .io .write (self .validation_data_file .as_posix (), self .validation_data )
98
+
99
+ csv_path = self .model_directory / "log.csv"
100
+ if self .state .restarted and csv_path .is_file ():
101
+ metrics_df = pd .read_csv (self .model_directory / "log.csv" )
102
+
103
+ if metrics_df ["epoch" ].iloc [- 1 ] >= self ._parameter ["n_epochs" ] - 1 :
104
+ return
98
105
99
106
self .train_model ()
100
- self .get_metrics_from_plots ()
107
+ self .get_metrics ()
101
108
102
109
def get_calculator (self , ** kwargs ):
103
110
"""Get an apax ase calculator"""
0 commit comments