Skip to content

Commit cb5f899

Browse files
Nic-Mamonai-bot
andauthored
1578 Add support to resume previous best metrics in CheckpointSaver (#1608)
* [DLMED] add support to resume metrics in CheckpointSaver Signed-off-by: Nic Ma <nma@nvidia.com> * [MONAI] python code formatting Signed-off-by: monai-bot <monai.miccai2019@gmail.com> * [DLMED] fix doc-string Signed-off-by: Nic Ma <nma@nvidia.com> Co-authored-by: monai-bot <monai.miccai2019@gmail.com>
1 parent 75b78cd commit cb5f899

File tree

2 files changed

+82
-5
lines changed

2 files changed

+82
-5
lines changed

monai/handlers/checkpoint_saver.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# limitations under the License.
1111

1212
import logging
13+
import warnings
1314
from typing import TYPE_CHECKING, Dict, Optional
1415

1516
from monai.utils import exact_version, optional_import
@@ -55,6 +56,10 @@ class CheckpointSaver:
5556
metric in descending order.
5657
key_metric_filename: set a fixed filename to set the best metric model, if not None,
5758
`key_metric_n_saved` should be 1 and only keep the best metric model.
59+
key_metric_save_state: whether to save the tracking list of key metric in the checkpoint file.
60+
if `True`, then will save an object in the checkpoint file with key `checkpointer` to be consistent
61+
with ignite: https://github.com/pytorch/ignite/blob/master/ignite/handlers/checkpoint.py#L99.
62+
typically, it's used to resume training and compare current metric with previous N values.
5863
epoch_level: save checkpoint during training for every N epochs or every N iterations.
5964
`True` is epoch level, `False` is iteration level.
6065
save_interval: save checkpoint every N epochs, default is 0 to save no checkpoint.
@@ -84,6 +89,7 @@ def __init__(
8489
key_metric_name: Optional[str] = None,
8590
key_metric_n_saved: int = 1,
8691
key_metric_filename: Optional[str] = None,
92+
key_metric_save_state: bool = False,
8793
epoch_level: bool = True,
8894
save_interval: int = 0,
8995
n_saved: Optional[int] = None,
@@ -156,6 +162,7 @@ def _score_func(engine: Engine):
156162
score_function=_score_func,
157163
score_name="key_metric",
158164
n_saved=key_metric_n_saved,
165+
include_self=key_metric_save_state,
159166
)
160167

161168
if save_interval > 0:
@@ -172,6 +179,32 @@ def _interval_func(engine: Engine):
172179
n_saved=n_saved,
173180
)
174181

