Skip to content

Commit 6110f70

Browse files
authored
Merge branch 'master' into olruwase/set_zero_opt_grad
2 parents 18ee6cf + 8cded57 commit 6110f70

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+2564
-47
lines changed

.github/workflows/no-torch.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ on:
55
pull_request:
66
paths:
77
- '.github/workflows/no-torch.yml'
8+
- 'op_builder/**'
89
schedule:
910
- cron: "0 0 * * *"
1011

.github/workflows/xpu-compile.yml

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
name: xpu-compile
2+
3+
on:
4+
workflow_dispatch:
5+
schedule:
6+
- cron: "0 0 * * *"
7+
pull_request:
8+
paths:
9+
- ".github/workflows/xpu-compile.yml"
10+
11+
concurrency:
12+
group: ${{ github.workflow }}-${{ github.ref }}
13+
cancel-in-progress: true
14+
15+
permissions:
16+
contents: read
17+
issues: write
18+
19+
jobs:
20+
compile-tests:
21+
runs-on: [self-hosted, intel, xpu]
22+
container:
23+
image: intel/oneapi-basekit:2024.2.1-0-devel-ubuntu22.04
24+
ports:
25+
- 80
26+
options: --privileged -it --rm --device /dev/dri:/dev/dri -v /dev/dri/by-path:/dev/dri/by-path --ipc=host --cap-add=ALL
27+
28+
steps:
29+
- uses: actions/checkout@v4
30+
- name: Install prerequisite
31+
run: |
32+
apt-get update
33+
apt-get install clinfo libaio-dev python3-pip -y
34+
pip install torch==2.3.1 -f https://pytorch-extension.intel.com/release-whl/stable/xpu/us/torch/
35+
pip install intel-extension-for-pytorch==2.3.110+xpu -f https://pytorch-extension.intel.com/release-whl/stable/xpu/us/intel-extension-for-pytorch/
36+
pip install oneccl_bind_pt==2.3.100+xpu -f https://pytorch-extension.intel.com/release-whl/stable/xpu/us/oneccl-bind-pt/
37+
pip install torchvision==0.18.1 -f https://pytorch-extension.intel.com/release-whl/stable/xpu/us/torchvision/
38+
pip install https://github.com/intel/intel-xpu-backend-for-triton/releases/download/v3.0.0b2/triton_xpu-3.0.0b2-cp310-cp310-linux_x86_64.whl
39+
pip install py-cpuinfo numpy
40+
pip install .[dev,autotuning]
41+
42+
- name: Check container state
43+
run: |
44+
ldd --version
45+
ds_report
46+
python3 -c "import torch; print('torch:', torch.__version__, torch)"
47+
python3 -c "import torch; import intel_extension_for_pytorch; print('XPU available:', torch.xpu.is_available())"
48+
python3 -c "from deepspeed.accelerator import get_accelerator; print('accelerator:', get_accelerator()._name)"
49+
pip list
50+
51+
- name: Compile Status
52+
shell: bash
53+
run: |
54+
export FI_HMEM=system
55+
ulimit -n 1048575
56+
cd tests/torch_compile
57+
export ZE_AFFINITY_MASK=0,1
58+
deepspeed test_compile.py --deepspeed_config ds_config.json 2>&1 | tee log.txt
59+
cat log.txt | grep "'graph_breaks'" | sed 's/,/ /g' | awk '{print $2}' >> $GITHUB_STEP_SUMMARY

