Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Oct 3, 2020
1 parent 1ea367e commit 8a20989
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 56 deletions.
49 changes: 27 additions & 22 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,10 @@ def save_checkpoint(self, trainer, pl_module):
This method runs on all ranks, it is the responsibility of `self.save_function`
to handle correct behaviour in distributed training, i.e., saving only on rank 0.
"""
epoch = trainer.current_epoch

if (
self.save_top_k == 0 # no models are saved
or self.period < 1 # no models are saved
or (epoch + 1) % self.period # skip epoch
or (trainer.current_epoch + 1) % self.period # skip epoch
or trainer.running_sanity_check # don't save anything during sanity check
):
return
Expand All @@ -196,22 +194,24 @@ def save_checkpoint(self, trainer, pl_module):
self._validate_monitor_key(trainer)

# track epoch when ckpt was last checked
self.epoch_last_check = epoch
self.epoch_last_check = trainer.current_epoch

# what can be monitored
monitor_candidates = self._monitor_candidates(trainer)

# ie: path/val_loss=0.5.ckpt
filepath = self._get_metric_interpolated_filepath_name(epoch, monitor_candidates)
filepath = self._get_metric_interpolated_filepath_name(
trainer.current_epoch, trainer.global_step, monitor_candidates
)

# callback supports multiple simultaneous modes
# here we call each mode sequentially
# Mode 1: save all checkpoints OR only the top k
if self.save_top_k:
self._save_top_k_checkpoints(monitor_candidates, trainer, pl_module, epoch, filepath)
self._save_top_k_checkpoints(monitor_candidates, trainer, pl_module, trainer.current_epoch, filepath)

# Mode 2: save the last checkpoint
self._save_last_checkpoint(trainer, pl_module, epoch, monitor_candidates, filepath)
self._save_last_checkpoint(trainer, pl_module, monitor_candidates, filepath)

def __validate_init_configuration(self):
if self.save_top_k is not None and self.save_top_k < -1:
Expand Down Expand Up @@ -321,16 +321,17 @@ def _format_checkpoint_name(
cls,
filename: Optional[str],
epoch: int,
step: int,
metrics: Dict[str, Any],
prefix: str = "",
) -> str:
if not filename:
# filename is not set, use default name
filename = "{epoch}"
filename = "{epoch}-{step}"
# check and parse user passed keys in the string
groups = re.findall(r"(\{.*?)[:\}]", filename)
if len(groups) >= 0:
metrics["epoch"] = epoch
metrics.update({"epoch": epoch, 'step': step})
for group in groups:
name = group[1:]
filename = filename.replace(group, name + "={" + name)
Expand All @@ -340,28 +341,28 @@ def _format_checkpoint_name(
return cls.CHECKPOINT_JOIN_CHAR.join([txt for txt in (prefix, filename) if txt])

def format_checkpoint_name(
self, epoch: int, metrics: Dict[str, Any], ver: Optional[int] = None
self, epoch: int, step: int, metrics: Dict[str, Any], ver: Optional[int] = None
) -> str:
"""Generate a filename according to the defined template.
Example::
>>> tmpdir = os.path.dirname(__file__)
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch}'))
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
>>> os.path.basename(ckpt.format_checkpoint_name(0, 1, metrics={}))
'epoch=0.ckpt'
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch:03d}'))
>>> os.path.basename(ckpt.format_checkpoint_name(5, {}))
>>> os.path.basename(ckpt.format_checkpoint_name(5, 2, metrics={}))
'epoch=005.ckpt'
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch}-{val_loss:.2f}'))
>>> os.path.basename(ckpt.format_checkpoint_name(2, dict(val_loss=0.123456)))
>>> os.path.basename(ckpt.format_checkpoint_name(2, 3, metrics=dict(val_loss=0.123456)))
'epoch=2-val_loss=0.12.ckpt'
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{missing:d}'))
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
>>> os.path.basename(ckpt.format_checkpoint_name(0, 4, metrics={}))
'missing=0.ckpt'
"""
filename = self._format_checkpoint_name(
self.filename, epoch, metrics, prefix=self.prefix
self.filename, epoch, step, metrics, prefix=self.prefix
)
if ver is not None:
filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}"))
Expand Down Expand Up @@ -438,12 +439,12 @@ def _validate_monitor_key(self, trainer):
)
raise MisconfigurationException(m)

def _get_metric_interpolated_filepath_name(self, epoch, ckpt_name_metrics):
filepath = self.format_checkpoint_name(epoch, ckpt_name_metrics)
def _get_metric_interpolated_filepath_name(self, epoch: int, step: int, ckpt_name_metrics: Dict[str, Any]):
filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics)
version_cnt = 0
while self._fs.exists(filepath):
filepath = self.format_checkpoint_name(
epoch, ckpt_name_metrics, ver=version_cnt
epoch, step, ckpt_name_metrics, ver=version_cnt
)
# this epoch called before
version_cnt += 1
Expand All @@ -455,7 +456,7 @@ def _monitor_candidates(self, trainer):
ckpt_name_metrics.update(trainer.logger_connector.progress_bar_metrics)
return ckpt_name_metrics

def _save_last_checkpoint(self, trainer, pl_module, epoch, ckpt_name_metrics, filepath):
def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics, filepath):
should_save_last = self.monitor is None or self.save_last
if not should_save_last:
return
Expand All @@ -465,7 +466,11 @@ def _save_last_checkpoint(self, trainer, pl_module, epoch, ckpt_name_metrics, fi
# when user ALSO asked for the 'last.ckpt' change the name
if self.save_last:
last_filepath = self._format_checkpoint_name(
self.CHECKPOINT_NAME_LAST, epoch, ckpt_name_metrics, prefix=self.prefix
self.CHECKPOINT_NAME_LAST,
trainer.current_epoch,
trainer.global_step,
ckpt_name_metrics,
prefix=self.prefix
)
last_filepath = os.path.join(self.dirpath, f"{last_filepath}.ckpt")

Expand Down Expand Up @@ -514,7 +519,7 @@ def _update_best_and_save(
self.best_k_models.pop(self.kth_best_model_path)
del_list.append(delpath)

self.best_k_models[filepath] = current
self.best_k_models[filepath] = float(current)
if len(self.best_k_models) == k:
# monitor dict has reached k elements
_op = max if self.mode == "min" else min
Expand Down Expand Up @@ -543,7 +548,7 @@ def to_yaml(self, filepath: Optional[Union[str, Path]] = None):
Saves the `best_k_models` dict containing the checkpoint
paths with the corresponding scores to a YAML file.
"""
best_k = {k: v.item() for k, v in self.best_k_models.items()}
best_k = {k: v for k, v in self.best_k_models.items()}
if filepath is None:
filepath = os.path.join(self.dirpath, "best_k_models.yaml")
with self._fs.open(filepath, "w") as fp:
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,10 @@ def __run_eval_epoch_end(self, num_dataloaders, using_eval_result):
# depre warning
if eval_results is not None:
step = 'testing_epoch_end' if self.testing else 'validation_epoch_end'
m = f'The {step} should not return anything as of 9.1.' \
self.warning_cache.warn(
f'The {step} should not return anything as of 9.1.'
f'to log, use self.log(...) or self.write(...) directly in the LightningModule'
self.warning_cache.warn(m)
)

