Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add checkpoint parameter to on_save_checkpoint #6072

Merged
merged 11 commits into from
Feb 25, 2021
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470))


- Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072))


### Changed


Expand Down
24 changes: 19 additions & 5 deletions pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""

import abc
from typing import Any
from typing import Any, Dict

from pytorch_lightning.core.lightning import LightningModule

Expand Down Expand Up @@ -177,12 +177,26 @@ def on_keyboard_interrupt(self, trainer, pl_module: LightningModule) -> None:
"""Called when the training is interrupted by ``KeyboardInterrupt``."""
pass

def on_save_checkpoint(self, trainer, pl_module: LightningModule) -> None:
"""Called when saving a model checkpoint, use to persist state."""
def on_save_checkpoint(self, trainer, pl_module: LightningModule, checkpoint: Dict[str, Any]) -> dict:
"""
Called when saving a model checkpoint, use to persist state.

Args:
trainer: the current Trainer instance.
pl_module: the current LightningModule instance.
checkpoint: the checkpoint dictionary that will be saved.

Returns:
The callback state.
"""
pass

def on_load_checkpoint(self, checkpointed_state) -> None:
"""Called when loading a model checkpoint, use to reload state."""
def on_load_checkpoint(self, callback_state: Dict[str, Any]) -> None:
"""Called when loading a model checkpoint, use to reload state.

Args:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
callback_state: the callback state returned by ``on_save_checkpoint``.
"""
pass

def on_after_backward(self, trainer, pl_module: LightningModule) -> None:
Expand Down
13 changes: 7 additions & 6 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Monitor a metric and stop training when it stops improving.

"""
from typing import Any, Dict

import numpy as np
import torch
Expand Down Expand Up @@ -140,19 +141,19 @@ def _validate_condition_metric(self, logs):
def monitor_op(self):
return self.mode_dict[self.mode]

def on_save_checkpoint(self, trainer, pl_module):
def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]:
return {
'wait_count': self.wait_count,
'stopped_epoch': self.stopped_epoch,
'best_score': self.best_score,
'patience': self.patience
}

def on_load_checkpoint(self, checkpointed_state):
self.wait_count = checkpointed_state['wait_count']
self.stopped_epoch = checkpointed_state['stopped_epoch']
self.best_score = checkpointed_state['best_score']
self.patience = checkpointed_state['patience']
def on_load_checkpoint(self, callback_state: Dict[str, Any]):
self.wait_count = callback_state['wait_count']
self.stopped_epoch = callback_state['stopped_epoch']
self.best_score = callback_state['best_score']
self.patience = callback_state['patience']

def on_validation_end(self, trainer, pl_module):
if trainer.running_sanity_check:
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def on_validation_end(self, trainer, pl_module):
"""
self.save_checkpoint(trainer, pl_module)

def on_save_checkpoint(self, trainer, pl_module) -> Dict[str, Any]:
def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]:
return {
"monitor": self.monitor,
"best_model_score": self.best_model_score,
Expand All @@ -200,9 +200,9 @@ def on_save_checkpoint(self, trainer, pl_module) -> Dict[str, Any]:
"dirpath": self.dirpath
}

def on_load_checkpoint(self, checkpointed_state: Dict[str, Any]):
self.best_model_score = checkpointed_state["best_model_score"]
self.best_model_path = checkpointed_state["best_model_path"]
def on_load_checkpoint(self, callback_state: Dict[str, Any]):
self.best_model_score = callback_state["best_model_score"]
self.best_model_path = callback_state["best_model_path"]

