|
190 | 190 | import functools |
191 | 191 | from abc import abstractmethod |
192 | 192 | from dataclasses import dataclass, field |
| 193 | +from enum import Enum |
193 | 194 | from typing import ClassVar, Generic, TypeVar |
194 | 195 |
|
195 | 196 | import torch |
|
224 | 225 | from vllm.v1.attention.backends.utils import ( |
225 | 226 | AttentionMetadataBuilder, |
226 | 227 | CommonAttentionMetadata, |
227 | | - QueryLenSupport, |
228 | | - ReorderSpec, |
229 | 228 | get_per_layer_parameters, |
230 | 229 | infer_global_hyperparameters, |
231 | 230 | split_decodes_and_prefills, |
232 | 231 | ) |
233 | 232 | from vllm.v1.kv_cache_interface import AttentionSpec |
234 | 233 |
|
| 234 | + |
| 235 | +class QueryLenSupport(Enum): |
| 236 | + """Defines the level of query length support for an attention backend's |
| 237 | + decode pipeline. |
| 238 | +
|
| 239 | + - SINGLE_ONLY: Decode pipeline only supports single-token queries |
| 240 | + (query_len=1) |
| 241 | + - UNIFORM: Decode pipeline supports uniform multi-token queries |
| 242 | + (all requests must have same query_len > 1) |
| 243 | + - VARLEN: Decode pipeline supports variable-length queries |
| 244 | + (mixed query lengths in same batch) |
| 245 | + """ |
| 246 | + |
| 247 | + SINGLE_ONLY = "single_only" |
| 248 | + UNIFORM = "uniform" |
| 249 | + VARLEN = "varlen" |
| 250 | + |
| 251 | + |
235 | 252 | try: |
236 | 253 | from vllm.vllm_flash_attn import flash_attn_varlen_func |
237 | 254 |
|
@@ -465,14 +482,20 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): |
465 | 482 | understand this class |
466 | 483 | """ |
467 | 484 |
|
| 485 | + # Defines the level of query length support for this backend. |
| 486 | + # - SINGLE_ONLY: Only single-token queries (no spec decode support) |
| 487 | + # - UNIFORM: Supports uniform multi-token queries (spec decode with uniform lengths) |
| 488 | + # - VARLEN: Supports variable-length queries (spec decode with mixed lengths) |
| 489 | + # If set to UNIFORM or VARLEN, this will increase `reorder_batch_threshold` when |
| 490 | + # speculative decoding is enabled. |
| 491 | + query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.SINGLE_ONLY |
| 492 | + |
468 | 493 | # The threshold for reordering the batch into decode and prefill requests. |
469 | 494 | # If > 1, the batch will be reordered such that requests with |
470 | 495 | # query length <= threshold are classified as decode requests. |
471 | | - # Use `decode_query_len_support` (above) to set this automatically |
| 496 | + # Use `query_len_support` (above) to set this automatically |
472 | 497 | # when speculative decoding is enabled. |
473 | | - reorder_spec: ClassVar[ReorderSpec] = ReorderSpec( |
474 | | - 1, decode_query_len_support=QueryLenSupport.SINGLE_ONLY |
475 | | - ) |
| 498 | + reorder_batch_threshold: int = 1 |
476 | 499 |
|
477 | 500 | @staticmethod |
478 | 501 | def determine_chunked_prefill_workspace_size(vllm_config: VllmConfig) -> int: |
@@ -597,19 +620,16 @@ def __init__( |
597 | 620 | device=device, |
598 | 621 | ) |
599 | 622 |
|
600 | | - assert self.reorder_spec.decode_threshold is not None |
601 | | - supports_spec_decode = ( |
602 | | - self.reorder_spec.decode_query_len_support != QueryLenSupport.SINGLE_ONLY |
603 | | - ) |
604 | | - self._init_decode_threshold( |
605 | | - self.reorder_spec.decode_threshold, supports_spec_decode |
| 623 | + supports_spec_decode = self.query_len_support != QueryLenSupport.SINGLE_ONLY |
| 624 | + self._init_reorder_batch_threshold( |
| 625 | + self.reorder_batch_threshold, supports_spec_decode |
606 | 626 | ) |
607 | 627 |
|
608 | | - # Validate consistency between decode_query_len_support and decode_threshold |
609 | | - if self.reorder_spec.decode_query_len_support == QueryLenSupport.SINGLE_ONLY: |
610 | | - assert self.reorder_spec.decode_threshold == 1, ( |
611 | | - f"decode_threshold must be 1 when decode_query_len_support is " |
612 | | - f"SINGLE_ONLY, got {self.reorder_spec.decode_threshold}" |
| 628 | + # Validate consistency between query_len_support and reorder_batch_threshold |
| 629 | + if self.query_len_support == QueryLenSupport.SINGLE_ONLY: |
| 630 | + assert self.reorder_batch_threshold == 1, ( |
| 631 | + f"reorder_batch_threshold must be 1 when query_len_support is " |
| 632 | + f"SINGLE_ONLY, got {self.reorder_batch_threshold}" |
613 | 633 | ) |
614 | 634 |
|
615 | 635 | def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata): |
@@ -712,14 +732,12 @@ def build_for_cudagraph_capture( |
712 | 732 | Currently, only decode is supported for full cudagraphs with MLA. |
713 | 733 | """ |
714 | 734 | m = common_attn_metadata |
715 | | - assert self.reorder_spec.decode_threshold is not None |
716 | | - assert m.num_reqs <= ( |
717 | | - m.num_actual_tokens * self.reorder_spec.decode_threshold |
718 | | - ), ( |
| 735 | + assert m.num_reqs <= (m.num_actual_tokens * self.reorder_batch_threshold), ( |
719 | 736 | "MLA only supports decode-only full CUDAGraph capture. " |
720 | 737 | "Make sure all cudagraph capture sizes <= max_num_seq." |
721 | 738 | ) |
722 | | - assert m.max_query_len <= self.reorder_spec.decode_threshold # decode only |
| 739 | + |
| 740 | + assert m.max_query_len <= self.reorder_batch_threshold # decode only |
723 | 741 |
|
724 | 742 | return self.build(0, m) |
725 | 743 |
|
@@ -751,14 +769,11 @@ def build( |
751 | 769 |
|
752 | 770 | num_computed_tokens_cpu = common_attn_metadata.seq_lens_cpu - query_seq_lens_cpu |
753 | 771 |
|
754 | | - assert self.reorder_spec.decode_threshold is not None |
755 | 772 | num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( |
756 | 773 | split_decodes_and_prefills( |
757 | 774 | common_attn_metadata, |
758 | | - decode_threshold=self.reorder_spec.decode_threshold, |
759 | | - require_uniform=( |
760 | | - self.reorder_spec.decode_query_len_support != QueryLenSupport.VARLEN |
761 | | - ), |
| 775 | + decode_threshold=self.reorder_batch_threshold, |
| 776 | + require_uniform=(self.query_len_support != QueryLenSupport.VARLEN), |
762 | 777 | ) |
763 | 778 | ) |
764 | 779 |
|
|
0 commit comments