Skip to content

Commit c89a532

Browse files
tdoublepxuebwang-amd
authored andcommitted
[V1] Remove V0 code paths for Hybrid models (vllm-project#25400)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent 828ed45 commit c89a532

31 files changed

+359
-2303
lines changed

tests/models/language/generation/test_hybrid.py

Lines changed: 19 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
SSM_MODELS = [
2121
"state-spaces/mamba-130m-hf",
2222
"tiiuae/falcon-mamba-tiny-dev",
23-
"yujiepan/mamba2-codestral-v0.1-tiny-random",
23+
# mamba2-codestral in transformers is broken pending:
24+
# https://github.com/huggingface/transformers/pull/40861
25+
#"yujiepan/mamba2-codestral-v0.1-tiny-random",
2426
]
2527

2628
HYBRID_MODELS = [
@@ -31,18 +33,7 @@
3133
"ibm-granite/granite-4.0-tiny-preview",
3234
"tiiuae/Falcon-H1-0.5B-Base",
3335
"LiquidAI/LFM2-1.2B",
34-
]
35-
36-
V1_SUPPORTED_MODELS = [
37-
"state-spaces/mamba-130m-hf",
38-
"ai21labs/Jamba-tiny-dev",
39-
"pfnet/plamo-2-1b",
40-
"yujiepan/mamba2-codestral-v0.1-tiny-random",
41-
"Zyphra/Zamba2-1.2B-instruct",
42-
"hmellor/tiny-random-BambaForCausalLM",
43-
"ibm-granite/granite-4.0-tiny-preview",
44-
"tiiuae/Falcon-H1-0.5B-Base",
45-
"LiquidAI/LFM2-1.2B",
36+
"tiny-random/qwen3-next-moe",
4637
]
4738

4839
FULL_CUDA_GRAPH_MODELS = [
@@ -51,10 +42,6 @@
5142
"Zyphra/Zamba2-1.2B-instruct",
5243
]
5344

54-
V0_UNSUPPORTED_MODELS = [
55-
"LiquidAI/LFM2-1.2B",
56-
]
57-
5845
FP32_STATE_MODELS = [
5946
"state-spaces/mamba-130m-hf",
6047
"Zyphra/Zamba2-1.2B-instruct",
@@ -88,20 +75,16 @@ def test_models(
8875
hf_outputs = hf_model.generate_greedy_logprobs_limit(
8976
example_prompts, max_tokens, num_logprobs)
9077

91-
if model in V1_SUPPORTED_MODELS:
92-
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
93-
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
94-
example_prompts, max_tokens, num_logprobs)
95-
else:
96-
vllm_v1_outputs = None
78+
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
79+
vllm_outputs = vllm_model.generate_greedy_logprobs(
80+
example_prompts, max_tokens, num_logprobs)
9781

98-
if model in V1_SUPPORTED_MODELS:
99-
check_logprobs_close(
100-
outputs_0_lst=hf_outputs,
101-
outputs_1_lst=vllm_v1_outputs,
102-
name_0="hf",
103-
name_1="vllm-v1",
104-
)
82+
check_logprobs_close(
83+
outputs_0_lst=hf_outputs,
84+
outputs_1_lst=vllm_outputs,
85+
name_0="hf",
86+
name_1="vllm",
87+
)
10588

10689

10790
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
@@ -299,14 +282,14 @@ def test_full_cuda_graph(
299282
example_prompts, max_tokens, num_logprobs)
300283

301284
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
302-
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
285+
vllm_outputs = vllm_model.generate_greedy_logprobs(
303286
example_prompts, max_tokens, num_logprobs)
304287

305288
check_logprobs_close(
306289
outputs_0_lst=hf_outputs,
307-
outputs_1_lst=vllm_v1_outputs,
290+
outputs_1_lst=vllm_outputs,
308291
name_0="hf",
309-
name_1="vllm-v1",
292+
name_1="vllm",
310293
)
311294

312295

@@ -340,12 +323,12 @@ def test_fp32_cache_state(
340323
with vllm_runner(model,
341324
max_num_seqs=MAX_NUM_SEQS,
342325
**{cache_dtype_param: "float32"}) as vllm_model:
343-
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
326+
vllm_outputs = vllm_model.generate_greedy_logprobs(
344327
example_prompts, max_tokens, num_logprobs)
345328

346329
check_logprobs_close(
347330
outputs_0_lst=hf_outputs,
348-
outputs_1_lst=vllm_v1_outputs,
331+
outputs_1_lst=vllm_outputs,
349332
name_0="hf",
350-
name_1="vllm-v1",
333+
name_1="vllm",
351334
)

tests/models/registry.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -312,14 +312,12 @@ def check_available_online(
312312
"PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"),
313313
"PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2"),
314314
"Phi3ForCausalLM": _HfExamplesInfo("microsoft/Phi-3-mini-4k-instruct"),
315-
"Phi4FlashForCausalLM": _HfExamplesInfo("microsoft/Phi-4-mini-flash-reasoning", # noqa: E501
316-
trust_remote_code=True,
317-
v0_only=True,
318-
max_model_len=10240),
319315
"PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct",
320316
trust_remote_code=True),
321317
"Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b",
322-
trust_remote_code=True),
318+
max_transformers_version="4.55.4",
319+
transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers", # noqa: E501
320+
trust_remote_code=True),
323321
"QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat",
324322
max_transformers_version="4.53",
325323
transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers", # noqa: E501
@@ -330,7 +328,8 @@ def check_available_online(
330328
"Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"),
331329
"Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"),
332330
"Qwen3NextForCausalLM": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct",
333-
min_transformers_version="4.56.2"),
331+
extras={"tiny-random": "tiny-random/qwen3-next-moe"}, # noqa: E501
332+
min_transformers_version="4.56.3"),
334333
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"),
335334
"SeedOssForCausalLM": _HfExamplesInfo("ByteDance-Seed/Seed-OSS-36B-Instruct", # noqa: E501
336335
trust_remote_code=True,
@@ -644,7 +643,7 @@ def check_available_online(
644643
trust_remote_code=True,
645644
speculative_model="XiaomiMiMo/MiMo-7B-RL"),
646645
"Qwen3NextMTP": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct",
647-
min_transformers_version="4.56.2"),
646+
min_transformers_version="4.56.3"),
648647
}
649648