accelerator/mlu_accelerator.py

Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# DeepSpeed Team
5+
import importlib
6+
import inspect
7+
import functools
8+
9+
from .abstract_accelerator import DeepSpeedAccelerator
10+
import torch
11+
# During setup stage torch may not be installed, pass on no torch will
12+
# allow op builder related API to be executed.
13+
14+
15+
class MLU_Accelerator(DeepSpeedAccelerator):
16+
17+
def __init__(self):
18+
self._name = 'mlu'
19+
self._communication_backend_name = 'cncl'
20+
self._compile_backend = "inductor"
21+
self.class_dict = None
22+
23+
def is_synchronized_device(self):
24+
return False
25+
26+
def use_host_timers(self):
27+
return self.is_synchronized_device()
28+
29+
def resolves_data_dependency(self):
30+
return self.is_synchronized_device()
31+
32+
def handles_memory_backpressure(self):
33+
return self.is_synchronized_device()
34+
35+
# Device APIs
36+
def device_name(self, device_index=None):
37+
if device_index == None:
38+
return 'mlu'
39+
return 'mlu:{}'.format(device_index)
40+
41+
def device(self, device_index=None):
42+
return torch.mlu.device(device_index)
43+
44+
def set_device(self, device_index):
45+
torch.mlu.set_device(device_index)
46+
47+
def current_device(self):
48+
return torch.mlu.current_device()
49+
50+
def current_device_name(self):
51+
return 'mlu:{}'.format(torch.mlu.current_device())
52+
53+
def device_count(self):
54+
return torch.mlu.device_count()
55+
56+
def synchronize(self, device_index=None):
57+
return torch.mlu.synchronize(device_index)
58+
59+
# RNG APIs
60+
def random(self):
61+
return torch.random
62+
63+
def set_rng_state(self, new_state, device_index=None):
64+
if device_index is None:
65+
return torch.mlu.set_rng_state(new_state)
66+
67+
return torch.mlu.set_rng_state(new_state, device_index)
68+
69+
def get_rng_state(self, device_index=None):
70+
if device_index is None:
71+
return torch.mlu.get_rng_state()
72+
73+
return torch.mlu.get_rng_state(device_index)
74+
75+
def manual_seed(self, seed):
76+
return torch.mlu.manual_seed(seed)
77+
78+
def manual_seed_all(self, seed):
79+
return torch.mlu.manual_seed_all(seed)
80+
81+
def initial_seed(self, seed):
82+
return torch.mlu.initial_seed(seed)
83+
84+
def default_generator(self, device_index):
85+
return torch.mlu.default_generators[device_index]
86+
87+
# Streams/Events
88+
@property
89+
def Stream(self):
90+
return torch.mlu.Stream
91+
92+
def stream(self, stream):
93+
return torch.mlu.stream(stream)
94+
95+
def current_stream(self, device_index=None):
96+
return torch.mlu.current_stream(device_index)
97+
98+
def default_stream(self, device_index=None):
99+
return torch.mlu.default_stream(device_index)
100+
101+
@property
102+
def Event(self):
103+
return torch.mlu.Event
104+
105+
# Memory management
106+
def empty_cache(self):
107+
return torch.mlu.empty_cache()
108+
109+
def memory_allocated(self, device_index=None):
110+
return torch.mlu.memory_allocated(device_index)
111+
112+
def max_memory_allocated(self, device_index=None):
113+
return torch.mlu.max_memory_allocated(device_index)
114+
115+
def reset_max_memory_allocated(self, device_index=None):
116+
return torch.mlu.reset_max_memory_allocated(device_index)
117+
118+
def memory_cached(self, device_index=None):
119+
return torch.mlu.memory_cached(device_index)
120+
121+
def max_memory_cached(self, device_index=None):
122+
return torch.mlu.max_memory_cached(device_index)
123+
124+
def reset_max_memory_cached(self, device_index=None):
125+
return torch.mlu.reset_max_memory_cached(device_index)
126+
127+
def memory_stats(self, device_index=None):
128+
if hasattr(torch.mlu, 'memory_stats'):
129+
return torch.mlu.memory_stats(device_index)
130+
131+
def reset_peak_memory_stats(self, device_index=None):
132+
if hasattr(torch.mlu, 'reset_peak_memory_stats'):
133+
return torch.mlu.reset_peak_memory_stats(device_index)
134+
135+
def memory_reserved(self, device_index=None):
136+
if hasattr(torch.mlu, 'memory_reserved'):
137+
return torch.mlu.memory_reserved(device_index)
138+
139+
def max_memory_reserved(self, device_index=None):
140+
if hasattr(torch.mlu, 'max_memory_reserved'):
141+
return torch.mlu.max_memory_reserved(device_index)
142+
143+
def total_memory(self, device_index=None):
144+
return torch.mlu.get_device_properties(device_index).total_memory
145+
146+
def available_memory(self, device_index=None):
147+
return self.total_memory(device_index) - self.memory_allocated(device_index)
148+
149+
# Data types
150+
def is_bf16_supported(self):
151+
return torch.mlu.is_bf16_supported()
152+
153+
def is_fp16_supported(self):
154+
return True
155+
156+
def supported_dtypes(self):
157+
supported_dtypes = [torch.float]
158+
if self.is_fp16_supported():
159+
supported_dtypes.append(torch.half)
160+
if self.is_bf16_supported():
161+
supported_dtypes.append(torch.bfloat16)
162+
return supported_dtypes
163+
164+
# Misc
165+
def amp(self):
166+
if hasattr(torch.mlu, 'amp'):
167+
return torch.mlu.amp
168+
return None
169+
170+
def is_available(self):
171+
return torch.mlu.is_available()
172+
173+
def range_push(self, msg):
174+
if hasattr(torch.mlu.cnpx, 'range_push'):
175+
return torch.mlu.cnpx.range_push(msg)
176+
177+
def range_pop(self):
178+
if hasattr(torch.mlu.cnpx, 'range_pop'):
179+
return torch.mlu.cnpx.range_pop()
180+
181+
def lazy_call(self, callback):
182+
return torch.mlu._lazy_call(callback)
183+
184+
def communication_backend_name(self):
185+
return self._communication_backend_name
186+
187+
def is_triton_supported(self):
188+
return True
189+
190+
# Graph operations
191+
def create_graph(self):
192+
torch.mlu.MLUGraph()
193+
194+
def capture_to_graph(self, graph, pool=None, stream=None):
195+
return torch.mlu.graph(graph, pool, stream)
196+
197+
def replay_graph(self, graph):
198+
graph.replay()
199+
return
200+
201+
# Tensor operations
202+
203+
@property
204+
def BFloat16Tensor(self):
205+
return functools.partial(torch.tensor, dtype=torch.bfloat16, device='mlu')
206+
207+
@property
208+
def ByteTensor(self):
209+
return functools.partial(torch.tensor, dtype=torch.uint8, device='mlu')
210+
211+
@property
212+
def DoubleTensor(self):
213+
return functools.partial(torch.tensor, dtype=torch.double, device='mlu')
214+
215+
@property
216+
def FloatTensor(self):
217+
return functools.partial(torch.tensor, dtype=torch.float, device='mlu')
218+
219+
@property
220+
def HalfTensor(self):
221+
return functools.partial(torch.tensor, dtype=torch.half, device='mlu')
222+
223+
@property
224+
def IntTensor(self):
225+
return functools.partial(torch.tensor, dtype=torch.int, device='mlu')
226+
227+
@property
228+
def LongTensor(self):
229+
return functools.partial(torch.tensor, dtype=torch.long, device='mlu')
230+
231+
def pin_memory(self, tensor):
232+
return tensor.pin_memory()
233+
234+
def is_pinned(self, tensor):
235+
return tensor.is_pinned()
236+
237+
def on_accelerator(self, tensor):
238+
device_str = str(tensor.device)
239+
if device_str.startswith('mlu:'):
240+
return True
241+
else:
242+
return False
243+
244+
def op_builder_dir(self):
245+
try:
246+
# is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
247+
# if successful this also means we're doing a local install and not JIT compile path
248+
from op_builder import __deepspeed__ # noqa: F401 # type: ignore
249+
return "op_builder.mlu"
250+
except ImportError:
251+
return "deepspeed.ops.op_builder.mlu"
252+
253+
def _lazy_init_class_dict(self):
254+
if self.class_dict:
255+
return
256+
257+
op_builder_module = importlib.import_module(self.op_builder_dir())
258+
259+
# get op builder class from op_builder/mlu/__init__.py
260+
self.class_dict = {}
261+
for class_name, class_obj in inspect.getmembers(op_builder_module, inspect.isclass):
262+
self.class_dict[class_name] = class_obj
263+
264+
# create an instance of op builder and return, name specified by class_name
265+
def create_op_builder(self, class_name):
266+
builder_class = self.get_op_builder(class_name)
267+
return builder_class()
268+
269+
# return an op builder class, name specified by class_name
270+
def get_op_builder(self, class_name):
271+
self._lazy_init_class_dict()
272+
if class_name in self.class_dict:
273+
return self.class_dict[class_name]
274+
else:
275+
return self.class_dict['NotImplementedBuilder']
276+
277+
def build_extension(self):
278+
from torch.utils.cpp_extension import BuildExtension
279+
return BuildExtension
280+
281+
def export_envs(self):
282+
return ['NEUWARE_HOME', 'CNCL', 'LD_LIBRARY', 'PATH']
283+
284+
def visible_devices_envs(self):
285+
return ['MLU_VISIBLE_DEVICES']
286+
287+
def set_visible_devices_envs(self, current_env, local_accelerator_ids):
288+
for env in self.visible_devices_envs():
289+
current_env[env] = ",".join(map(str, local_accelerator_ids))
290+
291+
def get_compile_backend(self):
292+
return self._compile_backend
293+
294+
def set_compile_backend(self, backend):
295+
supported_backends = torch._dynamo.list_backends(exclude_tags=())
296+
if backend in supported_backends:
297+
self._compile_backend = backend
298+
else:
299+
raise ValueError(
300+
f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends }")

0 commit comments

Comments
 (0)