Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix the scatter when input is cpu #1621

Merged
merged 5 commits into from
Jan 24, 2022
Merged

Conversation

Bilibilee
Copy link
Contributor

@Bilibilee Bilibilee commented Dec 27, 2021

Motivation

According to the issue #792.
find inconsistency between GPU and GPU

from mmcv.parallel import scatter_kwargs, DataContainer
import torch 
inputs = (torch.zeros([20, 3, 128, 128]), )
output, _ = scatter_kwargs(inputs, {}, [-1], 0)
print('CPU', output[0][0].size())
output, _ = scatter_kwargs(inputs, {}, [0], 0)
print('GPU', output[0][0].size())

print('------------')
inputs = (DataContainer([torch.zeros([20, 3, 128, 128])]),)
output, _ = scatter_kwargs(inputs, {}, [-1], 0)
print('CPU', output[0][0].size())
output, _ = scatter_kwargs(inputs, {}, [0], 0)
print('GPU', output[0][0].size())
CPU torch.Size([20, 3, 128, 128])
GPU torch.Size([20, 3, 128, 128])
---------------------
CPU torch.Size([1, 20, 3, 128, 128])
GPU torch.Size([20, 3, 128, 128])

the inconsistency is caused by the https://github.com/open-mmlab/mmcv/blob/master/mmcv/parallel/_functions.py#L28
output = output.unsqueeze(0)

Modification

to fix the problem, I just make a simple modification:
Delete the output = output.unsqueeze(0),and make the the return of Scatter.forward be tuple(outputs) if isinstance(outputs,list) else (outputs,)
so that,without the unsqueeze(0),the cpu and gpu have same shape.

@CLAassistant
Copy link

CLAassistant commented Dec 27, 2021

CLA assistant check
All committers have signed the CLA.

@teamwong111
Copy link
Contributor

I think this modification is great. Maybe @ZwwWayne take a look at it. Because pr#497 is related to this.

@zhouzaida
Copy link
Collaborator

Related PR #1282

@zhouzaida
Copy link
Collaborator

Hi @Bilibilee , thanks for your contribution. Could you resolve the CI that failed.

Add spaces to comply with the code specification
@Bilibilee
Copy link
Contributor Author

Bilibilee commented Jan 17, 2022

Hi @zhouzaida , I resolve the CI error

@zhouzaida
Copy link
Collaborator

Hi @zhouzaida , I resolve the CI error

Got it. It would be great if some unit tests are added at https://github.com/open-mmlab/mmcv/blob/master/tests/test_parallel.py

@zhouzaida
Copy link
Collaborator

Hi, could you update the tests/test_parallel.py with the following block?

from unittest.mock import MagicMock, patch

import pytest
import torch
import torch.nn as nn
from torch.nn.parallel import DataParallel, DistributedDataParallel

from mmcv.parallel import (MODULE_WRAPPERS, MMDataParallel,
                           MMDistributedDataParallel, is_module_wrapper)
from mmcv.parallel._functions import Scatter, get_input_device, scatter
from mmcv.parallel.distributed_deprecated import \
    MMDistributedDataParallel as DeprecatedMMDDP


def mock(*args, **kwargs):
    pass


@patch('torch.distributed._broadcast_coalesced', mock)
@patch('torch.distributed.broadcast', mock)
@patch('torch.nn.parallel.DistributedDataParallel._ddp_init_helper', mock)
def test_is_module_wrapper():

    class Model(nn.Module):

        def __init__(self):
            super().__init__()
            self.conv = nn.Conv2d(2, 2, 1)

        def forward(self, x):
            return self.conv(x)

    # _verify_model_across_ranks is added in torch1.9.0 so we should check
    # whether _verify_model_across_ranks is the member of torch.distributed
    # before mocking
    if hasattr(torch.distributed, '_verify_model_across_ranks'):
        torch.distributed._verify_model_across_ranks = mock

    model = Model()
    assert not is_module_wrapper(model)

    dp = DataParallel(model)
    assert is_module_wrapper(dp)

    mmdp = MMDataParallel(model)
    assert is_module_wrapper(mmdp)

    ddp = DistributedDataParallel(model, process_group=MagicMock())
    assert is_module_wrapper(ddp)

    mmddp = MMDistributedDataParallel(model, process_group=MagicMock())
    assert is_module_wrapper(mmddp)

    deprecated_mmddp = DeprecatedMMDDP(model)
    assert is_module_wrapper(deprecated_mmddp)

    # test module wrapper registry
    @MODULE_WRAPPERS.register_module()
    class ModuleWrapper(object):

        def __init__(self, module):
            self.module = module

        def forward(self, *args, **kwargs):
            return self.module(*args, **kwargs)

    module_wraper = ModuleWrapper(model)
    assert is_module_wrapper(module_wraper)


