|  | 
|  | 1 | +# Quick Start Recipe for Llama 3.3 70B on vLLM - NVIDIA Blackwell Hardware | 
|  | 2 | + | 
|  | 3 | +## Introduction | 
|  | 4 | + | 
|  | 5 | +This quick start recipe provides step-by-step instructions for running the Llama 3.3-70B Instruct model using vLLM with FP8 and NVFP4 quantization, optimized for NVIDIA Blackwell architecture GPUs. It covers the complete setup required; from accessing model weights and preparing the software environment to configuring vLLM parameters, launching the server, and validating inference output. | 
|  | 6 | + | 
|  | 7 | +The recipe is intended for developers and practitioners seeking high-throughput or low-latency inference using NVIDIA’s accelerated stack — building a docker image with vLLM for model serving, FlashInfer for optimized CUDA kernels, and ModelOpt to enable FP8 and NVFP4 quantized execution. | 
|  | 8 | + | 
|  | 9 | +## Access & Licensing | 
|  | 10 | + | 
|  | 11 | +### License | 
|  | 12 | + | 
|  | 13 | +To use Llama 3.3-70B, you must first agree to Meta’s Llama 3 Community License (https://ai.meta.com/resources/models-and-libraries/llama-downloads/). NVIDIA’s quantized versions (FP8 and FP4) are built on top of the base model and are available for research and commercial use under the same license. | 
|  | 14 | + | 
|  | 15 | +### Weights | 
|  | 16 | + | 
|  | 17 | +You only need to download one version of the model weights, depending on the precision in use: | 
|  | 18 | + | 
|  | 19 | +- FP8 model for Blackwell: [nvidia/Llama-3.3-70B-Instruct-FP8](https://huggingface.co/nvidia/Llama-3.3-70B-Instruct-FP8) | 
|  | 20 | +- FP4 model for Blackwell: [nvidia/Llama-3.3-70B-Instruct-FP4](https://huggingface.co/nvidia/Llama-3.3-70B-Instruct-FP4) | 
|  | 21 | + | 
|  | 22 | +No Hugging Face authentication token is required to download these weights. | 
|  | 23 | + | 
|  | 24 | +Note on Quantization Choice: | 
|  | 25 | +For Blackwell, NVFP4 provides additional memory savings and throughput gains, but may require tuning to maintain accuracy on certain tasks. | 
|  | 26 | + | 
|  | 27 | +## Prerequisites | 
|  | 28 | + | 
|  | 29 | +- OS: Linux | 
|  | 30 | +- Drivers: CUDA Driver 575 or above | 
|  | 31 | +- GPU: Blackwell architecture | 
|  | 32 | +- [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/index.html) | 
|  | 33 | + | 
|  | 34 | +## Deployment Steps | 
|  | 35 | + | 
|  | 36 | +### Build Docker Image | 
|  | 37 | + | 
|  | 38 | +Build a docker image with vLLM and other dependencies installed. We will use both the [official vLLM Dockerfile](https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile) and additional `Dockerfile.nvidia` containing the necessary packages and environment settings for NVIDIA GPUs. | 
|  | 39 | + | 
|  | 40 | +First, create `Dockerfile.nvidia` file as follows: | 
|  | 41 | + | 
|  | 42 | +`Dockerfile.nvidia` | 
|  | 43 | +``` | 
|  | 44 | +ARG base_image | 
|  | 45 | +FROM ${base_image} | 
|  | 46 | +
 | 
|  | 47 | +WORKDIR /workspace | 
|  | 48 | +
 | 
|  | 49 | +# Required environment variables for optimal performance on NVIDIA GPUs | 
|  | 50 | +# This will be removed when we enable these optimizations by default in VLLM. | 
|  | 51 | +# Use V1 engine | 
|  | 52 | +ENV VLLM_USE_V1=1 | 
|  | 53 | +# Use FlashInfer backend for attentions | 
|  | 54 | +ENV VLLM_ATTENTION_BACKEND=FLASHINFER | 
|  | 55 | +# Use FlashInfer trtllm-gen attention kernels | 
|  | 56 | +ENV VLLM_USE_TRTLLM_DECODE_ATTENTION=1 | 
|  | 57 | +
 | 
|  | 58 | +# Install lm_eval that is compatible with the latest vLLM | 
|  | 59 | +RUN pip3 install --no-build-isolation "lm-eval[api] @ git+https://github.com/EleutherAI/lm-evaluation-harness@4f8195f" | 
|  | 60 | +
 | 
|  | 61 | +ENTRYPOINT ["/bin/bash"] | 
|  | 62 | +``` | 
|  | 63 | + | 
|  | 64 | +Build the docker image named `vllm-llama-deploy` using the official vLLM Dockerfile and the Dockerfile.nvidia we just created. | 
|  | 65 | + | 
|  | 66 | +`build_image.sh` | 
|  | 67 | +``` | 
|  | 68 | +# Clone the vLLM GitHub repo and checkout the spcific commit. | 
|  | 69 | +git clone -b main --single-branch https://github.com/vllm-project/vllm.git | 
|  | 70 | +cd vllm | 
|  | 71 | +git checkout 055bd3978ededea015fb8f0cb6aa3cc48d84cde8 | 
|  | 72 | +
 | 
|  | 73 | +# Copy your Dockerfile.nvidia to docker/ directory | 
|  | 74 | +cp ../Dockerfile.nvidia docker/Dockerfile.nvidia | 
|  | 75 | +
 | 
|  | 76 | +# Build the docker image using official vLLM Dockerfile and Dockerfile.nvidia. | 
|  | 77 | +DOCKER_BUILDKIT=1 docker build . \ | 
|  | 78 | +        --file docker/Dockerfile \ | 
|  | 79 | +        --target vllm-openai \ | 
|  | 80 | +        --build-arg CUDA_VERSION=12.8.1 \ | 
|  | 81 | +        --build-arg max_jobs=32 \ | 
|  | 82 | +        --build-arg nvcc_threads=2 \ | 
|  | 83 | +        --build-arg USE_SCCACHE=1 \ | 
|  | 84 | +        --build-arg SCCACHE_S3_NO_CREDENTIALS=1 \ | 
|  | 85 | +        --build-arg RUN_WHEEL_CHECK=false \ | 
|  | 86 | +        --build-arg torch_cuda_arch_list="9.0+PTX 10.0+PTX" \ | 
|  | 87 | +        --build-arg vllm_fa_cmake_gpu_arches="90-real;100-real" \ | 
|  | 88 | +        -t vllm-official | 
|  | 89 | +
 | 
|  | 90 | +DOCKER_BUILDKIT=1 docker build . \ | 
|  | 91 | +        --file docker/Dockerfile.nvidia \ | 
|  | 92 | +        --build-arg base_image=vllm-official \ | 
|  | 93 | +        -t vllm-llama-deploy | 
|  | 94 | +``` | 
|  | 95 | + | 
|  | 96 | +### Run Docker Container | 
|  | 97 | + | 
|  | 98 | +Run the docker container using the docker image `vllm-llama-deploy`. | 
|  | 99 | + | 
|  | 100 | +`run_container.sh` | 
|  | 101 | +``` | 
|  | 102 | +docker run -e HF_TOKEN="$HF_TOKEN" -e HF_HOME="$HF_HOME" --ipc=host --gpus all --rm -it vllm-llama-deploy | 
|  | 103 | +``` | 
|  | 104 | + | 
|  | 105 | +Note: You can mount additional directories and paths using the `-v <local_path>:<path>` flag if needed, such as mounting the downloaded weight paths. | 
|  | 106 | + | 
|  | 107 | +The `-e HF_TOKEN="$HF_TOKEN" -e HF_HOME="$HF_HOME"` flags are added so that the models are downloaded using your HuggingFace token and the downloaded models can be cached in $HF_HOME. Refer to HuggingFace documentation for more information. | 
|  | 108 | + | 
|  | 109 | +### Launch the vLLM Server | 
|  | 110 | + | 
|  | 111 | +Below is an example command to launch the vLLM server with Llama-3.3-70B-Instruct-FP8 model. The explanation of each flag is shown in the `Configs and Parameters` section. | 
|  | 112 | + | 
|  | 113 | +`launch_server.sh` | 
|  | 114 | +``` | 
|  | 115 | +vllm serve nvidia/Llama-3.3-70B-Instruct-FP8 \ | 
|  | 116 | +  --host 0.0.0.0 \ | 
|  | 117 | +  --port 8080 \ | 
|  | 118 | +  --tokenizer nvidia/Llama-3.3-70B-Instruct-FP8 \ | 
|  | 119 | +  --quantization modelopt \ | 
|  | 120 | +  --kv-cache-dtype fp8 \ | 
|  | 121 | +  --trust-remote-code \ | 
|  | 122 | +  --gpu-memory-utilization 0.9 \ | 
|  | 123 | +  --compilation-config '{"pass_config": {"enable_fi_allreduce_fusion": true}, "custom_ops": ["+rms_norm"], "level": 3}' \ | 
|  | 124 | +  --enable-chunked-prefill \ | 
|  | 125 | +  --async-scheduling \ | 
|  | 126 | +  --no-enable-prefix-caching \ | 
|  | 127 | +  --disable-log-requests \ | 
|  | 128 | +  --pipeline-parallel-size 1 \ | 
|  | 129 | +  --tensor-parallel-size 1 \ | 
|  | 130 | +  --max-num-seqs 512 \ | 
|  | 131 | +  --max-num-batched-tokens 8192 \ | 
|  | 132 | +  --max-model-len 9216 & | 
|  | 133 | +``` | 
|  | 134 | + | 
|  | 135 | +After the server is set up, the client can now send prompt requests to the server and receive results. | 
|  | 136 | + | 
|  | 137 | +### Configs and Parameters | 
|  | 138 | + | 
|  | 139 | +You can specify the IP address and the port that you would like to run the server with using these flags: | 
|  | 140 | + | 
|  | 141 | +- `--host`: IP address of the server.  | 
|  | 142 | +- `--port`: The port to listen to by the server. | 
|  | 143 | + | 
|  | 144 | +Below are the config flags that we do not recommend changing or tuning with: | 
|  | 145 | + | 
|  | 146 | +- `--tokenizer`: Specify the path to the model file. | 
|  | 147 | +- `--quantization`: Must be `modelopt` for FP8 model and `modelopt_fp4` for FP4 model. | 
|  | 148 | +- `--kv-cache-dtype`: Kv-cache data type. We recommend setting it to `fp8` for best performance. | 
|  | 149 | +- `--trust-remote-code`: Trust the model code. | 
|  | 150 | +- `--gpu-memory-utilization`: The fraction of GPU memory to be used for the model executor. We recommend setting it to `0.9` to use up to 90% of the GPU memory. | 
|  | 151 | +- `--compilation-config`: Configuration for vLLM compilation stage. We recommend setting it to `'{"pass_config": {"enable_fi_allreduce_fusion": true}, "custom_ops": ["+rms_norm"], "level": 3}'` to enable all the necessary fusions for the best performance. | 
|  | 152 | +  - We are trying to enable these fusions by default so that this flag is no longer needed in the future. | 
|  | 153 | +- `--enable-chunked-prefill`: Enable chunked prefill stage. We recommend always adding this flag for best performance. | 
|  | 154 | +- `--async-scheduling`: Enable asynchronous scheduling to reduce the host overheads between decoding steps. We recommend always adding this flag for best performance on Blackwell architecture. | 
|  | 155 | +- `--no-enable-prefix-caching` Disable prefix caching. We recommend always adding this flag if running with synthetic dataset for consistent performance measurement. | 
|  | 156 | +- `--disable-log-requests`: Disable verbose logging from server. | 
|  | 157 | +- `--pipeline-parallel-size`: Pipeline parallelism size. We recommend setting it to `1` for best performance. | 
|  | 158 | + | 
|  | 159 | +Below are a few tunable parameters you can modify based on your serving requirements: | 
|  | 160 | + | 
|  | 161 | +- `--tensor-parallel-size`: Tensor parallelism size. Increasing this will increase the number of GPUs that are used for inference. | 
|  | 162 | +  - Set this to `1` to achieve the best throughput, and set this to `2`, `4`, or `8` to achieve better per-user latencies. | 
|  | 163 | +- `--max-num-seqs`: Maximum number of sequences per batch. | 
|  | 164 | +  - Set this to a large number like `512` to achieve the best throughput, and set this to a small number like `16` to achieve better per-user latencies. | 
|  | 165 | +- `--max-num-batched-tokens`: Maximum number of tokens per batch. | 
|  | 166 | +  - We recommend setting this to `8192`. Increasing this value may have slight performance improvements if the sequences have long input sequence lengths. | 
|  | 167 | +- `--max-model-len`: Maximum number of total tokens, including the input tokens and output tokens, for each request. | 
|  | 168 | +  - This must be set to a larger number if the expected input/output sequence lengths are large. | 
|  | 169 | +  - For example, if the maximum input sequence length is 1024 tokens and maximum output sequence length is 1024, then this must be set to at least 2048. | 
|  | 170 | + | 
|  | 171 | +## Validation & Expected Behavior | 
|  | 172 | + | 
|  | 173 | +### Basic Test | 
|  | 174 | + | 
|  | 175 | +After the vLLM server is set up and shows `Application startup complete`, you can send requests to the server  | 
|  | 176 | + | 
|  | 177 | +`run_basic_test.sh` | 
|  | 178 | +``` | 
|  | 179 | +curl http://0.0.0.0:8080/v1/completions -H "Content-Type: application/json" -d '{ "model": "nvidia/Llama-3.3-70B-Instruct-FP8", "prompt": "San Francisco is a", "max_tokens": 20, "temperature": 0 }' | 
|  | 180 | +``` | 
|  | 181 | + | 
|  | 182 | +Here is an example response, showing that the vLLM server returns "*city that is known for its vibrant culture, stunning architecture, and breathtaking natural beauty. From the iconic...*", completing the input sequence with up to 20 tokens. | 
|  | 183 | + | 
|  | 184 | +``` | 
|  | 185 | +{"id":"cmpl-36133d6f81384b308dd7c0d0e4327dd6","object":"text_completion","created":1753940624,"model":"nvidia/Llama-3.3-70B-Instruct-FP8","choices":[{"index":0,"text":" city that is known for its vibrant culture, stunning architecture, and breathtaking natural beauty. From the iconic","logprobs":null,"finish_reason":"length","stop_reason":null,"prompt_logprobs":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":5,"total_tokens":25,"completion_tokens":20,"prompt_tokens_details":null},"kv_transfer_params":null} | 
|  | 186 | +``` | 
|  | 187 | + | 
|  | 188 | +### Verify Accuracy | 
|  | 189 | + | 
|  | 190 | +When the server is still running, we can run accuracy tests using lm_eval tool. | 
|  | 191 | + | 
|  | 192 | +`run_accuracy.sh` | 
|  | 193 | +``` | 
|  | 194 | +lm_eval \ | 
|  | 195 | +  --model local-completions \ | 
|  | 196 | +  --tasks gsm8k \ | 
|  | 197 | +  --model_args \ | 
|  | 198 | +base_url=http://0.0.0.0:8080/v1/completions,\ | 
|  | 199 | +model=nvidia/Llama-3.3-70B-Instruct-FP8,\ | 
|  | 200 | +tokenized_requests=False,tokenizer_backend=None,\ | 
|  | 201 | +num_concurrent=128,timeout=120,max_retries=5 | 
|  | 202 | +``` | 
|  | 203 | + | 
|  | 204 | +Here is an example accuracy result with the nvidia/Llama-3.3-70B-Instruct-FP8 model on one B200 GPU: | 
|  | 205 | + | 
|  | 206 | +``` | 
|  | 207 | +local-completions (base_url=http://0.0.0.0:8080/v1/completions,model=nvidia/Llama-3.3-70B-Instruct-FP8,tokenized_requests=False,tokenizer_backend=None,num_concurrent=128,timeout=120,max_retries=5), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 1 | 
|  | 208 | +|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr| | 
|  | 209 | +|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:| | 
|  | 210 | +|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9386|±  |0.0066| | 
|  | 211 | +|     |       |strict-match    |     5|exact_match|↑  |0.7695|±  |0.0116| | 
|  | 212 | +``` | 
|  | 213 | + | 
|  | 214 | +### Benchmarking Performance | 
|  | 215 | + | 
|  | 216 | +To benchmark the performance, you can use the `benchmark_serving.py` script in the vLLM repository. | 
|  | 217 | + | 
|  | 218 | +`run_performance.sh` | 
|  | 219 | +``` | 
|  | 220 | +python3 /vllm-workspace/benchmarks/benchmark_serving.py \ | 
|  | 221 | +  --host 0.0.0.0 \ | 
|  | 222 | +  --port 8080 \ | 
|  | 223 | +  --model nvidia/Llama-3.3-70B-Instruct-FP8 \ | 
|  | 224 | +  --trust-remote-code \ | 
|  | 225 | +  --dataset-name random \ | 
|  | 226 | +  --random-input-len 1024 \ | 
|  | 227 | +  --random-output-len 1024 \ | 
|  | 228 | +  --ignore-eos \ | 
|  | 229 | +  --max-concurrency 512 \ | 
|  | 230 | +  --num-prompts 2560 \ | 
|  | 231 | +  --save-result --result-filename vllm_benchmark_serving_results.json | 
|  | 232 | +``` | 
|  | 233 | + | 
|  | 234 | +Explanations for the flags: | 
|  | 235 | + | 
|  | 236 | +- `--dataset-name`: Which dataset to use for benchmarking. We use a `random` dataset here. | 
|  | 237 | +- `--random-input-len`: Specifies the average input sequence length. | 
|  | 238 | +- `--random-output-len`: Specifies the average output sequence length. | 
|  | 239 | +- `--ignore-eos`: Disables early returning when eos (end-of-sentence) token is generated. This ensures that the output sequence lengths match our expected range. | 
|  | 240 | +- `--max-concurrency`: Maximum number of in-flight requests. We recommend matching this with the `--max-num-seqs` flag used to launch the server. | 
|  | 241 | +- `--num-prompts`: Total number of prompts used for performance benchmarking. We recommend setting it to at least five times of the `--max-concurrency` to measure the steady state performance. | 
|  | 242 | +- `--save-result --result-filename`: Output location for the performance benchmarking result. | 
|  | 243 | + | 
|  | 244 | +### Interpreting `benchmark_serving.py` Output  | 
|  | 245 | + | 
|  | 246 | +Sample output by the `benchmark_serving.py` script: | 
|  | 247 | + | 
|  | 248 | +``` | 
|  | 249 | +============ Serving Benchmark Result ============ | 
|  | 250 | +Successful requests:                     xxxxxx | 
|  | 251 | +Benchmark duration (s):                  xxx.xx | 
|  | 252 | +Total input tokens:                      xxxxxx | 
|  | 253 | +Total generated tokens:                  xxxxxx | 
|  | 254 | +Request throughput (req/s):              xxx.xx | 
|  | 255 | +Output token throughput (tok/s):         xxx.xx | 
|  | 256 | +Total Token throughput (tok/s):          xxx.xx | 
|  | 257 | +---------------Time to First Token---------------- | 
|  | 258 | +Mean TTFT (ms):                          xxx.xx | 
|  | 259 | +Median TTFT (ms):                        xxx.xx | 
|  | 260 | +P99 TTFT (ms):                           xxx.xx | 
|  | 261 | +-----Time per Output Token (excl. 1st token)------ | 
|  | 262 | +Mean TPOT (ms):                          xxx.xx | 
|  | 263 | +Median TPOT (ms):                        xxx.xx | 
|  | 264 | +P99 TPOT (ms):                           xxx.xx | 
|  | 265 | +---------------Inter-token Latency---------------- | 
|  | 266 | +Mean ITL (ms):                           xxx.xx | 
|  | 267 | +Median ITL (ms):                         xxx.xx | 
|  | 268 | +P99 ITL (ms):                            xxx.xx | 
|  | 269 | +----------------End-to-end Latency---------------- | 
|  | 270 | +Mean E2EL (ms):                          xxx.xx | 
|  | 271 | +Median E2EL (ms):                        xxx.xx | 
|  | 272 | +P99 E2EL (ms):                           xxx.xx | 
|  | 273 | +================================================== | 
|  | 274 | +``` | 
|  | 275 | + | 
|  | 276 | +Explanations for key metrics: | 
|  | 277 | + | 
|  | 278 | +- `Median Time to First Token (TTFT)`: The typical time elapsed from when a request is sent until the first output token is generated. | 
|  | 279 | +- `Median Time Per Output Token (TPOT)`: The typical time required to generate each token after the first one.  | 
|  | 280 | +- `Median Inter-Token Latency (ITL)`: The typical time delay between the completion of one token and the completion of the next. | 
|  | 281 | +- `Median End-to-End Latency (E2EL)`: The typical total time from when a request is submitted until the final token of the response is received.  | 
|  | 282 | +- `Total Token Throughput`: The combined rate at which the system processes both input (prompt) tokens and output (generated) tokens.  | 
|  | 283 | + | 
0 commit comments