-
Notifications
You must be signed in to change notification settings - Fork 171
[ Docs ] Overhaul accelerate
user guide
#76
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
39 commits
Select commit
Hold shift + click to select a range
e8227e5
update for fp8 dyanmic
4243c8f
Merge branch 'main' of https://github.com/vllm-project/llm-compressor
a0817a8
stash
0b6fd0f
updated accelerate examples
b4c60a8
updated big model examples
7ac50dc
revert fp8 changes
538610a
style and quality
2a0c245
fix main README
bfd64eb
anothger nit
b78a0e7
remove unnessary changes
94c4415
remove spurious changes
1edb96a
anoter silly change
8f6a39d
udpate
977d7f6
tweak language
b6f41d3
cleanup
82644d2
adjust title
f450c52
cleanup example more
25a9475
cleanup readme more
80333bd
cleanup
1558c24
update
f0ada4f
update
f89ccaf
update
ad27b4e
final cleanup
4ad73e9
update doc
e636000
update
a82e910
further cleanup
290d984
typo
3cac6a2
cleanup
d1e702a
make example inline
130d11a
more cleanup
6adec54
more nits
544ab37
update
92bca06
update examples
66ca93a
update
5930f90
tweak int8 example to make it run
7b31b06
update big model wording
24ce02e
Merge branch 'main' into switch-big-model-example
8f220e4
fix repeat in README
a9ffcae
revert readme to main
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
# Quantizing Big Models with HF Accelerate | ||
|
||
`llmcompressor` integrates with `accelerate` to support quantizing large models such as Llama 70B and 405B, or quantizing any model with limited GPU resources. | ||
|
||
## Overview | ||
|
||
[`accelerate`]((https://huggingface.co/docs/accelerate/en/index)) is a highly useful library in the Hugging Face ecosystem that supports for working with large models, including: | ||
- Offloading parameters to CPU | ||
- Sharding models across multiple GPUs with pipeline-parallelism | ||
|
||
|
||
### Using `device_map` | ||
|
||
To enable `accelerate` features with `llmcompressor`, simple insert `device_map` in `from_pretrained` during model load. | ||
|
||
```python | ||
from llmcompressor.transformers import SparseAutoModelForCausalLM | ||
MODEL_ID = "meta-llama/Meta-Llama-3-70B-Instruct" | ||
|
||
# device_map="auto" triggers usage of accelerate | ||
# if > 1 GPU, the model will be sharded across the GPUs | ||
# if not enough GPU memory to fit the model, parameters are offloaded to the CPU | ||
model = SparseAutoModelForCausalLM.from_pretrained( | ||
MODEL_ID, device_map="auto", torch_dtype="auto") | ||
``` | ||
|
||
`llmcompressor` is designed to respect the `device_map`, so calls to `oneshot` | ||
will work properly out of the box for basic quantization with `QuantizationModifier`, | ||
even for CPU offloaded models. | ||
|
||
To enable CPU offloading for second-order quantization methods such as GPTQ, we need to | ||
allocate additional memory upfront when computing the device map. Note that this | ||
device map will only compatible with `GPTQModifier(sequential_update=True, ...)` | ||
|
||
```python | ||
from llmcompressor.transformers.compression.helpers import calculate_offload_device_map | ||
from llmcompressor.transformers import SparseAutoModelForCausalLM, | ||
MODEL_ID = "meta-llama/Meta-Llama-3-70B-Instruct" | ||
|
||
# Load model, reserving memory in the device map for sequential GPTQ (adjust num_gpus as needed) | ||
device_map = calculate_offload_device_map(MODEL_ID, reserve_for_hessians=True, num_gpus=1) | ||
model = SparseAutoModelForCausalLM.from_pretrained( | ||
MODEL_ID, | ||
device_map=device_map, | ||
torch_dtype="auto", | ||
) | ||
``` | ||
|
||
### Practical Advice | ||
|
||
When working with `accelerate`, it is important to keep in mind that CPU offloading and naive pipeline-parallelism will slow down forward passes through the model. As a result, we need to take care to ensure that the quantization methods used fit well with the offloading scheme as methods that require many forward passes though the model will be slowed down. | ||
|
||
General rules of thumb: | ||
- CPU offloading is best used with data-free quantization methods (e.g. PTQ with `FP8_DYNAMIC`) | ||
- Multi-GPU is fast enough to be used with calibration data-based methods with `sequential_update=False` | ||
- It is possible to use Multi-GPU with `sequential_update=True` to save GPU memory, but the runtime will be slower | ||
|
||
## Examples | ||
|
||
We will show working examples for each use case: | ||
- **CPU Offloading**: Quantize `Llama-70B` to `FP8` using `PTQ` with a single GPU | ||
- **Multi-GPU**: Quantize `Llama-70B` to `INT8` using `GPTQ` and `SmoothQuant` with 8 GPUs | ||
|
||
### Installation | ||
|
||
Install `llmcompressor`: | ||
|
||
```bash | ||
pip install llmcompressor==0.1.0 | ||
``` | ||
|
||
### CPU Offloading: `FP8` Quantization with `PTQ` | ||
|
||
CPU offloading is slow. As a result, we recommend using this feature only with data-free quantization methods. For example, when quantizing a model to `fp8`, we typically use simple `PTQ` to statically quantize the weights and use dynamic quantization for the activations. These methods do not require calibration data. | ||
|
||
- `cpu_offloading_fp8.py` demonstrates quantizing the weights and activations of `Llama-70B` to `fp8` on a single GPU: | ||
|
||
```bash | ||
export CUDA_VISIBLE_DEVICES=0 | ||
python cpu_offloading_fp8.py | ||
``` | ||
|
||
The resulting model `./Meta-Llama-3-70B-Instruct-FP8-Dynamic` is ready to run with `vllm`! | ||
|
||
### Multi-GPU: `INT8` Quantization with `GPTQ` | ||
|
||
For quantization methods that require calibration data (e.g. `GPTQ`), CPU offloading is too slow. For these methods, `llmcompressor` can use `accelerate` multi-GPU to quantize models that are larger than a single GPU. For example, when quantizing a model to `int8`, we typically use `GPTQ` to statically quantize the weights, which requires calibration data. | ||
|
||
Note that running non-sequential `GPTQ` requires significant additional memory beyond the model size. As a rough rule of thumb, running `GPTQModifier` non-sequentially will take up 3x the model size for a 16-bit model and 2x the model size for a 32-bit model (these estimates include the memory required to store the model itself in GPU). | ||
|
||
- `multi_gpu_int8.py` demonstrates quantizing the weights and activations of `Llama-70B` to `int8` on 8 A100s: | ||
|
||
```python | ||
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 | ||
python multi_gpu_int8.py | ||
``` | ||
|
||
The resulting model `./Meta-Llama-3-70B-Instruct-INT8-Dynamic` is quantized and ready to run with `vllm`! | ||
|
||
## Questions or Feature Request? | ||
|
||
Please open up an issue on `vllm-project/llm-compressor` |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from transformers import AutoTokenizer | ||
|
||
from llmcompressor.modifiers.quantization import QuantizationModifier | ||
from llmcompressor.transformers import SparseAutoModelForCausalLM, oneshot | ||
|
||
MODEL_ID = "meta-llama/Meta-Llama-3-70B-Instruct" | ||
OUTPUT_DIR = MODEL_ID.split("/")[1] + "-FP8-Dynamic" | ||
|
||
# Load model | ||
# Note: device_map="auto" will offload to CPU if not enough space on GPU. | ||
model = SparseAutoModelForCausalLM.from_pretrained( | ||
MODEL_ID, device_map="auto", torch_dtype="auto" | ||
) | ||
|
||
# Configure the quantization scheme and algorithm (PTQ + FP8_DYNAMIC). | ||
recipe = QuantizationModifier( | ||
targets="Linear", scheme="FP8_DYNAMIC", ignore=["lm_head"] | ||
) | ||
|
||
# Apply quantization and save in `compressed-tensors` format. | ||
oneshot( | ||
model=model, | ||
recipe=recipe, | ||
tokenizer=AutoTokenizer.from_pretrained(MODEL_ID), | ||
output_dir=OUTPUT_DIR, | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.