77from einops import rearrange , repeat
88
99from vllm .model_executor .layers .mamba .ops .ssd_combined import (
10- mamba_chunk_scan_combined )
10+ mamba_chunk_scan_combined_varlen )
1111from vllm .platforms import current_platform
1212from vllm .v1 .attention .backends .mamba2_attn import (
1313 _query_start_loc_to_chunk_indices_offsets )
@@ -185,9 +185,14 @@ def end_boundary(n: int):
185185 IND_S = [x % full_length for x in IND_E ]
186186 IND_E = [end_boundary (x + y ) for x , y in zip (IND_S , spec )]
187187
188+ # varlen has implicit batch=1
189+ dt2 = dt2 .squeeze (0 )
190+ X2 = X2 .squeeze (0 )
191+ B2 = B2 .squeeze (0 )
192+ C2 = C2 .squeeze (0 )
188193 yield ([Y_min [s , IND_S [s ]:IND_E [s ]]
189194 for s in range (num_examples )] if return_naive_ref else None ,
190- cu_seqlens , seq_idx . unsqueeze ( 0 ) , (A , dt2 , X2 , B2 , C2 ))
195+ cu_seqlens , seq_idx , (A , dt2 , X2 , B2 , C2 ))
191196
192197
193198@pytest .mark .parametrize ("itype" ,
@@ -219,6 +224,20 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
219224
220225 Y_min , final_state_min = ssd_minimal_discrete (X * dt .unsqueeze (- 1 ), A * dt ,
221226 B , C , chunk_size )
227+
228+ cu_seqlens = torch .tensor ((0 , seqlen ), device = 'cuda' ).cumsum (dim = 0 )
229+ seq_idx = torch .zeros (seqlen , dtype = torch .int32 , device = cu_seqlens .device )
230+
231+ chunk_indices , chunk_offsets = \
232+ _query_start_loc_to_chunk_indices_offsets (
233+ cu_seqlens , chunk_size , cu_seqlens [- 1 ])
234+
235+ # varlen has implicit batch=1
236+ X = X .squeeze (0 )
237+ dt = dt .squeeze (0 )
238+ A = A .squeeze (0 )
239+ B = B .squeeze (0 )
240+ C = C .squeeze (0 )
222241 Y = torch .empty_like (X )
223242 final_state = mamba_chunk_scan_combined_varlen (X ,
224243 dt ,
@@ -234,11 +253,11 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
234253 out = Y )
235254
236255 # just test the last in sequence
237- torch .testing .assert_close (Y [:, - 1 ], Y_min [: , - 1 ], atol = atol , rtol = rtol )
256+ torch .testing .assert_close (Y [- 1 ], Y_min [0 , - 1 ], atol = atol , rtol = rtol )
238257
239258 # just test the last head
240259 # NOTE, in the kernel we always cast states to fp32
241- torch .testing .assert_close (final_state [:, - 1 ],
260+ torch .testing .assert_close (final_state [:, - 1 ]. to ( torch . float32 ) ,
242261 final_state_min [:, - 1 ].to (torch .float32 ),
243262 atol = atol ,
244263 rtol = rtol )
@@ -303,7 +322,7 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
303322 cu_seqlens , chunk_size , cu_seqlens [- 1 ])
304323
305324 Y = torch .empty_like (X )
306- new_states = mamba_chunk_scan_combined (
325+ new_states = mamba_chunk_scan_combined_varlen (
307326 X ,
308327 dt ,
309328 A ,
@@ -315,7 +334,6 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
315334 seq_idx = seq_idx ,
316335 chunk_indices = chunk_indices ,
317336 chunk_offsets = chunk_offsets ,
318- return_varlen_states = True ,
319337 initial_states = states ,
320338 out = Y ,
321339 )
@@ -324,7 +342,7 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
324342 for i in range (num_examples ):
325343
326344 # just test one dim and dstate
327- Y_eg = Y [0 , cu_seqlens [i ]:cu_seqlens [i + 1 ], 0 , 0 ]
345+ Y_eg = Y [cu_seqlens [i ]:cu_seqlens [i + 1 ], 0 , 0 ]
328346 Y_min_eg = Y_min [i ][:, 0 , 0 ]
329347 torch .testing .assert_close (Y_eg , Y_min_eg , atol = atol , rtol = rtol )
330348
0 commit comments