Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

signal handling and teardown #3632

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions pytorch_lightning/trainer/connectors/signal_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import signal
from typing import Callable, Any

from pytorch_lightning.trainer.states import TrainerState


class SignalConnector:
"""
Takes care of registering and restoring signal handlers for the
:class:`~pytorch_lightning.trainer.trainer.Trainer`. By default, it handles
SIGTERM, SIGINT and SIGSEGV by raising KeyboardInterrupt and letting Trainer do graceful shutdown.
"""

def __init__(self):
self.original_handlers = {}

def setup(self):
""" Registers the default signal handlers for the Trainer. """
self.register_signal(signal.SIGTERM, self.default_teardown)
self.register_signal(signal.SIGSEGV, self.default_teardown)
self.register_signal(signal.SIGINT, self.default_teardown)

def restore_signals(self):
""" Restores the original signal handlers (e.g. the Python or user defaults) """
for signum, handler in self.original_handlers.items():
signal.signal(signum, handler)

def register_signal(self, signum: int, handler: Callable):
""" Registers a signal handler and saves a reference to the original handler. """
self.original_handlers.update({signum: signal.getsignal(signum)})
signal.signal(signum, handler)

def default_teardown(self, signum: int, frame: Any): # pragma: no-cover
""" This default teardown raises KeyboardInterrupt and lets Trainer handle the graceful shutdown. """
# self.trainer.interrupted = True
# self.trainer._state = TrainerState.INTERRUPTED
raise KeyboardInterrupt
6 changes: 4 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from pytorch_lightning.profiler import BaseProfiler
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
from pytorch_lightning.trainer.configuration_validator import ConfigValidator
from pytorch_lightning.trainer.connectors.signal_connector import SignalConnector
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
Expand Down Expand Up @@ -277,6 +278,7 @@ def __init__(
self.profile_connector = ProfilerConnector(self)
self.checkpoint_connector = CheckpointConnector(self)
self.slurm_connector = SLURMConnector(self)
self.signal_connector = SignalConnector()
self.tuner = Tuner(self)
self.accelerator_backend = None
self.evaluation_loop = EvaluationLoop(self)
Expand Down Expand Up @@ -524,13 +526,13 @@ def train(self):

except KeyboardInterrupt:
rank_zero_warn('Detected KeyboardInterrupt, attempting graceful shutdown...')

print('here')
# user could press ctrl+c many times... only shutdown once
if not self.interrupted:
self.interrupted = True
self._state = TrainerState.INTERRUPTED
self.on_keyboard_interrupt()

print('here2')
# hook
self.train_loop.on_train_end()

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def num_optimizers(self):
return num_optimizers

def on_train_start(self):
self.trainer.signal_connector.setup()
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

# clear cache before training
if self.trainer.on_gpu and self.trainer.root_gpu is not None:
# use context because of:
Expand Down Expand Up @@ -186,8 +188,6 @@ def on_train_end(self):

# clear mem
if self.trainer.on_gpu:
model = self.trainer.get_model()
model.cpu()
torch.cuda.empty_cache()

def check_checkpoint_callback(self, should_save, is_last=False):
Expand Down
66 changes: 66 additions & 0 deletions tests/trainer/test_signal_handling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import os
import platform
import signal

import pytest
import torch.distributed

from pytorch_lightning import Trainer, Callback
from pytorch_lightning.trainer.states import TrainerState
from tests.base import EvalModelTemplate


class KillCallback(Callback):
""" A callback that simulates a terminal signal (e.g. SIGINT). """

def __init__(self, signal_code):
self._signal = signal_code

def on_batch_end(self, trainer, pl_module):
# send the signal after the first batch
assert trainer.global_step == 0, "did not interrupt training after first batch"
pid = os.getpid()
os.kill(pid, self._signal)

def on_train_end(self, trainer, pl_module):
assert trainer.train_loop._teardown_already_run

def on_keyboard_interrupt(self, trainer, pl_module):
assert trainer.state == TrainerState.INTERRUPTED
assert trainer.interrupted


def _get_available_signal_codes():
codes = [signal.SIGINT]
if platform.system() != "Windows":
codes += [signal.SIGTERM, signal.SIGSEGV]
codes = [pytest.param(c) for c in codes]
return codes


@pytest.mark.skipif(not torch.distributed.is_available(), reason="test requires torch.distributed module")
Borda marked this conversation as resolved.
Show resolved Hide resolved
@pytest.mark.parametrize(["signal_code"], _get_available_signal_codes())
def test_graceful_training_shutdown(tmpdir, signal_code):
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=100,
distributed_backend="ddp_cpu",
callbacks=[KillCallback(signal_code)],
num_processes=4
)
model = EvalModelTemplate()
trainer.fit(model)


@pytest.mark.parametrize(["signal_code"], _get_available_signal_codes())
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_graceful_training_shutdown_gpu(tmpdir, signal_code):
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=100,
distributed_backend="ddp_spawn",
gpus=2,
callbacks=[KillCallback(signal_code)],
)
model = EvalModelTemplate()
trainer.fit(model)