Skip to content

Conversation

@chengyupku
Copy link
Contributor

@chengyupku chengyupku commented Nov 2, 2025

This pull request refactors and improves the example_gqa_fwd_varlen.py implementation for variable-length Flash Attention with grouped query attention (GQA). The changes enhance correctness, efficiency, and clarity, especially around sequence masking, causal logic, and kernel launch parameters.

Attention masking and causal logic improvements

  • Refactored the reference attention computation in attention_ref to use explicit window-based masking for visible tokens, improving correctness for variable-length and causal attention. Key and query padding masks are now combined with the visibility mask for more accurate masking.
  • Updated causal logic in the main kernel to use a more robust mask and loop range calculation, ensuring proper handling of causal and non-causal cases. [1] [2]

Kernel parameter and launch optimizations

  • Increased block sizes and parallelism (block_M, block_N, num_stages, threads) for the tile-lang kernel launch, improving performance.
  • Adjusted benchmarking to use multiple warmup and repeat runs for more reliable latency measurements.

Code simplification and correctness fixes

  • Removed unnecessary requires_grad=True from input tensor creation, as gradients are not needed for this example.
  • Unified key and value start/end indices and sequence lengths for grouped attention, reducing code duplication and potential errors. [1] [2]

Summary by CodeRabbit

  • Bug Fixes

    • Improved numerical stability in attention computation
    • Enhanced masking for windowed attention and causal patterns
  • Performance Improvements

    • Optimized kernel tiling configuration for better throughput
    • Refined benchmarking measurements for more accurate performance assessment

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 2, 2025

Walkthrough

The PR updates the GQA variable-length flash attention example with improved windowed attention handling through visibility masking, enhanced numerical stability via explicit score masking, consolidated kernel indexing parameters, and increased tiling configuration for performance optimization across reference implementation and main lighting kernels.

Changes

Cohort / File(s) Summary
GQA Flash Attention Example Updates
examples/flash_attention/example_gqa_fwd_varlen.py
attention_ref: Replaced implicit dimension calculations with explicit shape unpacking (b, T, Hq, D); added windowed attention with visibility mask across time/source positions applied to scores via neg_inf; integrated key padding masking with visibility mask; updated tiling g-factors; applied softmax on masked scores; updated output masking logic. main/lighting: Increased block_M/block_N from 64→128, num_stages from 1→2, threads from 128→256; removed requires_grad guards; added is_causal-driven flops branch; consolidated k_start_idx/v_start_idx to kv_start_idx; replaced strict -infinity with -1e9 checks; added explicit max-sum updates; expanded do_bench parameters; added layout annotations for O_shared/Q_shared.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

  • attention_ref masking logic: Verify windowed attention visibility mask implementation, intersection with key padding mask, and softmax application order for correctness
  • Numerical stability changes: Validate -1e9 replacement for strict -infinity checks across causal/non-causal paths and confirm score normalization adjustments
  • Kernel indexing consolidation: Ensure kv_start_idx correctly replaces separate k_start_idx/v_start_idx throughout kernel wiring
  • Tiling configuration impact: Confirm new block sizes (128/128) and stage counts (2) are compatible with memory layout annotations and V_unpad indexing

Possibly related PRs

Suggested reviewers

  • LeiWang1999

Poem

🐰 With masked attention, windowed and bright,
Visibility layers now set things right,
Tiling expanded, performance takes flight,
Kernels refined with consolidated might,
GQA varlen shines in the light! ✨

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The pull request title "[Example] Update GQA varlen fwd" clearly and specifically identifies the main change: updating the example file for grouped query attention (GQA) with variable-length forward pass. The title is concise, free of vague terminology or noise, and specific enough that a reviewer scanning the pull request history would immediately understand this concerns updates to the example_gqa_fwd_varlen.py file. While the title doesn't enumerate all implementation details (masking refactoring, kernel optimizations, etc.), that level of detail is not expected in a PR title per the guidelines. The title successfully captures the primary change from the developer's perspective.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link

github-actions bot commented Nov 2, 2025

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c85bb3a and 1cf3bd7.

📒 Files selected for processing (1)
  • examples/flash_attention/example_gqa_fwd_varlen.py (7 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/flash_attention/example_gqa_fwd_varlen.py (7)
tilelang/language/annotations.py (1)
  • annotate_layout (25-36)
tilelang/layout/swizzle.py (1)
  • make_swizzled_layout (10-18)
tilelang/language/copy.py (1)
  • copy (11-87)
tilelang/language/fill.py (2)
  • fill (9-21)
  • clear (24-48)
tilelang/language/pipeline.py (1)
  • Pipelined (9-46)
tilelang/language/tir/op.py (1)
  • if_then_else (2907-2937)
tilelang/language/reduce.py (1)
  • reduce_max (50-68)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Quick Lint

Comment on lines +130 to 135
loop_range = (
T.min(
T.ceildiv(q_current_seqlen +
(bx + 1) * block_M, block_N), T.ceildiv(kv_current_seqlen, block_N))
if is_causal else T.ceildiv(kv_current_seqlen, block_N))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fix causal tile loop_range regression

For is_causal=True the new loop_range ends up as ceildiv(kv_current_seqlen, block_N) for every tile (e.g., q=kv=2048, block_M=block_N=128 ⇒ the very first tile now iterates all 16 K tiles instead of just the first one). We lose the triangular work reduction, so causal runs pay the full non-causal cost—latency blows up roughly ×16 on long sequences. Replace the sum with a clamp to the end of the current query tile.

-            loop_range = (
-                T.min(
-                    T.ceildiv(q_current_seqlen +
-                              (bx + 1) * block_M, block_N), T.ceildiv(kv_current_seqlen, block_N))
-                if is_causal else T.ceildiv(kv_current_seqlen, block_N))
+            loop_range = (
+                T.min(
+                    T.ceildiv(T.min(q_current_seqlen, (bx + 1) * block_M), block_N),
+                    T.ceildiv(kv_current_seqlen, block_N))
+                if is_causal else T.ceildiv(kv_current_seqlen, block_N))
🤖 Prompt for AI Agents
In examples/flash_attention/example_gqa_fwd_varlen.py around lines 130-135, the
causal branch builds loop_range by adding (bx+1)*block_M to q_current_seqlen
which causes every causal tile to iterate up to the full KV length; replace the
sum with a clamp to the end of the current query tile so the loop range uses the
minimum of q_current_seqlen and (bx+1)*block_M before taking ceildiv and then
min that with the KV ceildiv; implement this by computing the tile-end =
min(q_current_seqlen, (bx+1)*block_M), using ceildiv(tile-end, block_N), and
then min(...) with ceildiv(kv_current_seqlen, block_N) for the causal case
(non-causal case unchanged).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants