Skip to content

Commit 19942a9

Browse files
authored
Add recipes for Llama3.3 70B and Llama4 Scout (#13)
Signed-off-by: Po-Han Huang <pohanh@nvidia.com>
1 parent 985b860 commit 19942a9

File tree

3 files changed

+570
-0
lines changed

3 files changed

+570
-0
lines changed

Llama/Llama3.3-70B.md

Lines changed: 283 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,283 @@
1+
# Quick Start Recipe for Llama 3.3 70B on vLLM - NVIDIA Blackwell Hardware
2+
3+
## Introduction
4+
5+
This quick start recipe provides step-by-step instructions for running the Llama 3.3-70B Instruct model using vLLM with FP8 and NVFP4 quantization, optimized for NVIDIA Blackwell architecture GPUs. It covers the complete setup required; from accessing model weights and preparing the software environment to configuring vLLM parameters, launching the server, and validating inference output.
6+
7+
The recipe is intended for developers and practitioners seeking high-throughput or low-latency inference using NVIDIA’s accelerated stack — building a docker image with vLLM for model serving, FlashInfer for optimized CUDA kernels, and ModelOpt to enable FP8 and NVFP4 quantized execution.
8+
9+
## Access & Licensing
10+
11+
### License
12+
13+
To use Llama 3.3-70B, you must first agree to Meta’s Llama 3 Community License (https://ai.meta.com/resources/models-and-libraries/llama-downloads/). NVIDIA’s quantized versions (FP8 and FP4) are built on top of the base model and are available for research and commercial use under the same license.
14+
15+
### Weights
16+
17+
You only need to download one version of the model weights, depending on the precision in use:
18+
19+
- FP8 model for Blackwell: [nvidia/Llama-3.3-70B-Instruct-FP8](https://huggingface.co/nvidia/Llama-3.3-70B-Instruct-FP8)
20+
- FP4 model for Blackwell: [nvidia/Llama-3.3-70B-Instruct-FP4](https://huggingface.co/nvidia/Llama-3.3-70B-Instruct-FP4)
21+
22+
No Hugging Face authentication token is required to download these weights.
23+
24+
Note on Quantization Choice:
25+
For Blackwell, NVFP4 provides additional memory savings and throughput gains, but may require tuning to maintain accuracy on certain tasks.
26+
27+
## Prerequisites
28+
29+
- OS: Linux
30+
- Drivers: CUDA Driver 575 or above
31+
- GPU: Blackwell architecture
32+
- [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/index.html)
33+
34+
## Deployment Steps
35+
36+
### Build Docker Image
37+
38+
Build a docker image with vLLM and other dependencies installed. We will use both the [official vLLM Dockerfile](https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile) and additional `Dockerfile.nvidia` containing the necessary packages and environment settings for NVIDIA GPUs.
39+
40+
First, create `Dockerfile.nvidia` file as follows:
41+
42+
`Dockerfile.nvidia`
43+
```
44+
ARG base_image
45+
FROM ${base_image}
46+
47+
WORKDIR /workspace
48+
49+
# Required environment variables for optimal performance on NVIDIA GPUs
50+
# This will be removed when we enable these optimizations by default in VLLM.
51+
# Use V1 engine
52+
ENV VLLM_USE_V1=1
53+
# Use FlashInfer backend for attentions
54+
ENV VLLM_ATTENTION_BACKEND=FLASHINFER
55+
# Use FlashInfer trtllm-gen attention kernels
56+
ENV VLLM_USE_TRTLLM_DECODE_ATTENTION=1
57+
58+
# Install lm_eval that is compatible with the latest vLLM
59+
RUN pip3 install --no-build-isolation "lm-eval[api] @ git+https://github.com/EleutherAI/lm-evaluation-harness@4f8195f"
60+
61+
ENTRYPOINT ["/bin/bash"]
62+
```
63+
64+
Build the docker image named `vllm-llama-deploy` using the official vLLM Dockerfile and the Dockerfile.nvidia we just created.
65+
66+
`build_image.sh`
67+
```
68+
# Clone the vLLM GitHub repo and checkout the spcific commit.
69+
git clone -b main --single-branch https://github.com/vllm-project/vllm.git
70+
cd vllm
71+
git checkout 055bd3978ededea015fb8f0cb6aa3cc48d84cde8
72+
73+
# Copy your Dockerfile.nvidia to docker/ directory
74+
cp ../Dockerfile.nvidia docker/Dockerfile.nvidia
75+
76+
# Build the docker image using official vLLM Dockerfile and Dockerfile.nvidia.
77+
DOCKER_BUILDKIT=1 docker build . \
78+
--file docker/Dockerfile \
79+
--target vllm-openai \
80+
--build-arg CUDA_VERSION=12.8.1 \
81+
--build-arg max_jobs=32 \
82+
--build-arg nvcc_threads=2 \
83+
--build-arg USE_SCCACHE=1 \
84+
--build-arg SCCACHE_S3_NO_CREDENTIALS=1 \
85+
--build-arg RUN_WHEEL_CHECK=false \
86+
--build-arg torch_cuda_arch_list="9.0+PTX 10.0+PTX" \
87+
--build-arg vllm_fa_cmake_gpu_arches="90-real;100-real" \
88+
-t vllm-official
89+
90+
DOCKER_BUILDKIT=1 docker build . \
91+
--file docker/Dockerfile.nvidia \
92+
--build-arg base_image=vllm-official \
93+
-t vllm-llama-deploy
94+
```
95+
96+
### Run Docker Container
97+
98+
Run the docker container using the docker image `vllm-llama-deploy`.
99+
100+
`run_container.sh`
101+
```
102+
docker run -e HF_TOKEN="$HF_TOKEN" -e HF_HOME="$HF_HOME" --ipc=host --gpus all --rm -it vllm-llama-deploy
103+
```
104+
105+
Note: You can mount additional directories and paths using the `-v <local_path>:<path>` flag if needed, such as mounting the downloaded weight paths.
106+
107+
The `-e HF_TOKEN="$HF_TOKEN" -e HF_HOME="$HF_HOME"` flags are added so that the models are downloaded using your HuggingFace token and the downloaded models can be cached in $HF_HOME. Refer to HuggingFace documentation for more information.
108+
109+
### Launch the vLLM Server
110+
111+
Below is an example command to launch the vLLM server with Llama-3.3-70B-Instruct-FP8 model. The explanation of each flag is shown in the `Configs and Parameters` section.
112+
113+
`launch_server.sh`
114+
```
115+
vllm serve nvidia/Llama-3.3-70B-Instruct-FP8 \
116+
--host 0.0.0.0 \
117+
--port 8080 \
118+
--tokenizer nvidia/Llama-3.3-70B-Instruct-FP8 \
119+
--quantization modelopt \
120+
--kv-cache-dtype fp8 \
121+
--trust-remote-code \
122+
--gpu-memory-utilization 0.9 \
123+
--compilation-config '{"pass_config": {"enable_fi_allreduce_fusion": true}, "custom_ops": ["+rms_norm"], "level": 3}' \
124+
--enable-chunked-prefill \
125+
--async-scheduling \
126+
--no-enable-prefix-caching \
127+
--disable-log-requests \
128+
--pipeline-parallel-size 1 \
129+
--tensor-parallel-size 1 \
130+
--max-num-seqs 512 \
131+
--max-num-batched-tokens 8192 \
132+
--max-model-len 9216 &
133+
```
134+
135+
After the server is set up, the client can now send prompt requests to the server and receive results.
136+
137+
### Configs and Parameters
138+
139+
You can specify the IP address and the port that you would like to run the server with using these flags:
140+
141+
- `--host`: IP address of the server.
142+
- `--port`: The port to listen to by the server.
143+
144+
Below are the config flags that we do not recommend changing or tuning with:
145+
146+
- `--tokenizer`: Specify the path to the model file.
147+
- `--quantization`: Must be `modelopt` for FP8 model and `modelopt_fp4` for FP4 model.
148+
- `--kv-cache-dtype`: Kv-cache data type. We recommend setting it to `fp8` for best performance.
149+
- `--trust-remote-code`: Trust the model code.
150+
- `--gpu-memory-utilization`: The fraction of GPU memory to be used for the model executor. We recommend setting it to `0.9` to use up to 90% of the GPU memory.
151+
- `--compilation-config`: Configuration for vLLM compilation stage. We recommend setting it to `'{"pass_config": {"enable_fi_allreduce_fusion": true}, "custom_ops": ["+rms_norm"], "level": 3}'` to enable all the necessary fusions for the best performance.
152+
- We are trying to enable these fusions by default so that this flag is no longer needed in the future.
153+
- `--enable-chunked-prefill`: Enable chunked prefill stage. We recommend always adding this flag for best performance.
154+
- `--async-scheduling`: Enable asynchronous scheduling to reduce the host overheads between decoding steps. We recommend always adding this flag for best performance on Blackwell architecture.
155+
- `--no-enable-prefix-caching` Disable prefix caching. We recommend always adding this flag if running with synthetic dataset for consistent performance measurement.
156+
- `--disable-log-requests`: Disable verbose logging from server.
157+
- `--pipeline-parallel-size`: Pipeline parallelism size. We recommend setting it to `1` for best performance.
158+
159+
Below are a few tunable parameters you can modify based on your serving requirements:
160+
161+
- `--tensor-parallel-size`: Tensor parallelism size. Increasing this will increase the number of GPUs that are used for inference.
162+
- Set this to `1` to achieve the best throughput, and set this to `2`, `4`, or `8` to achieve better per-user latencies.
163+
- `--max-num-seqs`: Maximum number of sequences per batch.
164+
- Set this to a large number like `512` to achieve the best throughput, and set this to a small number like `16` to achieve better per-user latencies.
165+
- `--max-num-batched-tokens`: Maximum number of tokens per batch.
166+
- We recommend setting this to `8192`. Increasing this value may have slight performance improvements if the sequences have long input sequence lengths.
167+
- `--max-model-len`: Maximum number of total tokens, including the input tokens and output tokens, for each request.
168+
- This must be set to a larger number if the expected input/output sequence lengths are large.
169+
- For example, if the maximum input sequence length is 1024 tokens and maximum output sequence length is 1024, then this must be set to at least 2048.
170+
171+
## Validation & Expected Behavior
172+
173+
### Basic Test
174+
175+
After the vLLM server is set up and shows `Application startup complete`, you can send requests to the server
176+
177+
`run_basic_test.sh`
178+
```
179+
curl http://0.0.0.0:8080/v1/completions -H "Content-Type: application/json" -d '{ "model": "nvidia/Llama-3.3-70B-Instruct-FP8", "prompt": "San Francisco is a", "max_tokens": 20, "temperature": 0 }'
180+
```
181+
182+
Here is an example response, showing that the vLLM server returns "*city that is known for its vibrant culture, stunning architecture, and breathtaking natural beauty. From the iconic...*", completing the input sequence with up to 20 tokens.
183+
184+
```
185+
{"id":"cmpl-36133d6f81384b308dd7c0d0e4327dd6","object":"text_completion","created":1753940624,"model":"nvidia/Llama-3.3-70B-Instruct-FP8","choices":[{"index":0,"text":" city that is known for its vibrant culture, stunning architecture, and breathtaking natural beauty. From the iconic","logprobs":null,"finish_reason":"length","stop_reason":null,"prompt_logprobs":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":5,"total_tokens":25,"completion_tokens":20,"prompt_tokens_details":null},"kv_transfer_params":null}
186+
```
187+
188+
### Verify Accuracy
189+
190+
When the server is still running, we can run accuracy tests using lm_eval tool.
191+
192+
`run_accuracy.sh`
193+
```
194+
lm_eval \
195+
--model local-completions \
196+
--tasks gsm8k \
197+
--model_args \
198+
base_url=http://0.0.0.0:8080/v1/completions,\
199+
model=nvidia/Llama-3.3-70B-Instruct-FP8,\
200+
tokenized_requests=False,tokenizer_backend=None,\
201+
num_concurrent=128,timeout=120,max_retries=5
202+
```
203+
204+
Here is an example accuracy result with the nvidia/Llama-3.3-70B-Instruct-FP8 model on one B200 GPU:
205+
206+
```
207+
local-completions (base_url=http://0.0.0.0:8080/v1/completions,model=nvidia/Llama-3.3-70B-Instruct-FP8,tokenized_requests=False,tokenizer_backend=None,num_concurrent=128,timeout=120,max_retries=5), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 1
208+
|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr|
209+
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
210+
|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.9386|± |0.0066|
211+
| | |strict-match | 5|exact_match|↑ |0.7695|± |0.0116|
212+
```
213+
214+
### Benchmarking Performance
215+
216+
To benchmark the performance, you can use the `benchmark_serving.py` script in the vLLM repository.
217+
218+
`run_performance.sh`
219+
```
220+
python3 /vllm-workspace/benchmarks/benchmark_serving.py \
221+
--host 0.0.0.0 \
222+
--port 8080 \
223+
--model nvidia/Llama-3.3-70B-Instruct-FP8 \
224+
--trust-remote-code \
225+
--dataset-name random \
226+
--random-input-len 1024 \
227+
--random-output-len 1024 \
228+
--ignore-eos \
229+
--max-concurrency 512 \
230+
--num-prompts 2560 \
231+
--save-result --result-filename vllm_benchmark_serving_results.json
232+
```
233+
234+
Explanations for the flags:
235+
236+
- `--dataset-name`: Which dataset to use for benchmarking. We use a `random` dataset here.
237+
- `--random-input-len`: Specifies the average input sequence length.
238+
- `--random-output-len`: Specifies the average output sequence length.
239+
- `--ignore-eos`: Disables early returning when eos (end-of-sentence) token is generated. This ensures that the output sequence lengths match our expected range.
240+
- `--max-concurrency`: Maximum number of in-flight requests. We recommend matching this with the `--max-num-seqs` flag used to launch the server.
241+
- `--num-prompts`: Total number of prompts used for performance benchmarking. We recommend setting it to at least five times of the `--max-concurrency` to measure the steady state performance.
242+
- `--save-result --result-filename`: Output location for the performance benchmarking result.
243+
244+
### Interpreting `benchmark_serving.py` Output
245+
246+
Sample output by the `benchmark_serving.py` script:
247+
248+
```
249+
============ Serving Benchmark Result ============
250+
Successful requests: xxxxxx
251+
Benchmark duration (s): xxx.xx
252+
Total input tokens: xxxxxx
253+
Total generated tokens: xxxxxx
254+
Request throughput (req/s): xxx.xx
255+
Output token throughput (tok/s): xxx.xx
256+
Total Token throughput (tok/s): xxx.xx
257+
---------------Time to First Token----------------
258+
Mean TTFT (ms): xxx.xx
259+
Median TTFT (ms): xxx.xx
260+
P99 TTFT (ms): xxx.xx
261+
-----Time per Output Token (excl. 1st token)------
262+
Mean TPOT (ms): xxx.xx
263+
Median TPOT (ms): xxx.xx
264+
P99 TPOT (ms): xxx.xx
265+
---------------Inter-token Latency----------------
266+
Mean ITL (ms): xxx.xx
267+
Median ITL (ms): xxx.xx
268+
P99 ITL (ms): xxx.xx
269+
----------------End-to-end Latency----------------
270+
Mean E2EL (ms): xxx.xx
271+
Median E2EL (ms): xxx.xx
272+
P99 E2EL (ms): xxx.xx
273+
==================================================
274+
```
275+
276+
Explanations for key metrics:
277+
278+
- `Median Time to First Token (TTFT)`: The typical time elapsed from when a request is sent until the first output token is generated.
279+
- `Median Time Per Output Token (TPOT)`: The typical time required to generate each token after the first one.
280+
- `Median Inter-Token Latency (ITL)`: The typical time delay between the completion of one token and the completion of the next.
281+
- `Median End-to-End Latency (E2EL)`: The typical total time from when a request is submitted until the final token of the response is received.
282+
- `Total Token Throughput`: The combined rate at which the system processes both input (prompt) tokens and output (generated) tokens.
283+

0 commit comments

Comments
 (0)