-
-
Notifications
You must be signed in to change notification settings - Fork 10.8k
[Bugfix][Attention][DCP] Set reorder_batch_threshold back to 1 when using FlashMLA and enable DCP #27023
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Bugfix][Attention][DCP] Set reorder_batch_threshold back to 1 when using FlashMLA and enable DCP #27023
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request addresses a bug that causes an assertion error when using the FlashMLA backend with Decode Context Parallelism (DCP) and a reorder_batch_threshold greater than 1. The fix correctly identifies this unsupported configuration and resets reorder_batch_threshold to 1, along with query_len_support, to prevent the crash. My review includes a suggestion to improve the maintainability of the implementation by using a class attribute for feature detection instead of checking the class name as a string. This will make the code more robust against future refactoring.
| if ( | ||
| self.dcp_world_size > 1 and self.reorder_batch_threshold > 1 | ||
| and self.__class__.__name__ != "FlashAttnMLAMetadataBuilder" | ||
| ): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Checking the class name as a string (self.__class__.__name__) is fragile and can lead to silent bugs if the class is ever renamed. A more robust and maintainable approach is to use a class attribute to indicate feature support.
You can define a class attribute in MLACommonMetadataBuilder and override it in the specific subclass that supports this feature.
For example:
- Add a new class attribute to
MLACommonMetadataBuilder(e.g., right afterreorder_batch_threshold):
class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
...
reorder_batch_threshold: int = 1
_supports_dcp_and_reorder: ClassVar[bool] = False
...- In the
FlashAttnMLAMetadataBuilderclass, override this attribute:
class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder):
...
_supports_dcp_and_reorder: ClassVar[bool] = True
...- Then, update the condition here to use this new attribute, which is more idiomatic and safer.
| if ( | |
| self.dcp_world_size > 1 and self.reorder_batch_threshold > 1 | |
| and self.__class__.__name__ != "FlashAttnMLAMetadataBuilder" | |
| ): | |
| if ( | |
| self.dcp_world_size > 1 and self.reorder_batch_threshold > 1 | |
| and not self._supports_dcp_and_reorder | |
| ): |
8d76fee to
4af3874
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the following code from gpu_model_runner.py,
vllm/vllm/v1/worker/gpu_model_runner.py
Lines 586 to 596 in 00417f4
| if self.reorder_batch_threshold is not None: | |
| # NOTE(lucas): currently no backend supports the custom masking | |
| # required for DCP with q_len > 1, so we assert here. Remove this | |
| # assert once the custom mask is support is added to FA3. | |
| if ( | |
| self.dcp_world_size > 1 | |
| and envs.VLLM_ATTENTION_BACKEND != "FLASH_ATTN_MLA" | |
| ): | |
| assert self.reorder_batch_threshold == 1, ( | |
| "DCP not support reorder_batch_threshold > 1 now." | |
| ) |
could we instead just change that assert to set the threshold? The metadata builder's threshold won't be updated, but what ultimately matters is the gpu model runner's threshold. i.e.
if self.reorder_batch_threshold is not None:
# NOTE(lucas): currently no backend supports the custom masking
# required for DCP with q_len > 1, so we assert here. Remove this
# assert once the custom mask is support is added to FA3.
if (
self.dcp_world_size > 1
and envs.VLLM_ATTENTION_BACKEND != "FLASH_ATTN_MLA"
):
logger.warning("This backend does not support DCP with q_len > 1. Setting reorder_batch_threshold to 1.")
self.reorder_batch_threshold = 1
… enable DCP Signed-off-by: FENP <32334296+FENP@users.noreply.github.com>
4af3874 to
a2d5ef0
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix!
Purpose
For FlashMLA backend, #26541 set the default value of
reorder_batch_thresholdto 512.vllm/vllm/v1/attention/backends/mla/flashmla.py
Lines 71 to 76 in 00417f4
However, DCP support
reorder_batch_threshold> 1 only when FlashAttnMLA backend is used (#25049). Therefore, the following assertion error occurs when using the FlashMLA backend.vllm/vllm/v1/worker/gpu_model_runner.py
Lines 586 to 596 in 00417f4
This PR temporarily fixes the issue by setting
reorder_batch_threshold backto 1.Looking forward to DCP supporting
reorder_batch_threshold> 1 with FlashMLA in the future :).Test Plan
Test Result
main
this PR
INFO: 127.0.0.1:47140 - "POST /v1/chat/completions HTTP/1.1" 200 OKcc @minosfuture @MatthewBonanni @youkaichao @LucasWilkinson
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.