Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added changeable extension variable for model checkpoints #4977

Merged
merged 20 commits into from
Dec 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
edcc225
Added changeable extension variable for model checkpoints
janhenriklambrechts Dec 4, 2020
4b6407e
Merge branch 'master' into master
janhenriklambrechts Dec 4, 2020
8779d49
Removed whitespace
janhenriklambrechts Dec 4, 2020
81471d0
Merge branch 'master' of https://github.com/janhenriklambrechts/pytor…
janhenriklambrechts Dec 4, 2020
30f6458
Removed the last bit of whitespace
janhenriklambrechts Dec 4, 2020
4c56ff7
Wrote tests for FILE_EXTENSION
janhenriklambrechts Dec 5, 2020
59910fd
Merge branch 'master' into master
janhenriklambrechts Dec 5, 2020
cb9e381
Fixed formatting issues
janhenriklambrechts Dec 5, 2020
315a5ca
Merge branch 'master' of https://github.com/janhenriklambrechts/pytor…
janhenriklambrechts Dec 5, 2020
8b1119e
More formatting issues
janhenriklambrechts Dec 5, 2020
fe3111d
Simplify test by just using defaults
janhenriklambrechts Dec 5, 2020
947a9d5
Formatting to PEP8
janhenriklambrechts Dec 5, 2020
435f9b2
Merge branch 'master' into master
janhenriklambrechts Dec 5, 2020
debc4c4
Merge branch 'master' into master
janhenriklambrechts Dec 5, 2020
410dd28
Added dummy class that inherits ModelCheckpoint; run only one batch i…
janhenriklambrechts Dec 6, 2020
85b4f4f
Merge branch 'master' of https://github.com/janhenriklambrechts/pytor…
janhenriklambrechts Dec 6, 2020
e395a37
Fixed too much whitespace formatting
janhenriklambrechts Dec 6, 2020
aa42f7c
Merge branch 'master' into master
janhenriklambrechts Dec 6, 2020
047424d
some changes
rohitgr7 Dec 6, 2020
1181a49
Merge branch 'master' into master
janhenriklambrechts Dec 6, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class ModelCheckpoint(Callback):
Example::
# custom path
# saves a file like: my/path/epoch=0.ckpt
# saves a file like: my/path/epoch=0-step=10.ckpt
>>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
By default, dirpath is ``None`` and will be set at runtime to the location
Expand Down Expand Up @@ -140,6 +140,7 @@ class ModelCheckpoint(Callback):

CHECKPOINT_JOIN_CHAR = "-"
CHECKPOINT_NAME_LAST = "last"
janhenriklambrechts marked this conversation as resolved.
Show resolved Hide resolved
FILE_EXTENSION = ".ckpt"

def __init__(
self,
Expand Down Expand Up @@ -442,7 +443,7 @@ def format_checkpoint_name(
)
if ver is not None:
filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}"))
ckpt_name = f"{filename}.ckpt"
ckpt_name = f"{filename}{self.FILE_EXTENSION}"
return os.path.join(self.dirpath, ckpt_name) if self.dirpath else ckpt_name

def __resolve_ckpt_dir(self, trainer, pl_module):
Expand Down Expand Up @@ -545,7 +546,7 @@ def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics, filepath)
ckpt_name_metrics,
prefix=self.prefix
)
last_filepath = os.path.join(self.dirpath, f"{last_filepath}.ckpt")
last_filepath = os.path.join(self.dirpath, f"{last_filepath}{self.FILE_EXTENSION}")

self._save_model(last_filepath, trainer, pl_module)
if (
Expand Down
23 changes: 23 additions & 0 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,29 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir):
assert ckpt_name == filepath / 'test-epoch=3-step=2.ckpt'


class ModelCheckpointExtensionTest(ModelCheckpoint):
FILE_EXTENSION = '.tpkc'


def test_model_checkpoint_file_extension(tmpdir):
"""
Test ModelCheckpoint with different file extension.
"""

model = LogInTwoMethods()
model_checkpoint = ModelCheckpointExtensionTest(monitor='early_stop_on', dirpath=tmpdir, save_top_k=1, save_last=True)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[model_checkpoint],
max_steps=1,
logger=False,
)
trainer.fit(model)

expected = ['epoch=0-step=0.tpkc', 'last.tpkc']
assert set(expected) == set(os.listdir(tmpdir))


def test_model_checkpoint_save_last(tmpdir):
"""Tests that save_last produces only one last checkpoint."""
seed_everything()
Expand Down