Skip to content

Commit 017597e

Browse files
committed
Pack mamba2 metadata
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
1 parent 01d5a33 commit 017597e

File tree

4 files changed

+68
-60
lines changed

4 files changed

+68
-60
lines changed
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import math
3+
from dataclasses import dataclass
4+
5+
import torch
6+
7+
8+
@dataclass
9+
class Mamba2Metadata:
10+
chunk_size: int
11+
chunk_indices: torch.Tensor
12+
chunk_offsets: torch.Tensor
13+
14+
15+
def prepare_mamba2_metadata(seq_idx: torch.Tensor,
16+
chunk_size: int) -> Mamba2Metadata:
17+
# convert seq_idx to chunk indices and offsets
18+
# - derive the cu_seqlens
19+
_, cu_seqlens = torch.where(seq_idx.diff())
20+
cu_seqlens += 1
21+
22+
# outputs will have length expansion of chunks that do not divide
23+
# chunk_size
24+
N = math.ceil(seq_idx.shape[-1] / chunk_size) + (cu_seqlens % chunk_size
25+
> 0).sum()
26+
chunk_indices = torch.arange(N, dtype=torch.int, device=seq_idx.device)
27+
chunk_offsets = torch.zeros((N, ), dtype=torch.int, device=seq_idx.device)
28+
29+
cu_seqlens = cu_seqlens.tolist() + [seq_idx.shape[-1]]
30+
p = 0 # num of insertions
31+
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
32+
33+
# if does not divide chunk_size, then there is one chunk insertion
34+
p += (s % chunk_size > 0)
35+
36+
# get the dimensions
37+
# - the + 1 for _e is to shift the boundary by one chunk
38+
# - this shifting is not needed if chunk_size divides e
39+
_s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size
40+
> 0)
41+
42+
# adjust inidces and offsets
43+
chunk_indices[_s:_e] -= p
44+
chunk_offsets[_s] = s % chunk_size
45+
46+
return Mamba2Metadata(chunk_size=chunk_size,
47+
chunk_indices=chunk_indices,
48+
chunk_offsets=chunk_offsets)

vllm/model_executor/layers/mamba/mamba_mixer2.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from vllm.model_executor.custom_op import CustomOp
1919
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
2020
RowParallelLinear)
21+
from vllm.model_executor.layers.mamba.mamba2_metadata import Mamba2Metadata
2122
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
2223
causal_conv1d_fn, causal_conv1d_update)
2324
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
@@ -389,18 +390,21 @@ def forward_cuda(
389390
hidden_states: torch.Tensor,
390391
mamba_cache_params: MambaCacheParams,
391392
sequence_idx: Optional[torch.Tensor] = None,
392-
chunk_indices: Optional[torch.Tensor] = None,
393-
chunk_offsets: Optional[torch.Tensor] = None,
393+
mamba2_metadata: Optional[Mamba2Metadata] = None,
394394
):
395395
# For the mamba2 triton kernels to operate in continuous batching,
396396
# the sequence_idx is needed to be passed in. Also, for the kernels
397-
# to operate in chunked prefill, the chunk_indices and chunk_offsets
398-
# can be optionally passed in; it is more efficient to pre-compute
399-
# once since they are common to all layers. If they are not provided
400-
# then they will be derived from sequence_idx inside the kernels
401-
397+
# to operate in chunked prefill, the mamba2_metadata containing
398+
# chunk_indices and chunk_offsets must be passed in; it is
399+
# more efficient to pre-compute once since they are common to all
400+
# layers.
402401
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
403402

403+
chunk_indices, chunk_offsets = None, None
404+
if mamba2_metadata is not None:
405+
chunk_indices = mamba2_metadata.chunk_indices
406+
chunk_offsets = mamba2_metadata.chunk_offsets
407+
404408
seq_len, _ = hidden_states.shape
405409
groups_time_state_size = self.n_groups * self.ssm_state_size
406410

vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py

Lines changed: 1 addition & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55

66
# ruff: noqa: E501,SIM102
77

8-
import math
9-
108
import torch
119
import triton
1210
import triton.language as tl
@@ -442,40 +440,6 @@ def _chunk_scan_fwd_kernel(
442440
(offs_out_n[None, :] < hdim))
443441

444442

