Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support PyTorch backend on MLU #1770

Merged
merged 6 commits into from
Apr 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ jobs:
- name: Run unittests and generate coverage report
run: |
pip install -r requirements/test.txt
pytest tests/ --ignore=tests/test_runner --ignore=tests/test_optimizer.py --ignore=tests/test_cnn --ignore=tests/test_parallel.py --ignore=tests/test_ops --ignore=tests/test_load_model_zoo.py --ignore=tests/test_utils/test_logging.py --ignore=tests/test_image/test_io.py --ignore=tests/test_utils/test_registry.py --ignore=tests/test_utils/test_parrots_jit.py --ignore=tests/test_utils/test_trace.py --ignore=tests/test_utils/test_hub.py
pytest tests/ --ignore=tests/test_runner --ignore=tests/test_optimizer.py --ignore=tests/test_cnn --ignore=tests/test_parallel.py --ignore=tests/test_ops --ignore=tests/test_load_model_zoo.py --ignore=tests/test_utils/test_logging.py --ignore=tests/test_image/test_io.py --ignore=tests/test_utils/test_registry.py --ignore=tests/test_utils/test_parrots_jit.py --ignore=tests/test_utils/test_trace.py --ignore=tests/test_utils/test_hub.py --ignore=tests/test_device/test_mlu/test_mlu_parallel.py

build_without_ops:
runs-on: ubuntu-18.04
Expand Down
1 change: 1 addition & 0 deletions mmcv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
# - runner
# - parallel
# - op
# - device
4 changes: 4 additions & 0 deletions mmcv/device/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from . import mlu

__all__ = ['mlu']
10 changes: 10 additions & 0 deletions mmcv/device/mlu/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .data_parallel import MLUDataParallel
from .distributed import MLUDistributedDataParallel
from .scatter_gather import scatter, scatter_kwargs
from .utils import IS_MLU

__all__ = [
'MLUDataParallel', 'MLUDistributedDataParallel', 'scatter',
'scatter_kwargs', 'IS_MLU'
]
22 changes: 22 additions & 0 deletions mmcv/device/mlu/_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch


def scatter(input, devices):
"""scatter copies tensor to MLU directly."""
if isinstance(input, list):
outputs = [scatter(_input, devices) for _input in input]
return outputs
elif isinstance(input, torch.Tensor):
output = input.contiguous()
return output.to('mlu') if devices != [-1] else output
else:
raise Exception(f'Unknown type {type(input)}.')


class Scatter:

@staticmethod
def forward(target_mlus, input):
outputs = scatter(input, target_mlus)
return tuple(outputs) if isinstance(outputs, list) else (outputs, )
41 changes: 41 additions & 0 deletions mmcv/device/mlu/data_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) OpenMMLab. All rights reserved.

import torch

from mmcv.parallel import MMDataParallel
from .scatter_gather import scatter_kwargs


class MLUDataParallel(MMDataParallel):
"""The MLUDataParallel module that supports DataContainer.
Qiza-lyhm marked this conversation as resolved.
Show resolved Hide resolved

MLUDataParallel is a class inherited from MMDataParall, which supports
MLU training and inference only.

The main differences with MMDataParallel:

- It only supports single-card of MLU, and only use first card to
run training and inference.

- It uses direct host-to-device copy instead of stream-background
scatter.

.. warning::
MLUDataParallel only supports single MLU training, if you need to
train with multiple MLUs, please use MLUDistributedDataParallel
instead. If you have multiple MLUs, you can set the environment
variable ``MLU_VISIBLE_DEVICES=0`` (or any other card number(s))
to specify the running device.

Args:
module (:class:`nn.Module`): Module to be encapsulated.
dim (int): Dimension used to scatter the data. Defaults to 0.
"""

def __init__(self, *args, dim=0, **kwargs):
super(MLUDataParallel, self).__init__(*args, dim=dim, **kwargs)
self.device_ids = [0]
self.src_device_obj = torch.device('mlu:0')

def scatter(self, inputs, kwargs, device_ids):
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
20 changes: 20 additions & 0 deletions mmcv/device/mlu/distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) OpenMMLab. All rights reserved.

from mmcv.parallel import MMDistributedDataParallel
from .scatter_gather import scatter_kwargs


class MLUDistributedDataParallel(MMDistributedDataParallel):
"""The DDP module supports DataContainer.

MLUDDP has one difference from MMDDP which moves data to MLU with coping
instead of scattering.
"""

def to_kwargs(self, inputs, kwargs, device_id):
# Use `self.to_kwargs` instead of `self.scatter` in pytorch1.8
# to move all tensors to device_id
return scatter_kwargs(inputs, kwargs, [device_id], dim=self.dim)

