Skip to content

Commit

Permalink
feat(mlu): Support PyTorch backend on MLU. (#1515)
Browse files Browse the repository at this point in the history
* feat(mlu): Support PyTorch backend on MLU.

* fix redundant device variable.

* Update mmseg/apis/train.py

Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>

* Update comments.

* Update mmseg/apis/train.py

* Update is_mlu_available flag.

* align util_distribution.py to mmdet.

* align util_distribution.py to mmdet.

* add build_dp, build_ddp testcase.

* Update mmseg/utils/util_distribution.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmseg/utils/util_distribution.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmseg/utils/util_distribution.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update tests/test_utils/test_util_distribution.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update tests/test_utils/test_util_distribution.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update tests/test_utils/test_util_distribution.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update tests/test_utils/test_util_distribution.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* add mmcv version check for mlu device.

Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>
Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
  • Loading branch information
3 people committed May 25, 2022
1 parent aa50358 commit 7628a61
Show file tree
Hide file tree
Showing 7 changed files with 176 additions and 18 deletions.
19 changes: 11 additions & 8 deletions mmseg/apis/train.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
import random
import warnings

import mmcv
import numpy as np
import torch
import torch.distributed as dist
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner,
build_runner, get_dist_info)
from mmcv.utils import build_from_cfg

from mmseg import digit_version
from mmseg.core import DistEvalHook, EvalHook, build_optimizer
from mmseg.datasets import build_dataloader, build_dataset
from mmseg.utils import find_latest_checkpoint, get_root_logger
from mmseg.utils import (build_ddp, build_dp, find_latest_checkpoint,
get_root_logger)


def init_random_seed(seed=None, device='cuda'):
Expand Down Expand Up @@ -99,21 +100,23 @@ def train_segmentor(model,
train_loader_cfg = {**loader_cfg, **cfg.data.get('train_dataloader', {})}
data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset]

