Skip to content

Conversation

@FENP
Copy link
Contributor

@FENP FENP commented Oct 16, 2025

Purpose

For FlashMLA backend, #26541 set the default value of reorder_batch_threshold to 512.

class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM
reorder_batch_threshold: int = 512 # process small prefills with decode pathway
# ^ TODO(matt): tune this

However, DCP support reorder_batch_threshold > 1 only when FlashAttnMLA backend is used (#25049). Therefore, the following assertion error occurs when using the FlashMLA backend.

if self.reorder_batch_threshold is not None:
# NOTE(lucas): currently no backend supports the custom masking
# required for DCP with q_len > 1, so we assert here. Remove this
# assert once the custom mask is support is added to FA3.
if (
self.dcp_world_size > 1
and envs.VLLM_ATTENTION_BACKEND != "FLASH_ATTN_MLA"
):
assert self.reorder_batch_threshold == 1, (
"DCP not support reorder_batch_threshold > 1 now."
)

This PR temporarily fixes the issue by setting reorder_batch_threshold back to 1.

Looking forward to DCP supporting reorder_batch_threshold > 1 with FlashMLA in the future :).

Test Plan

export VLLM_ATTENTION_BACKEND="FLASHMLA"
vllm serve /deepseek-ai/DeepSeek-R1/ --gpu-memory-utilization 0.9 --tensor-parallel-size 8 --decode-context-parallel-size 8

curl http://localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '{"model": "/ossfs/workspace/DeepSeek-R1/", "messages": [{"role": "user", "content": "who is there"}], "temperature": 0.0, "max_tokens": 100}'

Test Result

main

...
ERROR 10-16 20:49:34 [multiproc_executor.py:700]  assert self.reorder_batch_threshold == 1, (
ERROR 10-16 20:49:34 [multiproc_executor.py:700] AssertionError: DCP not support reorder_batch_threshold > 1 now.
...
INFO:     127.0.0.1:46314 - "POST /v1/chat/completions HTTP/1.1" 500 Internal Server Error

this PR

INFO:     127.0.0.1:47140 - "POST /v1/chat/completions HTTP/1.1" 200 OK

cc @minosfuture @MatthewBonanni @youkaichao @LucasWilkinson


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@FENP FENP requested a review from LucasWilkinson as a code owner October 16, 2025 12:51
@mergify mergify bot added the v1 label Oct 16, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a bug that causes an assertion error when using the FlashMLA backend with Decode Context Parallelism (DCP) and a reorder_batch_threshold greater than 1. The fix correctly identifies this unsupported configuration and resets reorder_batch_threshold to 1, along with query_len_support, to prevent the crash. My review includes a suggestion to improve the maintainability of the implementation by using a class attribute for feature detection instead of checking the class name as a string. This will make the code more robust against future refactoring.

Comment on lines 561 to 565
if (
self.dcp_world_size > 1 and self.reorder_batch_threshold > 1
and self.__class__.__name__ != "FlashAttnMLAMetadataBuilder"
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Checking the class name as a string (self.__class__.__name__) is fragile and can lead to silent bugs if the class is ever renamed. A more robust and maintainable approach is to use a class attribute to indicate feature support.

You can define a class attribute in MLACommonMetadataBuilder and override it in the specific subclass that supports this feature.

For example:

  1. Add a new class attribute to MLACommonMetadataBuilder (e.g., right after reorder_batch_threshold):
class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
    ...
    reorder_batch_threshold: int = 1
    _supports_dcp_and_reorder: ClassVar[bool] = False
    ...
  1. In the FlashAttnMLAMetadataBuilder class, override this attribute:
class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder):
    ...
    _supports_dcp_and_reorder: ClassVar[bool] = True
    ...
  1. Then, update the condition here to use this new attribute, which is more idiomatic and safer.
Suggested change
if (
self.dcp_world_size > 1 and self.reorder_batch_threshold > 1
and self.__class__.__name__ != "FlashAttnMLAMetadataBuilder"
):
if (
self.dcp_world_size > 1 and self.reorder_batch_threshold > 1
and not self._supports_dcp_and_reorder
):

@FENP FENP force-pushed the bugfix/flashmla_reorder_batch_threshold branch from 8d76fee to 4af3874 Compare October 16, 2025 12:56
Copy link
Contributor

@MatthewBonanni MatthewBonanni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the following code from gpu_model_runner.py,

if self.reorder_batch_threshold is not None:
# NOTE(lucas): currently no backend supports the custom masking
# required for DCP with q_len > 1, so we assert here. Remove this
# assert once the custom mask is support is added to FA3.
if (
self.dcp_world_size > 1
and envs.VLLM_ATTENTION_BACKEND != "FLASH_ATTN_MLA"
):
assert self.reorder_batch_threshold == 1, (
"DCP not support reorder_batch_threshold > 1 now."
)

could we instead just change that assert to set the threshold? The metadata builder's threshold won't be updated, but what ultimately matters is the gpu model runner's threshold. i.e.

 if self.reorder_batch_threshold is not None: 
     # NOTE(lucas): currently no backend supports the custom masking 
     #  required for DCP with q_len > 1, so we assert here. Remove this 
     #  assert once the custom mask is support is added to FA3. 
     if ( 
         self.dcp_world_size > 1 
         and envs.VLLM_ATTENTION_BACKEND != "FLASH_ATTN_MLA" 
     ):
         logger.warning("This backend does not support DCP with q_len > 1. Setting reorder_batch_threshold to 1.")
         self.reorder_batch_threshold = 1

@LucasWilkinson

… enable DCP

Signed-off-by: FENP <32334296+FENP@users.noreply.github.com>
@FENP FENP force-pushed the bugfix/flashmla_reorder_batch_threshold branch from 4af3874 to a2d5ef0 Compare October 16, 2025 13:55
Copy link
Contributor

@minosfuture minosfuture left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants