-
Notifications
You must be signed in to change notification settings - Fork 332
[Feature][Example] Support TMA reduce operation and update GQA bwd example #969
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -80,7 +80,10 @@ AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) { | |||||
| std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]); | ||||||
| std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]); | ||||||
| if (args.size() >= 3) { | ||||||
| node->coalesced_width = Downcast<IntImm>(args[2]); | ||||||
| node->use_tma = Downcast<IntImm>(args[2]); | ||||||
| } | ||||||
| if (args.size() >= 4) { | ||||||
| node->coalesced_width = Downcast<IntImm>(args[3]); | ||||||
| } | ||||||
| data_ = std::move(node); | ||||||
| } | ||||||
|
|
@@ -169,6 +172,18 @@ Array<PrimExpr> AtomicAddNode::MakeIndices(const Array<IterVar> &ivs, | |||||
| return indices; | ||||||
| } | ||||||
|
|
||||||
| std::pair<Array<PrimExpr>, PrimExpr> | ||||||
| AtomicAddNode::ReturnIndicesAndSize(int src_dst) const { | ||||||
| Array<PrimExpr> indices; | ||||||
| Array<Range> ranges = src_dst == 0 ? src_range : dst_range; | ||||||
| PrimExpr size = 1; | ||||||
| for (size_t i = 0; i < ranges.size(); i++) { | ||||||
| indices.push_back(ranges[i]->min); | ||||||
| size *= ranges[i]->extent; | ||||||
| } | ||||||
| return {indices, size}; | ||||||
| } | ||||||
|
|
||||||
| /** | ||||||
| * @brief Build a combined bound-check predicate for indexed access. | ||||||
| * | ||||||
|
|
@@ -350,6 +365,28 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { | |||||
| */ | ||||||
| Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { | ||||||
| Target target = T.target; | ||||||
| if (use_tma->value != 0) { | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add null check before accessing If the constructor is called with fewer than 3 arguments, Apply this diff to add a safety check: - if (use_tma->value != 0) {
+ if (use_tma.defined() && use_tma->value != 0) {📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||
| Array<PrimExpr> src_indices, dst_indices; | ||||||
| PrimExpr src_size, dst_size; | ||||||
| std::tie(src_indices, src_size) = ReturnIndicesAndSize(0); | ||||||
| std::tie(dst_indices, dst_size) = ReturnIndicesAndSize(1); | ||||||
| ICHECK(analyzer->CanProveEqual(src_size, dst_size)) | ||||||
| << "src_size = " << src_size << ", dst_size = " << dst_size; | ||||||
| BufferLoad src_node = BufferLoad(src, src_indices); | ||||||
| BufferLoad dst_node = BufferLoad(dst, dst_indices); | ||||||
| Call address_of_src = | ||||||
| Call(DataType::Handle(), builtin::address_of(), {src_node}); | ||||||
| Call address_of_dst = | ||||||
| Call(DataType::Handle(), builtin::address_of(), {dst_node}); | ||||||
|
|
||||||
| int need_reduce = 1; | ||||||
| int eviction_policy = 0; | ||||||
| auto body = Evaluate(Call(DataType::Handle(), tma_store(), | ||||||
| {address_of_src, address_of_dst, | ||||||
| ceildiv(src_size * src->dtype.bits(), 8), | ||||||
| need_reduce, eviction_policy})); | ||||||
| return IfThenElse(EQ(T.thread_var, T.thread_bounds->min), body); | ||||||
| } | ||||||
| auto simt_loop = MakeSIMTLoop(analyzer); | ||||||
| auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse(simt_loop)); | ||||||
| auto par_op = ParallelOp(fused_loop); | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Verify default initialization of
use_tma.The bounds checks properly guard access to
args[2]andargs[3]. However, whenargs.size() < 3,use_tmais not assigned. Ensure thatuse_tmahas a safe default value (e.g., initialized to 0 in the node's declaration) to prevent undefined behavior when accessed in theLowermethod at line 368.Run the following script to check the initialization of
use_tmain the header file:🏁 Script executed:
Length of output: 874
Initialize default values for use_tma and coalesced_width in AtomicAddNode.
In
src/op/atomic_add.h, bothIntImm use_tmaandIntImm coalesced_widthlack default initializers, leading to null-handle dereference whenargs.size() < 3/4. Add inline or constructor defaults (e.g., constant 0).🤖 Prompt for AI Agents