diff --git a/tests/kernels/mamba/test_mamba_mixer2.py b/tests/kernels/mamba/test_mamba_mixer2.py index f5c6a18614ff..16c310726ad1 100644 --- a/tests/kernels/mamba/test_mamba_mixer2.py +++ b/tests/kernels/mamba/test_mamba_mixer2.py @@ -119,7 +119,8 @@ def mixer2_gated_norm_tensor_parallel( gate_states[..., local_rank * N:(local_rank + 1) * N], ) ref_output = mixer_single_gpu(hidden_states, gate_states) - torch.allclose(output, - ref_output[..., local_rank * N:(local_rank + 1) * N], - atol=1e-3, - rtol=1e-3) + torch.testing.assert_close(output, + ref_output[..., + local_rank * N:(local_rank + 1) * N], + atol=5e-3, + rtol=1e-3) diff --git a/tests/kernels/mamba/test_mamba_ssm_ssd.py b/tests/kernels/mamba/test_mamba_ssm_ssd.py index 6a3f21ba543f..00c1a2911d7d 100644 --- a/tests/kernels/mamba/test_mamba_ssm_ssd.py +++ b/tests/kernels/mamba/test_mamba_ssm_ssd.py @@ -193,6 +193,13 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, # this tests the kernels on a single example (no batching) + # TODO: the bfloat16 case requires higher thresholds. To be investigated + + if itype == torch.bfloat16: + atol, rtol = 5e-2, 5e-2 + else: + atol, rtol = 8e-3, 5e-3 + # set seed batch_size = 1 # batch_size # ssd_minimal_discrete requires chunk_size divide seqlen @@ -216,14 +223,14 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, return_final_states=True) # just test the last in sequence - torch.allclose(Y[:, -1], Y_min[:, -1], atol=1e-3, rtol=1e-3) + torch.testing.assert_close(Y[:, -1], Y_min[:, -1], atol=atol, rtol=rtol) # just test the last head # NOTE, in the kernel we always cast states to fp32 - torch.allclose(final_state[:, -1], - final_state_min[:, -1].to(torch.float32), - atol=1e-3, - rtol=1e-3) + torch.testing.assert_close(final_state[:, -1], + final_state_min[:, -1].to(torch.float32), + atol=atol, + rtol=rtol) @pytest.mark.parametrize("itype", [torch.float32, torch.float16]) @@ -263,6 +270,13 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, seqlen, chunk_size, num_examples, cases = seq_len_chunk_size_cases + # TODO: the irregular chunk size cases have some issues and require higher + # tolerance. This is to be invesigated + if chunk_size not in {8, 256}: + atol, rtol = 5e-1, 5e-1 + else: + atol, rtol = 5e-3, 5e-3 + # hold state during the cutting process so we know if an # example has been exhausted and needs to cycle last_taken: dict = {} # map: eg -> pointer to last taken sample @@ -300,7 +314,7 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, # just test one dim and dstate Y_eg = Y[0, cu_seqlens[i]:cu_seqlens[i + 1], 0, 0] Y_min_eg = Y_min[i][:, 0, 0] - torch.allclose(Y_eg, Y_min_eg, atol=1e-3, rtol=1e-3) + torch.testing.assert_close(Y_eg, Y_min_eg, atol=atol, rtol=rtol) # update states states = new_states