650649
_TRANSFORMERS_BACKEND_MODELS = {

vllm/model_executor/layers/mamba/abstract.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,7 @@ class MambaBase(AttentionLayerBase):
2020

2121
# Contains the KV cache (mamba state) for the layer
2222
# in the shape specified by `self.get_state_shape`.
23-
# The outer list is for v0 PP virtual engine. Though this code path
24-
# only runs for v1, we have to do this to unify with the interface
25-
# of Attention + v0 PP.
26-
kv_cache: list[Iterable[torch.Tensor]]
23+
kv_cache: tuple[torch.Tensor, ...]
2724

2825
@abstractmethod
2926
def get_state_shape(self) -> Iterable[tuple[int, ...]]:

vllm/model_executor/layers/mamba/linear_attn.py

Lines changed: 40 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from einops import rearrange
1616
from torch import nn
1717

18-
from vllm import envs
1918
from vllm.attention import AttentionMetadata
2019
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
2120
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
@@ -42,8 +41,6 @@
4241
import torch
4342
import torch.distributed
4443

45-
from vllm.model_executor.models.minimax_cache import MinimaxCacheParams
46-
4744

4845
class MiniMaxText01RMSNormTP(CustomOp):
4946
name = "MiniMaxText01RMSNormTP"
@@ -225,11 +222,10 @@ def __init__(
225222
self.tp_heads:(self.tp_rank + 1) *
226223
self.tp_heads].contiguous()
227224

228-
if envs.VLLM_USE_V1:
229-
compilation_config = get_current_vllm_config().compilation_config
230-
if prefix in compilation_config.static_forward_context:
231-
raise ValueError(f"Duplicate layer name: {prefix}")
232-
compilation_config.static_forward_context[prefix] = self
225+
compilation_config = get_current_vllm_config().compilation_config
226+
if prefix in compilation_config.static_forward_context:
227+
raise ValueError(f"Duplicate layer name: {prefix}")
228+
compilation_config.static_forward_context[prefix] = self
233229

234230
@staticmethod
235231
def weight_direct_load(param: torch.Tensor,
@@ -268,8 +264,7 @@ def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor,
268264
break
269265
if _prefill_idx >= len(state_indices_tensor):
270266
break
271-
# prefills are packed at end of batch in V1
272-
offset = attn_metadata.num_decode_tokens if envs.VLLM_USE_V1 else 0
267+
offset = attn_metadata.num_decode_tokens
273268
_start = attn_metadata.query_start_loc[offset + _prefill_idx]
274269
_end = attn_metadata.query_start_loc[offset + _prefill_idx + 1]
275270
slot_id = state_indices_tensor[offset + _prefill_idx]
@@ -291,10 +286,7 @@ def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor,
291286
hidden_decode = self._decode_infer(q, k, v, kv_cache,
292287
state_indices_tensor,
293288
attn_metadata)
294-
if envs.VLLM_USE_V1:
295-
hidden.insert(0, hidden_decode)
296-
else:
297-
hidden.append(hidden_decode)
289+
hidden.insert(0, hidden_decode)
298290

299291
if not hidden:
300292
return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype)
@@ -304,40 +296,28 @@ def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor,
304296

305297
def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor,
306298
attn_metadata):
307-
if not envs.VLLM_USE_V1:
308-
q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
309-
k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
310-
v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
311-
num_prefills = getattr(attn_metadata, "num_prefills", 0)
312-
slot_id = state_indices_tensor[num_prefills:]
313-
else:
314-
q = q[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
315-
k = k[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
316-
v = v[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
317-
slot_id = state_indices_tensor[:attn_metadata.num_decodes]
299+
q = q[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
300+
k = k[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
301+
v = v[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
302+
slot_id = state_indices_tensor[:attn_metadata.num_decodes]
318303
hidden = linear_decode_forward_triton(q, k, v, kv_cache, self.tp_slope,
319304
slot_id, 32)
320305
return hidden
321306

322307
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
323-
positions: torch.Tensor,
324-
kv_caches: MinimaxCacheParams) -> None:
325-
if not envs.VLLM_USE_V1:
326-
self._forward(hidden_states, output, positions, kv_caches)
327-
else:
328-
torch.ops.vllm.linear_attention(
329-
hidden_states,
330-
output,
331-
positions,
332-
self.prefix,
333-
)
308+
positions: torch.Tensor) -> None:
309+
torch.ops.vllm.linear_attention(
310+
hidden_states,
311+
output,
312+
positions,
313+
self.prefix,
314+
)
334315

335316
def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
336-
positions: torch.Tensor,
337-
kv_caches: Optional[MinimaxCacheParams]) -> None:
317+
positions: torch.Tensor) -> None:
338318
forward_context = get_forward_context()
339319
attn_metadata: AttentionMetadata = forward_context.attn_metadata
340-
if envs.VLLM_USE_V1 and attn_metadata is not None:
320+
if attn_metadata is not None:
341321
assert isinstance(attn_metadata, dict)
342322
attn_metadata = attn_metadata[self.prefix]
343323
assert isinstance(attn_metadata, LinearAttentionMetadata)
@@ -351,32 +331,26 @@ def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
351331
qkvact = torch.nn.functional.silu(qkv32)
352332
qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1))
353333
q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
354-
if envs.VLLM_USE_V1:
355-
if attn_metadata is not None:
356-
kv_cache = self.kv_cache[forward_context.virtual_engine][0]
357-
state_indices_tensor = attn_metadata.state_indices_tensor
358-
359-
num_prefills = getattr(attn_metadata, "num_prefills", 0)
360-
if num_prefills > 0:
361-
num_decode_tokens = getattr(attn_metadata,
362-
"num_decode_tokens", 0)
363-
for prefill_idx in range(num_prefills):
364-
q_start = attn_metadata.query_start_loc[
365-
num_decode_tokens + prefill_idx]
366-
q_end = attn_metadata.query_start_loc[num_decode_tokens
367-
+ prefill_idx +
368-
1]
369-
query_len = q_end - q_start
370-
context_len = attn_metadata.seq_lens[
371-
num_decode_tokens + prefill_idx] - query_len
372-
if context_len == 0:
373-
block_to_clear = state_indices_tensor[
374-
num_decode_tokens + prefill_idx]
375-
kv_cache[block_to_clear, ...] = 0
376-
else:
377-
assert kv_caches is not None
378-
kv_cache = kv_caches.minimax_cache
379-
state_indices_tensor = kv_caches.state_indices_tensor
334+
if attn_metadata is not None:
335+
kv_cache = self.kv_cache[forward_context.virtual_engine][0]
336+
state_indices_tensor = attn_metadata.state_indices_tensor
337+
338+
num_prefills = getattr(attn_metadata, "num_prefills", 0)
339+
if num_prefills > 0:
340+
num_decode_tokens = getattr(attn_metadata, "num_decode_tokens",
341+
0)
342+
for prefill_idx in range(num_prefills):
343+
q_start = attn_metadata.query_start_loc[num_decode_tokens +
344+
prefill_idx]
345+
q_end = attn_metadata.query_start_loc[num_decode_tokens +
346+
prefill_idx + 1]
347+
query_len = q_end - q_start
348+
context_len = attn_metadata.seq_lens[
349+
num_decode_tokens + prefill_idx] - query_len
350+
if context_len == 0:
351+
block_to_clear = state_indices_tensor[num_decode_tokens
352+
+ prefill_idx]
353+
kv_cache[block_to_clear, ...] = 0
380354

381355
decode_only = getattr(attn_metadata, "num_prefills", 0) == 0
382356
if attn_metadata is None:
@@ -410,8 +384,7 @@ def linear_attention(
410384
self = forward_context.no_compile_layers[layer_name]
411385
self._forward(hidden_states=hidden_states,
412386
output=output,
413-
positions=positions,
414-
kv_caches=None)
387+
positions=positions)
415388

416389

417390
def linear_attention_fake(

0 commit comments

Comments
 (0)