182+
def load_state_dict(self, state_dict: Dict) -> None:
183+
"""
184+
Utility to resume the internal state of key metric tracking list if configured to save
185+
checkpoints based on the key metric value.
186+
Note to set `key_metric_save_state=True` when saving the previous checkpoint.
187+
188+
Example::
189+
190+
CheckpointSaver(
191+
...
192+
save_key_metric=True,
193+
key_metric_save_state=True, # config to also save the state of this saver
194+
).attach(engine)
195+
engine.run(...)
196+
197+
# resumed training with a new CheckpointSaver
198+
saver = CheckpointSaver(save_key_metric=True, ...)
199+
# load the previous key metric tracking list into saver
200+
CheckpointLoader("/test/model.pt"), {"checkpointer": saver}).attach(engine)
201+
202+
"""
203+
if self._key_metric_checkpoint is not None:
204+
self._key_metric_checkpoint.load_state_dict(state_dict)
205+
else:
206+
warnings.warn("no key metric checkpoint saver to resume the key metric tracking list.")
207+
175208
def attach(self, engine: Engine) -> None:
176209
"""
177210
Args:

tests/test_handler_checkpoint_saver.py

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@
2020
from ignite.engine import Engine
2121
from parameterized import parameterized
2222

23-
from monai.handlers import CheckpointSaver
23+
from monai.handlers import CheckpointLoader, CheckpointSaver
2424

25-
TEST_CASE_1 = [True, None, False, None, 1, None, True, 0, None, ["test_checkpoint_final_iteration=40.pt"]]
25+
TEST_CASE_1 = [True, None, False, None, 1, None, False, True, 0, None, ["test_checkpoint_final_iteration=40.pt"]]
2626

2727
TEST_CASE_2 = [
2828
False,
@@ -31,6 +31,7 @@
3131
"val_loss",
3232
2,
3333
None,
34+
False,
3435
True,
3536
0,
3637
None,
@@ -44,6 +45,7 @@
4445
None,
4546
1,
4647
None,
48+
False,
4749
True,
4850
2,
4951
2,
@@ -58,16 +60,17 @@
5860
1,
5961
None,
6062
False,
63+
False,
6164
10,
6265
2,
6366
["test_checkpoint_iteration=30.pt", "test_checkpoint_iteration=40.pt"],
6467
]
6568

66-
TEST_CASE_5 = [True, None, False, None, 1, None, True, 0, None, ["test_checkpoint_final_iteration=40.pt"], True]
69+
TEST_CASE_5 = [True, None, False, None, 1, None, False, True, 0, None, ["test_checkpoint_final_iteration=40.pt"], True]
6770

68-
TEST_CASE_6 = [True, "final_model.pt", False, None, 1, None, True, 0, None, ["final_model.pt"]]
71+
TEST_CASE_6 = [True, "final_model.pt", False, None, 1, None, False, True, 0, None, ["final_model.pt"]]
6972

70-
TEST_CASE_7 = [False, None, True, "val_loss", 1, "model.pt", True, 0, None, ["model.pt"]]
73+
TEST_CASE_7 = [False, None, True, "val_loss", 1, "model.pt", False, True, 0, None, ["model.pt"]]
7174

7275

7376
class TestHandlerCheckpointSaver(unittest.TestCase):
@@ -80,6 +83,7 @@ def test_file(
8083
key_metric_name,
8184
key_metric_n_saved,
8285
key_metric_filename,
86+
key_metric_save_state,
8387
epoch_level,
8488
save_interval,
8589
n_saved,
@@ -112,6 +116,7 @@ def _train_func(engine, batch):
112116
key_metric_name,
113117
key_metric_n_saved,
114118
key_metric_filename,
119+
key_metric_save_state,
115120
epoch_level,
116121
save_interval,
117122
n_saved,
@@ -141,6 +146,45 @@ def _train_func(engine, batch):
141146
engine.run(range(3), max_epochs=2)
142147
self.assertTrue(os.path.exists(os.path.join(tempdir, "net_final_iteration=1.pt")))
143148

149+
def test_load_state_dict(self):
150+
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
151+
net = torch.nn.PReLU()
152+
153+
# set up engine
154+
def _train_func(engine, batch):
155+
engine.state.metrics["val_loss"] = engine.state.iteration
156+
157+
engine = Engine(_train_func)
158+
159+
# set up testing handler
160+
with tempfile.TemporaryDirectory() as tempdir:
161+
engine = Engine(_train_func)
162+
CheckpointSaver(
163+
save_dir=tempdir,
164+
save_dict={"net": net},
165+
save_key_metric=True,
166+
key_metric_name="val_loss",
167+
key_metric_n_saved=2,
168+
key_metric_save_state=True,
169+
).attach(engine)
170+
engine.run(range(3), max_epochs=2)
171+
172+
saver = CheckpointSaver(
173+
save_dir=tempdir,
174+
save_dict={"net": net},
175+
save_key_metric=True,
176+
key_metric_name="val_loss",
177+
key_metric_n_saved=2,
178+
)
179+
engine = Engine(_train_func)
180+
CheckpointLoader(os.path.join(tempdir, "net_key_metric=6.pt"), {"checkpointer": saver}).attach(engine)
181+
engine.run(range(1), max_epochs=1)
182+
183+
resumed = saver._key_metric_checkpoint._saved
184+
for i in range(2):
185+
self.assertEqual(resumed[i].priority, 3 * (i + 1))
186+
self.assertEqual(resumed[i].filename, f"net_key_metric={3 * (i + 1)}.pt")
187+
144188

145189
if __name__ == "__main__":
146190
unittest.main()

0 commit comments

Comments
 (0)