445-
def seq_idx_to_chunk_indices_offsets(seq_idx, chunk_size: int):
446-
447-
# convert seq_idx to chunk indices and offsets
448-
# - derive the cu_seqlens
449-
_, cu_seqlens = torch.where(seq_idx.diff())
450-
cu_seqlens += 1
451-
452-
# outputs will have length expansion of chunks that do not divide
453-
# chunk_size
454-
N = math.ceil(seq_idx.shape[-1] / chunk_size) + (cu_seqlens % chunk_size
455-
> 0).sum()
456-
chunk_indices = torch.arange(N, dtype=torch.int, device=seq_idx.device)
457-
chunk_offsets = torch.zeros((N, ), dtype=torch.int, device=seq_idx.device)
458-
459-
cu_seqlens = cu_seqlens.tolist() + [seq_idx.shape[-1]]
460-
p = 0 # num of insertions
461-
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
462-
463-
# if does not divide chunk_size, then there is one chunk insertion
464-
p += (s % chunk_size > 0)
465-
466-
# get the dimensions
467-
# - the + 1 for _e is to shift the boundary by one chunk
468-
# - this shifting is not needed if chunk_size divides e
469-
_s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size
470-
> 0)
471-
472-
# adjust inidces and offsets
473-
chunk_indices[_s:_e] -= p
474-
chunk_offsets[_s] = s % chunk_size
475-
476-
return chunk_indices, chunk_offsets
477-
478-
479443
def _chunk_scan_fwd(
480444
cb,
481445
x,
@@ -515,16 +479,10 @@ def _chunk_scan_fwd(
515479
if initial_states.shape[0] == 1:
516480
# no in this case no point to use initial states
517481
initial_states = None
518-
elif chunk_indices is None and chunk_offsets is None:
519-
# if chunk_indices and chunk_offsets both unset, then derive
520-
# from seq_idx
521-
chunk_indices, chunk_offsets = seq_idx_to_chunk_indices_offsets(
522-
seq_idx, chunk_size)
523482
else:
524483
assert chunk_indices is not None and chunk_offsets is not None, \
525484
(
526-
"chunk_indices and chunk_offsets should either "
527-
"be left unset, or else both should be set."
485+
"chunk_indices and chunk_offsets should have been set"
528486
)
529487
else:
530488
chunk_indices, chunk_offsets = None, None

vllm/model_executor/models/bamba.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
QKVParallelLinear,
1919
RowParallelLinear)
2020
from vllm.model_executor.layers.logits_processor import LogitsProcessor
21+
from vllm.model_executor.layers.mamba.mamba2_metadata import (
22+
Mamba2Metadata, prepare_mamba2_metadata)
2123
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
2224
MambaMixer2, extra_groups_for_head_shards)
23-
from vllm.model_executor.layers.mamba.ops.ssd_chunk_scan import (
24-
seq_idx_to_chunk_indices_offsets)
2525
from vllm.model_executor.layers.quantization import QuantizationConfig
2626
from vllm.model_executor.layers.rotary_embedding import get_rope
2727
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
@@ -111,8 +111,7 @@ def forward(
111111
residual: Optional[torch.Tensor],
112112
mamba_cache_params: MambaCacheParams,
113113
sequence_idx: Optional[torch.Tensor] = None,
114-
chunk_indices: Optional[torch.Tensor] = None,
115-
chunk_offsets: Optional[torch.Tensor] = None,
114+
mamba2_metadata: Optional[Mamba2Metadata] = None,
116115
**kwargs,
117116
):
118117
if residual is None:
@@ -123,7 +122,7 @@ def forward(
123122
hidden_states, residual)
124123

125124
hidden_states = self.mamba(hidden_states, mamba_cache_params,
126-
sequence_idx, chunk_indices, chunk_offsets)
125+
sequence_idx, mamba2_metadata)
127126
# Fully Connected
128127
hidden_states, residual = self.pre_ff_layernorm(
129128
hidden_states, residual)
@@ -317,7 +316,7 @@ def forward(
317316
# proper continuous batching computation including
318317
# chunked prefill
319318
seq_idx = None
320-
chunk_indices, chunk_offsets = None, None
319+
mamba2_metadata = None
321320
attn_metadata = get_forward_context().attn_metadata
322321
if attn_metadata.num_prefills > 0:
323322
seq_idx = torch.zeros_like(input_ids, dtype=torch.int32)
@@ -338,7 +337,7 @@ def forward(
338337
# metadata, we simply just compute redundently and
339338
# will be silently ignored inside the mamba kernels.
340339
# if not needed.
341-
chunk_indices, chunk_offsets = seq_idx_to_chunk_indices_offsets(
340+
mamba2_metadata = prepare_mamba2_metadata(
342341
seq_idx, self.config.mamba_chunk_size)
343342

344343
if get_pp_group().is_first_rank:
@@ -370,8 +369,7 @@ def forward(
370369
residual=residual,
371370
mamba_cache_params=layer_mamba_cache_params,
372371
sequence_idx=seq_idx,
373-
chunk_indices=chunk_indices,
374-
chunk_offsets=chunk_offsets,
372+
mamba2_metadata=mamba2_metadata,
375373
)
376374

377375
if not get_pp_group().is_last_rank:
@@ -574,4 +572,4 @@ def sample(
574572
def load_weights(self, weights: Iterable[Tuple[str,
575573
torch.Tensor]]) -> Set[str]:
576574
loader = AutoWeightsLoader(self)
577-
return loader.load_weights(weights)
575+
return loader.load_weights(weights)

0 commit comments

Comments
 (0)