Skip to content

Commit 4eccb9b

Browse files
committed
remove the changes in ReorderSpec
Signed-off-by: ganyi <ygan@amd.com>
1 parent 72cb937 commit 4eccb9b

File tree

16 files changed

+82
-104
lines changed

16 files changed

+82
-104
lines changed

tests/v1/attention/test_mla_backends.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
352352
simulated paged KV cache.
353353
5. Comparing the vLLM backend's output to the ground-truth SDPA output.
354354
"""
355-
from vllm.v1.attention.backends.utils import QueryLenSupport
355+
from vllm.v1.attention.backends.mla.common import QueryLenSupport
356356

357357
batch_spec = BATCH_SPECS[batch_spec_name]
358358
is_spec_decode_test = batch_spec_name.startswith("spec_decode")
@@ -372,7 +372,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
372372
block_size=block_size,
373373
)
374374

375-
# For spec decode tests, add a speculative_config to set the decode_threshold
375+
# For spec decode tests, add a speculative_config to set the reorder_batch_threshold
376376
if is_spec_decode_test:
377377
from vllm.config import SpeculativeConfig
378378

@@ -460,24 +460,20 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
460460
for backend_idx, backend in enumerate(BACKENDS_TO_TEST):
461461
builder_cls, _ = try_get_attention_backend(backend)
462462
if is_spec_decode_test:
463-
decode_query_len_support = getattr(
464-
builder_cls.reorder_spec,
465-
"decode_query_len_support",
466-
QueryLenSupport.SINGLE_ONLY,
463+
query_len_support = getattr(
464+
builder_cls, "query_len_support", QueryLenSupport.SINGLE_ONLY
467465
)
468-
supports_spec = decode_query_len_support != QueryLenSupport.SINGLE_ONLY
466+
supports_spec = query_len_support != QueryLenSupport.SINGLE_ONLY
469467
is_decode.append(supports_spec)
470468
else:
471-
threshold = getattr(builder_cls.reorder_spec, "decode_threshold", None)
472-
decode_query_len_support = getattr(
473-
builder_cls.reorder_spec,
474-
"decode_query_len_support",
475-
QueryLenSupport.SINGLE_ONLY,
469+
threshold = getattr(builder_cls, "reorder_batch_threshold", None)
470+
query_len_support = getattr(
471+
builder_cls, "query_len_support", QueryLenSupport.SINGLE_ONLY
476472
)
477473
within_threshold = q_len <= threshold if threshold else False
478474
if (
479475
within_threshold
480-
and decode_query_len_support == QueryLenSupport.UNIFORM
476+
and query_len_support == QueryLenSupport.UNIFORM
481477
and i > 0
482478
):
483479
first_q_len = query_lens[0]

vllm/v1/attention/backends/cpu_attn.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
from dataclasses import dataclass
4-
from typing import ClassVar, Optional
4+
from typing import Optional
55

66
import numpy as np
77
import torch
@@ -20,7 +20,6 @@
2020
from vllm.v1.attention.backends.utils import (
2121
AttentionMetadataBuilder,
2222
CommonAttentionMetadata,
23-
ReorderSpec,
2423
split_decodes_and_prefills,
2524
)
2625
from vllm.v1.kv_cache_interface import AttentionSpec
@@ -349,7 +348,7 @@ def get_seq_len_block_table_args(
349348

350349

351350
class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
352-
reorder_spec: ClassVar[ReorderSpec] = ReorderSpec(1)
351+
reorder_batch_threshold: int = 1
353352

354353
def __init__(
355354
self,
@@ -361,7 +360,7 @@ def __init__(
361360
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
362361

363362
self.scheduler_config = vllm_config.scheduler_config
364-
self._init_decode_threshold(1, False)
363+
self._init_reorder_batch_threshold(1, False)
365364

366365
self.seq_start_loc_cpu = torch.zeros(
367366
vllm_config.scheduler_config.max_num_seqs + 1,
@@ -385,11 +384,10 @@ def build(
385384
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
386385
query_start_loc_np = query_start_loc_cpu.numpy()
387386

388-
assert self.reorder_spec.decode_threshold is not None
389387
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
390388
split_decodes_and_prefills(
391389
common_attn_metadata,
392-
decode_threshold=self.reorder_spec.decode_threshold,
390+
decode_threshold=self.reorder_batch_threshold,
393391
require_uniform=True,
394392
)
395393
)

vllm/v1/attention/backends/flashinfer.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
AttentionCGSupport,
4545
AttentionMetadataBuilder,
4646
CommonAttentionMetadata,
47-
ReorderSpec,
4847
get_kv_cache_layout,
4948
get_per_layer_parameters,
5049
infer_global_hyperparameters,
@@ -276,7 +275,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
276275
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
277276
)
278277

279-
reorder_spec: ClassVar[ReorderSpec] = ReorderSpec(1)
278+
reorder_batch_threshold: int = 1
280279

281280
def __init__(
282281
self,
@@ -354,7 +353,7 @@ def __init__(
354353
else:
355354
self.q_data_type = self.model_config.dtype
356355

357-
self._init_decode_threshold(1, supports_spec_as_decode=can_use_trtllm)
356+
self._init_reorder_batch_threshold(1, supports_spec_as_decode=can_use_trtllm)
358357

359358
self._cascade_wrapper = None # Wrapper for cascade attention
360359

@@ -470,11 +469,10 @@ def build(
470469
) -> FlashInferMetadata:
471470
num_reqs = common_attn_metadata.num_reqs
472471
num_actual_tokens = common_attn_metadata.num_actual_tokens
473-
assert self.reorder_spec.decode_threshold is not None
474472
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
475473
split_decodes_and_prefills(
476474
common_attn_metadata,
477-
decode_threshold=self.reorder_spec.decode_threshold,
475+
decode_threshold=self.reorder_batch_threshold,
478476
require_uniform=True,
479477
)
480478
)
@@ -552,7 +550,7 @@ def build(
552550
paged_kv_last_page_len_np,
553551
)
554552

555-
uses_spec_reorder = self.reorder_spec.decode_threshold > 1
553+
uses_spec_reorder = self.reorder_batch_threshold > 1
556554
prefill_use_trtllm = use_trtllm_attention(
557555
self.num_qo_heads,
558556
self.num_kv_heads,

vllm/v1/attention/backends/gdn_attn.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
"""Backend for GatedDeltaNet attention."""
44

55
from dataclasses import dataclass
6-
from typing import ClassVar
76

87
import torch
98

@@ -14,7 +13,6 @@
1413
AttentionCGSupport,
1514
AttentionMetadataBuilder,
1615
CommonAttentionMetadata,
17-
ReorderSpec,
1816
compute_causal_conv1d_metadata,
1917
split_decodes_and_prefills,
2018
)
@@ -63,7 +61,7 @@ class GDNAttentionMetadata:
6361
class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]):
6462
cudagraph_support = AttentionCGSupport.UNIFORM_BATCH
6563

66-
reorder_spec: ClassVar[ReorderSpec] = ReorderSpec(1)
64+
reorder_batch_threshold: int = 1
6765

6866
def __init__(
6967
self,
@@ -82,7 +80,7 @@ def __init__(
8280
else:
8381
self.num_spec = 0
8482
self.use_spec_decode = self.num_spec > 0
85-
self._init_decode_threshold(1, self.use_spec_decode)
83+
self._init_reorder_batch_threshold(1, self.use_spec_decode)
8684

8785
self.use_full_cuda_graph = (
8886
self.compilation_config.cudagraph_mode.has_full_cudagraphs()

vllm/v1/attention/backends/linear_attn.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
from dataclasses import dataclass
4-
from typing import ClassVar
54

65
import torch
76

@@ -10,7 +9,6 @@
109
from vllm.v1.attention.backends.utils import (
1110
AttentionMetadataBuilder,
1211
CommonAttentionMetadata,
13-
ReorderSpec,
1412
split_decodes_and_prefills,
1513
)
1614
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
@@ -35,7 +33,7 @@ class LinearAttentionMetadata:
3533

3634

3735
class LinearAttentionMetadataBuilder(AttentionMetadataBuilder[LinearAttentionMetadata]):
38-
reorder_spec: ClassVar[ReorderSpec] = ReorderSpec(1)
36+
reorder_batch_threshold: int = 1
3937

4038
def __init__(
4139
self,
@@ -58,11 +56,9 @@ def build(
5856

5957
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
6058

61-
assert self.reorder_spec.decode_threshold is not None
6259
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
6360
split_decodes_and_prefills(
64-
common_attn_metadata,
65-
decode_threshold=self.reorder_spec.decode_threshold,
61+
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
6662
)
6763
)
6864

vllm/v1/attention/backends/mamba1_attn.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,9 @@ def build(
4949
query_start_loc.device
5050
)
5151

52-
assert self.reorder_spec.decode_threshold is not None
5352
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
5453
split_decodes_and_prefills(
55-
common_attn_metadata,
56-
decode_threshold=self.reorder_spec.decode_threshold,
54+
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
5755
)
5856
)
5957

vllm/v1/attention/backends/mamba2_attn.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -223,11 +223,9 @@ def build(
223223
block_idx_last_scheduled_token = None
224224
block_idx_last_computed_token = None
225225

226-
assert self.reorder_spec.decode_threshold is not None
227226
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
228227
split_decodes_and_prefills(
229-
common_attn_metadata,
230-
decode_threshold=self.reorder_spec.decode_threshold,
228+
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
231229
)
232230
)
233231

vllm/v1/attention/backends/mamba_attn.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,14 @@
1111
AttentionCGSupport,
1212
AttentionMetadataBuilder,
1313
CommonAttentionMetadata,
14-
ReorderSpec,
1514
)
1615
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
1716

1817
M = TypeVar("M")
1918

2019

2120
class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
22-
reorder_spec: ClassVar[ReorderSpec] = ReorderSpec(1)
21+
reorder_batch_threshold: int = 1
2322
cudagraph_support: ClassVar[AttentionCGSupport] = (
2423
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
2524
)

vllm/v1/attention/backends/mla/common.py

Lines changed: 42 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@
190190
import functools
191191
from abc import abstractmethod
192192
from dataclasses import dataclass, field
193+
from enum import Enum
193194
from typing import ClassVar, Generic, TypeVar
194195

195196
import torch
@@ -224,14 +225,30 @@
224225
from vllm.v1.attention.backends.utils import (
225226
AttentionMetadataBuilder,
226227
CommonAttentionMetadata,
227-
QueryLenSupport,
228-
ReorderSpec,
229228
get_per_layer_parameters,
230229
infer_global_hyperparameters,
231230
split_decodes_and_prefills,
232231
)
233232
from vllm.v1.kv_cache_interface import AttentionSpec
234233

234+
235+
class QueryLenSupport(Enum):
236+
"""Defines the level of query length support for an attention backend's
237+
decode pipeline.
238+
239+
- SINGLE_ONLY: Decode pipeline only supports single-token queries
240+
(query_len=1)
241+
- UNIFORM: Decode pipeline supports uniform multi-token queries
242+
(all requests must have same query_len > 1)
243+
- VARLEN: Decode pipeline supports variable-length queries
244+
(mixed query lengths in same batch)
245+
"""
246+
247+
SINGLE_ONLY = "single_only"
248+
UNIFORM = "uniform"
249+
VARLEN = "varlen"
250+
251+
235252
try:
236253
from vllm.vllm_flash_attn import flash_attn_varlen_func
237254

@@ -465,14 +482,20 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
465482
understand this class
466483
"""
467484

485+
# Defines the level of query length support for this backend.
486+
# - SINGLE_ONLY: Only single-token queries (no spec decode support)
487+
# - UNIFORM: Supports uniform multi-token queries (spec decode with uniform lengths)
488+
# - VARLEN: Supports variable-length queries (spec decode with mixed lengths)
489+
# If set to UNIFORM or VARLEN, this will increase `reorder_batch_threshold` when
490+
# speculative decoding is enabled.
491+
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.SINGLE_ONLY
492+
468493
# The threshold for reordering the batch into decode and prefill requests.
469494
# If > 1, the batch will be reordered such that requests with
470495
# query length <= threshold are classified as decode requests.
471-
# Use `decode_query_len_support` (above) to set this automatically
496+
# Use `query_len_support` (above) to set this automatically
472497
# when speculative decoding is enabled.
473-
reorder_spec: ClassVar[ReorderSpec] = ReorderSpec(
474-
1, decode_query_len_support=QueryLenSupport.SINGLE_ONLY
475-
)
498+
reorder_batch_threshold: int = 1
476499

477500
@staticmethod
478501
def determine_chunked_prefill_workspace_size(vllm_config: VllmConfig) -> int:
@@ -597,19 +620,16 @@ def __init__(
597620
device=device,
598621
)
599622

600-
assert self.reorder_spec.decode_threshold is not None
601-
supports_spec_decode = (
602-
self.reorder_spec.decode_query_len_support != QueryLenSupport.SINGLE_ONLY
603-
)
604-
self._init_decode_threshold(
605-
self.reorder_spec.decode_threshold, supports_spec_decode
623+
supports_spec_decode = self.query_len_support != QueryLenSupport.SINGLE_ONLY
624+
self._init_reorder_batch_threshold(
625+
self.reorder_batch_threshold, supports_spec_decode
606626
)
607627

608-
# Validate consistency between decode_query_len_support and decode_threshold
609-
if self.reorder_spec.decode_query_len_support == QueryLenSupport.SINGLE_ONLY:
610-
assert self.reorder_spec.decode_threshold == 1, (
611-
f"decode_threshold must be 1 when decode_query_len_support is "
612-
f"SINGLE_ONLY, got {self.reorder_spec.decode_threshold}"
628+
# Validate consistency between query_len_support and reorder_batch_threshold
629+
if self.query_len_support == QueryLenSupport.SINGLE_ONLY:
630+
assert self.reorder_batch_threshold == 1, (
631+
f"reorder_batch_threshold must be 1 when query_len_support is "
632+
f"SINGLE_ONLY, got {self.reorder_batch_threshold}"
613633
)
614634

615635
def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata):
@@ -712,14 +732,12 @@ def build_for_cudagraph_capture(
712732
Currently, only decode is supported for full cudagraphs with MLA.
713733
"""
714734
m = common_attn_metadata
715-
assert self.reorder_spec.decode_threshold is not None
716-
assert m.num_reqs <= (
717-
m.num_actual_tokens * self.reorder_spec.decode_threshold
718-
), (
735+
assert m.num_reqs <= (m.num_actual_tokens * self.reorder_batch_threshold), (
719736
"MLA only supports decode-only full CUDAGraph capture. "
720737
"Make sure all cudagraph capture sizes <= max_num_seq."
721738
)
722-
assert m.max_query_len <= self.reorder_spec.decode_threshold # decode only
739+
740+
assert m.max_query_len <= self.reorder_batch_threshold # decode only
723741

724742
return self.build(0, m)
725743

@@ -751,14 +769,11 @@ def build(
751769

752770
num_computed_tokens_cpu = common_attn_metadata.seq_lens_cpu - query_seq_lens_cpu
753771

754-
assert self.reorder_spec.decode_threshold is not None
755772
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
756773
split_decodes_and_prefills(
757774
common_attn_metadata,
758-
decode_threshold=self.reorder_spec.decode_threshold,
759-
require_uniform=(
760-
self.reorder_spec.decode_query_len_support != QueryLenSupport.VARLEN
761-
),
775+
decode_threshold=self.reorder_batch_threshold,
776+
require_uniform=(self.query_len_support != QueryLenSupport.VARLEN),
762777
)
763778
)
764779

0 commit comments

Comments
 (0)