Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better graceful shutdown for KeyboardInterrupt #19976

Merged
merged 16 commits into from
Jun 16, 2024
34 changes: 34 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,40 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).


## [unreleased] - YYYY-MM-DD

### Added

-

-

### Changed

-

-

### Deprecated

-

-

### Removed

-

-

### Fixed

-

-



## [2.3.0] - 2024-06-13

### Added
Expand Down
4 changes: 4 additions & 0 deletions src/lightning/fabric/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import contextlib
import logging
import os
import signal
import time
from contextlib import nullcontext
from datetime import timedelta
Expand Down Expand Up @@ -306,8 +307,11 @@ def _init_dist_connection(


def _destroy_dist_connection() -> None:
# Don't allow Ctrl+C to interrupt this handler
signal.signal(signal.SIGINT, signal.SIG_IGN)
if _distributed_is_initialized():
torch.distributed.destroy_process_group()
signal.signal(signal.SIGINT, signal.SIG_DFL)


def _get_default_process_group_backend_for_device(device: torch.device) -> str:
Expand Down
35 changes: 35 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,41 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).


## [unreleased] - YYYY-MM-DD

### Added

-

-

### Changed

- Triggering KeyboardInterrupt (Ctrl+C) during `.fit()`, `.evaluate()`, `.test()` or `.predict()` now terminates all processes launched by the Trainer and exits the program ([#19976](https://github.com/Lightning-AI/pytorch-lightning/pull/19976))

-

### Deprecated

-

-

### Removed

-

-

### Fixed

-

-



## [2.3.0] - 2024-06-13

### Added
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def update_main_process_results(self, trainer: "pl.Trainer", extra: Dict[str, An
def kill(self, signum: _SIGNUM) -> None:
for proc in self.procs:
if proc.is_alive() and proc.pid is not None:
log.info(f"pid {os.getpid()} killing {proc.pid} with {signum}")
log.debug(f"Process {os.getpid()} is terminating {proc.pid} with {signum}")
with suppress(ProcessLookupError):
os.kill(proc.pid, signum)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
@override
def kill(self, signum: _SIGNUM) -> None:
for proc in self.procs:
log.info(f"pid {os.getpid()} killing {proc.pid} with {signum}")
log.debug(f"Process {os.getpid()} is terminating {proc.pid} with {signum}")
# this skips subprocesses already terminated
proc.send_signal(signum)

Expand Down
20 changes: 14 additions & 6 deletions src/lightning/pytorch/trainer/call.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import signal
from copy import deepcopy
from typing import Any, Callable, Dict, Optional, Type, Union

Expand All @@ -20,10 +21,12 @@
import lightning.pytorch as pl
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
from lightning.pytorch.callbacks import Checkpoint, EarlyStopping
from lightning.pytorch.strategies.launchers import _SubprocessScriptLauncher
from lightning.pytorch.trainer.connectors.signal_connector import _get_sigkill_signal
from lightning.pytorch.trainer.states import TrainerStatus
from lightning.pytorch.utilities.exceptions import _TunerExitException
from lightning.pytorch.utilities.model_helpers import is_overridden
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn

log = logging.getLogger(__name__)

Expand All @@ -49,12 +52,17 @@ def _call_and_handle_interrupt(trainer: "pl.Trainer", trainer_fn: Callable, *arg
trainer.state.status = TrainerStatus.FINISHED
trainer.state.stage = None

# TODO: Unify both exceptions below, where `KeyboardError` doesn't re-raise
except KeyboardInterrupt as exception:
rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
# user could press Ctrl+c many times... only shutdown once
if not trainer.interrupted:
_interrupt(trainer, exception)
rank_zero_info("\nDetected KeyboardInterrupt, attempting graceful shutdown ...")
# user could press Ctrl+C many times, disable KeyboardInterrupt for shutdown
signal.signal(signal.SIGINT, signal.SIG_IGN)
_interrupt(trainer, exception)
trainer._teardown()
launcher = trainer.strategy.launcher
if isinstance(launcher, _SubprocessScriptLauncher):
launcher.kill(_get_sigkill_signal())
exit(1)

except BaseException as exception:
_interrupt(trainer, exception)
trainer._teardown()
Expand Down
11 changes: 5 additions & 6 deletions src/lightning/pytorch/trainer/connectors/signal_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import os
import re
import signal
import sys
import threading
from subprocess import call
from types import FrameType
Expand Down Expand Up @@ -54,7 +53,7 @@ def register_signal_handlers(self) -> None:
sigterm_handlers.append(self._sigterm_handler_fn)

# Windows seems to have signal incompatibilities
if not self._is_on_windows():
if not _IS_WINDOWS:
sigusr = environment.requeue_signal if isinstance(environment, SLURMEnvironment) else signal.SIGUSR1
assert sigusr is not None
if sigusr_handlers and not self._has_already_handler(sigusr):
Expand Down Expand Up @@ -155,10 +154,6 @@ def _valid_signals() -> Set[signal.Signals]:
}
return set(signal.Signals)

@staticmethod
def _is_on_windows() -> bool:
return sys.platform == "win32"

@staticmethod
def _has_already_handler(signum: _SIGNUM) -> bool:
return signal.getsignal(signum) not in (None, signal.SIG_DFL)
Expand All @@ -172,3 +167,7 @@ def __getstate__(self) -> Dict:
state = self.__dict__.copy()
state["_original_handlers"] = {}
return state


def _get_sigkill_signal() -> _SIGNUM:
return signal.SIGTERM if _IS_WINDOWS else signal.SIGKILL
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def on_train_start(self) -> None:

with mock.patch(
"lightning.pytorch.callbacks.progress.rich_progress.Progress.stop", autospec=True
) as mock_progress_stop:
) as mock_progress_stop, pytest.raises(SystemExit):
progress_bar = RichProgressBar()
trainer = Trainer(
default_root_dir=tmp_path,
Expand Down
9 changes: 7 additions & 2 deletions tests/tests_pytorch/callbacks/test_lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from functools import partial

import pytest
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import Callback, LambdaCallback
from lightning.pytorch.demos.boring_classes import BoringModel
Expand All @@ -23,10 +24,13 @@
def test_lambda_call(tmp_path):
seed_everything(42)

class CustomException(Exception):
pass

class CustomModel(BoringModel):
def on_train_epoch_start(self):
if self.current_epoch > 1:
raise KeyboardInterrupt
raise CustomException("Custom exception to trigger `on_exception` hooks")

checker = set()

Expand Down Expand Up @@ -59,7 +63,8 @@ def call(hook, *_, **__):
limit_predict_batches=1,
callbacks=[LambdaCallback(**hooks_args)],
)
trainer.fit(model, ckpt_path=ckpt_path)
with pytest.raises(CustomException):
trainer.fit(model, ckpt_path=ckpt_path)
trainer.test(model)
trainer.predict(model)

Expand Down
3 changes: 2 additions & 1 deletion tests/tests_pytorch/trainer/test_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,5 +84,6 @@ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):

trainer = Trainer(callbacks=[InterruptCallback()], default_root_dir=tmp_path, **extra_params)

trainer.fit(model)
with pytest.raises(SystemExit):
trainer.fit(model)
assert trainer.interrupted
34 changes: 30 additions & 4 deletions tests/tests_pytorch/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import torch
import torch.nn as nn
from lightning.fabric.utilities.cloud_io import _load as pl_load
from lightning.fabric.utilities.imports import _IS_WINDOWS
from lightning.fabric.utilities.seed import seed_everything
from lightning.pytorch import Callback, LightningDataModule, LightningModule, Trainer
from lightning.pytorch.accelerators import CPUAccelerator, CUDAAccelerator
Expand All @@ -45,7 +46,7 @@
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.overrides.distributed import UnrepeatedDistributedSampler, _IndexBatchSamplerWrapper
from lightning.pytorch.strategies import DDPStrategy, SingleDeviceStrategy
from lightning.pytorch.strategies.launchers import _MultiProcessingLauncher
from lightning.pytorch.strategies.launchers import _MultiProcessingLauncher, _SubprocessScriptLauncher
from lightning.pytorch.trainer.states import RunningStage, TrainerFn
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE
Expand Down Expand Up @@ -1007,7 +1008,8 @@ def on_exception(self, trainer, pl_module, exception):
)
assert not trainer.interrupted
assert handle_interrupt_callback.exception is None
trainer.fit(model)
with pytest.raises(SystemExit):
trainer.fit(model)
assert trainer.interrupted
assert isinstance(handle_interrupt_callback.exception, KeyboardInterrupt)
with pytest.raises(MisconfigurationException):
Expand All @@ -1016,6 +1018,30 @@ def on_exception(self, trainer, pl_module, exception):
assert isinstance(handle_interrupt_callback.exception, MisconfigurationException)


def test_keyboard_interrupt(tmp_path):
class InterruptCallback(Callback):
def __init__(self):
super().__init__()

def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
raise KeyboardInterrupt

model = BoringModel()
trainer = Trainer(
callbacks=[InterruptCallback()],
barebones=True,
default_root_dir=tmp_path,
)

trainer.strategy._launcher = Mock(spec=_SubprocessScriptLauncher)
trainer.strategy._launcher.launch = lambda function, *args, trainer, **kwargs: function(*args, **kwargs)

with pytest.raises(SystemExit) as exc_info:
trainer.fit(model)
assert exc_info.value.args[0] == 1
trainer.strategy._launcher.kill.assert_called_once_with(15 if _IS_WINDOWS else 9)


@pytest.mark.parametrize("precision", ["32-true", pytest.param("16-mixed", marks=RunIf(min_cuda_gpus=1))])
@RunIf(sklearn=True)
def test_gradient_clipping_by_norm(tmp_path, precision):
Expand Down Expand Up @@ -2042,7 +2068,7 @@ def on_fit_start(self):

trainer = Trainer(default_root_dir=tmp_path)
with mock.patch("lightning.pytorch.strategies.strategy.Strategy.on_exception") as on_exception_mock, suppress(
Exception
Exception, SystemExit
):
trainer.fit(ExceptionModel())
on_exception_mock.assert_called_once_with(exception)
Expand All @@ -2061,7 +2087,7 @@ def on_fit_start(self):
datamodule.on_exception = Mock()
trainer = Trainer(default_root_dir=tmp_path)

with suppress(Exception):
with suppress(Exception, SystemExit):
trainer.fit(ExceptionModel(), datamodule=datamodule)
datamodule.on_exception.assert_called_once_with(exception)

Expand Down
Loading