Skip to content

Commit

Permalink
[Feat] Add graceful detection of signal to exit + SignalConnector and…
Browse files Browse the repository at this point in the history
… merge SlurmConnector. (#9566)

Co-authored-by: Sean Naren <sean@grid.ai>
  • Loading branch information
tchaton and Sean Naren authored Sep 17, 2021
1 parent 856ed10 commit c7451b3
Show file tree
Hide file tree
Showing 7 changed files with 172 additions and 66 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Added partial support for global random state fault-tolerance in map-style datasets ([#8950](https://github.com/PyTorchLightning/pytorch-lightning/pull/8950))
* Converted state to tuple explicitly when setting Python random state ([#9401](https://github.com/PyTorchLightning/pytorch-lightning/pull/9401))
* Added support for restarting an optimizer loop (multiple optimizers) ([#9537](https://github.com/PyTorchLightning/pytorch-lightning/pull/9537))
* Added mechanism to detect a signal has been sent so the Trainer can gracefully exit ([#9566](https://github.com/PyTorchLightning/pytorch-lightning/pull/9566))


- Checkpoint saving & loading extensibility:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ module = [
"pytorch_lightning.loops.evaluation_loop",
"pytorch_lightning.trainer.connectors.checkpoint_connector",
"pytorch_lightning.trainer.connectors.logger_connector.*",
"pytorch_lightning.trainer.connectors.signal_connector",
"pytorch_lightning.trainer.progress",
"pytorch_lightning.tuner.auto_gpu_select",
"pytorch_lightning.utilities.apply_func",
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,14 +274,15 @@ def restore_lr_schedulers(self) -> None:
# PRIVATE OPS
# ----------------------------------

def hpc_save(self, folderpath: str, logger: LightningLoggerBase) -> str:
def hpc_save(self, folderpath: str, logger: Optional[LightningLoggerBase]) -> str:
# make sure the checkpoint folder exists
folderpath = str(folderpath) # because the tests pass a path object
fs = get_filesystem(folderpath)
fs.makedirs(folderpath, exist_ok=True)

# save logger to make sure we get all the metrics
logger.save()
if logger:
logger.finalize("finished")

max_suffix = self.max_ckpt_version_in_folder(folderpath)
ckpt_number = (max_suffix if max_suffix is not None else 0) + 1
Expand Down
110 changes: 110 additions & 0 deletions pytorch_lightning/trainer/connectors/signal_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import logging
import os
import signal
import sys
from signal import Signals
from subprocess import call
from types import FrameType, FunctionType
from typing import Callable, List, Union

import pytorch_lightning as pl
from pytorch_lightning.utilities.imports import _fault_tolerant_training

log = logging.getLogger(__name__)


class HandlersCompose:
def __init__(self, signal_handlers: Union[List[Callable], Callable]):
if not isinstance(signal_handlers, list):
signal_handlers = [signal_handlers]
self.signal_handlers = signal_handlers

def __call__(self, signum: Signals, frame: FrameType) -> None:
for signal_handler in self.signal_handlers:
signal_handler(signum, frame)


class SignalConnector:
def __init__(self, trainer: "pl.Trainer"):
self.trainer = trainer
self.trainer._terminate_gracefully = False

def register_signal_handlers(self) -> None:
sigusr1_handlers: List[Callable] = []
sigterm_handlers: List[Callable] = []

if _fault_tolerant_training():
sigusr1_handlers.append(self.fault_tolerant_sigusr1_handler_fn)

if self._is_on_slurm():
log.info("Set SLURM handle signals.")
sigusr1_handlers.append(self.slurm_sigusr1_handler_fn)

sigterm_handlers.append(self.sigterm_handler_fn)

# signal.SIGUSR1 doesn't seem available on windows
if not self._is_on_windows():
if not self._has_already_handler(signal.SIGUSR1):
signal.signal(signal.SIGUSR1, HandlersCompose(sigusr1_handlers))

if not self._has_already_handler(signal.SIGTERM):
signal.signal(signal.SIGTERM, HandlersCompose(sigterm_handlers))

def slurm_sigusr1_handler_fn(self, signum: Signals, frame: FrameType) -> None:
if self.trainer.is_global_zero:
# save weights
log.info("handling SIGUSR1")
self.trainer.checkpoint_connector.hpc_save(self.trainer.weights_save_path, self.trainer.logger)

# find job id
job_id = os.environ["SLURM_JOB_ID"]
cmd = ["scontrol", "requeue", job_id]

# requeue job
log.info(f"requeing job {job_id}...")
try:
result = call(cmd)
except FileNotFoundError:
# This can occur if a subprocess call to `scontrol` is run outside a shell context
# Re-attempt call (now with shell context). If any error is raised, propagate to user.
# When running a shell command, it should be passed as a single string.
joint_cmd = [str(x) for x in cmd]
result = call(" ".join(joint_cmd), shell=True)

# print result text
if result == 0:
log.info(f"requeued exp {job_id}")
else:
log.warning("requeue failed...")

# close experiment to avoid issues
if self.trainer.logger:
self.trainer.logger.finalize("finished")

def fault_tolerant_sigusr1_handler_fn(self, signum: Signals, frame: FrameType) -> None:
self.trainer._terminate_gracefully = True

def sigterm_handler_fn(self, signum: Signals, frame: FrameType) -> None:
log.info("bypassing sigterm")

def _is_on_slurm(self) -> bool:
# see if we're using slurm (not interactive)
on_slurm = False
try:
job_name = os.environ["SLURM_JOB_NAME"]
if job_name != "bash":
on_slurm = True
# todo: specify the possible exception
except Exception:
pass

return on_slurm

def _is_on_windows(self) -> bool:
return sys.platform == "win32"

def _has_already_handler(self, signum: Signals) -> bool:
try:
return isinstance(signal.getsignal(signum), FunctionType)
except AttributeError:
return False
60 changes: 0 additions & 60 deletions pytorch_lightning/trainer/connectors/slurm_connector.py

This file was deleted.

8 changes: 4 additions & 4 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
from pytorch_lightning.trainer.connectors.model_connector import ModelConnector
from pytorch_lightning.trainer.connectors.optimizer_connector import OptimizerConnector
from pytorch_lightning.trainer.connectors.slurm_connector import SLURMConnector
from pytorch_lightning.trainer.connectors.signal_connector import SignalConnector
from pytorch_lightning.trainer.connectors.training_trick_connector import TrainingTricksConnector
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
from pytorch_lightning.trainer.deprecated_api import DeprecatedTrainerAttributes
Expand Down Expand Up @@ -383,7 +383,7 @@ def __init__(
self.debugging_connector = DebuggingConnector(self)
self.training_tricks_connector = TrainingTricksConnector(self)
self.checkpoint_connector = CheckpointConnector(self, resume_from_checkpoint)
self.slurm_connector = SLURMConnector(self)
self.signal_connector = SignalConnector(self)
self.tuner = Tuner(self)

# max_epochs won't default to 1000 if max_steps/max_time are specified (including being set to -1).
Expand Down Expand Up @@ -1104,8 +1104,8 @@ def _pre_training_routine(self):
# wait for all to join if on distributed
self.accelerator.barrier("setup_training")

# register auto-resubmit when on SLURM
self.slurm_connector.register_slurm_signal_handlers()
# register signals
self.signal_connector.register_signal_handlers()

self.checkpoint_connector.resume_end()

Expand Down
53 changes: 53 additions & 0 deletions tests/trainer/connectors/test_signal_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import signal
from time import sleep
from unittest import mock

import pytest

from pytorch_lightning import Trainer
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf


@pytest.mark.parametrize("register_handler", [False, True])
@pytest.mark.parametrize("terminate_gracefully", [False, True])
@RunIf(min_torch="1.7.0", skip_windows=True)
def test_fault_tolerant_sig_handler(register_handler, terminate_gracefully, tmpdir):

# hack to reset the signal
signal.signal(signal.SIGUSR1, 0)

if register_handler:

def handler(*_):
pass

signal.signal(signal.SIGUSR1, handler)

with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": str(int(terminate_gracefully))}):

class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
if terminate_gracefully or register_handler:
os.kill(os.getpid(), signal.SIGUSR1)
sleep(0.1)
return super().training_step(batch, batch_idx)

model = TestModel()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_train_batches=2, limit_val_batches=0)
trainer.fit(model)
assert trainer._terminate_gracefully == (False if register_handler else terminate_gracefully)

0 comments on commit c7451b3

Please sign in to comment.