2525                                               ReplicatedLinear ,
2626                                               RowParallelLinear )
2727from  vllm .model_executor .layers .logits_processor  import  LogitsProcessor 
28+ from  vllm .model_executor .layers .mamba .mamba2_metadata  import  (
29+     Mamba2Metadata , prepare_mamba2_metadata )
2830from  vllm .model_executor .layers .mamba .mamba_mixer2  import  (
2931    MambaMixer2 , extra_groups_for_head_shards )
3032from  vllm .model_executor .layers .quantization  import  QuantizationConfig 
@@ -495,7 +497,6 @@ def __init__(
495497            head_dim = intermediate_size  //  config .n_mamba_heads ,
496498            rms_norm_eps = config .rms_norm_eps ,
497499            activation = "silu" ,
498-             chunk_size = config .chunk_size ,
499500            quant_config = quant_config ,
500501        )
501502
@@ -507,7 +508,7 @@ def forward(
507508        self ,
508509        hidden_states : torch .Tensor ,
509510        mamba_cache_params : MambaCacheParams ,
510-         sequence_idx :  Optional [ torch . Tensor ]  =   None ,
511+         mamba2_metadata :  Mamba2Metadata ,
511512        transformer_hidden_states : Optional [torch .Tensor ] =  None ,
512513        positions : Optional [torch .Tensor ] =  None ,
513514        original_hidden_states : Optional [torch .Tensor ] =  None ,
@@ -547,7 +548,7 @@ def forward(
547548        hidden_states  =  self .mamba (
548549            hidden_states ,
549550            mamba_cache_params = mamba_cache_params ,
550-             sequence_idx = sequence_idx ,
551+             mamba2_metadata = mamba2_metadata ,
551552        )
552553
553554        # residual connection after mamba 
@@ -594,8 +595,8 @@ def forward(
594595        hidden_states : torch .Tensor ,
595596        original_hidden_states : torch .Tensor ,
596597        positions : torch .Tensor ,
597-         mamba_cache_params : Optional [ MambaCacheParams ]  =   None ,
598-         sequence_idx :  Optional [ torch . Tensor ]  =   None ,
598+         mamba_cache_params : MambaCacheParams ,
599+         mamba2_metadata :  Mamba2Metadata ,
599600    ) ->  torch .Tensor :
600601        """Forward pass through the hybrid layer. 
601602         
@@ -634,7 +635,7 @@ def forward(
634635            hidden_states ,
635636            transformer_hidden_states = transformer_hidden_states ,
636637            mamba_cache_params = mamba_cache_params ,
637-             sequence_idx = sequence_idx ,
638+             mamba2_metadata = mamba2_metadata ,
638639        )
639640
640641        return  layer_outputs 
@@ -747,20 +748,14 @@ def forward(
747748            inputs_embeds  =  self .get_input_embeddings (input_ids )
748749        hidden_states  =  inputs_embeds 
749750
750-         # pass a sequence index tensor, that is required for 
751-         # proper continuous batching computation including 
752-         # chunked prefill 
753-         seq_idx  =  None 
754751        attn_metadata  =  get_forward_context ().attn_metadata 
755-         if  attn_metadata .num_prefills  >  0 :
756-             seq_idx  =  torch .zeros_like (input_ids , dtype = torch .int32 )
757-             for  i , (srt , end ) in  enumerate (
758-                     zip (
759-                         attn_metadata .query_start_loc ,
760-                         attn_metadata .query_start_loc [1 :],
761-                     )):
762-                 seq_idx [srt :end ] =  i 
763-             seq_idx .unsqueeze_ (0 )
752+ 
753+         mamba2_metadata  =  prepare_mamba2_metadata (
754+             chunk_size = self .config .chunk_size ,
755+             has_prefills = attn_metadata .num_prefills  >  0 ,
756+             input_ids = input_ids ,
757+             query_start_loc = attn_metadata .query_start_loc ,
758+         )
764759
765760        # Process through layers 
766761        original_hidden_states  =  torch .clone (hidden_states )
@@ -770,7 +765,7 @@ def forward(
770765                original_hidden_states = original_hidden_states ,
771766                positions = positions ,
772767                mamba_cache_params = mamba_cache_params .at_layer_idx (layer_idx ),
773-                 sequence_idx = seq_idx ,
768+                 mamba2_metadata = mamba2_metadata ,
774769            )
775770            hidden_states  =  layer_outputs 
776771
0 commit comments