def scatter(self, inputs, kwargs, device_ids):
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
59 changes: 59 additions & 0 deletions mmcv/device/mlu/scatter_gather.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch

from mmcv.parallel.data_container import DataContainer
from ._functions import Scatter


def scatter(inputs, target_mlus, dim=0):
"""Scatter inputs to target mlu.

The only difference from original :func:`scatter` is to add support for
:type:`~mmcv.parallel.DataContainer`.
"""

def scatter_map(obj):
if isinstance(obj, torch.Tensor):
if target_mlus != [-1]:
obj = obj.to('mlu')
Qiza-lyhm marked this conversation as resolved.
Show resolved Hide resolved
return obj
else:
# for CPU inference we use self-implemented scatter
return Scatter.forward(target_mlus, obj)
if isinstance(obj, DataContainer):
if obj.cpu_only:
return obj.data
else:
return Scatter.forward(target_mlus, obj.data)
if isinstance(obj, tuple) and len(obj) > 0:
return list(zip(*map(scatter_map, obj)))
if isinstance(obj, list) and len(obj) > 0:
out = list(map(list, zip(*map(scatter_map, obj))))
return out
if isinstance(obj, dict) and len(obj) > 0:
out = list(map(type(obj), zip(*map(scatter_map, obj.items()))))
return out
return [obj for targets in target_mlus]

# After scatter_map is called, a scatter_map cell will exist. This cell
# has a reference to the actual function scatter_map, which has references
# to a closure that has a reference to the scatter_map cell (because the
# fn is recursive). To avoid this reference cycle, we set the function to
# None, clearing the cell
try:
return scatter_map(inputs)
finally:
scatter_map = None


def scatter_kwargs(inputs, kwargs, target_mlus, dim=0):
"""Scatter with support for kwargs dictionary."""
inputs = scatter(inputs, target_mlus, dim) if inputs else []
kwargs = scatter(kwargs, target_mlus, dim) if kwargs else []
if len(inputs) < len(kwargs):
inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
elif len(kwargs) < len(inputs):
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
inputs = tuple(inputs)
kwargs = tuple(kwargs)
return inputs, kwargs
11 changes: 11 additions & 0 deletions mmcv/device/mlu/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
def is_mlu_available():
try:
import torch
return (hasattr(torch, 'is_mlu_available')
and torch.is_mlu_available())
except Exception:
return False


IS_MLU = is_mlu_available()
5 changes: 5 additions & 0 deletions mmcv/ops/csrc/common/pytorch_cpp_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,18 @@ using namespace at;

