-
Notifications
You must be signed in to change notification settings - Fork 332
[Example] Update GQA varlen fwd #1173
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
WalkthroughThe PR updates the GQA variable-length flash attention example with improved windowed attention handling through visibility masking, enhanced numerical stability via explicit score masking, consolidated kernel indexing parameters, and increased tiling configuration for performance optimization across reference implementation and main lighting kernels. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
examples/flash_attention/example_gqa_fwd_varlen.py(7 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/flash_attention/example_gqa_fwd_varlen.py (7)
tilelang/language/annotations.py (1)
annotate_layout(25-36)tilelang/layout/swizzle.py (1)
make_swizzled_layout(10-18)tilelang/language/copy.py (1)
copy(11-87)tilelang/language/fill.py (2)
fill(9-21)clear(24-48)tilelang/language/pipeline.py (1)
Pipelined(9-46)tilelang/language/tir/op.py (1)
if_then_else(2907-2937)tilelang/language/reduce.py (1)
reduce_max(50-68)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Quick Lint
| loop_range = ( | ||
| T.min( | ||
| T.ceildiv(q_current_seqlen + | ||
| (bx + 1) * block_M, block_N), T.ceildiv(kv_current_seqlen, block_N)) | ||
| if is_causal else T.ceildiv(kv_current_seqlen, block_N)) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix causal tile loop_range regression
For is_causal=True the new loop_range ends up as ceildiv(kv_current_seqlen, block_N) for every tile (e.g., q=kv=2048, block_M=block_N=128 ⇒ the very first tile now iterates all 16 K tiles instead of just the first one). We lose the triangular work reduction, so causal runs pay the full non-causal cost—latency blows up roughly ×16 on long sequences. Replace the sum with a clamp to the end of the current query tile.
- loop_range = (
- T.min(
- T.ceildiv(q_current_seqlen +
- (bx + 1) * block_M, block_N), T.ceildiv(kv_current_seqlen, block_N))
- if is_causal else T.ceildiv(kv_current_seqlen, block_N))
+ loop_range = (
+ T.min(
+ T.ceildiv(T.min(q_current_seqlen, (bx + 1) * block_M), block_N),
+ T.ceildiv(kv_current_seqlen, block_N))
+ if is_causal else T.ceildiv(kv_current_seqlen, block_N))🤖 Prompt for AI Agents
In examples/flash_attention/example_gqa_fwd_varlen.py around lines 130-135, the
causal branch builds loop_range by adding (bx+1)*block_M to q_current_seqlen
which causes every causal tile to iterate up to the full KV length; replace the
sum with a clamp to the end of the current query tile so the loop range uses the
minimum of q_current_seqlen and (bx+1)*block_M before taking ceildiv and then
min that with the KV ceildiv; implement this by computing the tile-end =
min(q_current_seqlen, (bx+1)*block_M), using ceildiv(tile-end, block_N), and
then min(...) with ceildiv(kv_current_seqlen, block_N) for the causal case
(non-causal case unchanged).
* [Example] Update GQA varlen fwd * fix
This pull request refactors and improves the
example_gqa_fwd_varlen.pyimplementation for variable-length Flash Attention with grouped query attention (GQA). The changes enhance correctness, efficiency, and clarity, especially around sequence masking, causal logic, and kernel launch parameters.Attention masking and causal logic improvements
attention_refto use explicit window-based masking for visible tokens, improving correctness for variable-length and causal attention. Key and query padding masks are now combined with the visibility mask for more accurate masking.Kernel parameter and launch optimizations
block_M,block_N,num_stages,threads) for the tile-lang kernel launch, improving performance.Code simplification and correctness fixes
requires_grad=Truefrom input tensor creation, as gradients are not needed for this example.Summary by CodeRabbit
Bug Fixes
Performance Improvements