Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable None model checkpoint default #3669

Merged
merged 22 commits into from
Sep 27, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
34ec8ce
enable None model checkpoint default
williamFalcon Sep 26, 2020
430c656
enable None model checkpoint default
williamFalcon Sep 26, 2020
919e506
enable None model checkpoint default
williamFalcon Sep 26, 2020
a0a354e
enable None model checkpoint default
williamFalcon Sep 26, 2020
fd9513f
enable None model checkpoint default
williamFalcon Sep 26, 2020
6db6315
enable None model checkpoint default
williamFalcon Sep 26, 2020
219fdc4
enable None model checkpoint default
williamFalcon Sep 26, 2020
aec27c0
enable None model checkpoint default
williamFalcon Sep 26, 2020
237846f
enable None model checkpoint default
williamFalcon Sep 26, 2020
ccf6d70
enable None model checkpoint default
williamFalcon Sep 26, 2020
94eb9bb
enable None model checkpoint default
williamFalcon Sep 27, 2020
cc88049
enable None model checkpoint default
williamFalcon Sep 27, 2020
eeb367d
enable None model checkpoint default
williamFalcon Sep 27, 2020
4391f53
enable None model checkpoint default
williamFalcon Sep 27, 2020
96fca9b
enable None model checkpoint default
williamFalcon Sep 27, 2020
397cbe9
enable None model checkpoint default
williamFalcon Sep 27, 2020
4142ab8
enable None model checkpoint default
williamFalcon Sep 27, 2020
a2e7722
enable None model checkpoint default
williamFalcon Sep 27, 2020
038471e
enable None model checkpoint default
williamFalcon Sep 27, 2020
6f47bbd
enable None model checkpoint default
williamFalcon Sep 27, 2020
dcf53d2
enable None model checkpoint default
williamFalcon Sep 27, 2020
38fa422
enable None model checkpoint default
williamFalcon Sep 27, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 35 additions & 18 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class ModelCheckpoint(Callback):
def __init__(
self,
filepath: Optional[str] = None,
monitor: Optional[str] = "checkpoint_on",
monitor: Optional[str] = None,
verbose: bool = False,
save_last: bool = False,
save_top_k: int = 1,
Expand Down Expand Up @@ -169,6 +169,10 @@ def __init__(
self.save_function = None
self.warned_result_obj = False

if self.save_top_k != 1 and self.monitor is None:
raise MisconfigurationException('To save checkpoints for a top_k metric, '
'ModelCheckpoint(monitor) cannot be None')

torch_inf = torch.tensor(np.Inf)
mode_dict = {
"min": (torch_inf, "min"),
Expand Down Expand Up @@ -336,6 +340,13 @@ def on_validation_end(self, trainer, pl_module):
metrics = trainer.logger_connector.callback_metrics
epoch = trainer.current_epoch

# backward compatibility... need to deprecate
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these need to check if self.monitor is None before resetting it to the older defaults right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup! sorry, i think i removed the WIP too soon haha. but good catch

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ended up just doing a much needed clean up to maintain sanity haha

if 'val_loss' in metrics:
self.monitor = 'val_loss'

if 'checkpoint_on' in metrics:
self.monitor = 'checkpoint_on'

# validate metric
if not (self.monitor is None or self._is_valid_monitor_key(metrics)):
m = (
Expand Down Expand Up @@ -367,15 +378,31 @@ def on_validation_end(self, trainer, pl_module):
# this epoch called before
version_cnt += 1

if self.save_top_k != -1:
current = metrics.get(self.monitor)
save_all_models = self.save_top_k == -1
should_save_last = self.monitor is None or self.save_last

if not isinstance(current, torch.Tensor):
rank_zero_warn(
f"The metric you returned {self.monitor}={current} must be a `torch.Tensor` "
f"instance, checkpoint not saved HINT: what is the value of {self.monitor}?",
RuntimeWarning,
if should_save_last:
last_filepath = filepath

# when user ALSO asked for the 'last.ckpt' change the name
if self.save_last:
filename = self._format_checkpoint_name(
self.CHECKPOINT_NAME_LAST, epoch, ckpt_name_metrics, prefix=self.prefix
)
last_filepath = os.path.join(self.dirpath, f"{filename}.ckpt")

self._save_model(last_filepath, trainer, pl_module)
if self.last_model_path and self.last_model_path != last_filepath:
self._del_model(self.last_model_path)
self.last_model_path = last_filepath

if self.monitor is None:
self.best_model_path = self.last_model_path

if not save_all_models:
current = metrics.get(self.monitor)

if not isinstance(current, torch.Tensor) and current is not None:
if current is not None:
current = torch.tensor(current).to(pl_module.device)

Expand All @@ -401,16 +428,6 @@ def on_validation_end(self, trainer, pl_module):
), "tried to make a checkpoint from non global_rank=0"
self._save_model(filepath, trainer, pl_module)

if self.save_last:
filename = self._format_checkpoint_name(
self.CHECKPOINT_NAME_LAST, epoch, ckpt_name_metrics, prefix=self.prefix
)
filepath = os.path.join(self.dirpath, f"{filename}.ckpt")
self._save_model(filepath, trainer, pl_module)
if self.last_model_path and self.last_model_path != filepath:
self._del_model(self.last_model_path)
self.last_model_path = filepath

def _is_valid_monitor_key(self, metrics):
return self.monitor in metrics or len(metrics) == 0

Expand Down
1 change: 0 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,6 @@ def run_sanity_check(self, ref_model):
self.on_sanity_check_end()
self.running_sanity_check = False

@trainer_state(entering=TrainerState.RUNNING, exiting=TrainerState.FINISHED)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Borda let's please get rid of these... makes readability impossible (and IDE inspection)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, we can remove wrappers, just keep the state assignment..

def test(
self,
model: Optional[LightningModule] = None,
Expand Down
23 changes: 23 additions & 0 deletions tests/base/model_valid_epoch_ends.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,29 @@ class ValidationEpochEndVariations(ABC):
"""
Houses all variations of validation_epoch_end steps
"""
def validation_epoch_end_no_monitor(self, outputs):
"""
Called at the end of validation to aggregate outputs

Args:
outputs: list of individual outputs of each validation step
"""
# if returned a scalar from validation_step, outputs is a list of tensor scalars
# we return just the average in this case (if we want)
def _mean(res, key):
# recursive mean for multilevel dicts
return torch.stack([x[key] if isinstance(x, dict) else _mean(x, key) for x in res]).mean()

val_acc_mean = _mean(outputs, 'val_acc')

# alternate between tensor and scalar
if self.current_epoch % 2 == 0:
val_acc_mean = val_acc_mean.item()

metrics_dict = {'val_acc': val_acc_mean}
results = {'progress_bar': metrics_dict, 'log': metrics_dict}
return results


def validation_epoch_end(self, outputs):
"""
Expand Down
24 changes: 24 additions & 0 deletions tests/base/model_valid_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,30 @@ def validation_step(self, batch, batch_idx, *args, **kwargs):
})
return output

def validation_step_no_monitor(self, batch, batch_idx, *args, **kwargs):
"""
Lightning calls this inside the validation loop
:param batch:
:return:
"""
self.validation_step_called = True
x, y = batch
x = x.view(x.size(0), -1)
y_hat = self(x)

loss_val = self.loss(y, y_hat)

# acc
labels_hat = torch.argmax(y_hat, dim=1)
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
val_acc = torch.tensor(val_acc).type_as(x)

output = OrderedDict({
'val_acc': val_acc,
'test_dic': {'val_loss_a': loss_val}
})
return output

def validation_step_result_obj(self, batch, batch_idx, *args, **kwargs):
x, y = batch
x = x.view(x.size(0), -1)
Expand Down
72 changes: 55 additions & 17 deletions tests/callbacks/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from tests.base import EvalModelTemplate
from pytorch_lightning.utilities.exceptions import MisconfigurationException


@pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2])
Expand All @@ -21,7 +22,7 @@ def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k):
tutils.reset_seed()
model = EvalModelTemplate()

checkpoint = ModelCheckpoint(filepath=None, save_top_k=save_top_k)
checkpoint = ModelCheckpoint(monitor='val_loss', filepath=None, save_top_k=save_top_k)

trainer = Trainer(
default_root_dir=tmpdir,
Expand Down Expand Up @@ -97,7 +98,7 @@ def test_model_checkpoint_no_extraneous_invocations(tmpdir):
"""Test to ensure that the model callback saves the checkpoints only once in distributed mode."""
model = EvalModelTemplate()
num_epochs = 4
model_checkpoint = ModelCheckpointTestInvocations(
model_checkpoint = ModelCheckpointTestInvocations(monitor='val_loss',
expected_count=num_epochs, save_top_k=-1
)
trainer = Trainer(
Expand Down Expand Up @@ -131,23 +132,24 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir):
assert ckpt_name == 'test@epoch=3,acc=0.03000'
ModelCheckpoint.CHECKPOINT_JOIN_CHAR = char_org
# no filepath set
ckpt_name = ModelCheckpoint(filepath=None).format_checkpoint_name(3, {})
ckpt_name = ModelCheckpoint(monitor='val_loss', filepath=None).format_checkpoint_name(3, {})
assert ckpt_name == 'epoch=3.ckpt'
ckpt_name = ModelCheckpoint(filepath='').format_checkpoint_name(5, {})
ckpt_name = ModelCheckpoint(monitor='val_loss', filepath='').format_checkpoint_name(5, {})
assert ckpt_name == 'epoch=5.ckpt'
# CWD
ckpt_name = ModelCheckpoint(filepath='.').format_checkpoint_name(3, {})
ckpt_name = ModelCheckpoint(monitor='val_loss', filepath='.').format_checkpoint_name(3, {})
assert Path(ckpt_name) == Path('.') / 'epoch=3.ckpt'
# dir does not exist so it is used as filename
filepath = tmpdir / 'dir'
ckpt_name = ModelCheckpoint(filepath=filepath, prefix='test').format_checkpoint_name(3, {})
ckpt_name = ModelCheckpoint(monitor='val_loss', filepath=filepath, prefix='test').format_checkpoint_name(3, {})
assert ckpt_name == tmpdir / 'test-dir.ckpt'
# now, dir exists
os.mkdir(filepath)
ckpt_name = ModelCheckpoint(filepath=filepath, prefix='test').format_checkpoint_name(3, {})
ckpt_name = ModelCheckpoint(monitor='val_loss', filepath=filepath, prefix='test').format_checkpoint_name(3, {})
assert ckpt_name == filepath / 'test-epoch=3.ckpt'
# with ver
ckpt_name = ModelCheckpoint(filepath=tmpdir / 'name', prefix='test').format_checkpoint_name(3, {}, ver=3)
ckpt_name = ModelCheckpoint(monitor='val_loss',
filepath=tmpdir / 'name', prefix='test').format_checkpoint_name(3, {}, ver=3)
assert ckpt_name == tmpdir / 'test-name-v3.ckpt'


Expand All @@ -156,7 +158,7 @@ def test_model_checkpoint_save_last(tmpdir):
model = EvalModelTemplate()
epochs = 3
ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last-{epoch}'
model_checkpoint = ModelCheckpoint(filepath=tmpdir, save_top_k=-1, save_last=True)
model_checkpoint = ModelCheckpoint(monitor='val_loss', filepath=tmpdir, save_top_k=-1, save_last=True)
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=False,
Expand All @@ -167,18 +169,27 @@ def test_model_checkpoint_save_last(tmpdir):
last_filename = model_checkpoint._format_checkpoint_name(ModelCheckpoint.CHECKPOINT_NAME_LAST, epochs - 1, {})
last_filename = last_filename + '.ckpt'
assert str(tmpdir / last_filename) == model_checkpoint.last_model_path
assert set(os.listdir(tmpdir)) == set(
[f'epoch={i}.ckpt' for i in range(epochs)] + [last_filename, 'lightning_logs']
)
assert set(os.listdir(tmpdir)) == \
set([f'epoch={i}.ckpt' for i in range(epochs)] + [last_filename, 'lightning_logs'])
ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last'


def test_none_monitor_top_k(tmpdir):
"""
Make sure that when saving top k of anything (if it's not 1), then monitor cannot be none
"""
seed_everything(100)
num_epochs = 3
with pytest.raises(MisconfigurationException, match=r'To save checkpoints for a top_k metric.*'):
ModelCheckpoint(filepath=tmpdir, save_top_k=num_epochs, save_last=True)


def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
"""Tests that the save_last checkpoint contains the latest information."""
seed_everything(100)
model = EvalModelTemplate()
num_epochs = 3
model_checkpoint = ModelCheckpoint(filepath=tmpdir, save_top_k=num_epochs, save_last=True)
model_checkpoint = ModelCheckpoint(monitor='val_loss', filepath=tmpdir, save_top_k=num_epochs, save_last=True)
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=False,
Expand Down Expand Up @@ -208,7 +219,7 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
def test_model_checkpoint_none_monitor(tmpdir):
model = EvalModelTemplate()
epochs = 2
checkpoint_callback = ModelCheckpoint(filepath=tmpdir, monitor=None, save_top_k=-1)
checkpoint_callback = ModelCheckpoint(monitor='val_loss', filepath=tmpdir, save_top_k=-1)
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=False,
Expand Down Expand Up @@ -240,7 +251,7 @@ def test_ckpt_metric_names(tmpdir):
progress_bar_refresh_rate=0,
limit_train_batches=0.01,
limit_val_batches=0.01,
checkpoint_callback=ModelCheckpoint(filepath=tmpdir + "/{val_loss:.2f}"),
checkpoint_callback=ModelCheckpoint(monitor='val_loss', filepath=tmpdir + "/{val_loss:.2f}"),
)

trainer.fit(model)
Expand All @@ -253,6 +264,33 @@ def test_ckpt_metric_names(tmpdir):
assert len(val) > 3


def test_default_checkpoint_behavior(tmpdir):
os.environ['PL_DEV_DEBUG'] = '1'
model = EvalModelTemplate()
model.validation_step = model.validation_step_no_monitor
model.validation_epoch_end = model.validation_epoch_end_no_monitor

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=3,
progress_bar_refresh_rate=0,
limit_train_batches=5,
limit_val_batches=5,
)

trainer.fit(model)
results = trainer.test()

assert len(results) == 1
assert results[0]['test_acc'] >= 0.90
assert len(trainer.dev_debugger.checkpoint_callback_history) == 3

# make sure the checkpoint we saved has the metric in the name
ckpts = os.listdir(os.path.join(tmpdir, 'lightning_logs', 'version_0', 'checkpoints'))
assert len(ckpts) == 1
assert ckpts[0] == 'epoch=2.ckpt'


def test_ckpt_metric_names_results(tmpdir):
model = EvalModelTemplate()
model.training_step = model.training_step_result_obj
Expand All @@ -271,7 +309,7 @@ def test_ckpt_metric_names_results(tmpdir):
progress_bar_refresh_rate=0,
limit_train_batches=0.01,
limit_val_batches=0.01,
checkpoint_callback=ModelCheckpoint(filepath=tmpdir + "/{val_loss:.2f}"),
checkpoint_callback=ModelCheckpoint(monitor='val_loss', filepath=tmpdir + "/{val_loss:.2f}"),
)

trainer.fit(model)
Expand All @@ -294,7 +332,7 @@ def test_model_checkpoint_save_last_warning(tmpdir, caplog, max_epochs, should_v
model.validation_step = None
trainer = Trainer(
default_root_dir=tmpdir,
checkpoint_callback=ModelCheckpoint(filepath=tmpdir, save_top_k=0, save_last=save_last),
checkpoint_callback=ModelCheckpoint(monitor='val_loss', filepath=tmpdir, save_top_k=0, save_last=save_last),
max_epochs=max_epochs,
)
trainer.fit(model)
Expand Down