def test_get_input_device():
    # if the device is CPU, return -1
    input = torch.zeros([1, 3, 3, 3])
    assert get_input_device(input) == -1
    inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
    assert get_input_device(inputs) == -1

    # if the device is GPU, return the index of device
    input = torch.zeros([1, 3, 3, 3]).cuda()
    assert get_input_device(input) == 0
    inputs = [
        torch.zeros([1, 3, 3, 3]).cuda(),
        torch.zeros([1, 4, 4, 4]).cuda()
    ]
    assert get_input_device(inputs) == 0

    # input should be a tensor or list of tensor
    with pytest.raises(Exception):
        get_input_device(5)


def test_scatter():
    # if the device is CPU, just return the input
    input = torch.zeros([1, 3, 3, 3])
    output = scatter(input=input, devices=[-1])
    assert torch.allclose(input, output)

    inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
    outputs = scatter(input=inputs, devices=[-1])
    for input, output in zip(inputs, outputs):
        assert torch.allclose(input, output)

    # if the device is GPU, copy the input from CPU to GPU
    if torch.cuda.is_available():
        input = torch.zeros([1, 3, 3, 3])
        output = scatter(input=input, devices=[0])
        assert torch.allclose(input.cuda(), output)

        inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
        outputs = scatter(input=inputs, devices=[0])
        for input, output in zip(inputs, outputs):
            assert torch.allclose(input.cuda(), output)

    # input should be a tensor or list of tensor
    with pytest.raises(Exception):
        scatter(5, [-1])


def test_Scatter():
    # if the device is CPU, just return the input
    target_gpus = [-1]
    input = torch.zeros([1, 3, 3, 3])
    outputs = Scatter.forward(target_gpus, input)
    assert isinstance(outputs, tuple)
    assert torch.allclose(input, outputs[0])

    target_gpus = [-1]
    inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
    outputs = Scatter.forward(target_gpus, inputs)
    assert isinstance(outputs, tuple)
    for input, output in zip(inputs, outputs):
        assert torch.allclose(input, output)

    # if the device is GPU, copy the input from CPU to GPU
    if torch.cuda.is_available():
        target_gpus = [0]
        input = torch.zeros([1, 3, 3, 3])
        outputs = Scatter.forward(target_gpus, input)
        assert isinstance(outputs, tuple)
        assert torch.allclose(input.cuda(), outputs[0])

        target_gpus = [0]
        inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
        outputs = Scatter.forward(target_gpus, inputs)
        assert isinstance(outputs, tuple)
        for input, output in zip(inputs, outputs):
            assert torch.allclose(input.cuda(), output[0])

@OpenMMLab-Assistant003
Copy link

Hi @Bilibilee!First of all, we want to express our gratitude for your significant PR in the MMCV project. Your contribution is highly appreciated, and we are grateful for your efforts in helping improve this open-source project during your personal time. We believe that many developers will benefit from your PR.

We would also like to invite you to join our Special Interest Group (SIG) private channel on Discord, where you can share your experiences, ideas, and build connections with like-minded peers. To join the SIG channel, simply message moderator— OpenMMLab on Discord or briefly share your open-source contributions in the #introductions channel and we will assist you. Look forward to seeing you there! Join us :https://discord.gg/UjgXkPWNqA

If you have WeChat account,welcome to join our community on WeChat. You can add our assistant :openmmlabwx. Please add "mmsig + Github ID" as a remark when adding friends:)
Thank you again for your contribution❤ @Bilibilee

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants