Skip to content

Commit 548e938

Browse files
committed
fix-default-0
1 parent 3ccf8d1 commit 548e938

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -356,13 +356,15 @@ def __init_triggers(
356356

357357
# Default to running once after each validation epoch if neither
358358
# every_n_train_steps nor every_n_val_epochs is set
359-
self.every_n_val_epochs = every_n_val_epochs or 0
360-
self.every_n_train_steps = every_n_train_steps or 0
361-
if self.every_n_train_steps == 0 and self.every_n_val_epochs == 0:
359+
if every_n_train_steps is None and every_n_val_epochs is None:
362360
self.every_n_val_epochs = 1
361+
self.every_n_train_steps = 0
363362
log.debug("Both every_n_train_steps and every_n_val_epochs are not set. Setting every_n_val_epochs=1")
363+
else:
364+
self.every_n_val_epochs = every_n_val_epochs or 0
365+
self.every_n_train_steps = every_n_train_steps or 0
364366

365-
# period takes precedence for every_n_val_epochs for backwards compatibility
367+
# period takes precedence over every_n_val_epochs for backwards compatibility
366368
if period is not None:
367369
rank_zero_warn(
368370
'Argument `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.'

tests/checkpointing/test_model_checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -614,7 +614,7 @@ def test_model_checkpoint_period(tmpdir, period: int):
614614
assert set(os.listdir(tmpdir)) == set(expected)
615615

616616

617-
@pytest.mark.parametrize("every_n_val_epochs", list(range(1, 4)))
617+
@pytest.mark.parametrize("every_n_val_epochs", list(range(4)))
618618
def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs):
619619
model = LogInTwoMethods()
620620
epochs = 5

0 commit comments

Comments
 (0)