diff --git a/tests/kernels/mamba/test_mamba_ssm_ssd.py b/tests/kernels/mamba/test_mamba_ssm_ssd.py index b4424b717d02..06caf65932d2 100644 --- a/tests/kernels/mamba/test_mamba_ssm_ssd.py +++ b/tests/kernels/mamba/test_mamba_ssm_ssd.py @@ -539,15 +539,15 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens): Y_ref_seq[: chunked_seqlens[i], ...], atol=atol, rtol=rtol, - msg=lambda x: f"seq{i} output part1 " + x, - ) # noqa: B023 + msg=lambda x, i=i: f"seq{i} output part1 " + x, + ) torch.testing.assert_close( Y_seq[chunked_seqlens[i] :, ...], Y_ref_seq[chunked_seqlens[i] :, ...], atol=atol, rtol=rtol, - msg=lambda x: f"seq{i} output part2 " + x, - ) # noqa: B023 + msg=lambda x, i=i: f"seq{i} output part2 " + x, + ) state_seq = state_chunked[i] state_seq_ref = state_ref[i] @@ -556,5 +556,5 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens): state_seq_ref, atol=atol, rtol=rtol, - msg=lambda x: f"seq{i} state " + x, - ) # noqa: B023 + msg=lambda x, i=i: f"seq{i} state " + x, + )