Skip to content

Commit 182f40e

Browse files
authored
Add NVIDIA TensorRT Model Optimizer in vLLM documentation (#17561)
1 parent 3e887d2 commit 182f40e

File tree

3 files changed

+90
-1
lines changed

3 files changed

+90
-1
lines changed

docs/source/features/quantization/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ gptqmodel
1717
int4
1818
int8
1919
fp8
20+
modelopt
2021
quark
2122
quantized_kvcache
2223
torchao
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# NVIDIA TensorRT Model Optimizer
2+
3+
The [NVIDIA TensorRT Model Optimizer](https://github.com/NVIDIA/TensorRT-Model-Optimizer) is a library designed to optimize models for inference with NVIDIA GPUs. It includes tools for Post-Training Quantization (PTQ) and Quantization Aware Training (QAT) of Large Language Models (LLMs), Vision Language Models (VLMs), and diffusion models.
4+
5+
We recommend installing the library with:
6+
7+
```console
8+
pip install nvidia-modelopt
9+
```
10+
11+
## Quantizing HuggingFace Models with PTQ
12+
13+
You can quantize HuggingFace models using the example scripts provided in the TensorRT Model Optimizer repository. The primary script for LLM PTQ is typically found within the `examples/llm_ptq` directory.
14+
15+
Below is an example showing how to quantize a model using modelopt's PTQ API:
16+
17+
```python
18+
import modelopt.torch.quantization as mtq
19+
from transformers import AutoModelForCausalLM
20+
21+
# Load the model from HuggingFace
22+
model = AutoModelForCausalLM.from_pretrained("<path_or_model_id>")
23+
24+
# Select the quantization config, for example, FP8
25+
config = mtq.FP8_DEFAULT_CFG
26+
27+
# Define a forward loop function for calibration
28+
def forward_loop(model):
29+
for data in calib_set:
30+
model(data)
31+
32+
# PTQ with in-place replacement of quantized modules
33+
model = mtq.quantize(model, config, forward_loop)
34+
```
35+
36+
After the model is quantized, you can export it to a quantized checkpoint using the export API:
37+
38+
```python
39+
import torch
40+
from modelopt.torch.export import export_hf_checkpoint
41+
42+
with torch.inference_mode():
43+
export_hf_checkpoint(
44+
model, # The quantized model.
45+
export_dir, # The directory where the exported files will be stored.
46+
)
47+
```
48+
49+
The quantized checkpoint can then be deployed with vLLM. As an example, the following code shows how to deploy `nvidia/Llama-3.1-8B-Instruct-FP8`, which is the FP8 quantized checkpoint derived from `meta-llama/Llama-3.1-8B-Instruct`, using vLLM:
50+
51+
```python
52+
from vllm import LLM, SamplingParams
53+
54+
def main():
55+
56+
model_id = "nvidia/Llama-3.1-8B-Instruct-FP8"
57+
# Ensure you specify quantization='modelopt' when loading the modelopt checkpoint
58+
llm = LLM(model=model_id, quantization="modelopt", trust_remote_code=True)
59+
60+
sampling_params = SamplingParams(temperature=0.8, top_p=0.9)
61+
62+
prompts = [
63+
"Hello, my name is",
64+
"The president of the United States is",
65+
"The capital of France is",
66+
"The future of AI is",
67+
]
68+
69+
outputs = llm.generate(prompts, sampling_params)
70+
71+
for output in outputs:
72+
prompt = output.prompt
73+
generated_text = output.outputs[0].text
74+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
75+
76+
if __name__ == "__main__":
77+
main()
78+
```

docs/source/features/quantization/supported_hardware.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,17 @@ The table below shows the compatibility of various quantization implementations
129129
*
130130
*
131131
*
132-
132+
- * modelopt
133+
* ✅︎
134+
* ✅︎
135+
* ✅︎
136+
* ✅︎
137+
* ✅︎︎
138+
*
139+
*
140+
*
141+
*
142+
*
133143
:::
134144

135145
- Volta refers to SM 7.0, Turing to SM 7.5, Ampere to SM 8.0/8.6, Ada to SM 8.9, and Hopper to SM 9.0.

0 commit comments

Comments
 (0)