Skip to content

Commit c53edce

Browse files
rohitgr7carmoccatchaton
authored
Disable batch transfer in DP mode (#6098)
* add exceptions and test * hook * fix * clean up * clean up * regex * regex * docs * rev * comment and docs * chlog * Apply suggestions from code review Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * Apply suggestions from code review Co-authored-by: chaton <thomas@grid.ai> * Monkey-patch device count * docs * pep * api_change Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: chaton <thomas@grid.ai>
1 parent e886d55 commit c53edce

File tree

5 files changed

+99
-20
lines changed

5 files changed

+99
-20
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
184184
- Fixed error message for AMP + CPU incompatibility ([#6107](https://github.com/PyTorchLightning/pytorch-lightning/pull/6107))
185185

186186

187+
- Disabled batch transfer in DP mode ([#6093](https://github.com/PyTorchLightning/pytorch-lightning/pull/6093))
188+
189+
187190
## [1.2.0] - 2021-02-18
188191

189192
### Added

pytorch_lightning/accelerators/gpu.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import logging
22
import os
3-
from typing import TYPE_CHECKING
3+
from typing import TYPE_CHECKING, Any
44

55
import torch
66

77
from pytorch_lightning.accelerators.accelerator import Accelerator
8+
from pytorch_lightning.plugins import DataParallelPlugin
89
from pytorch_lightning.utilities.exceptions import MisconfigurationException
910

1011
if TYPE_CHECKING:
@@ -48,3 +49,11 @@ def set_nvidia_flags() -> None:
4849
all_gpu_ids = ",".join([str(x) for x in range(torch.cuda.device_count())])
4950
devices = os.getenv("CUDA_VISIBLE_DEVICES", all_gpu_ids)
5051
_log.info(f"LOCAL_RANK: {os.getenv('LOCAL_RANK', 0)} - CUDA_VISIBLE_DEVICES: [{devices}]")
52+
53+
def to_device(self, batch: Any) -> Any:
54+
# no need to transfer batch to device in DP mode
55+
# TODO: Add support to allow batch transfer to device in Lightning for DP mode.
56+
if not isinstance(self.training_type_plugin, DataParallelPlugin):
57+
batch = super().to_device(batch)
58+
59+
return batch

pytorch_lightning/core/hooks.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -615,10 +615,7 @@ def transfer_batch_to_device(self, batch: Any, device: Optional[torch.device] =
615615
616616
Note:
617617
This hook only runs on single GPU training and DDP (no data-parallel).
618-
If you need multi-GPU support for your custom batch objects, you need to define your custom
619-
:class:`~torch.nn.parallel.DistributedDataParallel` or
620-
:class:`~pytorch_lightning.overrides.data_parallel.LightningDistributedDataParallel` and
621-
override :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_ddp`.
618+
Data-Parallel support will come in near future.
622619
623620
Args:
624621
batch: A batch of data that needs to be transferred to a new device.
@@ -638,6 +635,10 @@ def transfer_batch_to_device(self, batch, device):
638635
batch = super().transfer_batch_to_device(data, device)
639636
return batch
640637
638+
Raises:
639+
MisconfigurationException:
640+
If using data-parallel, ``Trainer(accelerator='dp')``.
641+
641642
See Also:
642643
- :meth:`move_data_to_device`
643644
- :meth:`apply_to_collection`
@@ -649,10 +650,11 @@ def on_before_batch_transfer(self, batch, dataloader_idx):
649650
"""
650651
Override to alter or apply batch augmentations to your batch before it is transferred to the device.
651652
652-
.. warning:: dataloader_idx always returns 0, and will be updated to support the true idx in the future.
653+
.. warning:: ``dataloader_idx`` always returns 0, and will be updated to support the true index in the future.
653654
654655
Note:
655656
This hook only runs on single GPU training and DDP (no data-parallel).
657+
Data-Parallel support will come in near future.
656658
657659
Args:
658660
batch: A batch of data that needs to be altered or augmented.
@@ -667,6 +669,10 @@ def on_before_batch_transfer(self, batch, dataloader_idx):
667669
batch['x'] = transforms(batch['x'])
668670
return batch
669671
672+
Raises:
673+
MisconfigurationException:
674+
If using data-parallel, ``Trainer(accelerator='dp')``.
675+
670676
See Also:
671677
- :meth:`on_after_batch_transfer`
672678
- :meth:`transfer_batch_to_device`
@@ -681,6 +687,7 @@ def on_after_batch_transfer(self, batch, dataloader_idx):
681687
682688
Note:
683689
This hook only runs on single GPU training and DDP (no data-parallel).
690+
Data-Parallel support will come in near future.
684691
685692
Args:
686693
batch: A batch of data that needs to be altered or augmented.
@@ -695,6 +702,10 @@ def on_after_batch_transfer(self, batch, dataloader_idx):
695702
batch['x'] = gpu_transforms(batch['x'])
696703
return batch
697704
705+
Raises:
706+
MisconfigurationException:
707+
If using data-parallel, ``Trainer(accelerator='dp')``.
708+
698709
See Also:
699710
- :meth:`on_before_batch_transfer`
700711
- :meth:`transfer_batch_to_device`

pytorch_lightning/trainer/connectors/data_connector.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def attach_data(self, model, train_dataloader, val_dataloaders, datamodule):
8989
# set up the passed in dataloaders (if needed)
9090
self.attach_dataloaders(model, train_dataloader, val_dataloaders)
9191
self.attach_datamodule(model, datamodule)
92+
self._validate_data_hooks(model)
9293

9394
def __enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloaders, datamodule):
9495
# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
@@ -97,6 +98,14 @@ def __enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloa
9798
'You cannot pass train_dataloader or val_dataloaders to trainer.fit if you supply a datamodule'
9899
)
99100

101+
def _validate_data_hooks(self, model):
102+
# Raise Misconfiguration exception since these hooks are not supported in DP mode
103+
# TODO: Remove this blocker once batch transfer to device is integrated in Lightning for DP mode.
104+
batch_transfer_hooks = ('on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer')
105+
for hook in batch_transfer_hooks:
106+
if self.trainer.accelerator_connector.use_dp and is_overridden(hook, model):
107+
raise MisconfigurationException(f'Overriding `{hook}` is not supported in DP mode.')
108+
100109
def attach_dataloaders(
101110
self,
102111
model,
@@ -127,22 +136,16 @@ def attach_datamodule(self, model, datamodule: Optional[LightningDataModule] = N
127136
if datamodule:
128137

129138
# Override loader hooks
130-
if is_overridden('train_dataloader', datamodule):
131-
model.train_dataloader = datamodule.train_dataloader
132-
if is_overridden('val_dataloader', datamodule):
133-
model.val_dataloader = datamodule.val_dataloader
134-
if is_overridden('test_dataloader', datamodule):
135-
model.test_dataloader = datamodule.test_dataloader
136-
if is_overridden('predict_dataloader', datamodule):
137-
model.predict_dataloader = datamodule.predict_dataloader
139+
dl_methods = ('train_dataloader', 'val_dataloader', 'test_dataloader', 'predict_dataloader')
140+
for method in dl_methods:
141+
if is_overridden(method, datamodule):
142+
setattr(model, method, getattr(datamodule, method))
138143

139144
# Override data transfer hooks if dataset-specific to_device logic has been defined in datamodule
140-
if is_overridden('on_before_batch_transfer', datamodule):
141-
model.on_before_batch_transfer = datamodule.on_before_batch_transfer
142-
if is_overridden('transfer_batch_to_device', datamodule):
143-
model.transfer_batch_to_device = datamodule.transfer_batch_to_device
144-
if is_overridden('on_after_batch_transfer', datamodule):
145-
model.on_after_batch_transfer = datamodule.on_after_batch_transfer
145+
batch_transfer_hooks = ('on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer')
146+
for hook in batch_transfer_hooks:
147+
if is_overridden(hook, datamodule):
148+
setattr(model, hook, getattr(datamodule, hook))
146149

147150
self.trainer.datamodule = datamodule
148151
datamodule.trainer = self.trainer

tests/accelerators/test_dp.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,18 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import pytest
1415
import torch
1516
import torch.nn.functional as F
1617
from torch.utils.data import DataLoader
1718

1819
import pytorch_lightning as pl
1920
import tests.helpers.pipelines as tpipes
2021
import tests.helpers.utils as tutils
22+
from pytorch_lightning import Trainer
2123
from pytorch_lightning.callbacks import EarlyStopping
2224
from pytorch_lightning.core import memory
25+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2326
from tests.helpers import BoringModel, RandomDataset
2427
from tests.helpers.datamodules import ClassifDataModule
2528
from tests.helpers.runif import RunIf
@@ -132,6 +135,56 @@ def training_epoch_end(self, outputs):
132135
assert outputs[0]["reduce_float"].item() == 0.5 # mean([0., 1.]) = 0.5
133136

134137

138+
def test_dp_raise_exception_with_batch_transfer_hooks(tmpdir, monkeypatch):
139+
"""
140+
Test that an exception is raised when overriding batch_transfer_hooks in DP model.
141+
"""
142+
monkeypatch.setattr("torch.cuda.device_count", lambda: 2)
143+
144+
class CustomModel(BoringModel):
145+
146+
def transfer_batch_to_device(self, batch, device):
147+
batch = batch.to(device)
148+
return batch
149+
150+
trainer_options = dict(
151+
default_root_dir=tmpdir,
152+
max_steps=7,
153+
gpus=[0, 1],
154+
accelerator='dp',
155+
)
156+
157+
trainer = Trainer(**trainer_options)
158+
model = CustomModel()
159+
160+
with pytest.raises(MisconfigurationException, match=r'Overriding `transfer_batch_to_device` is not .* in DP'):
161+
trainer.fit(model)
162+
163+
class CustomModel(BoringModel):
164+
165+
def on_before_batch_transfer(self, batch, dataloader_idx):
166+
batch += 1
167+
return batch
168+
169+
trainer = Trainer(**trainer_options)
170+
model = CustomModel()
171+
172+
with pytest.raises(MisconfigurationException, match=r'Overriding `on_before_batch_transfer` is not .* in DP'):
173+
trainer.fit(model)
174+
175+
class CustomModel(BoringModel):
176+
177+
def on_after_batch_transfer(self, batch, dataloader_idx):
178+
batch += 1
179+
return batch
180+
181+
trainer = Trainer(**trainer_options)
182+
model = CustomModel()
183+
184+
with pytest.raises(MisconfigurationException, match=r'Overriding `on_after_batch_transfer` is not .* in DP'):
185+
trainer.fit(model)
186+
187+
135188
@RunIf(min_gpus=2)
136189
def test_dp_training_step_dict(tmpdir):
137190
""" This test verifies that dp properly reduces dictionaries """

0 commit comments

Comments
 (0)