Skip to content

Commit b97b99a

Browse files
committed
fix(tests): Resolve late binding of loop variable in assert message lambda
The `msg` parameter for `torch.testing.assert_close` within the `test_mamba_chunk_scan_cont_batch_prefill_chunking` test function used a lambda defined inside a for-loop. This lambda captured the loop variable `i`, triggering the `B023` warning from the Ruff linter. Ruff implements this rule (originally from flake8-bugbear) to detect "late binding" issues, where a lambda captures a reference to a variable, not its value at the time of definition. Although the current test runner likely executes the lambda immediately upon assertion failure, this pattern is a latent bug. Future changes could defer message generation, causing all failure messages to incorrectly display the final value of `i`, which would be misleading for debugging. This commit fixes the issue by using a default argument in the lambda (`lambda x, i=i: ...`) to capture the value of `i` at definition time. This robustly resolves the potential bug and allows for the removal of the `# noqa: B023` suppressions. Signed-off-by: lyd1992 <liuyudong@iscas.ac.cn> Signed-off-by: ihb2032 <1355790728@qq.com Signed-off-by: ihb2032 <1355790728@qq.com>
1 parent d6953be commit b97b99a

File tree

1 file changed

+7
-11
lines changed

1 file changed

+7
-11
lines changed

tests/kernels/mamba/test_mamba_ssm_ssd.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -539,22 +539,18 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
539539
Y_ref_seq[: chunked_seqlens[i], ...],
540540
atol=atol,
541541
rtol=rtol,
542-
msg=lambda x: f"seq{i} output part1 " + x,
543-
) # noqa: B023
542+
msg=lambda x, i=i: f"seq{i} output part1 " + x)
544543
torch.testing.assert_close(
545544
Y_seq[chunked_seqlens[i] :, ...],
546545
Y_ref_seq[chunked_seqlens[i] :, ...],
547546
atol=atol,
548547
rtol=rtol,
549-
msg=lambda x: f"seq{i} output part2 " + x,
550-
) # noqa: B023
548+
msg=lambda x, i=i: f"seq{i} output part2 " + x)
551549

552550
state_seq = state_chunked[i]
553551
state_seq_ref = state_ref[i]
554-
torch.testing.assert_close(
555-
state_seq,
556-
state_seq_ref,
557-
atol=atol,
558-
rtol=rtol,
559-
msg=lambda x: f"seq{i} state " + x,
560-
) # noqa: B023
552+
torch.testing.assert_close(state_seq,
553+
state_seq_ref,
554+
atol=atol,
555+
rtol=rtol,
556+
msg=lambda x, i=i: f"seq{i} state " + x)

0 commit comments

Comments
 (0)