Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip reconciliate_processes if used within a cluster environment that creates processes externally #9389

Merged
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed collision of user argument when using ShardedDDP ([#9512](https://github.com/PyTorchLightning/pytorch-lightning/pull/9512))


- Fixed error reporting in DDP process reconciliation when processes are launched by an external agent ([#9389](https://github.com/PyTorchLightning/pytorch-lightning/pull/9389))
ananthsub marked this conversation as resolved.
Show resolved Hide resolved


## [1.4.5] - 2021-08-31

- Fixed reduction using `self.log(sync_dict=True, reduce_fx={mean,max})` ([#9142](https://github.com/PyTorchLightning/pytorch-lightning/pull/9142))
Expand Down
25 changes: 21 additions & 4 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def __init__(
self._model_averaging_period = model_averaging_period
self._pids: Optional[List[int]] = None
self._sync_dir: Optional[str] = None
self._rank_0_has_called_call_children_scripts: bool = False
self.set_world_ranks()

@property
Expand Down Expand Up @@ -252,6 +253,8 @@ def _call_children_scripts(self):
delay = np.random.uniform(1, 5, 1)[0]
sleep(delay)

self._rank_0_has_called_call_children_scripts = True

def setup_distributed(self):
reset_seed()

Expand All @@ -269,6 +272,7 @@ def setup_distributed(self):
# set the ranks and devices
self.dist.rank = self.global_rank
self.dist.device = self.root_device
self._rank_0_has_called_call_children_scripts = self.broadcast(self._rank_0_has_called_call_children_scripts)
ananthsub marked this conversation as resolved.
Show resolved Hide resolved

def _check_can_spawn_children(self):
if self.local_rank != 0:
Expand Down Expand Up @@ -373,7 +377,8 @@ def determine_ddp_device_ids(self):

def pre_dispatch(self):
# share ddp pids to all processes
self._share_information_to_prevent_deadlock()
if self._should_run_deadlock_detection():
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
ananthsub marked this conversation as resolved.
Show resolved Hide resolved
self._share_information_to_prevent_deadlock()

# move the model to the correct device
self.model_to_device()
Expand Down Expand Up @@ -454,7 +459,16 @@ def register_plugins(cls, plugin_registry: Dict) -> None:
find_unused_parameters=False,
)

def _share_information_to_prevent_deadlock(self):
def _should_run_deadlock_detection(self) -> bool:
"""Determines whether the plugin will perform process reconciliation in case of errors.

If the environment variable `PL_RECONCILE_PROCESS` is set, run detection regardless of the cluster environment.
By default this is disabled. Otherwise, if the cluster environment creates the processes, allow the scheduler /
parent process to perform the process termination, external to Lightning.
"""
return os.getenv("PL_RECONCILE_PROCESS", "0") == "1" or self._rank_0_has_called_call_children_scripts

def _share_information_to_prevent_deadlock(self) -> None:
self._share_pids()

# there should be a unique sync_dir per nodes.
Expand All @@ -470,17 +484,20 @@ def _share_information_to_prevent_deadlock(self):

self._sync_dir = sync_dirs[self.node_rank]

def _share_pids(self):
def _share_pids(self) -> None:
"""Make all DDP processes aware of all processes pids."""
self.barrier()
pids = self.all_gather(torch.tensor(os.getpid(), device=self.root_device))
pids = pids.cpu().numpy().tolist()
self._pids = pids if isinstance(pids, list) else [pids]

def reconciliate_processes(self, trace: str):
def reconciliate_processes(self, trace: str) -> None:
if self.world_size < 2:
return

if not self._should_run_deadlock_detection():
return

sync_dir = self._sync_dir

if not sync_dir:
Expand Down
2 changes: 1 addition & 1 deletion tests/plugins/environments/torch_elastic_deadlock.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException
from tests.helpers.boring_model import BoringModel

if os.getenv("PL_RUNNING_SPECIAL_TESTS", "0") == "1":
if os.getenv("PL_RUNNING_SPECIAL_TESTS", "0") == "1" and os.getenv("PL_RECONCILE_PROCESS", "0") == "1":

class CustomException(Exception):
pass
Expand Down
2 changes: 1 addition & 1 deletion tests/special_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ fi

# TODO: enable when CI uses torch>=1.9
# test deadlock is properly handled with TorchElastic.
# LOGS=$(PL_RUNNING_SPECIAL_TESTS=1 python -m torch.distributed.run --nproc_per_node=2 --max_restarts 0 -m coverage run --source pytorch_lightning -a tests/plugins/environments/torch_elastic_deadlock.py | grep "SUCCEEDED")
# LOGS=$(PL_RUNNING_SPECIAL_TESTS=1 PL_RECONCILE_PROCESS=1 python -m torch.distributed.run --nproc_per_node=2 --max_restarts 0 -m coverage run --source pytorch_lightning -a tests/plugins/environments/torch_elastic_deadlock.py | grep "SUCCEEDED")
# if [ -z "$LOGS" ]; then
# exit 1
# fi
Expand Down
7 changes: 4 additions & 3 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1823,13 +1823,14 @@ def test_exception_when_lightning_module_is_not_set_on_trainer():
trainer.predict()


class CustomException(Exception):
pass


@RunIf(min_gpus=2, special=True)
def test_ddp_terminate_when_deadlock_is_detected(tmpdir):
"""Test that DDP kills the remaining processes when only one rank is throwing an exception."""

class CustomException(Exception):
pass

class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
if batch_idx == 1 and self.trainer.is_global_zero:
Expand Down