Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable memory sharing on model parameters in ddp-spawn #18238

Merged
merged 29 commits into from
Aug 15, 2023
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
52feb38
disable
awaelchli Aug 6, 2023
31beadf
test
awaelchli Aug 6, 2023
c7f6ec2
update
awaelchli Aug 6, 2023
0210dca
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 6, 2023
5078ce0
update
awaelchli Aug 6, 2023
a3c0e3f
update test
awaelchli Aug 6, 2023
b8bb080
Merge remote-tracking branch 'origin/bugfix/tensor-memory-sharing' in…
awaelchli Aug 6, 2023
42a2172
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 6, 2023
0401dc8
update
awaelchli Aug 6, 2023
0592595
handle tied weights WIP
awaelchli Aug 7, 2023
d10e1c4
test
awaelchli Aug 8, 2023
b181fa7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 8, 2023
a9aec68
Merge branch 'master' into bugfix/tensor-memory-sharing
awaelchli Aug 9, 2023
7fee676
typo
awaelchli Aug 9, 2023
a43e852
barrier
awaelchli Aug 9, 2023
4753f62
update
awaelchli Aug 10, 2023
a0e9555
update
awaelchli Aug 10, 2023
be2f424
update test
awaelchli Aug 10, 2023
7753cbe
fixes
awaelchli Aug 13, 2023
8ae4074
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 13, 2023
2123297
add test
awaelchli Aug 13, 2023
0899bcc
chlog
awaelchli Aug 13, 2023
e1217bf
reset
awaelchli Aug 13, 2023
1a5f402
simplify
awaelchli Aug 13, 2023
c88c74b
Apply suggestions from code review
awaelchli Aug 13, 2023
68587d3
Merge branch 'master' into bugfix/tensor-memory-sharing
awaelchli Aug 13, 2023
cebedd2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 13, 2023
3dc048b
Apply suggestions from code review
awaelchli Aug 14, 2023
75f1396
add note about memory
awaelchli Aug 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue with `Fabric.all_reduce()` not performing an inplace operation for all backends consistently ([#18235](https://github.com/Lightning-AI/lightning/pull/18235))


- Fixed model parameters getting shared between processes when running with `strategy="ddp_spawn` and `accelerator="cpu"` ([#18238](https://github.com/Lightning-AI/lightning/pull/18238))
awaelchli marked this conversation as resolved.
Show resolved Hide resolved


## [2.0.5] - 2023-07-07

### Added
Expand Down
26 changes: 26 additions & 0 deletions src/lightning/fabric/strategies/launchers/multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import os
from dataclasses import dataclass
from multiprocessing.queues import SimpleQueue
Expand All @@ -19,7 +20,10 @@
import torch
import torch.backends.cudnn
import torch.multiprocessing as mp
from lightning_utilities import apply_to_collection
from torch.nn import Module

from lightning.fabric.accelerators.cpu import CPUAccelerator
from lightning.fabric.strategies.launchers.launcher import _Launcher
from lightning.fabric.utilities.apply_func import move_data_to_device
from lightning.fabric.utilities.imports import _IS_INTERACTIVE
Expand Down Expand Up @@ -122,6 +126,10 @@ def _wrapping_function(
) -> None:
if global_states:
global_states.restore()

if self._start_method == "spawn" and isinstance(self._strategy.accelerator, CPUAccelerator):
args, kwargs = _disable_module_memory_sharing((args, kwargs))

os.environ["LOCAL_RANK"] = str(process_idx)
results = function(*args, **kwargs)

Expand Down Expand Up @@ -190,3 +198,21 @@ def _check_bad_cuda_fork() -> None:
if _IS_INTERACTIVE:
message += " You will have to restart the Python kernel."
raise RuntimeError(message)


def _disable_module_memory_sharing(data: Any) -> Any:
"""Disables memory sharing on parameters and buffers of `nn.Module`s contained in the given collection.

Note: This is only required when running on CPU.
"""
# PyTorch enables memory sharing automatically on all tensors that are passed through `mp.spawn`.
# For model weights and buffers, this is undesired and can lead to race conditions between processes.
# Hence, we copy the tensors in the entire module to ensure it doesn't share memory with other processes.

@torch.no_grad()
def _unshare(module: Module) -> Module:
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
for tensor in itertools.chain(module.parameters(), module.buffers()):
tensor.data = tensor.data.clone()
return module

return apply_to_collection(data, function=_unshare, dtype=Module)
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue that would prevent the user to set the multiprocessing start method after importing lightning ([#18177](https://github.com/Lightning-AI/lightning/pull/18177))


- Fixed model parameters getting shared between processes when running with `strategy="ddp_spawn` and `accelerator="cpu"` ([#18238](https://github.com/Lightning-AI/lightning/pull/18238))
awaelchli marked this conversation as resolved.
Show resolved Hide resolved


## [2.0.5] - 2023-07-07

### Fixed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@
from torch import Tensor

import lightning.pytorch as pl
from lightning.fabric.strategies.launchers.multiprocessing import _check_bad_cuda_fork
from lightning.fabric.strategies.launchers.multiprocessing import _check_bad_cuda_fork, _disable_module_memory_sharing
from lightning.fabric.utilities import move_data_to_device
from lightning.fabric.utilities.seed import _collect_rng_states, _set_rng_states
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch.accelerators import CPUAccelerator
from lightning.pytorch.strategies.launchers.launcher import _Launcher
from lightning.pytorch.trainer.connectors.signal_connector import _SIGNUM
from lightning.pytorch.trainer.states import TrainerFn, TrainerState
Expand Down Expand Up @@ -144,6 +145,9 @@ def _wrapping_function(
) -> None:
if global_states:
global_states.restore()
if self._start_method == "spawn" and isinstance(self._strategy.accelerator, CPUAccelerator):
args, kwargs = _disable_module_memory_sharing((args, kwargs))

os.environ["LOCAL_RANK"] = str(process_idx)
results = function(*args, **kwargs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,20 @@

@RunIf(skip_windows=True)
@pytest.mark.parametrize("start_method", ["fork", "forkserver"])
def test_multiprocessing_launcher_interactive_compatible(start_method):
def test_interactive_compatible(start_method):
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
assert launcher.is_interactive_compatible == (start_method == "fork")


@mock.patch("lightning.fabric.strategies.launchers.multiprocessing.mp.get_all_start_methods", return_value=[])
def test_multiprocessing_launcher_forking_on_unsupported_platform(_):
def test_forking_on_unsupported_platform(_):
with pytest.raises(ValueError, match="The start method 'fork' is not available on this platform"):
_MultiProcessingLauncher(strategy=Mock(), start_method="fork")


@pytest.mark.parametrize("start_method", ["spawn", pytest.param("fork", marks=RunIf(standalone=True))])
@mock.patch("lightning.fabric.strategies.launchers.multiprocessing.mp")
def test_multiprocessing_launcher_start_method(mp_mock, start_method):
def test_start_method(mp_mock, start_method):
mp_mock.get_all_start_methods.return_value = [start_method]
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
launcher.launch(function=Mock())
Expand All @@ -51,7 +51,7 @@ def test_multiprocessing_launcher_start_method(mp_mock, start_method):

@pytest.mark.parametrize("start_method", ["spawn", pytest.param("fork", marks=RunIf(standalone=True))])
@mock.patch("lightning.fabric.strategies.launchers.multiprocessing.mp")
def test_multiprocessing_launcher_restore_globals(mp_mock, start_method):
def test_restore_globals(mp_mock, start_method):
"""Test that we pass the global state snapshot to the worker function only if we are starting with 'spawn'."""
mp_mock.get_all_start_methods.return_value = [start_method]
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
Expand Down Expand Up @@ -89,7 +89,7 @@ def test_global_state_snapshot():
@pytest.mark.parametrize("start_method", ["fork", "forkserver"])
@mock.patch("torch.cuda.is_initialized", return_value=True)
@mock.patch("lightning.fabric.strategies.launchers.multiprocessing.mp")
def test_multiprocessing_launcher_check_for_bad_cuda_fork(mp_mock, _, start_method):
def test_check_for_bad_cuda_fork(mp_mock, _, start_method):
mp_mock.get_all_start_methods.return_value = [start_method]
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
with pytest.raises(RuntimeError, match="Lightning can't create new processes if CUDA is already initialized"):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
import torch.nn as nn

from lightning.fabric import Fabric
from tests_fabric.helpers.runif import RunIf


class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.layer = nn.Linear(2, 2)
self.tied_layer = nn.Linear(2, 2)
self.tied_layer.weight = self.layer.weight
self.register_buffer("buffer", torch.ones(3))


@pytest.mark.parametrize("strategy", ["ddp_spawn", pytest.param("ddp_fork", marks=RunIf(skip_windows=True))])
def test_memory_sharing_disabled(strategy):
"""Test that the multiprocessing launcher disables memory sharing on model parameters and buffers to avoid race
conditions on model updates."""
tensor = torch.rand(4)
model = SimpleModel()
assert not tensor.is_shared()
assert not model.layer.weight.is_shared()
assert model.layer.weight.data_ptr() == model.tied_layer.weight.data_ptr()

fabric = Fabric(accelerator="cpu", devices=2, strategy=strategy)
fabric.launch(_test_memory_sharing_disabled, tensor, model)


def _test_memory_sharing_disabled(fabric, tensor, model):
is_spawn = fabric.strategy.launcher._start_method == "spawn"
assert not is_spawn or tensor.is_shared()
assert not model.layer.weight.is_shared()
assert not model.tied_layer.weight.is_shared()
assert not model.buffer.is_shared()

# weights remain tied
assert model.layer.weight.data_ptr() == model.tied_layer.weight.data_ptr()
assert torch.equal(model.layer.weight.data, model.tied_layer.weight.data)
fabric.barrier()
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@
from lightning.fabric.plugins import ClusterEnvironment
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.accelerators import CPUAccelerator
from lightning.pytorch.strategies import DDPStrategy
from lightning.pytorch.strategies.launchers.multiprocessing import _GlobalStateSnapshot, _MultiProcessingLauncher
from lightning.pytorch.trainer.states import TrainerFn
from tests_pytorch.helpers.runif import RunIf



@mock.patch("lightning.pytorch.strategies.launchers.multiprocessing.mp.get_all_start_methods", return_value=[])
def test_multiprocessing_launcher_forking_on_unsupported_platform(_):
with pytest.raises(ValueError, match="The start method 'fork' is not available on this platform"):
Expand Down Expand Up @@ -175,3 +177,31 @@ def test_kill():
with patch("os.kill") as kill_patch:
launcher.kill(15)
assert kill_patch.mock_calls == [call(proc0.pid, 15), call(proc1.pid, 15)]


class SimpleModel(BoringModel):
def __init__(self):
super().__init__()
self.tied_layer = torch.nn.Linear(32, 2)
self.tied_layer.weight = self.layer.weight
self.register_buffer("buffer", torch.ones(3))

def on_fit_start(self) -> None:
assert not self.layer.weight.is_shared()
assert not self.tied_layer.weight.is_shared()
assert not self.buffer.is_shared()

# weights remain tied
assert self.layer.weight.data_ptr() == self.tied_layer.weight.data_ptr()
assert torch.equal(self.layer.weight.data, self.tied_layer.weight.data)


def test_memory_sharing_disabled():
"""Test that the multiprocessing launcher disables memory sharing on model parameters and buffers to avoid race
conditions on model updates."""
model = SimpleModel()
assert not model.layer.weight.is_shared()
assert model.layer.weight.data_ptr() == model.tied_layer.weight.data_ptr()

trainer = Trainer(accelerator="cpu", devices=2, strategy="ddp_spawn", max_steps=0)
trainer.fit(model)