Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix the scatter when input is cpu #1621

Merged
merged 5 commits into from
Jan 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions mmcv/parallel/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,7 @@ def scatter(input, devices, streams=None):
if devices != [-1]:
with torch.cuda.device(devices[0]), torch.cuda.stream(stream):
output = output.cuda(devices[0], non_blocking=True)
else:
# unsqueeze the first dimension thus the tensor's shape is the
# same as those scattered with GPU.
output = output.unsqueeze(0)

return output
else:
raise Exception(f'Unknown type {type(input)}.')
Expand Down Expand Up @@ -76,4 +73,4 @@ def forward(target_gpus, input):
if streams is not None:
synchronize_stream(outputs, target_gpus, streams)

return tuple(outputs)
return tuple(outputs) if isinstance(outputs, list) else (outputs, )
82 changes: 82 additions & 0 deletions tests/test_parallel.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from unittest.mock import MagicMock, patch

import pytest
import torch
import torch.nn as nn
from torch.nn.parallel import DataParallel, DistributedDataParallel

from mmcv.parallel import (MODULE_WRAPPERS, MMDataParallel,
MMDistributedDataParallel, is_module_wrapper)
from mmcv.parallel._functions import Scatter, get_input_device, scatter
from mmcv.parallel.distributed_deprecated import \
MMDistributedDataParallel as DeprecatedMMDDP

Expand Down Expand Up @@ -64,3 +66,83 @@ def forward(self, *args, **kwargs):

module_wraper = ModuleWrapper(model)
assert is_module_wrapper(module_wraper)


def test_get_input_device():
# if the device is CPU, return -1
input = torch.zeros([1, 3, 3, 3])
assert get_input_device(input) == -1
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
assert get_input_device(inputs) == -1

# if the device is GPU, return the index of device
if torch.cuda.is_available():
input = torch.zeros([1, 3, 3, 3]).cuda()
assert get_input_device(input) == 0
inputs = [
torch.zeros([1, 3, 3, 3]).cuda(),
torch.zeros([1, 4, 4, 4]).cuda()
]
assert get_input_device(inputs) == 0

# input should be a tensor or list of tensor
with pytest.raises(Exception):
get_input_device(5)


def test_scatter():
# if the device is CPU, just return the input
input = torch.zeros([1, 3, 3, 3])
output = scatter(input=input, devices=[-1])
assert torch.allclose(input, output)

inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
outputs = scatter(input=inputs, devices=[-1])
for input, output in zip(inputs, outputs):
assert torch.allclose(input, output)

# if the device is GPU, copy the input from CPU to GPU
if torch.cuda.is_available():
input = torch.zeros([1, 3, 3, 3])
output = scatter(input=input, devices=[0])
assert torch.allclose(input.cuda(), output)

inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
outputs = scatter(input=inputs, devices=[0])
for input, output in zip(inputs, outputs):
assert torch.allclose(input.cuda(), output)

# input should be a tensor or list of tensor
with pytest.raises(Exception):
scatter(5, [-1])


def test_Scatter():
# if the device is CPU, just return the input
target_gpus = [-1]
input = torch.zeros([1, 3, 3, 3])
outputs = Scatter.forward(target_gpus, input)
assert isinstance(outputs, tuple)
assert torch.allclose(input, outputs[0])

target_gpus = [-1]
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
outputs = Scatter.forward(target_gpus, inputs)
assert isinstance(outputs, tuple)
for input, output in zip(inputs, outputs):
assert torch.allclose(input, output)

# if the device is GPU, copy the input from CPU to GPU
if torch.cuda.is_available():
target_gpus = [0]
input = torch.zeros([1, 3, 3, 3])
outputs = Scatter.forward(target_gpus, input)
assert isinstance(outputs, tuple)
assert torch.allclose(input.cuda(), outputs[0])

target_gpus = [0]
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
outputs = Scatter.forward(target_gpus, inputs)
assert isinstance(outputs, tuple)
for input, output in zip(inputs, outputs):
assert torch.allclose(input.cuda(), output[0])