-
Notifications
You must be signed in to change notification settings - Fork 3.4k
/
model_checkpoint.py
255 lines (209 loc) · 9.69 KB
/
model_checkpoint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
"""
Model Checkpointing
===================
Automatically save model checkpoints during training.
"""
import os
import re
import numpy as np
from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_warn
import torch
class ModelCheckpoint(Callback):
r"""
Save the model after every epoch.
Args:
filepath: path to save the model file.
Can contain named formatting options to be auto-filled.
Example::
# custom path
# saves a file like: my/path/epoch_0.ckpt
>>> checkpoint_callback = ModelCheckpoint('my/path/')
# save any arbitrary metrics like `val_loss`, etc. in name
# saves a file like: my/path/epoch=2-val_loss=0.2_other_metric=0.3.ckpt
>>> checkpoint_callback = ModelCheckpoint(
... filepath='my/path/{epoch}-{val_loss:.2f}-{other_metric:.2f}'
... )
monitor: quantity to monitor.
verbose: verbosity mode. Default: ``False``.
save_top_k: if `save_top_k == k`,
the best k models according to
the quantity monitored will be saved.
if ``save_top_k == 0``, no models are saved.
if ``save_top_k == -1``, all models are saved.
Please note that the monitors are checked every `period` epochs.
if ``save_top_k >= 2`` and the callback is called multiple
times inside an epoch, the name of the saved file will be
appended with a version count starting with `v0`.
mode: one of {auto, min, max}.
If ``save_top_k != 0``, the decision
to overwrite the current save file is made
based on either the maximization or the
minimization of the monitored quantity. For `val_acc`,
this should be `max`, for `val_loss` this should
be `min`, etc. In `auto` mode, the direction is
automatically inferred from the name of the monitored quantity.
save_weights_only: if ``True``, then only the model's weights will be
saved (``model.save_weights(filepath)``), else the full model
is saved (``model.save(filepath)``).
period: Interval (number of epochs) between checkpoints.
Example::
>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import ModelCheckpoint
# saves checkpoints to 'my/path/' whenever 'val_loss' has a new min
>>> checkpoint_callback = ModelCheckpoint(filepath='my/path/')
>>> trainer = Trainer(checkpoint_callback=checkpoint_callback)
# save epoch and val_loss in name
# saves a file like: my/path/sample-mnist_epoch=02_val_loss=0.32.ckpt
>>> checkpoint_callback = ModelCheckpoint(
... filepath='my/path/sample-mnist_{epoch:02d}-{val_loss:.2f}'
... )
"""
def __init__(self, filepath: str, monitor: str = 'val_loss', verbose: bool = False,
save_top_k: int = 1, save_weights_only: bool = False,
mode: str = 'auto', period: int = 1, prefix: str = ''):
super().__init__()
if save_top_k > 0 and os.path.isdir(filepath) and len(os.listdir(filepath)) > 0:
rank_zero_warn(
f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0."
"All files in this directory will be deleted when a checkpoint is saved!"
)
self.monitor = monitor
self.verbose = verbose
if os.path.isdir(filepath):
self.dirpath, self.filename = filepath, '{epoch}'
else:
self.dirpath, self.filename = os.path.split(filepath)
os.makedirs(self.dirpath, exist_ok=True)
self.save_top_k = save_top_k
self.save_weights_only = save_weights_only
self.period = period
self.epoch_last_check = None
self.prefix = prefix
self.best_k_models = {}
# {filename: monitor}
self.kth_best_model = ''
self.best = 0
self.save_function = None
torch_inf = torch.tensor(np.Inf)
mode_dict = {
'min': (torch.lt, torch_inf, 'min'),
'max': (torch.gt, -torch_inf, 'max'),
'auto': (torch.gt, -torch_inf, 'max') if 'acc' in self.monitor or self.monitor.startswith('fmeasure')
else (torch.lt, torch_inf, 'min'),
}
if mode not in mode_dict:
rank_zero_warn(f'ModelCheckpoint mode {mode} is unknown, '
f'fallback to auto mode.', RuntimeWarning)
mode = 'auto'
self.monitor_op, self.kth_value, self.mode = mode_dict[mode]
def _del_model(self, filepath):
os.remove(filepath)
def _save_model(self, filepath):
# make paths
os.makedirs(os.path.dirname(filepath), exist_ok=True)
# delegate the saving to the model
if self.save_function is not None:
self.save_function(filepath)
else:
raise ValueError(".save_function() not set")
def check_monitor_top_k(self, current):
less_than_k_models = len(self.best_k_models) < self.save_top_k
if less_than_k_models:
return True
if not isinstance(current, torch.Tensor):
current = torch.tensor(current)
return self.monitor_op(current, self.best_k_models[self.kth_best_model])
def format_checkpoint_name(self, epoch, metrics, ver=None):
"""Generate a filename according to the defined template.
Example::
>>> tmpdir = os.path.dirname(__file__)
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch}'))
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
'epoch=0.ckpt'
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch:03d}'))
>>> os.path.basename(ckpt.format_checkpoint_name(5, {}))
'epoch=005.ckpt'
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch}-{val_loss:.2f}'))
>>> os.path.basename(ckpt.format_checkpoint_name(2, dict(val_loss=0.123456)))
'epoch=2-val_loss=0.12.ckpt'
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{missing:d}'))
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
'missing=0.ckpt'
"""
# check if user passed in keys to the string
groups = re.findall(r'(\{.*?)[:\}]', self.filename)
if len(groups) == 0:
# default name
filename = f'{self.prefix}_ckpt_epoch_{epoch}'
else:
metrics['epoch'] = epoch
filename = self.filename
for tmp in groups:
name = tmp[1:]
filename = filename.replace(tmp, name + '={' + name)
if name not in metrics:
metrics[name] = 0
filename = filename.format(**metrics)
str_ver = f'_v{ver}' if ver is not None else ''
filepath = os.path.join(self.dirpath, self.prefix + filename + str_ver + '.ckpt')
return filepath
def on_validation_end(self, trainer, pl_module):
# only run on main process
if trainer.proc_rank != 0:
return
metrics = trainer.callback_metrics
epoch = trainer.current_epoch
if self.save_top_k == 0:
# no models are saved
return
if self.epoch_last_check is not None and (epoch - self.epoch_last_check) < self.period:
# skipping in this term
return
self.epoch_last_check = epoch
filepath = self.format_checkpoint_name(epoch, metrics)
version_cnt = 0
while os.path.isfile(filepath):
filepath = self.format_checkpoint_name(epoch, metrics, ver=version_cnt)
# this epoch called before
version_cnt += 1
if self.save_top_k != -1:
current = metrics.get(self.monitor)
if current is None:
rank_zero_warn(
f'Can save best model only with {self.monitor} available, skipping.', RuntimeWarning
)
elif self.check_monitor_top_k(current):
self._do_check_save(filepath, current, epoch)
elif self.verbose > 0:
log.info(f'\nEpoch {epoch:05d}: {self.monitor} was not in top {self.save_top_k}')
else:
if self.verbose > 0:
log.info(f'\nEpoch {epoch:05d}: saving model to {filepath}')
self._save_model(filepath)
def _do_check_save(self, filepath, current, epoch):
# remove kth
del_list = []
if len(self.best_k_models) == self.save_top_k and self.save_top_k > 0:
delpath = self.kth_best_model
self.best_k_models.pop(self.kth_best_model)
del_list.append(delpath)
self.best_k_models[filepath] = current
if len(self.best_k_models) == self.save_top_k:
# monitor dict has reached k elements
_op = max if self.mode == 'min' else min
self.kth_best_model = _op(self.best_k_models,
key=self.best_k_models.get)
self.kth_value = self.best_k_models[self.kth_best_model]
_op = min if self.mode == 'min' else max
self.best = _op(self.best_k_models.values())
if self.verbose > 0:
log.info(
f'\nEpoch {epoch:05d}: {self.monitor} reached'
f' {current:0.5f} (best {self.best:0.5f}), saving model to'
f' {filepath} as top {self.save_top_k}')
self._save_model(filepath)
for cur_path in del_list:
if cur_path != filepath:
self._del_model(cur_path)