Skip to content

Commit 158751d

Browse files
committed
fix(tests): Resolve late binding of loop variable in assert message lambda
The `msg` parameter for `torch.testing.assert_close` within the `test_mamba_chunk_scan_cont_batch_prefill_chunking` test function used a lambda defined inside a for-loop. This lambda captured the loop variable `i`, triggering the `B023` warning from the Ruff linter. Ruff implements this rule (originally from flake8-bugbear) to detect "late binding" issues, where a lambda captures a reference to a variable, not its value at the time of definition. Although the current test runner likely executes the lambda immediately upon assertion failure, this pattern is a latent bug. Future changes could defer message generation, causing all failure messages to incorrectly display the final value of `i`, which would be misleading for debugging. This commit fixes the issue by using a default argument in the lambda (`lambda x, i=i: ...`) to capture the value of `i` at definition time. This robustly resolves the potential bug and allows for the removal of the `# noqa: B023` suppressions. Signed-off-by: lyd1992 <liuyudong@iscas.ac.cn> Signed-off-by: ihb2032 <1355790728@qq.com
1 parent d6953be commit 158751d

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

tests/kernels/mamba/test_mamba_ssm_ssd.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -539,15 +539,15 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
539539
Y_ref_seq[: chunked_seqlens[i], ...],
540540
atol=atol,
541541
rtol=rtol,
542-
msg=lambda x: f"seq{i} output part1 " + x,
543-
) # noqa: B023
542+
msg=lambda x, i=i: f"seq{i} output part1 " + x,
543+
)
544544
torch.testing.assert_close(
545545
Y_seq[chunked_seqlens[i] :, ...],
546546
Y_ref_seq[chunked_seqlens[i] :, ...],
547547
atol=atol,
548548
rtol=rtol,
549-
msg=lambda x: f"seq{i} output part2 " + x,
550-
) # noqa: B023
549+
msg=lambda x, i=i: f"seq{i} output part2 " + x,
550+
)
551551

552552
state_seq = state_chunked[i]
553553
state_seq_ref = state_ref[i]
@@ -556,5 +556,5 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
556556
state_seq_ref,
557557
atol=atol,
558558
rtol=rtol,
559-
msg=lambda x: f"seq{i} state " + x,
560-
) # noqa: B023
559+
msg=lambda x, i=i: f"seq{i} state " + x,
560+
)

0 commit comments

Comments
 (0)