Skip to content

Commit

Permalink
1578 Add support to resume previous best metrics in CheckpointSaver (#…
Browse files Browse the repository at this point in the history
…1608)

* [DLMED] add support to resume metrics in CheckpointSaver

Signed-off-by: Nic Ma <nma@nvidia.com>

* [MONAI] python code formatting

Signed-off-by: monai-bot <monai.miccai2019@gmail.com>

* [DLMED] fix doc-string

Signed-off-by: Nic Ma <nma@nvidia.com>

Co-authored-by: monai-bot <monai.miccai2019@gmail.com>
  • Loading branch information
Nic-Ma and monai-bot authored Feb 23, 2021
1 parent 75b78cd commit cb5f899
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 5 deletions.
33 changes: 33 additions & 0 deletions monai/handlers/checkpoint_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# limitations under the License.

import logging
import warnings
from typing import TYPE_CHECKING, Dict, Optional

from monai.utils import exact_version, optional_import
Expand Down Expand Up @@ -55,6 +56,10 @@ class CheckpointSaver:
metric in descending order.
key_metric_filename: set a fixed filename to set the best metric model, if not None,
`key_metric_n_saved` should be 1 and only keep the best metric model.
key_metric_save_state: whether to save the tracking list of key metric in the checkpoint file.
if `True`, then will save an object in the checkpoint file with key `checkpointer` to be consistent
with ignite: https://github.com/pytorch/ignite/blob/master/ignite/handlers/checkpoint.py#L99.
typically, it's used to resume training and compare current metric with previous N values.
epoch_level: save checkpoint during training for every N epochs or every N iterations.
`True` is epoch level, `False` is iteration level.
save_interval: save checkpoint every N epochs, default is 0 to save no checkpoint.
Expand Down Expand Up @@ -84,6 +89,7 @@ def __init__(
key_metric_name: Optional[str] = None,
key_metric_n_saved: int = 1,
key_metric_filename: Optional[str] = None,
key_metric_save_state: bool = False,
epoch_level: bool = True,
save_interval: int = 0,
n_saved: Optional[int] = None,
Expand Down Expand Up @@ -156,6 +162,7 @@ def _score_func(engine: Engine):
score_function=_score_func,
score_name="key_metric",
n_saved=key_metric_n_saved,
include_self=key_metric_save_state,
)

if save_interval > 0:
Expand All @@ -172,6 +179,32 @@ def _interval_func(engine: Engine):
n_saved=n_saved,
)

def load_state_dict(self, state_dict: Dict) -> None:
"""
Utility to resume the internal state of key metric tracking list if configured to save
checkpoints based on the key metric value.
Note to set `key_metric_save_state=True` when saving the previous checkpoint.
Example::
CheckpointSaver(
...
save_key_metric=True,
key_metric_save_state=True, # config to also save the state of this saver
).attach(engine)
engine.run(...)
# resumed training with a new CheckpointSaver
saver = CheckpointSaver(save_key_metric=True, ...)
# load the previous key metric tracking list into saver
CheckpointLoader("/test/model.pt"), {"checkpointer": saver}).attach(engine)
"""
if self._key_metric_checkpoint is not None:
self._key_metric_checkpoint.load_state_dict(state_dict)
else:
warnings.warn("no key metric checkpoint saver to resume the key metric tracking list.")

def attach(self, engine: Engine) -> None:
"""
Args:
Expand Down
54 changes: 49 additions & 5 deletions tests/test_handler_checkpoint_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
from ignite.engine import Engine
from parameterized import parameterized

from monai.handlers import CheckpointSaver
from monai.handlers import CheckpointLoader, CheckpointSaver

TEST_CASE_1 = [True, None, False, None, 1, None, True, 0, None, ["test_checkpoint_final_iteration=40.pt"]]
TEST_CASE_1 = [True, None, False, None, 1, None, False, True, 0, None, ["test_checkpoint_final_iteration=40.pt"]]

TEST_CASE_2 = [
False,
Expand All @@ -31,6 +31,7 @@
"val_loss",
2,
None,
False,
True,
0,
None,
Expand All @@ -44,6 +45,7 @@
None,
1,
None,
False,
True,
2,
2,
Expand All @@ -58,16 +60,17 @@
1,
None,
False,
False,
10,
2,
["test_checkpoint_iteration=30.pt", "test_checkpoint_iteration=40.pt"],
]

TEST_CASE_5 = [True, None, False, None, 1, None, True, 0, None, ["test_checkpoint_final_iteration=40.pt"], True]
TEST_CASE_5 = [True, None, False, None, 1, None, False, True, 0, None, ["test_checkpoint_final_iteration=40.pt"], True]

TEST_CASE_6 = [True, "final_model.pt", False, None, 1, None, True, 0, None, ["final_model.pt"]]
TEST_CASE_6 = [True, "final_model.pt", False, None, 1, None, False, True, 0, None, ["final_model.pt"]]

TEST_CASE_7 = [False, None, True, "val_loss", 1, "model.pt", True, 0, None, ["model.pt"]]
TEST_CASE_7 = [False, None, True, "val_loss", 1, "model.pt", False, True, 0, None, ["model.pt"]]


class TestHandlerCheckpointSaver(unittest.TestCase):
Expand All @@ -80,6 +83,7 @@ def test_file(
key_metric_name,
key_metric_n_saved,
key_metric_filename,
key_metric_save_state,
epoch_level,
save_interval,
n_saved,
Expand Down Expand Up @@ -112,6 +116,7 @@ def _train_func(engine, batch):
key_metric_name,
key_metric_n_saved,
key_metric_filename,
key_metric_save_state,
epoch_level,
save_interval,
n_saved,
Expand Down Expand Up @@ -141,6 +146,45 @@ def _train_func(engine, batch):
engine.run(range(3), max_epochs=2)
self.assertTrue(os.path.exists(os.path.join(tempdir, "net_final_iteration=1.pt")))

def test_load_state_dict(self):
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
net = torch.nn.PReLU()

# set up engine
def _train_func(engine, batch):
engine.state.metrics["val_loss"] = engine.state.iteration

engine = Engine(_train_func)

# set up testing handler
with tempfile.TemporaryDirectory() as tempdir:
engine = Engine(_train_func)
CheckpointSaver(
save_dir=tempdir,
save_dict={"net": net},
save_key_metric=True,
key_metric_name="val_loss",
key_metric_n_saved=2,
key_metric_save_state=True,
).attach(engine)
engine.run(range(3), max_epochs=2)

saver = CheckpointSaver(
save_dir=tempdir,
save_dict={"net": net},
save_key_metric=True,
key_metric_name="val_loss",
key_metric_n_saved=2,
)
engine = Engine(_train_func)
CheckpointLoader(os.path.join(tempdir, "net_key_metric=6.pt"), {"checkpointer": saver}).attach(engine)
engine.run(range(1), max_epochs=1)

resumed = saver._key_metric_checkpoint._saved
for i in range(2):
self.assertEqual(resumed[i].priority, 3 * (i + 1))
self.assertEqual(resumed[i].filename, f"net_key_metric={3 * (i + 1)}.pt")


if __name__ == "__main__":
unittest.main()

0 comments on commit cb5f899

Please sign in to comment.