Skip to content

Commit d99ec5a

Browse files
authored
Merge pull request #289 from apax-hub/node_fix
Node fix
2 parents ebc07ff + 6129612 commit d99ec5a

File tree

3 files changed

+70
-18
lines changed

3 files changed

+70
-18
lines changed

apax/model/builder.py

-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ def __init__(self, model_config: ModelConfig, n_species: int = 119):
1919
self.n_species = n_species
2020

2121
def build_basis_function(self):
22-
2322
basis_config = self.config["basis"]
2423
name = basis_config["name"]
2524

apax/nodes/model.py

+21-14
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import typing as t
44

55
import ase.io
6+
import numpy as np
67
import pandas as pd
78
import yaml
89
import zntrack.utils
@@ -21,7 +22,7 @@ class ApaxBase(zntrack.Node):
2122

2223

2324
class Apax(ApaxBase):
24-
"""Class for the implementation of the apax model
25+
"""Class for traing Apax models
2526
2627
Parameters
2728
----------
@@ -32,19 +33,16 @@ class Apax(ApaxBase):
3233
validation_data: list[ase.Atoms]
3334
atoms object with the validation data set
3435
model: t.Optional[Apax]
35-
model to be used as a base model
36-
model_directory: pathlib.Path
37-
model directory
38-
train_data_file: pathlib.Path
39-
output path to the training data
40-
validation_data_file: pathlib.Path
41-
output path to the validation data
36+
model to be used as a base model for transfer learning
37+
log_level: str
38+
verbosity of logging during training
4239
"""
4340

4441
data: list = zntrack.deps()
4542
config: str = zntrack.params_path()
4643
validation_data = zntrack.deps()
4744
model: t.Optional[t.Any] = zntrack.deps(None)
45+
log_level: str = zntrack.meta.Text("info")
4846

4947
model_directory: pathlib.Path = zntrack.outs_path(zntrack.nwd / "apax_model")
5048

@@ -84,20 +82,29 @@ def _handle_parameter_file(self):
8482

8583
def train_model(self):
8684
"""Train the model using `apax.train.run`"""
87-
apax_run(self._parameter)
85+
apax_run(self._parameter, log_level=self.log_level)
8886

89-
def get_metrics_from_plots(self):
87+
def get_metrics(self):
9088
"""In addition to the plots write a model metric"""
9189
metrics_df = pd.read_csv(self.model_directory / "log.csv")
92-
self.metrics = metrics_df.iloc[-1].to_dict()
90+
best_epoch = np.argmin(metrics_df["val_loss"])
91+
self.metrics = metrics_df.iloc[best_epoch].to_dict()
9392

9493
def run(self):
9594
"""Primary method to run which executes all steps of the model training"""
96-
ase.io.write(self.train_data_file, self.data)
97-
ase.io.write(self.validation_data_file, self.validation_data)
95+
if not self.state.restarted:
96+
ase.io.write(self.train_data_file.as_posix(), self.data)
97+
ase.io.write(self.validation_data_file.as_posix(), self.validation_data)
98+
99+
csv_path = self.model_directory / "log.csv"
100+
if self.state.restarted and csv_path.is_file():
101+
metrics_df = pd.read_csv(self.model_directory / "log.csv")
102+
103+
if metrics_df["epoch"].iloc[-1] >= self._parameter["n_epochs"] - 1:
104+
return
98105

99106
self.train_model()
100-
self.get_metrics_from_plots()
107+
self.get_metrics()
101108

102109
def get_calculator(self, **kwargs):
103110
"""Get an apax ase calculator"""

apax/train/callbacks.py

+49-3
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,67 @@
1313
log = logging.getLogger(__name__)
1414

1515

16+
def format_str(k):
17+
return f"{k:.5f}"
18+
19+
1620
class CSVLoggerApax(CSVLogger):
1721
def __init__(self, filename, separator=",", append=False):
18-
super().__init__(filename, separator=",", append=False)
22+
super().__init__(filename, separator=separator, append=append)
1923

20-
def on_test_batch_end(self, batch, logs=None):
24+
def on_epoch_end(self, epoch, logs=None):
2125
logs = logs or {}
2226

2327
def handle_value(k):
2428
is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0
2529
if isinstance(k, str):
2630
return k
2731
elif isinstance(k, collections.abc.Iterable) and not is_zero_dim_ndarray:
28-
return f"\"[{', '.join(map(str, k))}]\""
32+
return f"\"[{', '.join(map(format_str, k))}]\""
2933
else:
34+
return format_str(k)
35+
36+
if self.keys is None:
37+
self.keys = sorted(logs.keys())
38+
# When validation_freq > 1, `val_` keys are not in first epoch logs
39+
# Add the `val_` keys so that its part of the fieldnames of writer.
40+
val_keys_found = False
41+
for key in self.keys:
42+
if key.startswith("val_"):
43+
val_keys_found = True
44+
break
45+
if not val_keys_found:
46+
self.keys.extend(["val_" + k for k in self.keys])
47+
48+
if not self.writer:
49+
50+
class CustomDialect(csv.excel):
51+
delimiter = self.sep
52+
53+
fieldnames = ["epoch"] + self.keys
54+
55+
self.writer = csv.DictWriter(
56+
self.csv_file, fieldnames=fieldnames, dialect=CustomDialect
57+
)
58+
if self.append_header:
59+
self.writer.writeheader()
60+
61+
row_dict = collections.OrderedDict({"epoch": epoch})
62+
row_dict.update((key, handle_value(logs.get(key, "NA"))) for key in self.keys)
63+
self.writer.writerow(row_dict)
64+
self.csv_file.flush()
65+
66+
def on_test_batch_end(self, batch, logs=None):
67+
logs = logs or {}
68+
69+
def handle_value(k):
70+
is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0
71+
if isinstance(k, str):
3072
return k
73+
elif isinstance(k, collections.abc.Iterable) and not is_zero_dim_ndarray:
74+
return f"\"[{', '.join(map(format_str, k))}]\""
75+
else:
76+
return format_str(k)
3177

3278
if self.keys is None:
3379
self.keys = sorted(logs.keys())

0 commit comments

Comments
 (0)