-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feat] Add graceful detection of signal to exit + SignalConnector and…
… merge SlurmConnector. (#9566) Co-authored-by: Sean Naren <sean@grid.ai>
- Loading branch information
Showing
7 changed files
with
172 additions
and
66 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
110 changes: 110 additions & 0 deletions
110
pytorch_lightning/trainer/connectors/signal_connector.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
import logging | ||
import os | ||
import signal | ||
import sys | ||
from signal import Signals | ||
from subprocess import call | ||
from types import FrameType, FunctionType | ||
from typing import Callable, List, Union | ||
|
||
import pytorch_lightning as pl | ||
from pytorch_lightning.utilities.imports import _fault_tolerant_training | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
class HandlersCompose: | ||
def __init__(self, signal_handlers: Union[List[Callable], Callable]): | ||
if not isinstance(signal_handlers, list): | ||
signal_handlers = [signal_handlers] | ||
self.signal_handlers = signal_handlers | ||
|
||
def __call__(self, signum: Signals, frame: FrameType) -> None: | ||
for signal_handler in self.signal_handlers: | ||
signal_handler(signum, frame) | ||
|
||
|
||
class SignalConnector: | ||
def __init__(self, trainer: "pl.Trainer"): | ||
self.trainer = trainer | ||
self.trainer._terminate_gracefully = False | ||
|
||
def register_signal_handlers(self) -> None: | ||
sigusr1_handlers: List[Callable] = [] | ||
sigterm_handlers: List[Callable] = [] | ||
|
||
if _fault_tolerant_training(): | ||
sigusr1_handlers.append(self.fault_tolerant_sigusr1_handler_fn) | ||
|
||
if self._is_on_slurm(): | ||
log.info("Set SLURM handle signals.") | ||
sigusr1_handlers.append(self.slurm_sigusr1_handler_fn) | ||
|
||
sigterm_handlers.append(self.sigterm_handler_fn) | ||
|
||
# signal.SIGUSR1 doesn't seem available on windows | ||
if not self._is_on_windows(): | ||
if not self._has_already_handler(signal.SIGUSR1): | ||
signal.signal(signal.SIGUSR1, HandlersCompose(sigusr1_handlers)) | ||
|
||
if not self._has_already_handler(signal.SIGTERM): | ||
signal.signal(signal.SIGTERM, HandlersCompose(sigterm_handlers)) | ||
|
||
def slurm_sigusr1_handler_fn(self, signum: Signals, frame: FrameType) -> None: | ||
if self.trainer.is_global_zero: | ||
# save weights | ||
log.info("handling SIGUSR1") | ||
self.trainer.checkpoint_connector.hpc_save(self.trainer.weights_save_path, self.trainer.logger) | ||
|
||
# find job id | ||
job_id = os.environ["SLURM_JOB_ID"] | ||
cmd = ["scontrol", "requeue", job_id] | ||
|
||
# requeue job | ||
log.info(f"requeing job {job_id}...") | ||
try: | ||
result = call(cmd) | ||
except FileNotFoundError: | ||
# This can occur if a subprocess call to `scontrol` is run outside a shell context | ||
# Re-attempt call (now with shell context). If any error is raised, propagate to user. | ||
# When running a shell command, it should be passed as a single string. | ||
joint_cmd = [str(x) for x in cmd] | ||
result = call(" ".join(joint_cmd), shell=True) | ||
|
||
# print result text | ||
if result == 0: | ||
log.info(f"requeued exp {job_id}") | ||
else: | ||
log.warning("requeue failed...") | ||
|
||
# close experiment to avoid issues | ||
if self.trainer.logger: | ||
self.trainer.logger.finalize("finished") | ||
|
||
def fault_tolerant_sigusr1_handler_fn(self, signum: Signals, frame: FrameType) -> None: | ||
self.trainer._terminate_gracefully = True | ||
|
||
def sigterm_handler_fn(self, signum: Signals, frame: FrameType) -> None: | ||
log.info("bypassing sigterm") | ||
|
||
def _is_on_slurm(self) -> bool: | ||
# see if we're using slurm (not interactive) | ||
on_slurm = False | ||
try: | ||
job_name = os.environ["SLURM_JOB_NAME"] | ||
if job_name != "bash": | ||
on_slurm = True | ||
# todo: specify the possible exception | ||
except Exception: | ||
pass | ||
|
||
return on_slurm | ||
|
||
def _is_on_windows(self) -> bool: | ||
return sys.platform == "win32" | ||
|
||
def _has_already_handler(self, signum: Signals) -> bool: | ||
try: | ||
return isinstance(signal.getsignal(signum), FunctionType) | ||
except AttributeError: | ||
return False |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import os | ||
import signal | ||
from time import sleep | ||
from unittest import mock | ||
|
||
import pytest | ||
|
||
from pytorch_lightning import Trainer | ||
from tests.helpers import BoringModel | ||
from tests.helpers.runif import RunIf | ||
|
||
|
||
@pytest.mark.parametrize("register_handler", [False, True]) | ||
@pytest.mark.parametrize("terminate_gracefully", [False, True]) | ||
@RunIf(min_torch="1.7.0", skip_windows=True) | ||
def test_fault_tolerant_sig_handler(register_handler, terminate_gracefully, tmpdir): | ||
|
||
# hack to reset the signal | ||
signal.signal(signal.SIGUSR1, 0) | ||
|
||
if register_handler: | ||
|
||
def handler(*_): | ||
pass | ||
|
||
signal.signal(signal.SIGUSR1, handler) | ||
|
||
with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": str(int(terminate_gracefully))}): | ||
|
||
class TestModel(BoringModel): | ||
def training_step(self, batch, batch_idx): | ||
if terminate_gracefully or register_handler: | ||
os.kill(os.getpid(), signal.SIGUSR1) | ||
sleep(0.1) | ||
return super().training_step(batch, batch_idx) | ||
|
||
model = TestModel() | ||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_train_batches=2, limit_val_batches=0) | ||
trainer.fit(model) | ||
assert trainer._terminate_gracefully == (False if register_handler else terminate_gracefully) |