Skip to content

Commit c0ead4d

Browse files
committed
pack more into mamba2 metadata
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
1 parent 017597e commit c0ead4d

File tree

3 files changed

+51
-45
lines changed

3 files changed

+51
-45
lines changed

vllm/model_executor/layers/mamba/mamba2_metadata.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@
88
@dataclass
99
class Mamba2Metadata:
1010
chunk_size: int
11+
seq_idx: torch.Tensor
1112
chunk_indices: torch.Tensor
1213
chunk_offsets: torch.Tensor
1314

1415

15-
def prepare_mamba2_metadata(seq_idx: torch.Tensor,
16-
chunk_size: int) -> Mamba2Metadata:
16+
def _seq_idx_to_chunk_indices_offsets(seq_idx, chunk_size: int):
17+
1718
# convert seq_idx to chunk indices and offsets
1819
# - derive the cu_seqlens
1920
_, cu_seqlens = torch.where(seq_idx.diff())
@@ -43,6 +44,41 @@ def prepare_mamba2_metadata(seq_idx: torch.Tensor,
4344
chunk_indices[_s:_e] -= p
4445
chunk_offsets[_s] = s % chunk_size
4546

47+
return chunk_indices, chunk_offsets
48+
49+
50+
def prepare_mamba2_metadata(
51+
chunk_size: int,
52+
has_prefills: bool,
53+
input_ids: torch.Tensor,
54+
query_start_loc: torch.Tensor,
55+
) -> Mamba2Metadata:
56+
57+
seq_idx = None
58+
chunk_indices, chunk_offsets = None, None
59+
if has_prefills:
60+
seq_idx = torch.zeros_like(input_ids, dtype=torch.int32)
61+
for i, (srt,
62+
end) in enumerate(zip(
63+
query_start_loc,
64+
query_start_loc[1:],
65+
)):
66+
seq_idx[srt:end] = i
67+
seq_idx.unsqueeze_(0)
68+
69+
# compute metadata for chunked prefill.
70+
# actually this is only needed if there are
71+
# initial states, but this is determinable
72+
# only from attention metadata yet
73+
# unavailable from the top-level model forward.
74+
# Rather than complicating things to extract said
75+
# metadata, we simply just compute redundently and
76+
# will be silently ignored inside the mamba kernels.
77+
# if not needed.
78+
chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets(
79+
seq_idx, chunk_size)
80+
4681
return Mamba2Metadata(chunk_size=chunk_size,
82+
seq_idx=seq_idx,
4783
chunk_indices=chunk_indices,
4884
chunk_offsets=chunk_offsets)

vllm/model_executor/layers/mamba/mamba_mixer2.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,6 @@ def __init__(self,
222222
head_dim: int = 64,
223223
rms_norm_eps: float = 1e-5,
224224
activation="silu",
225-
chunk_size: int = 256,
226225
quant_config: Optional[QuantizationConfig] = None):
227226
super().__init__()
228227

@@ -258,7 +257,6 @@ def __init__(self,
258257
self.ssm_state_size = ssm_state_size
259258
self.activation = activation
260259

261-
self.chunk_size = chunk_size
262260
self.intermediate_size = intermediate_size
263261
self.head_dim = head_dim
264262
self.num_heads = num_heads
@@ -389,8 +387,7 @@ def forward_cuda(
389387
self,
390388
hidden_states: torch.Tensor,
391389
mamba_cache_params: MambaCacheParams,
392-
sequence_idx: Optional[torch.Tensor] = None,
393-
mamba2_metadata: Optional[Mamba2Metadata] = None,
390+
mamba2_metadata: Mamba2Metadata,
394391
):
395392
# For the mamba2 triton kernels to operate in continuous batching,
396393
# the sequence_idx is needed to be passed in. Also, for the kernels
@@ -400,11 +397,6 @@ def forward_cuda(
400397
# layers.
401398
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
402399

403-
chunk_indices, chunk_offsets = None, None
404-
if mamba2_metadata is not None:
405-
chunk_indices = mamba2_metadata.chunk_indices
406-
chunk_offsets = mamba2_metadata.chunk_offsets
407-
408400
seq_len, _ = hidden_states.shape
409401
groups_time_state_size = self.n_groups * self.ssm_state_size
410402

@@ -496,13 +488,13 @@ def forward_cuda(
496488
self.A,
497489
B.view(1, seq_len, self.n_groups // self.tp_size, -1),
498490
C.view(1, seq_len, self.n_groups // self.tp_size, -1),
499-
chunk_size=self.chunk_size,
491+
chunk_size=mamba2_metadata.chunk_size,
500492
D=self.D,
501493
z=None,
502494
dt_bias=self.dt_bias,
503-
seq_idx=sequence_idx,
504-
chunk_indices=chunk_indices,
505-
chunk_offsets=chunk_offsets,
495+
seq_idx=mamba2_metadata.seq_idx,
496+
chunk_indices=mamba2_metadata.chunk_indices,
497+
chunk_offsets=mamba2_metadata.chunk_offsets,
506498
cu_seqlens=attn_metadata.query_start_loc,
507499
initial_states=initial_states,
508500
return_varlen_states=True,

vllm/model_executor/models/bamba.py

Lines changed: 8 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,6 @@ def __init__(self,
9696
head_dim=config.mamba_d_head,
9797
rms_norm_eps=config.rms_norm_eps,
9898
activation=config.hidden_act,
99-
chunk_size=config.mamba_chunk_size,
10099
quant_config=quant_config)
101100

102101
self.feed_forward = BambaMLP(config, quant_config=quant_config)
@@ -110,7 +109,6 @@ def forward(
110109
hidden_states: torch.Tensor,
111110
residual: Optional[torch.Tensor],
112111
mamba_cache_params: MambaCacheParams,
113-
sequence_idx: Optional[torch.Tensor] = None,
114112
mamba2_metadata: Optional[Mamba2Metadata] = None,
115113
**kwargs,
116114
):
@@ -122,7 +120,7 @@ def forward(
122120
hidden_states, residual)
123121

124122
hidden_states = self.mamba(hidden_states, mamba_cache_params,
125-
sequence_idx, mamba2_metadata)
123+
mamba2_metadata)
126124
# Fully Connected
127125
hidden_states, residual = self.pre_ff_layernorm(
128126
hidden_states, residual)
@@ -312,33 +310,14 @@ def forward(
312310
inputs_embeds: Optional[torch.Tensor] = None,
313311
) -> torch.Tensor:
314312

315-
# pass a sequence index tensor, that is required for
316-
# proper continuous batching computation including
317-
# chunked prefill
318-
seq_idx = None
319-
mamba2_metadata = None
320313
attn_metadata = get_forward_context().attn_metadata
321-
if attn_metadata.num_prefills > 0:
322-
seq_idx = torch.zeros_like(input_ids, dtype=torch.int32)
323-
for i, (srt, end) in enumerate(
324-
zip(
325-
attn_metadata.query_start_loc,
326-
attn_metadata.query_start_loc[1:],
327-
)):
328-
seq_idx[srt:end] = i
329-
seq_idx.unsqueeze_(0)
330-
331-
# compute metadata for chunked prefill.
332-
# actually this is only needed if there are
333-
# initial states, but this is determinable
334-
# only from attention metadata yet
335-
# unavailable from the current top-level forward.
336-
# Rather than complicating things to extract said
337-
# metadata, we simply just compute redundently and
338-
# will be silently ignored inside the mamba kernels.
339-
# if not needed.
340-
mamba2_metadata = prepare_mamba2_metadata(
341-
seq_idx, self.config.mamba_chunk_size)
314+
315+
mamba2_metadata = prepare_mamba2_metadata(
316+
chunk_size=self.config.mamba_chunk_size,
317+
has_prefills=attn_metadata.num_prefills > 0,
318+
input_ids=input_ids,
319+
query_start_loc=attn_metadata.query_start_loc,
320+
)
342321

343322
if get_pp_group().is_first_rank:
344323
if inputs_embeds is not None:
@@ -368,7 +347,6 @@ def forward(
368347
hidden_states=hidden_states,
369348
residual=residual,
370349
mamba_cache_params=layer_mamba_cache_params,
371-
sequence_idx=seq_idx,
372350
mamba2_metadata=mamba2_metadata,
373351
)
374352

0 commit comments

Comments
 (0)