#define CHECK_CUDA(x) \
TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_MLU(x) \
TORCH_CHECK(x.device().type() == at::kMLU, #x " must be a MLU tensor")
#define CHECK_CPU(x) \
TORCH_CHECK(!x.device().is_cuda(), #x " must be a CPU tensor")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_CUDA_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
#define CHECK_MLU_INPUT(x) \
CHECK_MLU(x); \
CHECK_CONTIGUOUS(x)
#define CHECK_CPU_INPUT(x) \
CHECK_CPU(x); \
CHECK_CONTIGUOUS(x)
Expand Down
26 changes: 26 additions & 0 deletions mmcv/ops/csrc/common/pytorch_mlu_helper.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*************************************************************************
* Copyright (C) 2021 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef PYTORCH_MLU_HELPER_HPP_
#define PYTORCH_MLU_HELPER_HPP_

#ifdef MMCV_WITH_MLU
#include "aten.h"

#define NFU_ALIGN_SIZE 128

#define PAD_UP(x, y) (((x) / (y) + (int)((x) % (y) > 0)) * (y))

#define PAD_DOWN(x, y) (((x) / (y)) * (y))

#endif

#endif // PYTORCH_MLU_HELPER_HPP_
17 changes: 14 additions & 3 deletions mmcv/runner/dist_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from torch._utils import (_flatten_dense_tensors, _take_tensors,
_unflatten_dense_tensors)

from mmcv.device.mlu import IS_MLU


def _find_free_port():
# Copied from https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py # noqa: E501
Expand Down Expand Up @@ -47,9 +49,18 @@ def init_dist(launcher, backend='nccl', **kwargs):
def _init_dist_pytorch(backend, **kwargs):
# TODO: use local_rank instead of rank % num_gpus
rank = int(os.environ['RANK'])
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(backend=backend, **kwargs)
if IS_MLU:
import torch_mlu # noqa: F401
torch.mlu.set_device(rank)
dist.init_process_group(
backend='cncl',
rank=rank,
world_size=int(os.environ['WORLD_SIZE']),
**kwargs)
else:
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(backend=backend, **kwargs)


def _init_dist_mpi(backend, **kwargs):
Expand Down
18 changes: 18 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
if torch.__version__ == 'parrots':
from parrots.utils.build_extension import BuildExtension
EXT_TYPE = 'parrots'
elif (hasattr(torch, 'is_mlu_available') and torch.is_mlu_available()) or \
os.getenv('FORCE_MLU', '0') == '1':
from torch_mlu.utils.cpp_extension import BuildExtension
EXT_TYPE = 'pytorch'
else:
from torch.utils.cpp_extension import BuildExtension
EXT_TYPE = 'pytorch'
Expand Down Expand Up @@ -288,6 +292,20 @@ def get_extensions():
extension = CUDAExtension
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common'))
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/cuda'))
elif (hasattr(torch, 'is_mlu_available') and
torch.is_mlu_available()) or \
os.getenv('FORCE_MLU', '0') == '1':
from torch_mlu.utils.cpp_extension import MLUExtension
define_macros += [('MMCV_WITH_MLU', None)]
mlu_args = os.getenv('MMCV_MLU_ARGS')
extra_compile_args['cncc'] = [mlu_args] if mlu_args else []
op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp') + \
glob.glob('./mmcv/ops/csrc/pytorch/cpu/*.cpp') + \
glob.glob('./mmcv/ops/csrc/pytorch/mlu/*.cpp') + \
glob.glob('./mmcv/ops/csrc/pytorch/mlu/*.mlu')
extension = MLUExtension
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common'))
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/mlu'))
else:
print(f'Compiling {ext_name} without CUDA')
op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp') + \
Expand Down
97 changes: 97 additions & 0 deletions tests/test_device/test_mlu/test_mlu_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest.mock import MagicMock, patch

import pytest
import torch
import torch.nn as nn

from mmcv.device.mlu import IS_MLU, MLUDataParallel, MLUDistributedDataParallel
from mmcv.device.mlu._functions import Scatter, scatter
from mmcv.parallel import is_module_wrapper


def mock(*args, **kwargs):
pass


@patch('torch.distributed._broadcast_coalesced', mock)
@patch('torch.distributed.broadcast', mock)
@patch('torch.nn.parallel.DistributedDataParallel._ddp_init_helper', mock)
def test_is_module_wrapper():

class Model(nn.Module):

def __init__(self):
super().__init__()
self.conv = nn.Conv2d(2, 2, 1)

def forward(self, x):
return self.conv(x)

model = Model()
assert not is_module_wrapper(model)

if IS_MLU:
mludp = MLUDataParallel(model)
assert is_module_wrapper(mludp)

mluddp = MLUDistributedDataParallel(model, process_group=MagicMock())
assert is_module_wrapper(mluddp)


def test_scatter():
# if the device is CPU, just return the input
input = torch.zeros([1, 3, 3, 3])
output = scatter(input=input, devices=[-1])
assert torch.allclose(input, output)

inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
outputs = scatter(input=inputs, devices=[-1])
for input, output in zip(inputs, outputs):
assert torch.allclose(input, output)

# if the device is MLU, copy the input from CPU to MLU
if IS_MLU:
input = torch.zeros([1, 3, 3, 3])
output = scatter(input=input, devices=[0])
assert torch.allclose(input.to('mlu'), output)

inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
outputs = scatter(input=inputs, devices=[0])
for input, output in zip(inputs, outputs):
assert torch.allclose(input.to('mlu'), output)

# input should be a tensor or list of tensor
with pytest.raises(Exception):
scatter(5, [-1])


def test_Scatter():
# if the device is CPU, just return the input
target_mlus = [-1]
input = torch.zeros([1, 3, 3, 3])
outputs = Scatter.forward(target_mlus, input)
assert isinstance(outputs, tuple)
assert torch.allclose(input, outputs[0])

target_mlus = [-1]
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
outputs = Scatter.forward(target_mlus, inputs)
assert isinstance(outputs, tuple)
for input, output in zip(inputs, outputs):
assert torch.allclose(input, output)

# if the device is MLU, copy the input from CPU to MLU
if IS_MLU:
target_mlus = [0]
input = torch.zeros([1, 3, 3, 3])
outputs = Scatter.forward(target_mlus, input)
assert isinstance(outputs, tuple)
assert torch.allclose(input.to('mlu'), outputs[0])

target_mlus = [0]
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
outputs = Scatter.forward(target_mlus, inputs)
assert isinstance(outputs, tuple)
for input, output in zip(inputs, outputs):
assert torch.allclose(input.to('mlu'), output[0])