# put model on gpus
# put model on devices
if distributed:
find_unused_parameters = cfg.get('find_unused_parameters', False)
# Sets the `find_unused_parameters` parameter in
# torch.nn.parallel.DistributedDataParallel
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
# DDP wrapper
model = build_ddp(
model,
cfg.device,
device_ids=[int(os.environ['LOCAL_RANK'])],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
else:
if not torch.cuda.is_available():
assert digit_version(mmcv.__version__) >= digit_version('1.4.4'), \
'Please use MMCV >= 1.4.4 for CPU training!'
model = MMDataParallel(model, device_ids=cfg.gpu_ids)
model = build_dp(model, cfg.device, device_ids=cfg.gpu_ids)

# build runner
optimizer = build_optimizer(model, cfg.optimizer)

Expand Down
4 changes: 3 additions & 1 deletion mmseg/datasets/samplers/distributed_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch.utils.data import DistributedSampler as _DistributedSampler

from mmseg.core.utils import sync_random_seed
from mmseg.utils import get_device


class DistributedSampler(_DistributedSampler):
Expand Down Expand Up @@ -41,7 +42,8 @@ def __init__(self,
# in the same order based on the same seed. Then different ranks
# could use different indices to select non-overlapped data from the
# same data list.
self.seed = sync_random_seed(seed)
device = get_device()
self.seed = sync_random_seed(seed, device)

def __iter__(self) -> Iterator:
"""
Expand Down
3 changes: 2 additions & 1 deletion mmseg/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from .logger import get_root_logger
from .misc import find_latest_checkpoint
from .set_env import setup_multi_processes
from .util_distribution import build_ddp, build_dp, get_device

__all__ = [
'get_root_logger', 'collect_env', 'find_latest_checkpoint',
'setup_multi_processes'
'setup_multi_processes', 'build_ddp', 'build_dp', 'get_device'
]
81 changes: 81 additions & 0 deletions mmseg/utils/util_distribution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import torch
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel

from mmseg import digit_version

dp_factory = {'cuda': MMDataParallel, 'cpu': MMDataParallel}

ddp_factory = {'cuda': MMDistributedDataParallel}


def build_dp(model, device='cuda', dim=0, *args, **kwargs):
"""build DataParallel module by device type.
if device is cuda, return a MMDataParallel module; if device is mlu,
return a MLUDataParallel module.
Args:
model (:class:`nn.Module`): module to be parallelized.
device (str): device type, cuda, cpu or mlu. Defaults to cuda.
dim (int): Dimension used to scatter the data. Defaults to 0.
Returns:
:class:`nn.Module`: parallelized module.
"""
if device == 'cuda':
model = model.cuda()
elif device == 'mlu':
assert digit_version(mmcv.__version__) >= digit_version('1.5.0'), \
'Please use MMCV >= 1.5.0 for MLU training!'
from mmcv.device.mlu import MLUDataParallel
dp_factory['mlu'] = MLUDataParallel
model = model.mlu()

return dp_factory[device](model, dim=dim, *args, **kwargs)


def build_ddp(model, device='cuda', *args, **kwargs):
"""Build DistributedDataParallel module by device type.
If device is cuda, return a MMDistributedDataParallel module;
if device is mlu, return a MLUDistributedDataParallel module.
Args:
model (:class:`nn.Module`): module to be parallelized.
device (str): device type, mlu or cuda.
Returns:
:class:`nn.Module`: parallelized module.
References:
.. [1] https://pytorch.org/docs/stable/generated/torch.nn.parallel.
DistributedDataParallel.html
"""
assert device in ['cuda', 'mlu'], 'Only available for cuda or mlu devices.'
if device == 'cuda':
model = model.cuda()
elif device == 'mlu':
assert digit_version(mmcv.__version__) >= digit_version('1.5.0'), \
'Please use MMCV >= 1.5.0 for MLU training!'
from mmcv.device.mlu import MLUDistributedDataParallel
ddp_factory['mlu'] = MLUDistributedDataParallel
model = model.mlu()

return ddp_factory[device](model, *args, **kwargs)


def is_mlu_available():
"""Returns a bool indicating if MLU is currently available."""
return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available()


def get_device():
"""Returns an available device, cpu, cuda or mlu."""
is_device_available = {
'cuda': torch.cuda.is_available(),
'mlu': is_mlu_available()
}
device_list = [k for k, v in is_device_available.items() if v]
return device_list[0] if len(device_list) == 1 else 'cpu'
68 changes: 68 additions & 0 deletions tests/test_utils/test_util_distribution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest.mock import MagicMock, patch

import mmcv
import torch
import torch.nn as nn
from mmcv.parallel import (MMDataParallel, MMDistributedDataParallel,
is_module_wrapper)

from mmseg import digit_version
from mmseg.utils import build_ddp, build_dp


def mock(*args, **kwargs):
pass


class Model(nn.Module):

def __init__(self):
super().__init__()
self.conv = nn.Conv2d(2, 2, 1)

def forward(self, x):
return self.conv(x)


@patch('torch.distributed._broadcast_coalesced', mock)
@patch('torch.distributed.broadcast', mock)
@patch('torch.nn.parallel.DistributedDataParallel._ddp_init_helper', mock)
def test_build_dp():
model = Model()
assert not is_module_wrapper(model)

mmdp = build_dp(model, 'cpu')
assert isinstance(mmdp, MMDataParallel)

if torch.cuda.is_available():
mmdp = build_dp(model, 'cuda')
assert isinstance(mmdp, MMDataParallel)

if digit_version(mmcv.__version__) >= digit_version('1.5.0'):
from mmcv.device.mlu import MLUDataParallel
from mmcv.utils import IS_MLU_AVAILABLE
if IS_MLU_AVAILABLE:
mludp = build_dp(model, 'mlu')
assert isinstance(mludp, MLUDataParallel)


@patch('torch.distributed._broadcast_coalesced', mock)
@patch('torch.distributed.broadcast', mock)
@patch('torch.nn.parallel.DistributedDataParallel._ddp_init_helper', mock)
def test_build_ddp():
model = Model()
assert not is_module_wrapper(model)

if torch.cuda.is_available():
mmddp = build_ddp(
model, 'cuda', device_id=[0], process_group=MagicMock())
assert isinstance(mmddp, MMDistributedDataParallel)

if digit_version(mmcv.__version__) >= digit_version('1.5.0'):
from mmcv.device.mlu import MLUDistributedDataParallel
from mmcv.utils import IS_MLU_AVAILABLE
if IS_MLU_AVAILABLE:
mluddp = build_ddp(
model, 'mlu', device_ids=[0], process_group=MagicMock())
assert isinstance(mluddp, MLUDistributedDataParallel)
13 changes: 7 additions & 6 deletions tools/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import mmcv
import torch
from mmcv.cnn.utils import revert_sync_batchnorm
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
wrap_fp16_model)
from mmcv.utils import DictAction
Expand All @@ -18,7 +17,7 @@
from mmseg.apis import multi_gpu_test, single_gpu_test
from mmseg.datasets import build_dataloader, build_dataset
from mmseg.models import build_segmentor
from mmseg.utils import setup_multi_processes
from mmseg.utils import build_ddp, build_dp, get_device, setup_multi_processes


def parse_args():
Expand Down Expand Up @@ -260,6 +259,7 @@ def main():
else:
tmpdir = None

cfg.device = get_device()
if not distributed:
warnings.warn(
'SyncBN is only supported with DDP. To be compatible with DP, '
Expand All @@ -269,7 +269,7 @@ def main():
assert digit_version(mmcv.__version__) >= digit_version('1.4.4'), \
'Please use MMCV >= 1.4.4 for CPU training!'
model = revert_sync_batchnorm(model)
model = MMDataParallel(model, device_ids=cfg.gpu_ids)
model = build_dp(model, cfg.device, device_ids=cfg.gpu_ids)
results = single_gpu_test(
model,
data_loader,
Expand All @@ -281,9 +281,10 @@ def main():
format_only=args.format_only or eval_on_format_results,
format_args=eval_kwargs)
else:
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
model = build_ddp(
model,
cfg.device,
device_ids=[int(os.environ['LOCAL_RANK'])],
broadcast_buffers=False)
results = multi_gpu_test(
model,
Expand Down
6 changes: 4 additions & 2 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from mmseg.apis import init_random_seed, set_random_seed, train_segmentor
from mmseg.datasets import build_dataset
from mmseg.models import build_segmentor
from mmseg.utils import collect_env, get_root_logger, setup_multi_processes
from mmseg.utils import (collect_env, get_device, get_root_logger,
setup_multi_processes)


def parse_args():
Expand Down Expand Up @@ -184,7 +185,8 @@ def main():
logger.info(f'Config:\n{cfg.pretty_text}')

# set random seeds
seed = init_random_seed(args.seed)
cfg.device = get_device()
seed = init_random_seed(args.seed, device=cfg.device)
seed = seed + dist.get_rank() if args.diff_seed else seed
logger.info(f'Set random seed to {seed}, '
f'deterministic: {args.deterministic}')
Expand Down

0 comments on commit 7628a61

Please sign in to comment.