Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
569 changes: 569 additions & 0 deletions examples/flash_attention/example_gqa_bwd_tma_reduce.py

Large diffs are not rendered by default.

39 changes: 38 additions & 1 deletion src/op/atomic_add.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,10 @@ AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) {
std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]);
std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]);
if (args.size() >= 3) {
node->coalesced_width = Downcast<IntImm>(args[2]);
node->use_tma = Downcast<IntImm>(args[2]);
}
if (args.size() >= 4) {
node->coalesced_width = Downcast<IntImm>(args[3]);
}
Comment on lines 82 to 87
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Verify default initialization of use_tma.

The bounds checks properly guard access to args[2] and args[3]. However, when args.size() < 3, use_tma is not assigned. Ensure that use_tma has a safe default value (e.g., initialized to 0 in the node's declaration) to prevent undefined behavior when accessed in the Lower method at line 368.

Run the following script to check the initialization of use_tma in the header file:


🏁 Script executed:

#!/bin/bash
# Description: Check if use_tma has default initialization in AtomicAddNode

# Search for AtomicAddNode class definition and use_tma member declaration
ast-grep --pattern $'class AtomicAddNode : $_ {
  $$$
  use_tma
  $$$
}'

# Also search for any constructor or initialization
rg -n "use_tma" src/op/atomic_add.h -A 2 -B 2

Length of output: 874


Initialize default values for use_tma and coalesced_width in AtomicAddNode.
In src/op/atomic_add.h, both IntImm use_tma and IntImm coalesced_width lack default initializers, leading to null-handle dereference when args.size() < 3/4. Add inline or constructor defaults (e.g., constant 0).

🤖 Prompt for AI Agents
In src/op/atomic_add.cc around lines 82-87, the code assumes
AtomicAddNode::use_tma and ::coalesced_width are valid IntImm handles when
args.size() < 3/4, causing null-handle deref; to fix, add default initializers
(e.g., IntImm(0) or constant 0) for use_tma and coalesced_width in
src/op/atomic_add.h—either inline where members are declared or in the
AtomicAddNode constructor—so they always have a valid IntImm value when not
provided in args.

data_ = std::move(node);
}
Expand Down Expand Up @@ -169,6 +172,18 @@ Array<PrimExpr> AtomicAddNode::MakeIndices(const Array<IterVar> &ivs,
return indices;
}

std::pair<Array<PrimExpr>, PrimExpr>
AtomicAddNode::ReturnIndicesAndSize(int src_dst) const {
Array<PrimExpr> indices;
Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
PrimExpr size = 1;
for (size_t i = 0; i < ranges.size(); i++) {
indices.push_back(ranges[i]->min);
size *= ranges[i]->extent;
}
return {indices, size};
}

/**
* @brief Build a combined bound-check predicate for indexed access.
*
Expand Down Expand Up @@ -350,6 +365,28 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
*/
Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
Target target = T.target;
if (use_tma->value != 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Add null check before accessing use_tma->value.

If the constructor is called with fewer than 3 arguments, use_tma may be uninitialized or null. Accessing use_tma->value without verification could lead to undefined behavior or a crash.

Apply this diff to add a safety check:

-  if (use_tma->value != 0) {
+  if (use_tma.defined() && use_tma->value != 0) {
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if (use_tma->value != 0) {
if (use_tma.defined() && use_tma->value != 0) {
🤖 Prompt for AI Agents
In src/op/atomic_add.cc around line 368, the code directly accesses
use_tma->value which can be null if the constructor was called with fewer than
three arguments; add a null check before accessing use_tma->value (e.g., verify
use_tma is non-null) and only evaluate use_tma->value when use_tma != nullptr,
otherwise treat as false/zero or handle the missing argument path appropriately
to avoid undefined behavior or crashes.

Array<PrimExpr> src_indices, dst_indices;
PrimExpr src_size, dst_size;
std::tie(src_indices, src_size) = ReturnIndicesAndSize(0);
std::tie(dst_indices, dst_size) = ReturnIndicesAndSize(1);
ICHECK(analyzer->CanProveEqual(src_size, dst_size))
<< "src_size = " << src_size << ", dst_size = " << dst_size;
BufferLoad src_node = BufferLoad(src, src_indices);
BufferLoad dst_node = BufferLoad(dst, dst_indices);
Call address_of_src =
Call(DataType::Handle(), builtin::address_of(), {src_node});
Call address_of_dst =
Call(DataType::Handle(), builtin::address_of(), {dst_node});

int need_reduce = 1;
int eviction_policy = 0;
auto body = Evaluate(Call(DataType::Handle(), tma_store(),
{address_of_src, address_of_dst,
ceildiv(src_size * src->dtype.bits(), 8),
need_reduce, eviction_policy}));
return IfThenElse(EQ(T.thread_var, T.thread_bounds->min), body);
}
auto simt_loop = MakeSIMTLoop(analyzer);
auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse(simt_loop));
auto par_op = ParallelOp(fused_loop);
Expand Down
6 changes: 6 additions & 0 deletions src/op/atomic_add.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class AtomicAddNode : public TileOperatorNode {
Buffer src, dst; ///< Source and destination buffers
Array<Range> src_range,
dst_range; ///< Access ranges for source and destination
IntImm use_tma; ///< Whether to use TMA for memory operations
IntImm coalesced_width; ///< Width for memory coalescing optimization

mutable ParallelOp par_op_; ///< Associated parallel operation
Expand All @@ -39,13 +40,15 @@ class AtomicAddNode : public TileOperatorNode {
.def_ro("dst", &AtomicAddNode::dst)
.def_ro("src_range", &AtomicAddNode::src_range)
.def_ro("dst_range", &AtomicAddNode::dst_range)
.def_ro("use_tma", &AtomicAddNode::use_tma)
.def_ro("coalesced_width", &AtomicAddNode::coalesced_width);
}

bool SEqualReduce(const AtomicAddNode *other, SEqualReducer equal) const {
return equal(src, other->src) && equal(dst, other->dst) &&
equal(src_range, other->src_range) &&
equal(dst_range, other->dst_range) &&
equal(use_tma, other->use_tma) &&
equal(coalesced_width, other->coalesced_width);
}

Expand All @@ -54,6 +57,7 @@ class AtomicAddNode : public TileOperatorNode {
hash_reduce(dst);
hash_reduce(src_range);
hash_reduce(dst_range);
hash_reduce(use_tma);
hash_reduce(coalesced_width);
}

Expand All @@ -67,6 +71,8 @@ class AtomicAddNode : public TileOperatorNode {
Array<IterVar> MakeIterVars() const;
/// Generate buffer indices from iteration variables
Array<PrimExpr> MakeIndices(const Array<IterVar> &ivs, int src_dst) const;
/// Return buffer indices and size
std::pair<Array<PrimExpr>, PrimExpr> ReturnIndicesAndSize(int src_dst) const;
/// Create boundary predicate for memory safety
PrimExpr MakePredicate(arith::Analyzer *analyzer, const Array<IterVar> &ivs,
Array<PrimExpr> extents, int src_dst) const;
Expand Down
9 changes: 8 additions & 1 deletion src/op/copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1571,6 +1571,9 @@ Stmt CopyNode::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer,
global_coords.Set(0, global_coords[0] + instruction_dim * loop_var);
for (auto coord : global_coords)
args.push_back(coord);
int need_reduce = 0;
if (!is_load)
args.push_back(need_reduce);
args.push_back(this->eviction_policy);
tma_copy = For(loop_var, 0, loop_extent, ForKind::kUnrolled,
Evaluate(Call(DataType::Handle(), op, args)));
Expand All @@ -1580,6 +1583,9 @@ Stmt CopyNode::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer,
args.push_back(shared_addr);
for (auto coord : global_coords)
args.push_back(coord);
int need_reduce = 0;
if (!is_load)
args.push_back(need_reduce);
args.push_back(this->eviction_policy);
tma_copy = Evaluate(Call(DataType::Handle(), op, args));
}
Expand Down Expand Up @@ -1654,10 +1660,11 @@ Stmt CopyNode::LowerBulkCopy1D(const LowerArgs &T, arith::Analyzer *analyzer,
{shared_addr, global_addr, 0,
elements * shared_tensor->dtype.bytes(), this->eviction_policy}));
} else {
int need_reduce = 0;
tma_copy = Evaluate(
Call(DataType::Handle(), tma_store(),
{global_addr, shared_addr, elements * shared_tensor->dtype.bytes(),
this->eviction_policy}));
need_reduce, this->eviction_policy}));
}
tma_copy = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), tma_copy);
return tma_copy;
Expand Down
7 changes: 6 additions & 1 deletion src/target/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1345,6 +1345,11 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
print_extern_call_stmt(ss.str(), 0, 1);
} else if (op->op.same_as(tl::tma_store())) {
std::stringstream ss;
auto need_reduce = op->args[op->args.size() - 2].as<IntImmNode>()->value;
if (need_reduce) {
print_extern_call_stmt("tl::tma_store_add", 0, 2);
return;
}
auto eviction_policy =
this->eviction_policy_names_
[op->args[op->args.size() - 1].as<IntImmNode>()->value];
Expand All @@ -1353,7 +1358,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
} else {
ss << "tl::tma_store";
}
print_extern_call_stmt(ss.str(), 0, 1);
print_extern_call_stmt(ss.str(), 0, 2);
} else if (op->op.same_as(tl::ptx_ldmatrix())) {
int trans = Downcast<IntImm>(op->args[0])->value;
int num = Downcast<IntImm>(op->args[1])->value;
Expand Down
10 changes: 10 additions & 0 deletions src/tl_templates/cuda/copy_sm90.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,16 @@ TL_DEVICE void tma_store(const CUtensorMap &descriptor,
: "memory");
}

TL_DEVICE void tma_store_add(float *const smem_ptr, float *gmem_ptr,
int32_t const &store_bytes) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile("cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 "
"[%0], [%1], %2;\n"
:
: "l"(gmem_ptr), "r"(smem_int_ptr), "r"(store_bytes)
: "memory");
}

TL_DEVICE void prefetch_tma_descriptor(const CUtensorMap &descriptor) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
asm volatile("prefetch.tensormap [%0];" : : "l"(gmem_int_desc) : "memory");
Expand Down
5 changes: 3 additions & 2 deletions tilelang/language/atomic.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ def atomic_min(dst: Buffer,
def atomic_add(dst: Buffer,
value: PrimExpr,
memory_order: Optional[str] = None,
return_prev: bool = False) -> PrimExpr:
return_prev: bool = False,
use_tma: bool = False) -> PrimExpr:
"""
Atomically add `value` into `dst`, returning a handle to the operation.

Expand Down Expand Up @@ -225,7 +226,7 @@ def _to_region(data, access_type):
raise NotImplementedError(
"return_prev is not supported for tile-region-based atomic operations")

return T.call_intrin("handle", op.Op.get("tl.atomicadd"), value, dst)
return T.call_intrin("handle", op.Op.get("tl.atomicadd"), value, dst, use_tma)


def atomic_addx2(dst: Buffer, value: PrimExpr, return_prev: bool = False) -> PrimExpr:
Expand Down
Loading