Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1578 Add support to resume previous best metrics in CheckpointSaver #1608

Merged
merged 6 commits into from
Feb 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions monai/handlers/checkpoint_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# limitations under the License.

import logging
import warnings
from typing import TYPE_CHECKING, Dict, Optional

from monai.utils import exact_version, optional_import
Expand Down Expand Up @@ -55,6 +56,10 @@ class CheckpointSaver:
metric in descending order.
key_metric_filename: set a fixed filename to set the best metric model, if not None,
`key_metric_n_saved` should be 1 and only keep the best metric model.
key_metric_save_state: whether to save the tracking list of key metric in the checkpoint file.
if `True`, then will save an object in the checkpoint file with key `checkpointer` to be consistent
with ignite: https://github.com/pytorch/ignite/blob/master/ignite/handlers/checkpoint.py#L99.
typically, it's used to resume training and compare current metric with previous N values.
epoch_level: save checkpoint during training for every N epochs or every N iterations.
`True` is epoch level, `False` is iteration level.
save_interval: save checkpoint every N epochs, default is 0 to save no checkpoint.
Expand Down Expand Up @@ -84,6 +89,7 @@ def __init__(
key_metric_name: Optional[str] = None,
key_metric_n_saved: int = 1,
key_metric_filename: Optional[str] = None,
key_metric_save_state: bool = False,
epoch_level: bool = True,
save_interval: int = 0,
n_saved: Optional[int] = None,
Expand Down Expand Up @@ -156,6 +162,7 @@ def _score_func(engine: Engine):
score_function=_score_func,
score_name="key_metric",
n_saved=key_metric_n_saved,
include_self=key_metric_save_state,
)

if save_interval > 0:
Expand All @@ -172,6 +179,32 @@ def _interval_func(engine: Engine):
n_saved=n_saved,
)

def load_state_dict(self, state_dict: Dict) -> None:
"""
Utility to resume the internal state of key metric tracking list if configured to save
checkpoints based on the key metric value.
Note to set `key_metric_save_state=True` when saving the previous checkpoint.

Example::

CheckpointSaver(
...
save_key_metric=True,
key_metric_save_state=True, # config to also save the state of this saver
).attach(engine)
engine.run(...)

# resumed training with a new CheckpointSaver
saver = CheckpointSaver(save_key_metric=True, ...)
# load the previous key metric tracking list into saver
CheckpointLoader("/test/model.pt"), {"checkpointer": saver}).attach(engine)

"""
if self._key_metric_checkpoint is not None:
self._key_metric_checkpoint.load_state_dict(state_dict)
else:
warnings.warn("no key metric checkpoint saver to resume the key metric tracking list.")

def attach(self, engine: Engine) -> None:
"""
Args:
Expand Down
54 changes: 49 additions & 5 deletions tests/test_handler_checkpoint_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
from ignite.engine import Engine
from parameterized import parameterized

from monai.handlers import CheckpointSaver
from monai.handlers import CheckpointLoader, CheckpointSaver

TEST_CASE_1 = [True, None, False, None, 1, None, True, 0, None, ["test_checkpoint_final_iteration=40.pt"]]
TEST_CASE_1 = [True, None, False, None, 1, None, False, True, 0, None, ["test_checkpoint_final_iteration=40.pt"]]

TEST_CASE_2 = [
False,
Expand All @@ -31,6 +31,7 @@
"val_loss",
2,
None,
False,
True,
0,
None,
Expand All @@ -44,6 +45,7 @@
None,
1,
None,
False,
True,
2,
2,
Expand All @@ -58,16 +60,17 @@
1,
None,
False,
False,
10,
2,
["test_checkpoint_iteration=30.pt", "test_checkpoint_iteration=40.pt"],
]

TEST_CASE_5 = [True, None, False, None, 1, None, True, 0, None, ["test_checkpoint_final_iteration=40.pt"], True]
TEST_CASE_5 = [True, None, False, None, 1, None, False, True, 0, None, ["test_checkpoint_final_iteration=40.pt"], True]

TEST_CASE_6 = [True, "final_model.pt", False, None, 1, None, True, 0, None, ["final_model.pt"]]
TEST_CASE_6 = [True, "final_model.pt", False, None, 1, None, False, True, 0, None, ["final_model.pt"]]

TEST_CASE_7 = [False, None, True, "val_loss", 1, "model.pt", True, 0, None, ["model.pt"]]
TEST_CASE_7 = [False, None, True, "val_loss", 1, "model.pt", False, True, 0, None, ["model.pt"]]


class TestHandlerCheckpointSaver(unittest.TestCase):
Expand All @@ -80,6 +83,7 @@ def test_file(
key_metric_name,
key_metric_n_saved,
key_metric_filename,
key_metric_save_state,
epoch_level,
save_interval,
n_saved,
Expand Down Expand Up @@ -112,6 +116,7 @@ def _train_func(engine, batch):
key_metric_name,
key_metric_n_saved,
key_metric_filename,
key_metric_save_state,
epoch_level,
save_interval,
n_saved,
Expand Down Expand Up @@ -141,6 +146,45 @@ def _train_func(engine, batch):
engine.run(range(3), max_epochs=2)
self.assertTrue(os.path.exists(os.path.join(tempdir, "net_final_iteration=1.pt")))

def test_load_state_dict(self):
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
net = torch.nn.PReLU()

# set up engine
def _train_func(engine, batch):
engine.state.metrics["val_loss"] = engine.state.iteration

engine = Engine(_train_func)

# set up testing handler
with tempfile.TemporaryDirectory() as tempdir:
engine = Engine(_train_func)
CheckpointSaver(
save_dir=tempdir,
save_dict={"net": net},
save_key_metric=True,
key_metric_name="val_loss",
key_metric_n_saved=2,
key_metric_save_state=True,
).attach(engine)
engine.run(range(3), max_epochs=2)

saver = CheckpointSaver(
save_dir=tempdir,
save_dict={"net": net},
save_key_metric=True,
key_metric_name="val_loss",
key_metric_n_saved=2,
)
engine = Engine(_train_func)
CheckpointLoader(os.path.join(tempdir, "net_key_metric=6.pt"), {"checkpointer": saver}).attach(engine)
engine.run(range(1), max_epochs=1)

resumed = saver._key_metric_checkpoint._saved
for i in range(2):
self.assertEqual(resumed[i].priority, 3 * (i + 1))
self.assertEqual(resumed[i].filename, f"net_key_metric={3 * (i + 1)}.pt")


if __name__ == "__main__":
unittest.main()