Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ref: device to gpus #3405

Merged
merged 5 commits into from
Sep 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 0 additions & 17 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,23 +130,6 @@ def copy_trainer_model_properties(self, model):
m.global_rank = self.global_rank
m.local_rank = self.local_rank

def transfer_batch_to_gpu(self, batch: Any, gpu_id: Optional[int] = None):
"""
Transfers the data to the GPU.

Args:
batch: A tensor or collection of tensors.
gpu_id: The id of the GPU device. If omitted, the first available GPU is chosen.

Return:
the tensor on the GPU device.

See Also:
- :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device`
"""
device = torch.device('cuda', gpu_id)
return self.__transfer_batch_to_device(batch, device)

def __transfer_batch_to_device(self, batch: Any, device: torch.device):
model = self.get_model()
if model is not None:
Expand Down
4 changes: 0 additions & 4 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,10 +203,6 @@ def copy_trainer_model_properties(self, *args):
def get_model(self) -> LightningModule:
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
def transfer_batch_to_gpu(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
def reset_test_dataloader(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""
Expand Down
6 changes: 4 additions & 2 deletions tests/core/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from tests.base.datamodules import TrialMNISTDataModule
from tests.base.develop_utils import reset_seed
from pytorch_lightning.utilities.model_utils import is_overridden
from pytorch_lightning.accelerators.gpu_backend import GPUBackend


def test_can_prepare_data(tmpdir):
Expand Down Expand Up @@ -346,13 +347,14 @@ def transfer_batch_to_device(self, data, device):
dm = CurrentTestDM()
batch = CustomBatch((torch.zeros(5, 28), torch.ones(5, 1, dtype=torch.long)))

trainer = Trainer()
trainer = Trainer(gpus=1)
# running .fit() would require us to implement custom data loaders, we mock the model reference instead
trainer.get_model = MagicMock(return_value=model)
if is_overridden('transfer_batch_to_device', dm):
model.transfer_batch_to_device = dm.transfer_batch_to_device

batch_gpu = trainer.transfer_batch_to_gpu(batch, 0)
trainer.accelerator_backend = GPUBackend(trainer)
batch_gpu = trainer.accelerator_backend.batch_to_device(batch, torch.device('cuda:0'))
expected = torch.device('cuda', 0)
assert dm.hook_called
assert batch_gpu.samples.device == batch_gpu.targets.device == expected
27 changes: 16 additions & 11 deletions tests/models/test_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
from tests.models.data.ddp import train_test_variations
from pytorch_lightning.accelerators.gpu_backend import GPUBackend
from pytorch_lightning.accelerators.cpu_backend import CPUBackend


PRETEND_N_OF_GPUS = 16

Expand Down Expand Up @@ -335,35 +338,36 @@ def test_parse_gpu_returns_none_when_no_devices_are_available(mocked_device_coun

@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_single_gpu_batch_parse():
trainer = Trainer()
trainer = Trainer(gpus=1)
trainer.accelerator_backend = GPUBackend(trainer)

# batch is just a tensor
batch = torch.rand(2, 3)
batch = trainer.transfer_batch_to_gpu(batch, 0)
batch = trainer.accelerator_backend.batch_to_device(batch, torch.device('cuda:0'))
assert batch.device.index == 0 and batch.type() == 'torch.cuda.FloatTensor'

# tensor list
batch = [torch.rand(2, 3), torch.rand(2, 3)]
batch = trainer.transfer_batch_to_gpu(batch, 0)
batch = trainer.accelerator_backend.batch_to_device(batch, torch.device('cuda:0'))
assert batch[0].device.index == 0 and batch[0].type() == 'torch.cuda.FloatTensor'
assert batch[1].device.index == 0 and batch[1].type() == 'torch.cuda.FloatTensor'

# tensor list of lists
batch = [[torch.rand(2, 3), torch.rand(2, 3)]]
batch = trainer.transfer_batch_to_gpu(batch, 0)
batch = trainer.accelerator_backend.batch_to_device(batch, torch.device('cuda:0'))
assert batch[0][0].device.index == 0 and batch[0][0].type() == 'torch.cuda.FloatTensor'
assert batch[0][1].device.index == 0 and batch[0][1].type() == 'torch.cuda.FloatTensor'

# tensor dict
batch = [{'a': torch.rand(2, 3), 'b': torch.rand(2, 3)}]
batch = trainer.transfer_batch_to_gpu(batch, 0)
batch = trainer.accelerator_backend.batch_to_device(batch, torch.device('cuda:0'))
assert batch[0]['a'].device.index == 0 and batch[0]['a'].type() == 'torch.cuda.FloatTensor'
assert batch[0]['b'].device.index == 0 and batch[0]['b'].type() == 'torch.cuda.FloatTensor'

# tuple of tensor list and list of tensor dict
batch = ([torch.rand(2, 3) for _ in range(2)],
[{'a': torch.rand(2, 3), 'b': torch.rand(2, 3)} for _ in range(2)])
batch = trainer.transfer_batch_to_gpu(batch, 0)
batch = trainer.accelerator_backend.batch_to_device(batch, torch.device('cuda:0'))
assert batch[0][0].device.index == 0 and batch[0][0].type() == 'torch.cuda.FloatTensor'

assert batch[1][0]['a'].device.index == 0
Expand All @@ -375,7 +379,7 @@ def test_single_gpu_batch_parse():
# namedtuple of tensor
BatchType = namedtuple('BatchType', ['a', 'b'])
batch = [BatchType(a=torch.rand(2, 3), b=torch.rand(2, 3)) for _ in range(2)]
batch = trainer.transfer_batch_to_gpu(batch, 0)
batch = trainer.accelerator_backend.batch_to_device(batch, torch.device('cuda:0'))
assert batch[0].a.device.index == 0
assert batch[0].a.type() == 'torch.cuda.FloatTensor'

Expand All @@ -388,7 +392,7 @@ def to(self, *args, **kwargs):
self.a = self.a.to(*args, **kwargs)
return self

batch = trainer.transfer_batch_to_gpu(CustomBatchType())
batch = trainer.accelerator_backend.batch_to_device(CustomBatchType(), torch.device('cuda:0'))
assert batch.a.type() == 'torch.cuda.FloatTensor'

# torchtext.data.Batch
Expand All @@ -415,7 +419,7 @@ def to(self, *args, **kwargs):
label_field.build_vocab(dataset)

batch = Batch(data=examples, dataset=dataset)
batch = trainer.transfer_batch_to_gpu(batch, 0)
batch = trainer.accelerator_backend.batch_to_device(batch, torch.device('cuda:0'))

assert batch.text.type() == 'torch.cuda.LongTensor'
assert batch.label.type() == 'torch.cuda.LongTensor'
Expand All @@ -425,10 +429,11 @@ def to(self, *args, **kwargs):
def test_non_blocking():
""" Tests that non_blocking=True only gets passed on torch.Tensor.to, but not on other objects. """
trainer = Trainer()
trainer.accelerator_backend = GPUBackend(trainer)

batch = torch.zeros(2, 3)
with patch.object(batch, 'to', wraps=batch.to) as mocked:
trainer.transfer_batch_to_gpu(batch, 0)
batch = trainer.accelerator_backend.batch_to_device(batch, torch.device('cuda:0'))
mocked.assert_called_with(torch.device('cuda', 0), non_blocking=True)

class BatchObject(object):
Expand All @@ -438,5 +443,5 @@ def to(self, *args, **kwargs):

batch = BatchObject()
with patch.object(batch, 'to', wraps=batch.to) as mocked:
trainer.transfer_batch_to_gpu(batch, 0)
batch = trainer.accelerator_backend.batch_to_device(batch, torch.device('cuda:0'))
mocked.assert_called_with(torch.device('cuda', 0))
6 changes: 4 additions & 2 deletions tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch

from pytorch_lightning import Trainer
from pytorch_lightning.accelerators.gpu_backend import GPUBackend
from tests.base import EvalModelTemplate


Expand Down Expand Up @@ -99,10 +100,11 @@ def transfer_batch_to_device(self, data, device):
model = CurrentTestModel()
batch = CustomBatch((torch.zeros(5, 28), torch.ones(5, 1, dtype=torch.long)))

trainer = Trainer()
trainer = Trainer(gpus=1)
trainer.accelerator_backend = GPUBackend(trainer)
# running .fit() would require us to implement custom data loaders, we mock the model reference instead
trainer.get_model = MagicMock(return_value=model)
batch_gpu = trainer.transfer_batch_to_gpu(batch, 0)
batch_gpu = trainer.accelerator_backend.batch_to_device(batch, torch.device('cuda:0'))
expected = torch.device('cuda', 0)
assert model.hook_called
assert batch_gpu.samples.device == batch_gpu.targets.device == expected