Skip to content

Commit

Permalink
Remove deprecated on_keyboard_interrupt (#13438)
Browse files Browse the repository at this point in the history
Co-authored-by: Akihiro Nitta <nitta@akihironitta.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
  • Loading branch information
4 people authored Jul 5, 2022
1 parent 61473c2 commit 61c28cb
Show file tree
Hide file tree
Showing 12 changed files with 19 additions and 75 deletions.
5 changes: 0 additions & 5 deletions docs/source-pytorch/extensions/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -353,11 +353,6 @@ on_predict_end
.. automethod:: pytorch_lightning.callbacks.Callback.on_predict_end
:noindex:

on_keyboard_interrupt
^^^^^^^^^^^^^^^^^^^^^

.. automethod:: pytorch_lightning.callbacks.Callback.on_keyboard_interrupt
:noindex:

on_exception
^^^^^^^^^^^^
Expand Down
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed support for the `DDP2Strategy` ([#12705](https://github.com/PyTorchLightning/pytorch-lightning/pull/12705))


- Removed deprecated `Callback.on_keyboard_interrupt` ([#13438](https://github.com/Lightning-AI/lightning/pull/13438))


### Fixed


Expand Down
8 changes: 0 additions & 8 deletions src/pytorch_lightning/callbacks/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,14 +272,6 @@ def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule
def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when predict ends."""

def on_keyboard_interrupt(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
r"""
.. deprecated:: v1.5
This callback hook was deprecated in v1.5 in favor of `on_exception` and will be removed in v1.7.
Called when any trainer execution is interrupted by KeyboardInterrupt.
"""

def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: BaseException) -> None:
"""Called when any trainer execution is interrupted by an exception."""

Expand Down
1 change: 0 additions & 1 deletion src/pytorch_lightning/callbacks/lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ def __init__(
on_validation_end: Optional[Callable] = None,
on_test_start: Optional[Callable] = None,
on_test_end: Optional[Callable] = None,
on_keyboard_interrupt: Optional[Callable] = None,
on_exception: Optional[Callable] = None,
on_save_checkpoint: Optional[Callable] = None,
on_load_checkpoint: Optional[Callable] = None,
Expand Down
10 changes: 0 additions & 10 deletions src/pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,16 +553,6 @@ def on_predict_end(self) -> None:
for callback in self.callbacks:
callback.on_predict_end(self, self.lightning_module)

def on_keyboard_interrupt(self):
r"""
.. deprecated:: v1.5
This callback hook was deprecated in v1.5 in favor of `on_exception` and will be removed in v1.7.
Called when any trainer execution is interrupted by KeyboardInterrupt.
"""
for callback in self.callbacks:
callback.on_keyboard_interrupt(self, self.lightning_module)

def on_exception(self, exception: BaseException) -> None:
r"""
.. deprecated:: v1.6
Expand Down
5 changes: 0 additions & 5 deletions src/pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,11 +277,6 @@ def _check_on_pretrain_routine(model: "pl.LightningModule") -> None:

def _check_deprecated_callback_hooks(trainer: "pl.Trainer") -> None:
for callback in trainer.callbacks:
if is_overridden(method_name="on_keyboard_interrupt", instance=callback):
rank_zero_deprecation(
"The `on_keyboard_interrupt` callback hook was deprecated in v1.5 and will be removed in v1.7."
" Please use the `on_exception` callback hook instead."
)
if is_overridden(method_name="on_init_start", instance=callback):
rank_zero_deprecation(
"The `on_init_start` callback hook was deprecated in v1.6 and will be removed in v1.8."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ class _LogOptions(TypedDict):
),
"on_predict_batch_start": None,
"on_predict_batch_end": None,
"on_keyboard_interrupt": None,
"on_exception": None,
"state_dict": None,
"on_save_checkpoint": None,
Expand Down
3 changes: 1 addition & 2 deletions src/pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,13 +653,12 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs:
return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, **kwargs)
else:
return trainer_fn(*args, **kwargs)
# TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7
# TODO(awaelchli): Unify both exceptions below, where `KeyboardError` doesn't re-raise
except KeyboardInterrupt as exception:
rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
# user could press Ctrl+c many times... only shutdown once
if not self.interrupted:
self.state.status = TrainerStatus.INTERRUPTED
self._call_callback_hooks("on_keyboard_interrupt")
self._call_callback_hooks("on_exception", exception)
except BaseException as exception:
self.state.status = TrainerStatus.INTERRUPTED
Expand Down
17 changes: 12 additions & 5 deletions tests/tests_pytorch/callbacks/test_lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@ def call(hook, *_, **__):
limit_val_batches=1,
callbacks=[LambdaCallback(**hooks_args)],
)
with pytest.deprecated_call(match="on_keyboard_interrupt` callback hook was deprecated in v1.5"):
with pytest.deprecated_call(
match="`on_init_start` callback hook was deprecated in v1.6 and will be removed in v1.8."
):
trainer.fit(model)

ckpt_path = trainer.checkpoint_callback.best_model_path

# raises KeyboardInterrupt and loads from checkpoint
Expand All @@ -63,11 +64,17 @@ def call(hook, *_, **__):
limit_predict_batches=1,
callbacks=[LambdaCallback(**hooks_args)],
)
with pytest.deprecated_call(match="on_keyboard_interrupt` callback hook was deprecated in v1.5"):
with pytest.deprecated_call(
match="`on_init_start` callback hook was deprecated in v1.6 and will be removed in v1.8."
):
trainer.fit(model, ckpt_path=ckpt_path)
with pytest.deprecated_call(match="on_keyboard_interrupt` callback hook was deprecated in v1.5"):
with pytest.deprecated_call(
match="`on_init_start` callback hook was deprecated in v1.6 and will be removed in v1.8."
):
trainer.test(model)
with pytest.deprecated_call(match="on_keyboard_interrupt` callback hook was deprecated in v1.5"):
with pytest.deprecated_call(
match="`on_init_start` callback hook was deprecated in v1.6 and will be removed in v1.8."
):
trainer.predict(model)

assert checker == hooks
25 changes: 1 addition & 24 deletions tests/tests_pytorch/deprecated_api/test_remove_1-7.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import pytest
import torch

from pytorch_lightning import Callback, Trainer
from pytorch_lightning import Trainer
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper
from pytorch_lightning.plugins.environments import (
Expand All @@ -35,29 +35,6 @@
from tests_pytorch.plugins.environments.test_lsf_environment import _make_rankfile


def test_v1_7_0_on_interrupt(tmpdir):
class HandleInterruptCallback(Callback):
def on_keyboard_interrupt(self, trainer, pl_module):
print("keyboard interrupt")

model = BoringModel()
handle_interrupt_callback = HandleInterruptCallback()

trainer = Trainer(
callbacks=[handle_interrupt_callback],
max_epochs=1,
limit_val_batches=0.1,
limit_train_batches=0.2,
enable_progress_bar=False,
logger=False,
default_root_dir=tmpdir,
)
with pytest.deprecated_call(
match="The `on_keyboard_interrupt` callback hook was deprecated in v1.5 and will be removed in v1.7"
):
trainer.fit(model)


class BoringCallbackDDPSpawnModel(BoringModel):
def add_to_queue(self, queue):
...
Expand Down
2 changes: 0 additions & 2 deletions tests/tests_pytorch/trainer/logging_/test_logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def test_fx_validator():
"on_fit_start",
"on_init_end",
"on_init_start",
"on_keyboard_interrupt",
"on_exception",
"on_load_checkpoint",
"load_state_dict",
Expand Down Expand Up @@ -93,7 +92,6 @@ def test_fx_validator():
"on_configure_sharded_model",
"on_init_end",
"on_init_start",
"on_keyboard_interrupt",
"on_exception",
"on_load_checkpoint",
"load_state_dict",
Expand Down
14 changes: 2 additions & 12 deletions tests/tests_pytorch/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import math
import os
import pickle
import sys
from argparse import Namespace
from contextlib import nullcontext
from copy import deepcopy
Expand Down Expand Up @@ -1013,14 +1012,10 @@ class HandleInterruptCallback(Callback):
def __init__(self):
super().__init__()
self.exception = None
self.exc_info = None

def on_exception(self, trainer, pl_module, exception):
self.exception = exception

def on_keyboard_interrupt(self, trainer, pl_module):
self.exc_info = sys.exc_info()

interrupt_callback = InterruptCallback()
handle_interrupt_callback = HandleInterruptCallback()

Expand All @@ -1035,15 +1030,10 @@ def on_keyboard_interrupt(self, trainer, pl_module):
)
assert not trainer.interrupted
assert handle_interrupt_callback.exception is None
assert handle_interrupt_callback.exc_info is None
with pytest.deprecated_call(match="on_keyboard_interrupt` callback hook was deprecated in v1.5"):
trainer.fit(model)
trainer.fit(model)
assert trainer.interrupted
assert isinstance(handle_interrupt_callback.exception, KeyboardInterrupt)
assert isinstance(handle_interrupt_callback.exc_info[1], KeyboardInterrupt)
with pytest.raises(MisconfigurationException), pytest.deprecated_call(
match="on_keyboard_interrupt` callback hook was deprecated in v1.5"
):
with pytest.raises(MisconfigurationException):
trainer.test(model)
assert trainer.interrupted
assert isinstance(handle_interrupt_callback.exception, MisconfigurationException)
Expand Down

0 comments on commit 61c28cb

Please sign in to comment.