Skip to content

Commit

Permalink
Make Ascend NPU available (#3831)
Browse files Browse the repository at this point in the history
* Make Ascend NPU available

NPU accelerator support is introduced in (#3595).
This commit provides two enhancements:
  1. Add a new accelerator_name 'npu' for choosing, it can be specified
by environment variable or auto detected.
  2. Optimize auto detect code in get_accelerator to avoid too many
layers of exception throwing.

* Use DS_ACCELERATOR_LIST for overriding accelerators

When detecting override accelerators there's an error message to show
all support accelerators, using an accelerator list instead of hard
coding accelerator names in this message.

And fix code format issue(yapf).

* Add HCCL backend

HCCL is the distribute backend of Ascend NPU, it already implemented in
npu plugin for pytorch (https://gitee.com/ascend/pytorch). Add HCCL
backend as a not implemented backend to avoid not supported warning.

* Add NPUNotImplementedBuilder

Ascend NPU does not implement any op yet, leave npu folder empty will
throw NoneType[op_name] when not supported op is called. Add this
NPUNotImplementedBuilder as the default builder.

* Optimize builder search logic

1. cpu and other backend implement their ops in sub dirs under
op_builder, cuda_accelerator should skip these sub dirs.
2. Each backend will have its own NotImplementedBuilder, add device
prefix to this class to distinguish.

* Change the unimplemented builder name to the same for each backend
  • Loading branch information
hipudding authored Jul 22, 2023
1 parent 19d5c03 commit 23a11a3
Show file tree
Hide file tree
Showing 8 changed files with 143 additions and 51 deletions.
9 changes: 6 additions & 3 deletions accelerator/cuda_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,12 @@ def _lazy_init_class_dict(self):
# put all valid class name <--> class type mapping into class_dict
op_builder_dir = self.op_builder_dir()
op_builder_module = importlib.import_module(op_builder_dir)
for _, module_name, _ in pkgutil.iter_modules([os.path.dirname(op_builder_module.__file__)]):
# avoid self references
if module_name != 'all_ops' and module_name != 'builder' and module_name != 'cpu':
op_builder_absolute_path = os.path.dirname(op_builder_module.__file__)
for _, module_name, _ in pkgutil.iter_modules([op_builder_absolute_path]):
# avoid self references,
# skip sub_directories which contains ops for other backend(cpu, npu, etc.).
if module_name != 'all_ops' and module_name != 'builder' and not os.path.isdir(
os.path.join(op_builder_absolute_path, module_name)):
module = importlib.import_module("{}.{}".format(op_builder_dir, module_name))
for member_name in module.__dir__():
if member_name.endswith(
Expand Down
42 changes: 15 additions & 27 deletions accelerator/npu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,8 @@

# DeepSpeed Team

import torch
from .abstract_accelerator import DeepSpeedAccelerator
# During setup stage torch may not be installed, pass on no torch will
# allow op builder related API to be executed.
try:
import torch.npu
except ImportError:
pass


class NPU_Accelerator(DeepSpeedAccelerator):
Expand Down Expand Up @@ -209,32 +204,25 @@ def op_builder_dir(self):
except ImportError:
return "deepspeed.ops.op_builder.npu"

# dict that holds class name <--> class type mapping i.e.
# 'AsyncIOBuilder': <class 'op_builder.async_io.AsyncIOBuilder'>
# this dict will be filled at init stage
class_dict = None

def _lazy_init_class_dict(self):
if self.class_dict != None:
return
else:
self.class_dict = {}

# create an instance of op builder and return, name specified by class_name
def create_op_builder(self, class_name):
self._lazy_init_class_dict()
if class_name in self.class_dict:
return self.class_dict[class_name]()
else:
return None
builder_class = self.get_op_builder(class_name)
if builder_class != None:
return builder_class()
return None

# return an op builder class, name specified by class_name
def get_op_builder(self, class_name):
self._lazy_init_class_dict()
if class_name in self.class_dict:
return self.class_dict[class_name]
else:
return None
try:
# is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
# if successful this also means we're doing a local install and not JIT compile path
from op_builder import __deepspeed__ # noqa: F401
from op_builder.npu import NotImplementedBuilder
except ImportError:
from deepspeed.ops.op_builder.npu import NotImplementedBuilder

# return a NPUNotImplementedBuilder to avoid get NoneType[Name] in unit tests
return NotImplementedBuilder

def build_extension(self):
from torch.utils.cpp_extension import BuildExtension
Expand Down
67 changes: 46 additions & 21 deletions accelerator/real_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def get_accelerator():
accelerator_name = None
ds_set_method = None
# 1. Detect whether there is override of DeepSpeed accelerators from environment variable.
# DS_ACCELERATOR = 'cuda'|'xpu'|'cpu'
DS_ACCELERATOR_LIST = ['cuda', 'cpu', 'xpu', 'npu', 'mps']
if "DS_ACCELERATOR" in os.environ.keys():
accelerator_name = os.environ["DS_ACCELERATOR"]
if accelerator_name == "xpu":
Expand All @@ -65,7 +65,11 @@ def get_accelerator():
except ImportError as e:
raise ValueError(
f"CPU_Accelerator requires intel_extension_for_pytorch, which is not installed on this system.")
elif accelerator_name == "cuda":
elif accelerator_name == "npu":
try:
import torch_npu # noqa: F401
except ImportError as e:
raise ValueError(f"NPU_Accelerator requires torch_npu, which is not installed on this system.")
pass
elif accelerator_name == "mps":
try:
Expand All @@ -75,41 +79,58 @@ def get_accelerator():
torch.mps.current_allocated_memory()
except (RuntimeError, ImportError) as e:
raise ValueError(f"MPS_Accelerator requires torch.mps, which is not installed on this system.")
elif accelerator_name == "cuda":
pass
else:
raise ValueError(
f'DS_ACCELERATOR must be one of "cuda", "cpu", or "xpu". Value "{accelerator_name}" is not supported')
f'DS_ACCELERATOR must be one of {DS_ACCELERATOR_LIST}. Value "{accelerator_name}" is not supported')
ds_set_method = "override"

# 2. If no override, detect which accelerator to use automatically
if accelerator_name == None:
# We need a way to choose among different accelerator types.
# Currently we detect which accelerator extension is installed
# in the environment and use it if the installing answer is True.
# An alternative might be detect whether CUDA device is installed on
# the system but this comes with two pitfalls:
# 1. the system may not have torch pre-installed, so
# get_accelerator().is_available() may not work.
# 2. Some scenario like install on login node (without CUDA device)
# and run on compute node (with CUDA device) may cause mismatch
# between installation time and runtime.

try:
from intel_extension_for_deepspeed import XPU_Accelerator # noqa: F401,F811

accelerator_name = "xpu"
except ImportError as e:
# We need a way to choose between CUDA_Accelerator and CPU_Accelerator
# Currently we detect whether intel_extension_for_pytorch is installed
# in the environment and use CPU_Accelerator if the answer is True.
# An alternative might be detect whether CUDA device is installed on
# the system but this comes with two pitfalls:
# 1. the system may not have torch pre-installed, so
# get_accelerator().is_available() may not work.
# 2. Some scenario like install on login node (without CUDA device)
# and run on compute node (with CUDA device) may cause mismatch
# between installation time and runtime.
pass
if accelerator_name == None:
try:
import intel_extension_for_pytorch # noqa: F401,F811

accelerator_name = "cpu"
except ImportError as e:
try:
import torch.mps

# should use torch.mps.is_available() if it exists someday but this is used as proxy
torch.mps.current_allocated_memory()
accelerator_name = "mps"
except (RuntimeError, ImportError) as e:
accelerator_name = "cuda"
pass
if accelerator_name == None:
try:
import torch_npu # noqa: F401,F811

accelerator_name = "npu"
except ImportError as e:
pass
if accelerator_name == None:
try:
import torch.mps

# should use torch.mps.is_available() if it exists someday but this is used as proxy
torch.mps.current_allocated_memory()
accelerator_name = "mps"
except (RuntimeError, ImportError) as e:
pass
if accelerator_name == None:
accelerator_name = "cuda"

ds_set_method = "auto detect"

# 3. Set ds_accelerator accordingly
Expand All @@ -124,6 +145,10 @@ def get_accelerator():
elif accelerator_name == "xpu":
# XPU_Accelerator is already imported in detection stage
ds_accelerator = XPU_Accelerator()
elif accelerator_name == "npu":
from .npu_accelerator import NPU_Accelerator

ds_accelerator = NPU_Accelerator()
elif accelerator_name == "mps":
from .mps_accelerator import MPS_Accelerator

Expand Down
8 changes: 8 additions & 0 deletions deepspeed/comm/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
nccl_backend = None
mpi_backend = None
ccl_backend = None
hccl_backend = None

# This should be set here so all rank/size information from the launcher can be propagated
from deepspeed.comm.utils import *
Expand Down Expand Up @@ -144,6 +145,7 @@ def init_deepspeed_backend(ds_backend, timeout, init_method):
global nccl_backend
global mpi_backend
global ccl_backend
global hccl_backend

rank = int(os.getenv('RANK', '-1'))
size = int(os.getenv('WORLD_SIZE', '-1'))
Expand All @@ -157,6 +159,8 @@ def init_deepspeed_backend(ds_backend, timeout, init_method):
elif ds_backend == CCL_BACKEND:
ccl_backend = CCLBackend(rank=rank, world_size=size, timeout=timeout, init_method=init_method)
utils.logger.info(f"Initialize {ds_backend} backend")
elif ds_backend == HCCL_BACKEND:
utils.logger.warn("HCCL backend in DeepSpeed not yet implemented")
else:
utils.logger.debug(f"DeepSpeed does not support {ds_backend} backend")

Expand Down Expand Up @@ -196,6 +200,7 @@ def set_backend():
global nccl_backend
global mpi_backend
global ccl_backend
global hccl_backend

backend_name = get_accelerator().communication_backend_name()

Expand All @@ -208,6 +213,9 @@ def set_backend():
elif backend_name == CCL_BACKEND:
if ccl_backend is not None and ccl_backend.is_initialized():
cdb = ccl_backend
elif backend_name == HCCL_BACKEND:
if hccl_backend is not None and hccl_backend.is_initialized():
cdb = hccl_backend


@timed_op
Expand Down
1 change: 1 addition & 0 deletions deepspeed/comm/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
MPI_BACKEND = 'mpi'
GLOO_BACKEND = 'gloo'
SCCL_BACKEND = 'sccl'
HCCL_BACKEND = 'hccl'

DEFAULT_AML_MASTER_PORT = "54965"
DEFAULT_AML_NCCL_SOCKET_IFNAME = "^docker0,lo"
Expand Down
9 changes: 9 additions & 0 deletions op_builder/npu/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
'''Copyright The Microsoft DeepSpeed Team'''

# NPU related operators will be added in the future.

from .no_impl import NotImplementedBuilder
34 changes: 34 additions & 0 deletions op_builder/npu/builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

try:
# is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
# if successful this also means we're doing a local install and not JIT compile path
from op_builder import __deepspeed__ # noqa: F401
from op_builder.builder import OpBuilder
except ImportError:
from deepspeed.ops.op_builder.builder import OpBuilder


class NPUOpBuilder(OpBuilder):

def builder(self):
from torch.utils.cpp_extension import CppExtension as ExtensionBuilder

compile_args = {'cxx': self.strip_empty_entries(self.cxx_args())}

cpp_ext = ExtensionBuilder(name=self.absolute_name(),
sources=self.strip_empty_entries(self.sources()),
include_dirs=self.strip_empty_entries(self.include_paths()),
libraries=self.strip_empty_entries(self.libraries_args()),
extra_compile_args=compile_args)

return cpp_ext

def cxx_args(self):
return []

def libraries_args(self):
return []
24 changes: 24 additions & 0 deletions op_builder/npu/no_impl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from .builder import NPUOpBuilder


class NotImplementedBuilder(NPUOpBuilder):
BUILD_VAR = "DS_BUILD_NOT_IMPLEMENTED"
NAME = "deepspeed_not_implemented"

def __init__(self, name=None):
name = self.NAME if name is None else name
super().__init__(name=name)

def absolute_name(self):
return f'deepspeed.ops.comm.{self.NAME}_op'

def load(self, verbose=True):
raise ValueError("This op had not been implemented on NPU backend.")

def sources(self):
return []

0 comments on commit 23a11a3

Please sign in to comment.