Skip to content

Commit 4065f42

Browse files
authored
Merge 81385f2 into 55dd3a4
2 parents 55dd3a4 + 81385f2 commit 4065f42

File tree

4 files changed

+200
-21
lines changed

4 files changed

+200
-21
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1212
- Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470))
1313

1414

15+
- Added support to checkpoint after training batches in `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146))
16+
1517
- Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072))
1618

1719

@@ -46,6 +48,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4648

4749
### Deprecated
4850

51+
- `period` has been deprecated in favor of `every_n_val_epochs` in the `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146))
52+
4953

5054
- Deprecated `trainer.running_sanity_check` in favor of `trainer.sanity_checking` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))
5155

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 60 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,21 @@ class ModelCheckpoint(Callback):
9393
save_weights_only: if ``True``, then only the model's weights will be
9494
saved (``model.save_weights(filepath)``), else the full model
9595
is saved (``model.save(filepath)``).
96+
every_n_train_steps: Number of training steps between checkpoints.
97+
To disable, set ``every_n_train_steps = 0``. This value must be non-negative.
98+
every_n_val_epochs: Number of validation epochs between checkpoints.
99+
To disable, set ``every_n_val_epochs = 0``. This value must be non-negative.
100+
This is not mutually exclusive with ``every_n_val_epochs``.
101+
If both are set, pay extreme caution if also setting ``monitor``
102+
as the ``monitor`` value must be available in both training and validation.
103+
This can have unintended consequences with tracking the top k models.
96104
period: Interval (number of epochs) between checkpoints.
97105
106+
.. warning::
107+
This argument has been deprecated in v1.3 and will be removed in v1.5.
108+
109+
Use ``every_n_val_epochs`` instead.
110+
98111
Note:
99112
For extra customization, ModelCheckpoint includes the following attributes:
100113
@@ -155,15 +168,19 @@ def __init__(
155168
save_top_k: Optional[int] = None,
156169
save_weights_only: bool = False,
157170
mode: str = "min",
158-
period: int = 1,
171+
every_n_train_steps: int = 0,
172+
every_n_val_epochs: int = 1,
173+
period: Optional[int] = None,
159174
):
160175
super().__init__()
161176
self.monitor = monitor
162177
self.verbose = verbose
163178
self.save_last = save_last
164179
self.save_top_k = save_top_k
165180
self.save_weights_only = save_weights_only
166-
self.period = period
181+
self.every_n_val_epochs = period if period is not None else every_n_val_epochs
182+
self.period = self.every_n_val_epochs
183+
self.every_n_train_steps = every_n_train_steps
167184
self._last_global_step_saved = -1
168185
self.current_score = None
169186
self.best_k_models = {}
@@ -174,6 +191,12 @@ def __init__(
174191
self.save_function = None
175192
self.warned_result_obj = False
176193

194+
if period is not None:
195+
rank_zero_warn(
196+
'Argument `period` is deprecated in v1.3 and will be removed in v1.5.'
197+
' Please use `every_n_val_epochs` instead.', DeprecationWarning
198+
)
199+
177200
self.__init_monitor_mode(monitor, mode)
178201
self.__init_ckpt_dir(dirpath, filename, save_top_k)
179202
self.__validate_init_configuration()
@@ -185,11 +208,27 @@ def on_pretrain_routine_start(self, trainer, pl_module):
185208
self.__resolve_ckpt_dir(trainer)
186209
self.save_function = trainer.save_checkpoint
187210

211+
def on_train_batch_end(self, trainer, pl_module, *args, **kwargs) -> None:
212+
""" Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps` """
213+
if self._should_skip_saving_checkpoint(trainer):
214+
return
215+
step = trainer.global_step
216+
skip_batch = self.every_n_train_steps < 1 or ((step + 1) % self.every_n_train_steps != 0)
217+
if skip_batch:
218+
return
219+
self.save_checkpoint(trainer, pl_module)
220+
188221
def on_validation_end(self, trainer, pl_module):
189222
"""
190223
checkpoints can be saved at the end of the val loop
191224
"""
192-
self.save_checkpoint(trainer)
225+
skip = (
226+
self._should_skip_saving_checkpoint(trainer) or self.every_n_val_epochs < 1
227+
or (trainer.current_epoch + 1) % self.every_n_val_epochs != 0
228+
)
229+
if skip:
230+
return
231+
self.save_checkpoint(trainer, pl_module)
193232

194233
def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]:
195234
return {
@@ -216,20 +255,8 @@ def save_checkpoint(self, trainer, unused: Optional = None):
216255
" has been removed. Support for the old signature will be removed in v1.5", DeprecationWarning
217256
)
218257

219-
epoch = trainer.current_epoch
220258
global_step = trainer.global_step
221259

222-
from pytorch_lightning.trainer.states import TrainerState
223-
if (
224-
trainer.fast_dev_run # disable checkpointing with fast_dev_run
225-
or trainer.state != TrainerState.FITTING # don't save anything during non-fit
226-
or trainer.sanity_checking # don't save anything during sanity check
227-
or self.period < 1 # no models are saved
228-
or (epoch + 1) % self.period # skip epoch
229-
or self._last_global_step_saved == global_step # already saved at the last step
230-
):
231-
return
232-
233260
self._add_backward_monitor_support(trainer)
234261
self._validate_monitor_key(trainer)
235262

@@ -248,9 +275,26 @@ def save_checkpoint(self, trainer, unused: Optional = None):
248275
# Mode 3: save last checkpoints
249276
self._save_last_checkpoint(trainer, monitor_candidates)
250277

278+
def _should_skip_saving_checkpoint(self, trainer) -> bool:
279+
from pytorch_lightning.trainer.states import TrainerState
280+
return (
281+
trainer.fast_dev_run # disable checkpointing with fast_dev_run
282+
or trainer.state != TrainerState.FITTING # don't save anything during non-fit
283+
or trainer.sanity_checking # don't save anything during sanity check
284+
or self._last_global_step_saved == trainer.global_step # already saved at the last step
285+
)
286+
251287
def __validate_init_configuration(self):
252288
if self.save_top_k is not None and self.save_top_k < -1:
253289
raise MisconfigurationException(f'Invalid value for save_top_k={self.save_top_k}. Must be None or >= -1')
290+
if self.every_n_train_steps < 0:
291+
raise MisconfigurationException(
292+
f'Invalid value for every_n_train_batches={self.every_n_train_steps}. Must be >= 0'
293+
)
294+
if self.every_n_val_epochs < 0:
295+
raise MisconfigurationException(
296+
f'Invalid value for every_n_val_epochs={self.every_n_val_epochs}. Must be >= 0'
297+
)
254298
if self.monitor is None:
255299
# None: save last epoch, -1: save all epochs, 0: nothing is saved
256300
if self.save_top_k not in (None, -1, 0):
@@ -554,9 +598,7 @@ def _save_none_monitor_checkpoint(self, trainer, monitor_candidates: Dict[str, A
554598
self._save_model(trainer, filepath)
555599

556600
if (
557-
self.save_top_k is None
558-
and self.best_model_path
559-
and self.best_model_path != filepath
601+
self.save_top_k is None and self.best_model_path and self.best_model_path != filepath
560602
and trainer.is_global_zero
561603
):
562604
self._del_model(self.best_model_path)

tests/checkpointing/test_model_checkpoint.py

Lines changed: 129 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,26 @@ def test_none_monitor_save_last(tmpdir):
515515
ModelCheckpoint(dirpath=tmpdir, save_last=False)
516516

517517

518+
def test_invalid_every_n_val_epochs(tmpdir):
519+
""" Make sure that a MisconfigurationException is raised for a negative every_n_val_epochs argument. """
520+
with pytest.raises(MisconfigurationException, match=r'.*Must be >= 0'):
521+
ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=-3)
522+
# These should not fail
523+
ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=0)
524+
ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=1)
525+
ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=2)
526+
527+
528+
def test_invalid_every_n_train_steps(tmpdir):
529+
""" Make sure that a MisconfigurationException is raised for a negative every_n_val_epochs argument. """
530+
with pytest.raises(MisconfigurationException, match=r'.*Must be >= 0'):
531+
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=-3)
532+
# These should not fail
533+
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=0)
534+
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=1)
535+
ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=2)
536+
537+
518538
def test_model_checkpoint_save_last_none_monitor(tmpdir, caplog):
519539
""" Test that it is possible to save all checkpoints when monitor=None. """
520540
seed_everything()
@@ -558,9 +578,8 @@ def test_model_checkpoint_period(tmpdir, period: int):
558578
default_root_dir=tmpdir,
559579
callbacks=[checkpoint_callback],
560580
max_epochs=epochs,
561-
limit_train_batches=0.1,
562-
limit_val_batches=0.1,
563-
val_check_interval=1.0,
581+
limit_train_batches=1,
582+
limit_val_batches=1,
564583
logger=False,
565584
)
566585
trainer.fit(model)
@@ -570,6 +589,113 @@ def test_model_checkpoint_period(tmpdir, period: int):
570589
assert set(os.listdir(tmpdir)) == set(expected)
571590

572591

592+
@pytest.mark.parametrize("every_n_val_epochs", list(range(4)))
593+
def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs):
594+
model = LogInTwoMethods()
595+
epochs = 5
596+
checkpoint_callback = ModelCheckpoint(
597+
dirpath=tmpdir, filename='{epoch}', save_top_k=-1, every_n_val_epochs=every_n_val_epochs
598+
)
599+
trainer = Trainer(
600+
default_root_dir=tmpdir,
601+
callbacks=[checkpoint_callback],
602+
max_epochs=epochs,
603+
limit_train_batches=1,
604+
limit_val_batches=1,
605+
logger=False,
606+
)
607+
trainer.fit(model)
608+
609+
# check that the correct ckpts were created
610+
expected = [f'epoch={e}.ckpt' for e in range(epochs)
611+
if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else []
612+
assert set(os.listdir(tmpdir)) == set(expected)
613+
614+
615+
@pytest.mark.parametrize("every_n_val_epochs", list(range(4)))
616+
def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epochs):
617+
""" Tests that if period is set, it takes precedence over every_n_val_epochs for backwards compatibility. """
618+
model = LogInTwoMethods()
619+
epochs = 5
620+
checkpoint_callback = ModelCheckpoint(
621+
dirpath=tmpdir,
622+
filename='{epoch}',
623+
save_top_k=-1,
624+
every_n_val_epochs=(2 * every_n_val_epochs),
625+
period=every_n_val_epochs
626+
)
627+
trainer = Trainer(
628+
default_root_dir=tmpdir,
629+
callbacks=[checkpoint_callback],
630+
max_epochs=epochs,
631+
limit_train_batches=1,
632+
limit_val_batches=1,
633+
logger=False,
634+
)
635+
trainer.fit(model)
636+
637+
# check that the correct ckpts were created
638+
expected = [f'epoch={e}.ckpt' for e in range(epochs)
639+
if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else []
640+
assert set(os.listdir(tmpdir)) == set(expected)
641+
642+
643+
def test_ckpt_every_n_train_steps(tmpdir):
644+
""" Tests that the checkpoints are saved every n training steps. """
645+
646+
model = LogInTwoMethods()
647+
every_n_train_steps = 16
648+
checkpoint_callback = ModelCheckpoint(
649+
filename="{step}",
650+
every_n_val_epochs=0,
651+
every_n_train_steps=every_n_train_steps,
652+
dirpath=tmpdir,
653+
save_top_k=-1,
654+
save_last=False,
655+
)
656+
trainer = Trainer(
657+
default_root_dir=tmpdir,
658+
max_epochs=2,
659+
progress_bar_refresh_rate=0,
660+
callbacks=[checkpoint_callback],
661+
logger=False,
662+
)
663+
664+
trainer.fit(model)
665+
expected = [f"step={i}.ckpt" for i in range(15, 128, every_n_train_steps)]
666+
assert set(os.listdir(tmpdir)) == set(expected)
667+
668+
669+
@pytest.mark.parametrize("every_n_val_epochs", [1, 3])
670+
def test_ckpt_every_n_train_steps_and_every_n_val_epochs(tmpdir, every_n_val_epochs):
671+
""" Tests that checkpoints are taken every 30 steps and every epochs """
672+
model = LogInTwoMethods()
673+
every_n_train_steps = 30
674+
checkpoint_callback = ModelCheckpoint(
675+
every_n_val_epochs=every_n_val_epochs,
676+
every_n_train_steps=every_n_train_steps,
677+
dirpath=tmpdir,
678+
save_top_k=-1,
679+
save_last=False,
680+
filename="{step}",
681+
)
682+
max_epochs = 3
683+
epoch_step_length = 64
684+
trainer = Trainer(
685+
default_root_dir=tmpdir,
686+
max_epochs=max_epochs,
687+
callbacks=[checkpoint_callback],
688+
logger=False,
689+
)
690+
trainer.fit(model)
691+
expected_steps_for_ckpt = [
692+
i for i in range(epoch_step_length * max_epochs)
693+
if ((i + 1) % every_n_train_steps) == 0 or (i + 1) % (every_n_val_epochs * epoch_step_length) == 0
694+
]
695+
expected_ckpt_files = [f"step={step}.ckpt" for step in expected_steps_for_ckpt]
696+
assert set(os.listdir(tmpdir)) == set(expected_ckpt_files)
697+
698+
573699
def test_model_checkpoint_topk_zero(tmpdir):
574700
""" Test that no checkpoints are saved when save_top_k=0. """
575701
model = LogInTwoMethods()

tests/deprecated_api/test_remove_1-5.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,10 @@ def configure_optimizers(self):
104104

105105
with pytest.deprecated_call(match="`training_step` .* `optimizer_idx` .* manual .* will be removed in v1.5"):
106106
trainer.fit(model)
107+
108+
109+
def test_v1_5_0_model_checkpoint_period(tmpdir):
110+
with no_warning_call(DeprecationWarning):
111+
ModelCheckpoint(dirpath=tmpdir)
112+
with pytest.deprecated_call(match="is deprecated in v1.3 and will be removed in v1.5"):
113+
ModelCheckpoint(dirpath=tmpdir, period=1)

0 commit comments

Comments
 (0)