Skip to content

Commit 0550703

Browse files
[Feature][Example] Support TMA reduce operation and update GQA bwd example (#969)
* [Feature][Example] Support TMA reduce operation and update GQA bwd example * move GQA bwd with TMA reduce to new example * [Lint]: [pre-commit.ci] auto fixes [...] --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 77b9d08 commit 0550703

File tree

7 files changed

+640
-5
lines changed

7 files changed

+640
-5
lines changed

examples/flash_attention/example_gqa_bwd_tma_reduce.py

Lines changed: 569 additions & 0 deletions
Large diffs are not rendered by default.

src/op/atomic_add.cc

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,10 @@ AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) {
8080
std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]);
8181
std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]);
8282
if (args.size() >= 3) {
83-
node->coalesced_width = Downcast<IntImm>(args[2]);
83+
node->use_tma = Downcast<IntImm>(args[2]);
84+
}
85+
if (args.size() >= 4) {
86+
node->coalesced_width = Downcast<IntImm>(args[3]);
8487
}
8588
data_ = std::move(node);
8689
}
@@ -169,6 +172,18 @@ Array<PrimExpr> AtomicAddNode::MakeIndices(const Array<IterVar> &ivs,
169172
return indices;
170173
}
171174

175+
std::pair<Array<PrimExpr>, PrimExpr>
176+
AtomicAddNode::ReturnIndicesAndSize(int src_dst) const {
177+
Array<PrimExpr> indices;
178+
Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
179+
PrimExpr size = 1;
180+
for (size_t i = 0; i < ranges.size(); i++) {
181+
indices.push_back(ranges[i]->min);
182+
size *= ranges[i]->extent;
183+
}
184+
return {indices, size};
185+
}
186+
172187
/**
173188
* @brief Build a combined bound-check predicate for indexed access.
174189
*
@@ -350,6 +365,28 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
350365
*/
351366
Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
352367
Target target = T.target;
368+
if (use_tma->value != 0) {
369+
Array<PrimExpr> src_indices, dst_indices;
370+
PrimExpr src_size, dst_size;
371+
std::tie(src_indices, src_size) = ReturnIndicesAndSize(0);
372+
std::tie(dst_indices, dst_size) = ReturnIndicesAndSize(1);
373+
ICHECK(analyzer->CanProveEqual(src_size, dst_size))
374+
<< "src_size = " << src_size << ", dst_size = " << dst_size;
375+
BufferLoad src_node = BufferLoad(src, src_indices);
376+
BufferLoad dst_node = BufferLoad(dst, dst_indices);
377+
Call address_of_src =
378+
Call(DataType::Handle(), builtin::address_of(), {src_node});
379+
Call address_of_dst =
380+
Call(DataType::Handle(), builtin::address_of(), {dst_node});
381+
382+
int need_reduce = 1;
383+
int eviction_policy = 0;
384+
auto body = Evaluate(Call(DataType::Handle(), tma_store(),
385+
{address_of_src, address_of_dst,
386+
ceildiv(src_size * src->dtype.bits(), 8),
387+
need_reduce, eviction_policy}));
388+
return IfThenElse(EQ(T.thread_var, T.thread_bounds->min), body);
389+
}
353390
auto simt_loop = MakeSIMTLoop(analyzer);
354391
auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse(simt_loop));
355392
auto par_op = ParallelOp(fused_loop);

