-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(mlu): Support PyTorch backend on MLU. (#1515)
* feat(mlu): Support PyTorch backend on MLU. * fix redundant device variable. * Update mmseg/apis/train.py Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com> * Update comments. * Update mmseg/apis/train.py * Update is_mlu_available flag. * align util_distribution.py to mmdet. * align util_distribution.py to mmdet. * add build_dp, build_ddp testcase. * Update mmseg/utils/util_distribution.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmseg/utils/util_distribution.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmseg/utils/util_distribution.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update tests/test_utils/test_util_distribution.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update tests/test_utils/test_util_distribution.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update tests/test_utils/test_util_distribution.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update tests/test_utils/test_util_distribution.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * add mmcv version check for mlu device. Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com> Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
- Loading branch information
1 parent
aa50358
commit 7628a61
Showing
7 changed files
with
176 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import mmcv | ||
import torch | ||
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel | ||
|
||
from mmseg import digit_version | ||
|
||
dp_factory = {'cuda': MMDataParallel, 'cpu': MMDataParallel} | ||
|
||
ddp_factory = {'cuda': MMDistributedDataParallel} | ||
|
||
|
||
def build_dp(model, device='cuda', dim=0, *args, **kwargs): | ||
"""build DataParallel module by device type. | ||
if device is cuda, return a MMDataParallel module; if device is mlu, | ||
return a MLUDataParallel module. | ||
Args: | ||
model (:class:`nn.Module`): module to be parallelized. | ||
device (str): device type, cuda, cpu or mlu. Defaults to cuda. | ||
dim (int): Dimension used to scatter the data. Defaults to 0. | ||
Returns: | ||
:class:`nn.Module`: parallelized module. | ||
""" | ||
if device == 'cuda': | ||
model = model.cuda() | ||
elif device == 'mlu': | ||
assert digit_version(mmcv.__version__) >= digit_version('1.5.0'), \ | ||
'Please use MMCV >= 1.5.0 for MLU training!' | ||
from mmcv.device.mlu import MLUDataParallel | ||
dp_factory['mlu'] = MLUDataParallel | ||
model = model.mlu() | ||
|
||
return dp_factory[device](model, dim=dim, *args, **kwargs) | ||
|
||
|
||
def build_ddp(model, device='cuda', *args, **kwargs): | ||
"""Build DistributedDataParallel module by device type. | ||
If device is cuda, return a MMDistributedDataParallel module; | ||
if device is mlu, return a MLUDistributedDataParallel module. | ||
Args: | ||
model (:class:`nn.Module`): module to be parallelized. | ||
device (str): device type, mlu or cuda. | ||
Returns: | ||
:class:`nn.Module`: parallelized module. | ||
References: | ||
.. [1] https://pytorch.org/docs/stable/generated/torch.nn.parallel. | ||
DistributedDataParallel.html | ||
""" | ||
assert device in ['cuda', 'mlu'], 'Only available for cuda or mlu devices.' | ||
if device == 'cuda': | ||
model = model.cuda() | ||
elif device == 'mlu': | ||
assert digit_version(mmcv.__version__) >= digit_version('1.5.0'), \ | ||
'Please use MMCV >= 1.5.0 for MLU training!' | ||
from mmcv.device.mlu import MLUDistributedDataParallel | ||
ddp_factory['mlu'] = MLUDistributedDataParallel | ||
model = model.mlu() | ||
|
||
return ddp_factory[device](model, *args, **kwargs) | ||
|
||
|
||
def is_mlu_available(): | ||
"""Returns a bool indicating if MLU is currently available.""" | ||
return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available() | ||
|
||
|
||
def get_device(): | ||
"""Returns an available device, cpu, cuda or mlu.""" | ||
is_device_available = { | ||
'cuda': torch.cuda.is_available(), | ||
'mlu': is_mlu_available() | ||
} | ||
device_list = [k for k, v in is_device_available.items() if v] | ||
return device_list[0] if len(device_list) == 1 else 'cpu' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from unittest.mock import MagicMock, patch | ||
|
||
import mmcv | ||
import torch | ||
import torch.nn as nn | ||
from mmcv.parallel import (MMDataParallel, MMDistributedDataParallel, | ||
is_module_wrapper) | ||
|
||
from mmseg import digit_version | ||
from mmseg.utils import build_ddp, build_dp | ||
|
||
|
||
def mock(*args, **kwargs): | ||
pass | ||
|
||
|
||
class Model(nn.Module): | ||
|
||
def __init__(self): | ||
super().__init__() | ||
self.conv = nn.Conv2d(2, 2, 1) | ||
|
||
def forward(self, x): | ||
return self.conv(x) | ||
|
||
|
||
@patch('torch.distributed._broadcast_coalesced', mock) | ||
@patch('torch.distributed.broadcast', mock) | ||
@patch('torch.nn.parallel.DistributedDataParallel._ddp_init_helper', mock) | ||
def test_build_dp(): | ||
model = Model() | ||
assert not is_module_wrapper(model) | ||
|
||
mmdp = build_dp(model, 'cpu') | ||
assert isinstance(mmdp, MMDataParallel) | ||
|
||
if torch.cuda.is_available(): | ||
mmdp = build_dp(model, 'cuda') | ||
assert isinstance(mmdp, MMDataParallel) | ||
|
||
if digit_version(mmcv.__version__) >= digit_version('1.5.0'): | ||
from mmcv.device.mlu import MLUDataParallel | ||
from mmcv.utils import IS_MLU_AVAILABLE | ||
if IS_MLU_AVAILABLE: | ||
mludp = build_dp(model, 'mlu') | ||
assert isinstance(mludp, MLUDataParallel) | ||
|
||
|
||
@patch('torch.distributed._broadcast_coalesced', mock) | ||
@patch('torch.distributed.broadcast', mock) | ||
@patch('torch.nn.parallel.DistributedDataParallel._ddp_init_helper', mock) | ||
def test_build_ddp(): | ||
model = Model() | ||
assert not is_module_wrapper(model) | ||
|
||
if torch.cuda.is_available(): | ||
mmddp = build_ddp( | ||
model, 'cuda', device_id=[0], process_group=MagicMock()) | ||
assert isinstance(mmddp, MMDistributedDataParallel) | ||
|
||
if digit_version(mmcv.__version__) >= digit_version('1.5.0'): | ||
from mmcv.device.mlu import MLUDistributedDataParallel | ||
from mmcv.utils import IS_MLU_AVAILABLE | ||
if IS_MLU_AVAILABLE: | ||
mluddp = build_ddp( | ||
model, 'mlu', device_ids=[0], process_group=MagicMock()) | ||
assert isinstance(mluddp, MLUDistributedDataParallel) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters