Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Add graceful detection of signal to exit + SignalConnector and merge SlurmConnector. #9566

Merged
merged 20 commits into from
Sep 17, 2021
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Added partial support for global random state fault-tolerance in map-style datasets ([#8950](https://github.com/PyTorchLightning/pytorch-lightning/pull/8950))
* Converted state to tuple explicitly when setting Python random state ([#9401](https://github.com/PyTorchLightning/pytorch-lightning/pull/9401))
* Added support for restarting an optimizer loop (multiple optimizers) ([#9537](https://github.com/PyTorchLightning/pytorch-lightning/pull/9537))
* Added a mechanism to detect a signal as been sent so the Trainer can gracefully exit ([#9566](https://github.com/PyTorchLightning/pytorch-lightning/pull/9566))
tchaton marked this conversation as resolved.
Show resolved Hide resolved
tchaton marked this conversation as resolved.
Show resolved Hide resolved


- Checkpoint saving & loading extensibility:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,27 @@
import signal
from subprocess import call

from pytorch_lightning.plugins.environments import SLURMEnvironment
from pytorch_lightning.utilities.imports import _fault_tolerant_training

log = logging.getLogger(__name__)


class SLURMConnector:
class SignalConnector:
def __init__(self, trainer):
self.trainer = trainer
self.trainer._should_gracefully_terminate = False
tchaton marked this conversation as resolved.
Show resolved Hide resolved

def register_signal_handlers(self):
cluster_env = getattr(self.trainer.training_type_plugin, "cluster_environment", None)
tchaton marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(cluster_env, SLURMEnvironment):
self.register_slurm_signal_handlers()
elif _fault_tolerant_training():
tchaton marked this conversation as resolved.
Show resolved Hide resolved
self.register_fault_tolerant_handlers()

def register_fault_tolerant_handlers(self):
signal.signal(signal.SIGUSR1, self.sig_fault_tolerant_handler)
signal.signal(signal.SIGTERM, self.term_handler)
tchaton marked this conversation as resolved.
Show resolved Hide resolved

def register_slurm_signal_handlers(self):
# see if we're using slurm (not interactive)
Expand All @@ -23,10 +38,10 @@ def register_slurm_signal_handlers(self):

if on_slurm:
log.info("Set SLURM handle signals.")
signal.signal(signal.SIGUSR1, self.sig_handler)
signal.signal(signal.SIGUSR1, self.sig_slurm_handler)
signal.signal(signal.SIGTERM, self.term_handler)

def sig_handler(self, signum, frame): # pragma: no-cover
def sig_slurm_handler(self, signum, frame): # pragma: no-cover
tchaton marked this conversation as resolved.
Show resolved Hide resolved
if self.trainer.is_global_zero:
# save weights
log.info("handling SIGUSR1")
Expand Down Expand Up @@ -56,5 +71,8 @@ def sig_handler(self, signum, frame): # pragma: no-cover
# close experiment to avoid issues
self.trainer.logger.close()

def sig_fault_tolerant_handler(self, signum, frame): # pragma: no-cover
self.trainer._should_gracefully_terminate = True

def term_handler(self, signum, frame): # pragma: no-cover
log.info("bypassing sigterm")
tchaton marked this conversation as resolved.
Show resolved Hide resolved
8 changes: 4 additions & 4 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
from pytorch_lightning.trainer.connectors.model_connector import ModelConnector
from pytorch_lightning.trainer.connectors.optimizer_connector import OptimizerConnector
from pytorch_lightning.trainer.connectors.slurm_connector import SLURMConnector
from pytorch_lightning.trainer.connectors.signal_connector import SignalConnector
from pytorch_lightning.trainer.connectors.training_trick_connector import TrainingTricksConnector
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
from pytorch_lightning.trainer.deprecated_api import DeprecatedTrainerAttributes
Expand Down Expand Up @@ -383,7 +383,7 @@ def __init__(
self.debugging_connector = DebuggingConnector(self)
self.training_tricks_connector = TrainingTricksConnector(self)
self.checkpoint_connector = CheckpointConnector(self, resume_from_checkpoint)
self.slurm_connector = SLURMConnector(self)
self.signal_connector = SignalConnector(self)
self.tuner = Tuner(self)

# max_epochs won't default to 1000 if max_steps/max_time are specified (including being set to -1).
Expand Down Expand Up @@ -1096,8 +1096,8 @@ def _pre_training_routine(self):
# wait for all to join if on distributed
self.accelerator.barrier("setup_training")

# register auto-resubmit when on SLURM
self.slurm_connector.register_slurm_signal_handlers()
# register signals
self.signal_connector.register_signal_handlers()

self.checkpoint_connector.resume_end()

Expand Down
42 changes: 42 additions & 0 deletions tests/trainer/connectors/test_signal_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import signal
from time import sleep
from unittest import mock

import pytest

from pytorch_lightning import Trainer
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf


@pytest.mark.parametrize("should_gracefully_terminate", [False, True])
@RunIf(min_torch="1.7.0", special=True)
def test_fault_tolerant_sig_handler(should_gracefully_terminate, tmpdir):

with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": str(int(should_gracefully_terminate))}):

class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
if should_gracefully_terminate and self.trainer.current_epoch == 1 and batch_idx == 1:
os.kill(os.getpid(), signal.SIGUSR1)
sleep(0.1)
return super().training_step(batch, batch_idx)

model = TestModel()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, limit_train_batches=2, limit_val_batches=2)
trainer.fit(model)
assert trainer._should_gracefully_terminate == should_gracefully_terminate