Skip to content

Commit 848288c

Browse files
authored
[warning] Add a warning with missing callback with resume_from_checkpoint (#7254)
* add a warning * add changelog
1 parent e272bea commit 848288c

File tree

3 files changed

+39
-2
lines changed

3 files changed

+39
-2
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
136136
- Added `tpu_distributed` check for TPU Spawn barrier ([#7241](https://github.com/PyTorchLightning/pytorch-lightning/pull/7241))
137137

138138

139+
- Added warning when missing `Callback` and using `resume_from_checkpoint` ([#7254](https://github.com/PyTorchLightning/pytorch-lightning/pull/7254))
140+
141+
139142
### Changed
140143

141144
- Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259))

pytorch_lightning/trainer/callback_hook.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from pytorch_lightning.callbacks import Callback
2121
from pytorch_lightning.core.lightning import LightningModule
22-
from pytorch_lightning.utilities import rank_zero_deprecation
22+
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn
2323
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
2424
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
2525
from pytorch_lightning.utilities.warnings import WarningCache
@@ -293,10 +293,22 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[Type, dict]:
293293

294294
def on_load_checkpoint(self, checkpoint):
295295
"""Called when loading a model checkpoint."""
296-
callback_states = checkpoint.get('callbacks')
296+
297297
# Todo: the `callback_states` are dropped with TPUSpawn as they
298298
# can't be saved using `xm.save`
299299
# https://github.com/pytorch/xla/issues/2773
300+
callback_states = checkpoint.get('callbacks')
301+
302+
current_callbacks_type = {type(cb) for cb in self.callbacks}
303+
saved_callbacks_type = set(callback_states.keys())
304+
difference = saved_callbacks_type.difference(current_callbacks_type)
305+
if difference:
306+
rank_zero_warn(
307+
"Be aware that when using ``resume_from_checkpoint``, "
308+
"callbacks used to create the checkpoint need to be provided. "
309+
f"Please, add the following callbacks: {list(difference)}. ", UserWarning
310+
)
311+
300312
if callback_states is not None:
301313
for callback in self.callbacks:
302314
state = callback_states.get(type(callback))

tests/trainer/test_trainer.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2043,6 +2043,28 @@ def test_fit_test_synchronization(tmpdir):
20432043
trainer.test()
20442044

20452045

2046+
class CustomCallbackOnLoadCheckpoint(Callback):
2047+
2048+
def on_save_checkpoint(self, trainer, pl_module, checkpoint) -> dict:
2049+
return {"a": None}
2050+
2051+
2052+
def test_on_load_checkpoint_missing_callbacks(tmpdir):
2053+
""" Test a warning appears when callbacks in the checkpoint don't match callbacks provided when resuming. """
2054+
2055+
model = BoringModel()
2056+
chk = ModelCheckpoint(dirpath=tmpdir, save_last=True)
2057+
2058+
trainer = Trainer(default_root_dir=tmpdir, max_epochs=3, callbacks=[chk, CustomCallbackOnLoadCheckpoint()])
2059+
trainer.fit(model)
2060+
2061+
trainer = Trainer(
2062+
default_root_dir=tmpdir, max_epochs=5, resume_from_checkpoint=chk.last_model_path, progress_bar_refresh_rate=1
2063+
)
2064+
with pytest.warns(UserWarning, match="CustomCallbackOnLoadCheckpoint"):
2065+
trainer.fit(model)
2066+
2067+
20462068
def test_module_current_fx_attributes_reset(tmpdir):
20472069
""" Ensure that lightning module's attributes related to current hook fx are reset at the end of execution. """
20482070
model = BoringModel()

0 commit comments

Comments
 (0)