def save_checkpoint(self, trainer, pl_module):
"""
Expand Down
27 changes: 22 additions & 5 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@

from abc import ABC
from copy import deepcopy
from typing import List
from inspect import signature
from typing import List, Dict, Any, Type, Callable

from pytorch_lightning.callbacks import Callback
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_warn


class TrainerCallbackHookMixin(ABC):
Expand Down Expand Up @@ -197,14 +199,29 @@ def on_keyboard_interrupt(self):
for callback in self.callbacks:
callback.on_keyboard_interrupt(self, self.lightning_module)

def on_save_checkpoint(self):
@staticmethod
def __is_old_signature(fn: Callable) -> bool:
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
parameters = list(signature(fn).parameters)
if len(parameters) == 2 and parameters[1] != "args":
return True
return False

def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[Type, dict]:
"""Called when saving a model checkpoint."""
callback_states = {}
for callback in self.callbacks:
callback_class = type(callback)
state = callback.on_save_checkpoint(self, self.lightning_module)
if self.__is_old_signature(callback.on_save_checkpoint):
rank_zero_warn(
"`Callback.on_save_checkpoint` signature has changed in v1.3."
" A `checkpoint` parameter has been added."
" Support for the old signature will be removed in v1.5",
DeprecationWarning
)
state = callback.on_save_checkpoint(self, self.lightning_module) # noqa: parameter-unfilled
else:
state = callback.on_save_checkpoint(self, self.lightning_module, checkpoint)
if state:
callback_states[callback_class] = state
callback_states[type(callback)] = state
return callback_states

def on_load_checkpoint(self, checkpoint):
Expand Down
14 changes: 5 additions & 9 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,17 +270,18 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
if not has_reached_max_steps:
current_epoch += 1

model = self.trainer.lightning_module

checkpoint = {
'epoch': current_epoch,
'global_step': global_step,
'pytorch-lightning_version': pytorch_lightning.__version__,
'state_dict': model.state_dict(),
}

if not weights_only:

# dump callbacks
callback_states = self.trainer.on_save_checkpoint()
checkpoint['callbacks'] = callback_states
checkpoint['callbacks'] = self.trainer.on_save_checkpoint(checkpoint)

optimizer_states = []
for i, optimizer in enumerate(self.trainer.optimizers):
Expand All @@ -305,12 +306,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
elif self.trainer.amp_backend == AMPType.APEX:
checkpoint['amp_scaling_state'] = amp.state_dict()

# add the hyper_parameters and state_dict from the model
model = self.trainer.lightning_module

# dump the module_arguments and state_dict from the model
checkpoint['state_dict'] = model.state_dict()

# dump hyper-parameters
if model.hparams:
if hasattr(model, '_hparams_name'):
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_NAME] = model._hparams_name
Expand Down
2 changes: 1 addition & 1 deletion tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def test_trainer_callback_system(torch_save, tmpdir):
call.on_validation_epoch_end(trainer, model),
call.on_epoch_end(trainer, model),
call.on_validation_end(trainer, model),
call.on_save_checkpoint(trainer, model),
call.on_save_checkpoint(trainer, model), # should take ANY but we are inspecting signature for BC
call.on_train_end(trainer, model),
call.on_fit_end(trainer, model),
call.teardown(trainer, model, 'fit'),
Expand Down
4 changes: 2 additions & 2 deletions tests/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ def __init__(self, expected_state, *args, **kwargs):

def on_train_start(self, trainer, pl_module):
if self.expected_state:
assert self.on_save_checkpoint(trainer, pl_module) == self.expected_state
assert self.on_save_checkpoint(trainer, pl_module, {}) == self.expected_state

def on_validation_end(self, trainer, pl_module):
super().on_validation_end(trainer, pl_module)
self.saved_states.append(self.on_save_checkpoint(trainer, pl_module).copy())
self.saved_states.append(self.on_save_checkpoint(trainer, pl_module, {}).copy())


def test_resume_early_stopping_from_checkpoint(tmpdir):
Expand Down
4 changes: 2 additions & 2 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,9 +346,9 @@ def __init__(self, expected_count, *args, **kwargs):
def on_train_start(self, trainer, pl_module):
torch.save = Mock(wraps=torch.save)

def on_save_checkpoint(self, trainer, pl_module):
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
# expect all ranks to run but only rank 0 will actually write the checkpoint file
super().on_save_checkpoint(trainer, pl_module)
super().on_save_checkpoint(trainer, pl_module, checkpoint)
self.on_save_checkpoint_count += 1

def on_train_end(self, trainer, pl_module):
Expand Down
4 changes: 2 additions & 2 deletions tests/deprecated_api/test_remove_1-4.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test deprecated functionality which will be removed in vX.Y.Z"""
"""Test deprecated functionality which will be removed in v1.4.0"""
import sys

import pytest
Expand Down Expand Up @@ -243,5 +243,5 @@ def training_step(self, batch, batch_idx):

trainer = Trainer(default_root_dir=tmpdir, checkpoint_callback=True, max_epochs=1)

with pytest.warns(DeprecationWarning, match=r"Relying on.*is deprecated in v1.2 and will be removed in v1.4"):
with pytest.deprecated_call(match=r"Relying on.*is deprecated in v1.2 and will be removed in v1.4"):
trainer.fit(TestModel())
56 changes: 56 additions & 0 deletions tests/deprecated_api/test_remove_1-5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test deprecated functionality which will be removed in v1.5.0"""

import pytest

from pytorch_lightning import Trainer, Callback
from tests.helpers import BoringModel
from tests.helpers.utils import no_warning_call


def test_v1_5_0_old_callback_on_save_checkpoint(tmpdir):
class OldSignature(Callback):
def on_save_checkpoint(self, trainer, pl_module): # noqa
...

model = BoringModel()
trainer_kwargs = {
"default_root_dir": tmpdir,
"checkpoint_callback": False,
"max_epochs": 1,
}
filepath = tmpdir / "test.ckpt"

trainer = Trainer(**trainer_kwargs, callbacks=[OldSignature()])
trainer.fit(model)

with pytest.deprecated_call(match="old signature will be removed in v1.5"):
trainer.save_checkpoint(filepath)

class NewSignature(Callback):
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
...

class ValidSignature1(Callback):
def on_save_checkpoint(self, trainer, *args):
...

class ValidSignature2(Callback):
def on_save_checkpoint(self, *args):
...

trainer.callbacks = [NewSignature(), ValidSignature1(), ValidSignature2()]
with no_warning_call(DeprecationWarning):
trainer.save_checkpoint(filepath)
19 changes: 19 additions & 0 deletions tests/helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
import functools
import os
import traceback
from contextlib import contextmanager
from typing import Optional

import pytest

from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
Expand Down Expand Up @@ -111,3 +115,18 @@ def inner_f(queue, **kwargs):
assert result == 1, 'expected 1, but returned %s' % result

return wrapper


@contextmanager
def no_warning_call(warning_type, match: Optional[str] = None):
with pytest.warns(None) as record:
yield

try:
w = record.pop(warning_type)
if not ((match and match in w.text) or w):
return
except AssertionError:
# no warning raised
return
raise AssertionError(f"`{warning_type}` was raised: {w}")
4 changes: 2 additions & 2 deletions tests/trainer/connectors/test_callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@ def test_checkpoint_callbacks_are_last(tmpdir):

class StatefulCallback0(Callback):

def on_save_checkpoint(self, trainer, pl_module):
def on_save_checkpoint(self, *args):
return {"content0": 0}


class StatefulCallback1(Callback):

def on_save_checkpoint(self, trainer, pl_module):
def on_save_checkpoint(self, *args):
return {"content1": 1}


Expand Down