Skip to content

Commit 4382192

Browse files
committed
Patching mamba2 and zamba2
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
1 parent c0ead4d commit 4382192

File tree

3 files changed

+28
-38
lines changed

3 files changed

+28
-38
lines changed

vllm/model_executor/models/bamba.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def forward(
109109
hidden_states: torch.Tensor,
110110
residual: Optional[torch.Tensor],
111111
mamba_cache_params: MambaCacheParams,
112-
mamba2_metadata: Optional[Mamba2Metadata] = None,
112+
mamba2_metadata: Mamba2Metadata,
113113
**kwargs,
114114
):
115115
if residual is None:

vllm/model_executor/models/mamba2.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from vllm.forward_context import get_forward_context
1414
from vllm.model_executor.layers.layernorm import RMSNorm
1515
from vllm.model_executor.layers.logits_processor import LogitsProcessor
16+
from vllm.model_executor.layers.mamba.mamba2_metadata import (
17+
Mamba2Metadata, prepare_mamba2_metadata)
1618
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
1719
MambaMixer2, extra_groups_for_head_shards)
1820
from vllm.model_executor.layers.quantization.base_config import (
@@ -57,7 +59,6 @@ def __init__(self,
5759
head_dim=config.head_dim,
5860
rms_norm_eps=config.layer_norm_epsilon,
5961
activation=config.hidden_act,
60-
chunk_size=config.chunk_size,
6162
quant_config=quant_config)
6263

6364
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
@@ -67,7 +68,7 @@ def forward(
6768
hidden_states: torch.Tensor,
6869
residual: Optional[torch.Tensor],
6970
mamba_cache_params: MambaCacheParams,
70-
sequence_idx: Optional[torch.Tensor],
71+
mamba2_metadata: Mamba2Metadata,
7172
**kwargs,
7273
):
7374
if residual is None:
@@ -77,7 +78,7 @@ def forward(
7778
hidden_states, residual = self.norm(hidden_states, residual)
7879

7980
hidden_states = self.mixer(hidden_states, mamba_cache_params,
80-
sequence_idx)
81+
mamba2_metadata)
8182
return hidden_states, residual
8283

8384

@@ -138,20 +139,14 @@ def forward(
138139
hidden_states = intermediate_tensors["hidden_states"]
139140
residual = intermediate_tensors["residual"]
140141

141-
# pass a sequence index tensor, that is required for
142-
# proper continuous batching computation including
143-
# chunked prefill
144-
seq_idx = None
145142
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
146-
if attn_metadata.num_prefills > 0:
147-
seq_idx = torch.zeros_like(input_ids, dtype=torch.int32)
148-
for i, (srt, end) in enumerate(
149-
zip(
150-
attn_metadata.query_start_loc,
151-
attn_metadata.query_start_loc[1:],
152-
)):
153-
seq_idx[srt:end] = i
154-
seq_idx.unsqueeze_(0)
143+
144+
mamba2_metadata = prepare_mamba2_metadata(
145+
chunk_size=self.config.chunk_size,
146+
has_prefills=attn_metadata.num_prefills > 0,
147+
input_ids=input_ids,
148+
query_start_loc=attn_metadata.query_start_loc,
149+
)
155150

156151
for i in range(len(self.layers)):
157152
layer = self.layers[i]
@@ -162,7 +157,7 @@ def forward(
162157
residual=residual,
163158
mamba_cache_params=mamba_cache_params.at_layer_idx(
164159
i - self.start_layer),
165-
sequence_idx=seq_idx)
160+
mamba2_metadata=mamba2_metadata)
166161

167162
if not get_pp_group().is_last_rank:
168163
return IntermediateTensors({

vllm/model_executor/models/zamba2.py

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
ReplicatedLinear,
2626
RowParallelLinear)
2727
from vllm.model_executor.layers.logits_processor import LogitsProcessor
28+
from vllm.model_executor.layers.mamba.mamba2_metadata import (
29+
Mamba2Metadata, prepare_mamba2_metadata)
2830
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
2931
MambaMixer2, extra_groups_for_head_shards)
3032
from vllm.model_executor.layers.quantization import QuantizationConfig
@@ -495,7 +497,6 @@ def __init__(
495497
head_dim=intermediate_size // config.n_mamba_heads,
496498
rms_norm_eps=config.rms_norm_eps,
497499
activation="silu",
498-
chunk_size=config.chunk_size,
499500
quant_config=quant_config,
500501
)
501502

@@ -507,7 +508,7 @@ def forward(
507508
self,
508509
hidden_states: torch.Tensor,
509510
mamba_cache_params: MambaCacheParams,
510-
sequence_idx: Optional[torch.Tensor] = None,
511+
mamba2_metadata: Mamba2Metadata,
511512
transformer_hidden_states: Optional[torch.Tensor] = None,
512513
positions: Optional[torch.Tensor] = None,
513514
original_hidden_states: Optional[torch.Tensor] = None,
@@ -547,7 +548,7 @@ def forward(
547548
hidden_states = self.mamba(
548549
hidden_states,
549550
mamba_cache_params=mamba_cache_params,
550-
sequence_idx=sequence_idx,
551+
mamba2_metadata=mamba2_metadata,
551552
)
552553

553554
# residual connection after mamba
@@ -594,8 +595,8 @@ def forward(
594595
hidden_states: torch.Tensor,
595596
original_hidden_states: torch.Tensor,
596597
positions: torch.Tensor,
597-
mamba_cache_params: Optional[MambaCacheParams] = None,
598-
sequence_idx: Optional[torch.Tensor] = None,
598+
mamba_cache_params: MambaCacheParams,
599+
mamba2_metadata: Mamba2Metadata,
599600
) -> torch.Tensor:
600601
"""Forward pass through the hybrid layer.
601602
@@ -634,7 +635,7 @@ def forward(
634635
hidden_states,
635636
transformer_hidden_states=transformer_hidden_states,
636637
mamba_cache_params=mamba_cache_params,
637-
sequence_idx=sequence_idx,
638+
mamba2_metadata=mamba2_metadata,
638639
)
639640

640641
return layer_outputs
@@ -747,20 +748,14 @@ def forward(
747748
inputs_embeds = self.get_input_embeddings(input_ids)
748749
hidden_states = inputs_embeds
749750

750-
# pass a sequence index tensor, that is required for
751-
# proper continuous batching computation including
752-
# chunked prefill
753-
seq_idx = None
754751
attn_metadata = get_forward_context().attn_metadata
755-
if attn_metadata.num_prefills > 0:
756-
seq_idx = torch.zeros_like(input_ids, dtype=torch.int32)
757-
for i, (srt, end) in enumerate(
758-
zip(
759-
attn_metadata.query_start_loc,
760-
attn_metadata.query_start_loc[1:],
761-
)):
762-
seq_idx[srt:end] = i
763-
seq_idx.unsqueeze_(0)
752+
753+
mamba2_metadata = prepare_mamba2_metadata(
754+
chunk_size=self.config.chunk_size,
755+
has_prefills=attn_metadata.num_prefills > 0,
756+
input_ids=input_ids,
757+
query_start_loc=attn_metadata.query_start_loc,
758+
)
764759

765760
# Process through layers
766761
original_hidden_states = torch.clone(hidden_states)
@@ -770,7 +765,7 @@ def forward(
770765
original_hidden_states=original_hidden_states,
771766
positions=positions,
772767
mamba_cache_params=mamba_cache_params.at_layer_idx(layer_idx),
773-
sequence_idx=seq_idx,
768+
mamba2_metadata=mamba2_metadata,
774769
)
775770
hidden_states = layer_outputs
776771

0 commit comments

Comments
 (0)