-
Notifications
You must be signed in to change notification settings - Fork 331
[Feature][Example] Support TMA reduce operation and update GQA bwd example #969
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
WalkthroughIntroduces a TMA-enabled atomic add path configurable via a new use_tma flag, adjusts argument positions, threads a need_reduce flag through bulk TMA copies, adds CUDA codegen support for tma_store_add, exposes a Python atomic_add(use_tma=...) parameter, and adds a flash attention example demonstrating TMA-reduction-based backward. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Python as Python API
participant IR as AtomicAddNode(Lower)
participant CG as CUDA Codegen
participant GPU as CUDA Kernels
Python->>IR: tl.atomicadd(value, dst, use_tma)
alt use_tma != 0
IR->>CG: emit tl::tma_store(..., need_reduce=1, ...)
CG->>GPU: extern "tma_store_add"(smem_addr, gmem_ptr, size_bytes)
else
IR->>CG: emit SIMT atomic add loop
CG->>GPU: atomicAdd(...) per element
end
sequenceDiagram
autonumber
participant Copy as Bulk Copy (IR)
participant CG as CUDA Codegen
participant GPU as CUDA Kernels
Copy->>CG: tma_store(..., need_reduce, eviction_policy)
alt need_reduce != 0
CG->>GPU: tma_store_add(smem_ptr, gmem_ptr, store_bytes)
else
CG->>GPU: tma_store(smem_ptr, gmem_ptr, store_bytes, ...)
end
sequenceDiagram
autonumber
participant App as Example main()
participant TL as TileLang Kernels
participant Torch as PyTorch Autograd
App->>Torch: attention(Q,K,V, causal, use_atomic)
Torch->>TL: flashattn_fwd(...)
TL-->>Torch: O, lse
Torch->>TL: backward (atomic_add or split)
TL-->>Torch: dQ, dK, dV
Torch-->>App: outputs and grads
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
|
||
| def make_dq_layout(dQ): | ||
| # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment | ||
| return T.Layout(dQ.shape, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but this will still be helpful for Ampere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. TMA reduces the need to cache data into shared memory, allowing the naive atomic to utilize this layout. However, why do we need to transpose BLHD into BHLD?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. TMA reduces the need to cache data into shared memory, allowing the naive atomic to utilize this layout. However, why do we need to transpose BLHD into BHLD?
Because FA3 did this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can create another example named example_gqa_bwd_wgmma_pipelined.py, where the GQA kernel is implemented with customized pipelines for Hopper and tma reduce. We can keep using atomic add in this file then?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
examples/flash_attention/example_gqa_bwd.py (1)
444-446: Fixmod_postcall in split path.
flashattn_bwd_postprocessnow expects(dQ, dK, dV)but this branch still calls it with a single tensor, so runtime will raiseTypeError: missing 2 required positional arguments. Update the split path to either supply the full(dQ, dK, dV)inputs or avoid callingmod_postaltogether.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
examples/flash_attention/example_gqa_bwd.py(5 hunks)src/op/atomic_add.cc(3 hunks)src/op/atomic_add.h(4 hunks)src/op/copy.cc(3 hunks)src/target/codegen_cuda.cc(2 hunks)src/tl_templates/cuda/copy_sm90.h(1 hunks)tilelang/language/atomic.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (6)
tilelang/language/atomic.py (1)
tilelang/language/tir/op.py (1)
call_intrin(119-144)
src/tl_templates/cuda/copy_sm90.h (2)
src/tl_templates/cuda/copy.h (1)
void(16-18)src/tl_templates/cuda/ldsm.h (11)
void(7-14)void(16-23)void(25-33)void(35-42)void(44-52)void(54-62)void(64-70)void(72-79)void(81-89)void(91-98)void(100-108)
src/op/copy.cc (1)
src/tl_templates/cuda/copy_sm90.h (6)
tma_store(175-182)tma_store(185-195)tma_store(198-209)tma_store(212-223)tma_store(226-238)tma_store(241-253)
examples/flash_attention/example_gqa_bwd.py (3)
tilelang/language/__init__.py (1)
annotate_layout(107-145)tilelang/language/copy.py (1)
copy(10-86)tilelang/language/atomic.py (1)
atomic_add(116-229)
src/op/atomic_add.cc (3)
src/transform/atomicadd_vectorize.cc (12)
node(36-43)node(36-37)node(46-51)node(46-46)node(53-75)node(53-53)node(180-201)node(180-180)node(203-284)node(203-203)indices(77-125)indices(77-77)tilelang/language/tir/op.py (2)
address_of(463-479)div(2939-2961)src/tl_templates/cuda/copy_sm90.h (6)
tma_store(175-182)tma_store(185-195)tma_store(198-209)tma_store(212-223)tma_store(226-238)tma_store(241-253)
src/op/atomic_add.h (1)
src/op/atomic_add.cc (2)
ReturnIndicesAndSize(173-183)ReturnIndicesAndSize(174-174)
🪛 Ruff (0.13.3)
examples/flash_attention/example_gqa_bwd.py
118-118: Ambiguous variable name: l
(E741)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: format-check
🔇 Additional comments (4)
src/op/atomic_add.cc (4)
173-183: LGTM!The method correctly computes the base indices (using range minimums) and total size (product of extents) for TMA operations. The logic is sound for calculating the starting address and transfer size.
386-386: Confirm single-thread execution model for TMA operations.The TMA path guards execution with
T.thread_var == T.thread_bounds->min, ensuring only one thread performs the TMA transfer. This differs from the SIMT path where all threads participate.Please confirm this single-thread execution model is correct for TMA-based atomic reductions. If this is intentional, consider adding a comment explaining why only one thread should issue the TMA instruction.
Additionally, verify that other threads are properly synchronized (e.g., via barriers) before and after this operation to ensure memory consistency.
380-386: Flags correct for TMA atomic-add need_reduce=1 maps to the “.add” reduction opcode, and eviction_policy=0 corresponds to EVICT_NORMAL.
366-372: use_tma is always initialized in AtomicAdd constructor
AtomicAdd::AtomicAdd unconditionally assignsuse_tma = Downcast<IntImm>(args[2])(calls always pass ≥3 args), so it’s never left uninitialized.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
🧹 Nitpick comments (2)
src/op/atomic_add.cc (2)
175-185: Consider adding documentation for the TMA use case.The implementation correctly computes starting indices (using
minof each range) and the total size (product of extents). This is suitable for TMA bulk operations.Consider adding a brief comment explaining that this method is specifically designed for TMA operations where the entire range is transferred as a contiguous block, unlike the SIMT path which iterates over individual elements.
382-383: Consider making reduction flags configurable.The
need_reduceandeviction_policyvalues are hardcoded as constants. While this may be appropriate for the current atomic add reduction use case, it limits flexibility if different reduction strategies or cache policies are needed in the future.If these values should remain constant for atomic add operations, consider adding a brief comment explaining why these specific values are used.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
examples/flash_attention/example_gqa_bwd_tma_reduce.py(1 hunks)src/op/atomic_add.cc(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/op/atomic_add.cc (2)
tilelang/language/tir/op.py (1)
address_of(463-479)src/tl_templates/cuda/copy_sm90.h (6)
tma_store(175-182)tma_store(185-195)tma_store(198-209)tma_store(212-223)tma_store(226-238)tma_store(241-253)
examples/flash_attention/example_gqa_bwd_tma_reduce.py (1)
examples/flash_attention/example_gqa_bwd.py (17)
flashattn_fwd(12-80)flashattn_bwd_preprocess(87-113)make_dq_layout(116-119)flashattn_bwd_postprocess(126-144)flash_bwd_post(133-142)flashattn_bwd_atomic_add(150-244)flash_bwd(171-242)flash_bwd(273-347)flashattn_bwd_split(250-349)_attention(353-433)forward(356-366)backward(369-433)maybe_contiguous(375-378)ref_program(439-461)main(464-522)run(509-510)run1(512-513)
🪛 Ruff (0.13.3)
examples/flash_attention/example_gqa_bwd_tma_reduce.py
96-96: Ambiguous variable name: O
(E741)
119-119: Ambiguous variable name: l
(E741)
507-507: Ambiguous variable name: O
(E741)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: build-test-metal
- GitHub Check: build-test-amd
- GitHub Check: format-check
🔇 Additional comments (2)
src/op/atomic_add.cc (2)
368-389: The TMA lowering path implementation looks correct.The new TMA-based lowering path:
- Correctly computes indices and sizes for src and dst
- Validates that source and destination have matching sizes
- Creates proper address_of expressions for buffer loads
- Guards execution to a single thread using the thread variable bounds
- Uses
ceildivfor byte-size calculation (correctly handling non-byte-aligned types)The logic appropriately branches between TMA and SIMT paths based on the
use_tmaflag.
384-387: Verify IR-op registration for tma_store with five parameters
No matchingtma_storeOp registration found undersrc/op. Confirm that an Op accepting (gmem_ptr, smem_ptr, size, need_reduce, eviction_policy) is defined and registered as a builtin.
| @torch.compile | ||
| class _attention(torch.autograd.Function): | ||
|
|
||
| @staticmethod | ||
| def forward(ctx, q, k, v, causal, groups=1, use_atomic=True): | ||
| BATCH, N_CTX, H, D_HEAD_QK = q.shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove torch.compile from _attention.
Line 370 decorates the torch.autograd.Function subclass with torch.compile, which replaces the class object and strips the .apply attribute. The very next line (attention = _attention.apply) will then raise AttributeError, breaking both the example and any external users. Drop the decorator (or compile the forward path separately) so _attention remains a Function subclass.
-@torch.compile
-class _attention(torch.autograd.Function):
+class _attention(torch.autograd.Function):📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| @torch.compile | |
| class _attention(torch.autograd.Function): | |
| @staticmethod | |
| def forward(ctx, q, k, v, causal, groups=1, use_atomic=True): | |
| BATCH, N_CTX, H, D_HEAD_QK = q.shape | |
| class _attention(torch.autograd.Function): | |
| @staticmethod | |
| def forward(ctx, q, k, v, causal, groups=1, use_atomic=True): | |
| BATCH, N_CTX, H, D_HEAD_QK = q.shape |
🤖 Prompt for AI Agents
In examples/flash_attention/example_gqa_bwd_tma_reduce.py around lines 370 to
375, the torch.compile decorator is applied to the torch.autograd.Function
subclass _attention which replaces the class object and removes the .apply
attribute; remove the @torch.compile decorator from the class definition so
_attention remains a proper Function subclass (or alternatively, compile only
the forward path separately while keeping the original class intact) to ensure
_attention.apply exists and the subsequent assignment attention =
_attention.apply works correctly.
| kernel(q, k, v, do, lse, delta, dq, dk, dv) | ||
| dq = mod_post(dq) | ||
| dk, dv = dk.sum(0), dv.sum(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Call flashattn_bwd_postprocess with all required tensors.
flashattn_bwd_postprocess now expects the accumulated dQ, dK, and dV buffers (Line 126), but the split path (Line 446) still calls it with just dQ. As soon as use_atomic=False, this raises TypeError: flash_bwd_post() missing 2 required positional arguments. Either supply the matching accumulators (even if you later discard the returned dK/dV) or keep a dedicated postprocess for the split path.
🤖 Prompt for AI Agents
In examples/flash_attention/example_gqa_bwd_tma_reduce.py around lines 445-447,
the split path currently calls the postprocess with only dQ which no longer
matches flashattn_bwd_postprocess signature; update the call to pass the
accumulated dQ, dK, and dV buffers (the same accumulator variables used in the
atomic path) so you call flashattn_bwd_postprocess(dq_accum, dk_accum, dv_accum)
even if you later reduce/sum or discard the returned dK/dV, or alternatively
implement a separate postprocess function for the split path that accepts only
dQ and adapts to the new API.
| if (args.size() >= 3) { | ||
| node->coalesced_width = Downcast<IntImm>(args[2]); | ||
| node->use_tma = Downcast<IntImm>(args[2]); | ||
| } | ||
| if (args.size() >= 4) { | ||
| node->coalesced_width = Downcast<IntImm>(args[3]); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Verify default initialization of use_tma.
The bounds checks properly guard access to args[2] and args[3]. However, when args.size() < 3, use_tma is not assigned. Ensure that use_tma has a safe default value (e.g., initialized to 0 in the node's declaration) to prevent undefined behavior when accessed in the Lower method at line 368.
Run the following script to check the initialization of use_tma in the header file:
🏁 Script executed:
#!/bin/bash
# Description: Check if use_tma has default initialization in AtomicAddNode
# Search for AtomicAddNode class definition and use_tma member declaration
ast-grep --pattern $'class AtomicAddNode : $_ {
$$$
use_tma
$$$
}'
# Also search for any constructor or initialization
rg -n "use_tma" src/op/atomic_add.h -A 2 -B 2Length of output: 874
Initialize default values for use_tma and coalesced_width in AtomicAddNode.
In src/op/atomic_add.h, both IntImm use_tma and IntImm coalesced_width lack default initializers, leading to null-handle dereference when args.size() < 3/4. Add inline or constructor defaults (e.g., constant 0).
🤖 Prompt for AI Agents
In src/op/atomic_add.cc around lines 82-87, the code assumes
AtomicAddNode::use_tma and ::coalesced_width are valid IntImm handles when
args.size() < 3/4, causing null-handle deref; to fix, add default initializers
(e.g., IntImm(0) or constant 0) for use_tma and coalesced_width in
src/op/atomic_add.h—either inline where members are declared or in the
AtomicAddNode constructor—so they always have a valid IntImm value when not
provided in args.
| */ | ||
| Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { | ||
| Target target = T.target; | ||
| if (use_tma->value != 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add null check before accessing use_tma->value.
If the constructor is called with fewer than 3 arguments, use_tma may be uninitialized or null. Accessing use_tma->value without verification could lead to undefined behavior or a crash.
Apply this diff to add a safety check:
- if (use_tma->value != 0) {
+ if (use_tma.defined() && use_tma->value != 0) {📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if (use_tma->value != 0) { | |
| if (use_tma.defined() && use_tma->value != 0) { |
🤖 Prompt for AI Agents
In src/op/atomic_add.cc around line 368, the code directly accesses
use_tma->value which can be null if the constructor was called with fewer than
three arguments; add a null check before accessing use_tma->value (e.g., verify
use_tma is non-null) and only evaluate use_tma->value when use_tma != nullptr,
otherwise treat as false/zero or handle the missing argument path appropriately
to avoid undefined behavior or crashes.
|
LGTM |
…ample (tile-ai#969) * [Feature][Example] Support TMA reduce operation and update GQA bwd example * move GQA bwd with TMA reduce to new example * [Lint]: [pre-commit.ci] auto fixes [...] --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This pull request introduces significant changes to enable and optimize the use of Tensor Memory Access (TMA)-based atomic add operations for improved performance in the FlashAttention backward pass implementation. The changes span both the Python and C++ codebases, adding new features, updating APIs, and refactoring the kernel logic to leverage TMA reductions where appropriate.
Key changes include:
FlashAttention kernel and Python API updates:
flashattn_bwd_postprocessand related functions to support TMA-based reductions for gradients (dQ,dK,dV), and adjusted their signatures and logic to handle the new layouts and additional tensors. [1] [2] [3] [4] [5]atomic_addAPI and its internal logic to accept ause_tmaargument, enabling selection between standard and TMA-based atomic add operations. [1] [2]Core operator and lowering changes (C++):
AtomicAddNodeand related logic to support ause_tmaflag, and implemented a new lowering path that generates TMA-based reduction code when this flag is set. This includes new methods for index/size calculation and code generation. [1] [2] [3] [4] [5] [6] [7]tma_store_addfor CUDA, implementing the actual TMA-based atomic add using the appropriate PTX instruction.Bulk copy and codegen adjustments:
tl::tma_store_addwhen reductions are requested, and to handle the new argument conventions. [1] [2]These changes collectively enable more efficient and scalable reductions in the FlashAttention backward pass by leveraging hardware-accelerated TMA instructions, while maintaining backward compatibility and flexibility in the API.
FlashAttention kernel and Python API updates:
flashattn_bwd_postprocessand related kernel logic to support TMA-based reductions fordQ,dK, anddV, including changes to tensor layout, kernel signatures, and usage in the main backward path. [1] [2] [3] [4] [5]atomic_addAPI and internal region logic to accept and propagate ause_tmaargument for selecting TMA-based reductions. [1] [2]Core operator and lowering changes (C++):
use_tmaflag toAtomicAddNode, implemented a new lowering path for TMA-based reductions, and added helper methods for index/size calculation. [1] [2] [3] [4] [5] [6] [7]tma_store_addCUDA device function for TMA-based atomic add operations using the appropriate PTX instruction.Bulk copy and codegen adjustments:
tl::tma_store_addwhen reductions are required and to handle new argument conventions. [1] [2]Summary by CodeRabbit
New Features
Performance
Documentation