Skip to content

Commit 28c13c8

Browse files
committed
Reduce redundancy in mamba2 blocks
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
1 parent 3f04a7f commit 28c13c8

File tree

4 files changed

+68
-42
lines changed

4 files changed

+68
-42
lines changed

vllm/model_executor/layers/mamba/mamba_mixer2.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,8 @@ def forward_cuda(
389389
hidden_states: torch.Tensor,
390390
mamba_cache_params: MambaCacheParams,
391391
sequence_idx: Optional[torch.Tensor] = None,
392+
chunk_indices: Optional[torch.Tensor] = None,
393+
chunk_offsets: Optional[torch.Tensor] = None,
392394
):
393395
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
394396

@@ -490,6 +492,8 @@ def forward_cuda(
490492
z=None,
491493
dt_bias=self.dt_bias,
492494
seq_idx=sequence_idx,
495+
chunk_indices=chunk_indices,
496+
chunk_offsets=chunk_offsets,
493497
cu_seqlens=attn_metadata.query_start_loc,
494498
initial_states=initial_states,
495499
return_varlen_states=True,

vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py

Lines changed: 6 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55

66
# ruff: noqa: E501,SIM102
77

8-
import math
9-
108
import torch
119
import triton
1210
import triton.language as tl
@@ -442,40 +440,6 @@ def _chunk_scan_fwd_kernel(
442440
(offs_out_n[None, :] < hdim))
443441

444442

445-
def _seq_idx_to_chunk_indices_offsets(seq_idx, chunk_size: int):
446-
447-
# convert seq_idx to chunk indices and offsets
448-
# - derive the cu_seqlens
449-
_, cu_seqlens = torch.where(seq_idx.diff())
450-
cu_seqlens += 1
451-
452-
# outputs will have length expansion of chunks that do not divide
453-
# chunk_size
454-
N = math.ceil(seq_idx.shape[-1] / chunk_size) + (cu_seqlens % chunk_size
455-
> 0).sum()
456-
chunk_indices = torch.arange(N, dtype=torch.int, device=seq_idx.device)
457-
chunk_offsets = torch.zeros((N, ), dtype=torch.int, device=seq_idx.device)
458-
459-
cu_seqlens = cu_seqlens.tolist() + [seq_idx.shape[-1]]
460-
p = 0 # num of insertions
461-
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
462-
463-
# if does not divide chunk_size, then there is one chunk insertion
464-
p += (s % chunk_size > 0)
465-
466-
# get the dimensions
467-
# - the + 1 for _e is to shift the boundary by one chunk
468-
# - this shifting is not needed if chunk_size divides e
469-
_s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size
470-
> 0)
471-
472-
# adjust inidces and offsets
473-
chunk_indices[_s:_e] -= p
474-
chunk_offsets[_s] = s % chunk_size
475-
476-
return chunk_indices, chunk_offsets
477-
478-
479443
def _chunk_scan_fwd(
480444
cb,
481445
x,
@@ -486,6 +450,8 @@ def _chunk_scan_fwd(
486450
D=None,
487451
z=None,
488452
seq_idx=None,
453+
chunk_indices=None,
454+
chunk_offsets=None,
489455
initial_states=None,
490456
):
491457
batch, seqlen, nheads, headdim = x.shape
@@ -502,7 +468,6 @@ def _chunk_scan_fwd(
502468
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
503469
assert states.shape == (batch, nchunks, nheads, headdim, dstate)
504470

505-
chunk_indices, chunk_offsets = None, None
506471
if seq_idx is not None:
507472
assert seq_idx.shape == (batch, seqlen)
508473

@@ -516,9 +481,9 @@ def _chunk_scan_fwd(
516481
if initial_states.shape[0] == 1:
517482
# no in this case no point to use initial states
518483
initial_states = None
519-
else:
520-
chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets(
521-
seq_idx, chunk_size)
484+
485+
if initial_states is None:
486+
chunk_indices, chunk_offsets = None, None
522487

523488
# Allocates output.
524489
out = torch.empty(batch,
@@ -544,6 +509,7 @@ def _chunk_scan_fwd(
544509
if chunk_offsets is None else len(chunk_offsets), nheads)
545510
z_strides = ((z.stride(0), z.stride(1), z.stride(2),
546511
z.stride(3)) if z is not None else (0, 0, 0, 0))
512+
547513
_chunk_scan_fwd_kernel[grid](
548514
cb,
549515
x,

vllm/model_executor/layers/mamba/ops/ssd_combined.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ def _mamba_chunk_scan_combined_fwd(x,
3030
dt_bias=None,
3131
initial_states=None,
3232
seq_idx=None,
33+
chunk_indices=None,
34+
chunk_offsets=None,
3335
cu_seqlens=None,
3436
dt_softplus=False,
3537
dt_limit=(0.0, float("inf"))):
@@ -141,6 +143,8 @@ def _mamba_chunk_scan_combined_fwd(x,
141143
D=D,
142144
z=z,
143145
seq_idx=seq_idx,
146+
chunk_indices=chunk_indices,
147+
chunk_offsets=chunk_offsets,
144148
initial_states=initial_states,
145149
)
146150
if cu_seqlens is None:
@@ -170,6 +174,8 @@ def mamba_chunk_scan_combined(x,
170174
dt_bias=None,
171175
initial_states=None,
172176
seq_idx=None,
177+
chunk_indices=None,
178+
chunk_offsets=None,
173179
cu_seqlens=None,
174180
dt_softplus=False,
175181
dt_limit=(0.0, float("inf")),
@@ -210,6 +216,8 @@ def mamba_chunk_scan_combined(x,
210216
dt_bias=dt_bias,
211217
initial_states=initial_states,
212218
seq_idx=seq_idx,
219+
chunk_indices=chunk_indices,
220+
chunk_offsets=chunk_offsets,
213221
cu_seqlens=cu_seqlens,
214222
dt_softplus=dt_softplus,
215223
dt_limit=dt_limit)

vllm/model_executor/models/bamba.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
"""Inference-only Bamba model."""
33
# Added by the IBM Team, 2024
4+
import math
45
from typing import Iterable, Optional, Set, Tuple
56

67
import torch
@@ -109,6 +110,8 @@ def forward(
109110
residual: Optional[torch.Tensor],
110111
mamba_cache_params: MambaCacheParams,
111112
sequence_idx: Optional[torch.Tensor] = None,
113+
chunk_indices: Optional[torch.Tensor] = None,
114+
chunk_offsets: Optional[torch.Tensor] = None,
112115
**kwargs,
113116
):
114117
if residual is None:
@@ -119,7 +122,7 @@ def forward(
119122
hidden_states, residual)
120123

121124
hidden_states = self.mamba(hidden_states, mamba_cache_params,
122-
sequence_idx)
125+
sequence_idx, chunk_indices, chunk_offsets)
123126
# Fully Connected
124127
hidden_states, residual = self.pre_ff_layernorm(
125128
hidden_states, residual)
@@ -254,12 +257,46 @@ def forward(
254257
}
255258

256259

260+
def _seq_idx_to_chunk_indices_offsets(seq_idx, chunk_size: int):
261+
262+
# convert seq_idx to chunk indices and offsets
263+
# - derive the cu_seqlens
264+
_, cu_seqlens = torch.where(seq_idx.diff())
265+
cu_seqlens += 1
266+
267+
# outputs will have length expansion of chunks that do not divide
268+
# chunk_size
269+
N = math.ceil(seq_idx.shape[-1] / chunk_size) + (cu_seqlens % chunk_size
270+
> 0).sum()
271+
chunk_indices = torch.arange(N, dtype=torch.int, device=seq_idx.device)
272+
chunk_offsets = torch.zeros((N, ), dtype=torch.int, device=seq_idx.device)
273+
274+
cu_seqlens = cu_seqlens.tolist() + [seq_idx.shape[-1]]
275+
p = 0 # num of insertions
276+
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
277+
278+
# if does not divide chunk_size, then there is one chunk insertion
279+
p += (s % chunk_size > 0)
280+
281+
# get the dimensions
282+
# - the + 1 for _e is to shift the boundary by one chunk
283+
# - this shifting is not needed if chunk_size divides e
284+
_s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size
285+
> 0)
286+
287+
# adjust inidces and offsets
288+
chunk_indices[_s:_e] -= p
289+
chunk_offsets[_s] = s % chunk_size
290+
291+
return chunk_indices, chunk_offsets
292+
293+
257294
class BambaModel(nn.Module):
258295

259296
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
260297
super().__init__()
261298

262-
config = vllm_config.model_config.hf_config
299+
config: BambaConfig = vllm_config.model_config.hf_config
263300
cache_config = vllm_config.cache_config
264301
quant_config = vllm_config.quant_config
265302
lora_config = vllm_config.lora_config
@@ -313,6 +350,7 @@ def forward(
313350
# proper continuous batching computation including
314351
# chunked prefill
315352
seq_idx = None
353+
chunk_indices, chunk_offsets = None, None
316354
attn_metadata = get_forward_context().attn_metadata
317355
if attn_metadata.num_prefills > 0:
318356
seq_idx = torch.zeros_like(input_ids, dtype=torch.int32)
@@ -323,6 +361,9 @@ def forward(
323361
)):
324362
seq_idx[srt:end] = i
325363
seq_idx.unsqueeze_(0)
364+
# Compute mamba2 metadata tensors that are reused across layers
365+
chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets(
366+
seq_idx, self.config.mamba_chunk_size)
326367

327368
if get_pp_group().is_first_rank:
328369
if inputs_embeds is not None:
@@ -337,6 +378,7 @@ def forward(
337378

338379
residual = None
339380
num_attn = 0
381+
extra_args = {}
340382
for i in range(len(self.layers)):
341383
layer = self.layers[i]
342384
if isinstance(layer, BambaAttentionDecoderLayer):
@@ -346,13 +388,19 @@ def forward(
346388
if isinstance(layer, BambaMixerDecoderLayer):
347389
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
348390
i - num_attn)
391+
extra_args = {
392+
'chunk_indices': chunk_indices,
393+
'chunk_offsets': chunk_offsets,
394+
}
349395

396+
# print(f"{len(extra_args)=}")
350397
hidden_states, residual = layer(
351398
positions=positions,
352399
hidden_states=hidden_states,
353400
residual=residual,
354401
mamba_cache_params=layer_mamba_cache_params,
355402
sequence_idx=seq_idx,
403+
**extra_args,
356404
)
357405

358406
if not get_pp_group().is_last_rank:

0 commit comments

Comments
 (0)