Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Add graceful detection of signal to exit + SignalConnector and merge SlurmConnector. #9566

Merged
merged 20 commits into from
Sep 17, 2021
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Added partial support for global random state fault-tolerance in map-style datasets ([#8950](https://github.com/PyTorchLightning/pytorch-lightning/pull/8950))
* Converted state to tuple explicitly when setting Python random state ([#9401](https://github.com/PyTorchLightning/pytorch-lightning/pull/9401))
* Added support for restarting an optimizer loop (multiple optimizers) ([#9537](https://github.com/PyTorchLightning/pytorch-lightning/pull/9537))
* Added mechanism to detect a signal has been sent so the Trainer can gracefully exit ([#9566](https://github.com/PyTorchLightning/pytorch-lightning/pull/9566))


- Checkpoint saving & loading extensibility:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ module = [
"pytorch_lightning.loops.evaluation_loop",
"pytorch_lightning.trainer.connectors.checkpoint_connector",
"pytorch_lightning.trainer.connectors.logger_connector.*",
"pytorch_lightning.trainer.connectors.signal_connector",
"pytorch_lightning.trainer.progress",
"pytorch_lightning.tuner.auto_gpu_select",
"pytorch_lightning.utilities.apply_func",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,14 +274,15 @@ def restore_lr_schedulers(self) -> None:
# PRIVATE OPS
# ----------------------------------

def hpc_save(self, folderpath: str, logger: LightningLoggerBase) -> str:
def hpc_save(self, folderpath: str, logger: Optional[LightningLoggerBase]) -> str:
# make sure the checkpoint folder exists
folderpath = str(folderpath) # because the tests pass a path object
fs = get_filesystem(folderpath)
fs.makedirs(folderpath, exist_ok=True)

# save logger to make sure we get all the metrics
logger.save()
if logger:
logger.finalize("finished")

max_suffix = self.max_ckpt_version_in_folder(folderpath)
ckpt_number = (max_suffix if max_suffix is not None else 0) + 1
Expand Down
110 changes: 110 additions & 0 deletions pytorch_lightning/trainer/connectors/signal_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import logging
tchaton marked this conversation as resolved.
Show resolved Hide resolved
import os
import signal
import sys
from signal import Signals
from subprocess import call
from types import FrameType, FunctionType
from typing import Callable, List, Union

import pytorch_lightning as pl
from pytorch_lightning.utilities.imports import _fault_tolerant_training

log = logging.getLogger(__name__)


class HandlersCompose:
def __init__(self, signal_handlers: Union[List[Callable], Callable]):
if not isinstance(signal_handlers, list):
signal_handlers = [signal_handlers]
self.signal_handlers = signal_handlers

def __call__(self, signum: Signals, frame: FrameType) -> None:
for signal_handler in self.signal_handlers:
signal_handler(signum, frame)


class SignalConnector:
def __init__(self, trainer: "pl.Trainer"):
self.trainer = trainer
self.trainer._terminate_gracefully = False

def register_signal_handlers(self) -> None:
sigusr1_handlers: List[Callable] = []
sigterm_handlers: List[Callable] = []

if _fault_tolerant_training():
sigusr1_handlers.append(self.fault_tolerant_sigusr1_handler_fn)
tchaton marked this conversation as resolved.
Show resolved Hide resolved

if self._is_on_slurm():
log.info("Set SLURM handle signals.")
sigusr1_handlers.append(self.slurm_sigusr1_handler_fn)

sigterm_handlers.append(self.sigterm_handler_fn)
Copy link
Contributor

@awaelchli awaelchli Oct 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this line got unindented, it was previously under the slurm check.
it means that Lightning can't be killed by sigterm.
As best as I can judge, this is not required for fault tolerant training.

See #10154 for context.


# signal.SIGUSR1 doesn't seem available on windows
if not self._is_on_windows():
tchaton marked this conversation as resolved.
Show resolved Hide resolved
if not self._has_already_handler(signal.SIGUSR1):
signal.signal(signal.SIGUSR1, HandlersCompose(sigusr1_handlers))

if not self._has_already_handler(signal.SIGTERM):
signal.signal(signal.SIGTERM, HandlersCompose(sigterm_handlers))

def slurm_sigusr1_handler_fn(self, signum: Signals, frame: FrameType) -> None:
if self.trainer.is_global_zero:
# save weights
log.info("handling SIGUSR1")
self.trainer.checkpoint_connector.hpc_save(self.trainer.weights_save_path, self.trainer.logger)
tchaton marked this conversation as resolved.
Show resolved Hide resolved

# find job id
job_id = os.environ["SLURM_JOB_ID"]
cmd = ["scontrol", "requeue", job_id]

# requeue job
log.info(f"requeing job {job_id}...")
try:
result = call(cmd)
except FileNotFoundError:
# This can occur if a subprocess call to `scontrol` is run outside a shell context
# Re-attempt call (now with shell context). If any error is raised, propagate to user.
# When running a shell command, it should be passed as a single string.
joint_cmd = [str(x) for x in cmd]
result = call(" ".join(joint_cmd), shell=True)

# print result text
if result == 0:
log.info(f"requeued exp {job_id}")
else:
log.warning("requeue failed...")

# close experiment to avoid issues
if self.trainer.logger:
self.trainer.logger.finalize("finished")

def fault_tolerant_sigusr1_handler_fn(self, signum: Signals, frame: FrameType) -> None:
self.trainer._terminate_gracefully = True
tchaton marked this conversation as resolved.
Show resolved Hide resolved

def sigterm_handler_fn(self, signum: Signals, frame: FrameType) -> None:
log.info("bypassing sigterm")
tchaton marked this conversation as resolved.
Show resolved Hide resolved

def _is_on_slurm(self) -> bool:
# see if we're using slurm (not interactive)
on_slurm = False
try:
job_name = os.environ["SLURM_JOB_NAME"]
if job_name != "bash":
on_slurm = True
tchaton marked this conversation as resolved.
Show resolved Hide resolved
# todo: specify the possible exception
except Exception:
pass

return on_slurm

def _is_on_windows(self) -> bool:
return sys.platform == "win32"

def _has_already_handler(self, signum: Signals) -> bool:
try:
return isinstance(signal.getsignal(signum), FunctionType)
except AttributeError:
return False
60 changes: 0 additions & 60 deletions pytorch_lightning/trainer/connectors/slurm_connector.py

This file was deleted.

8 changes: 4 additions & 4 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
from pytorch_lightning.trainer.connectors.model_connector import ModelConnector
from pytorch_lightning.trainer.connectors.optimizer_connector import OptimizerConnector
from pytorch_lightning.trainer.connectors.slurm_connector import SLURMConnector
from pytorch_lightning.trainer.connectors.signal_connector import SignalConnector
from pytorch_lightning.trainer.connectors.training_trick_connector import TrainingTricksConnector
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
from pytorch_lightning.trainer.deprecated_api import DeprecatedTrainerAttributes
Expand Down Expand Up @@ -383,7 +383,7 @@ def __init__(
self.debugging_connector = DebuggingConnector(self)
self.training_tricks_connector = TrainingTricksConnector(self)
self.checkpoint_connector = CheckpointConnector(self, resume_from_checkpoint)
self.slurm_connector = SLURMConnector(self)
self.signal_connector = SignalConnector(self)
self.tuner = Tuner(self)

# max_epochs won't default to 1000 if max_steps/max_time are specified (including being set to -1).
Expand Down Expand Up @@ -1096,8 +1096,8 @@ def _pre_training_routine(self):
# wait for all to join if on distributed
self.accelerator.barrier("setup_training")

# register auto-resubmit when on SLURM
self.slurm_connector.register_slurm_signal_handlers()
# register signals
self.signal_connector.register_signal_handlers()

self.checkpoint_connector.resume_end()

Expand Down
53 changes: 53 additions & 0 deletions tests/trainer/connectors/test_signal_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import signal
from time import sleep
from unittest import mock

import pytest

from pytorch_lightning import Trainer
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf


@pytest.mark.parametrize("register_handler", [False, True])
@pytest.mark.parametrize("terminate_gracefully", [False, True])
@RunIf(min_torch="1.7.0", skip_windows=True)
def test_fault_tolerant_sig_handler(register_handler, terminate_gracefully, tmpdir):

# hack to reset the signal
signal.signal(signal.SIGUSR1, 0)

if register_handler:

def handler(*_):
pass

signal.signal(signal.SIGUSR1, handler)

with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": str(int(terminate_gracefully))}):

class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
if terminate_gracefully or register_handler:
os.kill(os.getpid(), signal.SIGUSR1)
sleep(0.1)
return super().training_step(batch, batch_idx)

model = TestModel()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_train_batches=2, limit_val_batches=0)
trainer.fit(model)
assert trainer._terminate_gracefully == (False if register_handler else terminate_gracefully)