src/op/atomic_add.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class AtomicAddNode : public TileOperatorNode {
2020
Buffer src, dst; ///< Source and destination buffers
2121
Array<Range> src_range,
2222
dst_range; ///< Access ranges for source and destination
23+
IntImm use_tma; ///< Whether to use TMA for memory operations
2324
IntImm coalesced_width; ///< Width for memory coalescing optimization
2425

2526
mutable ParallelOp par_op_; ///< Associated parallel operation
@@ -39,13 +40,15 @@ class AtomicAddNode : public TileOperatorNode {
3940
.def_ro("dst", &AtomicAddNode::dst)
4041
.def_ro("src_range", &AtomicAddNode::src_range)
4142
.def_ro("dst_range", &AtomicAddNode::dst_range)
43+
.def_ro("use_tma", &AtomicAddNode::use_tma)
4244
.def_ro("coalesced_width", &AtomicAddNode::coalesced_width);
4345
}
4446

4547
bool SEqualReduce(const AtomicAddNode *other, SEqualReducer equal) const {
4648
return equal(src, other->src) && equal(dst, other->dst) &&
4749
equal(src_range, other->src_range) &&
4850
equal(dst_range, other->dst_range) &&
51+
equal(use_tma, other->use_tma) &&
4952
equal(coalesced_width, other->coalesced_width);
5053
}
5154

@@ -54,6 +57,7 @@ class AtomicAddNode : public TileOperatorNode {
5457
hash_reduce(dst);
5558
hash_reduce(src_range);
5659
hash_reduce(dst_range);
60+
hash_reduce(use_tma);
5761
hash_reduce(coalesced_width);
5862
}
5963

@@ -67,6 +71,8 @@ class AtomicAddNode : public TileOperatorNode {
6771
Array<IterVar> MakeIterVars() const;
6872
/// Generate buffer indices from iteration variables
6973
Array<PrimExpr> MakeIndices(const Array<IterVar> &ivs, int src_dst) const;
74+
/// Return buffer indices and size
75+
std::pair<Array<PrimExpr>, PrimExpr> ReturnIndicesAndSize(int src_dst) const;
7076
/// Create boundary predicate for memory safety
7177
PrimExpr MakePredicate(arith::Analyzer *analyzer, const Array<IterVar> &ivs,
7278
Array<PrimExpr> extents, int src_dst) const;

src/op/copy.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1571,6 +1571,9 @@ Stmt CopyNode::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer,
15711571
global_coords.Set(0, global_coords[0] + instruction_dim * loop_var);
15721572
for (auto coord : global_coords)
15731573
args.push_back(coord);
1574+
int need_reduce = 0;
1575+
if (!is_load)
1576+
args.push_back(need_reduce);
15741577
args.push_back(this->eviction_policy);
15751578
tma_copy = For(loop_var, 0, loop_extent, ForKind::kUnrolled,
15761579
Evaluate(Call(DataType::Handle(), op, args)));
@@ -1580,6 +1583,9 @@ Stmt CopyNode::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer,
15801583
args.push_back(shared_addr);
15811584
for (auto coord : global_coords)
15821585
args.push_back(coord);
1586+
int need_reduce = 0;
1587+
if (!is_load)
1588+
args.push_back(need_reduce);
15831589
args.push_back(this->eviction_policy);
15841590
tma_copy = Evaluate(Call(DataType::Handle(), op, args));
15851591
}
@@ -1654,10 +1660,11 @@ Stmt CopyNode::LowerBulkCopy1D(const LowerArgs &T, arith::Analyzer *analyzer,
16541660
{shared_addr, global_addr, 0,
16551661
elements * shared_tensor->dtype.bytes(), this->eviction_policy}));
16561662
} else {
1663+
int need_reduce = 0;
16571664
tma_copy = Evaluate(
16581665
Call(DataType::Handle(), tma_store(),
16591666
{global_addr, shared_addr, elements * shared_tensor->dtype.bytes(),
1660-
this->eviction_policy}));
1667+
need_reduce, this->eviction_policy}));
16611668
}
16621669
tma_copy = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), tma_copy);
16631670
return tma_copy;

src/target/codegen_cuda.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1345,6 +1345,11 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
13451345
print_extern_call_stmt(ss.str(), 0, 1);
13461346
} else if (op->op.same_as(tl::tma_store())) {
13471347
std::stringstream ss;
1348+
auto need_reduce = op->args[op->args.size() - 2].as<IntImmNode>()->value;
1349+
if (need_reduce) {
1350+
print_extern_call_stmt("tl::tma_store_add", 0, 2);
1351+
return;
1352+
}
13481353
auto eviction_policy =
13491354
this->eviction_policy_names_
13501355
[op->args[op->args.size() - 1].as<IntImmNode>()->value];
@@ -1353,7 +1358,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
13531358
} else {
13541359
ss << "tl::tma_store";
13551360
}
1356-
print_extern_call_stmt(ss.str(), 0, 1);
1361+
print_extern_call_stmt(ss.str(), 0, 2);
13571362
} else if (op->op.same_as(tl::ptx_ldmatrix())) {
13581363
int trans = Downcast<IntImm>(op->args[0])->value;
13591364
int num = Downcast<IntImm>(op->args[1])->value;

src/tl_templates/cuda/copy_sm90.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,16 @@ TL_DEVICE void tma_store(const CUtensorMap &descriptor,
252252
: "memory");
253253
}
254254

255+
TL_DEVICE void tma_store_add(float *const smem_ptr, float *gmem_ptr,
256+
int32_t const &store_bytes) {
257+
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
258+
asm volatile("cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 "
259+
"[%0], [%1], %2;\n"
260+
:
261+
: "l"(gmem_ptr), "r"(smem_int_ptr), "r"(store_bytes)
262+
: "memory");
263+
}
264+
255265
TL_DEVICE void prefetch_tma_descriptor(const CUtensorMap &descriptor) {
256266
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
257267
asm volatile("prefetch.tensormap [%0];" : : "l"(gmem_int_desc) : "memory");

tilelang/language/atomic.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,8 @@ def atomic_min(dst: Buffer,
116116
def atomic_add(dst: Buffer,
117117
value: PrimExpr,
118118
memory_order: Optional[str] = None,
119-
return_prev: bool = False) -> PrimExpr:
119+
return_prev: bool = False,
120+
use_tma: bool = False) -> PrimExpr:
120121
"""
121122
Atomically add `value` into `dst`, returning a handle to the operation.
122123
@@ -225,7 +226,7 @@ def _to_region(data, access_type):
225226
raise NotImplementedError(
226227
"return_prev is not supported for tile-region-based atomic operations")
227228

228-
return T.call_intrin("handle", op.Op.get("tl.atomicadd"), value, dst)
229+
return T.call_intrin("handle", op.Op.get("tl.atomicadd"), value, dst, use_tma)
229230

230231

231232
def atomic_addx2(dst: Buffer, value: PrimExpr, return_prev: bool = False) -> PrimExpr:

0 commit comments

Comments
 (0)