|
| 1 | +# Copyright (c) Microsoft Corporation. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +# DeepSpeed Team |
| 5 | +import importlib |
| 6 | +import inspect |
| 7 | +import functools |
| 8 | + |
| 9 | +from .abstract_accelerator import DeepSpeedAccelerator |
| 10 | +import torch |
| 11 | +# During setup stage torch may not be installed, pass on no torch will |
| 12 | +# allow op builder related API to be executed. |
| 13 | + |
| 14 | + |
| 15 | +class MLU_Accelerator(DeepSpeedAccelerator): |
| 16 | + |
| 17 | + def __init__(self): |
| 18 | + self._name = 'mlu' |
| 19 | + self._communication_backend_name = 'cncl' |
| 20 | + self._compile_backend = "inductor" |
| 21 | + self.class_dict = None |
| 22 | + |
| 23 | + def is_synchronized_device(self): |
| 24 | + return False |
| 25 | + |
| 26 | + def use_host_timers(self): |
| 27 | + return self.is_synchronized_device() |
| 28 | + |
| 29 | + def resolves_data_dependency(self): |
| 30 | + return self.is_synchronized_device() |
| 31 | + |
| 32 | + def handles_memory_backpressure(self): |
| 33 | + return self.is_synchronized_device() |
| 34 | + |
| 35 | + # Device APIs |
| 36 | + def device_name(self, device_index=None): |
| 37 | + if device_index == None: |
| 38 | + return 'mlu' |
| 39 | + return 'mlu:{}'.format(device_index) |
| 40 | + |
| 41 | + def device(self, device_index=None): |
| 42 | + return torch.mlu.device(device_index) |
| 43 | + |
| 44 | + def set_device(self, device_index): |
| 45 | + torch.mlu.set_device(device_index) |
| 46 | + |
| 47 | + def current_device(self): |
| 48 | + return torch.mlu.current_device() |
| 49 | + |
| 50 | + def current_device_name(self): |
| 51 | + return 'mlu:{}'.format(torch.mlu.current_device()) |
| 52 | + |
| 53 | + def device_count(self): |
| 54 | + return torch.mlu.device_count() |
| 55 | + |
| 56 | + def synchronize(self, device_index=None): |
| 57 | + return torch.mlu.synchronize(device_index) |
| 58 | + |
| 59 | + # RNG APIs |
| 60 | + def random(self): |
| 61 | + return torch.random |
| 62 | + |
| 63 | + def set_rng_state(self, new_state, device_index=None): |
| 64 | + if device_index is None: |
| 65 | + return torch.mlu.set_rng_state(new_state) |
| 66 | + |
| 67 | + return torch.mlu.set_rng_state(new_state, device_index) |
| 68 | + |
| 69 | + def get_rng_state(self, device_index=None): |
| 70 | + if device_index is None: |
| 71 | + return torch.mlu.get_rng_state() |
| 72 | + |
| 73 | + return torch.mlu.get_rng_state(device_index) |
| 74 | + |
| 75 | + def manual_seed(self, seed): |
| 76 | + return torch.mlu.manual_seed(seed) |
| 77 | + |
| 78 | + def manual_seed_all(self, seed): |
| 79 | + return torch.mlu.manual_seed_all(seed) |
| 80 | + |
| 81 | + def initial_seed(self, seed): |
| 82 | + return torch.mlu.initial_seed(seed) |
| 83 | + |
| 84 | + def default_generator(self, device_index): |
| 85 | + return torch.mlu.default_generators[device_index] |
| 86 | + |
| 87 | + # Streams/Events |
| 88 | + @property |
| 89 | + def Stream(self): |
| 90 | + return torch.mlu.Stream |
| 91 | + |
| 92 | + def stream(self, stream): |
| 93 | + return torch.mlu.stream(stream) |
| 94 | + |
| 95 | + def current_stream(self, device_index=None): |
| 96 | + return torch.mlu.current_stream(device_index) |
| 97 | + |
| 98 | + def default_stream(self, device_index=None): |
| 99 | + return torch.mlu.default_stream(device_index) |
| 100 | + |
| 101 | + @property |
| 102 | + def Event(self): |
| 103 | + return torch.mlu.Event |
| 104 | + |
| 105 | + # Memory management |
| 106 | + def empty_cache(self): |
| 107 | + return torch.mlu.empty_cache() |
| 108 | + |
| 109 | + def memory_allocated(self, device_index=None): |
| 110 | + return torch.mlu.memory_allocated(device_index) |
| 111 | + |
| 112 | + def max_memory_allocated(self, device_index=None): |
| 113 | + return torch.mlu.max_memory_allocated(device_index) |
| 114 | + |
| 115 | + def reset_max_memory_allocated(self, device_index=None): |
| 116 | + return torch.mlu.reset_max_memory_allocated(device_index) |
| 117 | + |
| 118 | + def memory_cached(self, device_index=None): |
| 119 | + return torch.mlu.memory_cached(device_index) |
| 120 | + |
| 121 | + def max_memory_cached(self, device_index=None): |
| 122 | + return torch.mlu.max_memory_cached(device_index) |
| 123 | + |
| 124 | + def reset_max_memory_cached(self, device_index=None): |
| 125 | + return torch.mlu.reset_max_memory_cached(device_index) |
| 126 | + |
| 127 | + def memory_stats(self, device_index=None): |
| 128 | + if hasattr(torch.mlu, 'memory_stats'): |
| 129 | + return torch.mlu.memory_stats(device_index) |
| 130 | + |
| 131 | + def reset_peak_memory_stats(self, device_index=None): |
| 132 | + if hasattr(torch.mlu, 'reset_peak_memory_stats'): |
| 133 | + return torch.mlu.reset_peak_memory_stats(device_index) |
| 134 | + |
| 135 | + def memory_reserved(self, device_index=None): |
| 136 | + if hasattr(torch.mlu, 'memory_reserved'): |
| 137 | + return torch.mlu.memory_reserved(device_index) |
| 138 | + |
| 139 | + def max_memory_reserved(self, device_index=None): |
| 140 | + if hasattr(torch.mlu, 'max_memory_reserved'): |
| 141 | + return torch.mlu.max_memory_reserved(device_index) |
| 142 | + |
| 143 | + def total_memory(self, device_index=None): |
| 144 | + return torch.mlu.get_device_properties(device_index).total_memory |
| 145 | + |
| 146 | + def available_memory(self, device_index=None): |
| 147 | + return self.total_memory(device_index) - self.memory_allocated(device_index) |
| 148 | + |
| 149 | + # Data types |
| 150 | + def is_bf16_supported(self): |
| 151 | + return torch.mlu.is_bf16_supported() |
| 152 | + |
| 153 | + def is_fp16_supported(self): |
| 154 | + return True |
| 155 | + |
| 156 | + def supported_dtypes(self): |
| 157 | + supported_dtypes = [torch.float] |
| 158 | + if self.is_fp16_supported(): |
| 159 | + supported_dtypes.append(torch.half) |
| 160 | + if self.is_bf16_supported(): |
| 161 | + supported_dtypes.append(torch.bfloat16) |
| 162 | + return supported_dtypes |
| 163 | + |
| 164 | + # Misc |
| 165 | + def amp(self): |
| 166 | + if hasattr(torch.mlu, 'amp'): |
| 167 | + return torch.mlu.amp |
| 168 | + return None |
| 169 | + |
| 170 | + def is_available(self): |
| 171 | + return torch.mlu.is_available() |
| 172 | + |
| 173 | + def range_push(self, msg): |
| 174 | + if hasattr(torch.mlu.cnpx, 'range_push'): |
| 175 | + return torch.mlu.cnpx.range_push(msg) |
| 176 | + |
| 177 | + def range_pop(self): |
| 178 | + if hasattr(torch.mlu.cnpx, 'range_pop'): |
| 179 | + return torch.mlu.cnpx.range_pop() |
| 180 | + |
| 181 | + def lazy_call(self, callback): |
| 182 | + return torch.mlu._lazy_call(callback) |
| 183 | + |
| 184 | + def communication_backend_name(self): |
| 185 | + return self._communication_backend_name |
| 186 | + |
| 187 | + def is_triton_supported(self): |
| 188 | + return True |
| 189 | + |
| 190 | + # Graph operations |
| 191 | + def create_graph(self): |
| 192 | + torch.mlu.MLUGraph() |
| 193 | + |
| 194 | + def capture_to_graph(self, graph, pool=None, stream=None): |
| 195 | + return torch.mlu.graph(graph, pool, stream) |
| 196 | + |
| 197 | + def replay_graph(self, graph): |
| 198 | + graph.replay() |
| 199 | + return |
| 200 | + |
| 201 | + # Tensor operations |
| 202 | + |
| 203 | + @property |
| 204 | + def BFloat16Tensor(self): |
| 205 | + return functools.partial(torch.tensor, dtype=torch.bfloat16, device='mlu') |
| 206 | + |
| 207 | + @property |
| 208 | + def ByteTensor(self): |
| 209 | + return functools.partial(torch.tensor, dtype=torch.uint8, device='mlu') |
| 210 | + |
| 211 | + @property |
| 212 | + def DoubleTensor(self): |
| 213 | + return functools.partial(torch.tensor, dtype=torch.double, device='mlu') |
| 214 | + |
| 215 | + @property |
| 216 | + def FloatTensor(self): |
| 217 | + return functools.partial(torch.tensor, dtype=torch.float, device='mlu') |
| 218 | + |
| 219 | + @property |
| 220 | + def HalfTensor(self): |
| 221 | + return functools.partial(torch.tensor, dtype=torch.half, device='mlu') |
| 222 | + |
| 223 | + @property |
| 224 | + def IntTensor(self): |
| 225 | + return functools.partial(torch.tensor, dtype=torch.int, device='mlu') |
| 226 | + |
| 227 | + @property |
| 228 | + def LongTensor(self): |
| 229 | + return functools.partial(torch.tensor, dtype=torch.long, device='mlu') |
| 230 | + |
| 231 | + def pin_memory(self, tensor): |
| 232 | + return tensor.pin_memory() |
| 233 | + |
| 234 | + def is_pinned(self, tensor): |
| 235 | + return tensor.is_pinned() |
| 236 | + |
| 237 | + def on_accelerator(self, tensor): |
| 238 | + device_str = str(tensor.device) |
| 239 | + if device_str.startswith('mlu:'): |
| 240 | + return True |
| 241 | + else: |
| 242 | + return False |
| 243 | + |
| 244 | + def op_builder_dir(self): |
| 245 | + try: |
| 246 | + # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed |
| 247 | + # if successful this also means we're doing a local install and not JIT compile path |
| 248 | + from op_builder import __deepspeed__ # noqa: F401 # type: ignore |
| 249 | + return "op_builder.mlu" |
| 250 | + except ImportError: |
| 251 | + return "deepspeed.ops.op_builder.mlu" |
| 252 | + |
| 253 | + def _lazy_init_class_dict(self): |
| 254 | + if self.class_dict: |
| 255 | + return |
| 256 | + |
| 257 | + op_builder_module = importlib.import_module(self.op_builder_dir()) |
| 258 | + |
| 259 | + # get op builder class from op_builder/mlu/__init__.py |
| 260 | + self.class_dict = {} |
| 261 | + for class_name, class_obj in inspect.getmembers(op_builder_module, inspect.isclass): |
| 262 | + self.class_dict[class_name] = class_obj |
| 263 | + |
| 264 | + # create an instance of op builder and return, name specified by class_name |
| 265 | + def create_op_builder(self, class_name): |
| 266 | + builder_class = self.get_op_builder(class_name) |
| 267 | + return builder_class() |
| 268 | + |
| 269 | + # return an op builder class, name specified by class_name |
| 270 | + def get_op_builder(self, class_name): |
| 271 | + self._lazy_init_class_dict() |
| 272 | + if class_name in self.class_dict: |
| 273 | + return self.class_dict[class_name] |
| 274 | + else: |
| 275 | + return self.class_dict['NotImplementedBuilder'] |
| 276 | + |
| 277 | + def build_extension(self): |
| 278 | + from torch.utils.cpp_extension import BuildExtension |
| 279 | + return BuildExtension |
| 280 | + |
| 281 | + def export_envs(self): |
| 282 | + return ['NEUWARE_HOME', 'CNCL', 'LD_LIBRARY', 'PATH'] |
| 283 | + |
| 284 | + def visible_devices_envs(self): |
| 285 | + return ['MLU_VISIBLE_DEVICES'] |
| 286 | + |
| 287 | + def set_visible_devices_envs(self, current_env, local_accelerator_ids): |
| 288 | + for env in self.visible_devices_envs(): |
| 289 | + current_env[env] = ",".join(map(str, local_accelerator_ids)) |
| 290 | + |
| 291 | + def get_compile_backend(self): |
| 292 | + return self._compile_backend |
| 293 | + |
| 294 | + def set_compile_backend(self, backend): |
| 295 | + supported_backends = torch._dynamo.list_backends(exclude_tags=()) |
| 296 | + if backend in supported_backends: |
| 297 | + self._compile_backend = backend |
| 298 | + else: |
| 299 | + raise ValueError( |
| 300 | + f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends }") |
0 commit comments