From eb233ea12dacf010edda49470396e85707d8c00e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 2 Aug 2022 00:21:46 +0200 Subject: [PATCH] Snapshot selected globals and restore them in spawned process (#13921) Co-authored-by: Jirka Borovec --- src/pytorch_lightning/CHANGELOG.md | 3 + src/pytorch_lightning/strategies/ddp_spawn.py | 2 - .../strategies/launchers/multiprocessing.py | 65 ++++++++++++++++++- .../strategies/launchers/xla.py | 8 ++- src/pytorch_lightning/strategies/tpu_spawn.py | 2 - .../launchers/test_multiprocessing.py | 43 +++++++++++- 6 files changed, 113 insertions(+), 10 deletions(-) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index ea649a9b65236..1516b74453842 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -396,6 +396,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue that caused the learning rate finder to set the model's learning rate to None when no suggestion was possible ([#13845](https://github.com/Lightning-AI/lightning/pull/13845)) +- Fixed an issue causing deterministic algorighms and other globals to get reset in spawned processes ([#13921](https://github.com/Lightning-AI/lightning/pull/13921)) + + - Fixed default `amp_level` for `DeepSpeedPrecisionPlugin` to `O2` ([#13897](https://github.com/PyTorchLightning/pytorch-lightning/pull/13897)) diff --git a/src/pytorch_lightning/strategies/ddp_spawn.py b/src/pytorch_lightning/strategies/ddp_spawn.py index 6a3460febbf07..30bcef457c44a 100644 --- a/src/pytorch_lightning/strategies/ddp_spawn.py +++ b/src/pytorch_lightning/strategies/ddp_spawn.py @@ -50,7 +50,6 @@ from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_11 from pytorch_lightning.utilities.optimizer import optimizers_to_device from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only -from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import PredictStep, STEP_OUTPUT, TestStep, ValidationStep log = logging.getLogger(__name__) @@ -175,7 +174,6 @@ def set_world_ranks(self, process_idx: int = 0) -> None: rank_zero_only.rank = self.cluster_environment.global_rank() def _worker_setup(self, process_idx: int) -> None: - reset_seed() self.set_world_ranks(process_idx) rank_zero_only.rank = self.global_rank self._process_group_backend = self._get_process_group_backend() diff --git a/src/pytorch_lightning/strategies/launchers/multiprocessing.py b/src/pytorch_lightning/strategies/launchers/multiprocessing.py index 37e6c8d893150..91fa92b555ae0 100644 --- a/src/pytorch_lightning/strategies/launchers/multiprocessing.py +++ b/src/pytorch_lightning/strategies/launchers/multiprocessing.py @@ -13,11 +13,13 @@ # limitations under the License. import os from collections import UserList +from dataclasses import dataclass from multiprocessing.queues import SimpleQueue -from typing import Any, Callable, NamedTuple, Optional +from typing import Any, Callable, Dict, NamedTuple, Optional import numpy as np import torch +import torch.backends.cudnn import torch.multiprocessing as mp from torch import Tensor from typing_extensions import Literal @@ -27,7 +29,9 @@ from pytorch_lightning.strategies.strategy import Strategy from pytorch_lightning.trainer.states import TrainerFn, TrainerState from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device +from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_11 from pytorch_lightning.utilities.rank_zero import rank_zero_debug +from pytorch_lightning.utilities.seed import _collect_rng_states, _set_rng_states from pytorch_lightning.utilities.types import _PATH @@ -89,9 +93,16 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] os.environ["MASTER_PORT"] = str(self._strategy.cluster_environment.main_port) context = mp.get_context(self._start_method) return_queue = context.SimpleQueue() + + if self._start_method == "spawn": + global_states = _GlobalStateSnapshot.capture() + process_args = [trainer, function, args, kwargs, return_queue, global_states] + else: + process_args = [trainer, function, args, kwargs, return_queue] + mp.start_processes( self._wrapping_function, - args=(trainer, function, args, kwargs, return_queue), + args=process_args, nprocs=self._strategy.num_processes, start_method=self._start_method, ) @@ -110,7 +121,10 @@ def _wrapping_function( args: Any, kwargs: Any, return_queue: SimpleQueue, + global_states: Optional["_GlobalStateSnapshot"] = None, ) -> None: + if global_states: + global_states.restore() self._strategy._worker_setup(process_idx) results = function(*args, **kwargs) @@ -209,3 +223,50 @@ class _WorkerOutput(NamedTuple): trainer_state: TrainerState trainer_results: Any extra: _FakeQueue + + +@dataclass +class _GlobalStateSnapshot: + """Captures a hand-selected set of (global) variables in modules and provides a way to restore them. + + It facilitates and encapsulates the transfer of globals like PyTorch's deterministic flags or random generator state + across process boundaries when launching processes with :func:`torch.multiprocessing.spawn`. + + Example: + + .. code-block:: python + + # in main process + snapshot = _GlobalStateSnapshot.capture() + + # in worker process + snapshot.restore() + """ + + use_deterministic_algorithms: bool + use_deterministic_algorithms_warn_only: bool + cudnn_benchmark: bool + rng_states: Dict[str, Any] + + @classmethod + def capture(cls) -> "_GlobalStateSnapshot": + """Capture a few global states from torch, numpy, etc., that we want to restore in a spawned worker + process.""" + warn_only = torch.is_deterministic_algorithms_warn_only_enabled() if _TORCH_GREATER_EQUAL_1_11 else False + return cls( + use_deterministic_algorithms=torch.are_deterministic_algorithms_enabled(), + use_deterministic_algorithms_warn_only=warn_only, + cudnn_benchmark=torch.backends.cudnn.benchmark, + rng_states=_collect_rng_states(), + ) + + def restore(self) -> None: + """Restores all globals to the values captured in the :meth:`capture` method.""" + if _TORCH_GREATER_EQUAL_1_11: + torch.use_deterministic_algorithms( + self.use_deterministic_algorithms, warn_only=self.use_deterministic_algorithms_warn_only + ) + else: + torch.use_deterministic_algorithms(self.use_deterministic_algorithms) + torch.backends.cudnn.benchmark = self.cudnn_benchmark + _set_rng_states(self.rng_states) diff --git a/src/pytorch_lightning/strategies/launchers/xla.py b/src/pytorch_lightning/strategies/launchers/xla.py index 037ec027bfd7d..064d952f71a8f 100644 --- a/src/pytorch_lightning/strategies/launchers/xla.py +++ b/src/pytorch_lightning/strategies/launchers/xla.py @@ -21,7 +21,12 @@ from torch.multiprocessing import ProcessContext import pytorch_lightning as pl -from pytorch_lightning.strategies.launchers.multiprocessing import _FakeQueue, _MultiProcessingLauncher, _WorkerOutput +from pytorch_lightning.strategies.launchers.multiprocessing import ( + _FakeQueue, + _GlobalStateSnapshot, + _MultiProcessingLauncher, + _WorkerOutput, +) from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.apply_func import move_data_to_device @@ -96,6 +101,7 @@ def _wrapping_function( args: Any, kwargs: Any, return_queue: SimpleQueue, + global_states: Optional[_GlobalStateSnapshot] = None, ) -> None: self._strategy._worker_setup(process_idx) results = function(*args, **kwargs) diff --git a/src/pytorch_lightning/strategies/tpu_spawn.py b/src/pytorch_lightning/strategies/tpu_spawn.py index 2d474fafe51b1..4d20e784e0d29 100644 --- a/src/pytorch_lightning/strategies/tpu_spawn.py +++ b/src/pytorch_lightning/strategies/tpu_spawn.py @@ -37,7 +37,6 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.optimizer import optimizers_to_device from pytorch_lightning.utilities.rank_zero import rank_zero_only -from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT if _TPU_AVAILABLE: @@ -206,7 +205,6 @@ def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ def _worker_setup(self, process_idx: int): self._launched = True - reset_seed() self.set_world_ranks(process_idx) rank_zero_only.rank = self.global_rank diff --git a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py index 2a5fe82928a67..ad3e891ad607f 100644 --- a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py +++ b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py @@ -15,19 +15,20 @@ from unittest.mock import ANY, Mock import pytest +import torch -from pytorch_lightning.strategies.launchers.multiprocessing import _MultiProcessingLauncher +from pytorch_lightning.strategies.launchers.multiprocessing import _GlobalStateSnapshot, _MultiProcessingLauncher @mock.patch("pytorch_lightning.strategies.launchers.multiprocessing.mp.get_all_start_methods", return_value=[]) -def test_spawn_launcher_forking_on_unsupported_platform(_): +def test_multiprocessing_launcher_forking_on_unsupported_platform(_): with pytest.raises(ValueError, match="The start method 'fork' is not available on this platform"): _MultiProcessingLauncher(strategy=Mock(), start_method="fork") @pytest.mark.parametrize("start_method", ["spawn", "fork"]) @mock.patch("pytorch_lightning.strategies.launchers.multiprocessing.mp") -def test_spawn_launcher_start_method(mp_mock, start_method): +def test_multiprocessing_launcher_start_method(mp_mock, start_method): mp_mock.get_all_start_methods.return_value = [start_method] launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method) launcher.launch(function=Mock()) @@ -38,3 +39,39 @@ def test_spawn_launcher_start_method(mp_mock, start_method): nprocs=ANY, start_method=start_method, ) + + +@pytest.mark.parametrize("start_method", ["spawn", "fork"]) +@mock.patch("pytorch_lightning.strategies.launchers.multiprocessing.mp") +def test_multiprocessing_launcher_restore_globals(mp_mock, start_method): + """Test that we pass the global state snapshot to the worker function only if we are starting with 'spawn'.""" + mp_mock.get_all_start_methods.return_value = [start_method] + launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method) + launcher.launch(function=Mock()) + function_args = mp_mock.start_processes.call_args[1]["args"] + if start_method == "spawn": + assert len(function_args) == 6 + assert isinstance(function_args[5], _GlobalStateSnapshot) + else: + assert len(function_args) == 5 + + +def test_global_state_snapshot(): + """Test the capture() and restore() methods for the global state snapshot.""" + torch.use_deterministic_algorithms(True) + torch.backends.cudnn.benchmark = False + torch.manual_seed(123) + + # capture the state of globals + snapshot = _GlobalStateSnapshot.capture() + + # simulate there is a process boundary and flags get reset here + torch.use_deterministic_algorithms(False) + torch.backends.cudnn.benchmark = True + torch.manual_seed(321) + + # restore the state of globals + snapshot.restore() + assert torch.are_deterministic_algorithms_enabled() + assert not torch.backends.cudnn.benchmark + assert torch.initial_seed() == 123