Skip to content

Commit 6fffdec

Browse files
committed
fix tests
1 parent d91636d commit 6fffdec

File tree

3 files changed

+12
-26
lines changed

3 files changed

+12
-26
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3535

3636
### Deprecated
3737

38+
- `period` has been deprecated in favor of `every_n_epochs` in the `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146))
39+
3840

3941
### Removed
4042

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,8 @@ def __init__(
174174
self.save_last = save_last
175175
self.save_top_k = save_top_k
176176
self.save_weights_only = save_weights_only
177-
self.period = every_n_epochs
178-
self.every_n_epochs = every_n_epochs
177+
self.every_n_epochs = period or every_n_epochs
178+
self.period = self.every_n_epochs
179179
self.every_n_batches = every_n_batches
180180
self._last_global_step_saved = -1
181181
self.current_score = None
@@ -192,8 +192,6 @@ def __init__(
192192
'Argument `period` is deprecated in v1.3 and will be removed in v1.5.'
193193
' Please use `every_n_epochs` instead.', DeprecationWarning
194194
)
195-
self.every_n_epochs = period
196-
self.period = period
197195

198196
self.__init_monitor_mode(monitor, mode)
199197
self.__init_ckpt_dir(dirpath, filename, save_top_k)
@@ -209,7 +207,7 @@ def on_pretrain_routine_start(self, trainer, pl_module):
209207
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) -> None:
210208
if self._should_skip_saving_checkpoint(trainer):
211209
return
212-
step = trainer.global_step
210+
step = trainer.total_batch_idx
213211
skip_batch = self.every_n_steps < 1 or ((step + 1) % self.every_n_steps != 0)
214212
if skip_batch:
215213
return

tests/checkpointing/test_model_checkpoint.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -592,9 +592,8 @@ def test_model_checkpoint_every_n_epochs(tmpdir, every_n_epochs):
592592
default_root_dir=tmpdir,
593593
callbacks=[checkpoint_callback],
594594
max_epochs=epochs,
595-
limit_train_batches=0.1,
596-
limit_val_batches=0.1,
597-
val_check_interval=1.0,
595+
limit_train_batches=1,
596+
limit_val_batches=1,
598597
logger=False,
599598
)
600599
trainer.fit(model)
@@ -615,9 +614,8 @@ def test_model_checkpoint_every_n_epochs_and_no_period(tmpdir, every_n_epochs):
615614
default_root_dir=tmpdir,
616615
callbacks=[checkpoint_callback],
617616
max_epochs=epochs,
618-
limit_train_batches=0.1,
619-
limit_val_batches=0.1,
620-
val_check_interval=1.0,
617+
limit_train_batches=1,
618+
limit_val_batches=1,
621619
logger=False,
622620
)
623621
trainer.fit(model)
@@ -631,16 +629,15 @@ def test_ckpt_every_n_batches(tmpdir):
631629
""" Tests that the checkpoints are saved every n training steps. """
632630

633631
model = LogInTwoMethods()
634-
632+
every_n_batches = 16
635633
trainer = Trainer(
636634
default_root_dir=tmpdir,
637-
min_epochs=2,
638635
max_epochs=2,
639636
progress_bar_refresh_rate=0,
640637
checkpoint_callback=ModelCheckpoint(
641638
filename="{step}",
642639
every_n_epochs=0,
643-
every_n_batches=16,
640+
every_n_batches=every_n_batches,
644641
dirpath=tmpdir,
645642
save_top_k=-1,
646643
save_last=False,
@@ -649,16 +646,7 @@ def test_ckpt_every_n_batches(tmpdir):
649646
)
650647

651648
trainer.fit(model)
652-
expected = [
653-
"step=15.ckpt",
654-
"step=31.ckpt",
655-
"step=47.ckpt",
656-
"step=63.ckpt",
657-
"step=79.ckpt",
658-
"step=95.ckpt",
659-
"step=111.ckpt",
660-
"step=127.ckpt",
661-
]
649+
expected=[f"step={i}.ckpt" for i in range(15, 128, every_n_batches)]
662650
assert set(os.listdir(tmpdir)) == set(expected)
663651

664652

@@ -667,9 +655,7 @@ def test_ckpt_every_n_batches_and_every_n_epochs(tmpdir):
667655
model = LogInTwoMethods()
668656
trainer = Trainer(
669657
default_root_dir=tmpdir,
670-
min_epochs=2,
671658
max_epochs=2,
672-
progress_bar_refresh_rate=0,
673659
checkpoint_callback=ModelCheckpoint(
674660
every_n_epochs=1,
675661
every_n_batches=30,

0 commit comments

Comments
 (0)