Skip to content

Commit 19b01e1

Browse files
delockLiangliang-Matjruwaseloadams
authored
Add accelerator setup guides (#5827)
This document provide a places to hold accelerator setup guides. It is intend to be a single place to lookup installation guides for different accelerators. Currently CPU and XPU setup guides are added to this document and could be extended to other accelerators. --------- Co-authored-by: Liangliang Ma <1906710196@qq.com> Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
1 parent a8d1b44 commit 19b01e1

File tree

2 files changed

+137
-64
lines changed

2 files changed

+137
-64
lines changed

docs/_tutorials/accelerator-abstraction-interface.md

Lines changed: 3 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ tags: getting-started
1212
- [Tensor operations](#tensor-operations)
1313
- [Communication backend](#communication-backend)
1414
- [Run DeepSpeed model on different accelerators](#run-deepspeed-model-on-different-accelerators)
15-
- [Run DeepSpeed model on CPU](#run-deepspeed-model-on-cpu)
1615
- [Implement new accelerator extension](#implement-new-accelerator-extension)
1716

1817
# Introduction
@@ -79,69 +78,9 @@ torch.distributed.init_process_group(get_accelerator().communication_backend_nam
7978
```
8079

8180
# Run DeepSpeed model on different accelerators
82-
Once a model is ported with DeepSpeed Accelerator Abstraction Interface, we can run this model on different accelerators using an extension to DeepSpeed. DeepSpeed checks whether a certain extension is installed in the environment to decide whether to use the Accelerator backend in that extension. For example, if we wish to run a model on Intel GPU, we can install _Intel Extension for DeepSpeed_ following the instructions in the following [link](https://github.com/intel/intel-extension-for-deepspeed/)
83-
84-
After the extension is installed, install DeepSpeed and run the model. The model will be running on top of DeepSpeed. Because DeepSpeed installation is also accelerator related, it is recommended to install DeepSpeed accelerator extension before installing DeepSpeed.
85-
86-
`CUDA_Accelerator` is the default accelerator in DeepSpeed. If no other DeepSpeed accelerator extension is installed, `CUDA_Accelerator` will be used.
87-
88-
When running a model on different accelerators in a cloud environment, the recommended practice is to provision an environment for each accelerator in a different env with tools such as _anaconda/miniconda/virtualenv_. When running models on different Accelerator, load the env accordingly.
89-
90-
Note that different accelerator may have different 'flavor' of float16 or bfloat16. So it is recommended to make the model configurable for both float16 and bfloat16, in that way model code does not need to be changed when running on different accelerators.
91-
92-
# Run DeepSpeed model on CPU
93-
DeepSpeed support using CPU as accelerator. DeepSpeed model using DeepSpeed Accelerator Abstraction Interface could run on CPU without change to model code. DeepSpeed decide whether _Intel Extension for PyTorch_ is installed in the environment. If this packaged is installed, DeepSpeed will use CPU as accelerator. Otherwise CUDA device will be used as accelerator.
94-
95-
To run DeepSpeed model on CPU, use the following steps to prepare environment:
96-
97-
```
98-
python -m pip install intel_extension_for_pytorch
99-
python -m pip install oneccl_bind_pt -f https://developer.intel.com/ipex-whl-stable-cpu
100-
git clone https://github.com/oneapi-src/oneCCL
101-
cd oneCCL
102-
mkdir build
103-
cd build
104-
cmake ..
105-
make
106-
make install
107-
```
108-
109-
Before run CPU workload, we need to source oneCCL environment variables
110-
```
111-
source <path-to-oneCCL>/build/_install/env/setvars.sh
112-
```
113-
114-
After environment is prepared, we can launch DeepSpeed inference with the following command
115-
```
116-
deepspeed --bind_cores_to_rank <deepspeed-model-script>
117-
```
118-
119-
This command would launch number of workers equal to number of CPU sockets on the system. Currently DeepSpeed support running inference model with AutoTP on top of CPU. The argument `--bind_cores_to_rank` distribute CPU cores on the system evenly among workers, to allow each worker running on a dedicated set of CPU cores.
120-
121-
On CPU system, there might be daemon process that periodically activate which would increase variance of each worker. One practice is leave a couple of cores for daemon process using `--bind-core-list` argument:
122-
123-
```
124-
deepspeed --bind_cores_to_rank --bind_core_list 0-51,56-107 <deepspeed-model-script>
125-
```
126-
127-
The command above leave 4 cores on each socket to daemon process (assume two sockets, each socket has 56 cores).
128-
129-
We can also set an arbitrary number of workers. Unlike GPU, CPU cores on host can be further divided into subgroups. When this number is not set, DeepSpeed would detect number of NUMA nodes on the system and launch one worker for each NUMA node.
130-
131-
```
132-
deepspeed --num_accelerators 4 --bind_cores_to_rank <deepspeed-model-script>
133-
```
134-
135-
Launching DeepSpeed model on multiple CPU nodes is similar to other accelerators. We need to specify `impi` as launcher and specify `--bind_cores_to_rank` for better core binding. Also specify `slots` number according to number of CPU sockets in host file.
136-
137-
```
138-
# hostfile content should follow the format
139-
# worker-1-hostname slots=<#sockets>
140-
# worker-2-hostname slots=<#sockets>
141-
# ...
142-
143-
deepspeed --hostfile=<hostfile> --bind_cores_to_rank --launcher impi --master_addr <master-ip> <deepspeed-model-script>
144-
```
81+
[Accelerator Setup Guide](accelerator-setup-guide.md) provides a guide on how to setup different accelerators for DeepSpeed. It also comes with simple example how to run deepspeed for different accelerators. The following guides are provided:
82+
1. Run DeepSpeed model on CPU
83+
2. Run DeepSpeed model on XPU
14584

14685
# Implement new accelerator extension
14786
It is possible to implement a new DeepSpeed accelerator extension to support new accelerator in DeepSpeed. An example to follow is _[Intel Extension For DeepSpeed](https://github.com/intel/intel-extension-for-deepspeed/)_. An accelerator extension contains the following components:
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
---
2+
title: DeepSpeed Accelerator Setup Guides
3+
tags: getting-started
4+
---
5+
6+
# Contents
7+
- [Contents](#contents)
8+
- [Introduction](#introduction)
9+
- [Intel Architecture (IA) CPU](#intel-architecture-ia-cpu)
10+
- [Intel XPU](#intel-xpu)
11+
12+
# Introduction
13+
DeepSpeed supports different accelerators from different companies. Setup steps to run DeepSpeed on certain accelerators might be different. This guide allows user to lookup setup instructions for the accelerator family and hardware they are using.
14+
15+
# Intel Architecture (IA) CPU
16+
DeepSpeed supports CPU with Intel Architecture instruction set. It is recommended to have the CPU support at least AVX2 instruction set and recommend AMX instruction set.
17+
18+
DeepSpeed has been verified on the following CPU processors:
19+
* 4th Gen Intel® Xeon® Scalarable Processors
20+
* 5th Gen Intel® Xeon® Scalarable Processors
21+
* 6th Gen Intel® Xeon® Scalarable Processors
22+
23+
## Installation steps for Intel Architecture CPU
24+
To install DeepSpeed on Intel Architecture CPU, use the following steps:
25+
1. Install gcc compiler
26+
DeepSpeed requires gcc-9 or above to build kernels on Intel Architecture CPU, install gcc-9 or above.
27+
28+
2. Install numactl
29+
DeepSpeed use `numactl` for fine grain CPU core allocation for load-balancing, install numactl on your system.
30+
For example, on Ubuntu system, use the following command:
31+
`sudo apt-get install numactl`
32+
33+
3. Install PyTorch
34+
`pip install torch`
35+
36+
4. Install DeepSpeed
37+
`pip install deepspeed`
38+
39+
## How to launch DeepSpeed on Intel Architecture CPU
40+
DeepSpeed can launch on Intel Architecture CPU with default deepspeed command. However, for compute intensive workloads, Intel Architecture CPU works best when each worker process runs on different set of physical CPU cores, so worker process does not compete CPU cores with each other. To bind cores to each worker (rank), use the following command line switch for better performance.
41+
```
42+
deepspeed --bind_cores_to_rank <deepspeed-model-script>
43+
```
44+
This switch would automatically detect the number of CPU NUMA node on the host, launch the same number of workers, and bind each worker to cores/memory of a different NUMA node. This improves performance by ensuring workers do not interfere with each other, and that all memory allocation is from local memory.
45+
46+
If a user wishes to have more control on the number of workers and specific cores that can be used by the workload, user can use the following command line switches.
47+
```
48+
deepspeed --num_accelerators <number-of-workers> --bind_cores_to_rank --bind_core_list <comma-seperated-dash-range> <deepspeed-model-script>
49+
```
50+
For example:
51+
```
52+
deepspeed --num_accelerators 4 --bind_cores_to_rank --bind_core_list <0-27,32-59> inference.py
53+
```
54+
This would start 4 workers for the workload. The core list range will be divided evenly between 4 workers, with worker 0 take 0-13, worker 1, take 14-27, worker 2 take 32-45, and worker 3 take 46-59. Core 28-31,60-63 are left out because there might be some background process running on the system, leaving some idle cores will reduce performance jitting and straggler effect.
55+
56+
Launching DeepSpeed model on multiple CPU nodes is similar to other accelerators. We need to specify `impi` as launcher and specify `--bind_cores_to_rank` for better core binding. Also specify `slots` number according to number of CPU sockets in host file.
57+
58+
```
59+
# hostfile content should follow the format
60+
# worker-1-hostname slots=<#sockets>
61+
# worker-2-hostname slots=<#sockets>
62+
# ...
63+
64+
deepspeed --hostfile=<hostfile> --bind_cores_to_rank --launcher impi --master_addr <master-ip> <deepspeed-model-script>
65+
```
66+
67+
## Install with Intel Extension for PyTorch and oneCCL
68+
Although not mandatory, Intel Extension for PyTorch and Intel oneCCL provide better optimizations for LLM models. Intel oneCCL also provide optimization when running LLM model on multi-node. To use DeepSpeed with Intel Extension for PyTorch and oneCCL, use the following steps:
69+
1. Install Intel Extension for PyTorch. This is suggested if you want to get better LLM inference performance on CPU.
70+
`pip install intel-extension-for-pytorch`
71+
72+
The following steps are to install oneCCL binding for PyTorch. This is suggested if you are running DeepSpeed on multiple CPU node, for better communication performance. On single node with multiple CPU socket, these steps are not needed.
73+
74+
2. Install Intel oneCCL binding for PyTorch
75+
`python -m pip install oneccl_bind_pt -f https://developer.intel.com/ipex-whl-stable-cpu`
76+
77+
3. Install Intel oneCCL, this will be used to build direct oneCCL kernels (CCLBackend kernels)
78+
```
79+
pip install oneccl-devel
80+
pip install impi-devel
81+
```
82+
Then set the environment variables for Intel oneCCL (assuming using conda environment).
83+
```
84+
export CPATH=${CONDA_PREFIX}/include:$CPATH
85+
export CCL_ROOT=${CONDA_PREFIX}
86+
export I_MPI_ROOT=${CONDA_PREFIX}
87+
export LD_LIBRARY_PATH=${CONDA_PREFIX}/lib/ccl/cpu:${CONDA_PREFIX}/lib/libfabric:${CONDA_PREFIX}/lib
88+
```
89+
90+
## Optimize LLM inference with Intel Extension for PyTorch
91+
Intel Extension for PyTorch compatible with DeepSpeed AutoTP tensor parallel inference. It allows CPU inference to benefit from both DeepSpeed Automatic Tensor Parallelism, and LLM optimizations of Intel Extension for PyTorch. To use Intel Extension for PyTorch, after calling deepspeed.init_inference, call
92+
```
93+
ipex_model = ipex.llm.optimize(deepspeed_model)
94+
```
95+
to get model optimzied by Intel Extension for PyTorch.
96+
97+
## More example for using DeepSpeed with Intel Extension for PyTorch on Intel Architecture CPU
98+
Refer to https://github.com/intel/intel-extension-for-pytorch/tree/main/examples/cpu/inference/python/llm for more extensive guide.
99+
100+
# Intel XPU
101+
DeepSpeed XPU accelerator supports Intel® Data Center GPU Max Series.
102+
103+
DeepSpeed has been verified on the following GPU products:
104+
* Intel® Data Center GPU Max 1100
105+
* Intel® Data Center GPU Max 1550
106+
107+
## Installation steps for Intel XPU
108+
To install DeepSpeed on Intel XPU, use the following steps:
109+
1. Install oneAPI base toolkit \
110+
The Intel® oneAPI Base Toolkit (Base Kit) is a core set of tools and libraries, including an DPC++/C++ Compiler for building Deepspeed XPU kernels like fusedAdam and CPUAdam, high performance computation libraries demanded by IPEX, etc.
111+
For easy download, usage and more details, check [Intel oneAPI base-toolkit](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit.html).
112+
2. Install PyTorch, Intel extension for pytorch, Intel oneCCL Bindings for PyTorch. These packages are required in `xpu_accelerator` for torch functionality and performance, also communication backend on Intel platform. The recommended installation reference:
113+
https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu.
114+
115+
3. Install DeepSpeed \
116+
`pip install deepspeed`
117+
118+
## How to use DeepSpeed on Intel XPU
119+
DeepSpeed can be launched on Intel XPU with deepspeed launch command. Before that, user needs activate the oneAPI environment by: \
120+
`source <oneAPI installed path>/setvars.sh`
121+
122+
To validate the XPU availability and if the XPU accelerator is correctly chosen, here is an example:
123+
```
124+
$ python
125+
>>> import torch; print('torch:', torch.__version__)
126+
torch: 2.3.0
127+
>>> import intel_extension_for_pytorch; print('XPU available:', torch.xpu.is_available())
128+
XPU available: True
129+
>>> from deepspeed.accelerator import get_accelerator; print('accelerator:', get_accelerator()._name)
130+
accelerator: xpu
131+
```
132+
133+
## More example for using DeepSpeed on Intel XPU
134+
Refer to https://github.com/intel/intel-extension-for-pytorch/tree/release/xpu/2.1.40/examples/gpu/inference/python/llm for more extensive guide.

0 commit comments

Comments
 (0)