Skip to content

Commit 30f6b6f

Browse files
committed
Merge remote-tracking branch 'upstream/main' into retire-format-sh
2 parents 8c7ce10 + 1d4b718 commit 30f6b6f

19 files changed

+1435
-439
lines changed

.github/workflows/ci.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,6 @@ jobs:
226226
echo "PIP_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL}" | tee -a "${GITHUB_ENV}"
227227
echo "UV_INDEX=${UV_INDEX}" | tee -a "${GITHUB_ENV}"
228228
fi
229-
230229
export CLANG_TIDY_CMAKE_OPTIONS="${CLANG_TIDY_CMAKE_OPTIONS} -DUSE_METAL=ON"
231230
232231
echo "USE_METAL=ON" | tee -a "${GITHUB_ENV}"

examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py

Lines changed: 792 additions & 0 deletions
Large diffs are not rendered by default.

examples/flash_attention/example_gqa_fwd_varlen.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
from tilelang.profiler import do_bench
99
from varlen_utils import generate_random_padding_mask, generate_qkv
1010

11-
tilelang.disable_cache()
12-
1311

1412
def attention_ref(
1513
q,

src/op/atomic_add.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,12 @@ AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) {
5858
if (args.size() >= 3) {
5959
node->use_tma = Downcast<IntImm>(args[2]);
6060
}
61+
node->memory_order = IntImm(0);
6162
if (args.size() >= 4) {
62-
node->coalesced_width = Downcast<IntImm>(args[3]);
63+
node->memory_order = Downcast<IntImm>(args[3]);
64+
}
65+
if (args.size() >= 5) {
66+
node->coalesced_width = Downcast<IntImm>(args[4]);
6367
}
6468
data_ = std::move(node);
6569
}
@@ -285,6 +289,7 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
285289

286290
new_args.push_back(dst_value);
287291
new_args.push_back(src_value);
292+
new_args.push_back(memory_order);
288293

289294
Call atomicadd_call =
290295
tvm::tir::Call(dst->dtype, atomicadd_elem_op(), new_args);

src/op/atomic_add.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class AtomicAddNode : public TileOperatorNode {
2222
dst_range; ///< Access ranges for source and destination
2323
IntImm use_tma; ///< Whether to use TMA for memory operations
2424
IntImm coalesced_width; ///< Width for memory coalescing optimization
25+
IntImm memory_order; ///< Memory order for atomic operations
2526

2627
mutable ParallelOp par_op_; ///< Associated parallel operation
2728
static constexpr const char *_type_key = "tl.AtomicAdd";
@@ -41,15 +42,17 @@ class AtomicAddNode : public TileOperatorNode {
4142
.def_ro("src_range", &AtomicAddNode::src_range)
4243
.def_ro("dst_range", &AtomicAddNode::dst_range)
4344
.def_ro("use_tma", &AtomicAddNode::use_tma)
44-
.def_ro("coalesced_width", &AtomicAddNode::coalesced_width);
45+
.def_ro("coalesced_width", &AtomicAddNode::coalesced_width)
46+
.def_ro("memory_order", &AtomicAddNode::memory_order);
4547
}
4648

4749
bool SEqualReduce(const AtomicAddNode *other, SEqualReducer equal) const {
4850
return equal(src, other->src) && equal(dst, other->dst) &&
4951
equal(src_range, other->src_range) &&
5052
equal(dst_range, other->dst_range) &&
5153
equal(use_tma, other->use_tma) &&
52-
equal(coalesced_width, other->coalesced_width);
54+
equal(coalesced_width, other->coalesced_width) &&
55+
equal(memory_order, other->memory_order);
5356
}
5457

5558
void SHashReduce(SHashReducer hash_reduce) const {
@@ -59,6 +62,7 @@ class AtomicAddNode : public TileOperatorNode {
5962
hash_reduce(dst_range);
6063
hash_reduce(use_tma);
6164
hash_reduce(coalesced_width);
65+
hash_reduce(memory_order);
6266
}
6367

6468
static constexpr bool _type_has_method_sequal_reduce = true;

src/op/builtin.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ TIR_DEFINE_TL_BUILTIN(increase_descriptor_offset)
296296
Integer(CallEffectKind::kOpaque));
297297

298298
TIR_DEFINE_TL_BUILTIN(atomicadd_elem_op)
299-
.set_num_inputs(2)
299+
.set_num_inputs(3)
300300
.set_attr<TCallEffectKind>("TCallEffectKind",
301301
Integer(CallEffectKind::kOpaque));
302302

0 commit comments

Comments
 (0)