Skip to content

Conversation

@mgoin
Copy link
Member

@mgoin mgoin commented Jun 11, 2025

Purpose

In the previous PR to use FlashInfer by default (#19118), this inadventantly prevents FlashAttention from being used if FlashInfer is installed since we don't have an explicit case to check for the selected_backend to be FA.

Test Plan

Test locally on a B200

Test Result

Before (main):

VLLM_ATTENTION_BACKEND=FLASH_ATTN vllm serve meta-llama/Llama-3.1-8B-Instruct
...
INFO 06-11 10:42:28 [cuda.py:240] Using FlashInfer backend on V1 engine by default for Blackwell (SM 10.0) GPUs.

After (PR):

VLLM_ATTENTION_BACKEND=FLASH_ATTN vllm serve meta-llama/Llama-3.1-8B-Instruct
...
INFO 06-11 10:41:12 [cuda.py:237] Using Flash Attention backend on V1 engine.

mgoin added 2 commits June 11, 2025 10:33
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
@mgoin mgoin added bug Something isn't working v1 labels Jun 11, 2025
@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @mgoin, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a bug in the V1 engine's attention backend selection logic. Previously, on Blackwell GPUs, the presence of FlashInfer would prevent a user from explicitly choosing FlashAttention via the environment variable. This change ensures that explicit backend selection takes precedence over the default architecture-based choice, allowing users to force FlashAttention if desired.

Highlights

  • Bugfix: I've fixed an issue where manually selecting the FlashAttention backend using VLLM_ATTENTION_BACKEND=FLASH_ATTN was being overridden by the default FlashInfer selection on Blackwell (SM 10.0) GPUs if FlashInfer was installed.
  • Backend Selection Logic: The logic in vllm/platforms/cuda.py for the V1 engine now correctly prioritizes the explicitly selected backend (selected_backend) before falling back to the default backend determination based on GPU architecture.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configureGemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request correctly addresses the bug where FlashAttention could not be manually selected on Blackwell GPUs when FlashInfer was installed. The changes add the necessary explicit check for the FlashAttention backend and improve the conditional logic by using elif for mutually exclusive backend selections.

logger.info_once("Using Triton backend on V1 engine.")
return ("vllm.v1.attention.backends."
"triton_attn.TritonAttentionBackend")
elif selected_backend == _Backend.FLASH_ATTN:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
elif selected_backend == _Backend.FLASH_ATTN:
elif selected_backend == _Backend.FLASH_ATTN_VLLM_V1:

Seems the v1 FA enum should be FLASH_ATTN_VLLM_V1, same to FLASHINFER_VLLM_V1:

FLASH_ATTN_VLLM_V1 = enum.auto()
TRITON_ATTN_VLLM_V1 = enum.auto()
XFORMERS = enum.auto()
ROCM_FLASH = enum.auto()
ROCM_AITER_MLA = enum.auto() # Supported by V1
ROCM_AITER_MLA_VLLM_V1 = enum.auto()
TORCH_SDPA = enum.auto()
FLASHINFER = enum.auto()
FLASHINFER_VLLM_V1 = enum.auto()

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We seem inconsistent here between using the V1 vs "V0" attention backend names

        if use_v1:
            if selected_backend == _Backend.FLASHINFER:
                logger.info_once("Using FlashInfer backend on V1 engine.")
                return "vllm.v1.attention.backends.flashinfer.FlashInferBackend"
            if selected_backend == _Backend.FLEX_ATTENTION:
                logger.info("Using FlexAttenion backend on V1 engine.")
                return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"  # noqa: E501
            if selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
                logger.info_once("Using Triton backend on V1 engine.")
                return ("vllm.v1.attention.backends."
                        "triton_attn.TritonAttentionBackend")

I'm also not sure that it makes sense as a user to specify FLASH_ATTN and have FlashInfer be used by default on V1 then

Copy link
Member

@Isotr0py Isotr0py Jun 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We seem inconsistent here between using the V1 vs "V0" attention backend names

Yea, and _VLLM_V1 suffix is also sometimes annoying, because it's easy to type _V1_VLLM and I finally found the engine initialized with unexpected backend. 😅

Given we have had use_v1 to control the v1 enablement, I think it should be OK to use enum without _VLLM_V1 suffix.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can do a followup to allow for both variants

@Isotr0py Isotr0py enabled auto-merge (squash) June 12, 2025 03:54
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jun 12, 2025
@houseroad
Copy link
Collaborator

The test failed on the old flaky tests, so retried it, let's see.

@Isotr0py Isotr0py merged commit af09b3f into vllm-project:main Jun 12, 2025
75 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants