Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
241 changes: 241 additions & 0 deletions accelerator/npu_accelerator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from .abstract_accelerator import DeepSpeedAccelerator
# During setup stage torch may not be installed, pass on no torch will
# allow op builder related API to be executed.
try:
import torch.npu
except ImportError:
pass


class NPU_Accelerator(DeepSpeedAccelerator):

def __init__(self):
self._name = 'npu'
self._communication_backend_name = 'hccl'

def is_synchronized_device(self):
return False

# Device APIs
def device_name(self, device_index=None):
if device_index == None:
return 'npu'
return 'npu:{}'.format(device_index)

def device(self, device_index=None):
return torch.npu.device(device_index)

def set_device(self, device_index):
torch.npu.set_device(device_index)

def current_device(self):
return torch.npu.current_device()

def current_device_name(self):
return 'npu:{}'.format(torch.npu.current_device())

def device_count(self):
return torch.npu.device_count()

def synchronize(self, device_index=None):
return torch.npu.synchronize(device_index)

# RNG APIs
def random(self):
return torch.random

def set_rng_state(self, new_state, device_index=None):
if device_index is None:
return torch.npu.set_rng_state(new_state)

return torch.npu.set_rng_state(new_state, device_index)

def get_rng_state(self, device_index=None):
if device_index is None:
return torch.npu.get_rng_state()

return torch.npu.get_rng_state(device_index)

def manual_seed(self, seed):
return torch.npu.manual_seed(seed)

def manual_seed_all(self, seed):
return torch.npu.manual_seed_all(seed)

def initial_seed(self, seed):
return torch.npu.initial_seed(seed)

def default_generator(self, device_index):
return torch.npu.default_generators[device_index]

# Streams/Events
@property
def Stream(self):
return torch.npu.Stream

def stream(self, stream):
return torch.npu.stream(stream)

def current_stream(self, device_index=None):
return torch.npu.current_stream(device_index)

def default_stream(self, device_index=None):
return torch.npu.default_stream(device_index)

@property
def Event(self):
return torch.npu.Event

# Memory management
def empty_cache(self):
return torch.npu.empty_cache()

def memory_allocated(self, device_index=None):
return torch.npu.memory_allocated(device_index)

def max_memory_allocated(self, device_index=None):
return torch.npu.max_memory_allocated(device_index)

def reset_max_memory_allocated(self, device_index=None):
return torch.npu.reset_max_memory_allocated(device_index)

def memory_cached(self, device_index=None):
return torch.npu.memory_cached(device_index)

def max_memory_cached(self, device_index=None):
return torch.npu.max_memory_cached(device_index)

def reset_max_memory_cached(self, device_index=None):
return torch.npu.reset_max_memory_cached(device_index)

def memory_stats(self, device_index=None):
if hasattr(torch.npu, 'memory_stats'):
return torch.npu.memory_stats(device_index)

def reset_peak_memory_stats(self, device_index=None):
if hasattr(torch.npu, 'reset_peak_memory_stats'):
return torch.npu.reset_peak_memory_stats(device_index)

def memory_reserved(self, device_index=None):
if hasattr(torch.npu, 'memory_reserved'):
return torch.npu.memory_reserved(device_index)

def max_memory_reserved(self, device_index=None):
if hasattr(torch.npu, 'max_memory_reserved'):
return torch.npu.max_memory_reserved(device_index)

def total_memory(self, device_index=None):
return torch.npu.get_device_properties(device_index).total_memory

# Data types
def is_bf16_supported(self):
return torch.npu.is_bf16_supported()

def is_fp16_supported(self):
return True

# Misc
def amp(self):
if hasattr(torch.npu, 'amp'):
return torch.npu.amp
return None

def is_available(self):
return torch.npu.is_available()

def range_push(self, msg):
return

def range_pop(self):
return

def lazy_call(self, callback):
return torch.npu._lazy_call(callback)

def communication_backend_name(self):
return self._communication_backend_name

# Tensor operations

@property
def BFloat16Tensor(self):
return torch.npu.BFloat16Tensor

@property
def ByteTensor(self):
return torch.npu.ByteTensor

@property
def DoubleTensor(self):
return torch.npu.DoubleTensor

@property
def FloatTensor(self):
return torch.npu.FloatTensor

@property
def HalfTensor(self):
return torch.npu.HalfTensor

@property
def IntTensor(self):
return torch.npu.IntTensor

@property
def LongTensor(self):
return torch.npu.LongTensor

def pin_memory(self, tensor):
return tensor.pin_memory()

def on_accelerator(self, tensor):
device_str = str(tensor.device)
if device_str.startswith('npu:'):
return True
else:
return False

def op_builder_dir(self):
try:
# is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
# if successful this also means we're doing a local install and not JIT compile path
from op_builder import __deepspeed__ # noqa: F401
return "op_builder.npu"
except ImportError:
return "deepspeed.ops.op_builder.npu"

# dict that holds class name <--> class type mapping i.e.
# 'AsyncIOBuilder': <class 'op_builder.async_io.AsyncIOBuilder'>
# this dict will be filled at init stage
class_dict = None

def _lazy_init_class_dict(self):
if self.class_dict != None:
return
else:
self.class_dict = {}

# create an instance of op builder and return, name specified by class_name
def create_op_builder(self, class_name):
self._lazy_init_class_dict()
if class_name in self.class_dict:
return self.class_dict[class_name]()
else:
return None

# return an op builder class, name specified by class_name
def get_op_builder(self, class_name):
self._lazy_init_class_dict()
if class_name in self.class_dict:
return self.class_dict[class_name]
else:
return None

def build_extension(self):
from torch.utils.cpp_extension import BuildExtension
return BuildExtension