-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
🐛 Bug
Short: ModelCheckpoint callback with save_top_k does not use semantic meaning (does not reflect order of models) in naming of the files.
Say for save_top_k=5 it saves 5 best models into 5 files, but the user does not know which is which! (The ModelCheckpoint does, but only for the training session, not if we want to resume).
My motivation: I want to access "2nd best" model.
To Reproduce
Reproducible code: #9944 (comment)
Using this ModelCheckpoint config for the example
model_checkpoint:
_target_: pytorch_lightning.callbacks.ModelCheckpoint
monitor: 'val/loss' # name of the logged metric
save_top_k: 5 # save k best models (-1 save all, 0 don't save)
save_last: True # always save model from last epoch
verbose: True # show more detailed info during training
mode: min # can be "max" or "min"
dirpath: 'xxx'
filename: 'best' # use the current epoch number for naming the checkpoint, metrics may also be used
The ModelCheckpoint 1just cyclically saves the new best model as best_v{current_epoch % k}.ckpt
, as seen from the following training log:
Training: -1it [00:00, ?it/s]EPOCH 0
Epoch 0: 100%|██████████| 2100/2100 [00:21<00:00, 96.04it/s, lostrain error: 0.3088467121
validation error: 0.1557712555
Epoch 0, global step 2055: val/loss reached 0.15577 (best 0.15577), saving model to "xxx/best.ckpt" as top 5
EPOCH 1
Epoch 1: 100%|██████████| 2100/2100 [00:20<00:00, 100.72it/s, lotrain error: 0.3018136621
validation error: 0.1548990458
Epoch 1, global step 4111: val/loss reached 0.15490 (best 0.15490), saving model to "xxx/best-v1.ckpt" as top 5
EPOCH 2
Epoch 2: 100%|██████████| 2100/2100 [00:21<00:00, 97.58it/s, lostrain error: 0.2986578345
validation error: 0.1544228494
Epoch 2, global step 6167: val/loss reached 0.15442 (best 0.15442), saving model to "xxx/best-v2.ckpt" as top 5
EPOCH 3
Epoch 3: 100%|██████████| 2100/2100 [00:21<00:00, 98.33it/s, lostrain error: 0.2965526581
validation error: 0.1539662182
Epoch 3, global step 8223: val/loss reached 0.15397 (best 0.15397), saving model to "xxx/best-v3.ckpt" as top 5
EPOCH 4
Epoch 4: 100%|██████████| 2100/2100 [00:21<00:00, 99.14it/s, lostrain error: 0.2950037122
validation error: 0.1536256075
Epoch 4, global step 10279: val/loss reached 0.15363 (best 0.15363), saving model to "xxx/best-v4.ckpt" as top 5
EPOCH 5
Epoch 5: 100%|██████████| 2100/2100 [00:21<00:00, 97.18it/s, lostrain error: 0.2937672734
validation error: 0.1534034163
Epoch 5, global step 12335: val/loss reached 0.15340 (best 0.15340), saving model to "xxx/best.ckpt" as top 5
EPOCH 6
Epoch 6: 100%|██████████| 2100/2100 [00:21<00:00, 98.24it/s, lostrain error: 0.2926785052
validation error: 0.1531589478
Epoch 6, global step 14391: val/loss reached 0.15316 (best 0.15316), saving model to "xxx/best-v1.ckpt" as top 5
EPOCH 7
Epoch 7: 100%|██████████| 2100/2100 [00:21<00:00, 96.33it/s, lostrain error: 0.2916747034
validation error: 0.1529426873
Epoch 7, global step 16447: val/loss reached 0.15294 (best 0.15294), saving model to "xxx/best-v2.ckpt" as top 5
EPOCH 8
Epoch 8: 100%|██████████| 2100/2100 [00:22<00:00, 92.43it/s, lostrain error: 0.2907347977
validation error: 0.1527983993
Epoch 8, global step 18503: val/loss reached 0.15280 (best 0.15280), saving model to "xxx/best-v3.ckpt" as top 5
EPOCH 9
Epoch 9: 100%|██████████| 2100/2100 [00:20<00:00, 101.61it/s, lotrain error: ] 0.2898018062
validation error: 0.1526378989
Epoch 9, global step 20559: val/loss reached 0.15264 (best 0.15264), saving model to "xxx/best-v4.ckpt" as top 5
Epoch 9: 100%|██████████| 2100/2100 [00:20<00:00, 101.42it/s, loss=0.289, v_num=gdir]Saving latest checkpoint...
Setting period from checkpoint test_set
In this example, the real best model is:
Epoch 9, global step 20559: val/loss reached 0.15264 (best 0.15264), saving model to "xxx/best-v4.ckpt"
Now the trick is that in trainer.test(..., ckpt_path="best")
(same for validate()
and predict()
)
the "best"
is not "best.ckpt"
but the best filename that only the callback knows, as seen from the following code that is used by above methods:
def __load_ckpt_weights(self, ckpt_path: Optional[str]) -> Optional[str]:
if ckpt_path is None:
return
fn = self.state.fn.value
if ckpt_path == "best":
# if user requests the best checkpoint but we don't have it, error
if not self.checkpoint_callback.best_model_path:
if self.fast_dev_run:
raise MisconfigurationException(
f"You cannot execute `.{fn}()` with `fast_dev_run=True` unless you do"
f" `.{fn}(ckpt_path=PATH)` as no checkpoint path was generated during fitting."
)
raise MisconfigurationException(
f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured to save the best model.'
)
# load best weights
ckpt_path = self.checkpoint_callback.best_model_path
...
Problems:
- might not work after the callback is finised (end of training session). Then the info "which saved checkpoint file is really the best" is lost.
- no semantics of filenames (ordering)
- best.ckpt is not always the best model. That is confusing.
Expected behavior
The ModelCheckpoint callback would keep an ordering of the top_k best models. And re-save them accordingly on each change.
So best.ckpt
= always the best model, best_v4.ckpt
= 5th best model. The filenames (semantics) would remain the same during training, the content (models) of the files would update.
Benefits:
best.ckpt
is the best model, so users can manually load it for other use-cases thantest( ckpt_path="best")
- we can access Nth best model
Environment
latest PL 1.4.9