2222from vllm .model_executor .layers .mamba .mamba_mixer2 import (
2323 MambaMixer2 , extra_groups_for_head_shards )
2424from vllm .model_executor .layers .quantization import QuantizationConfig
25+ from vllm .model_executor .layers .mamba .ops .ssd_chunk_scan import (
26+ seq_idx_to_chunk_indices_offsets )
2527from vllm .model_executor .layers .rotary_embedding import get_rope
2628from vllm .model_executor .layers .sampler import SamplerOutput , get_sampler
2729from vllm .model_executor .layers .vocab_parallel_embedding import (
@@ -256,41 +258,6 @@ def forward(
256258 "mamba" : BambaMixerDecoderLayer
257259}
258260
259-
260- def _seq_idx_to_chunk_indices_offsets (seq_idx , chunk_size : int ):
261-
262- # convert seq_idx to chunk indices and offsets
263- # - derive the cu_seqlens
264- _ , cu_seqlens = torch .where (seq_idx .diff ())
265- cu_seqlens += 1
266-
267- # outputs will have length expansion of chunks that do not divide
268- # chunk_size
269- N = math .ceil (seq_idx .shape [- 1 ] / chunk_size ) + (cu_seqlens % chunk_size
270- > 0 ).sum ()
271- chunk_indices = torch .arange (N , dtype = torch .int , device = seq_idx .device )
272- chunk_offsets = torch .zeros ((N , ), dtype = torch .int , device = seq_idx .device )
273-
274- cu_seqlens = cu_seqlens .tolist () + [seq_idx .shape [- 1 ]]
275- p = 0 # num of insertions
276- for s , e in zip (cu_seqlens [:- 1 ], cu_seqlens [1 :]):
277-
278- # if does not divide chunk_size, then there is one chunk insertion
279- p += (s % chunk_size > 0 )
280-
281- # get the dimensions
282- # - the + 1 for _e is to shift the boundary by one chunk
283- # - this shifting is not needed if chunk_size divides e
284- _s , _e = s // chunk_size + p , e // chunk_size + p + (e % chunk_size
285- > 0 )
286-
287- # adjust inidces and offsets
288- chunk_indices [_s :_e ] -= p
289- chunk_offsets [_s ] = s % chunk_size
290-
291- return chunk_indices , chunk_offsets
292-
293-
294261class BambaModel (nn .Module ):
295262
296263 def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
@@ -361,8 +328,17 @@ def forward(
361328 )):
362329 seq_idx [srt :end ] = i
363330 seq_idx .unsqueeze_ (0 )
364- # Compute mamba2 metadata tensors that are reused across layers
365- chunk_indices , chunk_offsets = _seq_idx_to_chunk_indices_offsets (
331+
332+ # compute metadata for chunked prefill.
333+ # actually this is only needed if there are
334+ # initial states, but this is determinable
335+ # only from attention metadata yet
336+ # unavailable from the current top-level forward.
337+ # Rather than complicating things to extract said
338+ # metadata, we simply just compute redundently and
339+ # will be silently ignored inside the mamba kernels.
340+ # if not needed.
341+ chunk_indices , chunk_offsets = seq_idx_to_chunk_indices_offsets (
366342 seq_idx , self .config .mamba_chunk_size )
367343
368344 if get_pp_group ().is_first_rank :
@@ -378,7 +354,6 @@ def forward(
378354
379355 residual = None
380356 num_attn = 0
381- extra_args = {}
382357 for i in range (len (self .layers )):
383358 layer = self .layers [i ]
384359 if isinstance (layer , BambaAttentionDecoderLayer ):
@@ -388,19 +363,15 @@ def forward(
388363 if isinstance (layer , BambaMixerDecoderLayer ):
389364 layer_mamba_cache_params = mamba_cache_params .at_layer_idx (
390365 i - num_attn )
391- extra_args = {
392- 'chunk_indices' : chunk_indices ,
393- 'chunk_offsets' : chunk_offsets ,
394- }
395366
396- # print(f"{len(extra_args)=}")
397367 hidden_states , residual = layer (
398368 positions = positions ,
399369 hidden_states = hidden_states ,
400370 residual = residual ,
401371 mamba_cache_params = layer_mamba_cache_params ,
402372 sequence_idx = seq_idx ,
403- ** extra_args ,
373+ chunk_indices = chunk_indices ,
374+ chunk_offsets = chunk_offsets ,
404375 )
405376
406377 if not get_pp_group ().is_last_rank :
0 commit comments