Skip to content

Commit

Permalink
[Feature] Support mmseg with NPU backend. (#2768)
Browse files Browse the repository at this point in the history
## Motivation

Added ascending device support in mmseg.

## Modification

The main modification points are as follows:
We added an NPU device in the DDP scenario and DP scenario when using
the NPU.

## BC-breaking (Optional)

None

## Use cases (Optional)

We tested
[fcn_unet_s5-d16_4x4_512x1024_160k_cityscapes.py](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/unet/fcn_unet_s5-d16_4x4_512x1024_160k_cityscapes.py)
.
  • Loading branch information
luomaoling authored Mar 23, 2023
1 parent 49f2a71 commit ae78cb9
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 4 deletions.
11 changes: 10 additions & 1 deletion mmseg/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,11 @@ def train_segmentor(model,
logger=logger,
meta=meta))

if cfg.device == 'npu':
optimiter_config = dict(type='Fp16OptimizerHook', loss_scale='dynamic')
cfg.optimizer_config = optimiter_config if \
not cfg.optimizer_config else cfg.optimizer_config

# register hooks
runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config,
cfg.checkpoint_config, cfg.log_config,
Expand Down Expand Up @@ -187,8 +192,12 @@ def train_segmentor(model,
resume_from = find_latest_checkpoint(cfg.work_dir)
if resume_from is not None:
cfg.resume_from = resume_from

if cfg.resume_from:
runner.resume(cfg.resume_from)
if cfg.device == 'npu':
runner.resume(cfg.resume_from, map_location='npu')
else:
runner.resume(cfg.resume_from)
elif cfg.load_from:
runner.load_checkpoint(cfg.load_from)
runner.run(data_loaders, cfg.workflow)
29 changes: 26 additions & 3 deletions mmseg/utils/util_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ def build_dp(model, device='cuda', dim=0, *args, **kwargs):
dp_factory['mlu'] = MLUDataParallel
model = model.mlu()

elif device == 'npu':
assert digit_version(mmcv.__version__) >= digit_version('1.7.0'), \
'Please use MMCV >= 1.7.0 for NPU training!'
from mmcv.device.npu import NPUDataParallel
torch.npu.set_compile_mode(jit_compile=False)
dp_factory['npu'] = NPUDataParallel
model = model.npu()

return dp_factory[device](model, dim=dim, *args, **kwargs)


Expand All @@ -53,7 +61,8 @@ def build_ddp(model, device='cuda', *args, **kwargs):
.. [1] https://pytorch.org/docs/stable/generated/torch.nn.parallel.
DistributedDataParallel.html
"""
assert device in ['cuda', 'mlu'], 'Only available for cuda or mlu devices.'
assert device in ['cuda', 'mlu', 'npu'], 'Only available for cuda, '\
'npu or mlu devices.'
if device == 'cuda':
model = model.cuda()
elif device == 'mlu':
Expand All @@ -63,6 +72,14 @@ def build_ddp(model, device='cuda', *args, **kwargs):
ddp_factory['mlu'] = MLUDistributedDataParallel
model = model.mlu()

elif device == 'npu':
assert digit_version(mmcv.__version__) >= digit_version('1.7.0'), \
'Please use MMCV >= 1.7.0 for NPU training!'
from mmcv.device.npu import NPUDistributedDataParallel
torch.npu.set_compile_mode(jit_compile=False)
ddp_factory['npu'] = NPUDistributedDataParallel
model = model.npu()

return ddp_factory[device](model, *args, **kwargs)


Expand All @@ -71,11 +88,17 @@ def is_mlu_available():
return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available()


def is_npu_available():
"""Returns a bool indicating if NPU is currently available."""
return hasattr(torch, 'npu') and torch.npu.is_available()


def get_device():
"""Returns an available device, cpu, cuda or mlu."""
"""Returns an available device, cpu, npu, cuda or mlu."""
is_device_available = {
'npu': is_npu_available(),
'cuda': torch.cuda.is_available(),
'mlu': is_mlu_available()
}
device_list = [k for k, v in is_device_available.items() if v]
return device_list[0] if len(device_list) == 1 else 'cpu'
return device_list[0] if len(device_list) >= 1 else 'cpu'
15 changes: 15 additions & 0 deletions tests/test_utils/test_util_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ def test_build_dp():
mludp = build_dp(model, 'mlu')
assert isinstance(mludp, MLUDataParallel)

if digit_version(mmcv.__version__) >= digit_version('1.7.0'):
from mmcv.device.npu import NPUDataParallel
from mmcv.utils import IS_NPU_AVAILABLE
if IS_NPU_AVAILABLE:
npu_dp = model.npu(model, 'npu')
assert isinstance(npu_dp, NPUDataParallel)


@patch('torch.distributed._broadcast_coalesced', mock)
@patch('torch.distributed.broadcast', mock)
Expand All @@ -66,3 +73,11 @@ def test_build_ddp():
mluddp = build_ddp(
model, 'mlu', device_ids=[0], process_group=MagicMock())
assert isinstance(mluddp, MLUDistributedDataParallel)

if digit_version(mmcv.__version__) >= digit_version('1.7.0'):
from mmcv.device.npu import NPUDistributedDataParallel
from mmcv.utils import IS_NPU_AVAILABLE
if IS_NPU_AVAILABLE:
npu_ddp = build_ddp(
model, 'npu', device_ids=[0], process_group=MagicMock())
assert isinstance(npu_ddp, NPUDistributedDataParallel)

0 comments on commit ae78cb9

Please sign in to comment.