Skip to content

ModelCheckpoint: save_top_k > 1 cannot recognize ordering of models from ckpt names #9944

@breznak

Description

@breznak

🐛 Bug

Short: ModelCheckpoint callback with save_top_k does not use semantic meaning (does not reflect order of models) in naming of the files.

Say for save_top_k=5 it saves 5 best models into 5 files, but the user does not know which is which! (The ModelCheckpoint does, but only for the training session, not if we want to resume).

My motivation: I want to access "2nd best" model.

To Reproduce

Reproducible code: #9944 (comment)

Using this ModelCheckpoint config for the example


model_checkpoint:
  _target_: pytorch_lightning.callbacks.ModelCheckpoint
  monitor: 'val/loss'       # name of the logged metric
  save_top_k: 5             # save k best models (-1 save all, 0 don't save)
  save_last: True           # always save model from last epoch
  verbose: True             # show more detailed info during training
  mode: min                 # can be "max" or "min"
  dirpath: 'xxx'
  filename: 'best'          # use the current epoch number for naming the checkpoint, metrics may also be used

The ModelCheckpoint 1just cyclically saves the new best model as best_v{current_epoch % k}.ckpt, as seen from the following training log:

Training: -1it [00:00, ?it/s]EPOCH 0
Epoch 0: 100%|██████████| 2100/2100 [00:21<00:00, 96.04it/s, lostrain error:        	0.3088467121
        validation error:       0.1557712555               
Epoch 0, global step 2055: val/loss reached 0.15577 (best 0.15577), saving model to "xxx/best.ckpt" as top 5
EPOCH 1
Epoch 1: 100%|██████████| 2100/2100 [00:20<00:00, 100.72it/s, lotrain error:            0.3018136621
        validation error:       0.1548990458               
Epoch 1, global step 4111: val/loss reached 0.15490 (best 0.15490), saving model to "xxx/best-v1.ckpt" as top 5
EPOCH 2
Epoch 2: 100%|██████████| 2100/2100 [00:21<00:00, 97.58it/s, lostrain error:            0.2986578345
        validation error:       0.1544228494               
Epoch 2, global step 6167: val/loss reached 0.15442 (best 0.15442), saving model to "xxx/best-v2.ckpt" as top 5
EPOCH 3
Epoch 3: 100%|██████████| 2100/2100 [00:21<00:00, 98.33it/s, lostrain error:            0.2965526581
        validation error:       0.1539662182               
Epoch 3, global step 8223: val/loss reached 0.15397 (best 0.15397), saving model to "xxx/best-v3.ckpt" as top 5
EPOCH 4
Epoch 4: 100%|██████████| 2100/2100 [00:21<00:00, 99.14it/s, lostrain error:            0.2950037122
        validation error:       0.1536256075               
Epoch 4, global step 10279: val/loss reached 0.15363 (best 0.15363), saving model to "xxx/best-v4.ckpt" as top 5
EPOCH 5
Epoch 5: 100%|██████████| 2100/2100 [00:21<00:00, 97.18it/s, lostrain error:        	0.2937672734
        validation error:       0.1534034163               
Epoch 5, global step 12335: val/loss reached 0.15340 (best 0.15340), saving model to "xxx/best.ckpt" as top 5
EPOCH 6
Epoch 6: 100%|██████████| 2100/2100 [00:21<00:00, 98.24it/s, lostrain error:            0.2926785052
        validation error:       0.1531589478               
Epoch 6, global step 14391: val/loss reached 0.15316 (best 0.15316), saving model to "xxx/best-v1.ckpt" as top 5
EPOCH 7
Epoch 7: 100%|██████████| 2100/2100 [00:21<00:00, 96.33it/s, lostrain error:        	0.2916747034
        validation error:       0.1529426873               
Epoch 7, global step 16447: val/loss reached 0.15294 (best 0.15294), saving model to "xxx/best-v2.ckpt" as top 5
EPOCH 8
Epoch 8: 100%|██████████| 2100/2100 [00:22<00:00, 92.43it/s, lostrain error:        	0.2907347977
        validation error:       0.1527983993               
Epoch 8, global step 18503: val/loss reached 0.15280 (best 0.15280), saving model to "xxx/best-v3.ckpt" as top 5
EPOCH 9
Epoch 9: 100%|██████████| 2100/2100 [00:20<00:00, 101.61it/s, lotrain error:        ]   0.2898018062
        validation error:       0.1526378989               
Epoch 9, global step 20559: val/loss reached 0.15264 (best 0.15264), saving model to "xxx/best-v4.ckpt" as top 5
Epoch 9: 100%|██████████| 2100/2100 [00:20<00:00, 101.42it/s, loss=0.289, v_num=gdir]Saving latest checkpoint...
Setting period from checkpoint test_set

In this example, the real best model is:

Epoch 9, global step 20559: val/loss reached 0.15264 (best 0.15264), saving model to "xxx/best-v4.ckpt"

Now the trick is that in trainer.test(..., ckpt_path="best") (same for validate() and predict() )
the "best" is not "best.ckpt" but the best filename that only the callback knows, as seen from the following code that is used by above methods:


    def __load_ckpt_weights(self, ckpt_path: Optional[str]) -> Optional[str]:
        if ckpt_path is None:
            return

        fn = self.state.fn.value

        if ckpt_path == "best":
            # if user requests the best checkpoint but we don't have it, error
            if not self.checkpoint_callback.best_model_path:
                if self.fast_dev_run:
                    raise MisconfigurationException(
                        f"You cannot execute `.{fn}()` with `fast_dev_run=True` unless you do"
                        f" `.{fn}(ckpt_path=PATH)` as no checkpoint path was generated during fitting."
                    )
                raise MisconfigurationException(
                    f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured to save the best model.'
                )
            # load best weights
            ckpt_path = self.checkpoint_callback.best_model_path

...

Problems:

  • might not work after the callback is finised (end of training session). Then the info "which saved checkpoint file is really the best" is lost.
  • no semantics of filenames (ordering)
  • best.ckpt is not always the best model. That is confusing.

Expected behavior

The ModelCheckpoint callback would keep an ordering of the top_k best models. And re-save them accordingly on each change.
So best.ckpt = always the best model, best_v4.ckpt = 5th best model. The filenames (semantics) would remain the same during training, the content (models) of the files would update.

Benefits:

  • best.ckpt is the best model, so users can manually load it for other use-cases than test( ckpt_path="best")
  • we can access Nth best model

Environment

latest PL 1.4.9

Metadata

Metadata

Assignees

No one assigned

    Labels

    checkpointingRelated to checkpointingfeatureIs an improvement or enhancementhelp wantedOpen to be worked on

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions