Skip to content

Conversation

@pavanimajety
Copy link
Collaborator

@pavanimajety pavanimajety commented Jul 14, 2025

This fix first checks if the current device is SM100 to validate whether kv-cache-dtype=fp8 is supported.

Essential Elements of an Effective PR Description Checklist

  • [N/A] The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

Removes the necessity for specifying VLLM_ATTENTION_BACKEND=FLASHINFER when kv-cache-dtype=auto is used on SM100 GPUs

Test Plan

The accuracy itself is not affected in this PR, it only fixes the bail out condition for SM100 device capability. Without this PR, when the elif branch determines supported=False, it doesn't evaluate other if branches.

Test Result

Before:

WARNING 07-14 16:30:08 [flashinfer.py:141] Using TRTLLM decode attention (auto-detected).
2025-07-14 16:30:08,665 - INFO - flashinfer.jit: Loading from /root/.cache/flashinfer/cubins/4c623163877c8fef5751c9c7a59940cd2baae02e/fmha/trtllm-gen/fmhaSm100Kernel_QkvBfloat16OBfloat16H128PagedKvDenseP16MultiCtasKvVarSeqQ8Kv128StaticSwapsAbForGen.cubin
2025-07-14 16:30:08,666 - INFO - flashinfer.jit: Loading from /root/.cache/flashinfer/cubins/4c623163877c8fef5751c9c7a59940cd2baae02e/fmha/trtllm-gen/fmhaSm100Kernel_QkvBfloat16OBfloat16H128PagedKvDenseP16MultiCtasKvCgaVarSeqQ8Kv128StaticSwapsAbForGen.cubin
2025-07-14 16:30:09,990 - INFO - flashinfer.jit: Loading from /root/.cache/flashinfer/cubins/4c623163877c8fef5751c9c7a59940cd2baae02e/fmha/trtllm-gen/fmhaSm100Kernel_QkvBfloat16OBfloat16H128PagedKvDenseP16VarSeqQ8Kv128PersistentSwapsAbForGen.cubin
Running generate_until requests: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [00:53<00:00,  9.35it/s]

[rank0]:[W714 16:31:02.350916466 ProcessGroupNCCL.cpp:1505] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
INFO:lm_eval.loggers.evaluation_tracker:Output path not provided, skipping saving results aggregated
vllm (pretrained=nvidia/Llama-3.3-70b-Instruct-FP8,quantization=modelopt,max_model_len=2048,kv_cache_dtype=auto,trust_remote_code=True), gen_kwargs: (temperature=0.0), limit: 500.0, num_fewshot: 5, batch_size: 200
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.938|±  |0.0108|
|     |       |strict-match    |     5|exact_match|↑  |0.902|±  |0.0133|

and kv-cache-dtype="fp8"

ache_dir":null}, use_cached_outputs=False,                                                                                                  
INFO 07-14 16:34:50 [cuda.py:379] Cannot use FlashAttention backend for FP8 KV cache.
WARNING 07-14 16:34:50 [cuda.py:381] Please use FlashInfer backend with FP8 KV Cache for better performance by setting environment variable VLLM_ATTENTION_BACKEND=FLASHINFER
INFO 07-14 16:34:50 [cuda.py:395] Using XFormers backend.        
Traceback (most recent call last):                                  
  File "/usr/local/bin/lm_eval", line 8, in <module>                                                                                        
    sys.exit(cli_evaluate())                                       
             ^^^^^^^^^^^^^^             

After:

kv_cache_dtype="auto"

VLLM_USE_STANDALONE_COMPILE=0 lm_eval --model vllm --model_args pretrained=nvidia/Llama-3.3-70b-Instruct-FP8,quantization=modelopt,max_model_len=2048,kv_cache_dtype=auto --gen_kwargs temperature=0.0 --limit 500 --trust
_remote_code --tasks gsm8k --num_fewshot 5 --batch_size 200 
vllm (pretrained=nvidia/Llama-3.3-70b-Instruct-FP8,quantization=modelopt,max_model_len=2048,kv_cache_dtype=auto,trust_remote_code=True), gen_kwargs: (temperature=0.0), limit: 500.0, num_fewshot: 5, batch_size: 200
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.938|±  |0.0108|
|     |       |strict-match    |     5|exact_match|↑  |0.898|±  |0.0135|

kv_cache_dtype="fp8"

$vllm# VLLM_USE_STANDALONE_COMPILE=0 lm_eval --model vllm --model_args pretrained=nvidia/Llama-3.3-70b-Instruct-FP8,quantization=modelopt,max_model_len=2048,kv_cache_dtype=fp8 --gen_kwargs temperature=0.0 --limit 500[325/325$
remote_code --tasks gsm8k --num_fewshot 5 --batch_size 200                                                                                                                                                                                                                              
INFO 07-14 16:10:52 [__init__.py:253] Automatically detected platform cuda.                                                                 
WARNING:lm_eval.__main__: --limit SHOULD ONLY BE USED FOR TESTING.REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.                                                                                                                                                                      
INFO:lm_eval.__main__:Passed `--trust_remote_code`, setting environment variable `HF_DATASETS_TRUST_REMOTE_CODE=true`                                                                                                                                                                   
INFO:lm_eval.__main__:Selected Tasks: ['gsm8k']                                                                                                                                                                                                                                         
WARNING:lm_eval.evaluator:Model appears to be an instruct variant but chat template is not applied. Recommend setting `apply_chat_template` (optionally `fewshot_as_multiturn`).                                                                                                        
INFO:lm_eval.evaluator:Setting random seed to 0 | Setting numpy seed to 1234 | Setting torch manual seed to 1234 | Setting fewshot manual seed to 1234                                                                                                                                  
WARNING:lm_eval.evaluator:generation_kwargs: {'temperature': 0.0} specified through cli, these settings will update set parameters in yaml tasks. Ensure 'do_sample=True' for non-greedy decoding!                                                                                      
INFO:lm_eval.evaluator:Initializing vllm model, with arguments: {'pretrained': 'nvidia/Llama-3.3-70b-Instruct-FP8', 'quantization': 'modelopt', 'max_model_len': 2048, 'kv_cache_dtype': 'fp8', 'trust_remote_code': True}                                                              
INFO 07-14 16:11:00 [config.py:1561] Using max model len 2048                                                                                                                                                                                                                           
**INFO 07-14 16:11:00 [arg_utils.py:1425] Using TRTLLM decode attention ---> debug print(removed in PR)**                                                                                                                                                                                                                
INFO 07-14 16:11:00 [config.py:1687] Using fp8 data type to store kv cache. It reduces the GPU memory footprint and boosts the performance. Meanwhile, it may cause accuracy drop without a proper scaling factor                                                                       
INFO 07-14 16:11:00 [config.py:2380] Chunked prefill is enabled with max_num_batched_tokens=16384.                                                                                                                                                                                      
WARNING 07-14 16:11:01 [modelopt.py:53] Detected ModelOpt fp8 checkpoint. Please note that the format is experimental and could change.     
INFO 07-14 16:11:00 [config.py:1687] Using fp8 data type to store kv cache. It reduces the GPU memory footprint and boosts the performance. Meanwhile, it may cause accuracy drop without a proper scaling factor                                                                       
INFO 07-14 16:11:00 [config.py:2380] Chunked prefill is enabled with max_num_batched_tokens=16384.                                                                                                                                                                                      
WARNING 07-14 16:11:01 [modelopt.py:53] Detected ModelOpt fp8 checkpoint. Please note that the format is experimental and could change.                                                                                                                                                 
WARNING 07-14 16:11:02 [__init__.py:2870] We must use the `spawn` multiprocessing start method. Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. See https://docs.vllm.ai/en/latest/usage/troubleshooting.html#python-multiprocessing for more information. Reason: CUDA is initializ
ed                                                                                                                                                                                                                                                                                      
INFO 07-14 16:11:06 [__init__.py:253] Automatically detected platform cuda.                                                                                                                                                                                                             
INFO 07-14 16:11:08 [core.py:526] Waiting for init message from front-end.                                                                                                                                                                                                              
INFO 07-14 16:11:08 [core.py:69] Initializing a V1 LLM engine (v0.9.2rc2.dev226+g667624659.d20250714) with config: model='nvidia/Llama-3.3-70b-Instruct-FP8', speculative_config=None, tokenizer='nvidia/Llama-3.3-70b-Instruct-FP8', skip_tokenizer_init=False, tokenizer_mode=auto, re
vision=None, override_neuron_config={}, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=2048, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=modelop
t, enforce_eager=False, kv_cache_dtype=fp8,  device_config=cuda, decoding_config=DecodingConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_backend=''), observability_config=ObservabilityConfig(show_hidden_m
etrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=1234, served_model_name=nvidia/Llama-3.3-70b-Instruct-FP8, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_pro
c=True, pooler_config=None, compilation_config={"level":3,"debug_dump_path":"","cache_dir":"","backend":"","custom_ops":[],"splitting_ops":["vllm.unified_attention","vllm.unified_attention_with_output"],"use_inductor":true,"compile_sizes":[],"inductor_compile_config":{"enable_aut
o_functionalized_v2":false},"inductor_passes":{},"use_cudagraph":true,"cudagraph_num_of_warmups":1,"cudagraph_capture_sizes":[512,504,496,488,480,472,464,456,448,440,432,424,416,408,400,392,384,376,368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,248,240,232,224,216,20
8,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],"cudagraph_copy_inputs":false,"full_cuda_graph":false,"max_capture_size":512,"local_cache_dir":null}                                                                                    
[W714 16:11:09.983633192 ProcessGroupNCCL.cpp:959] Warning: TORCH_NCCL_AVOID_RECORD_STREAMS is the default now, this environment variable is thus deprecated. (function operator()) 




vllm (pretrained=nvidia/Llama-3.3-70b-Instruct-FP8,quantization=modelopt,max_model_len=2048,kv_cache_dtype=fp8,trust_remote_code=True), gen_kwargs: (temperature=0.0), limit: 500.0, num_fewshot: 5, batch_size: 200                                                                    
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|                                                                                                                                                                                                                
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|                                                                                                                                                                                                                
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.940|±  |0.0106|                                                                                                                                                                                                                
|     |       |strict-match    |     5|exact_match|↑  |0.904|±  |0.0132|        

(Optional) Documentation Update

…r Backend

Signed-off-by: Pavani Majety <pmajety@nvidia.com>
@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @pavanimajety, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a bug in the kv-cache-dtype auto-detection logic, specifically for SM100 GPUs. It refines how the system determines support for FP8 KV cache, ensuring that users on SM100 devices can leverage kv-cache-dtype=auto without needing to manually specify the FlashInfer backend, thereby improving usability and correctness.

Highlights

  • Bugfix for KV-cache dtype support on SM100: I've reordered the logic for determining kv-cache-dtype=fp8 support to correctly identify SM100 GPUs. This ensures that kv-cache-dtype=auto functions as expected without requiring the VLLM_ATTENTION_BACKEND=FLASHINFER environment variable.
  • Improved kv-cache-dtype auto-detection: The change ensures that the system properly recognizes and enables FP8 KV cache on SM100 devices, streamlining the user experience and preventing unnecessary configuration requirements that previously led to bailouts.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request modifies the logic for detecting FP8 support on SM100 GPUs, removing the requirement for a specific environment variable. The change improves the accuracy of device capability detection.

Signed-off-by: Pavani Majety <pmajety@nvidia.com>
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
@mgoin
Copy link
Member

mgoin commented Jul 14, 2025

I just realized this isn't right either as we need to know that we will be using the FlashInfer backend.. Maybe we could check if FlashInfer is installed and VLLM_ATTENTION_BACKEND is set to FI or unset?

@pavanimajety
Copy link
Collaborator Author

I just realized this isn't right either as we need to know that we will be using the FlashInfer backend.. Maybe we could check if FlashInfer is installed and VLLM_ATTENTION_BACKEND is set to FI or unset?

This path should work because we no longer bail out and proceed with general check based on platform and set the backend to Flashinfer. Why do we need to check if flashinfer is installed or not at this stage?

@mgoin
Copy link
Member

mgoin commented Jul 14, 2025

My concern is if we are on Blackwell and don't have FlashInfer installed on V1, then we are stuck using FA2 which doesn't support FP8 kv cache. I guess if the user has already set kv_cache_dtype="fp8", then we aren't breaking any flow though

@mgoin mgoin added bug Something isn't working ready ONLY add when PR is ready to merge/full CI is needed labels Jul 14, 2025
@mgoin mgoin enabled auto-merge (squash) July 15, 2025 01:37
@mgoin mgoin merged commit 9ad0a45 into vllm-project:main Jul 15, 2025
77 checks passed
x22x22 pushed a commit to x22x22/vllm that referenced this pull request Aug 5, 2025
vllm-project#20934)

Signed-off-by: Pavani Majety <pmajety@nvidia.com>
Signed-off-by: x22x22 <wadeking@qq.com>
Pradyun92 pushed a commit to Pradyun92/vllm that referenced this pull request Aug 6, 2025
npanpaliya pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Aug 6, 2025
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
vllm-project#20934)

Signed-off-by: Pavani Majety <pmajety@nvidia.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
paulpak58 pushed a commit to paulpak58/vllm that referenced this pull request Aug 13, 2025
vllm-project#20934)

Signed-off-by: Pavani Majety <pmajety@nvidia.com>
Signed-off-by: Paul Pak <paulpak58@gmail.com>
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
vllm-project#20934)

Signed-off-by: Pavani Majety <pmajety@nvidia.com>
Signed-off-by: Diego-Castan <diego.castan@ibm.com>
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 27, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants