Skip to content

Commit

Permalink
Update Checkpoint Loader docs with an example (#6871)
Browse files Browse the repository at this point in the history
Related to #6687, this adds some documentation how to resume a network.
I did not find this documented anywhere in MONAI and it's rather
important imo. Please shortly check doc the syntax, I don't have any
experience with it. The code itself is from my own repository and
working.

What is potentially still missing here is some information on when to
use the Checkpoint Loader. As far I understood the docs, this shall
mostly be used to resume interrupted training runs.
But what is 1) with pure inference runs where the state of the trainer
does not matter, only the evaluator 2) Resuming the training at epoch
200 but with a learning rate reset (e.g. DeepEdit train without clicks
first for 200 epochs, then 200 epochs with clicks on top).
1 works well in my experience, 2 as well if you modify the state_dict to
exclude e.g. the learning rate scheduler.

### Description

A few sentences describing the changes proposed in this pull request.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Matthias Hadlich <matthiashadlich@posteo.de>
  • Loading branch information
matt3o authored Aug 15, 2023
1 parent a01a726 commit 5833b1c
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions monai/handlers/checkpoint_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,23 @@ class CheckpointLoader:
If saving checkpoint after `torch.nn.DataParallel`, need to save `model.module` instead
as PyTorch recommended and then use this loader to load the model.
Usage example::
trainer = SupervisedTrainer(...)
save_dict = {
"trainer": trainer,
"net": network,
"opt": optimizer,
"lr": lr_scheduler,
}
map_location = "cuda:0"
# checkpoint needs to have same save_dict for this to work
handler = CheckpointLoader(load_path="/test/checkpoint.pt", load_dict=save_dict, map_location=map_location, strict=True)
handler(trainer)
# Trainer now has the same state as stored, including the number of epochs and iterations completed
# so you can resume an interrupted training at the place where it left
Args:
load_path: the file path of checkpoint, it should be a PyTorch `pth` file.
load_dict: target objects that load checkpoint to. examples::
Expand Down

0 comments on commit 5833b1c

Please sign in to comment.