-
-
Couldn't load subscription status.
- Fork 10.8k
Add batch invariant kernel override for FlashInfer backend [2/n] #25769
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces an optional batch-invariant mode, controlled by the VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT environment variable, to ensure deterministic outputs regardless of batch size. This is achieved by overriding several performance-optimized but non-deterministic kernels with deterministic alternatives, including custom Triton kernels for matmul, log_softmax, and mean, and by forcing deterministic configurations in attention backends like FlashInfer and FlexAttention. The changes are well-integrated and include a comprehensive test suite to validate the batch invariance. My review focuses on improving the robustness of how the controlling environment variable is parsed in both C++ and Python code to handle common boolean string values and prevent potential issues.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems support wasn't added for the trtllm path in flashinfer. Should we update supports_trtllm_attention to also check against this environment variable so we force flashinfer?
Lines 184 to 192 in 984d184
| @functools.cache | |
| def supports_trtllm_attention() -> bool: | |
| """ | |
| TRTLLM attention is supported if the platform is SM100 and | |
| NVIDIA artifactory is accessible | |
| """ | |
| # Requires SM100 and NVIDIA artifactory to be accessible to download cubins | |
| return current_platform.is_device_capability( | |
| 100) and has_nvidia_artifactory() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, not sure I follow. are you suggesting we force trtllm on top of forcing flashinfer (in the case of batch_invariant=1)?
from what I gather -- trtllm is supported quite cleanly as an option independent of the batch invariance:
https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/flashinfer.py#L540-L541
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh okay maybe I misunderstood. I saw that you only used the new parameters in plan in the if not attn_metadata.prefill_use_trtllm: case, so I assumed that this only works for the non-trtllm backend. If it works for both backends, then my comment can be disregarded
|
This pull request has merge conflicts that must be resolved before it can be |
6760e2d to
bf4df9e
Compare
bf4df9e to
68024d1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just a few nits
35d4192 to
64930d4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the work!
Please also fix the pre-commit issue
Signed-off-by: Bram Wasti <bwasti@meta.com>
|
addressed all comments in the latest! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the work! A few more thoughts
| @@ -42,6 +45,7 @@ | |||
| from vllm.v1.kv_cache_interface import AttentionSpec | |||
|
|
|||
| FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 | |||
| FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT = 2048 * 1024 * 1024 | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a comment for the number here
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Signed-off-by: Bram Wasti <bwasti@fb.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Signed-off-by: Bram Wasti <bwasti@fb.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Signed-off-by: Bram Wasti <bwasti@fb.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for the work!
|
This PR is causing fullgraph test to fail on main: https://buildkite.com/vllm/ci/builds/33518/steps/canvas?sid=0199ad61-7880-4598-9503-66481c15c00c Reverting for now |
…m-project#25769) Signed-off-by: Bram Wasti <bwasti@meta.com> Signed-off-by: Bram Wasti <bwasti@fb.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
…m-project#25769) Signed-off-by: Bram Wasti <bwasti@meta.com> Signed-off-by: Bram Wasti <bwasti@fb.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Signed-off-by: Karan Goel <3261985+karan@users.noreply.github.com>
…m-project#25769) Signed-off-by: Bram Wasti <bwasti@meta.com> Signed-off-by: Bram Wasti <bwasti@fb.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Signed-off-by: Bram Wasti <bwasti@meta.com>
Signed-off-by: Bram Wasti <bwasti@meta.com>
…m-project#25769) Signed-off-by: Bram Wasti <bwasti@meta.com> Signed-off-by: Bram Wasti <bwasti@fb.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
…m-project#25769) Signed-off-by: Bram Wasti <bwasti@meta.com> Signed-off-by: Bram Wasti <bwasti@fb.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
…m-project#25769) Signed-off-by: Bram Wasti <bwasti@meta.com> Signed-off-by: Bram Wasti <bwasti@fb.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
…m-project#25769) Signed-off-by: Bram Wasti <bwasti@meta.com> Signed-off-by: Bram Wasti <bwasti@fb.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Continuing from #25603, this patch extends to the much faster flashinfer backend
(This might look like a big change, but I am going to rebase onto #25603 and most of it will go away, mostly just look at the flashinfer.py file)
Purpose
Add optional determinism to flashinfer backend.
Test Plan
Test Result
Pass.
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.