Skip to content

Commit 2848bdc

Browse files
committed
add tests
1 parent b88b888 commit 2848bdc

File tree

3 files changed

+32
-5
lines changed

3 files changed

+32
-5
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1212
- Added support to checkpoint after training batches in `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146))
1313
- Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470))
1414

15+
- Added support to checkpoint after training batches in `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146))
1516

1617
### Changed
1718

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -165,13 +165,9 @@ def __init__(
165165
every_n_epochs: int = 1,
166166
every_n_batches: int = -1,
167167
mode: str = "min",
168-
<<<<<<< HEAD
169-
period: Optional[int] = None,
170-
=======
171-
period: int = 1,
172168
every_n_epochs: int = 1,
173169
every_n_batches: int = -1,
174-
>>>>>>> Update model_checkpoint.py
170+
period: Optional[int] = None,
175171
):
176172
super().__init__()
177173
self.monitor = monitor

tests/checkpointing/test_model_checkpoint.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,36 @@ def test_none_monitor_top_k(tmpdir):
499499
ModelCheckpoint(dirpath=tmpdir, save_top_k=-1)
500500
ModelCheckpoint(dirpath=tmpdir, save_top_k=0)
501501

502+
def test_invalid_every_n_epoch(tmpdir):
503+
""" Test that an exception is raised for every_n_epochs = 0 or < -1. """
504+
with pytest.raises(
505+
MisconfigurationException, match=r'Invalid value for every_n_epochs=0*'
506+
):
507+
ModelCheckpoint(dirpath=tmpdir, every_n_epochs=0)
508+
with pytest.raises(
509+
MisconfigurationException, match=r'Invalid value for every_n_epochs=-2*'
510+
):
511+
ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-2)
512+
513+
# These should not fail
514+
ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-1)
515+
ModelCheckpoint(dirpath=tmpdir, every_n_epochs=3)
516+
517+
def test_invalid_every_n_batches(tmpdir):
518+
""" Test that an exception is raised for every_n_batches = 0 or < -1. """
519+
with pytest.raises(
520+
MisconfigurationException, match=r'Invalid value for every_n_batches=0*'
521+
):
522+
ModelCheckpoint(dirpath=tmpdir, every_n_batches=0)
523+
with pytest.raises(
524+
MisconfigurationException, match=r'Invalid value for every_n_batches=-2*'
525+
):
526+
ModelCheckpoint(dirpath=tmpdir, every_n_batches=-2)
527+
528+
# These should not fail
529+
ModelCheckpoint(dirpath=tmpdir, every_n_batches=-1)
530+
ModelCheckpoint(dirpath=tmpdir, every_n_batches=3)
531+
502532

503533
def test_none_monitor_save_last(tmpdir):
504534
""" Test that a warning appears for save_last=True with monitor=None. """

0 commit comments

Comments
 (0)