diff --git a/CHANGELOG.md b/CHANGELOG.md index 0f96664db72b0..df7e1bdd8188f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -240,6 +240,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `AttributeError for `require_backward_grad_sync` when running manual optimization with sharded plugin ([#6915](https://github.com/PyTorchLightning/pytorch-lightning/pull/6915)) +- Fixed `self.device` not returning the correct device in replicas of data-parallel ([#6414](https://github.com/PyTorchLightning/pytorch-lightning/pull/6414)) + + ## [1.2.7] - 2021-04-06 ### Fixed diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index b027502f99e8a..a9d312e9f6417 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -22,6 +22,7 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.overrides.distributed import LightningDistributedModule +from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection @@ -71,6 +72,8 @@ def __init__(self, pl_module: LightningModule): super().__init__(pl_module) def forward(self, *inputs, **kwargs): + self.update_replica_device_attributes(inputs) + # forward call will redirect to training_step, validation_step, etc. output = super().forward(*inputs, **kwargs) def output_transform(data: Any): @@ -85,6 +88,37 @@ def output_transform(data: Any): ) return output + def update_replica_device_attributes(self, inputs: Any) -> None: + """ + Updates the device information of LightningModule by reading the device from the inputs. + In :class:`~torch.nn.data_parallel.DataParallel` changes to the state during the `forward` pass + are lost when the replicas get discarded. The only way to know the current device is from the + inputs passed into the model. + + Args: + inputs: A collection of inputs (typically a tuple). If the inputs don't contain tensors, + a warning is shown that accessing ``self.device`` will not return the correct device. + """ + replica_device = None + + def find_tensor_with_device(tensor: torch.Tensor) -> torch.Tensor: + nonlocal replica_device + if replica_device is None and tensor.device != torch.device("cpu"): + replica_device = tensor.device + return tensor + + apply_to_collection(inputs, dtype=torch.Tensor, function=find_tensor_with_device) + + if replica_device is not None: + # by calling .to() we force the update to the self.device property + self.module.to(device=replica_device) + else: + rank_zero_warn( + "Could not determine on which device the inputs are." + " When using DataParallel (accelerator='dp'), be aware that in case you are using self.device" + " in your code, it will reference only the root device." + ) + def python_scalar_to_tensor(data: Any, device: torch.device = torch.device("cpu")) -> Any: """ Converts a Python scalar number to a torch tensor and places it on the given device. """ diff --git a/tests/overrides/test_data_parallel.py b/tests/overrides/test_data_parallel.py index aaf47c82d5f08..f9440afc59767 100644 --- a/tests/overrides/test_data_parallel.py +++ b/tests/overrides/test_data_parallel.py @@ -2,8 +2,11 @@ import pytest import torch +import torch.nn as nn from torch.nn import DataParallel +from pytorch_lightning import LightningModule +from pytorch_lightning.core.decorators import auto_move_data from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides.data_parallel import ( LightningParallelModule, @@ -123,3 +126,68 @@ def training_step(self, batch, batch_idx): wrapped_model = LightningParallelModule(model) output = wrapped_model(batch, batch_idx) assert output["python scalar"] == torch.tensor([12.3], device=device) + + +@RunIf(min_gpus=2) +@pytest.mark.parametrize( + "nest, unnest", [ + (lambda x: x, lambda x: x), + (lambda x: dict(data=x), lambda x: x["data"]), + (lambda x: [x, (x, x)], lambda x: x[1][0]), + ] +) +def test_lightning_parallel_module_device_access(nest, unnest): + """ Test that self.device returns the correct value in replicas of DataParallel. """ + + class DeviceAccessModel(LightningModule): + + def __init__(self): + super().__init__() + self.layer = nn.Linear(2, 3) + + @auto_move_data + def training_step(self, batch, batch_idx): + batch = unnest(batch) + assert batch.shape == torch.Size([1, 1]) + assert self.device.index == batch.item() + assert self.device == self.layer.weight.device + return torch.tensor(1, device=self.device) + + pl_module = DeviceAccessModel() + # required for redirecting the forward call to training_step + pl_module.trainer = Mock() + pl_module.trainer._running_stage = RunningStage.TRAINING + + root_device = torch.device("cuda", 0) + wrapped_module = LightningParallelModule(pl_module).to(root_device) + model = DataParallel(wrapped_module, device_ids=[0, 1]) + + data = torch.tensor([0.0, 1.0], device=root_device).view(2, 1) # one value per gpu + data = data.to(root_device) + data = nest(data) + output = model(data, 0) + assert output.device == root_device + assert pl_module.device == root_device + assert torch.all(output.cpu().eq(torch.tensor([1, 1]))) + + +@RunIf(min_gpus=2) +def test_lightning_parallel_module_device_access_warning(): + """ Test that we show a warning when the device can't be inferred from the input. """ + + class DeviceAccessModel(LightningModule): + + def training_step(self, batch, batch_idx): + pass + + pl_module = DeviceAccessModel() + # required for redirecting the forward call to training_step + pl_module.trainer = Mock() + pl_module.trainer._running_stage = RunningStage.TRAINING + + wrapped_module = LightningParallelModule(pl_module).cuda() + model = DataParallel(wrapped_module, device_ids=[0, 1]) + + data = dict(x=1) # contains no tensors + with pytest.warns(UserWarning, match="Could not determine on which device the inputs are."): + _ = model(data, 0)