Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ModelCheckpoint period #3630

Merged
merged 5 commits into from
Sep 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed determinism in `DDPSpawnBackend` when using `seed_everything` in main process ([#3335](https://github.com/PyTorchLightning/pytorch-lightning/pull/3335))

- Fixed `ModelCheckpoint` `period` to actually save every `period` epochs ([3630](https://github.com/PyTorchLightning/pytorch-lightning/pull/3630))

## [0.9.0] - YYYY-MM-DD

### Added
Expand Down
41 changes: 15 additions & 26 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,27 +181,21 @@ def save_checkpoint(self, trainer, pl_module):
"""
Performs the main logic around saving a checkpoint
"""
# only run on main process
if trainer.global_rank != 0:
return

# no models are saved
if self.save_top_k == 0:
return

# don't save anything during sanity check
if trainer.running_sanity_check:
return
epoch = trainer.current_epoch

rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
# skip this epoch
if self._should_skip_epoch(trainer):
if (
trainer.global_rank != 0 # only run on main process
or self.save_top_k == 0 # no models are saved
or self.period < 1 # no models are saved
or (epoch + 1) % self.period # skip epoch
or trainer.running_sanity_check # don't save anything during sanity check
or self.epoch_last_check == epoch # already saved
):
return

self._add_backward_monitor_support(trainer)
self._validate_monitor_key(trainer)

epoch = trainer.current_epoch

# track epoch when ckpt was last checked
self.epoch_last_check = trainer.current_epoch

Expand Down Expand Up @@ -278,7 +272,7 @@ def __init_monitor_mode(self, monitor, mode):

if mode not in mode_dict:
rank_zero_warn(
f"ModelCheckpoint mode {mode} is unknown, " f"fallback to auto mode.",
f"ModelCheckpoint mode {mode} is unknown, fallback to auto mode",
RuntimeWarning,
)
mode = "auto"
Expand All @@ -290,7 +284,6 @@ def _del_model(self, filepath: str):
self._fs.rm(filepath)

def _save_model(self, filepath: str, trainer, pl_module):

# in debugging, track when we save checkpoints
trainer.dev_debugger.track_checkpointing_history(filepath)

Expand All @@ -317,9 +310,7 @@ def check_monitor_top_k(self, current) -> bool:
current = torch.tensor(current)

monitor_op = {"min": torch.lt, "max": torch.gt}[self.mode]

val = monitor_op(current, self.best_k_models[self.kth_best_model_path])
return val
return monitor_op(current, self.best_k_models[self.kth_best_model_path]).item()

@classmethod
def _format_checkpoint_name(
Expand Down Expand Up @@ -443,10 +434,6 @@ def _validate_monitor_key(self, trainer):
)
raise MisconfigurationException(m)

def _should_skip_epoch(self, trainer):
epoch = trainer.current_epoch
return (self.epoch_last_check is not None) and (epoch - self.epoch_last_check) < self.period

def _get_metric_interpolated_filepath_name(self, epoch, ckpt_name_metrics):
filepath = self.format_checkpoint_name(epoch, ckpt_name_metrics)
version_cnt = 0
Expand Down Expand Up @@ -496,8 +483,10 @@ def _save_top_k_checkpoints(self, metrics, trainer, pl_module, epoch, filepath):
if current is None:
m = f"Can save best model only with {self.monitor} available, skipping."
if self.monitor == 'checkpoint_on':
m = 'No checkpoint_on found. Hint: Did you set it in EvalResult(checkpoint_on=tensor) or ' \
'TrainResult(checkpoint_on=tensor)?'
m = (
'No checkpoint_on found. HINT: Did you set it in '
'EvalResult(checkpoint_on=tensor) or TrainResult(checkpoint_on=tensor)?'
)
rank_zero_warn(m, RuntimeWarning)
elif self.check_monitor_top_k(current):
self._do_check_save(filepath, current, epoch, trainer, pl_module)
Expand Down
31 changes: 26 additions & 5 deletions tests/callbacks/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,13 +181,13 @@ def test_model_checkpoint_save_last(tmpdir):
early_stop_callback=False,
checkpoint_callback=model_checkpoint,
max_epochs=epochs,
logger=False,
)
trainer.fit(model)
last_filename = model_checkpoint._format_checkpoint_name(ModelCheckpoint.CHECKPOINT_NAME_LAST, epochs - 1, {})
last_filename = last_filename + '.ckpt'
assert str(tmpdir / last_filename) == model_checkpoint.last_model_path
assert set(os.listdir(tmpdir)) == \
set([f'epoch={i}.ckpt' for i in range(epochs)] + [last_filename, 'lightning_logs'])
assert set(os.listdir(tmpdir)) == set([f'epoch={i}.ckpt' for i in range(epochs)] + [last_filename])
ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last'


Expand Down Expand Up @@ -261,6 +261,7 @@ def test_model_checkpoint_none_monitor(tmpdir):
early_stop_callback=False,
checkpoint_callback=checkpoint_callback,
max_epochs=epochs,
logger=False,
)
trainer.fit(model)

Expand All @@ -272,8 +273,28 @@ def test_model_checkpoint_none_monitor(tmpdir):
assert checkpoint_callback.kth_best_model_path == ''

# check that the correct ckpts were created
expected = ['lightning_logs']
expected.extend(f'epoch={e}.ckpt' for e in range(epochs))
expected = [f'epoch={e}.ckpt' for e in range(epochs)]
assert set(os.listdir(tmpdir)) == set(expected)


@pytest.mark.parametrize("period", list(range(4)))
def test_model_checkpoint_period(tmpdir, period):
model = EvalModelTemplate()
epochs = 5
checkpoint_callback = ModelCheckpoint(filepath=tmpdir, save_top_k=-1, period=period)
carmocca marked this conversation as resolved.
Show resolved Hide resolved
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=False,
checkpoint_callback=checkpoint_callback,
max_epochs=epochs,
limit_train_batches=0.1,
limit_val_batches=0.1,
logger=False,
)
trainer.fit(model)

# check that the correct ckpts were created
expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % period] if period > 0 else []
assert set(os.listdir(tmpdir)) == set(expected)


Expand All @@ -296,7 +317,7 @@ def test_model_checkpoint_topk_zero(tmpdir):
assert checkpoint_callback.best_k_models == {}
assert checkpoint_callback.kth_best_model_path == ''
# check that no ckpts were created
assert len(set(os.listdir(tmpdir))) == 0
assert len(os.listdir(tmpdir)) == 0


def test_ckpt_metric_names(tmpdir):
Expand Down