diff --git a/CHANGELOG.md b/CHANGELOG.md index 925106c035cf7..f3884c8dc2121 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470)) +- Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072)) + + ### Changed diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index d9ffff1bd47e6..d53acf0f7030d 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -17,7 +17,7 @@ """ import abc -from typing import Any +from typing import Any, Dict from pytorch_lightning.core.lightning import LightningModule @@ -177,12 +177,26 @@ def on_keyboard_interrupt(self, trainer, pl_module: LightningModule) -> None: """Called when the training is interrupted by ``KeyboardInterrupt``.""" pass - def on_save_checkpoint(self, trainer, pl_module: LightningModule) -> None: - """Called when saving a model checkpoint, use to persist state.""" + def on_save_checkpoint(self, trainer, pl_module: LightningModule, checkpoint: Dict[str, Any]) -> dict: + """ + Called when saving a model checkpoint, use to persist state. + + Args: + trainer: the current Trainer instance. + pl_module: the current LightningModule instance. + checkpoint: the checkpoint dictionary that will be saved. + + Returns: + The callback state. + """ pass - def on_load_checkpoint(self, checkpointed_state) -> None: - """Called when loading a model checkpoint, use to reload state.""" + def on_load_checkpoint(self, callback_state: Dict[str, Any]) -> None: + """Called when loading a model checkpoint, use to reload state. + + Args: + callback_state: the callback state returned by ``on_save_checkpoint``. + """ pass def on_after_backward(self, trainer, pl_module: LightningModule) -> None: diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 6a2b75c7de71d..9bcd028fa44eb 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -18,6 +18,7 @@ Monitor a metric and stop training when it stops improving. """ +from typing import Any, Dict import numpy as np import torch @@ -117,7 +118,7 @@ def _validate_condition_metric(self, logs): def monitor_op(self): return self.mode_dict[self.mode] - def on_save_checkpoint(self, trainer, pl_module): + def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]: return { 'wait_count': self.wait_count, 'stopped_epoch': self.stopped_epoch, @@ -125,11 +126,11 @@ def on_save_checkpoint(self, trainer, pl_module): 'patience': self.patience } - def on_load_checkpoint(self, checkpointed_state): - self.wait_count = checkpointed_state['wait_count'] - self.stopped_epoch = checkpointed_state['stopped_epoch'] - self.best_score = checkpointed_state['best_score'] - self.patience = checkpointed_state['patience'] + def on_load_checkpoint(self, callback_state: Dict[str, Any]): + self.wait_count = callback_state['wait_count'] + self.stopped_epoch = callback_state['stopped_epoch'] + self.best_score = callback_state['best_score'] + self.patience = callback_state['patience'] def on_validation_end(self, trainer, pl_module): if trainer.running_sanity_check: diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 8f2ad2a45a3a2..54ad16f7b686f 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -191,7 +191,7 @@ def on_validation_end(self, trainer, pl_module): """ self.save_checkpoint(trainer, pl_module) - def on_save_checkpoint(self, trainer, pl_module) -> Dict[str, Any]: + def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]: return { "monitor": self.monitor, "best_model_score": self.best_model_score, @@ -200,9 +200,9 @@ def on_save_checkpoint(self, trainer, pl_module) -> Dict[str, Any]: "dirpath": self.dirpath } - def on_load_checkpoint(self, checkpointed_state: Dict[str, Any]): - self.best_model_score = checkpointed_state["best_model_score"] - self.best_model_path = checkpointed_state["best_model_path"] + def on_load_checkpoint(self, callback_state: Dict[str, Any]): + self.best_model_score = callback_state["best_model_score"] + self.best_model_path = callback_state["best_model_path"] def save_checkpoint(self, trainer, pl_module): """ diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index f292f5a78bc65..60e9183ac42f7 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -14,10 +14,12 @@ from abc import ABC from copy import deepcopy -from typing import List +from inspect import signature +from typing import List, Dict, Any, Type, Callable from pytorch_lightning.callbacks import Callback from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.utilities import rank_zero_warn class TrainerCallbackHookMixin(ABC): @@ -197,14 +199,29 @@ def on_keyboard_interrupt(self): for callback in self.callbacks: callback.on_keyboard_interrupt(self, self.lightning_module) - def on_save_checkpoint(self): + @staticmethod + def __is_old_signature(fn: Callable) -> bool: + parameters = list(signature(fn).parameters) + if len(parameters) == 2 and parameters[1] != "args": + return True + return False + + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[Type, dict]: """Called when saving a model checkpoint.""" callback_states = {} for callback in self.callbacks: - callback_class = type(callback) - state = callback.on_save_checkpoint(self, self.lightning_module) + if self.__is_old_signature(callback.on_save_checkpoint): + rank_zero_warn( + "`Callback.on_save_checkpoint` signature has changed in v1.3." + " A `checkpoint` parameter has been added." + " Support for the old signature will be removed in v1.5", + DeprecationWarning + ) + state = callback.on_save_checkpoint(self, self.lightning_module) # noqa: parameter-unfilled + else: + state = callback.on_save_checkpoint(self, self.lightning_module, checkpoint) if state: - callback_states[callback_class] = state + callback_states[type(callback)] = state return callback_states def on_load_checkpoint(self, checkpoint): diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 3b75f406b1917..60c76b70bba50 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -270,17 +270,18 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: if not has_reached_max_steps: current_epoch += 1 + model = self.trainer.lightning_module + checkpoint = { 'epoch': current_epoch, 'global_step': global_step, 'pytorch-lightning_version': pytorch_lightning.__version__, + 'state_dict': model.state_dict(), } if not weights_only: - # dump callbacks - callback_states = self.trainer.on_save_checkpoint() - checkpoint['callbacks'] = callback_states + checkpoint['callbacks'] = self.trainer.on_save_checkpoint(checkpoint) optimizer_states = [] for i, optimizer in enumerate(self.trainer.optimizers): @@ -305,12 +306,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: elif self.trainer.amp_backend == AMPType.APEX: checkpoint['amp_scaling_state'] = amp.state_dict() - # add the hyper_parameters and state_dict from the model - model = self.trainer.lightning_module - - # dump the module_arguments and state_dict from the model - checkpoint['state_dict'] = model.state_dict() - + # dump hyper-parameters if model.hparams: if hasattr(model, '_hparams_name'): checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_NAME] = model._hparams_name diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 379bc79263a6e..3eaaf81ca0a1e 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -98,7 +98,7 @@ def test_trainer_callback_system(torch_save, tmpdir): call.on_validation_epoch_end(trainer, model), call.on_epoch_end(trainer, model), call.on_validation_end(trainer, model), - call.on_save_checkpoint(trainer, model), + call.on_save_checkpoint(trainer, model), # should take ANY but we are inspecting signature for BC call.on_train_end(trainer, model), call.on_fit_end(trainer, model), call.teardown(trainer, model, 'fit'), diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index 6470e1837d87c..9954560beed15 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -40,11 +40,11 @@ def __init__(self, expected_state, *args, **kwargs): def on_train_start(self, trainer, pl_module): if self.expected_state: - assert self.on_save_checkpoint(trainer, pl_module) == self.expected_state + assert self.on_save_checkpoint(trainer, pl_module, {}) == self.expected_state def on_validation_end(self, trainer, pl_module): super().on_validation_end(trainer, pl_module) - self.saved_states.append(self.on_save_checkpoint(trainer, pl_module).copy()) + self.saved_states.append(self.on_save_checkpoint(trainer, pl_module, {}).copy()) def test_resume_early_stopping_from_checkpoint(tmpdir): diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 9f7a41946f586..8ea6f8a600a7d 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -346,9 +346,9 @@ def __init__(self, expected_count, *args, **kwargs): def on_train_start(self, trainer, pl_module): torch.save = Mock(wraps=torch.save) - def on_save_checkpoint(self, trainer, pl_module): + def on_save_checkpoint(self, trainer, pl_module, checkpoint): # expect all ranks to run but only rank 0 will actually write the checkpoint file - super().on_save_checkpoint(trainer, pl_module) + super().on_save_checkpoint(trainer, pl_module, checkpoint) self.on_save_checkpoint_count += 1 def on_train_end(self, trainer, pl_module): diff --git a/tests/deprecated_api/test_remove_1-4.py b/tests/deprecated_api/test_remove_1-4.py index 7ccd1dafe02a3..491d0bea3223c 100644 --- a/tests/deprecated_api/test_remove_1-4.py +++ b/tests/deprecated_api/test_remove_1-4.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Test deprecated functionality which will be removed in vX.Y.Z""" +"""Test deprecated functionality which will be removed in v1.4.0""" import sys import pytest @@ -243,5 +243,5 @@ def training_step(self, batch, batch_idx): trainer = Trainer(default_root_dir=tmpdir, checkpoint_callback=True, max_epochs=1) - with pytest.warns(DeprecationWarning, match=r"Relying on.*is deprecated in v1.2 and will be removed in v1.4"): + with pytest.deprecated_call(match=r"Relying on.*is deprecated in v1.2 and will be removed in v1.4"): trainer.fit(TestModel()) diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py new file mode 100644 index 0000000000000..e87fb5c2ebbb2 --- /dev/null +++ b/tests/deprecated_api/test_remove_1-5.py @@ -0,0 +1,56 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test deprecated functionality which will be removed in v1.5.0""" + +import pytest + +from pytorch_lightning import Trainer, Callback +from tests.helpers import BoringModel +from tests.helpers.utils import no_warning_call + + +def test_v1_5_0_old_callback_on_save_checkpoint(tmpdir): + class OldSignature(Callback): + def on_save_checkpoint(self, trainer, pl_module): # noqa + ... + + model = BoringModel() + trainer_kwargs = { + "default_root_dir": tmpdir, + "checkpoint_callback": False, + "max_epochs": 1, + } + filepath = tmpdir / "test.ckpt" + + trainer = Trainer(**trainer_kwargs, callbacks=[OldSignature()]) + trainer.fit(model) + + with pytest.deprecated_call(match="old signature will be removed in v1.5"): + trainer.save_checkpoint(filepath) + + class NewSignature(Callback): + def on_save_checkpoint(self, trainer, pl_module, checkpoint): + ... + + class ValidSignature1(Callback): + def on_save_checkpoint(self, trainer, *args): + ... + + class ValidSignature2(Callback): + def on_save_checkpoint(self, *args): + ... + + trainer.callbacks = [NewSignature(), ValidSignature1(), ValidSignature2()] + with no_warning_call(DeprecationWarning): + trainer.save_checkpoint(filepath) diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index d23f3d5540e78..5a7062829d738 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -14,6 +14,10 @@ import functools import os import traceback +from contextlib import contextmanager +from typing import Optional + +import pytest from pytorch_lightning import seed_everything from pytorch_lightning.callbacks import ModelCheckpoint @@ -111,3 +115,18 @@ def inner_f(queue, **kwargs): assert result == 1, 'expected 1, but returned %s' % result return wrapper + + +@contextmanager +def no_warning_call(warning_type, match: Optional[str] = None): + with pytest.warns(None) as record: + yield + + try: + w = record.pop(warning_type) + if not ((match and match in w.text) or w): + return + except AssertionError: + # no warning raised + return + raise AssertionError(f"`{warning_type}` was raised: {w}") diff --git a/tests/trainer/connectors/test_callback_connector.py b/tests/trainer/connectors/test_callback_connector.py index a472f4398c967..34149e2231bf5 100644 --- a/tests/trainer/connectors/test_callback_connector.py +++ b/tests/trainer/connectors/test_callback_connector.py @@ -42,13 +42,13 @@ def test_checkpoint_callbacks_are_last(tmpdir): class StatefulCallback0(Callback): - def on_save_checkpoint(self, trainer, pl_module): + def on_save_checkpoint(self, *args): return {"content0": 0} class StatefulCallback1(Callback): - def on_save_checkpoint(self, trainer, pl_module): + def on_save_checkpoint(self, *args): return {"content1": 1}