if using_eval_result and not user_reduced:
eval_results = self.__auto_reduce_result_objs(outputs)
Expand Down
70 changes: 40 additions & 30 deletions tests/callbacks/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_model_checkpoint_to_yaml(tmpdir, save_top_k):
path_yaml = os.path.join(tmpdir, 'best_k_models.yaml')
checkpoint.to_yaml(path_yaml)
d = yaml.full_load(open(path_yaml, 'r'))
best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()}
best_k = {k: v for k, v in checkpoint.best_k_models.items()}
assert d == best_k


Expand Down Expand Up @@ -124,7 +124,9 @@ def test_model_checkpoint_no_extraneous_invocations(tmpdir):
"""Test to ensure that the model callback saves the checkpoints only once in distributed mode."""
model = EvalModelTemplate()
num_epochs = 4
model_checkpoint = ModelCheckpointTestInvocations(monitor='early_stop_on', expected_count=num_epochs, save_top_k=-1)
model_checkpoint = ModelCheckpointTestInvocations(
filepath=tmpdir, monitor='early_stop_on', expected_count=num_epochs, save_top_k=-1
)
trainer = Trainer(
distributed_backend="ddp_cpu",
num_processes=2,
Expand All @@ -139,50 +141,51 @@ def test_model_checkpoint_no_extraneous_invocations(tmpdir):

def test_model_checkpoint_format_checkpoint_name(tmpdir):
# empty filename:
ckpt_name = ModelCheckpoint._format_checkpoint_name('', 3, {})
assert ckpt_name == 'epoch=3'
ckpt_name = ModelCheckpoint._format_checkpoint_name(None, 3, {}, prefix='test')
assert ckpt_name == 'test-epoch=3'
ckpt_name = ModelCheckpoint._format_checkpoint_name('', 3, 2, {})
assert ckpt_name == 'epoch=3-step=2'
ckpt_name = ModelCheckpoint._format_checkpoint_name(None, 3, 2, {}, prefix='test')
assert ckpt_name == 'test-epoch=3-step=2'
# no groups case:
ckpt_name = ModelCheckpoint._format_checkpoint_name('ckpt', 3, {}, prefix='test')
ckpt_name = ModelCheckpoint._format_checkpoint_name('ckpt', 3, 2, {}, prefix='test')
assert ckpt_name == 'test-ckpt'
# no prefix
ckpt_name = ModelCheckpoint._format_checkpoint_name('{epoch:03d}-{acc}', 3, {'acc': 0.03})
ckpt_name = ModelCheckpoint._format_checkpoint_name('{epoch:03d}-{acc}', 3, 2, {'acc': 0.03})
assert ckpt_name == 'epoch=003-acc=0.03'
# prefix
char_org = ModelCheckpoint.CHECKPOINT_JOIN_CHAR
ModelCheckpoint.CHECKPOINT_JOIN_CHAR = '@'
ckpt_name = ModelCheckpoint._format_checkpoint_name('{epoch},{acc:.5f}', 3, {'acc': 0.03}, prefix='test')
ckpt_name = ModelCheckpoint._format_checkpoint_name('{epoch},{acc:.5f}', 3, 2, {'acc': 0.03}, prefix='test')
assert ckpt_name == 'test@epoch=3,acc=0.03000'
ModelCheckpoint.CHECKPOINT_JOIN_CHAR = char_org
# no filepath set
ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath=None).format_checkpoint_name(3, {})
assert ckpt_name == 'epoch=3.ckpt'
ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath='').format_checkpoint_name(5, {})
assert ckpt_name == 'epoch=5.ckpt'
ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath=None).format_checkpoint_name(3, 4, {})
assert ckpt_name == 'epoch=3-step=4.ckpt'
ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath='').format_checkpoint_name(5, 4, {})
assert ckpt_name == 'epoch=5-step=4.ckpt'
# CWD
ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath='.').format_checkpoint_name(3, {})
assert Path(ckpt_name) == Path('.') / 'epoch=3.ckpt'
ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath='.').format_checkpoint_name(3, 4, {})
assert Path(ckpt_name) == Path('.') / 'epoch=3-step=4.ckpt'
# dir does not exist so it is used as filename
filepath = tmpdir / 'dir'
ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath=filepath, prefix='test').format_checkpoint_name(3, {})
ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath=filepath, prefix='test').format_checkpoint_name(3, 4, {})
assert ckpt_name == tmpdir / 'test-dir.ckpt'
# now, dir exists
os.mkdir(filepath)
ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath=filepath, prefix='test').format_checkpoint_name(3, {})
assert ckpt_name == filepath / 'test-epoch=3.ckpt'
ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath=filepath, prefix='test').format_checkpoint_name(3, 4, {})
assert ckpt_name == filepath / 'test-epoch=3-step=4.ckpt'
# with ver
ckpt_name = ModelCheckpoint(monitor='early_stop_on',
filepath=tmpdir / 'name', prefix='test').format_checkpoint_name(3, {}, ver=3)
filepath=tmpdir / 'name', prefix='test').format_checkpoint_name(3, 4, {}, ver=3)
assert ckpt_name == tmpdir / 'test-name-v3.ckpt'


