66from  torch  import  nn 
77
88from  vllm .attention .backends .abstract  import  AttentionMetadata 
9- from  vllm .attention .backends .flash_attn  import  FlashAttentionMetadata 
10- from  vllm .attention .backends .placeholder_attn  import  (
11-     PlaceholderAttentionMetadata )
12- from  vllm .attention .backends .xformers  import  XFormersMetadata 
139from  vllm .distributed  import  (divide , get_tensor_model_parallel_rank ,
1410                              get_tensor_model_parallel_world_size ,
1511                              tensor_model_parallel_all_gather ,
1814from  vllm .model_executor .custom_op  import  CustomOp 
1915from  vllm .model_executor .layers .linear  import  (ColumnParallelLinear ,
2016                                               RowParallelLinear )
17+ from  vllm .model_executor .layers .mamba .mamba2_metadata  import  Mamba2Metadata 
2118from  vllm .model_executor .layers .mamba .ops .causal_conv1d  import  (
2219    causal_conv1d_fn , causal_conv1d_update )
2320from  vllm .model_executor .layers .mamba .ops .mamba_ssm  import  (
@@ -221,7 +218,6 @@ def __init__(self,
221218                 head_dim : int  =  64 ,
222219                 rms_norm_eps : float  =  1e-5 ,
223220                 activation = "silu" ,
224-                  chunk_size : int  =  256 ,
225221                 quant_config : Optional [QuantizationConfig ] =  None ):
226222        super ().__init__ ()
227223
@@ -257,7 +253,6 @@ def __init__(self,
257253        self .ssm_state_size  =  ssm_state_size 
258254        self .activation  =  activation 
259255
260-         self .chunk_size  =  chunk_size 
261256        self .intermediate_size  =  intermediate_size 
262257        self .head_dim  =  head_dim 
263258        self .num_heads  =  num_heads 
@@ -388,25 +383,17 @@ def forward_cuda(
388383        self ,
389384        hidden_states : torch .Tensor ,
390385        mamba_cache_params : MambaCacheParams ,
391-         sequence_idx :  Optional [ torch . Tensor ]  =   None ,
386+         mamba2_metadata :  Mamba2Metadata ,
392387    ):
388+         # mamba2_metadata contains metadata necessary for the mamba2 triton 
389+         # kernels to operate in continuous batching and in chunked prefill 
390+         # modes; they are computed at top-level model forward since they 
391+         # are the same and reused for all mamba layers in the same iteration 
393392        attn_metadata : AttentionMetadata  =  get_forward_context ().attn_metadata 
394393
395394        seq_len , _  =  hidden_states .shape 
396395        groups_time_state_size  =  self .n_groups  *  self .ssm_state_size 
397396
398-         # detect if there are prefills 
399-         has_prefill  =  attn_metadata .num_prefills  >  0 
400- 
401-         # - also need flags to indicate if there are initial states 
402-         # - currently we really only support the FlashAttention backend 
403-         has_initial_states  =  None 
404-         if  (isinstance (attn_metadata ,
405-                        (FlashAttentionMetadata , XFormersMetadata ,
406-                         PlaceholderAttentionMetadata ))
407-                 and  attn_metadata .context_lens_tensor  is  not None ):
408-             has_initial_states  =  attn_metadata .context_lens_tensor  >  0 
409- 
410397        # 1. Gated MLP's linear projection 
411398        projected_states , _  =  self .in_proj (hidden_states )
412399        gate , hidden_states_B_C , dt  =  torch .split (
@@ -423,7 +410,7 @@ def forward_cuda(
423410        conv_weights  =  self .conv1d .weight .view (self .conv1d .weight .size (0 ),
424411                                               self .conv1d .weight .size (2 ))
425412
426-         if  has_prefill :
413+         if  mamba2_metadata . has_prefill :
427414            # |---------- N-1 iteration --------| 
428415            # |---------------- N iteration ---------------------| 
429416            # |- tokenA -|......................|-- newTokens ---| 
@@ -439,7 +426,7 @@ def forward_cuda(
439426                self .conv1d .bias ,
440427                activation = self .activation ,
441428                conv_states = mamba_cache_params .conv_state ,
442-                 has_initial_state = has_initial_states ,
429+                 has_initial_state = mamba2_metadata . has_initial_states ,
443430                cache_indices = mamba_cache_params .state_indices_tensor ,
444431                query_start_loc = attn_metadata .query_start_loc ).transpose (
445432                    0 , 1 )[:seq_len ]
@@ -467,16 +454,15 @@ def forward_cuda(
467454        )
468455
469456        # 3. State Space Model sequence transformation 
470-         if  has_prefill :
471- 
457+         if  mamba2_metadata .has_prefill :
472458            initial_states  =  None 
473-             if  has_initial_states  is  not None   and   torch . any ( 
474-                     has_initial_states ):
475-                 zero_init_indices   =   mamba_cache_params . state_indices_tensor [ 
476-                      ~ has_initial_states ] 
477-                 mamba_cache_params . ssm_state [ zero_init_indices ]  =   0 
478-                 initial_states   =  mamba_cache_params .ssm_state [
479-                     mamba_cache_params .state_indices_tensor ]
459+             if  ( mamba2_metadata . has_initial_states  is  not None 
460+                     and   mamba2_metadata . prep_initial_states ):
461+                 # making a copy of the states 
462+                 initial_states   =   torch . where ( 
463+                      mamba2_metadata . has_initial_states [:,  None ,  None ,  None ], 
464+                      mamba_cache_params .ssm_state [
465+                          mamba_cache_params .state_indices_tensor ],  0 ) 
480466
481467            scan_output , varlen_state  =  mamba_chunk_scan_combined (
482468                hidden_states .view (1 , seq_len , self .num_heads  //  self .tp_size ,
@@ -485,11 +471,13 @@ def forward_cuda(
485471                self .A ,
486472                B .view (1 , seq_len , self .n_groups  //  self .tp_size , - 1 ),
487473                C .view (1 , seq_len , self .n_groups  //  self .tp_size , - 1 ),
488-                 chunk_size = self .chunk_size ,
474+                 chunk_size = mamba2_metadata .chunk_size ,
489475                D = self .D ,
490476                z = None ,
491477                dt_bias = self .dt_bias ,
492-                 seq_idx = sequence_idx ,
478+                 seq_idx = mamba2_metadata .seq_idx ,
479+                 chunk_indices = mamba2_metadata .chunk_indices ,
480+                 chunk_offsets = mamba2_metadata .chunk_offsets ,
493481                cu_seqlens = attn_metadata .query_start_loc ,
494482                initial_states = initial_states ,
495483                return_varlen_states = True ,
0 commit comments