Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,16 @@ def __init__(self,
parallel_config)
self.mla_dims = get_mla_dims(self.model_config)
self.aot_schedule = current_platform.is_cuda()

self.speculative_config = vllm_config.speculative_config
# Set reorder_batch_threshold based on speculative config
if (self.speculative_config is not None and
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this might have negative consequences for backends which do not have kernel support for spec-friendly decodes. If so, we might want to have a per-backend flag to modulate when we apply this. Something like:

    reorder_batch_threshold: ClassVar[int] = 1
    supports_spec_decodes: ClassVar[bool] = false
...
        self.speculative_config = vllm_config.speculative_config
        # Set reorder_batch_threshold based on speculative config
        if (self.supports_spec_decodes and 
                self.speculative_config is not None and
                self.speculative_config.num_speculative_tokens is not None):
            self.reorder_batch_threshold = (  # type: ignore[misc]
                1 + self.speculative_config.num_speculative_tokens)
        else:
            self.reorder_batch_threshold = 1  # type: ignore[misc]

self.speculative_config.num_speculative_tokens is not None):
self.reorder_batch_threshold = ( # type: ignore[misc]
1 + self.speculative_config.num_speculative_tokens)
else:
self.reorder_batch_threshold = 1 # type: ignore[misc]

try:
self.dcp_world_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group
Expand Down Expand Up @@ -662,9 +672,10 @@ def build(self,
num_computed_tokens_cpu = (common_attn_metadata.seq_lens_cpu -
query_seq_lens_cpu)


num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
split_decodes_and_prefills(common_attn_metadata,
decode_threshold=self.reorder_batch_threshold)
split_decodes_and_prefills(common_attn_metadata,self.reorder_batch_threshold
,require_uniform=True)

# Note(hc): update seq_lens of decode reqs under DCP.
if self.dcp_world_size > 1:
Expand Down
19 changes: 18 additions & 1 deletion vllm/v1/attention/backends/mla/flashmla.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,23 @@ def _forward_decode(
q = torch.cat(q, dim=-1)

assert isinstance(q, torch.Tensor)

batch_size = attn_metadata.decode.seq_lens.shape[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you refactor this into a utility function? It will likely need to be called in each backend that supports this feature (FlashInfer-MLA at least), so it will be nice to be able to reuse the logic.

total_tokens = q.shape[0]
num_heads = q.shape[1]
head_dim = q.shape[2]

# support uniform batch
if total_tokens % batch_size == 0:
seq_len = total_tokens // batch_size
q = q.view(batch_size, seq_len, num_heads, head_dim)
else:
raise ValueError(
f"total_tokens={total_tokens}, batch_size={batch_size}. "
f"Expected uniform batches with seq_len=1 or seq_len=2.")

o, lse = flash_mla_with_kvcache(
q=q.unsqueeze(1), # Add seqlen dim of 1 (decode)
q=q,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
Expand All @@ -199,4 +214,6 @@ def _forward_decode(
descale_k=layer._k_scale.reshape(1),
)

o = o.view(total_tokens, num_heads, self.kv_lora_rank)

return o, lse
63 changes: 58 additions & 5 deletions vllm/v1/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,35 +645,41 @@ def subclass_attention_backend(
def split_decodes_and_prefills(
common_attn_metadata: CommonAttentionMetadata,
decode_threshold: int = 1,
require_uniform: bool = False,
) -> tuple[int, int, int, int]:
"""
Assuming a reordered batch, finds the boundary between prefill and decode
requests.

Args:
common_attn_metadata: CommonAttentionMetadata object containing the
batch metadata.
decode_threshold: The maximum query length to be considered a decode.

require_uniform: If True, only selects decode requests with the same
query length for uniform batching
If False, selects all decode requests regardless of
length variation.

Returns:
num_decodes: The number of decode requests.
num_prefills: The number of prefill requests.
num_decode_tokens: The number of tokens in the decode requests.
num_prefill_tokens: The number of tokens in the prefill requests.
"""

if require_uniform:
return split_decodes_and_prefills_uniform(common_attn_metadata,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of a separate function couldn't we just do something like:

if require_uniform:
      decode_threshold = min(decode_threshold, min(query_lens))

argmax should return the first instance of is_prefill so it should be safe, we just need to drop:

assert torch.all(query_lens[first_prefill:] > decode_threshold)

but we have to drop that for #24845 anyways

Copy link
Collaborator

@LucasWilkinson LucasWilkinson Sep 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh sorry I see you want to handle the

[2, 2, 2, 1, 5] case; I think this quite unlikely but I think we can handle this pretty simply by doing something like

# all prefills fast out
if query_lens[0] > decode_threshold:
       return 0, num_reqs, 0, num_tokens

if require_uniform:
       is_prefill = query_lens != query_lens[0]
else:
       is_prefill = query_lens > decode_threshold

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(and still dropping assert torch.all(query_lens[first_prefill:] > decode_threshold))

Copy link
Collaborator

@benchislett benchislett Sep 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LucasWilkinson I think the current implementation is probably correct. In the case of [1, 2, 2, 1, 10] with decode_threshold = 2, we want to return [1] for decodes and not [2, 2] or [1, 1]. The decodes sequence must be a prefix of the requests since we only return num_decodes and that is used to determine how far from the front we should slice.

To handle this more thoroughly you would have to modify the batch reordering code. This PR doesn't, and only does a best-effort pass to read uniform decodes from the front, falling back to prefills if there's a mismatch. I think that is fine for now.

Edit* to make the example a better counterexample.

Copy link
Collaborator

@LucasWilkinson LucasWilkinson Sep 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh ya sorry im not doubting the correctness of the current implementation, sorry for the confusion!; I was just suggesting we can just modify the existing implementation and do:

# all prefills fast out
if query_lens[0] > decode_threshold: 
    return 0, num_reqs, 0, num_tokens 
if require_uniform: 
    is_prefill = query_lens != query_lens[0] 
else: 
    is_prefill = query_lens > decode_threshold

instead of the current

is_prefill = query_lens > decode_threshold

(and remove assert torch.all(query_lens[first_prefill:] > decode_threshold))

then we wouldn't need the separate function and could achieve the same effect with alot less code (and it would be vectorized)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LucasWilkinson It's not clear to me why this is doable. You're talking about a modification to split_decodes_and_prefills, right? In this case, I think it's possible that it could receive an input like [2, 1, 2, 1, 2, 1] In which case you would need to split into decode [2] and prefills [1, 2, 1, 2, 1]. You would not be able to do seq_lens == 2 and split into [2, 2, 2] and [1, 1, 1] since these are not contiguous in the input request array.

Copy link
Collaborator

@LucasWilkinson LucasWilkinson Sep 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh it's because 'is_prefill' is fed into 'argmax' to find the split point which should return the index of the first prefill and ignore any subsequent decodes

decode_threshold)

max_query_len = common_attn_metadata.max_query_len
num_reqs = common_attn_metadata.num_reqs
num_tokens = common_attn_metadata.num_actual_tokens
query_start_loc = common_attn_metadata.query_start_loc_cpu

if max_query_len <= decode_threshold:
return num_reqs, 0, num_tokens, 0

query_lens = query_start_loc[1:] - query_start_loc[:-1]
is_prefill = query_lens > decode_threshold
if not torch.any(is_prefill):
return num_reqs, 0, num_tokens, 0

first_prefill = is_prefill.int().argmax(dim=-1).item()
assert torch.all(query_lens[first_prefill:] > decode_threshold)
assert torch.all(query_lens[:first_prefill] <= decode_threshold)
Expand All @@ -684,6 +690,53 @@ def split_decodes_and_prefills(
return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens)


def split_decodes_and_prefills_uniform(
common_attn_metadata: CommonAttentionMetadata,
decode_threshold: int = 1,
) -> tuple[int, int, int, int]:
"""
Similar to split_decodes_and_prefills but ensures decode batch is uniform.
Only selects decode requests with the same query length.
"""
num_reqs = common_attn_metadata.num_reqs
num_tokens = common_attn_metadata.num_actual_tokens
query_start_loc = common_attn_metadata.query_start_loc_cpu
query_lens = query_start_loc[1:] - query_start_loc[:-1]
# find all candidates that satisfy the threshold
decode_candidates = query_lens <= decode_threshold

if not torch.any(decode_candidates):
return 0, num_reqs, 0, num_tokens

first_len = None
first_prefill = 0

# find the longest continuous uniform sequence from the front
for i in range(num_reqs):
current_len = query_lens[i].item()
if current_len > decode_threshold:
# prefill request,stop
break
if first_len is None:
# the first decode request
first_len = current_len
first_prefill = 1
elif current_len == first_len:
# same length, continue
first_prefill = i + 1
else:
# different length, stop
break

num_decodes = first_prefill
num_prefills = num_reqs - num_decodes
num_decode_tokens = query_start_loc[first_prefill].item(
) if first_prefill < len(query_start_loc) else num_tokens
num_prefill_tokens = num_tokens - num_decode_tokens

return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens)


def reorder_batch_to_split_decodes_and_prefills(
input_batch: "InputBatch",
scheduler_output: "SchedulerOutput",
Expand Down