-
-
Notifications
You must be signed in to change notification settings - Fork 11.3k
[Attention][V1] Toggle for v1 attention backend #18275
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
ac2b381
Toggle for v1 attention
gshtras 9540f72
Caching the env variable in the __init__
gshtras dc0f0d1
Merge remote-tracking branch 'origin/main' into attention_toggle_upst…
gshtras f82da97
Better naming and logic extracted to a variable
gshtras 3b264d5
Merge remote-tracking branch 'origin/main' into attention_toggle_upst…
gshtras File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,6 +15,7 @@ | |
| VLLM_NCCL_SO_PATH: Optional[str] = None | ||
| LD_LIBRARY_PATH: Optional[str] = None | ||
| VLLM_USE_TRITON_FLASH_ATTN: bool = False | ||
| VLLM_V1_USE_PREFILL_DECODE_ATTENTION: bool = False | ||
| VLLM_FLASH_ATTN_VERSION: Optional[int] = None | ||
| LOCAL_RANK: int = 0 | ||
| CUDA_VISIBLE_DEVICES: Optional[str] = None | ||
|
|
@@ -290,6 +291,13 @@ def get_vllm_port() -> Optional[int]: | |
| lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in | ||
| ("true", "1")), | ||
|
|
||
| # Use separate prefill and decode kernels for V1 attention instead of | ||
| # the unified triton kernel. | ||
| "VLLM_V1_USE_PREFILL_DECODE_ATTENTION": | ||
| lambda: | ||
| (os.getenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "False").lower() in | ||
| ("true", "1")), | ||
|
|
||
| # Force vllm to use a specific flash-attention version (2 or 3), only valid | ||
| # when using the flash-attention backend. | ||
| "VLLM_FLASH_ATTN_VERSION": | ||
|
|
@@ -323,8 +331,8 @@ def get_vllm_port() -> Optional[int]: | |
|
|
||
| # Whether to log responses from API Server for debugging | ||
| "VLLM_DEBUG_LOG_API_SERVER_RESPONSE": | ||
| lambda: os.environ.get("VLLM_DEBUG_LOG_API_SERVER_RESPONSE", "False"). | ||
| lower() == "true", | ||
| lambda: os.environ.get("VLLM_DEBUG_LOG_API_SERVER_RESPONSE", "False" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: accidental change?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ruff or yapf change, reformatting the whole file now results in this |
||
| ).lower() == "true", | ||
|
|
||
| # S3 access information, used for tensorizer to load model from S3 | ||
| "S3_ACCESS_KEY_ID": | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit:
VLLM_V1_TRITON_ATTN_FORCE_PREFILL_DECODEsounds slightly more accurate to me, but feel free to use the name that works best in your eyes