@@ -193,6 +193,13 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
193193
194194 # this tests the kernels on a single example (no batching)
195195
196+ # TODO: the bfloat16 case requires higher thresholds. To be investigated
197+
198+ if itype == torch .bfloat16 :
199+ atol , rtol = 5e-2 , 5e-2
200+ else :
201+ atol , rtol = 8e-3 , 5e-3
202+
196203 # set seed
197204 batch_size = 1 # batch_size
198205 # ssd_minimal_discrete requires chunk_size divide seqlen
@@ -216,14 +223,14 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
216223 return_final_states = True )
217224
218225 # just test the last in sequence
219- torch .allclose (Y [:, - 1 ], Y_min [:, - 1 ], atol = 1e-3 , rtol = 1e-3 )
226+ torch .testing . assert_close (Y [:, - 1 ], Y_min [:, - 1 ], atol = atol , rtol = rtol )
220227
221228 # just test the last head
222229 # NOTE, in the kernel we always cast states to fp32
223- torch .allclose (final_state [:, - 1 ],
224- final_state_min [:, - 1 ].to (torch .float32 ),
225- atol = 1e-3 ,
226- rtol = 1e-3 )
230+ torch .testing . assert_close (final_state [:, - 1 ],
231+ final_state_min [:, - 1 ].to (torch .float32 ),
232+ atol = atol ,
233+ rtol = rtol )
227234
228235
229236@pytest .mark .parametrize ("itype" , [torch .float32 , torch .float16 ])
@@ -263,6 +270,13 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
263270
264271 seqlen , chunk_size , num_examples , cases = seq_len_chunk_size_cases
265272
273+ # TODO: the irregular chunk size cases have some issues and require higher
274+ # tolerance. This is to be invesigated
275+ if chunk_size not in {8 , 256 }:
276+ atol , rtol = 5e-1 , 5e-1
277+ else :
278+ atol , rtol = 5e-3 , 5e-3
279+
266280 # hold state during the cutting process so we know if an
267281 # example has been exhausted and needs to cycle
268282 last_taken : dict = {} # map: eg -> pointer to last taken sample
@@ -300,7 +314,7 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
300314 # just test one dim and dstate
301315 Y_eg = Y [0 , cu_seqlens [i ]:cu_seqlens [i + 1 ], 0 , 0 ]
302316 Y_min_eg = Y_min [i ][:, 0 , 0 ]
303- torch .allclose (Y_eg , Y_min_eg , atol = 1e-3 , rtol = 1e-3 )
317+ torch .testing . assert_close (Y_eg , Y_min_eg , atol = atol , rtol = rtol )
304318
305319 # update states
306320 states = new_states
0 commit comments