def test_model_checkpoint_save_last(tmpdir):
"""Tests that save_last produces only one last checkpoint."""
seed_everything()
model = EvalModelTemplate()
epochs = 3
ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last-{epoch}'
model_checkpoint = ModelCheckpoint(monitor='early_stop_on', filepath=tmpdir, save_top_k=-1, save_last=True)
model_checkpoint = ModelCheckpoint(monitor='early_stop_on', filepath=tmpdir / '{step}', save_top_k=-1, save_last=True)
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=False,
Expand All @@ -191,10 +194,12 @@ def test_model_checkpoint_save_last(tmpdir):
logger=False,
)
trainer.fit(model)
last_filename = model_checkpoint._format_checkpoint_name(ModelCheckpoint.CHECKPOINT_NAME_LAST, epochs - 1, {})
last_filename = model_checkpoint._format_checkpoint_name(
ModelCheckpoint.CHECKPOINT_NAME_LAST, trainer.current_epoch, trainer.global_step, {}
)
last_filename = last_filename + '.ckpt'
assert str(tmpdir / last_filename) == model_checkpoint.last_model_path
assert set(os.listdir(tmpdir)) == set([f'epoch={i}.ckpt' for i in range(epochs)] + [last_filename])
assert set(os.listdir(tmpdir)) == set([f'step={i}.ckpt' for i in [19, 29, 30]] + [last_filename])
ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last'


Expand Down Expand Up @@ -229,12 +234,13 @@ def test_none_monitor_save_last(tmpdir):

def test_model_checkpoint_none_monitor(tmpdir):
""" Test that it is possible to save all checkpoints when monitor=None. """
seed_everything()
model = EvalModelTemplate()
model.validation_step = model.validation_step_no_monitor
model.validation_epoch_end = model.validation_epoch_end_no_monitor

epochs = 2
checkpoint_callback = ModelCheckpoint(monitor=None, filepath=tmpdir, save_top_k=-1)
checkpoint_callback = ModelCheckpoint(monitor=None, filepath=tmpdir / '{step}', save_top_k=-1)
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=False,
Expand All @@ -246,28 +252,29 @@ def test_model_checkpoint_none_monitor(tmpdir):

# these should not be set if monitor is None
assert checkpoint_callback.monitor is None
assert checkpoint_callback.best_model_path == checkpoint_callback.last_model_path == tmpdir / 'epoch=1.ckpt'
assert checkpoint_callback.best_model_path == checkpoint_callback.last_model_path == tmpdir / 'step=20.ckpt'
assert checkpoint_callback.best_model_score == 0
assert checkpoint_callback.best_k_models == {}
assert checkpoint_callback.kth_best_model_path == ''

# check that the correct ckpts were created
expected = [f'epoch={e}.ckpt' for e in range(epochs)]
expected = [f'step={i}.ckpt' for i in [9, 19, 20]]
assert set(os.listdir(tmpdir)) == set(expected)


@pytest.mark.parametrize("period", list(range(4)))
def test_model_checkpoint_period(tmpdir, period):
model = EvalModelTemplate()
epochs = 5
checkpoint_callback = ModelCheckpoint(filepath=tmpdir, save_top_k=-1, period=period)
checkpoint_callback = ModelCheckpoint(filepath=tmpdir / '{epoch}', save_top_k=-1, period=period)
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=False,
checkpoint_callback=checkpoint_callback,
max_epochs=epochs,
limit_train_batches=0.1,
limit_val_batches=0.1,
val_check_interval=1.0,
logger=False,
)
trainer.fit(model)
Expand Down Expand Up @@ -304,13 +311,14 @@ def test_model_checkpoint_topk_all(tmpdir):
seed_everything(1000)
epochs = 2
model = EvalModelTemplate()
checkpoint_callback = ModelCheckpoint(filepath=tmpdir, monitor="early_stop_on", save_top_k=-1)
checkpoint_callback = ModelCheckpoint(filepath=tmpdir / '{epoch}', monitor="early_stop_on", save_top_k=-1)
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=False,
checkpoint_callback=checkpoint_callback,
max_epochs=epochs,
logger=False,
val_check_interval=1.0,
)
trainer.fit(model)
assert checkpoint_callback.best_model_path == tmpdir / "epoch=1.ckpt"
Expand Down Expand Up @@ -364,12 +372,12 @@ def test_default_checkpoint_behavior(tmpdir):

assert len(results) == 1
assert results[0]['test_acc'] >= 0.80
assert len(trainer.dev_debugger.checkpoint_callback_history) == 3
assert len(trainer.dev_debugger.checkpoint_callback_history) == 4

# make sure the checkpoint we saved has the metric in the name
ckpts = os.listdir(os.path.join(tmpdir, 'lightning_logs', 'version_0', 'checkpoints'))
assert len(ckpts) == 1
assert ckpts[0] == 'epoch=2.ckpt'
assert ckpts[0] == 'epoch=2-step=15.ckpt'


def test_ckpt_metric_names_results(tmpdir):
Expand Down Expand Up @@ -426,19 +434,21 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
model = EvalModelTemplate()
num_epochs = 3
model_checkpoint = ModelCheckpoint(
monitor='early_stop_on', filepath=tmpdir, save_top_k=num_epochs, save_last=True
monitor='early_stop_on', filepath=tmpdir / '{epoch}', save_top_k=num_epochs, save_last=True
)
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=False,
checkpoint_callback=model_checkpoint,
max_epochs=num_epochs,
val_check_interval=1.0,
)
trainer.fit(model)

path_last_epoch = str(tmpdir / f"epoch={num_epochs - 1}.ckpt")
path_last = str(tmpdir / "last.ckpt")
assert path_last == model_checkpoint.last_model_path
assert os.path.isfile(path_last_epoch)

ckpt_last_epoch = torch.load(path_last_epoch)
ckpt_last = torch.load(path_last)
Expand Down
10 changes: 8 additions & 2 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,8 +401,14 @@ def mock_save_function(filepath, *args):
# simulated losses
losses = [10, 9, 2.8, 5, 2.5]

checkpoint_callback = ModelCheckpoint(tmpdir, monitor='checkpoint_on', save_top_k=save_top_k, save_last=save_last,
prefix=file_prefix, verbose=1)
checkpoint_callback = ModelCheckpoint(
tmpdir / '{epoch}',
monitor='checkpoint_on',
save_top_k=save_top_k,
save_last=save_last,
prefix=file_prefix,
verbose=1,
)
checkpoint_callback.save_function = mock_save_function
trainer = Trainer()

Expand Down

0 comments on commit 8a20989

Please sign in to comment.