@@ -573,8 +573,8 @@ def forward_cuda(
573573            x  =  hidden_states_B_C_p .transpose (
574574                0 , 1 )  # this is the form that causal-conv see 
575575            if  mamba2_metadata .cu_seqlen  is  None :
576-                 mamba2_metadata  =  update_metadata (
577-                     x ,  attn_metadata . query_start_loc ,  mamba2_metadata )
576+                 mamba2_metadata  =  update_metadata (x ,  query_start_loc_p , 
577+                                                    mamba2_metadata )
578578            hidden_states_B_C_p  =  causal_conv1d_fn (
579579                x ,
580580                conv_weights ,
@@ -583,6 +583,7 @@ def forward_cuda(
583583                conv_states = conv_state ,
584584                has_initial_state = has_initial_states_p ,
585585                cache_indices = state_indices_tensor_p ,
586+                 metadata = mamba2_metadata ,
586587                query_start_loc = query_start_loc_p ).transpose (
587588                    0 , 1 )[:num_prefill_tokens ]
588589
@@ -593,9 +594,14 @@ def forward_cuda(
593594            initial_states  =  None 
594595            if  (has_initial_states_p  is  not None  and  prep_initial_states ):
595596                # making a copy of the states 
596-                 initial_states  =  torch .where (
597-                     has_initial_states_p [:, None , None , None ],
598-                     ssm_state [state_indices_tensor_p ], 0 )
597+                 if  envs .VLLM_USE_V1 :
598+                     initial_states  =  torch .where (
599+                         has_initial_states_p [:, None , None , None ],
600+                         ssm_state [state_indices_tensor_p ], 0 )
601+                 else :
602+                     initial_states  =  torch .where (
603+                         has_initial_states_p [:num_prefills , None , None , None ],
604+                         ssm_state [state_indices_tensor_p ], 0 )
599605
600606            scan_output , varlen_state  =  mamba_chunk_scan_combined (
601607                hidden_states_p .view (1 , num_prefill_tokens ,
0 commit comments