- 
          
- 
                Notifications
    You must be signed in to change notification settings 
- Fork 10.9k
[Bugfix] [Performance]Better MTP Support when use flashmla #24045
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
af963eb
              e39e008
              fbd4dee
              8d5e8a6
              f02168a
              2fa1ed4
              d1dcc97
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -184,8 +184,23 @@ def _forward_decode( | |
| q = torch.cat(q, dim=-1) | ||
|  | ||
| assert isinstance(q, torch.Tensor) | ||
|  | ||
| batch_size = attn_metadata.decode.seq_lens.shape[0] | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you refactor this into a utility function? It will likely need to be called in each backend that supports this feature (FlashInfer-MLA at least), so it will be nice to be able to reuse the logic. | ||
| total_tokens = q.shape[0] | ||
| num_heads = q.shape[1] | ||
| head_dim = q.shape[2] | ||
|  | ||
| # support uniform batch | ||
| if total_tokens % batch_size == 0: | ||
| seq_len = total_tokens // batch_size | ||
| q = q.view(batch_size, seq_len, num_heads, head_dim) | ||
| else: | ||
| raise ValueError( | ||
| f"total_tokens={total_tokens}, batch_size={batch_size}. " | ||
| f"Expected uniform batches with seq_len=1 or seq_len=2.") | ||
|  | ||
| o, lse = flash_mla_with_kvcache( | ||
| q=q.unsqueeze(1), # Add seqlen dim of 1 (decode) | ||
| q=q, | ||
| k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 | ||
| block_table=attn_metadata.decode.block_table, | ||
| cache_seqlens=attn_metadata.decode.seq_lens, | ||
|  | @@ -199,4 +214,6 @@ def _forward_decode( | |
| descale_k=layer._k_scale.reshape(1), | ||
| ) | ||
|  | ||
| o = o.view(total_tokens, num_heads, self.kv_lora_rank) | ||
|  | ||
| return o, lse | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -645,35 +645,41 @@ def subclass_attention_backend( | |
| def split_decodes_and_prefills( | ||
| common_attn_metadata: CommonAttentionMetadata, | ||
| decode_threshold: int = 1, | ||
| require_uniform: bool = False, | ||
| ) -> tuple[int, int, int, int]: | ||
| """ | ||
| Assuming a reordered batch, finds the boundary between prefill and decode | ||
| requests. | ||
|  | ||
| Args: | ||
| common_attn_metadata: CommonAttentionMetadata object containing the | ||
| batch metadata. | ||
| decode_threshold: The maximum query length to be considered a decode. | ||
|  | ||
| require_uniform: If True, only selects decode requests with the same | ||
| query length for uniform batching | ||
| If False, selects all decode requests regardless of | ||
| length variation. | ||
|  | ||
| Returns: | ||
| num_decodes: The number of decode requests. | ||
| num_prefills: The number of prefill requests. | ||
| num_decode_tokens: The number of tokens in the decode requests. | ||
| num_prefill_tokens: The number of tokens in the prefill requests. | ||
| """ | ||
|  | ||
| if require_uniform: | ||
| return split_decodes_and_prefills_uniform(common_attn_metadata, | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. instead of a separate function couldn't we just do something like: 
 but we have to drop that for #24845 anyways There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh sorry I see you want to handle the 
 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (and still dropping  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @LucasWilkinson I think the current implementation is probably correct. In the case of  To handle this more thoroughly you would have to modify the batch reordering code. This PR doesn't, and only does a best-effort pass to read uniform decodes from the front, falling back to prefills if there's a mismatch. I think that is fine for now. Edit* to make the example a better counterexample. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh ya sorry im not doubting the correctness of the current implementation, sorry for the confusion!; I was just suggesting we can just modify the existing implementation and do: instead of the current (and remove  then we wouldn't need the separate function and could achieve the same effect with alot less code (and it would be vectorized) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @LucasWilkinson It's not clear to me why this is doable. You're talking about a modification to  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh it's because 'is_prefill' is fed into 'argmax' to find the split point which should return the index of the first prefill and ignore any subsequent decodes | ||
| decode_threshold) | ||
|  | ||
| max_query_len = common_attn_metadata.max_query_len | ||
| num_reqs = common_attn_metadata.num_reqs | ||
| num_tokens = common_attn_metadata.num_actual_tokens | ||
| query_start_loc = common_attn_metadata.query_start_loc_cpu | ||
|  | ||
| if max_query_len <= decode_threshold: | ||
| return num_reqs, 0, num_tokens, 0 | ||
|  | ||
| query_lens = query_start_loc[1:] - query_start_loc[:-1] | ||
| is_prefill = query_lens > decode_threshold | ||
| if not torch.any(is_prefill): | ||
| return num_reqs, 0, num_tokens, 0 | ||
|  | ||
| first_prefill = is_prefill.int().argmax(dim=-1).item() | ||
| assert torch.all(query_lens[first_prefill:] > decode_threshold) | ||
| assert torch.all(query_lens[:first_prefill] <= decode_threshold) | ||
|  | @@ -684,6 +690,53 @@ def split_decodes_and_prefills( | |
| return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) | ||
|  | ||
|  | ||
| def split_decodes_and_prefills_uniform( | ||
| common_attn_metadata: CommonAttentionMetadata, | ||
| decode_threshold: int = 1, | ||
| ) -> tuple[int, int, int, int]: | ||
| """ | ||
| Similar to split_decodes_and_prefills but ensures decode batch is uniform. | ||
| Only selects decode requests with the same query length. | ||
| """ | ||
| num_reqs = common_attn_metadata.num_reqs | ||
| num_tokens = common_attn_metadata.num_actual_tokens | ||
| query_start_loc = common_attn_metadata.query_start_loc_cpu | ||
| query_lens = query_start_loc[1:] - query_start_loc[:-1] | ||
| # find all candidates that satisfy the threshold | ||
| decode_candidates = query_lens <= decode_threshold | ||
|  | ||
| if not torch.any(decode_candidates): | ||
| return 0, num_reqs, 0, num_tokens | ||
|  | ||
| first_len = None | ||
| first_prefill = 0 | ||
|  | ||
| # find the longest continuous uniform sequence from the front | ||
| for i in range(num_reqs): | ||
| current_len = query_lens[i].item() | ||
| if current_len > decode_threshold: | ||
| # prefill request,stop | ||
| break | ||
| if first_len is None: | ||
| # the first decode request | ||
| first_len = current_len | ||
| first_prefill = 1 | ||
| elif current_len == first_len: | ||
| # same length, continue | ||
| first_prefill = i + 1 | ||
| else: | ||
| # different length, stop | ||
| break | ||
|  | ||
| num_decodes = first_prefill | ||
| num_prefills = num_reqs - num_decodes | ||
| num_decode_tokens = query_start_loc[first_prefill].item( | ||
| ) if first_prefill < len(query_start_loc) else num_tokens | ||
| num_prefill_tokens = num_tokens - num_decode_tokens | ||
|  | ||
| return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) | ||
|  | ||
|  | ||
| def reorder_batch_to_split_decodes_and_prefills( | ||
| input_batch: "InputBatch", | ||
| scheduler_output: "SchedulerOutput", | ||
|  | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this might have negative consequences for backends which do not have kernel support for spec-friendly decodes. If so, we might want to have a per-backend flag to modulate when we apply this. Something like: