Skip to content

Commit

Permalink
Disable memory sharing on model parameters in ddp-spawn (#18238)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Aug 15, 2023
1 parent 0d1932c commit a0ca2c8
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 6 deletions.
3 changes: 3 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue with `Fabric.all_reduce()` not performing an inplace operation for all backends consistently ([#18235](https://github.com/Lightning-AI/lightning/pull/18235))


- Fixed model parameters getting shared between processes when running with `strategy="ddp_spawn"` and `accelerator="cpu"`; this has a necessary memory impact, as parameters are replicated for each process now ([#18238](https://github.com/Lightning-AI/lightning/pull/18238))


## [2.0.5] - 2023-07-07

### Added
Expand Down
26 changes: 26 additions & 0 deletions src/lightning/fabric/strategies/launchers/multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import os
from dataclasses import dataclass
from multiprocessing.queues import SimpleQueue
Expand All @@ -19,7 +20,10 @@
import torch
import torch.backends.cudnn
import torch.multiprocessing as mp
from lightning_utilities import apply_to_collection
from torch.nn import Module

from lightning.fabric.accelerators.cpu import CPUAccelerator
from lightning.fabric.strategies.launchers.launcher import _Launcher
from lightning.fabric.utilities.apply_func import move_data_to_device
from lightning.fabric.utilities.imports import _IS_INTERACTIVE
Expand Down Expand Up @@ -122,6 +126,10 @@ def _wrapping_function(
) -> None:
if global_states:
global_states.restore()

if self._start_method == "spawn" and isinstance(self._strategy.accelerator, CPUAccelerator):
args, kwargs = _disable_module_memory_sharing((args, kwargs))

os.environ["LOCAL_RANK"] = str(process_idx)
results = function(*args, **kwargs)

Expand Down Expand Up @@ -190,3 +198,21 @@ def _check_bad_cuda_fork() -> None:
if _IS_INTERACTIVE:
message += " You will have to restart the Python kernel."
raise RuntimeError(message)


def _disable_module_memory_sharing(data: Any) -> Any:
"""Disables memory sharing on parameters and buffers of `nn.Module`s contained in the given collection.
Note: This is only required when running on CPU.
"""
# PyTorch enables memory sharing automatically on all tensors that are passed through `mp.spawn`.
# For model weights and buffers, this is undesired and can lead to race conditions between processes.
# Hence, we copy the tensors in the entire module to ensure it doesn't share memory with other processes.

@torch.no_grad()
def unshare(module: Module) -> Module:
for tensor in itertools.chain(module.parameters(), module.buffers()):
tensor.data = tensor.data.clone()
return module

return apply_to_collection(data, function=unshare, dtype=Module)
2 changes: 2 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue that could cause the `LightningOptimizer` wrapper returned by `LightningModule.optimizers()` have different internal state than the optimizer it wraps ([#18280](https://github.com/Lightning-AI/lightning/pull/18280))


- Fixed model parameters getting shared between processes when running with `strategy="ddp_spawn"` and `accelerator="cpu"`; this has a necessary memory impact, as parameters are replicated for each process now ([#18238](https://github.com/Lightning-AI/lightning/pull/18238))


## [2.0.5] - 2023-07-07

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@
from torch import Tensor

import lightning.pytorch as pl
from lightning.fabric.strategies.launchers.multiprocessing import _check_bad_cuda_fork
from lightning.fabric.strategies.launchers.multiprocessing import _check_bad_cuda_fork, _disable_module_memory_sharing
from lightning.fabric.utilities import move_data_to_device
from lightning.fabric.utilities.seed import _collect_rng_states, _set_rng_states
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch.accelerators import CPUAccelerator
from lightning.pytorch.strategies.launchers.launcher import _Launcher
from lightning.pytorch.trainer.connectors.signal_connector import _SIGNUM
from lightning.pytorch.trainer.states import TrainerFn, TrainerState
Expand Down Expand Up @@ -144,6 +145,9 @@ def _wrapping_function(
) -> None:
if global_states:
global_states.restore()
if self._start_method == "spawn" and isinstance(self._strategy.accelerator, CPUAccelerator):
args, kwargs = _disable_module_memory_sharing((args, kwargs))

os.environ["LOCAL_RANK"] = str(process_idx)
results = function(*args, **kwargs)

Expand Down
10 changes: 5 additions & 5 deletions tests/tests_fabric/strategies/launchers/test_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,20 @@

@RunIf(skip_windows=True)
@pytest.mark.parametrize("start_method", ["fork", "forkserver"])
def test_multiprocessing_launcher_interactive_compatible(start_method):
def test_interactive_compatible(start_method):
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
assert launcher.is_interactive_compatible == (start_method == "fork")


@mock.patch("lightning.fabric.strategies.launchers.multiprocessing.mp.get_all_start_methods", return_value=[])
def test_multiprocessing_launcher_forking_on_unsupported_platform(_):
def test_forking_on_unsupported_platform(_):
with pytest.raises(ValueError, match="The start method 'fork' is not available on this platform"):
_MultiProcessingLauncher(strategy=Mock(), start_method="fork")


@pytest.mark.parametrize("start_method", ["spawn", pytest.param("fork", marks=RunIf(standalone=True))])
@mock.patch("lightning.fabric.strategies.launchers.multiprocessing.mp")
def test_multiprocessing_launcher_start_method(mp_mock, start_method):
def test_start_method(mp_mock, start_method):
mp_mock.get_all_start_methods.return_value = [start_method]
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
launcher.launch(function=Mock())
Expand All @@ -51,7 +51,7 @@ def test_multiprocessing_launcher_start_method(mp_mock, start_method):

@pytest.mark.parametrize("start_method", ["spawn", pytest.param("fork", marks=RunIf(standalone=True))])
@mock.patch("lightning.fabric.strategies.launchers.multiprocessing.mp")
def test_multiprocessing_launcher_restore_globals(mp_mock, start_method):
def test_restore_globals(mp_mock, start_method):
"""Test that we pass the global state snapshot to the worker function only if we are starting with 'spawn'."""
mp_mock.get_all_start_methods.return_value = [start_method]
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
Expand Down Expand Up @@ -89,7 +89,7 @@ def test_global_state_snapshot():
@pytest.mark.parametrize("start_method", ["fork", "forkserver"])
@mock.patch("torch.cuda.is_initialized", return_value=True)
@mock.patch("lightning.fabric.strategies.launchers.multiprocessing.mp")
def test_multiprocessing_launcher_check_for_bad_cuda_fork(mp_mock, _, start_method):
def test_check_for_bad_cuda_fork(mp_mock, _, start_method):
mp_mock.get_all_start_methods.return_value = [start_method]
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
with pytest.raises(RuntimeError, match="Lightning can't create new processes if CUDA is already initialized"):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
import torch.nn as nn

from lightning.fabric import Fabric
from tests_fabric.helpers.runif import RunIf


class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.layer = nn.Linear(2, 2)
self.tied_layer = nn.Linear(2, 2)
self.tied_layer.weight = self.layer.weight
self.register_buffer("buffer", torch.ones(3))


@pytest.mark.parametrize("strategy", ["ddp_spawn", pytest.param("ddp_fork", marks=RunIf(skip_windows=True))])
def test_memory_sharing_disabled(strategy):
"""Test that the multiprocessing launcher disables memory sharing on model parameters and buffers to avoid race
conditions on model updates."""
tensor = torch.rand(4)
model = SimpleModel()
assert not tensor.is_shared()
assert not model.layer.weight.is_shared()
assert model.layer.weight.data_ptr() == model.tied_layer.weight.data_ptr()

fabric = Fabric(accelerator="cpu", devices=2, strategy=strategy)
fabric.launch(_test_memory_sharing_disabled, tensor, model)


def _test_memory_sharing_disabled(fabric, tensor, model):
is_spawn = fabric.strategy.launcher._start_method == "spawn"
assert not is_spawn or tensor.is_shared()
assert not model.layer.weight.is_shared()
assert not model.tied_layer.weight.is_shared()
assert not model.buffer.is_shared()

# weights remain tied
assert model.layer.weight.data_ptr() == model.tied_layer.weight.data_ptr()
assert torch.equal(model.layer.weight.data, model.tied_layer.weight.data)
fabric.barrier()
28 changes: 28 additions & 0 deletions tests/tests_pytorch/strategies/launchers/test_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,31 @@ def test_kill():
with patch("os.kill") as kill_patch:
launcher.kill(15)
assert kill_patch.mock_calls == [call(proc0.pid, 15), call(proc1.pid, 15)]


class SimpleModel(BoringModel):
def __init__(self):
super().__init__()
self.tied_layer = torch.nn.Linear(32, 2)
self.tied_layer.weight = self.layer.weight
self.register_buffer("buffer", torch.ones(3))

def on_fit_start(self) -> None:
assert not self.layer.weight.is_shared()
assert not self.tied_layer.weight.is_shared()
assert not self.buffer.is_shared()

# weights remain tied
assert self.layer.weight.data_ptr() == self.tied_layer.weight.data_ptr()
assert torch.equal(self.layer.weight.data, self.tied_layer.weight.data)


def test_memory_sharing_disabled():
"""Test that the multiprocessing launcher disables memory sharing on model parameters and buffers to avoid race
conditions on model updates."""
model = SimpleModel()
assert not model.layer.weight.is_shared()
assert model.layer.weight.data_ptr() == model.tied_layer.weight.data_ptr()

trainer = Trainer(accelerator="cpu", devices=2, strategy="ddp_spawn", max_steps=0)
trainer.fit(model)

0 comments on commit a0ca2c8

Please sign in to comment.