diff --git a/mmcv/parallel/_functions.py b/mmcv/parallel/_functions.py index 9b5a8a4448..95c58bf1a8 100644 --- a/mmcv/parallel/_functions.py +++ b/mmcv/parallel/_functions.py @@ -22,10 +22,7 @@ def scatter(input, devices, streams=None): if devices != [-1]: with torch.cuda.device(devices[0]), torch.cuda.stream(stream): output = output.cuda(devices[0], non_blocking=True) - else: - # unsqueeze the first dimension thus the tensor's shape is the - # same as those scattered with GPU. - output = output.unsqueeze(0) + return output else: raise Exception(f'Unknown type {type(input)}.') @@ -76,4 +73,4 @@ def forward(target_gpus, input): if streams is not None: synchronize_stream(outputs, target_gpus, streams) - return tuple(outputs) + return tuple(outputs) if isinstance(outputs, list) else (outputs, ) diff --git a/tests/test_parallel.py b/tests/test_parallel.py index e8e5456828..f551c4025c 100644 --- a/tests/test_parallel.py +++ b/tests/test_parallel.py @@ -1,11 +1,13 @@ from unittest.mock import MagicMock, patch +import pytest import torch import torch.nn as nn from torch.nn.parallel import DataParallel, DistributedDataParallel from mmcv.parallel import (MODULE_WRAPPERS, MMDataParallel, MMDistributedDataParallel, is_module_wrapper) +from mmcv.parallel._functions import Scatter, get_input_device, scatter from mmcv.parallel.distributed_deprecated import \ MMDistributedDataParallel as DeprecatedMMDDP @@ -64,3 +66,83 @@ def forward(self, *args, **kwargs): module_wraper = ModuleWrapper(model) assert is_module_wrapper(module_wraper) + + +def test_get_input_device(): + # if the device is CPU, return -1 + input = torch.zeros([1, 3, 3, 3]) + assert get_input_device(input) == -1 + inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])] + assert get_input_device(inputs) == -1 + + # if the device is GPU, return the index of device + if torch.cuda.is_available(): + input = torch.zeros([1, 3, 3, 3]).cuda() + assert get_input_device(input) == 0 + inputs = [ + torch.zeros([1, 3, 3, 3]).cuda(), + torch.zeros([1, 4, 4, 4]).cuda() + ] + assert get_input_device(inputs) == 0 + + # input should be a tensor or list of tensor + with pytest.raises(Exception): + get_input_device(5) + + +def test_scatter(): + # if the device is CPU, just return the input + input = torch.zeros([1, 3, 3, 3]) + output = scatter(input=input, devices=[-1]) + assert torch.allclose(input, output) + + inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])] + outputs = scatter(input=inputs, devices=[-1]) + for input, output in zip(inputs, outputs): + assert torch.allclose(input, output) + + # if the device is GPU, copy the input from CPU to GPU + if torch.cuda.is_available(): + input = torch.zeros([1, 3, 3, 3]) + output = scatter(input=input, devices=[0]) + assert torch.allclose(input.cuda(), output) + + inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])] + outputs = scatter(input=inputs, devices=[0]) + for input, output in zip(inputs, outputs): + assert torch.allclose(input.cuda(), output) + + # input should be a tensor or list of tensor + with pytest.raises(Exception): + scatter(5, [-1]) + + +def test_Scatter(): + # if the device is CPU, just return the input + target_gpus = [-1] + input = torch.zeros([1, 3, 3, 3]) + outputs = Scatter.forward(target_gpus, input) + assert isinstance(outputs, tuple) + assert torch.allclose(input, outputs[0]) + + target_gpus = [-1] + inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])] + outputs = Scatter.forward(target_gpus, inputs) + assert isinstance(outputs, tuple) + for input, output in zip(inputs, outputs): + assert torch.allclose(input, output) + + # if the device is GPU, copy the input from CPU to GPU + if torch.cuda.is_available(): + target_gpus = [0] + input = torch.zeros([1, 3, 3, 3]) + outputs = Scatter.forward(target_gpus, input) + assert isinstance(outputs, tuple) + assert torch.allclose(input.cuda(), outputs[0]) + + target_gpus = [0] + inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])] + outputs = Scatter.forward(target_gpus, inputs) + assert isinstance(outputs, tuple) + for input, output in zip(inputs, outputs): + assert torch.allclose(input.cuda(), output[0])