Skip to content

Commit

Permalink
Snapshot selected globals and restore them in spawned process (#13921)
Browse files Browse the repository at this point in the history
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
awaelchli and Borda authored Aug 1, 2022
1 parent 91bdacf commit eb233ea
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 10 deletions.
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue that caused the learning rate finder to set the model's learning rate to None when no suggestion was possible ([#13845](https://github.com/Lightning-AI/lightning/pull/13845))


- Fixed an issue causing deterministic algorighms and other globals to get reset in spawned processes ([#13921](https://github.com/Lightning-AI/lightning/pull/13921))


- Fixed default `amp_level` for `DeepSpeedPrecisionPlugin` to `O2` ([#13897](https://github.com/PyTorchLightning/pytorch-lightning/pull/13897))


Expand Down
2 changes: 0 additions & 2 deletions src/pytorch_lightning/strategies/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_11
from pytorch_lightning.utilities.optimizer import optimizers_to_device
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import PredictStep, STEP_OUTPUT, TestStep, ValidationStep

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -175,7 +174,6 @@ def set_world_ranks(self, process_idx: int = 0) -> None:
rank_zero_only.rank = self.cluster_environment.global_rank()

def _worker_setup(self, process_idx: int) -> None:
reset_seed()
self.set_world_ranks(process_idx)
rank_zero_only.rank = self.global_rank
self._process_group_backend = self._get_process_group_backend()
Expand Down
65 changes: 63 additions & 2 deletions src/pytorch_lightning/strategies/launchers/multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
# limitations under the License.
import os
from collections import UserList
from dataclasses import dataclass
from multiprocessing.queues import SimpleQueue
from typing import Any, Callable, NamedTuple, Optional
from typing import Any, Callable, Dict, NamedTuple, Optional

import numpy as np
import torch
import torch.backends.cudnn
import torch.multiprocessing as mp
from torch import Tensor
from typing_extensions import Literal
Expand All @@ -27,7 +29,9 @@
from pytorch_lightning.strategies.strategy import Strategy
from pytorch_lightning.trainer.states import TrainerFn, TrainerState
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_11
from pytorch_lightning.utilities.rank_zero import rank_zero_debug
from pytorch_lightning.utilities.seed import _collect_rng_states, _set_rng_states
from pytorch_lightning.utilities.types import _PATH


Expand Down Expand Up @@ -89,9 +93,16 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
os.environ["MASTER_PORT"] = str(self._strategy.cluster_environment.main_port)
context = mp.get_context(self._start_method)
return_queue = context.SimpleQueue()

if self._start_method == "spawn":
global_states = _GlobalStateSnapshot.capture()
process_args = [trainer, function, args, kwargs, return_queue, global_states]
else:
process_args = [trainer, function, args, kwargs, return_queue]

mp.start_processes(
self._wrapping_function,
args=(trainer, function, args, kwargs, return_queue),
args=process_args,
nprocs=self._strategy.num_processes,
start_method=self._start_method,
)
Expand All @@ -110,7 +121,10 @@ def _wrapping_function(
args: Any,
kwargs: Any,
return_queue: SimpleQueue,
global_states: Optional["_GlobalStateSnapshot"] = None,
) -> None:
if global_states:
global_states.restore()
self._strategy._worker_setup(process_idx)
results = function(*args, **kwargs)

Expand Down Expand Up @@ -209,3 +223,50 @@ class _WorkerOutput(NamedTuple):
trainer_state: TrainerState
trainer_results: Any
extra: _FakeQueue


@dataclass
class _GlobalStateSnapshot:
"""Captures a hand-selected set of (global) variables in modules and provides a way to restore them.
It facilitates and encapsulates the transfer of globals like PyTorch's deterministic flags or random generator state
across process boundaries when launching processes with :func:`torch.multiprocessing.spawn`.
Example:
.. code-block:: python
# in main process
snapshot = _GlobalStateSnapshot.capture()
# in worker process
snapshot.restore()
"""

use_deterministic_algorithms: bool
use_deterministic_algorithms_warn_only: bool
cudnn_benchmark: bool
rng_states: Dict[str, Any]

@classmethod
def capture(cls) -> "_GlobalStateSnapshot":
"""Capture a few global states from torch, numpy, etc., that we want to restore in a spawned worker
process."""
warn_only = torch.is_deterministic_algorithms_warn_only_enabled() if _TORCH_GREATER_EQUAL_1_11 else False
return cls(
use_deterministic_algorithms=torch.are_deterministic_algorithms_enabled(),
use_deterministic_algorithms_warn_only=warn_only,
cudnn_benchmark=torch.backends.cudnn.benchmark,
rng_states=_collect_rng_states(),
)

def restore(self) -> None:
"""Restores all globals to the values captured in the :meth:`capture` method."""
if _TORCH_GREATER_EQUAL_1_11:
torch.use_deterministic_algorithms(
self.use_deterministic_algorithms, warn_only=self.use_deterministic_algorithms_warn_only
)
else:
torch.use_deterministic_algorithms(self.use_deterministic_algorithms)
torch.backends.cudnn.benchmark = self.cudnn_benchmark
_set_rng_states(self.rng_states)
8 changes: 7 additions & 1 deletion src/pytorch_lightning/strategies/launchers/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@
from torch.multiprocessing import ProcessContext

import pytorch_lightning as pl
from pytorch_lightning.strategies.launchers.multiprocessing import _FakeQueue, _MultiProcessingLauncher, _WorkerOutput
from pytorch_lightning.strategies.launchers.multiprocessing import (
_FakeQueue,
_GlobalStateSnapshot,
_MultiProcessingLauncher,
_WorkerOutput,
)
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities.apply_func import move_data_to_device
Expand Down Expand Up @@ -96,6 +101,7 @@ def _wrapping_function(
args: Any,
kwargs: Any,
return_queue: SimpleQueue,
global_states: Optional[_GlobalStateSnapshot] = None,
) -> None:
self._strategy._worker_setup(process_idx)
results = function(*args, **kwargs)
Expand Down
2 changes: 0 additions & 2 deletions src/pytorch_lightning/strategies/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.optimizer import optimizers_to_device
from pytorch_lightning.utilities.rank_zero import rank_zero_only
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT

if _TPU_AVAILABLE:
Expand Down Expand Up @@ -206,7 +205,6 @@ def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[

def _worker_setup(self, process_idx: int):
self._launched = True
reset_seed()
self.set_world_ranks(process_idx)
rank_zero_only.rank = self.global_rank

Expand Down
43 changes: 40 additions & 3 deletions tests/tests_pytorch/strategies/launchers/test_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,20 @@
from unittest.mock import ANY, Mock

import pytest
import torch

from pytorch_lightning.strategies.launchers.multiprocessing import _MultiProcessingLauncher
from pytorch_lightning.strategies.launchers.multiprocessing import _GlobalStateSnapshot, _MultiProcessingLauncher


@mock.patch("pytorch_lightning.strategies.launchers.multiprocessing.mp.get_all_start_methods", return_value=[])
def test_spawn_launcher_forking_on_unsupported_platform(_):
def test_multiprocessing_launcher_forking_on_unsupported_platform(_):
with pytest.raises(ValueError, match="The start method 'fork' is not available on this platform"):
_MultiProcessingLauncher(strategy=Mock(), start_method="fork")


@pytest.mark.parametrize("start_method", ["spawn", "fork"])
@mock.patch("pytorch_lightning.strategies.launchers.multiprocessing.mp")
def test_spawn_launcher_start_method(mp_mock, start_method):
def test_multiprocessing_launcher_start_method(mp_mock, start_method):
mp_mock.get_all_start_methods.return_value = [start_method]
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
launcher.launch(function=Mock())
Expand All @@ -38,3 +39,39 @@ def test_spawn_launcher_start_method(mp_mock, start_method):
nprocs=ANY,
start_method=start_method,
)


@pytest.mark.parametrize("start_method", ["spawn", "fork"])
@mock.patch("pytorch_lightning.strategies.launchers.multiprocessing.mp")
def test_multiprocessing_launcher_restore_globals(mp_mock, start_method):
"""Test that we pass the global state snapshot to the worker function only if we are starting with 'spawn'."""
mp_mock.get_all_start_methods.return_value = [start_method]
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
launcher.launch(function=Mock())
function_args = mp_mock.start_processes.call_args[1]["args"]
if start_method == "spawn":
assert len(function_args) == 6
assert isinstance(function_args[5], _GlobalStateSnapshot)
else:
assert len(function_args) == 5


def test_global_state_snapshot():
"""Test the capture() and restore() methods for the global state snapshot."""
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.benchmark = False
torch.manual_seed(123)

# capture the state of globals
snapshot = _GlobalStateSnapshot.capture()

# simulate there is a process boundary and flags get reset here
torch.use_deterministic_algorithms(False)
torch.backends.cudnn.benchmark = True
torch.manual_seed(321)

# restore the state of globals
snapshot.restore()
assert torch.are_deterministic_algorithms_enabled()
assert not torch.backends.cudnn.benchmark
assert torch.initial_seed() == 123

0 comments on commit eb233ea

Please sign in to comment.