From 5833b1cb64b483ae7817cddffe9ac155ad71aa4a Mon Sep 17 00:00:00 2001 From: Matthias Hadlich Date: Tue, 15 Aug 2023 16:25:53 +0200 Subject: [PATCH] Update Checkpoint Loader docs with an example (#6871) Related to #6687, this adds some documentation how to resume a network. I did not find this documented anywhere in MONAI and it's rather important imo. Please shortly check doc the syntax, I don't have any experience with it. The code itself is from my own repository and working. What is potentially still missing here is some information on when to use the Checkpoint Loader. As far I understood the docs, this shall mostly be used to resume interrupted training runs. But what is 1) with pure inference runs where the state of the trainer does not matter, only the evaluator 2) Resuming the training at epoch 200 but with a learning rate reset (e.g. DeepEdit train without clicks first for 200 epochs, then 200 epochs with clicks on top). 1 works well in my experience, 2 as well if you modify the state_dict to exclude e.g. the learning rate scheduler. ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Matthias Hadlich --- monai/handlers/checkpoint_loader.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/monai/handlers/checkpoint_loader.py b/monai/handlers/checkpoint_loader.py index 5b05e7055c..9a867534a3 100644 --- a/monai/handlers/checkpoint_loader.py +++ b/monai/handlers/checkpoint_loader.py @@ -36,6 +36,23 @@ class CheckpointLoader: If saving checkpoint after `torch.nn.DataParallel`, need to save `model.module` instead as PyTorch recommended and then use this loader to load the model. + Usage example:: + + trainer = SupervisedTrainer(...) + save_dict = { + "trainer": trainer, + "net": network, + "opt": optimizer, + "lr": lr_scheduler, + } + + map_location = "cuda:0" + # checkpoint needs to have same save_dict for this to work + handler = CheckpointLoader(load_path="/test/checkpoint.pt", load_dict=save_dict, map_location=map_location, strict=True) + handler(trainer) + # Trainer now has the same state as stored, including the number of epochs and iterations completed + # so you can resume an interrupted training at the place where it left + Args: load_path: the file path of checkpoint, it should be a PyTorch `pth` file. load_dict: target objects that load checkpoint to. examples::